Filtering negative numbers, fastAVX

time to read 8 min | 1588 words

In the previous post I discussed how we can optimize the filtering of negative numbers by unrolling the loop, looked into branchless code and in general was able to improve performance by up to 15% from the initial version we started with. We pushed as much as we could on what can be done using scalar code. Now it is the time to open a whole new world and see what we can do when we implement this challenge using vector instructions.

The key problem with such tasks is that SIMD, AVX and their friends were designed by… an interesting process using a perspective that makes sense if you can see in a couple of additional dimensions. I assume that at least some of that is implementation constraints, but the key issue is that when you start using SIMD, you realize that you don’t have general-purpose instructions. Instead, you have a lot of dedicated instructions that are doing one thing, hopefully well, and it is your role to compose them into something that would make sense. Oftentimes, you need to turn the solution on its head in order to successfully solve it using SIMD. The benefit, of course, is that you can get quite an amazing boost in speed when you do this.

The algorithm we use is basically to scan the list of entries and copy to the start of the list only those items that are positive. How can we do that using SIMD? The whole point here is that we want to be able to operate on multiple data, but this particular task isn’t trivial. I’m going to show the code first, then discuss what it does in detail:

We start with the usual check (if you’ll recall, that ensures that the JIT knows to elide some range checks, then we load the PremuteTable. For now, just assume that this is magic (and it is). The first interesting thing happens when we start iterating over the loop. Unlike before, now we do that in chunks of 4 int64 elements at a time. Inside the loop, we start by loading a vector of int64 and then we do the first odd thing. We call ExtractMostSignificantBits(), since the sign bit is used to mark whether a number if negative or not. That means that I can use a single instruction to get an integer with the bits set for all the negative numbers. That is particularly juicy for what we need, since there is no need for comparisons, etc.

If the mask we got is all zeroes, it means that all the numbers we loaded to the vector are positives, so we can write them as-is to the output and move to the next part. Things get interesting when that isn’t the case.

We load a permute value using some shenanigans (we’ll touch on that shortly) and call the PermuteVar8x32() method. The idea here is that we pack all the non-negative numbers to the start of the vector, then we write the vector to the output. The key here is that when we do that, we increment the output index only by the number of valid values.  The rest of this method just handles the remainder that does not fit into a vector.

The hard part in this implementation was to figure out how to handle the scenario where we loaded some negative numbers. We need a way to filter them, after all. But there is no SIMD instruction that allows us to do so. Luckily, we have the Avx2.PermuteVar8x32() method to help here. To confuse things, we don’t actually want to deal with 8x32 values. We want to deal with 4x64 values. There is Avx2.Permute4x64() method, and it will work quite nicely, with a single caveat. This method assumes that you are going to pass it a constant value. We don’t have such a constant, we need to be able to provide that based on whatever the masked bits will give us.

So how do we deal with this issue of filtering with SIMD? We need to move all the values we care about to the front of the vector. We have the method to do that, PermuteVar8x32() method, and we just need to figure out how to actually make use of this. PermuteVar8x32() accepts an input vector as well as a vector of the premutation you want to make. In this case, we are basing this on the 4 top bits of the 4 elements vector of int64. As such, there are a total of 16 options available to us. We have to deal with 32bits values rather than 64bits, but that isn’t that much of a problem.

Here is the premutation table that we’ll be using:

What you can see here is that when we have a 1 in the bits (shown in comments) we’ll not copy that to the vector. Let’s take a look at the entry of 0101, which may be caused by the following values [1,-2,3,-4].

When we look at the right entry at index #5 in the table: 2,3,6,7,0,0,0,0

What does this mean? It means that we want to put the 2nd int64 element in the source vector and move it as the first element of the destination vector, take the 3rd element from the source as the second element in the destination and discard the rest (marked as 0,0,0,0 in the table).

This is a bit hard to follow because we have to compose the value out of the individual 32 bits words, but it works quite well. Or, at least, it would work, but not as efficiently. This is because we would need to load the PermuteTableInts into a variable and access it, but there are better ways to deal with it. We can also ask the JIT to embed the value directly. The problem is that the pattern that the JIT recognizes is limited to ReadOnlySpan<byte>, which means that the already non-trivial int32 table got turned into this:

This is the exact same data as before, but using ReadOnlySpan<byte> means that the JIT can package that inside the data section and treat it as a constant value.

The code was heavily optimized, to the point where I noticed a JIT bug where these two versions of the code give different assembly output:

Here is what we get out:

This looks like an unintended consequence of Roslyn and the JIT each doing their (separate jobs), but not reaching the end goal. Constant folding looks like it is done mostly by Roslyn, but it does a scan like that from the left, so it wouldn’t convert $A * 4 * 8 to $A * 32. That is because it stopped evaluating the constants when it found a variable. When we add parenthesis, we isolate the value and now understand that we can fold it.

Speaking of assembly, here is the annotated assembly version of the code:

And after all of this work, where are we standing?

Method N Mean Error StdDev Ratio RatioSD Code Size
FilterCmp 23 285.7 ns 3.84 ns 3.59 ns 1.00 0.00 411 B
FilterCmp_NoRangeCheck 23 272.6 ns 3.98 ns 3.53 ns 0.95 0.01 397 B
FilterCmp_Unroll_8 23 261.4 ns 1.27 ns 1.18 ns 0.91 0.01 672 B
FilterCmp_Avx 23 261.6 ns 1.37 ns 1.28 ns 0.92 0.01 521 B
               
FilterCmp 1047 758.7 ns 1.51 ns 1.42 ns 1.00 0.00 411 B
FilterCmp_NoRangeCheck 1047 756.8 ns 1.83 ns 1.53 ns 1.00 0.00 397 B
FilterCmp_Unroll_8 1047 640.4 ns 1.94 ns 1.82 ns 0.84 0.00 672 B
FilterCmp_Avx 1047 426.0 ns 1.62 ns 1.52 ns 0.56 0.00 521 B
               
FilterCmp 1048599 502,681.4 ns 3,732.37 ns 3,491.26 ns 1.00 0.00 411 B
FilterCmp_NoRangeCheck 1048599 499,472.7 ns 6,082.44 ns 5,689.52 ns 0.99 0.01 397 B
FilterCmp_Unroll_8 1048599 425,800.3 ns 352.45 ns 312.44 ns 0.85 0.01 672 B
FilterCmp_Avx 1048599 218,075.1 ns 212.40 ns 188.29 ns 0.43 0.00 521 B
               
FilterCmp 33554455 29,820,978.8 ns 73,461.68 ns 61,343.83 ns 1.00 0.00 411 B
FilterCmp_NoRangeCheck 33554455 29,471,229.2 ns 73,805.56 ns 69,037.77 ns 0.99 0.00 397 B
FilterCmp_Unroll_8 33554455 29,234,413.8 ns 67,597.45 ns 63,230.70 ns 0.98 0.00 672 B
FilterCmp_Avx 33554455 28,498,115.4 ns 71,661.94 ns 67,032.62 ns 0.96 0.00 521 B

So it seems that the idea of using SIMD instruction has a lot of merit. Moving from the original code to the final version, we see that we can complete the same task in up to half the time.

I’m not quite sure why we aren’t seeing the same sort of performance on the 32M, but I suspect that this is likely because we far exceed the CPU cache and we have to fetch it all from memory, so that is as fast as it can go.

If you are interested in learning more, Lemire solves the same problem in SVE (SIMD for ARM) and Paul has a similar approach in Rust.

If you can think of further optimizations, I would love to hear your ideas.

More posts in "Filtering negative numbers, fast" series:

  1. (15 Sep 2023) Beating memcpy()
  2. (13 Sep 2023) AVX
  3. (12 Sep 2023) Unroll
  4. (11 Sep 2023) Scalar