Skip to content

Commit

Permalink
Merge pull request #152 from sterrettm2/swizzle_cleanup
Browse files Browse the repository at this point in the history
Cleanup for single vector sort/bitonic merge (and minor cleanup for argsort/argselect)
  • Loading branch information
r-devulap authored Oct 22, 2024
2 parents d62f656 + 3b41715 commit 990ae6c
Show file tree
Hide file tree
Showing 13 changed files with 597 additions and 722 deletions.
54 changes: 18 additions & 36 deletions src/avx2-32bit-half.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,36 +9,6 @@

#include "avx2-emu-funcs.hpp"

/*
* Constants used in sorting 8 elements in a ymm registers. Based on Bitonic
* sorting network (see
* https://en.wikipedia.org/wiki/Bitonic_sorter#/media/File:BitonicSort.svg)
*/

// ymm 7, 6, 5, 4, 3, 2, 1, 0
#define NETWORK_32BIT_AVX2_1 4, 5, 6, 7, 0, 1, 2, 3
#define NETWORK_32BIT_AVX2_2 0, 1, 2, 3, 4, 5, 6, 7
#define NETWORK_32BIT_AVX2_3 5, 4, 7, 6, 1, 0, 3, 2
#define NETWORK_32BIT_AVX2_4 3, 2, 1, 0, 7, 6, 5, 4

/*
* Assumes ymm is random and performs a full sorting network defined in
* https://en.wikipedia.org/wiki/Bitonic_sorter#/media/File:BitonicSort.svg
*/
template <typename vtype, typename reg_t = typename vtype::reg_t>
X86_SIMD_SORT_INLINE reg_t sort_ymm_32bit_half(reg_t ymm)
{
using swizzle = typename vtype::swizzle_ops;

const typename vtype::opmask_t oxAA = vtype::seti(-1, 0, -1, 0);
const typename vtype::opmask_t oxCC = vtype::seti(-1, -1, 0, 0);

ymm = cmp_merge<vtype>(ymm, swizzle::template swap_n<vtype, 2>(ymm), oxAA);
ymm = cmp_merge<vtype>(ymm, vtype::reverse(ymm), oxCC);
ymm = cmp_merge<vtype>(ymm, swizzle::template swap_n<vtype, 2>(ymm), oxAA);
return ymm;
}

struct avx2_32bit_half_swizzle_ops;

template <>
Expand Down Expand Up @@ -74,6 +44,10 @@ struct avx2_half_vector<int32_t> {
auto mask = ((0x1ull << num_to_read) - 0x1ull);
return convert_int_to_avx2_mask_half(mask);
}
static opmask_t convert_int_to_mask(uint64_t intMask)
{
return convert_int_to_avx2_mask_half(intMask);
}
static regi_t seti(int v1, int v2, int v3, int v4)
{
return _mm_set_epi32(v1, v2, v3, v4);
Expand Down Expand Up @@ -155,7 +129,7 @@ struct avx2_half_vector<int32_t> {
}
static reg_t reverse(reg_t ymm)
{
const __m128i rev_index = _mm_set_epi32(0, 1, 2, 3);
const __m128i rev_index = _mm_set_epi32(NETWORK_REVERSE_4LANES);
return permutexvar(rev_index, ymm);
}
static type_t reducemax(reg_t v)
Expand All @@ -181,7 +155,7 @@ struct avx2_half_vector<int32_t> {
}
static reg_t sort_vec(reg_t x)
{
return sort_ymm_32bit_half<avx2_half_vector<type_t>>(x);
return sort_reg_4lanes<avx2_half_vector<type_t>>(x);
}
static reg_t cast_from(__m128i v)
{
Expand Down Expand Up @@ -237,6 +211,10 @@ struct avx2_half_vector<uint32_t> {
auto mask = ((0x1ull << num_to_read) - 0x1ull);
return convert_int_to_avx2_mask_half(mask);
}
static opmask_t convert_int_to_mask(uint64_t intMask)
{
return convert_int_to_avx2_mask_half(intMask);
}
static regi_t seti(int v1, int v2, int v3, int v4)
{
return _mm_set_epi32(v1, v2, v3, v4);
Expand Down Expand Up @@ -309,7 +287,7 @@ struct avx2_half_vector<uint32_t> {
}
static reg_t reverse(reg_t ymm)
{
const __m128i rev_index = _mm_set_epi32(0, 1, 2, 3);
const __m128i rev_index = _mm_set_epi32(NETWORK_REVERSE_4LANES);
return permutexvar(rev_index, ymm);
}
static type_t reducemax(reg_t v)
Expand All @@ -335,7 +313,7 @@ struct avx2_half_vector<uint32_t> {
}
static reg_t sort_vec(reg_t x)
{
return sort_ymm_32bit_half<avx2_half_vector<type_t>>(x);
return sort_reg_4lanes<avx2_half_vector<type_t>>(x);
}
static reg_t cast_from(__m128i v)
{
Expand Down Expand Up @@ -411,6 +389,10 @@ struct avx2_half_vector<float> {
auto mask = ((0x1ull << num_to_read) - 0x1ull);
return convert_int_to_avx2_mask_half(mask);
}
static opmask_t convert_int_to_mask(uint64_t intMask)
{
return convert_int_to_avx2_mask_half(intMask);
}
static int32_t convert_mask_to_int(opmask_t mask)
{
return convert_avx2_mask_to_int_half(mask);
Expand Down Expand Up @@ -478,7 +460,7 @@ struct avx2_half_vector<float> {
}
static reg_t reverse(reg_t ymm)
{
const __m128i rev_index = _mm_set_epi32(0, 1, 2, 3);
const __m128i rev_index = _mm_set_epi32(NETWORK_REVERSE_4LANES);
return permutexvar(rev_index, ymm);
}
static type_t reducemax(reg_t v)
Expand All @@ -504,7 +486,7 @@ struct avx2_half_vector<float> {
}
static reg_t sort_vec(reg_t x)
{
return sort_ymm_32bit_half<avx2_half_vector<type_t>>(x);
return sort_reg_4lanes<avx2_half_vector<type_t>>(x);
}
static reg_t cast_from(__m128i v)
{
Expand Down
57 changes: 6 additions & 51 deletions src/avx2-32bit-qsort.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,51 +9,6 @@

#include "avx2-emu-funcs.hpp"

/*
* Constants used in sorting 8 elements in a ymm registers. Based on Bitonic
* sorting network (see
* https://en.wikipedia.org/wiki/Bitonic_sorter#/media/File:BitonicSort.svg)
*/

// ymm 7, 6, 5, 4, 3, 2, 1, 0
#define NETWORK_32BIT_AVX2_1 4, 5, 6, 7, 0, 1, 2, 3
#define NETWORK_32BIT_AVX2_2 0, 1, 2, 3, 4, 5, 6, 7
#define NETWORK_32BIT_AVX2_3 5, 4, 7, 6, 1, 0, 3, 2
#define NETWORK_32BIT_AVX2_4 3, 2, 1, 0, 7, 6, 5, 4

/*
* Assumes ymm is random and performs a full sorting network defined in
* https://en.wikipedia.org/wiki/Bitonic_sorter#/media/File:BitonicSort.svg
*/
template <typename vtype, typename reg_t = typename vtype::reg_t>
X86_SIMD_SORT_INLINE reg_t sort_ymm_32bit(reg_t ymm)
{
const typename vtype::opmask_t oxAA = _mm256_set_epi32(
0xFFFFFFFF, 0, 0xFFFFFFFF, 0, 0xFFFFFFFF, 0, 0xFFFFFFFF, 0);
const typename vtype::opmask_t oxCC = _mm256_set_epi32(
0xFFFFFFFF, 0xFFFFFFFF, 0, 0, 0xFFFFFFFF, 0xFFFFFFFF, 0, 0);
const typename vtype::opmask_t oxF0 = _mm256_set_epi32(
0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0, 0, 0, 0);

const typename vtype::ymmi_t rev_index = vtype::seti(NETWORK_32BIT_AVX2_2);
ymm = cmp_merge<vtype>(
ymm, vtype::template shuffle<SHUFFLE_MASK(2, 3, 0, 1)>(ymm), oxAA);
ymm = cmp_merge<vtype>(
ymm,
vtype::permutexvar(vtype::seti(NETWORK_32BIT_AVX2_1), ymm),
oxCC);
ymm = cmp_merge<vtype>(
ymm, vtype::template shuffle<SHUFFLE_MASK(2, 3, 0, 1)>(ymm), oxAA);
ymm = cmp_merge<vtype>(ymm, vtype::permutexvar(rev_index, ymm), oxF0);
ymm = cmp_merge<vtype>(
ymm,
vtype::permutexvar(vtype::seti(NETWORK_32BIT_AVX2_3), ymm),
oxCC);
ymm = cmp_merge<vtype>(
ymm, vtype::template shuffle<SHUFFLE_MASK(2, 3, 0, 1)>(ymm), oxAA);
return ymm;
}

struct avx2_32bit_swizzle_ops;

template <>
Expand Down Expand Up @@ -180,7 +135,7 @@ struct avx2_vector<int32_t> {
}
static reg_t reverse(reg_t ymm)
{
const __m256i rev_index = _mm256_set_epi32(NETWORK_32BIT_AVX2_2);
const __m256i rev_index = _mm256_set_epi32(NETWORK_REVERSE_8LANES);
return permutexvar(rev_index, ymm);
}
static type_t reducemax(reg_t v)
Expand All @@ -206,7 +161,7 @@ struct avx2_vector<int32_t> {
}
static reg_t sort_vec(reg_t x)
{
return sort_ymm_32bit<avx2_vector<type_t>>(x);
return sort_reg_8lanes<avx2_vector<type_t>>(x);
}
static reg_t cast_from(__m256i v)
{
Expand Down Expand Up @@ -342,7 +297,7 @@ struct avx2_vector<uint32_t> {
}
static reg_t reverse(reg_t ymm)
{
const __m256i rev_index = _mm256_set_epi32(NETWORK_32BIT_AVX2_2);
const __m256i rev_index = _mm256_set_epi32(NETWORK_REVERSE_8LANES);
return permutexvar(rev_index, ymm);
}
static type_t reducemax(reg_t v)
Expand All @@ -368,7 +323,7 @@ struct avx2_vector<uint32_t> {
}
static reg_t sort_vec(reg_t x)
{
return sort_ymm_32bit<avx2_vector<type_t>>(x);
return sort_reg_8lanes<avx2_vector<type_t>>(x);
}
static reg_t cast_from(__m256i v)
{
Expand Down Expand Up @@ -520,7 +475,7 @@ struct avx2_vector<float> {
}
static reg_t reverse(reg_t ymm)
{
const __m256i rev_index = _mm256_set_epi32(NETWORK_32BIT_AVX2_2);
const __m256i rev_index = _mm256_set_epi32(NETWORK_REVERSE_8LANES);
return permutexvar(rev_index, ymm);
}
static type_t reducemax(reg_t v)
Expand All @@ -547,7 +502,7 @@ struct avx2_vector<float> {
}
static reg_t sort_vec(reg_t x)
{
return sort_ymm_32bit<avx2_vector<type_t>>(x);
return sort_reg_8lanes<avx2_vector<type_t>>(x);
}
static reg_t cast_from(__m256i v)
{
Expand Down
44 changes: 15 additions & 29 deletions src/avx2-64bit-qsort.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,32 +10,6 @@

#include "avx2-emu-funcs.hpp"

/*
* Assumes ymm is random and performs a full sorting network defined in
* https://en.wikipedia.org/wiki/Bitonic_sorter#/media/File:BitonicSort.svg
*/
template <typename vtype, typename reg_t = typename vtype::reg_t>
X86_SIMD_SORT_INLINE reg_t sort_ymm_64bit(reg_t ymm)
{
const typename vtype::opmask_t oxAA
= _mm256_set_epi64x(0xFFFFFFFFFFFFFFFF, 0, 0xFFFFFFFFFFFFFFFF, 0);
const typename vtype::opmask_t oxCC
= _mm256_set_epi64x(0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0, 0);
ymm = cmp_merge<vtype>(
ymm,
vtype::template permutexvar<SHUFFLE_MASK(2, 3, 0, 1)>(ymm),
oxAA);
ymm = cmp_merge<vtype>(
ymm,
vtype::template permutexvar<SHUFFLE_MASK(0, 1, 2, 3)>(ymm),
oxCC);
ymm = cmp_merge<vtype>(
ymm,
vtype::template permutexvar<SHUFFLE_MASK(2, 3, 0, 1)>(ymm),
oxAA);
return ymm;
}

struct avx2_64bit_swizzle_ops;

template <>
Expand Down Expand Up @@ -81,6 +55,10 @@ struct avx2_vector<int64_t> {
auto mask = ((0x1ull << num_to_read) - 0x1ull);
return convert_int_to_avx2_mask_64bit(mask);
}
static opmask_t convert_int_to_mask(uint64_t intMask)
{
return convert_int_to_avx2_mask_64bit(intMask);
}
static ymmi_t seti(int64_t v1, int64_t v2, int64_t v3, int64_t v4)
{
return _mm256_set_epi64x(v1, v2, v3, v4);
Expand Down Expand Up @@ -207,7 +185,7 @@ struct avx2_vector<int64_t> {
}
static reg_t sort_vec(reg_t x)
{
return sort_ymm_64bit<avx2_vector<type_t>>(x);
return sort_reg_4lanes<avx2_vector<type_t>>(x);
}
static reg_t cast_from(__m256i v)
{
Expand Down Expand Up @@ -265,6 +243,10 @@ struct avx2_vector<uint64_t> {
auto mask = ((0x1ull << num_to_read) - 0x1ull);
return convert_int_to_avx2_mask_64bit(mask);
}
static opmask_t convert_int_to_mask(uint64_t intMask)
{
return convert_int_to_avx2_mask_64bit(intMask);
}
static ymmi_t seti(int64_t v1, int64_t v2, int64_t v3, int64_t v4)
{
return _mm256_set_epi64x(v1, v2, v3, v4);
Expand Down Expand Up @@ -389,7 +371,7 @@ struct avx2_vector<uint64_t> {
}
static reg_t sort_vec(reg_t x)
{
return sort_ymm_64bit<avx2_vector<type_t>>(x);
return sort_reg_4lanes<avx2_vector<type_t>>(x);
}
static reg_t cast_from(__m256i v)
{
Expand Down Expand Up @@ -460,6 +442,10 @@ struct avx2_vector<double> {
auto mask = ((0x1ull << num_to_read) - 0x1ull);
return convert_int_to_avx2_mask_64bit(mask);
}
static opmask_t convert_int_to_mask(uint64_t intMask)
{
return convert_int_to_avx2_mask_64bit(intMask);
}
static int32_t convert_mask_to_int(opmask_t mask)
{
return convert_avx2_mask_to_int_64bit(mask);
Expand Down Expand Up @@ -593,7 +579,7 @@ struct avx2_vector<double> {
}
static reg_t sort_vec(reg_t x)
{
return sort_ymm_64bit<avx2_vector<type_t>>(x);
return sort_reg_4lanes<avx2_vector<type_t>>(x);
}
static reg_t cast_from(__m256i v)
{
Expand Down
Loading

0 comments on commit 990ae6c

Please sign in to comment.