Filtering negative numbers, fastScalar
While working deep on the guts of RavenDB, I found myself with a seemingly simple task. Given a list of longs, I need to filter out all negative numbers as quickly as possible.
The actual scenario is that we run a speculative algorithm, given a potentially large list of items, we check if we can fulfill the request in an optimal fashion. However, if that isn’t possible, we need to switch to a slower code path that does more work.
Conceptually, this looks something like this:
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode characters
void Run(Span<long> entries, long switchOverPoint, Aggregation state) { long cost = 0; for (var i = 0; i < entries.Length; i++) { var (item, searchCost) = FetchCheaply(entries[i]); cost += searchCost; if(item is not null) { entries[i] = -entries[i]; // mark as processed if(state.Aggregate(item)) return; // we have enough results to bail early } if(cost > switchOverPoint) { // speculative execution failed, we need to do this the hard way RunManually(entries, state); return; } } }
That is the setup for this story. The problem we have now is that we now need to filter the results we pass to the RunManually() method.
There is a problem here, however. We marked the entries that we already used in the list by negating them. The issue is that RunManually() does not allow negative values, and its internal implementation is not friendly to ignoring those values.
In other words, given a Span<long>, I need to write the code that would filter out all the negative numbers. Everything else about the list of numbers should remain the same (the order of elements, etc).
From a coding perspective, this is as simple as:
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode characters
public Span<long> FilterNegative(Span<long> entries) { return entries.ToArray().Where(l => l > 0).ToArray(); }
Please note, just looking at this code makes me cringe a lot. This does the work, but it has an absolutely horrible performance profile. It allocates multiple arrays, uses a lambda, etc.
We don’t actually care about the entries here, so we are free to modify them without allocating a new value. As such, let’s see what kind of code we can write to do this work in an efficient manner. Here is what I came up with:
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode characters
public static int FilterCmp(Span<long> items) { int output = 0; for (int i = 0; i < items.Length; i++) { if (items[i] < 0) continue; items[output++] = items[i]; } return output; }
The way this works is that we scan through the list, skipping writing the negative lists, so we effectively “move down” all the non-negative lists on top of the negative ones. This has a cost of O(N) and will modify the entire array, the final output is the number of valid items that we have there.
In order to test the performance, I wrote the following harness:
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode characters
[RPlotExporter] [DisassemblyDiagnoser] public class SimdVsScalar { private long[] items; [Params(1024 + 23, 1024 * 1024 + 23, 32 * 1024 * 1024 + 23)] public int N; [GlobalSetup] public void Setup() { items = new long[N]; var r = new Random(2391); for (int i = 0; i < items.Length; i++) { items[i] = r.NextInt64(); } } private void SprinkleNegatives() { var r = new Random(13245); var negatives = Math.Max((int)(items.Length * 0.005), 1); for (int i = 0; i < negatives; i++) { var idx = r.Next(items.Length); items[idx] = -items[idx]; } } [Benchmark] public int FilterOr() { SprinkleNegatives(); return Filter.FilterOr(items); } [Benchmark] public int FilterCmp() { SprinkleNegatives(); return Filter.FilterCmp(items); } [Benchmark] public int Base() { SprinkleNegatives(); return items.Length; } }
We compare 1K, 1M and 32M elements arrays, each of which has about 0.5% negative, randomly spread across the range. Because we modify the values directly, we need to sprinkle the negatives across the array on each call. In this case, I’m testing two options for this task, one that uses a direct comparison (shown above) and one that uses bitwise or, like so:
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode characters
public static int FilterOr(Span<long> items) { int output = 0; for (int i = 0; i < items.Length; i++) { if ((items[i] & ~long.MaxValue) != 0) continue; items[output++] = items[i]; } return output; }
I’m testing the cost of sprinkling negatives as well, since that has to be done before each benchmark call (since we modify the array during the call, we need to “reset” its state for the next one).
Given the two options, before we discuss the results, what would you expect to be the faster option? How would the size of the array matter here?
I really like this example, because it is simple, there isn’t any real complexity in what we are trying to do. And there is a very straightforward implementation that we can use as our baseline. That also means that I get to analyze what is going on at a very deep level. You might have noticed the disassembler attribute on the benchmark code, we are going to dive deep into that. For the same reason, we aren’t using exactly 1K, 1M, or 32M arrays, but slightly higher than that, so we’ll have to deal with remainders later on.
Let’s first look at what the JIT actually did here. Here is the annotated assembly for the FilterCmp function:
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode characters
; Filter.FilterCmp(System.Span`1<Int64>) sub rsp,28 ; Reserve stack space mov rax,[rcx] ; rax now holds the pointer from the span mov edx,[rcx+8] ; edx now holds the length of the span xor ecx,ecx ; zero ecx (i) xor r8d,r8d ; zero r8d (output) test edx,edx ; if items.Length <= 0 jle short M02_L02 ; jump the epilog & return M02_L00: mov r9d,r8d mov r9,[rax+r9*8] ; r9 = items[i] test r9,r9 ; r9 < 0 jl short M02_L01 ; continue to next iteration ; items[output] -- range check lea r10d,[rcx+1] ; r10d = output+1 cmp ecx,edx ; check against items.length jae short M02_L03 ; jump to out of range exception mov ecx,ecx ; clear high bits in rcx (overflow from addition?) mov [rax+rcx*8],r9 ; items[output] = items[i] mov ecx,r10d ; output = r10d M02_L01: inc r8d ; i++ cmp r8d,edx ; i < item.Length jl short M02_L00 ; back to start of the loop M02_L02: mov eax,ecx ; return output add rsp,28 ret M02_L03: ; range check failure call CORINFO_HELP_RNGCHKFAIL int 3 ; Total bytes of code 69
For the FilterOr, the code is pretty much the same, except that the key part is:
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode characters
- test r9,r9 - jl short M02_L01 + mov r10,8000000000000000 + test r10,r9 + jne short M02_L01
As you can see, the cmp option is slightly smaller, in terms of code size. In terms of performance, we have:
Method | N | Mean |
---|---|---|
FilterOr | 1047 | 745.6 ns |
FilterCmp | 1047 | 745.8 ns |
— | – | – |
FilterOr | 1048599 | 497,463.6 ns |
FilterCmp | 1048599 | 498,784.8 ns |
— | – | – |
FilterOr | 33554455 | 31,427,660.7 ns |
FilterCmp | 33554455 | 30,024,102.9 ns |
The costs are very close to one another, with Or being very slightly faster on low numbers, and Cmp being slightly faster on the larger sizes. Note that the difference level between them is basically noise. They have the same performance.
The question is, can we do better here?
Looking at the assembly, there is an extra range check in the main loop that the JIT couldn’t elide (the call to items[output++]). Can we do something about it, and would it make any difference in performance? Here is how I can remove the range check:
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode characters
public static int FilterCmp_NoRangeCheck(Span<long> items) { int outputIdx = 0; int i = 0; ref var output = ref items[i]; for (; i < items.Length; i++) { if (items[i] < 0) continue; ref var outputDest = ref Unsafe.Add(ref output, outputIdx++); outputDest = items[i]; } return outputIdx; }
Here I’m telling the JIT: “I know what I’m doing”, and it shows.
Let’s look at the assembly changes between those two methods, first the prolog:
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode characters
; Filter.FilterCmp_NoRangeCheck(System.Span`1<Int64>) sub rsp,28 ; setup stack space mov rax,[rcx] ; rax is not the pointer for the span mov edx,[rcx+8] ; edx is the length of the span xor ecx,ecx ; zero ecx (i = 0) xor r8d,r8d ; zero r8d (outputIdx = 0) test edx,edx ; if items.Length == 0 je short M02_L03 ; jump to CORINFO_HELP_RNGCHKFAIL mov r9,rax ; output = items[0] test edx,edx ; if items.Length <= 0 jle short M02_L02 ; jump past the for loop
Here you can see what we are actually doing here. Note the last 4 instructions, we have a range check for the items, and then we have another check for the loop. The first will get you an exception, the second will just skip the loop. In both cases, we test the exact same thing. The JIT had a chance to actually optimize that, but didn’t.
Here is a funny scenario where adding code may reduce the amount of code generated. Let’s do another version of this method:
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode characters
public static int FilterCmp_NoRangeCheck(Span<long> items) { if (items.Length == 0) return 0; int outputIdx = 0; int i = 0; ref var output = ref items[i]; for (; i < items.Length; i++) { if (items[i] < 0) continue; ref var outputDest = ref Unsafe.Add(ref output, outputIdx++); outputDest = items[i]; } return outputIdx; }
In this case, I added a check to handle the scenario of items being empty. What can the JIT do with this now? It turns out, quite a lot. We dropped 10 bytes from the method, which is a nice result of our diet. Here is the annotated version of the assembly:
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode characters
; Filter.FilterCmp_NoRangeCheck(System.Span`1<Int64>) mov rdx,[rcx] ; rdx = pointer of items mov ecx,[rcx+8] ; ecx = length of items test ecx,ecx ; items.Length == 0 jne short M02_L00 ; if not true, jump ahead xor eax,eax ; return 0 ret M02_L00: xor eax,eax ; outputIdx = 0 xor r8d,r8d ; i = 0 mov r9,rdx ; output = items.pointer test ecx,ecx ; if items.Length <= 0 jle short M02_L03 ; jump to end function exit M02_L01: mov r10d,r8d mov r10,[r9+r10*8] ; r10 = items[i] test r10,r10 ; if items[i] < 0 jl short M02_L02 lea r11d,[rax+1] ; r11d = output +1 cdqe ; sign extent RAX (32 bits module on the addition) mov [rdx+rax*8],r10 ; copy r10 (items[i]) to items[r11d] mov eax,r11d ; outputIdx = r11d M02_L02: inc r8d ; i++ cmp r8d,ecx ; i < items.Length jl short M02_L01 ; do the loop again M02_L03: ret ; return ; Total bytes of code 59
A lot of the space savings in this case come from just not having to do a range check, but you’ll note that we still do an extra check there (lines 12..13), even though we already checked that. I think that the JIT knows that the value is not zero at this point, but has to consider that the value may be negative.
If we’ll change the initial guard clause to: items.Length <= 0, what do you think will happen? At this point, the JIT is smart enough to just elide everything, we are at 55 bytes of code and it is a super clean assembly (not a sentence I ever thought I would use). I’ll spare you going through more assembly listing, but you can find the output here.
And after all of that, where are we at?
Method | N | Mean | Error | StdDev | Ratio | RatioSD | Code Size |
---|---|---|---|---|---|---|---|
FilterCmp | 23 | 274.5 ns | 1.91 ns | 1.70 ns | 1.00 | 0.00 | 411 B |
FilterCmp_NoRangeCheck | 23 | 269.7 ns | 1.33 ns | 1.24 ns | 0.98 | 0.01 | 397 B |
FilterCmp | 1047 | 744.5 ns | 4.88 ns | 4.33 ns | 1.00 | 0.00 | 411 B |
FilterCmp_NoRangeCheck | 1047 | 745.8 ns | 3.44 ns | 3.22 ns | 1.00 | 0.00 | 397 B |
FilterCmp | 1048599 | 502,608.6 ns | 3,890.38 ns | 3,639.06 ns | 1.00 | 0.00 | 411 B |
FilterCmp_NoRangeCheck | 1048599 | 490,669.1 ns | 1,793.52 ns | 1,589.91 ns | 0.98 | 0.01 | 397 B |
FilterCmp | 33554455 | 30,495,286.6 ns | 602,907.86 ns | 717,718.92 ns | 1.00 | 0.00 | 411 B |
FilterCmp_NoRangeCheck | 33554455 | 29,952,221.2 ns | 442,176.37 ns | 391,977.84 ns | 0.99 | 0.02 | 397 B |
There is a very slight benefit to the NoRangeCheck, but even when we talk about 32M items, we aren’t talking about a lot of time.
The question what can we do better here?
More posts in "Filtering negative numbers, fast" series:
- (15 Sep 2023) Beating memcpy()
- (13 Sep 2023) AVX
- (12 Sep 2023) Unroll
- (11 Sep 2023) Scalar
Comments
This looks pretty well optimised already, but it also looks like a workload that should fair well with SIMD - guessing that's the next step?
Cocowalla,
Yep, wait for tomorrow's post :-)
Jason,
Unrolling is the next post in the series. And yes, that is the next natural step. What I find amazing is the difference between the original code and what we end up when we optimize.
Comment preview