Skip to content

Commit

Permalink
Merge pull request #168 from sterrettm2/kv-pivot
Browse files Browse the repository at this point in the history
Adds smart pivot selection to key-value sorting
  • Loading branch information
r-devulap authored Oct 7, 2024
2 parents f99c392 + 8d378c9 commit d62f656
Show file tree
Hide file tree
Showing 5 changed files with 55 additions and 12 deletions.
28 changes: 27 additions & 1 deletion src/avx2-32bit-half.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,11 @@ struct avx2_half_vector<int32_t> {
{
return _mm_set1_epi32(type_max());
} // TODO: this should broadcast bits as is?
static opmask_t knot_opmask(opmask_t x)
{
auto allOnes = seti(-1, -1, -1, -1);
return _mm_xor_si128(x, allOnes);
}
static opmask_t get_partial_loadmask(uint64_t num_to_read)
{
auto mask = ((0x1ull << num_to_read) - 0x1ull);
Expand Down Expand Up @@ -186,6 +191,10 @@ struct avx2_half_vector<int32_t> {
{
return v;
}
static bool all_false(opmask_t k)
{
return _mm_movemask_ps(_mm_castsi128_ps(k)) == 0;
}
static int double_compressstore(type_t *left_addr,
type_t *right_addr,
opmask_t k,
Expand Down Expand Up @@ -218,6 +227,11 @@ struct avx2_half_vector<uint32_t> {
{
return _mm_set1_epi32(type_max());
}
static opmask_t knot_opmask(opmask_t x)
{
auto allOnes = seti(-1, -1, -1, -1);
return _mm_xor_si128(x, allOnes);
}
static opmask_t get_partial_loadmask(uint64_t num_to_read)
{
auto mask = ((0x1ull << num_to_read) - 0x1ull);
Expand Down Expand Up @@ -331,6 +345,10 @@ struct avx2_half_vector<uint32_t> {
{
return v;
}
static bool all_false(opmask_t k)
{
return _mm_movemask_ps(_mm_castsi128_ps(k)) == 0;
}
static int double_compressstore(type_t *left_addr,
type_t *right_addr,
opmask_t k,
Expand Down Expand Up @@ -363,7 +381,11 @@ struct avx2_half_vector<float> {
{
return _mm_set1_ps(type_max());
}

static opmask_t knot_opmask(opmask_t x)
{
auto allOnes = seti(-1, -1, -1, -1);
return _mm_xor_si128(x, allOnes);
}
static regi_t seti(int v1, int v2, int v3, int v4)
{
return _mm_set_epi32(v1, v2, v3, v4);
Expand Down Expand Up @@ -492,6 +514,10 @@ struct avx2_half_vector<float> {
{
return _mm_castps_si128(v);
}
static bool all_false(opmask_t k)
{
return _mm_movemask_ps(_mm_castsi128_ps(k)) == 0;
}
static int double_compressstore(type_t *left_addr,
type_t *right_addr,
opmask_t k,
Expand Down
12 changes: 12 additions & 0 deletions src/avx512-64bit-common.h
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,10 @@ struct ymm_vector<float> {
{
return _mm256_castps_si256(v);
}
static bool all_false(opmask_t k)
{
return k == 0;
}
static reg_t reverse(reg_t ymm)
{
const __m256i rev_index = _mm256_set_epi32(NETWORK_32BIT_AVX2_2);
Expand Down Expand Up @@ -394,6 +398,10 @@ struct ymm_vector<uint32_t> {
{
return v;
}
static bool all_false(opmask_t k)
{
return k == 0;
}
static reg_t reverse(reg_t ymm)
{
const __m256i rev_index = _mm256_set_epi32(NETWORK_32BIT_AVX2_2);
Expand Down Expand Up @@ -578,6 +586,10 @@ struct ymm_vector<int32_t> {
{
return v;
}
static bool all_false(opmask_t k)
{
return k == 0;
}
static reg_t reverse(reg_t ymm)
{
const __m256i rev_index = _mm256_set_epi32(NETWORK_32BIT_AVX2_2);
Expand Down
18 changes: 14 additions & 4 deletions src/xss-common-keyvaluesort.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ X86_SIMD_SORT_INLINE arrsize_t kvpartition(type_t1 *keys,
for (int32_t i = (right - left) % vtype1::numlanes; i > 0; --i) {
*smallest = std::min(*smallest, keys[left]);
*biggest = std::max(*biggest, keys[left]);
if (keys[left] > pivot) {
if (keys[left] >= pivot) {
right--;
std::swap(keys[left], keys[right]);
std::swap(indexes[left], indexes[right]);
Expand Down Expand Up @@ -204,12 +204,13 @@ X86_SIMD_SORT_INLINE arrsize_t kvpartition_unrolled(type_t1 *keys,
return kvpartition<vtype1, vtype2>(
keys, indexes, left, right, pivot, smallest, biggest);
}

/* make array length divisible by vtype1::numlanes , shortening the array */
for (int32_t i = ((right - left) % (num_unroll * vtype1::numlanes)); i > 0;
--i) {
*smallest = std::min(*smallest, keys[left]);
*biggest = std::max(*biggest, keys[left]);
if (keys[left] > pivot) {
if (keys[left] >= pivot) {
right--;
std::swap(keys[left], keys[right]);
std::swap(indexes[left], indexes[right]);
Expand Down Expand Up @@ -386,18 +387,27 @@ X86_SIMD_SORT_INLINE void kvsort_(type1_t *keys,
* Base case: use bitonic networks to sort arrays <= 128
*/
if (right + 1 - left <= 128) {

kvsort_n<vtype1, vtype2, 128>(
keys + left, indexes + left, (int32_t)(right + 1 - left));
return;
}

type1_t pivot = get_pivot_blocks<vtype1>(keys, left, right);
// Ascending comparator for this vtype
using comparator = Comparator<vtype1, false>;
type1_t pivot;
auto pivot_result
= get_pivot_smart<vtype1, comparator, type1_t>(keys, left, right);
pivot = pivot_result.pivot;

if (pivot_result.result == pivot_result_t::Sorted) { return; }

type1_t smallest = vtype1::type_max();
type1_t biggest = vtype1::type_min();
arrsize_t pivot_index = kvpartition_unrolled<vtype1, vtype2, 4>(
keys, indexes, left, right + 1, pivot, &smallest, &biggest);

if (pivot_result.result == pivot_result_t::Only2Values) { return; }

#ifdef XSS_COMPILE_OPENMP
if (pivot != smallest) {
bool parallel_left = (pivot_index - left) > task_threshold;
Expand Down
5 changes: 0 additions & 5 deletions src/xss-pivot-selection.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -148,12 +148,7 @@ get_pivot_smart(type_t *arr, const arrsize_t left, const arrsize_t right)
return pivot_results<type_t>(
comparator::choosePivotMedianIsLargest(median));
}
else {
// Should be unreachable
return pivot_results<type_t>(median);
}

// Should be unreachable
return pivot_results<type_t>(median);
}

Expand Down
4 changes: 2 additions & 2 deletions utils/rand_array.h
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,8 @@ static std::vector<T> get_array(std::string arrtype,
else {
val = std::numeric_limits<T>::max();
}
for (size_t ii = 1; ii <= arrsize; ++ii) {
if (rand() % 0x1) { arr[ii] = val; }
for (size_t ii = 0; ii < arrsize; ++ii) {
if (rand() & 0x1) { arr[ii] = val; }
}
}
else {
Expand Down

0 comments on commit d62f656

Please sign in to comment.