Skip to content

Commit

Permalink
Changed quicksort and quickselect to use template based sorting networks
Browse files Browse the repository at this point in the history
  • Loading branch information
sterrettm2 committed Aug 16, 2023
1 parent 42c672f commit 70424a6
Show file tree
Hide file tree
Showing 7 changed files with 310 additions and 1,001 deletions.
102 changes: 3 additions & 99 deletions src/avx512-16bit-common.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#define AVX512_16BIT_COMMON

#include "avx512-common-qsort.h"
#include "xss-network-qsort.hpp"

/*
* Constants used in sorting 32 elements in a ZMM registers. Based on Bitonic
Expand Down Expand Up @@ -118,103 +119,6 @@ X86_SIMD_SORT_INLINE zmm_t bitonic_merge_zmm_16bit(zmm_t zmm)
return zmm;
}

// Assumes zmm1 and zmm2 are sorted and performs a recursive half cleaner
template <typename vtype, typename zmm_t = typename vtype::zmm_t>
X86_SIMD_SORT_INLINE void bitonic_merge_two_zmm_16bit(zmm_t &zmm1, zmm_t &zmm2)
{
// 1) First step of a merging network: coex of zmm1 and zmm2 reversed
zmm2 = vtype::permutexvar(vtype::get_network(4), zmm2);
zmm_t zmm3 = vtype::min(zmm1, zmm2);
zmm_t zmm4 = vtype::max(zmm1, zmm2);
// 2) Recursive half cleaner for each
zmm1 = bitonic_merge_zmm_16bit<vtype>(zmm3);
zmm2 = bitonic_merge_zmm_16bit<vtype>(zmm4);
}

// Assumes [zmm0, zmm1] and [zmm2, zmm3] are sorted and performs a recursive
// half cleaner
template <typename vtype, typename zmm_t = typename vtype::zmm_t>
X86_SIMD_SORT_INLINE void bitonic_merge_four_zmm_16bit(zmm_t *zmm)
{
zmm_t zmm2r = vtype::permutexvar(vtype::get_network(4), zmm[2]);
zmm_t zmm3r = vtype::permutexvar(vtype::get_network(4), zmm[3]);
zmm_t zmm_t1 = vtype::min(zmm[0], zmm3r);
zmm_t zmm_t2 = vtype::min(zmm[1], zmm2r);
zmm_t zmm_t3 = vtype::permutexvar(vtype::get_network(4),
vtype::max(zmm[1], zmm2r));
zmm_t zmm_t4 = vtype::permutexvar(vtype::get_network(4),
vtype::max(zmm[0], zmm3r));
zmm_t zmm0 = vtype::min(zmm_t1, zmm_t2);
zmm_t zmm1 = vtype::max(zmm_t1, zmm_t2);
zmm_t zmm2 = vtype::min(zmm_t3, zmm_t4);
zmm_t zmm3 = vtype::max(zmm_t3, zmm_t4);
zmm[0] = bitonic_merge_zmm_16bit<vtype>(zmm0);
zmm[1] = bitonic_merge_zmm_16bit<vtype>(zmm1);
zmm[2] = bitonic_merge_zmm_16bit<vtype>(zmm2);
zmm[3] = bitonic_merge_zmm_16bit<vtype>(zmm3);
}

template <typename vtype, typename type_t>
X86_SIMD_SORT_INLINE void sort_32_16bit(type_t *arr, int32_t N)
{
typename vtype::opmask_t load_mask = ((0x1ull << N) - 0x1ull) & 0xFFFFFFFF;
typename vtype::zmm_t zmm
= vtype::mask_loadu(vtype::zmm_max(), load_mask, arr);
vtype::mask_storeu(arr, load_mask, sort_zmm_16bit<vtype>(zmm));
}

template <typename vtype, typename type_t>
X86_SIMD_SORT_INLINE void sort_64_16bit(type_t *arr, int32_t N)
{
if (N <= 32) {
sort_32_16bit<vtype>(arr, N);
return;
}
using zmm_t = typename vtype::zmm_t;
typename vtype::opmask_t load_mask
= ((0x1ull << (N - 32)) - 0x1ull) & 0xFFFFFFFF;
zmm_t zmm1 = vtype::loadu(arr);
zmm_t zmm2 = vtype::mask_loadu(vtype::zmm_max(), load_mask, arr + 32);
zmm1 = sort_zmm_16bit<vtype>(zmm1);
zmm2 = sort_zmm_16bit<vtype>(zmm2);
bitonic_merge_two_zmm_16bit<vtype>(zmm1, zmm2);
vtype::storeu(arr, zmm1);
vtype::mask_storeu(arr + 32, load_mask, zmm2);
}

template <typename vtype, typename type_t>
X86_SIMD_SORT_INLINE void sort_128_16bit(type_t *arr, int32_t N)
{
if (N <= 64) {
sort_64_16bit<vtype>(arr, N);
return;
}
using zmm_t = typename vtype::zmm_t;
using opmask_t = typename vtype::opmask_t;
zmm_t zmm[4];
zmm[0] = vtype::loadu(arr);
zmm[1] = vtype::loadu(arr + 32);
opmask_t load_mask1 = 0xFFFFFFFF, load_mask2 = 0xFFFFFFFF;
if (N != 128) {
uint64_t combined_mask = (0x1ull << (N - 64)) - 0x1ull;
load_mask1 = combined_mask & 0xFFFFFFFF;
load_mask2 = (combined_mask >> 32) & 0xFFFFFFFF;
}
zmm[2] = vtype::mask_loadu(vtype::zmm_max(), load_mask1, arr + 64);
zmm[3] = vtype::mask_loadu(vtype::zmm_max(), load_mask2, arr + 96);
zmm[0] = sort_zmm_16bit<vtype>(zmm[0]);
zmm[1] = sort_zmm_16bit<vtype>(zmm[1]);
zmm[2] = sort_zmm_16bit<vtype>(zmm[2]);
zmm[3] = sort_zmm_16bit<vtype>(zmm[3]);
bitonic_merge_two_zmm_16bit<vtype>(zmm[0], zmm[1]);
bitonic_merge_two_zmm_16bit<vtype>(zmm[2], zmm[3]);
bitonic_merge_four_zmm_16bit<vtype>(zmm);
vtype::storeu(arr, zmm[0]);
vtype::storeu(arr + 32, zmm[1]);
vtype::mask_storeu(arr + 64, load_mask1, zmm[2]);
vtype::mask_storeu(arr + 96, load_mask2, zmm[3]);
}

template <typename vtype, typename type_t>
X86_SIMD_SORT_INLINE type_t get_pivot_16bit(type_t *arr,
const int64_t left,
Expand Down Expand Up @@ -274,7 +178,7 @@ qsort_16bit_(type_t *arr, int64_t left, int64_t right, int64_t max_iters)
* Base case: use bitonic networks to sort arrays <= 128
*/
if (right + 1 - left <= 128) {
sort_128_16bit<vtype>(arr + left, (int32_t)(right + 1 - left));
xss::sort_n<vtype, 128>(arr + left, (int32_t)(right + 1 - left));
return;
}

Expand Down Expand Up @@ -307,7 +211,7 @@ static void qselect_16bit_(type_t *arr,
* Base case: use bitonic networks to sort arrays <= 128
*/
if (right + 1 - left <= 128) {
sort_128_16bit<vtype>(arr + left, (int32_t)(right + 1 - left));
xss::sort_n<vtype, 128>(arr + left, (int32_t)(right + 1 - left));
return;
}

Expand Down
40 changes: 40 additions & 0 deletions src/avx512-16bit-qsort.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#define AVX512_QSORT_16BIT

#include "avx512-16bit-common.h"
#include "xss-network-qsort.hpp"

struct float16 {
uint16_t val;
Expand Down Expand Up @@ -152,6 +153,19 @@ struct zmm_vector<float16> {
{
return _mm512_storeu_si512(mem, x);
}
static zmm_t reverse(zmm_t zmm)
{
const auto rev_index = get_network(4);
return permutexvar(rev_index, zmm);
}
static zmm_t bitonic_merge(zmm_t x)
{
return bitonic_merge_zmm_16bit<zmm_vector<float16>>(x);
}
static zmm_t sort_vec(zmm_t x)
{
return sort_zmm_16bit<zmm_vector<float16>>(x);
}
};

template <>
Expand Down Expand Up @@ -251,6 +265,19 @@ struct zmm_vector<int16_t> {
{
return _mm512_storeu_si512(mem, x);
}
static zmm_t reverse(zmm_t zmm)
{
const auto rev_index = get_network(4);
return permutexvar(rev_index, zmm);
}
static zmm_t bitonic_merge(zmm_t x)
{
return bitonic_merge_zmm_16bit<zmm_vector<type_t>>(x);
}
static zmm_t sort_vec(zmm_t x)
{
return sort_zmm_16bit<zmm_vector<type_t>>(x);
}
};
template <>
struct zmm_vector<uint16_t> {
Expand Down Expand Up @@ -347,6 +374,19 @@ struct zmm_vector<uint16_t> {
{
return _mm512_storeu_si512(mem, x);
}
static zmm_t reverse(zmm_t zmm)
{
const auto rev_index = get_network(4);
return permutexvar(rev_index, zmm);
}
static zmm_t bitonic_merge(zmm_t x)
{
return bitonic_merge_zmm_16bit<zmm_vector<type_t>>(x);
}
static zmm_t sort_vec(zmm_t x)
{
return sort_zmm_16bit<zmm_vector<type_t>>(x);
}
};

template <>
Expand Down
Loading

0 comments on commit 70424a6

Please sign in to comment.