Numpy doesn't yet have a radix sort, so I wondered whether it was possible to write one using pre-existing numpy functions. So far I have the following, which does work, but is about 10 times slower than numpy's quicksort.
Test and benchmark:
a = np.random.randint(0, 1e8, 1e6)
assert(np.all(radix_sort(a) == np.sort(a)))
%timeit np.sort(a)
%timeit radix_sort(a)
The mask_b
loop can be at least partially vectorized, broadcasting out across masks from &
, and using cumsum
with axis
arg, but that ends up being a pessimization, presumably due to the increased memory footprint.
If anyone can see a way to improve on what I have I'd be interested to hear, even if it's still slower than np.sort
...this is more a case of intellectual curiosity and interest in numpy tricks.
Note that you can implement a fast counting sort easily enough, though that's only relevant for small integer data.
Edit 1: Taking np.arange(n)
out of the loop helps a little, but that's not very exiciting.
Edit 2: The cumsum
was actually redundant (ooops!) but this simpler version only helps marginally with performance..
def radix_sort(a):
bit_len = np.max(a).bit_length()
n = len(a)
cached_arange = arange(n)
idx = np.empty(n, dtype=int) # fully overwritten each iteration
for mask_b in xrange(bit_len):
is_one = (a & 2**mask_b).astype(bool)
n_ones = np.sum(is_one)
n_zeros = n-n_ones
idx[~is_one] = cached_arange[:n_zeros]
idx[is_one] = cached_arange[:n_ones] + n_zeros
# next three lines just do: a[idx] = a, but correctly
new_a = np.empty(n, dtype=a.dtype)
new_a[idx] = a
a = new_a
return a
Edit 3: rather than loop over single bits, you can loop over two or more at a time, if you construct idx in multiple steps. Using 2 bits helps a little, I've not tried more:
idx[is_zero] = np.arange(n_zeros)
idx[is_one] = np.arange(n_ones)
idx[is_two] = np.arange(n_twos)
idx[is_three] = np.arange(n_threes)
Edits 4 and 5: going to 4 bits seems best for the input I'm testing. Also, you can get rid of the idx
step entirely. Now only about 5 times, rather than 10 times, slower than np.sort
(source available as gist):
Edit 6: This is a tidied up version of the above, but it's also a tiny bit slower. 80% of the time is spent on repeat
and extract
- if only there was a way to broadcast the extract
:( ...
def radix_sort(a, batch_m_bits=3):
bit_len = np.max(a).bit_length()
batch_m = 2**batch_m_bits
mask = 2**batch_m_bits - 1
val_set = np.arange(batch_m, dtype=a.dtype)[:, nax] # nax = np.newaxis
for _ in range((bit_len-1)//batch_m_bits + 1): # ceil-division
a = np.extract((a & mask)[nax, :] == val_set,
np.repeat(a[nax, :], batch_m, axis=0))
val_set <<= batch_m_bits
mask <<= batch_m_bits
return a
Edits 7 & 8: Actually, you can broadcast the extract using as_strided
from numpy.lib.stride_tricks
, but it doesn't seem to help much performance-wise:
Initially this made sense to me on the grounds that extract
will be iterating over the whole array batch_m
times, so the total number of cache lines requested by the CPU will be the same as before (it's just that by the end of the process it has request each cache line batch_m
times). However the reality is that extract
is not clever enough to iterate over arbitrary stepped arrays, and has to expand out the array before beginning, i.e. the repeat ends up being done anyway.
In fact, having looked at the source for extract
, I now see that the best we can do with this approach is:
a = a[np.flatnonzero((a & mask)[nax, :] == val_set) % len(a)]
which is marginally slower than extract
. However, if len(a)
is a power of two we can replace the expensive mod operation with & (len(a) - 1)
, which does end up being a bit faster than the extract
version (now about 4.9x np.sort
for a=randint(0, 1e8, 2**20
). I suppose we could make this work for non-power of two lengths by zero-padding, and then cropping the extra zeros at the end of the sort...however this would be a pessimisation unless the length was already close to being power of two.
CodeReview
because there are a lot morenumpy
eyes onStackoverflow
. But for coding style and pure Python issues, CR is good. – hpaulj