diff --git a/src/avx2-32bit-half.hpp b/src/avx2-32bit-half.hpp index 9e782bb..3646979 100644 --- a/src/avx2-32bit-half.hpp +++ b/src/avx2-32bit-half.hpp @@ -9,36 +9,6 @@ #include "avx2-emu-funcs.hpp" -/* - * Constants used in sorting 8 elements in a ymm registers. Based on Bitonic - * sorting network (see - * https://en.wikipedia.org/wiki/Bitonic_sorter#/media/File:BitonicSort.svg) - */ - -// ymm 7, 6, 5, 4, 3, 2, 1, 0 -#define NETWORK_32BIT_AVX2_1 4, 5, 6, 7, 0, 1, 2, 3 -#define NETWORK_32BIT_AVX2_2 0, 1, 2, 3, 4, 5, 6, 7 -#define NETWORK_32BIT_AVX2_3 5, 4, 7, 6, 1, 0, 3, 2 -#define NETWORK_32BIT_AVX2_4 3, 2, 1, 0, 7, 6, 5, 4 - -/* - * Assumes ymm is random and performs a full sorting network defined in - * https://en.wikipedia.org/wiki/Bitonic_sorter#/media/File:BitonicSort.svg - */ -template -X86_SIMD_SORT_INLINE reg_t sort_ymm_32bit_half(reg_t ymm) -{ - using swizzle = typename vtype::swizzle_ops; - - const typename vtype::opmask_t oxAA = vtype::seti(-1, 0, -1, 0); - const typename vtype::opmask_t oxCC = vtype::seti(-1, -1, 0, 0); - - ymm = cmp_merge(ymm, swizzle::template swap_n(ymm), oxAA); - ymm = cmp_merge(ymm, vtype::reverse(ymm), oxCC); - ymm = cmp_merge(ymm, swizzle::template swap_n(ymm), oxAA); - return ymm; -} - struct avx2_32bit_half_swizzle_ops; template <> @@ -74,6 +44,10 @@ struct avx2_half_vector { auto mask = ((0x1ull << num_to_read) - 0x1ull); return convert_int_to_avx2_mask_half(mask); } + static opmask_t convert_int_to_mask(uint64_t intMask) + { + return convert_int_to_avx2_mask_half(intMask); + } static regi_t seti(int v1, int v2, int v3, int v4) { return _mm_set_epi32(v1, v2, v3, v4); @@ -155,7 +129,7 @@ struct avx2_half_vector { } static reg_t reverse(reg_t ymm) { - const __m128i rev_index = _mm_set_epi32(0, 1, 2, 3); + const __m128i rev_index = _mm_set_epi32(NETWORK_REVERSE_4LANES); return permutexvar(rev_index, ymm); } static type_t reducemax(reg_t v) @@ -181,7 +155,7 @@ struct avx2_half_vector { } static reg_t sort_vec(reg_t x) { - return sort_ymm_32bit_half>(x); + return sort_reg_4lanes>(x); } static reg_t cast_from(__m128i v) { @@ -237,6 +211,10 @@ struct avx2_half_vector { auto mask = ((0x1ull << num_to_read) - 0x1ull); return convert_int_to_avx2_mask_half(mask); } + static opmask_t convert_int_to_mask(uint64_t intMask) + { + return convert_int_to_avx2_mask_half(intMask); + } static regi_t seti(int v1, int v2, int v3, int v4) { return _mm_set_epi32(v1, v2, v3, v4); @@ -309,7 +287,7 @@ struct avx2_half_vector { } static reg_t reverse(reg_t ymm) { - const __m128i rev_index = _mm_set_epi32(0, 1, 2, 3); + const __m128i rev_index = _mm_set_epi32(NETWORK_REVERSE_4LANES); return permutexvar(rev_index, ymm); } static type_t reducemax(reg_t v) @@ -335,7 +313,7 @@ struct avx2_half_vector { } static reg_t sort_vec(reg_t x) { - return sort_ymm_32bit_half>(x); + return sort_reg_4lanes>(x); } static reg_t cast_from(__m128i v) { @@ -411,6 +389,10 @@ struct avx2_half_vector { auto mask = ((0x1ull << num_to_read) - 0x1ull); return convert_int_to_avx2_mask_half(mask); } + static opmask_t convert_int_to_mask(uint64_t intMask) + { + return convert_int_to_avx2_mask_half(intMask); + } static int32_t convert_mask_to_int(opmask_t mask) { return convert_avx2_mask_to_int_half(mask); @@ -478,7 +460,7 @@ struct avx2_half_vector { } static reg_t reverse(reg_t ymm) { - const __m128i rev_index = _mm_set_epi32(0, 1, 2, 3); + const __m128i rev_index = _mm_set_epi32(NETWORK_REVERSE_4LANES); return permutexvar(rev_index, ymm); } static type_t reducemax(reg_t v) @@ -504,7 +486,7 @@ struct avx2_half_vector { } static reg_t sort_vec(reg_t x) { - return sort_ymm_32bit_half>(x); + return sort_reg_4lanes>(x); } static reg_t cast_from(__m128i v) { diff --git a/src/avx2-32bit-qsort.hpp b/src/avx2-32bit-qsort.hpp index 56ffc23..7c7218e 100644 --- a/src/avx2-32bit-qsort.hpp +++ b/src/avx2-32bit-qsort.hpp @@ -9,51 +9,6 @@ #include "avx2-emu-funcs.hpp" -/* - * Constants used in sorting 8 elements in a ymm registers. Based on Bitonic - * sorting network (see - * https://en.wikipedia.org/wiki/Bitonic_sorter#/media/File:BitonicSort.svg) - */ - -// ymm 7, 6, 5, 4, 3, 2, 1, 0 -#define NETWORK_32BIT_AVX2_1 4, 5, 6, 7, 0, 1, 2, 3 -#define NETWORK_32BIT_AVX2_2 0, 1, 2, 3, 4, 5, 6, 7 -#define NETWORK_32BIT_AVX2_3 5, 4, 7, 6, 1, 0, 3, 2 -#define NETWORK_32BIT_AVX2_4 3, 2, 1, 0, 7, 6, 5, 4 - -/* - * Assumes ymm is random and performs a full sorting network defined in - * https://en.wikipedia.org/wiki/Bitonic_sorter#/media/File:BitonicSort.svg - */ -template -X86_SIMD_SORT_INLINE reg_t sort_ymm_32bit(reg_t ymm) -{ - const typename vtype::opmask_t oxAA = _mm256_set_epi32( - 0xFFFFFFFF, 0, 0xFFFFFFFF, 0, 0xFFFFFFFF, 0, 0xFFFFFFFF, 0); - const typename vtype::opmask_t oxCC = _mm256_set_epi32( - 0xFFFFFFFF, 0xFFFFFFFF, 0, 0, 0xFFFFFFFF, 0xFFFFFFFF, 0, 0); - const typename vtype::opmask_t oxF0 = _mm256_set_epi32( - 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0, 0, 0, 0); - - const typename vtype::ymmi_t rev_index = vtype::seti(NETWORK_32BIT_AVX2_2); - ymm = cmp_merge( - ymm, vtype::template shuffle(ymm), oxAA); - ymm = cmp_merge( - ymm, - vtype::permutexvar(vtype::seti(NETWORK_32BIT_AVX2_1), ymm), - oxCC); - ymm = cmp_merge( - ymm, vtype::template shuffle(ymm), oxAA); - ymm = cmp_merge(ymm, vtype::permutexvar(rev_index, ymm), oxF0); - ymm = cmp_merge( - ymm, - vtype::permutexvar(vtype::seti(NETWORK_32BIT_AVX2_3), ymm), - oxCC); - ymm = cmp_merge( - ymm, vtype::template shuffle(ymm), oxAA); - return ymm; -} - struct avx2_32bit_swizzle_ops; template <> @@ -180,7 +135,7 @@ struct avx2_vector { } static reg_t reverse(reg_t ymm) { - const __m256i rev_index = _mm256_set_epi32(NETWORK_32BIT_AVX2_2); + const __m256i rev_index = _mm256_set_epi32(NETWORK_REVERSE_8LANES); return permutexvar(rev_index, ymm); } static type_t reducemax(reg_t v) @@ -206,7 +161,7 @@ struct avx2_vector { } static reg_t sort_vec(reg_t x) { - return sort_ymm_32bit>(x); + return sort_reg_8lanes>(x); } static reg_t cast_from(__m256i v) { @@ -342,7 +297,7 @@ struct avx2_vector { } static reg_t reverse(reg_t ymm) { - const __m256i rev_index = _mm256_set_epi32(NETWORK_32BIT_AVX2_2); + const __m256i rev_index = _mm256_set_epi32(NETWORK_REVERSE_8LANES); return permutexvar(rev_index, ymm); } static type_t reducemax(reg_t v) @@ -368,7 +323,7 @@ struct avx2_vector { } static reg_t sort_vec(reg_t x) { - return sort_ymm_32bit>(x); + return sort_reg_8lanes>(x); } static reg_t cast_from(__m256i v) { @@ -520,7 +475,7 @@ struct avx2_vector { } static reg_t reverse(reg_t ymm) { - const __m256i rev_index = _mm256_set_epi32(NETWORK_32BIT_AVX2_2); + const __m256i rev_index = _mm256_set_epi32(NETWORK_REVERSE_8LANES); return permutexvar(rev_index, ymm); } static type_t reducemax(reg_t v) @@ -547,7 +502,7 @@ struct avx2_vector { } static reg_t sort_vec(reg_t x) { - return sort_ymm_32bit>(x); + return sort_reg_8lanes>(x); } static reg_t cast_from(__m256i v) { diff --git a/src/avx2-64bit-qsort.hpp b/src/avx2-64bit-qsort.hpp index d8b094d..5e020b8 100644 --- a/src/avx2-64bit-qsort.hpp +++ b/src/avx2-64bit-qsort.hpp @@ -10,32 +10,6 @@ #include "avx2-emu-funcs.hpp" -/* - * Assumes ymm is random and performs a full sorting network defined in - * https://en.wikipedia.org/wiki/Bitonic_sorter#/media/File:BitonicSort.svg - */ -template -X86_SIMD_SORT_INLINE reg_t sort_ymm_64bit(reg_t ymm) -{ - const typename vtype::opmask_t oxAA - = _mm256_set_epi64x(0xFFFFFFFFFFFFFFFF, 0, 0xFFFFFFFFFFFFFFFF, 0); - const typename vtype::opmask_t oxCC - = _mm256_set_epi64x(0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0, 0); - ymm = cmp_merge( - ymm, - vtype::template permutexvar(ymm), - oxAA); - ymm = cmp_merge( - ymm, - vtype::template permutexvar(ymm), - oxCC); - ymm = cmp_merge( - ymm, - vtype::template permutexvar(ymm), - oxAA); - return ymm; -} - struct avx2_64bit_swizzle_ops; template <> @@ -81,6 +55,10 @@ struct avx2_vector { auto mask = ((0x1ull << num_to_read) - 0x1ull); return convert_int_to_avx2_mask_64bit(mask); } + static opmask_t convert_int_to_mask(uint64_t intMask) + { + return convert_int_to_avx2_mask_64bit(intMask); + } static ymmi_t seti(int64_t v1, int64_t v2, int64_t v3, int64_t v4) { return _mm256_set_epi64x(v1, v2, v3, v4); @@ -207,7 +185,7 @@ struct avx2_vector { } static reg_t sort_vec(reg_t x) { - return sort_ymm_64bit>(x); + return sort_reg_4lanes>(x); } static reg_t cast_from(__m256i v) { @@ -265,6 +243,10 @@ struct avx2_vector { auto mask = ((0x1ull << num_to_read) - 0x1ull); return convert_int_to_avx2_mask_64bit(mask); } + static opmask_t convert_int_to_mask(uint64_t intMask) + { + return convert_int_to_avx2_mask_64bit(intMask); + } static ymmi_t seti(int64_t v1, int64_t v2, int64_t v3, int64_t v4) { return _mm256_set_epi64x(v1, v2, v3, v4); @@ -389,7 +371,7 @@ struct avx2_vector { } static reg_t sort_vec(reg_t x) { - return sort_ymm_64bit>(x); + return sort_reg_4lanes>(x); } static reg_t cast_from(__m256i v) { @@ -460,6 +442,10 @@ struct avx2_vector { auto mask = ((0x1ull << num_to_read) - 0x1ull); return convert_int_to_avx2_mask_64bit(mask); } + static opmask_t convert_int_to_mask(uint64_t intMask) + { + return convert_int_to_avx2_mask_64bit(intMask); + } static int32_t convert_mask_to_int(opmask_t mask) { return convert_avx2_mask_to_int_64bit(mask); @@ -593,7 +579,7 @@ struct avx2_vector { } static reg_t sort_vec(reg_t x) { - return sort_ymm_64bit>(x); + return sort_reg_4lanes>(x); } static reg_t cast_from(__m256i v) { diff --git a/src/avx512-16bit-common.h b/src/avx512-16bit-common.h index 76db872..524ce7a 100644 --- a/src/avx512-16bit-common.h +++ b/src/avx512-16bit-common.h @@ -7,89 +7,6 @@ #ifndef AVX512_16BIT_COMMON #define AVX512_16BIT_COMMON -/* - * Constants used in sorting 32 elements in a ZMM registers. Based on Bitonic - * sorting network (see - * https://en.wikipedia.org/wiki/Bitonic_sorter#/media/File:BitonicSort.svg) - */ -// ZMM register: 31,30,29,28,27,26,25,24,23,22,21,20,19,18,17,16,15,14,13,12,11,10,9,8,7,6,5,4,3,2,1,0 -static const uint16_t network[6][32] - = {{7, 6, 5, 4, 3, 2, 1, 0, 15, 14, 13, 12, 11, 10, 9, 8, - 23, 22, 21, 20, 19, 18, 17, 16, 31, 30, 29, 28, 27, 26, 25, 24}, - {15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0, - 31, 30, 29, 28, 27, 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 16}, - {4, 5, 6, 7, 0, 1, 2, 3, 12, 13, 14, 15, 8, 9, 10, 11, - 20, 21, 22, 23, 16, 17, 18, 19, 28, 29, 30, 31, 24, 25, 26, 27}, - {31, 30, 29, 28, 27, 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 16, - 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0}, - {8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, - 24, 25, 26, 27, 28, 29, 30, 31, 16, 17, 18, 19, 20, 21, 22, 23}, - {16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, - 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}}; - -/* - * Assumes zmm is random and performs a full sorting network defined in - * https://en.wikipedia.org/wiki/Bitonic_sorter#/media/File:BitonicSort.svg - */ -template -X86_SIMD_SORT_INLINE reg_t sort_zmm_16bit(reg_t zmm) -{ - // Level 1 - zmm = cmp_merge( - zmm, - vtype::template shuffle(zmm), - 0xAAAAAAAA); - // Level 2 - zmm = cmp_merge( - zmm, - vtype::template shuffle(zmm), - 0xCCCCCCCC); - zmm = cmp_merge( - zmm, - vtype::template shuffle(zmm), - 0xAAAAAAAA); - // Level 3 - zmm = cmp_merge( - zmm, vtype::permutexvar(vtype::get_network(1), zmm), 0xF0F0F0F0); - zmm = cmp_merge( - zmm, - vtype::template shuffle(zmm), - 0xCCCCCCCC); - zmm = cmp_merge( - zmm, - vtype::template shuffle(zmm), - 0xAAAAAAAA); - // Level 4 - zmm = cmp_merge( - zmm, vtype::permutexvar(vtype::get_network(2), zmm), 0xFF00FF00); - zmm = cmp_merge( - zmm, vtype::permutexvar(vtype::get_network(3), zmm), 0xF0F0F0F0); - zmm = cmp_merge( - zmm, - vtype::template shuffle(zmm), - 0xCCCCCCCC); - zmm = cmp_merge( - zmm, - vtype::template shuffle(zmm), - 0xAAAAAAAA); - // Level 5 - zmm = cmp_merge( - zmm, vtype::permutexvar(vtype::get_network(4), zmm), 0xFFFF0000); - zmm = cmp_merge( - zmm, vtype::permutexvar(vtype::get_network(5), zmm), 0xFF00FF00); - zmm = cmp_merge( - zmm, vtype::permutexvar(vtype::get_network(3), zmm), 0xF0F0F0F0); - zmm = cmp_merge( - zmm, - vtype::template shuffle(zmm), - 0xCCCCCCCC); - zmm = cmp_merge( - zmm, - vtype::template shuffle(zmm), - 0xAAAAAAAA); - return zmm; -} - struct avx512_16bit_swizzle_ops { template X86_SIMD_SORT_INLINE typename vtype::reg_t swap_n(typename vtype::reg_t reg) diff --git a/src/avx512-16bit-qsort.hpp b/src/avx512-16bit-qsort.hpp index 18595ac..65a64fb 100644 --- a/src/avx512-16bit-qsort.hpp +++ b/src/avx512-16bit-qsort.hpp @@ -30,10 +30,6 @@ struct zmm_vector { using swizzle_ops = avx512_16bit_swizzle_ops; - static reg_t get_network(int index) - { - return _mm512_loadu_si512(&network[index - 1][0]); - } static type_t type_max() { return X86_SIMD_SORT_INFINITYH; @@ -179,12 +175,12 @@ struct zmm_vector { } static reg_t reverse(reg_t zmm) { - const auto rev_index = get_network(4); + const auto rev_index = _mm512_set_epi16(NETWORK_REVERSE_32LANES); return permutexvar(rev_index, zmm); } static reg_t sort_vec(reg_t x) { - return sort_zmm_16bit>(x); + return sort_reg_32lanes>(x); } static reg_t cast_from(__m512i v) { @@ -225,10 +221,6 @@ struct zmm_vector { using swizzle_ops = avx512_16bit_swizzle_ops; - static reg_t get_network(int index) - { - return _mm512_loadu_si512(&network[index - 1][0]); - } static type_t type_max() { return X86_SIMD_SORT_MAX_INT16; @@ -328,12 +320,12 @@ struct zmm_vector { } static reg_t reverse(reg_t zmm) { - const auto rev_index = get_network(4); + const auto rev_index = _mm512_set_epi16(NETWORK_REVERSE_32LANES); return permutexvar(rev_index, zmm); } static reg_t sort_vec(reg_t x) { - return sort_zmm_16bit>(x); + return sort_reg_32lanes>(x); } static reg_t cast_from(__m512i v) { @@ -373,10 +365,6 @@ struct zmm_vector { using swizzle_ops = avx512_16bit_swizzle_ops; - static reg_t get_network(int index) - { - return _mm512_loadu_si512(&network[index - 1][0]); - } static type_t type_max() { return X86_SIMD_SORT_MAX_UINT16; @@ -474,12 +462,12 @@ struct zmm_vector { } static reg_t reverse(reg_t zmm) { - const auto rev_index = get_network(4); + const auto rev_index = _mm512_set_epi16(NETWORK_REVERSE_32LANES); return permutexvar(rev_index, zmm); } static reg_t sort_vec(reg_t x) { - return sort_zmm_16bit>(x); + return sort_reg_32lanes>(x); } static reg_t cast_from(__m512i v) { diff --git a/src/avx512-32bit-qsort.hpp b/src/avx512-32bit-qsort.hpp index eeaba51..ffcd85a 100644 --- a/src/avx512-32bit-qsort.hpp +++ b/src/avx512-32bit-qsort.hpp @@ -8,22 +8,6 @@ #ifndef AVX512_QSORT_32BIT #define AVX512_QSORT_32BIT -/* - * Constants used in sorting 16 elements in a ZMM registers. Based on Bitonic - * sorting network (see - * https://en.wikipedia.org/wiki/Bitonic_sorter#/media/File:BitonicSort.svg) - */ -#define NETWORK_32BIT_1 14, 15, 12, 13, 10, 11, 8, 9, 6, 7, 4, 5, 2, 3, 0, 1 -#define NETWORK_32BIT_2 12, 13, 14, 15, 8, 9, 10, 11, 4, 5, 6, 7, 0, 1, 2, 3 -#define NETWORK_32BIT_3 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7 -#define NETWORK_32BIT_4 13, 12, 15, 14, 9, 8, 11, 10, 5, 4, 7, 6, 1, 0, 3, 2 -#define NETWORK_32BIT_5 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 -#define NETWORK_32BIT_6 11, 10, 9, 8, 15, 14, 13, 12, 3, 2, 1, 0, 7, 6, 5, 4 -#define NETWORK_32BIT_7 7, 6, 5, 4, 3, 2, 1, 0, 15, 14, 13, 12, 11, 10, 9, 8 - -template -X86_SIMD_SORT_INLINE reg_t sort_zmm_32bit(reg_t zmm); - struct avx512_32bit_swizzle_ops; template <> @@ -185,12 +169,12 @@ struct zmm_vector { } static reg_t reverse(reg_t zmm) { - const auto rev_index = _mm512_set_epi32(NETWORK_32BIT_5); + const auto rev_index = _mm512_set_epi32(NETWORK_REVERSE_16LANES); return permutexvar(rev_index, zmm); } static reg_t sort_vec(reg_t x) { - return sort_zmm_32bit>(x); + return sort_reg_16lanes>(x); } static reg_t cast_from(__m512i v) { @@ -372,12 +356,12 @@ struct zmm_vector { } static reg_t reverse(reg_t zmm) { - const auto rev_index = _mm512_set_epi32(NETWORK_32BIT_5); + const auto rev_index = _mm512_set_epi32(NETWORK_REVERSE_16LANES); return permutexvar(rev_index, zmm); } static reg_t sort_vec(reg_t x) { - return sort_zmm_32bit>(x); + return sort_reg_16lanes>(x); } static reg_t cast_from(__m512i v) { @@ -573,12 +557,12 @@ struct zmm_vector { } static reg_t reverse(reg_t zmm) { - const auto rev_index = _mm512_set_epi32(NETWORK_32BIT_5); + const auto rev_index = _mm512_set_epi32(NETWORK_REVERSE_16LANES); return permutexvar(rev_index, zmm); } static reg_t sort_vec(reg_t x) { - return sort_zmm_32bit>(x); + return sort_reg_16lanes>(x); } static reg_t cast_from(__m512i v) { @@ -602,56 +586,6 @@ struct zmm_vector { } }; -/* - * Assumes zmm is random and performs a full sorting network defined in - * https://en.wikipedia.org/wiki/Bitonic_sorter#/media/File:BitonicSort.svg - */ -template -X86_SIMD_SORT_INLINE reg_t sort_zmm_32bit(reg_t zmm) -{ - zmm = cmp_merge( - zmm, - vtype::template shuffle(zmm), - 0xAAAA); - zmm = cmp_merge( - zmm, - vtype::template shuffle(zmm), - 0xCCCC); - zmm = cmp_merge( - zmm, - vtype::template shuffle(zmm), - 0xAAAA); - zmm = cmp_merge( - zmm, - vtype::permutexvar(_mm512_set_epi32(NETWORK_32BIT_3), zmm), - 0xF0F0); - zmm = cmp_merge( - zmm, - vtype::template shuffle(zmm), - 0xCCCC); - zmm = cmp_merge( - zmm, - vtype::template shuffle(zmm), - 0xAAAA); - zmm = cmp_merge( - zmm, - vtype::permutexvar(_mm512_set_epi32(NETWORK_32BIT_5), zmm), - 0xFF00); - zmm = cmp_merge( - zmm, - vtype::permutexvar(_mm512_set_epi32(NETWORK_32BIT_6), zmm), - 0xF0F0); - zmm = cmp_merge( - zmm, - vtype::template shuffle(zmm), - 0xCCCC); - zmm = cmp_merge( - zmm, - vtype::template shuffle(zmm), - 0xAAAA); - return zmm; -} - struct avx512_32bit_swizzle_ops { template X86_SIMD_SORT_INLINE typename vtype::reg_t swap_n(typename vtype::reg_t reg) diff --git a/src/avx512-64bit-common.h b/src/avx512-64bit-common.h index 14201d1..5d55196 100644 --- a/src/avx512-64bit-common.h +++ b/src/avx512-64bit-common.h @@ -9,19 +9,8 @@ #include "avx2-32bit-qsort.hpp" -/* - * Constants used in sorting 8 elements in a ZMM registers. Based on Bitonic - * sorting network (see - * https://en.wikipedia.org/wiki/Bitonic_sorter#/media/File:BitonicSort.svg) - */ -// ZMM 7, 6, 5, 4, 3, 2, 1, 0 -#define NETWORK_64BIT_1 4, 5, 6, 7, 0, 1, 2, 3 -#define NETWORK_64BIT_2 0, 1, 2, 3, 4, 5, 6, 7 -#define NETWORK_64BIT_3 5, 4, 7, 6, 1, 0, 3, 2 -#define NETWORK_64BIT_4 3, 2, 1, 0, 7, 6, 5, 4 - template -X86_SIMD_SORT_INLINE reg_t sort_zmm_64bit(reg_t zmm); +X86_SIMD_SORT_INLINE reg_t sort_reg_8lanes(reg_t zmm); struct avx512_64bit_swizzle_ops; struct avx512_ymm_64bit_swizzle_ops; @@ -196,7 +185,7 @@ struct ymm_vector { } static reg_t sort_vec(reg_t x) { - return sort_zmm_64bit>(x); + return sort_reg_8lanes>(x); } static void storeu(void *mem, reg_t x) { @@ -216,7 +205,7 @@ struct ymm_vector { } static reg_t reverse(reg_t ymm) { - const __m256i rev_index = _mm256_set_epi32(NETWORK_32BIT_AVX2_2); + const __m256i rev_index = _mm256_set_epi32(NETWORK_REVERSE_8LANES); return permutexvar(rev_index, ymm); } static int double_compressstore(type_t *left_addr, @@ -384,7 +373,7 @@ struct ymm_vector { } static reg_t sort_vec(reg_t x) { - return sort_zmm_64bit>(x); + return sort_reg_8lanes>(x); } static void storeu(void *mem, reg_t x) { @@ -404,7 +393,7 @@ struct ymm_vector { } static reg_t reverse(reg_t ymm) { - const __m256i rev_index = _mm256_set_epi32(NETWORK_32BIT_AVX2_2); + const __m256i rev_index = _mm256_set_epi32(NETWORK_REVERSE_8LANES); return permutexvar(rev_index, ymm); } static int double_compressstore(type_t *left_addr, @@ -572,7 +561,7 @@ struct ymm_vector { } static reg_t sort_vec(reg_t x) { - return sort_zmm_64bit>(x); + return sort_reg_8lanes>(x); } static void storeu(void *mem, reg_t x) { @@ -592,7 +581,7 @@ struct ymm_vector { } static reg_t reverse(reg_t ymm) { - const __m256i rev_index = _mm256_set_epi32(NETWORK_32BIT_AVX2_2); + const __m256i rev_index = _mm256_set_epi32(NETWORK_REVERSE_8LANES); return permutexvar(rev_index, ymm); } static int double_compressstore(type_t *left_addr, @@ -763,12 +752,12 @@ struct zmm_vector { } static reg_t reverse(reg_t zmm) { - const regi_t rev_index = seti(NETWORK_64BIT_2); + const regi_t rev_index = seti(NETWORK_REVERSE_8LANES); return permutexvar(rev_index, zmm); } static reg_t sort_vec(reg_t x) { - return sort_zmm_64bit>(x); + return sort_reg_8lanes>(x); } static reg_t cast_from(__m512i v) { @@ -942,12 +931,12 @@ struct zmm_vector { } static reg_t reverse(reg_t zmm) { - const regi_t rev_index = seti(NETWORK_64BIT_2); + const regi_t rev_index = seti(NETWORK_REVERSE_8LANES); return permutexvar(rev_index, zmm); } static reg_t sort_vec(reg_t x) { - return sort_zmm_64bit>(x); + return sort_reg_8lanes>(x); } static reg_t cast_from(__m512i v) { @@ -1140,12 +1129,12 @@ struct zmm_vector { } static reg_t reverse(reg_t zmm) { - const regi_t rev_index = seti(NETWORK_64BIT_2); + const regi_t rev_index = seti(NETWORK_REVERSE_8LANES); return permutexvar(rev_index, zmm); } static reg_t sort_vec(reg_t x) { - return sort_zmm_64bit>(x); + return sort_reg_8lanes>(x); } static reg_t cast_from(__m512i v) { @@ -1169,28 +1158,6 @@ struct zmm_vector { } }; -/* - * Assumes zmm is random and performs a full sorting network defined in - * https://en.wikipedia.org/wiki/Bitonic_sorter#/media/File:BitonicSort.svg - */ -template -X86_SIMD_SORT_INLINE reg_t sort_zmm_64bit(reg_t zmm) -{ - const typename vtype::regi_t rev_index = vtype::seti(NETWORK_64BIT_2); - zmm = cmp_merge( - zmm, vtype::template shuffle(zmm), 0xAA); - zmm = cmp_merge( - zmm, vtype::permutexvar(vtype::seti(NETWORK_64BIT_1), zmm), 0xCC); - zmm = cmp_merge( - zmm, vtype::template shuffle(zmm), 0xAA); - zmm = cmp_merge(zmm, vtype::permutexvar(rev_index, zmm), 0xF0); - zmm = cmp_merge( - zmm, vtype::permutexvar(vtype::seti(NETWORK_64BIT_3), zmm), 0xCC); - zmm = cmp_merge( - zmm, vtype::template shuffle(zmm), 0xAA); - return zmm; -} - struct avx512_64bit_swizzle_ops { template X86_SIMD_SORT_INLINE typename vtype::reg_t swap_n(typename vtype::reg_t reg) diff --git a/src/avx512fp16-16bit-qsort.hpp b/src/avx512fp16-16bit-qsort.hpp index 6d2cca6..7de26a0 100644 --- a/src/avx512fp16-16bit-qsort.hpp +++ b/src/avx512fp16-16bit-qsort.hpp @@ -27,10 +27,6 @@ struct zmm_vector<_Float16> { using swizzle_ops = avx512_16bit_swizzle_ops; - static __m512i get_network(int index) - { - return _mm512_loadu_si512(&network[index - 1][0]); - } static type_t type_max() { Fp16Bits val; @@ -143,12 +139,12 @@ struct zmm_vector<_Float16> { } static reg_t reverse(reg_t zmm) { - const auto rev_index = get_network(4); + const auto rev_index = _mm512_set_epi16(NETWORK_REVERSE_32LANES); return permutexvar(rev_index, zmm); } static reg_t sort_vec(reg_t x) { - return sort_zmm_16bit>(x); + return sort_reg_32lanes>(x); } static reg_t cast_from(__m512i v) { diff --git a/src/xss-common-argsort.h b/src/xss-common-argsort.h index 4fa5041..cf02b30 100644 --- a/src/xss-common-argsort.h +++ b/src/xss-common-argsort.h @@ -468,11 +468,11 @@ X86_SIMD_SORT_INLINE type_t get_pivot_64bit(type_t *arr, } template -X86_SIMD_SORT_INLINE void argsort_64bit_(type_t *arr, - arrsize_t *arg, - arrsize_t left, - arrsize_t right, - arrsize_t max_iters) +X86_SIMD_SORT_INLINE void argsort_(type_t *arr, + arrsize_t *arg, + arrsize_t left, + arrsize_t right, + arrsize_t max_iters) { /* * Resort to std::sort if quicksort isnt making any progress @@ -495,20 +495,19 @@ X86_SIMD_SORT_INLINE void argsort_64bit_(type_t *arr, arrsize_t pivot_index = argpartition_unrolled( arr, arg, left, right + 1, pivot, &smallest, &biggest); if (pivot != smallest) - argsort_64bit_( + argsort_( arr, arg, left, pivot_index - 1, max_iters - 1); if (pivot != biggest) - argsort_64bit_( - arr, arg, pivot_index, right, max_iters - 1); + argsort_(arr, arg, pivot_index, right, max_iters - 1); } template -X86_SIMD_SORT_INLINE void argselect_64bit_(type_t *arr, - arrsize_t *arg, - arrsize_t pos, - arrsize_t left, - arrsize_t right, - arrsize_t max_iters) +X86_SIMD_SORT_INLINE void argselect_(type_t *arr, + arrsize_t *arg, + arrsize_t pos, + arrsize_t left, + arrsize_t right, + arrsize_t max_iters) { /* * Resort to std::sort if quicksort isnt making any progress @@ -531,30 +530,34 @@ X86_SIMD_SORT_INLINE void argselect_64bit_(type_t *arr, arrsize_t pivot_index = argpartition_unrolled( arr, arg, left, right + 1, pivot, &smallest, &biggest); if ((pivot != smallest) && (pos < pivot_index)) - argselect_64bit_( + argselect_( arr, arg, pos, left, pivot_index - 1, max_iters - 1); else if ((pivot != biggest) && (pos >= pivot_index)) - argselect_64bit_( + argselect_( arr, arg, pos, pivot_index, right, max_iters - 1); } /* argsort methods for 32-bit and 64-bit dtypes */ -template -X86_SIMD_SORT_INLINE void avx512_argsort(T *arr, - arrsize_t *arg, - arrsize_t arrsize, - bool hasnan = false, - bool descending = false) +template + typename full_vector, + template + typename half_vector> +X86_SIMD_SORT_INLINE void xss_argsort(T *arr, + arrsize_t *arg, + arrsize_t arrsize, + bool hasnan = false, + bool descending = false) { - /* TODO optimization: on 32-bit, use zmm_vector for 32-bit dtype */ + /* TODO optimization: on 32-bit, use full_vector for 32-bit dtype */ using vectype = typename std::conditional, - zmm_vector>::type; + half_vector, + full_vector>::type; using argtype = typename std::conditional, - zmm_vector>::type; + half_vector, + full_vector>::type; if (arrsize > 1) { if constexpr (xss::fp::is_floating_point_v) { @@ -567,14 +570,24 @@ X86_SIMD_SORT_INLINE void avx512_argsort(T *arr, } } UNUSED(hasnan); - argsort_64bit_( + argsort_( arr, arg, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize)); if (descending) { std::reverse(arg, arg + arrsize); } } } -/* argsort methods for 32-bit and 64-bit dtypes */ +template +X86_SIMD_SORT_INLINE void avx512_argsort(T *arr, + arrsize_t *arg, + arrsize_t arrsize, + bool hasnan = false, + bool descending = false) +{ + xss_argsort( + arr, arg, arrsize, hasnan, descending); +} + template X86_SIMD_SORT_INLINE void avx2_argsort(T *arr, arrsize_t *arg, @@ -582,33 +595,45 @@ X86_SIMD_SORT_INLINE void avx2_argsort(T *arr, bool hasnan = false, bool descending = false) { + xss_argsort( + arr, arg, arrsize, hasnan, descending); +} + +/* argselect methods for 32-bit and 64-bit dtypes */ +template + typename full_vector, + template + typename half_vector> +X86_SIMD_SORT_INLINE void xss_argselect(T *arr, + arrsize_t *arg, + arrsize_t k, + arrsize_t arrsize, + bool hasnan = false) +{ + /* TODO optimization: on 32-bit, use full_vector for 32-bit dtype */ using vectype = typename std::conditional, - avx2_vector>::type; + half_vector, + full_vector>::type; using argtype = typename std::conditional, - avx2_vector>::type; + half_vector, + full_vector>::type; + if (arrsize > 1) { if constexpr (xss::fp::is_floating_point_v) { if ((hasnan) && (array_has_nan(arr, arrsize))) { - std_argsort_withnan(arr, arg, 0, arrsize); - - if (descending) { std::reverse(arg, arg + arrsize); } - + std_argselect_withnan(arr, arg, k, 0, arrsize); return; } } UNUSED(hasnan); - argsort_64bit_( - arr, arg, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize)); - - if (descending) { std::reverse(arg, arg + arrsize); } + argselect_( + arr, arg, k, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize)); } } -/* argselect methods for 32-bit and 64-bit dtypes */ template X86_SIMD_SORT_INLINE void avx512_argselect(T *arr, arrsize_t *arg, @@ -616,30 +641,9 @@ X86_SIMD_SORT_INLINE void avx512_argselect(T *arr, arrsize_t arrsize, bool hasnan = false) { - /* TODO optimization: on 32-bit, use zmm_vector for 32-bit dtype */ - using vectype = typename std::conditional, - zmm_vector>::type; - - using argtype = - typename std::conditional, - zmm_vector>::type; - - if (arrsize > 1) { - if constexpr (xss::fp::is_floating_point_v) { - if ((hasnan) && (array_has_nan(arr, arrsize))) { - std_argselect_withnan(arr, arg, k, 0, arrsize); - return; - } - } - UNUSED(hasnan); - argselect_64bit_( - arr, arg, k, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize)); - } + xss_argselect(arr, arg, k, arrsize, hasnan); } -/* argselect methods for 32-bit and 64-bit dtypes */ template X86_SIMD_SORT_INLINE void avx2_argselect(T *arr, arrsize_t *arg, @@ -647,25 +651,8 @@ X86_SIMD_SORT_INLINE void avx2_argselect(T *arr, arrsize_t arrsize, bool hasnan = false) { - using vectype = typename std::conditional, - avx2_vector>::type; - - using argtype = - typename std::conditional, - avx2_vector>::type; - - if (arrsize > 1) { - if constexpr (xss::fp::is_floating_point_v) { - if ((hasnan) && (array_has_nan(arr, arrsize))) { - std_argselect_withnan(arr, arg, k, 0, arrsize); - return; - } - } - UNUSED(hasnan); - argselect_64bit_( - arr, arg, k, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize)); - } + xss_argselect( + arr, arg, k, arrsize, hasnan); } + #endif // XSS_COMMON_ARGSORT diff --git a/src/xss-common-includes.h b/src/xss-common-includes.h index 83a5471..386ca86 100644 --- a/src/xss-common-includes.h +++ b/src/xss-common-includes.h @@ -74,25 +74,17 @@ #define X86_SIMD_SORT_UNROLL_LOOP(num) #endif +#define NETWORK_REVERSE_4LANES 0, 1, 2, 3 +#define NETWORK_REVERSE_8LANES 0, 1, 2, 3, 4, 5, 6, 7 +#define NETWORK_REVERSE_16LANES \ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 +#define NETWORK_REVERSE_32LANES \ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, \ + 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31 + template constexpr bool always_false = false; -/* - * Constants used in sorting 8 elements in a ZMM registers. Based on Bitonic - * sorting network (see - * https://en.wikipedia.org/wiki/Bitonic_sorter#/media/File:BitonicSort.svg) - */ -// ZMM 7, 6, 5, 4, 3, 2, 1, 0 -#define NETWORK_64BIT_1 4, 5, 6, 7, 0, 1, 2, 3 -#define NETWORK_64BIT_2 0, 1, 2, 3, 4, 5, 6, 7 -#define NETWORK_64BIT_3 5, 4, 7, 6, 1, 0, 3, 2 -#define NETWORK_64BIT_4 3, 2, 1, 0, 7, 6, 5, 4 -#define NETWORK_32BIT_1 14, 15, 12, 13, 10, 11, 8, 9, 6, 7, 4, 5, 2, 3, 0, 1 -#define NETWORK_32BIT_3 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7 -#define NETWORK_32BIT_5 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 -#define NETWORK_32BIT_6 11, 10, 9, 8, 15, 14, 13, 12, 3, 2, 1, 0, 7, 6, 5, 4 -#define NETWORK_32BIT_7 7, 6, 5, 4, 3, 2, 1, 0, 15, 14, 13, 12, 11, 10, 9, 8 - typedef size_t arrsize_t; template diff --git a/src/xss-common-qsort.h b/src/xss-common-qsort.h index 6401194..2bd66d2 100644 --- a/src/xss-common-qsort.h +++ b/src/xss-common-qsort.h @@ -37,6 +37,7 @@ #include "xss-pivot-selection.hpp" #include "xss-network-qsort.hpp" #include "xss-common-comparators.hpp" +#include "xss-reg-networks.hpp" template bool is_a_nan(T elem) diff --git a/src/xss-network-keyvaluesort.hpp b/src/xss-network-keyvaluesort.hpp index 3cb3037..771d7a3 100644 --- a/src/xss-network-keyvaluesort.hpp +++ b/src/xss-network-keyvaluesort.hpp @@ -78,294 +78,21 @@ X86_SIMD_SORT_INLINE reg_t1 cmp_merge(reg_t1 in1, return tmp_keys; // 0 -> min, 1 -> max } -template -X86_SIMD_SORT_INLINE reg_t sort_reg_16lanes(reg_t key_zmm, - index_type &index_zmm) -{ - using key_swizzle = typename vtype1::swizzle_ops; - using index_swizzle = typename vtype2::swizzle_ops; - - const auto oxAAAA = convert_int_to_mask(0xAAAA); - const auto oxCCCC = convert_int_to_mask(0xCCCC); - const auto oxFF00 = convert_int_to_mask(0xFF00); - const auto oxF0F0 = convert_int_to_mask(0xF0F0); - - key_zmm = cmp_merge( - key_zmm, - key_swizzle::template reverse_n(key_zmm), - index_zmm, - index_swizzle::template reverse_n(index_zmm), - oxAAAA); - key_zmm = cmp_merge( - key_zmm, - key_swizzle::template reverse_n(key_zmm), - index_zmm, - index_swizzle::template reverse_n(index_zmm), - oxCCCC); - key_zmm = cmp_merge( - key_zmm, - key_swizzle::template reverse_n(key_zmm), - index_zmm, - index_swizzle::template reverse_n(index_zmm), - oxAAAA); - key_zmm = cmp_merge( - key_zmm, - vtype1::permutexvar(vtype1::seti(NETWORK_32BIT_3), key_zmm), - index_zmm, - vtype2::permutexvar(vtype2::seti(NETWORK_32BIT_3), index_zmm), - oxF0F0); - key_zmm = cmp_merge( - key_zmm, - key_swizzle::template swap_n(key_zmm), - index_zmm, - index_swizzle::template swap_n(index_zmm), - oxCCCC); - key_zmm = cmp_merge( - key_zmm, - key_swizzle::template reverse_n(key_zmm), - index_zmm, - index_swizzle::template reverse_n(index_zmm), - oxAAAA); - key_zmm = cmp_merge( - key_zmm, - vtype1::permutexvar(vtype1::seti(NETWORK_32BIT_5), key_zmm), - index_zmm, - vtype2::permutexvar(vtype2::seti(NETWORK_32BIT_5), index_zmm), - oxFF00); - key_zmm = cmp_merge( - key_zmm, - vtype1::permutexvar(vtype1::seti(NETWORK_32BIT_6), key_zmm), - index_zmm, - vtype2::permutexvar(vtype2::seti(NETWORK_32BIT_6), index_zmm), - oxF0F0); - key_zmm = cmp_merge( - key_zmm, - key_swizzle::template swap_n(key_zmm), - index_zmm, - index_swizzle::template swap_n(index_zmm), - oxCCCC); - key_zmm = cmp_merge( - key_zmm, - key_swizzle::template reverse_n(key_zmm), - index_zmm, - index_swizzle::template reverse_n(index_zmm), - oxAAAA); - return key_zmm; -} - -// Assumes zmm is bitonic and performs a recursive half cleaner -template -X86_SIMD_SORT_INLINE reg_t bitonic_merge_reg_16lanes(reg_t key_zmm, - index_type &index_zmm) -{ - using key_swizzle = typename vtype1::swizzle_ops; - using index_swizzle = typename vtype2::swizzle_ops; - - const auto oxAAAA = convert_int_to_mask(0xAAAA); - const auto oxCCCC = convert_int_to_mask(0xCCCC); - const auto oxFF00 = convert_int_to_mask(0xFF00); - const auto oxF0F0 = convert_int_to_mask(0xF0F0); - - key_zmm = cmp_merge( - key_zmm, - vtype1::permutexvar(vtype1::seti(NETWORK_32BIT_7), key_zmm), - index_zmm, - vtype2::permutexvar(vtype2::seti(NETWORK_32BIT_7), index_zmm), - oxFF00); - key_zmm = cmp_merge( - key_zmm, - vtype1::permutexvar(vtype1::seti(NETWORK_32BIT_6), key_zmm), - index_zmm, - vtype2::permutexvar(vtype2::seti(NETWORK_32BIT_6), index_zmm), - oxF0F0); - key_zmm = cmp_merge( - key_zmm, - key_swizzle::template swap_n(key_zmm), - index_zmm, - index_swizzle::template swap_n(index_zmm), - oxCCCC); - key_zmm = cmp_merge( - key_zmm, - key_swizzle::template reverse_n(key_zmm), - index_zmm, - index_swizzle::template reverse_n(index_zmm), - oxAAAA); - return key_zmm; -} - -template -X86_SIMD_SORT_INLINE reg_t sort_reg_8lanes(reg_t key_zmm, index_type &index_zmm) -{ - using key_swizzle = typename vtype1::swizzle_ops; - using index_swizzle = typename vtype2::swizzle_ops; - - const auto oxAA = convert_int_to_mask(0xAA); - const auto oxCC = convert_int_to_mask(0xCC); - const auto oxF0 = convert_int_to_mask(0xF0); - - key_zmm = cmp_merge( - key_zmm, - key_swizzle::template swap_n(key_zmm), - index_zmm, - index_swizzle::template swap_n(index_zmm), - oxAA); - key_zmm = cmp_merge( - key_zmm, - vtype1::permutexvar(vtype1::seti(NETWORK_64BIT_1), key_zmm), - index_zmm, - vtype2::permutexvar(vtype2::seti(NETWORK_64BIT_1), index_zmm), - oxCC); - key_zmm = cmp_merge( - key_zmm, - key_swizzle::template swap_n(key_zmm), - index_zmm, - index_swizzle::template swap_n(index_zmm), - oxAA); - key_zmm = cmp_merge(key_zmm, - vtype1::reverse(key_zmm), - index_zmm, - vtype2::reverse(index_zmm), - oxF0); - key_zmm = cmp_merge( - key_zmm, - vtype1::permutexvar(vtype1::seti(NETWORK_64BIT_3), key_zmm), - index_zmm, - vtype2::permutexvar(vtype2::seti(NETWORK_64BIT_3), index_zmm), - oxCC); - key_zmm = cmp_merge( - key_zmm, - key_swizzle::template swap_n(key_zmm), - index_zmm, - index_swizzle::template swap_n(index_zmm), - oxAA); - return key_zmm; -} - -template -X86_SIMD_SORT_INLINE reg_t sort_ymm_64bit(reg_t key_zmm, index_type &index_zmm) -{ - using key_swizzle = typename vtype1::swizzle_ops; - using index_swizzle = typename vtype2::swizzle_ops; - - const typename vtype1::opmask_t oxAA = vtype1::seti(-1, 0, -1, 0); - const typename vtype1::opmask_t oxCC = vtype1::seti(-1, -1, 0, 0); - - key_zmm = cmp_merge( - key_zmm, - key_swizzle::template swap_n(key_zmm), - index_zmm, - index_swizzle::template swap_n(index_zmm), - oxAA); - key_zmm = cmp_merge(key_zmm, - vtype1::reverse(key_zmm), - index_zmm, - vtype2::reverse(index_zmm), - oxCC); - key_zmm = cmp_merge( - key_zmm, - key_swizzle::template swap_n(key_zmm), - index_zmm, - index_swizzle::template swap_n(index_zmm), - oxAA); - return key_zmm; -} - -// Assumes zmm is bitonic and performs a recursive half cleaner -template -X86_SIMD_SORT_INLINE reg_t bitonic_merge_reg_8lanes(reg_t key_zmm, - index_type &index_zmm) -{ - using key_swizzle = typename vtype1::swizzle_ops; - using index_swizzle = typename vtype2::swizzle_ops; - - const auto oxAA = convert_int_to_mask(0xAA); - const auto oxCC = convert_int_to_mask(0xCC); - const auto oxF0 = convert_int_to_mask(0xF0); - - // 1) half_cleaner[8]: compare 0-4, 1-5, 2-6, 3-7 - key_zmm = cmp_merge( - key_zmm, - vtype1::permutexvar(vtype1::seti(NETWORK_64BIT_4), key_zmm), - index_zmm, - vtype2::permutexvar(vtype2::seti(NETWORK_64BIT_4), index_zmm), - oxF0); - // 2) half_cleaner[4] - key_zmm = cmp_merge( - key_zmm, - vtype1::permutexvar(vtype1::seti(NETWORK_64BIT_3), key_zmm), - index_zmm, - vtype2::permutexvar(vtype2::seti(NETWORK_64BIT_3), index_zmm), - oxCC); - // 3) half_cleaner[1] - key_zmm = cmp_merge( - key_zmm, - key_swizzle::template swap_n(key_zmm), - index_zmm, - index_swizzle::template swap_n(index_zmm), - oxAA); - return key_zmm; -} - -template -X86_SIMD_SORT_INLINE reg_t bitonic_merge_ymm_64bit(reg_t key_zmm, - index_type &index_zmm) -{ - using key_swizzle = typename vtype1::swizzle_ops; - using index_swizzle = typename vtype2::swizzle_ops; - - const typename vtype1::opmask_t oxAA = vtype1::seti(-1, 0, -1, 0); - const typename vtype1::opmask_t oxCC = vtype1::seti(-1, -1, 0, 0); - - // 2) half_cleaner[4] - key_zmm = cmp_merge( - key_zmm, - key_swizzle::template swap_n(key_zmm), - index_zmm, - index_swizzle::template swap_n(index_zmm), - oxCC); - // 3) half_cleaner[1] - key_zmm = cmp_merge( - key_zmm, - key_swizzle::template swap_n(key_zmm), - index_zmm, - index_swizzle::template swap_n(index_zmm), - oxAA); - return key_zmm; -} - template X86_SIMD_SORT_INLINE void bitonic_merge_dispatch(typename keyType::reg_t &key, typename valueType::reg_t &value) { constexpr int numlanes = keyType::numlanes; - if constexpr (numlanes == 8) { + if constexpr (numlanes == 4) { + key = bitonic_merge_reg_4lanes(key, value); + } + else if constexpr (numlanes == 8) { key = bitonic_merge_reg_8lanes(key, value); } else if constexpr (numlanes == 16) { key = bitonic_merge_reg_16lanes(key, value); } - else if constexpr (numlanes == 4) { - key = bitonic_merge_ymm_64bit(key, value); - } else { static_assert(always_false, "bitonic_merge_dispatch: No implementation"); @@ -379,15 +106,15 @@ X86_SIMD_SORT_INLINE void sort_vec_dispatch(typename keyType::reg_t &key, typename valueType::reg_t &value) { constexpr int numlanes = keyType::numlanes; - if constexpr (numlanes == 8) { + if constexpr (numlanes == 4) { + key = sort_reg_4lanes(key, value); + } + else if constexpr (numlanes == 8) { key = sort_reg_8lanes(key, value); } else if constexpr (numlanes == 16) { key = sort_reg_16lanes(key, value); } - else if constexpr (numlanes == 4) { - key = sort_ymm_64bit(key, value); - } else { static_assert(always_false, "sort_vec_dispatch: No implementation"); diff --git a/src/xss-reg-networks.hpp b/src/xss-reg-networks.hpp new file mode 100644 index 0000000..727cccf --- /dev/null +++ b/src/xss-reg-networks.hpp @@ -0,0 +1,443 @@ +#ifndef XSS_REG_NETWORKS +#define XSS_REG_NETWORKS + +#include "xss-common-includes.h" + +template +typename vtype::opmask_t convert_int_to_mask(maskType mask); + +template +X86_SIMD_SORT_INLINE reg_t cmp_merge(reg_t in1, reg_t in2, opmask_t mask); + +template +X86_SIMD_SORT_INLINE reg_t1 cmp_merge(reg_t1 in1, + reg_t1 in2, + reg_t2 &indexes1, + reg_t2 indexes2, + opmask_t mask); + +// Single vector functions + +/* + * Assumes reg is random and performs a full sorting network defined in + * https://en.wikipedia.org/wiki/Bitonic_sorter#/media/File:BitonicSort.svg + */ +template +X86_SIMD_SORT_INLINE reg_t sort_reg_4lanes(reg_t reg) +{ + using swizzle = typename vtype::swizzle_ops; + + const typename vtype::opmask_t oxA = convert_int_to_mask(0xA); + const typename vtype::opmask_t oxC = convert_int_to_mask(0xC); + + reg = cmp_merge( + reg, swizzle::template reverse_n(reg), oxA); + reg = cmp_merge( + reg, swizzle::template reverse_n(reg), oxC); + reg = cmp_merge(reg, swizzle::template swap_n(reg), oxA); + return reg; +} + +/* + * Assumes reg is random and performs a full sorting network defined in + * https://en.wikipedia.org/wiki/Bitonic_sorter#/media/File:BitonicSort.svg + */ +template +X86_SIMD_SORT_INLINE reg_t sort_reg_8lanes(reg_t reg) +{ + using swizzle = typename vtype::swizzle_ops; + + const typename vtype::opmask_t oxAA = convert_int_to_mask(0xAA); + const typename vtype::opmask_t oxCC = convert_int_to_mask(0xCC); + const typename vtype::opmask_t oxF0 = convert_int_to_mask(0xF0); + + reg = cmp_merge( + reg, swizzle::template reverse_n(reg), oxAA); + reg = cmp_merge( + reg, swizzle::template reverse_n(reg), oxCC); + reg = cmp_merge(reg, swizzle::template swap_n(reg), oxAA); + reg = cmp_merge( + reg, swizzle::template reverse_n(reg), oxF0); + reg = cmp_merge(reg, swizzle::template swap_n(reg), oxCC); + reg = cmp_merge(reg, swizzle::template swap_n(reg), oxAA); + return reg; +} + +/* + * Assumes reg is random and performs a full sorting network defined in + * https://en.wikipedia.org/wiki/Bitonic_sorter#/media/File:BitonicSort.svg + */ +template +X86_SIMD_SORT_INLINE reg_t sort_reg_16lanes(reg_t reg) +{ + using swizzle = typename vtype::swizzle_ops; + + const typename vtype::opmask_t oxAAAA = convert_int_to_mask(0xAAAA); + const typename vtype::opmask_t oxCCCC = convert_int_to_mask(0xCCCC); + const typename vtype::opmask_t oxF0F0 = convert_int_to_mask(0xF0F0); + const typename vtype::opmask_t oxFF00 = convert_int_to_mask(0xFF00); + + reg = cmp_merge( + reg, swizzle::template reverse_n(reg), oxAAAA); + reg = cmp_merge( + reg, swizzle::template reverse_n(reg), oxCCCC); + reg = cmp_merge( + reg, swizzle::template swap_n(reg), oxAAAA); + reg = cmp_merge( + reg, swizzle::template reverse_n(reg), oxF0F0); + reg = cmp_merge( + reg, swizzle::template swap_n(reg), oxCCCC); + reg = cmp_merge( + reg, swizzle::template swap_n(reg), oxAAAA); + reg = cmp_merge( + reg, swizzle::template reverse_n(reg), oxFF00); + reg = cmp_merge( + reg, swizzle::template swap_n(reg), oxF0F0); + reg = cmp_merge( + reg, swizzle::template swap_n(reg), oxCCCC); + reg = cmp_merge( + reg, swizzle::template swap_n(reg), oxAAAA); + return reg; +} + +/* + * Assumes reg is random and performs a full sorting network defined in + * https://en.wikipedia.org/wiki/Bitonic_sorter#/media/File:BitonicSort.svg + */ +template +X86_SIMD_SORT_INLINE reg_t sort_reg_32lanes(reg_t reg) +{ + using swizzle = typename vtype::swizzle_ops; + + const typename vtype::opmask_t oxAAAAAAAA + = convert_int_to_mask(0xAAAAAAAA); + const typename vtype::opmask_t oxCCCCCCCC + = convert_int_to_mask(0xCCCCCCCC); + const typename vtype::opmask_t oxF0F0F0F0 + = convert_int_to_mask(0xF0F0F0F0); + const typename vtype::opmask_t oxFF00FF00 + = convert_int_to_mask(0xFF00FF00); + const typename vtype::opmask_t oxFFFF0000 + = convert_int_to_mask(0xFFFF0000); + + // Level 1 + reg = cmp_merge( + reg, swizzle::template reverse_n(reg), oxAAAAAAAA); + // Level 2 + reg = cmp_merge( + reg, swizzle::template reverse_n(reg), oxCCCCCCCC); + reg = cmp_merge( + reg, swizzle::template swap_n(reg), oxAAAAAAAA); + // Level 3 + reg = cmp_merge( + reg, swizzle::template reverse_n(reg), oxF0F0F0F0); + reg = cmp_merge( + reg, swizzle::template swap_n(reg), oxCCCCCCCC); + reg = cmp_merge( + reg, swizzle::template swap_n(reg), oxAAAAAAAA); + // Level 4 + reg = cmp_merge( + reg, swizzle::template reverse_n(reg), oxFF00FF00); + reg = cmp_merge( + reg, swizzle::template swap_n(reg), oxF0F0F0F0); + reg = cmp_merge( + reg, swizzle::template swap_n(reg), oxCCCCCCCC); + reg = cmp_merge( + reg, swizzle::template swap_n(reg), oxAAAAAAAA); + // Level 5 + reg = cmp_merge( + reg, swizzle::template reverse_n(reg), oxFFFF0000); + reg = cmp_merge( + reg, swizzle::template swap_n(reg), oxFF00FF00); + reg = cmp_merge( + reg, swizzle::template swap_n(reg), oxF0F0F0F0); + reg = cmp_merge( + reg, swizzle::template swap_n(reg), oxCCCCCCCC); + reg = cmp_merge( + reg, swizzle::template swap_n(reg), oxAAAAAAAA); + return reg; +} + +// Key-index functions for kv-sort + +template +X86_SIMD_SORT_INLINE reg_t sort_reg_4lanes(reg_t key_reg, index_type &index_reg) +{ + using key_swizzle = typename vtype1::swizzle_ops; + using index_swizzle = typename vtype2::swizzle_ops; + + const typename vtype1::opmask_t oxA = convert_int_to_mask(0xA); + const typename vtype1::opmask_t oxC = convert_int_to_mask(0xC); + + key_reg = cmp_merge( + key_reg, + key_swizzle::template reverse_n(key_reg), + index_reg, + index_swizzle::template reverse_n(index_reg), + oxA); + key_reg = cmp_merge( + key_reg, + key_swizzle::template reverse_n(key_reg), + index_reg, + index_swizzle::template reverse_n(index_reg), + oxC); + key_reg = cmp_merge( + key_reg, + key_swizzle::template swap_n(key_reg), + index_reg, + index_swizzle::template swap_n(index_reg), + oxA); + return key_reg; +} + +template +X86_SIMD_SORT_INLINE reg_t sort_reg_8lanes(reg_t key_reg, index_type &index_reg) +{ + using key_swizzle = typename vtype1::swizzle_ops; + using index_swizzle = typename vtype2::swizzle_ops; + + const auto oxAA = convert_int_to_mask(0xAA); + const auto oxCC = convert_int_to_mask(0xCC); + const auto oxF0 = convert_int_to_mask(0xF0); + + key_reg = cmp_merge( + key_reg, + key_swizzle::template reverse_n(key_reg), + index_reg, + index_swizzle::template reverse_n(index_reg), + oxAA); + key_reg = cmp_merge( + key_reg, + key_swizzle::template reverse_n(key_reg), + index_reg, + index_swizzle::template reverse_n(index_reg), + oxCC); + key_reg = cmp_merge( + key_reg, + key_swizzle::template swap_n(key_reg), + index_reg, + index_swizzle::template swap_n(index_reg), + oxAA); + key_reg = cmp_merge( + key_reg, + key_swizzle::template reverse_n(key_reg), + index_reg, + index_swizzle::template reverse_n(index_reg), + oxF0); + key_reg = cmp_merge( + key_reg, + key_swizzle::template swap_n(key_reg), + index_reg, + index_swizzle::template swap_n(index_reg), + oxCC); + key_reg = cmp_merge( + key_reg, + key_swizzle::template swap_n(key_reg), + index_reg, + index_swizzle::template swap_n(index_reg), + oxAA); + return key_reg; +} + +template +X86_SIMD_SORT_INLINE reg_t sort_reg_16lanes(reg_t key_reg, + index_type &index_reg) +{ + using key_swizzle = typename vtype1::swizzle_ops; + using index_swizzle = typename vtype2::swizzle_ops; + + const auto oxAAAA = convert_int_to_mask(0xAAAA); + const auto oxCCCC = convert_int_to_mask(0xCCCC); + const auto oxFF00 = convert_int_to_mask(0xFF00); + const auto oxF0F0 = convert_int_to_mask(0xF0F0); + + key_reg = cmp_merge( + key_reg, + key_swizzle::template reverse_n(key_reg), + index_reg, + index_swizzle::template reverse_n(index_reg), + oxAAAA); + key_reg = cmp_merge( + key_reg, + key_swizzle::template reverse_n(key_reg), + index_reg, + index_swizzle::template reverse_n(index_reg), + oxCCCC); + key_reg = cmp_merge( + key_reg, + key_swizzle::template swap_n(key_reg), + index_reg, + index_swizzle::template swap_n(index_reg), + oxAAAA); + key_reg = cmp_merge( + key_reg, + key_swizzle::template reverse_n(key_reg), + index_reg, + index_swizzle::template reverse_n(index_reg), + oxF0F0); + key_reg = cmp_merge( + key_reg, + key_swizzle::template swap_n(key_reg), + index_reg, + index_swizzle::template swap_n(index_reg), + oxCCCC); + key_reg = cmp_merge( + key_reg, + key_swizzle::template swap_n(key_reg), + index_reg, + index_swizzle::template swap_n(index_reg), + oxAAAA); + key_reg = cmp_merge( + key_reg, + key_swizzle::template reverse_n(key_reg), + index_reg, + index_swizzle::template reverse_n(index_reg), + oxFF00); + key_reg = cmp_merge( + key_reg, + key_swizzle::template swap_n(key_reg), + index_reg, + index_swizzle::template swap_n(index_reg), + oxF0F0); + key_reg = cmp_merge( + key_reg, + key_swizzle::template swap_n(key_reg), + index_reg, + index_swizzle::template swap_n(index_reg), + oxCCCC); + key_reg = cmp_merge( + key_reg, + key_swizzle::template swap_n(key_reg), + index_reg, + index_swizzle::template swap_n(index_reg), + oxAAAA); + return key_reg; +} + +// Assumes reg is bitonic and performs a recursive half cleaner +template +X86_SIMD_SORT_INLINE reg_t bitonic_merge_reg_4lanes(reg_t key_reg, + index_type &index_reg) +{ + using key_swizzle = typename vtype1::swizzle_ops; + using index_swizzle = typename vtype2::swizzle_ops; + + const typename vtype1::opmask_t oxA = convert_int_to_mask(0xA); + const typename vtype1::opmask_t oxC = convert_int_to_mask(0xC); + + // 2) half_cleaner[4] + key_reg = cmp_merge( + key_reg, + key_swizzle::template swap_n(key_reg), + index_reg, + index_swizzle::template swap_n(index_reg), + oxC); + // 3) half_cleaner[1] + key_reg = cmp_merge( + key_reg, + key_swizzle::template swap_n(key_reg), + index_reg, + index_swizzle::template swap_n(index_reg), + oxA); + return key_reg; +} + +// Assumes reg is bitonic and performs a recursive half cleaner +template +X86_SIMD_SORT_INLINE reg_t bitonic_merge_reg_8lanes(reg_t key_reg, + index_type &index_reg) +{ + using key_swizzle = typename vtype1::swizzle_ops; + using index_swizzle = typename vtype2::swizzle_ops; + + const auto oxAA = convert_int_to_mask(0xAA); + const auto oxCC = convert_int_to_mask(0xCC); + const auto oxF0 = convert_int_to_mask(0xF0); + + // 1) half_cleaner[8]: compare 0-4, 1-5, 2-6, 3-7 + key_reg = cmp_merge( + key_reg, + key_swizzle::template swap_n(key_reg), + index_reg, + index_swizzle::template swap_n(index_reg), + oxF0); + // 2) half_cleaner[4] + key_reg = cmp_merge( + key_reg, + key_swizzle::template swap_n(key_reg), + index_reg, + index_swizzle::template swap_n(index_reg), + oxCC); + // 3) half_cleaner[1] + key_reg = cmp_merge( + key_reg, + key_swizzle::template swap_n(key_reg), + index_reg, + index_swizzle::template swap_n(index_reg), + oxAA); + return key_reg; +} + +// Assumes reg is bitonic and performs a recursive half cleaner +template +X86_SIMD_SORT_INLINE reg_t bitonic_merge_reg_16lanes(reg_t key_reg, + index_type &index_reg) +{ + using key_swizzle = typename vtype1::swizzle_ops; + using index_swizzle = typename vtype2::swizzle_ops; + + const auto oxAAAA = convert_int_to_mask(0xAAAA); + const auto oxCCCC = convert_int_to_mask(0xCCCC); + const auto oxFF00 = convert_int_to_mask(0xFF00); + const auto oxF0F0 = convert_int_to_mask(0xF0F0); + + key_reg = cmp_merge( + key_reg, + key_swizzle::template swap_n(key_reg), + index_reg, + index_swizzle::template swap_n(index_reg), + oxFF00); + key_reg = cmp_merge( + key_reg, + key_swizzle::template swap_n(key_reg), + index_reg, + index_swizzle::template swap_n(index_reg), + oxF0F0); + key_reg = cmp_merge( + key_reg, + key_swizzle::template swap_n(key_reg), + index_reg, + index_swizzle::template swap_n(index_reg), + oxCCCC); + key_reg = cmp_merge( + key_reg, + key_swizzle::template swap_n(key_reg), + index_reg, + index_swizzle::template swap_n(index_reg), + oxAAAA); + return key_reg; +} + +#endif // XSS_REG_NETWORKS \ No newline at end of file