From 9e4717c600deba3f164f5fa8c23234f368ec30fd Mon Sep 17 00:00:00 2001 From: jinge90 Date: Fri, 9 Feb 2024 16:12:29 +0800 Subject: [PATCH 01/71] [SYCL] Impl of fallback group sort in device code Signed-off-by: jinge90 --- libdevice/atomic.hpp | 73 +--------------- libdevice/cmake/modules/SYCLLibdevice.cmake | 6 +- libdevice/fallback-gsort.cpp | 38 +++++++++ libdevice/group_helper.hpp | 33 ++++++++ libdevice/sort_helper.hpp | 93 +++++++++++++++++++++ libdevice/spirv_decls.hpp | 86 +++++++++++++++++++ 6 files changed, 256 insertions(+), 73 deletions(-) create mode 100644 libdevice/fallback-gsort.cpp create mode 100644 libdevice/group_helper.hpp create mode 100644 libdevice/sort_helper.hpp create mode 100644 libdevice/spirv_decls.hpp diff --git a/libdevice/atomic.hpp b/libdevice/atomic.hpp index 3b6d1cf71f441..1a77fa69c5ad6 100644 --- a/libdevice/atomic.hpp +++ b/libdevice/atomic.hpp @@ -7,78 +7,9 @@ //===----------------------------------------------------------------------===// #pragma once -#include - -#include "device.h" - -#ifdef __SPIR__ - -#define SPIR_GLOBAL __attribute__((opencl_global)) - -namespace __spv { -struct Scope { - - enum Flag : uint32_t { - CrossDevice = 0, - Device = 1, - Workgroup = 2, - Subgroup = 3, - Invocation = 4, - }; - - constexpr Scope(Flag flag) : flag_value(flag) {} - - constexpr operator uint32_t() const { return flag_value; } - - Flag flag_value; -}; - -struct MemorySemanticsMask { - - enum Flag : uint32_t { - None = 0x0, - Acquire = 0x2, - Release = 0x4, - AcquireRelease = 0x8, - SequentiallyConsistent = 0x10, - UniformMemory = 0x40, - SubgroupMemory = 0x80, - WorkgroupMemory = 0x100, - CrossWorkgroupMemory = 0x200, - AtomicCounterMemory = 0x400, - ImageMemory = 0x800, - }; - - constexpr MemorySemanticsMask(Flag flag) : flag_value(flag) {} - - constexpr operator uint32_t() const { return flag_value; } - - Flag flag_value; -}; -} // namespace __spv - -extern DEVICE_EXTERNAL int -__spirv_AtomicCompareExchange(int SPIR_GLOBAL *, __spv::Scope::Flag, - __spv::MemorySemanticsMask::Flag, - __spv::MemorySemanticsMask::Flag, int, int); - -extern DEVICE_EXTERNAL int -__spirv_AtomicCompareExchange(int *, __spv::Scope::Flag, - __spv::MemorySemanticsMask::Flag, - __spv::MemorySemanticsMask::Flag, int, int); - -extern DEVICE_EXTERNAL int __spirv_AtomicLoad(const int SPIR_GLOBAL *, - __spv::Scope::Flag, - __spv::MemorySemanticsMask::Flag); - -extern DEVICE_EXTERNAL void -__spirv_AtomicStore(int SPIR_GLOBAL *, __spv::Scope::Flag, - __spv::MemorySemanticsMask::Flag, int); - -extern DEVICE_EXTERNAL void -__spirv_AtomicStore(int *, __spv::Scope::Flag, __spv::MemorySemanticsMask::Flag, - int); +#include "spirv_decls.hpp" +#if defined(__SPIR__) /// Atomically set the value in *Ptr with Desired if and only if it is Expected /// Return the value which already was in *Ptr static inline int atomicCompareAndSet(SPIR_GLOBAL int *Ptr, int Desired, diff --git a/libdevice/cmake/modules/SYCLLibdevice.cmake b/libdevice/cmake/modules/SYCLLibdevice.cmake index 1d2e1b4de64f5..6a73b415d4535 100644 --- a/libdevice/cmake/modules/SYCLLibdevice.cmake +++ b/libdevice/cmake/modules/SYCLLibdevice.cmake @@ -102,15 +102,16 @@ function(add_fallback_devicelib fallback_filename) add_devicelib_obj(${fallback_filename} SRC ${FB_SRC} DEP ${FB_DEP} EXTRA_ARGS ${FB_EXTRA_ARGS}) endfunction() -set(crt_obj_deps wrapper.h device.h spirv_vars.h sycl-compiler) +set(crt_obj_deps wrapper.h device.h atomic.hpp spirv_decls.hpp spirv_vars.h sycl-compiler) set(complex_obj_deps device_complex.h device.h sycl-compiler) set(cmath_obj_deps device_math.h device.h sycl-compiler) set(imf_obj_deps device_imf.hpp imf_half.hpp imf_bf16.hpp imf_rounding_op.hpp imf_impl_utils.hpp device.h sycl-compiler) set(itt_obj_deps device_itt.h spirv_vars.h device.h sycl-compiler) set(bfloat16_obj_deps sycl-headers sycl-compiler) if (NOT MSVC) - set(sanitizer_obj_deps device.h atomic.hpp spirv_vars.h include/sanitizer_device_utils.hpp include/spir_global_var.hpp sycl-compiler) + set(sanitizer_obj_deps device.h atomic.hpp spirv_decls.hpp spirv_vars.h include/sanitizer_device_utils.hpp include/spir_global_var.hpp sycl-compiler) endif() +set(gsort_obj_deps device.h spirv_decls.hpp spirv_vars.h group_helper.hpp sort_helper.hpp sycl-compiler) add_devicelib_obj(libsycl-itt-stubs SRC itt_stubs.cpp DEP ${itt_obj_deps}) add_devicelib_obj(libsycl-itt-compiler-wrappers SRC itt_compiler_wrappers.cpp DEP ${itt_obj_deps}) @@ -139,6 +140,7 @@ add_fallback_devicelib(libsycl-fallback-cmath SRC fallback-cmath.cpp DEP ${cmath add_fallback_devicelib(libsycl-fallback-cmath-fp64 SRC fallback-cmath-fp64.cpp DEP ${cmath_obj_deps}) add_fallback_devicelib(libsycl-fallback-bfloat16 SRC fallback-bfloat16.cpp DEP ${bfloat16_obj_deps}) add_fallback_devicelib(libsycl-native-bfloat16 SRC bfloat16_wrapper.cpp DEP ${bfloat16_obj_deps}) +add_fallback_devicelib(libsycl-fallback-gsort SRC fallback-gsort.cpp DEP ${gsort_obj_deps}) file(MAKE_DIRECTORY ${obj_binary_dir}/libdevice) set(imf_fallback_src_dir ${obj_binary_dir}/libdevice) diff --git a/libdevice/fallback-gsort.cpp b/libdevice/fallback-gsort.cpp new file mode 100644 index 0000000000000..0b0099ff82339 --- /dev/null +++ b/libdevice/fallback-gsort.cpp @@ -0,0 +1,38 @@ + +//==--- fallback_gsort_fp32.cpp - fallback implementation of group sort +//-----==// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +#include "device.h" +#include "sort_helper.hpp" +#include +#if defined(__SPIR__) + +DEVICE_EXTERN_C_INLINE +void __devicelib_default_work_group_joint_sort_ascending_p1i32_u32_p1i8( + int32_t *first, uint32_t n, uint8_t *scratch) { + merge_sort(first, n, scratch); +} + +DEVICE_EXTERN_C_INLINE +void __devicelib_default_work_group_joint_sort_ascending_p1i32_u32_p3i8( + int32_t *first, uint32_t n, uint8_t *scratch) { + merge_sort(first, n, scratch); +} + +DEVICE_EXTERN_C_INLINE +void __devicelib_default_work_group_joint_sort_ascending_p3i32_u32_p1i8( + int32_t *first, uint32_t n, uint8_t *scratch) { + merge_sort(first, n, scratch); +} + +DEVICE_EXTERN_C_INLINE +void __devicelib_default_work_group_joint_sort_ascending_p3i32_u32_p3i8( + int32_t *first, uint32_t n, uint8_t *scratch) { + merge_sort(first, n, scratch); +} +#endif diff --git a/libdevice/group_helper.hpp b/libdevice/group_helper.hpp new file mode 100644 index 0000000000000..59d3319508472 --- /dev/null +++ b/libdevice/group_helper.hpp @@ -0,0 +1,33 @@ +//==------- group_helper.hpp - utils related to work-group operations-------==// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//==------------------------------------------------------------------------==// +#pragma once +#include "spirv_vars.h" +#include "spirv_decls.hpp" +#include +#if defined(__SPIR__) + +static inline size_t __get_wg_local_range() { + return __spirv_BuiltInWorkgroupSize.x * __spirv_BuiltInWorkgroupSize.y * + __spirv_BuiltInWorkgroupSize.z; +} + +static inline size_t __get_wg_local_linear_id() { + return (__spirv_BuiltInLocalInvocationId.x * __spirv_BuiltInWorkgroupSize.y * + __spirv_BuiltInWorkgroupSize.z) + + (__spirv_BuiltInLocalInvocationId.y * __spirv_BuiltInWorkgroupSize.z) + + __spirv_BuiltInLocalInvocationId.z; +} + +static inline void group_barrier() { + __spirv_ControlBarrier(__spv::Scope::Workgroup, __spv::Scope::Workgroup, + __spv::MemorySemanticsMask::SequentiallyConsistent | + __spv::MemorySemanticsMask::SubgroupMemory | + __spv::MemorySemanticsMask::WorkgroupMemory | + __spv::MemorySemanticsMask::CrossWorkgroupMemory); +} +#endif diff --git a/libdevice/sort_helper.hpp b/libdevice/sort_helper.hpp new file mode 100644 index 0000000000000..21b675e2be32c --- /dev/null +++ b/libdevice/sort_helper.hpp @@ -0,0 +1,93 @@ +//==------- sort_helper.hpp - helper functions to do group sorting----------==// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//==------------------------------------------------------------------------==// + +#ifndef __LIBDEVICE_SORT_H__ +#define __LIBDEVICE_SORT_H__ +#include "group_helper.hpp" +#include + +#if defined(__SPIR__) +void bubble_sort(int32_t *first, const size_t beg, const size_t end) { + if (beg < end) { + for (size_t i = beg; i < end; ++i) + for (size_t j = i + 1; j < end; ++j) { + if (first[i] > first[j]) { + first[i] = first[i] ^ first[j]; + first[j] = first[i] ^ first[j]; + first[i] = first[i] ^ first[j]; + } + } + } +} + +void merge(int32_t *din, int32_t *dout, size_t widx, size_t msize, + size_t chunks, size_t n) { + if (2 * widx >= chunks) + return; + size_t beg1 = 2 * widx * msize; + size_t end1 = beg1 + msize; + size_t beg2, end2; + if (end1 >= n) { + end1 = beg2 = end2 = n; + } else { + beg2 = end1; + end2 = beg2 + msize; + if (end2 >= n) + end2 = n; + } + size_t output_idx = 2 * widx * msize; + while ((beg1 != end1) && (beg2 != end2)) { + if (din[beg1] < din[beg2]) + dout[output_idx++] = din[beg1++]; + else + dout[output_idx++] = din[beg2++]; + } + + while (beg1 != end1) + dout[output_idx++] = din[beg1++]; + while (beg2 != end2) + dout[output_idx++] = din[beg2++]; +} + +void merge_sort(int32_t *first, uint32_t n, uint8_t *scratch) { + const size_t idx = __get_wg_local_linear_id(); + const size_t wg_size = __get_wg_local_range(); + const size_t chunk_size = (n - 1) / wg_size + 1; + + const size_t bubble_beg = (idx * chunk_size) >= n ? n : idx * chunk_size; + const size_t bubble_end = + ((idx + 1) * chunk_size) > n ? n : (idx + 1) * chunk_size; + bubble_sort(first, bubble_beg, bubble_end); + group_barrier(); + int32_t *scratch1 = reinterpret_cast(scratch); + bool data_in_scratch = false; + // We have wg_size chunks here, each chunk has chunk_size elements which + // are sorted. The last chunck's element number may be smaller. + size_t chunks_to_merge = (n - 1) / chunk_size + 1; + size_t merge_size = chunk_size; + while (chunks_to_merge > 1) { + // workitem 0 will merge chunk 0, 1. + // workitem 1 will merge chunk 2, 3. + // workitem idx will merge chunk 2 * idx and 2 * idx + 1 + int32_t *data_in = data_in_scratch ? scratch1 : first; + int32_t *data_out = data_in_scratch ? first : scratch1; + merge(data_in, data_out, idx, merge_size, chunks_to_merge, n); + group_barrier(); + chunks_to_merge = (chunks_to_merge - 1) / 2 + 1; + merge_size <<= 1; + data_in_scratch = !data_in_scratch; + } + if (data_in_scratch) { + for (size_t i = idx * chunk_size; i < bubble_end; ++i) + first[i] = scratch1[i]; + group_barrier(); + } +} +#endif + +#endif diff --git a/libdevice/spirv_decls.hpp b/libdevice/spirv_decls.hpp new file mode 100644 index 0000000000000..d05d5007b9b9d --- /dev/null +++ b/libdevice/spirv_decls.hpp @@ -0,0 +1,86 @@ +//==-------------- atomic.hpp - support of atomic operations ---------------==// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +#pragma once + +#include + +#include "device.h" + +#ifdef __SPIR__ + +#define SPIR_GLOBAL __attribute__((opencl_global)) + +namespace __spv { +struct Scope { + + enum Flag : uint32_t { + CrossDevice = 0, + Device = 1, + Workgroup = 2, + Subgroup = 3, + Invocation = 4, + }; + + constexpr Scope(Flag flag) : flag_value(flag) {} + + constexpr operator uint32_t() const { return flag_value; } + + Flag flag_value; +}; + +struct MemorySemanticsMask { + + enum Flag : uint32_t { + None = 0x0, + Acquire = 0x2, + Release = 0x4, + AcquireRelease = 0x8, + SequentiallyConsistent = 0x10, + UniformMemory = 0x40, + SubgroupMemory = 0x80, + WorkgroupMemory = 0x100, + CrossWorkgroupMemory = 0x200, + AtomicCounterMemory = 0x400, + ImageMemory = 0x800, + }; + + constexpr MemorySemanticsMask(Flag flag) : flag_value(flag) {} + + constexpr operator uint32_t() const { return flag_value; } + + Flag flag_value; +}; +} // namespace __spv + +extern DEVICE_EXTERNAL void +__spirv_ControlBarrier(__spv::Scope::Flag, __spv::Scope::Flag Memory, + uint32_t Semantics); + +extern DEVICE_EXTERNAL int +__spirv_AtomicCompareExchange(int SPIR_GLOBAL *, __spv::Scope::Flag, + __spv::MemorySemanticsMask::Flag, + __spv::MemorySemanticsMask::Flag, int, int); + +extern DEVICE_EXTERNAL int +__spirv_AtomicCompareExchange(int *, __spv::Scope::Flag, + __spv::MemorySemanticsMask::Flag, + __spv::MemorySemanticsMask::Flag, int, int); + +extern DEVICE_EXTERNAL int __spirv_AtomicLoad(const int SPIR_GLOBAL *, + __spv::Scope::Flag, + __spv::MemorySemanticsMask::Flag); + +extern DEVICE_EXTERNAL void +__spirv_AtomicStore(int SPIR_GLOBAL *, __spv::Scope::Flag, + __spv::MemorySemanticsMask::Flag, int); + +extern DEVICE_EXTERNAL void +__spirv_AtomicStore(int *, __spv::Scope::Flag, __spv::MemorySemanticsMask::Flag, + int); + +#endif // __SPIR__ From 4614b00c3134bebdd3e8a3f08d3166b03d27b0f3 Mon Sep 17 00:00:00 2001 From: jinge90 Date: Fri, 9 Feb 2024 16:25:15 +0800 Subject: [PATCH 02/71] fix clang-fortmat issue Signed-off-by: jinge90 --- libdevice/group_helper.hpp | 2 +- libdevice/spirv_decls.hpp | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/libdevice/group_helper.hpp b/libdevice/group_helper.hpp index 59d3319508472..b95d2e6984b32 100644 --- a/libdevice/group_helper.hpp +++ b/libdevice/group_helper.hpp @@ -6,8 +6,8 @@ // //==------------------------------------------------------------------------==// #pragma once -#include "spirv_vars.h" #include "spirv_decls.hpp" +#include "spirv_vars.h" #include #if defined(__SPIR__) diff --git a/libdevice/spirv_decls.hpp b/libdevice/spirv_decls.hpp index d05d5007b9b9d..597bfcd26a3b5 100644 --- a/libdevice/spirv_decls.hpp +++ b/libdevice/spirv_decls.hpp @@ -57,9 +57,9 @@ struct MemorySemanticsMask { }; } // namespace __spv -extern DEVICE_EXTERNAL void -__spirv_ControlBarrier(__spv::Scope::Flag, __spv::Scope::Flag Memory, - uint32_t Semantics); +extern DEVICE_EXTERNAL void __spirv_ControlBarrier(__spv::Scope::Flag, + __spv::Scope::Flag Memory, + uint32_t Semantics); extern DEVICE_EXTERNAL int __spirv_AtomicCompareExchange(int SPIR_GLOBAL *, __spv::Scope::Flag, From 4d47fcb22641338a46fc6535556f539919a364a6 Mon Sep 17 00:00:00 2001 From: jinge90 Date: Thu, 22 Feb 2024 21:38:36 +0800 Subject: [PATCH 03/71] template for group sort Signed-off-by: jinge90 --- libdevice/fallback-gsort.cpp | 81 ++++++++++++++++++++++++++++++++++-- libdevice/sort_helper.hpp | 73 ++++++++++++++++++++++++++------ 2 files changed, 137 insertions(+), 17 deletions(-) diff --git a/libdevice/fallback-gsort.cpp b/libdevice/fallback-gsort.cpp index 0b0099ff82339..888e78b695002 100644 --- a/libdevice/fallback-gsort.cpp +++ b/libdevice/fallback-gsort.cpp @@ -10,29 +10,102 @@ #include "device.h" #include "sort_helper.hpp" #include +#include #if defined(__SPIR__) DEVICE_EXTERN_C_INLINE void __devicelib_default_work_group_joint_sort_ascending_p1i32_u32_p1i8( int32_t *first, uint32_t n, uint8_t *scratch) { - merge_sort(first, n, scratch); + merge_sort(first, n, scratch, std::less{}); } DEVICE_EXTERN_C_INLINE void __devicelib_default_work_group_joint_sort_ascending_p1i32_u32_p3i8( int32_t *first, uint32_t n, uint8_t *scratch) { - merge_sort(first, n, scratch); + merge_sort(first, n, scratch, std::less{}); } DEVICE_EXTERN_C_INLINE void __devicelib_default_work_group_joint_sort_ascending_p3i32_u32_p1i8( int32_t *first, uint32_t n, uint8_t *scratch) { - merge_sort(first, n, scratch); + merge_sort(first, n, scratch, std::less{}); } DEVICE_EXTERN_C_INLINE void __devicelib_default_work_group_joint_sort_ascending_p3i32_u32_p3i8( int32_t *first, uint32_t n, uint8_t *scratch) { - merge_sort(first, n, scratch); + merge_sort(first, n, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +void __devicelib_default_work_group_joint_sort_descending_p1i32_u32_p1i8( + int32_t *first, uint32_t n, uint8_t *scratch) { + merge_sort(first, n, scratch, std::greater{}); +} + +DEVICE_EXTERN_C_INLINE +void __devicelib_default_work_group_joint_sort_descending_p1i32_u32_p3i8( + int32_t *first, uint32_t n, uint8_t *scratch) { + merge_sort(first, n, scratch, std::greater{}); +} + +DEVICE_EXTERN_C_INLINE +void __devicelib_default_work_group_joint_sort_descending_p3i32_u32_p1i8( + int32_t *first, uint32_t n, uint8_t *scratch) { + merge_sort(first, n, scratch, std::greater{}); +} + +DEVICE_EXTERN_C_INLINE +void __devicelib_default_work_group_joint_sort_descending_p3i32_u32_p3i8( + int32_t *first, uint32_t n, uint8_t *scratch) { + merge_sort(first, n, scratch, std::greater{}); +} + +DEVICE_EXTERN_C_INLINE +void __devicelib_default_work_group_joint_sort_ascending_p1f32_u32_p1i8( + float *first, uint32_t n, uint8_t *scratch) { + merge_sort(first, n, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +void __devicelib_default_work_group_joint_sort_ascending_p1f32_u32_p3i8( + float *first, uint32_t n, uint8_t *scratch) { + merge_sort(first, n, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +void __devicelib_default_work_group_joint_sort_ascending_p3f32_u32_p1i8( + float *first, uint32_t n, uint8_t *scratch) { + merge_sort(first, n, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +void __devicelib_default_work_group_joint_sort_ascending_p3f32_u32_p3i8( + float *first, uint32_t n, uint8_t *scratch) { + merge_sort(first, n, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +void __devicelib_default_work_group_joint_sort_descending_p1f32_u32_p1i8( + float *first, uint32_t n, uint8_t *scratch) { + merge_sort(first, n, scratch, std::greater{}); +} + +DEVICE_EXTERN_C_INLINE +void __devicelib_default_work_group_joint_sort_descending_p1f32_u32_p3i8( + float *first, uint32_t n, uint8_t *scratch) { + merge_sort(first, n, scratch, std::greater{}); +} + +DEVICE_EXTERN_C_INLINE +void __devicelib_default_work_group_joint_sort_descending_p3f32_u32_p1i8( + float *first, uint32_t n, uint8_t *scratch) { + merge_sort(first, n, scratch, std::greater{}); +} + +DEVICE_EXTERN_C_INLINE +void __devicelib_default_work_group_joint_sort_descending_p3f32_u32_p3i8( + float *first, uint32_t n, uint8_t *scratch) { + merge_sort(first, n, scratch, std::greater{}); } #endif diff --git a/libdevice/sort_helper.hpp b/libdevice/sort_helper.hpp index 21b675e2be32c..08975f764d441 100644 --- a/libdevice/sort_helper.hpp +++ b/libdevice/sort_helper.hpp @@ -12,37 +12,46 @@ #include #if defined(__SPIR__) -void bubble_sort(int32_t *first, const size_t beg, const size_t end) { +template +void bubble_sort(Tp *first, const size_t beg, const size_t end, Compare comp) { if (beg < end) { + Tp temp; for (size_t i = beg; i < end; ++i) for (size_t j = i + 1; j < end; ++j) { - if (first[i] > first[j]) { - first[i] = first[i] ^ first[j]; - first[j] = first[i] ^ first[j]; - first[i] = first[i] ^ first[j]; + if (!comp(first[i], first[j])) { + temp = first[i]; + first[i] = first[j]; + first[j] = temp; } } } } -void merge(int32_t *din, int32_t *dout, size_t widx, size_t msize, - size_t chunks, size_t n) { +// widx: work-item id with a work-group +// chunks: number of sorted chunks waiting to be merged +// n: total number of elements waiting to be sorted +// msize: number of elements in a chunk ready to be merged +template +void merge(Tp *din, Tp *dout, size_t widx, size_t msize, size_t chunks, + size_t n, Compare comp) { if (2 * widx >= chunks) - return; + return ; + size_t beg1 = 2 * widx * msize; size_t end1 = beg1 + msize; size_t beg2, end2; - if (end1 >= n) { + if (end1 >= n) end1 = beg2 = end2 = n; - } else { + else { beg2 = end1; end2 = beg2 + msize; if (end2 >= n) end2 = n; } + size_t output_idx = 2 * widx * msize; while ((beg1 != end1) && (beg2 != end2)) { - if (din[beg1] < din[beg2]) + if (comp(din[beg1], din[beg2])) dout[output_idx++] = din[beg1++]; else dout[output_idx++] = din[beg2++]; @@ -54,6 +63,43 @@ void merge(int32_t *din, int32_t *dout, size_t widx, size_t msize, dout[output_idx++] = din[beg2++]; } +template +void merge_sort(Tp *first, uint32_t n, uint8_t *scratch, Compare comp) { + const size_t idx = __get_wg_local_linear_id(); + const size_t wg_size = __get_wg_local_range(); + const size_t chunk_size = (n - 1) / wg_size + 1; + + const size_t bubble_beg = (idx * chunk_size) >= n ? n : idx * chunk_size; + const size_t bubble_end = + ((idx + 1) * chunk_size) > n ? n : (idx + 1) * chunk_size; + bubble_sort(first, bubble_beg, bubble_end, comp); + group_barrier(); + Tp *scratch1 = reinterpret_cast(scratch); + bool data_in_scratch = false; + // We have wg_size chunks here, each chunk has chunk_size elements which + // are sorted. The last chunck's element number may be smaller. + size_t chunks_to_merge = (n - 1) / chunk_size + 1; + size_t merge_size = chunk_size; + while (chunks_to_merge > 1) { + // workitem 0 will merge chunk 0, 1. + // workitem 1 will merge chunk 2, 3. + // workitem idx will merge chunk 2 * idx and 2 * idx + 1 + Tp *data_in = data_in_scratch ? scratch1 : first; + Tp *data_out = data_in_scratch ? first : scratch1; + merge(data_in, data_out, idx, merge_size, chunks_to_merge, n, comp); + group_barrier(); + chunks_to_merge = (chunks_to_merge - 1) / 2 + 1; + merge_size <<= 1; + data_in_scratch = !data_in_scratch; + } + if (data_in_scratch) { + for (size_t i = idx * chunk_size; i < bubble_end; ++i) + first[i] = scratch1[i]; + group_barrier(); + } +} + +/* void merge_sort(int32_t *first, uint32_t n, uint8_t *scratch) { const size_t idx = __get_wg_local_linear_id(); const size_t wg_size = __get_wg_local_range(); @@ -62,7 +108,7 @@ void merge_sort(int32_t *first, uint32_t n, uint8_t *scratch) { const size_t bubble_beg = (idx * chunk_size) >= n ? n : idx * chunk_size; const size_t bubble_end = ((idx + 1) * chunk_size) > n ? n : (idx + 1) * chunk_size; - bubble_sort(first, bubble_beg, bubble_end); + bubble_sort(first, bubble_beg, bubble_end, std::greater{}); group_barrier(); int32_t *scratch1 = reinterpret_cast(scratch); bool data_in_scratch = false; @@ -76,7 +122,7 @@ void merge_sort(int32_t *first, uint32_t n, uint8_t *scratch) { // workitem idx will merge chunk 2 * idx and 2 * idx + 1 int32_t *data_in = data_in_scratch ? scratch1 : first; int32_t *data_out = data_in_scratch ? first : scratch1; - merge(data_in, data_out, idx, merge_size, chunks_to_merge, n); + merge(data_in, data_out, idx, merge_size, chunks_to_merge, n, std::less{}); group_barrier(); chunks_to_merge = (chunks_to_merge - 1) / 2 + 1; merge_size <<= 1; @@ -88,6 +134,7 @@ void merge_sort(int32_t *first, uint32_t n, uint8_t *scratch) { group_barrier(); } } +*/ #endif #endif From 33c2df41a277dfd83f2c541b3ff5c04bb21b6996 Mon Sep 17 00:00:00 2001 From: jinge90 Date: Thu, 22 Feb 2024 21:53:48 +0800 Subject: [PATCH 04/71] remove commented code Signed-off-by: jinge90 --- libdevice/sort_helper.hpp | 36 ------------------------------------ 1 file changed, 36 deletions(-) diff --git a/libdevice/sort_helper.hpp b/libdevice/sort_helper.hpp index 08975f764d441..8b4a5eaba33eb 100644 --- a/libdevice/sort_helper.hpp +++ b/libdevice/sort_helper.hpp @@ -99,42 +99,6 @@ void merge_sort(Tp *first, uint32_t n, uint8_t *scratch, Compare comp) { } } -/* -void merge_sort(int32_t *first, uint32_t n, uint8_t *scratch) { - const size_t idx = __get_wg_local_linear_id(); - const size_t wg_size = __get_wg_local_range(); - const size_t chunk_size = (n - 1) / wg_size + 1; - - const size_t bubble_beg = (idx * chunk_size) >= n ? n : idx * chunk_size; - const size_t bubble_end = - ((idx + 1) * chunk_size) > n ? n : (idx + 1) * chunk_size; - bubble_sort(first, bubble_beg, bubble_end, std::greater{}); - group_barrier(); - int32_t *scratch1 = reinterpret_cast(scratch); - bool data_in_scratch = false; - // We have wg_size chunks here, each chunk has chunk_size elements which - // are sorted. The last chunck's element number may be smaller. - size_t chunks_to_merge = (n - 1) / chunk_size + 1; - size_t merge_size = chunk_size; - while (chunks_to_merge > 1) { - // workitem 0 will merge chunk 0, 1. - // workitem 1 will merge chunk 2, 3. - // workitem idx will merge chunk 2 * idx and 2 * idx + 1 - int32_t *data_in = data_in_scratch ? scratch1 : first; - int32_t *data_out = data_in_scratch ? first : scratch1; - merge(data_in, data_out, idx, merge_size, chunks_to_merge, n, std::less{}); - group_barrier(); - chunks_to_merge = (chunks_to_merge - 1) / 2 + 1; - merge_size <<= 1; - data_in_scratch = !data_in_scratch; - } - if (data_in_scratch) { - for (size_t i = idx * chunk_size; i < bubble_end; ++i) - first[i] = scratch1[i]; - group_barrier(); - } -} -*/ #endif #endif From 90cb9d2d1bd5bbfd42ea172465e4ef15106f187f Mon Sep 17 00:00:00 2001 From: jinge90 Date: Fri, 23 Feb 2024 13:03:40 +0800 Subject: [PATCH 05/71] add default work group joint sort for i8 --- libdevice/fallback-gsort.cpp | 48 ++++++++++++++++++++++++++++++++++++ 1 file changed, 48 insertions(+) diff --git a/libdevice/fallback-gsort.cpp b/libdevice/fallback-gsort.cpp index 888e78b695002..afcc15c18a177 100644 --- a/libdevice/fallback-gsort.cpp +++ b/libdevice/fallback-gsort.cpp @@ -13,6 +13,54 @@ #include #if defined(__SPIR__) +DEVICE_EXTERN_C_INLINE +void __devicelib_default_work_group_joint_sort_ascending_p1i8_u32_p1i8( + int8_t *first, uint32_t n, uint8_t *scratch) { + merge_sort(first, n, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +void __devicelib_default_work_group_joint_sort_ascending_p1i8_u32_p3i8( + int8_t *first, uint32_t n, uint8_t *scratch) { + merge_sort(first, n, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +void __devicelib_default_work_group_joint_sort_ascending_p3i8_u32_p1i8( + int8_t *first, uint32_t n, uint8_t *scratch) { + merge_sort(first, n, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +void __devicelib_default_work_group_joint_sort_ascending_p3i8_u32_p3i8( + int8_t *first, uint32_t n, uint8_t *scratch) { + merge_sort(first, n, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +void __devicelib_default_work_group_joint_sort_descending_p1i8_u32_p1i8( + int8_t *first, uint32_t n, uint8_t *scratch) { + merge_sort(first, n, scratch, std::greater{}); +} + +DEVICE_EXTERN_C_INLINE +void __devicelib_default_work_group_joint_sort_descending_p1i8_u32_p3i8( + int8_t *first, uint32_t n, uint8_t *scratch) { + merge_sort(first, n, scratch, std::greater{}); +} + +DEVICE_EXTERN_C_INLINE +void __devicelib_default_work_group_joint_sort_descending_p3i8_u32_p1i8( + int8_t *first, uint32_t n, uint8_t *scratch) { + merge_sort(first, n, scratch, std::greater{}); +} + +DEVICE_EXTERN_C_INLINE +void __devicelib_default_work_group_joint_sort_descending_p3i8_u32_p3i8( + int8_t *first, uint32_t n, uint8_t *scratch) { + merge_sort(first, n, scratch, std::greater{}); +} + DEVICE_EXTERN_C_INLINE void __devicelib_default_work_group_joint_sort_ascending_p1i32_u32_p1i8( int32_t *first, uint32_t n, uint8_t *scratch) { From f0224ea478015ebfce13e13fbba51da31f6a3400 Mon Sep 17 00:00:00 2001 From: jinge90 Date: Fri, 23 Feb 2024 15:45:46 +0800 Subject: [PATCH 06/71] add default work group joint sort fo i64 Signed-off-by: jinge90 --- libdevice/fallback-gsort.cpp | 95 ++++++++++++++++++++++++++++++++++++ 1 file changed, 95 insertions(+) diff --git a/libdevice/fallback-gsort.cpp b/libdevice/fallback-gsort.cpp index afcc15c18a177..34a83599449b6 100644 --- a/libdevice/fallback-gsort.cpp +++ b/libdevice/fallback-gsort.cpp @@ -61,6 +61,54 @@ void __devicelib_default_work_group_joint_sort_descending_p3i8_u32_p3i8( merge_sort(first, n, scratch, std::greater{}); } +DEVICE_EXTERN_C_INLINE +void __devicelib_default_work_group_joint_sort_ascending_p1i16_u32_p1i8( + int16_t *first, uint32_t n, uint8_t *scratch) { + merge_sort(first, n, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +void __devicelib_default_work_group_joint_sort_ascending_p1i16_u32_p3i8( + int16_t *first, uint32_t n, uint8_t *scratch) { + merge_sort(first, n, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +void __devicelib_default_work_group_joint_sort_ascending_p3i16_u32_p1i8( + int16_t *first, uint32_t n, uint8_t *scratch) { + merge_sort(first, n, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +void __devicelib_default_work_group_joint_sort_ascending_p3i16_u32_p3i8( + int16_t *first, uint32_t n, uint8_t *scratch) { + merge_sort(first, n, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +void __devicelib_default_work_group_joint_sort_descending_p1i16_u32_p1i8( + int16_t *first, uint32_t n, uint8_t *scratch) { + merge_sort(first, n, scratch, std::greater{}); +} + +DEVICE_EXTERN_C_INLINE +void __devicelib_default_work_group_joint_sort_descending_p1i16_u32_p3i8( + int16_t *first, uint32_t n, uint8_t *scratch) { + merge_sort(first, n, scratch, std::greater{}); +} + +DEVICE_EXTERN_C_INLINE +void __devicelib_default_work_group_joint_sort_descending_p3i16_u32_p1i8( + int16_t *first, uint32_t n, uint8_t *scratch) { + merge_sort(first, n, scratch, std::greater{}); +} + +DEVICE_EXTERN_C_INLINE +void __devicelib_default_work_group_joint_sort_descending_p3i16_u32_p3i8( + int16_t *first, uint32_t n, uint8_t *scratch) { + merge_sort(first, n, scratch, std::greater{}); +} + DEVICE_EXTERN_C_INLINE void __devicelib_default_work_group_joint_sort_ascending_p1i32_u32_p1i8( int32_t *first, uint32_t n, uint8_t *scratch) { @@ -109,6 +157,53 @@ void __devicelib_default_work_group_joint_sort_descending_p3i32_u32_p3i8( merge_sort(first, n, scratch, std::greater{}); } +void __devicelib_default_work_group_joint_sort_ascending_p1i64_u32_p1i8( + int64_t *first, uint32_t n, uint8_t *scratch) { + merge_sort(first, n, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +void __devicelib_default_work_group_joint_sort_ascending_p1i64_u32_p3i8( + int64_t *first, uint32_t n, uint8_t *scratch) { + merge_sort(first, n, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +void __devicelib_default_work_group_joint_sort_ascending_p3i64_u32_p1i8( + int64_t *first, uint32_t n, uint8_t *scratch) { + merge_sort(first, n, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +void __devicelib_default_work_group_joint_sort_ascending_p3i64_u32_p3i8( + int64_t *first, uint32_t n, uint8_t *scratch) { + merge_sort(first, n, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +void __devicelib_default_work_group_joint_sort_descending_p1i64_u32_p1i8( + int64_t *first, uint32_t n, uint8_t *scratch) { + merge_sort(first, n, scratch, std::greater{}); +} + +DEVICE_EXTERN_C_INLINE +void __devicelib_default_work_group_joint_sort_descending_p1i64_u32_p3i8( + int64_t *first, uint32_t n, uint8_t *scratch) { + merge_sort(first, n, scratch, std::greater{}); +} + +DEVICE_EXTERN_C_INLINE +void __devicelib_default_work_group_joint_sort_descending_p3i64_u32_p1i8( + int64_t *first, uint32_t n, uint8_t *scratch) { + merge_sort(first, n, scratch, std::greater{}); +} + +DEVICE_EXTERN_C_INLINE +void __devicelib_default_work_group_joint_sort_descending_p3i64_u32_p3i8( + int64_t *first, uint32_t n, uint8_t *scratch) { + merge_sort(first, n, scratch, std::greater{}); +} + DEVICE_EXTERN_C_INLINE void __devicelib_default_work_group_joint_sort_ascending_p1f32_u32_p1i8( float *first, uint32_t n, uint8_t *scratch) { From eed61f46495866d91f23d0643d96b5813e382d9f Mon Sep 17 00:00:00 2001 From: jinge90 Date: Tue, 27 Feb 2024 14:35:16 +0800 Subject: [PATCH 07/71] add default work group joint sort for unsigned integer Signed-off-by: jinge90 --- libdevice/fallback-gsort.cpp | 193 +++++++++++++++++++++++++++++++++++ 1 file changed, 193 insertions(+) diff --git a/libdevice/fallback-gsort.cpp b/libdevice/fallback-gsort.cpp index 34a83599449b6..6d7992fc1d1fd 100644 --- a/libdevice/fallback-gsort.cpp +++ b/libdevice/fallback-gsort.cpp @@ -13,6 +13,7 @@ #include #if defined(__SPIR__) +//============ default work grop joint sort for signed integer =============== DEVICE_EXTERN_C_INLINE void __devicelib_default_work_group_joint_sort_ascending_p1i8_u32_p1i8( int8_t *first, uint32_t n, uint8_t *scratch) { @@ -204,6 +205,198 @@ void __devicelib_default_work_group_joint_sort_descending_p3i64_u32_p3i8( merge_sort(first, n, scratch, std::greater{}); } +//=========== default work grop joint sort for unsigned integer ============== +DEVICE_EXTERN_C_INLINE +void __devicelib_default_work_group_joint_sort_ascending_p1u8_u32_p1i8( + uint8_t *first, uint32_t n, uint8_t *scratch) { + merge_sort(first, n, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +void __devicelib_default_work_group_joint_sort_ascending_p1u8_u32_p3i8( + uint8_t *first, uint32_t n, uint8_t *scratch) { + merge_sort(first, n, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +void __devicelib_default_work_group_joint_sort_ascending_p3u8_u32_p1i8( + uint8_t *first, uint32_t n, uint8_t *scratch) { + merge_sort(first, n, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +void __devicelib_default_work_group_joint_sort_ascending_p3u8_u32_p3i8( + uint8_t *first, uint32_t n, uint8_t *scratch) { + merge_sort(first, n, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +void __devicelib_default_work_group_joint_sort_descending_p1u8_u32_p1i8( + uint8_t *first, uint32_t n, uint8_t *scratch) { + merge_sort(first, n, scratch, std::greater{}); +} + +DEVICE_EXTERN_C_INLINE +void __devicelib_default_work_group_joint_sort_descending_p1u8_u32_p3i8( + uint8_t *first, uint32_t n, uint8_t *scratch) { + merge_sort(first, n, scratch, std::greater{}); +} + +DEVICE_EXTERN_C_INLINE +void __devicelib_default_work_group_joint_sort_descending_p3u8_u32_p1i8( + uint8_t *first, uint32_t n, uint8_t *scratch) { + merge_sort(first, n, scratch, std::greater{}); +} + +DEVICE_EXTERN_C_INLINE +void __devicelib_default_work_group_joint_sort_descending_p3u8_u32_p3i8( + uint8_t *first, uint32_t n, uint8_t *scratch) { + merge_sort(first, n, scratch, std::greater{}); +} + +DEVICE_EXTERN_C_INLINE +void __devicelib_default_work_group_joint_sort_ascending_p1u16_u32_p1i8( + uint16_t *first, uint32_t n, uint8_t *scratch) { + merge_sort(first, n, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +void __devicelib_default_work_group_joint_sort_ascending_p1u16_u32_p3i8( + uint16_t *first, uint32_t n, uint8_t *scratch) { + merge_sort(first, n, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +void __devicelib_default_work_group_joint_sort_ascending_p3u16_u32_p1i8( + uint16_t *first, uint32_t n, uint8_t *scratch) { + merge_sort(first, n, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +void __devicelib_default_work_group_joint_sort_ascending_p3u16_u32_p3i8( + uint16_t *first, uint32_t n, uint8_t *scratch) { + merge_sort(first, n, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +void __devicelib_default_work_group_joint_sort_descending_p1u16_u32_p1i8( + uint16_t *first, uint32_t n, uint8_t *scratch) { + merge_sort(first, n, scratch, std::greater{}); +} + +DEVICE_EXTERN_C_INLINE +void __devicelib_default_work_group_joint_sort_descending_p1u16_u32_p3i8( + uint16_t *first, uint32_t n, uint8_t *scratch) { + merge_sort(first, n, scratch, std::greater{}); +} + +DEVICE_EXTERN_C_INLINE +void __devicelib_default_work_group_joint_sort_descending_p3u16_u32_p1i8( + uint16_t *first, uint32_t n, uint8_t *scratch) { + merge_sort(first, n, scratch, std::greater{}); +} + +DEVICE_EXTERN_C_INLINE +void __devicelib_default_work_group_joint_sort_descending_p3u16_u32_p3i8( + uint16_t *first, uint32_t n, uint8_t *scratch) { + merge_sort(first, n, scratch, std::greater{}); +} + +DEVICE_EXTERN_C_INLINE +void __devicelib_default_work_group_joint_sort_ascending_p1u32_u32_p1i8( + uint32_t *first, uint32_t n, uint8_t *scratch) { + merge_sort(first, n, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +void __devicelib_default_work_group_joint_sort_ascending_p1u32_u32_p3i8( + uint32_t *first, uint32_t n, uint8_t *scratch) { + merge_sort(first, n, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +void __devicelib_default_work_group_joint_sort_ascending_p3u32_u32_p1i8( + uint32_t *first, uint32_t n, uint8_t *scratch) { + merge_sort(first, n, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +void __devicelib_default_work_group_joint_sort_ascending_p3u32_u32_p3i8( + uint32_t *first, uint32_t n, uint8_t *scratch) { + merge_sort(first, n, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +void __devicelib_default_work_group_joint_sort_descending_p1u32_u32_p1i8( + uint32_t *first, uint32_t n, uint8_t *scratch) { + merge_sort(first, n, scratch, std::greater{}); +} + +DEVICE_EXTERN_C_INLINE +void __devicelib_default_work_group_joint_sort_descending_p1u32_u32_p3i8( + uint32_t *first, uint32_t n, uint8_t *scratch) { + merge_sort(first, n, scratch, std::greater{}); +} + +DEVICE_EXTERN_C_INLINE +void __devicelib_default_work_group_joint_sort_descending_p3u32_u32_p1i8( + uint32_t *first, uint32_t n, uint8_t *scratch) { + merge_sort(first, n, scratch, std::greater{}); +} + +DEVICE_EXTERN_C_INLINE +void __devicelib_default_work_group_joint_sort_descending_p3u32_u32_p3i8( + uint32_t *first, uint32_t n, uint8_t *scratch) { + merge_sort(first, n, scratch, std::greater{}); +} + +void __devicelib_default_work_group_joint_sort_ascending_p1u64_u32_p1i8( + uint64_t *first, uint32_t n, uint8_t *scratch) { + merge_sort(first, n, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +void __devicelib_default_work_group_joint_sort_ascending_p1u64_u32_p3i8( + uint64_t *first, uint32_t n, uint8_t *scratch) { + merge_sort(first, n, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +void __devicelib_default_work_group_joint_sort_ascending_p3u64_u32_p1i8( + uint64_t *first, uint32_t n, uint8_t *scratch) { + merge_sort(first, n, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +void __devicelib_default_work_group_joint_sort_ascending_p3u64_u32_p3i8( + uint64_t *first, uint32_t n, uint8_t *scratch) { + merge_sort(first, n, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +void __devicelib_default_work_group_joint_sort_descending_p1u64_u32_p1i8( + uint64_t *first, uint32_t n, uint8_t *scratch) { + merge_sort(first, n, scratch, std::greater{}); +} + +DEVICE_EXTERN_C_INLINE +void __devicelib_default_work_group_joint_sort_descending_p1u64_u32_p3i8( + uint64_t *first, uint32_t n, uint8_t *scratch) { + merge_sort(first, n, scratch, std::greater{}); +} + +DEVICE_EXTERN_C_INLINE +void __devicelib_default_work_group_joint_sort_descending_p3u64_u32_p1i8( + uint64_t *first, uint32_t n, uint8_t *scratch) { + merge_sort(first, n, scratch, std::greater{}); +} + +DEVICE_EXTERN_C_INLINE +void __devicelib_default_work_group_joint_sort_descending_p3u64_u32_p3i8( + uint64_t *first, uint32_t n, uint8_t *scratch) { + merge_sort(first, n, scratch, std::greater{}); +} +//=============== default work grop joint sort for fp32 ====================== DEVICE_EXTERN_C_INLINE void __devicelib_default_work_group_joint_sort_ascending_p1f32_u32_p1i8( float *first, uint32_t n, uint8_t *scratch) { From 9840222af70dbed0391c1e33cd8aaf28aefa077e Mon Sep 17 00:00:00 2001 From: jinge90 Date: Fri, 1 Mar 2024 15:21:40 +0800 Subject: [PATCH 08/71] add wg default private sort for i8 Signed-off-by: jinge90 --- libdevice/fallback-gsort.cpp | 52 ++++++++++++++++++++++++++++++++++++ libdevice/sort_helper.hpp | 46 +++++++++++++++++++++++++++++-- 2 files changed, 96 insertions(+), 2 deletions(-) diff --git a/libdevice/fallback-gsort.cpp b/libdevice/fallback-gsort.cpp index 6d7992fc1d1fd..5dbeb74550fba 100644 --- a/libdevice/fallback-gsort.cpp +++ b/libdevice/fallback-gsort.cpp @@ -444,4 +444,56 @@ void __devicelib_default_work_group_joint_sort_descending_p3f32_u32_p3i8( float *first, uint32_t n, uint8_t *scratch) { merge_sort(first, n, scratch, std::greater{}); } + +//============ default work grop private sort for signed integer ============== +// Since 'first' should point to 'private' memory address space, it can only be +// decorated with 'p1'. +DEVICE_EXTERN_C_INLINE +void __devicelib_default_work_group_private_sort_close_ascending_p1i8_u32_p1i8( + int8_t *first, uint32_t n, uint8_t *scratch) { + private_merge_sort_close(first, n, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +void __devicelib_default_work_group_private_sort_close_ascending_p1i8_u32_p3i8( + int8_t *first, uint32_t n, uint8_t *scratch) { + private_merge_sort_close(first, n, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +void __devicelib_default_work_group_private_sort_close_descending_p1i8_u32_p1i8( + int8_t *first, uint32_t n, uint8_t *scratch) { + private_merge_sort_close(first, n, scratch, std::greater{}); +} + +DEVICE_EXTERN_C_INLINE +void __devicelib_default_work_group_private_sort_close_descending_p1i8_u32_p3i8( + int8_t *first, uint32_t n, uint8_t *scratch) { + private_merge_sort_close(first, n, scratch, std::greater{}); +} + +DEVICE_EXTERN_C_INLINE +void __devicelib_default_work_group_private_sort_spread_ascending_p1i8_u32_p1i8( + int8_t *first, uint32_t n, uint8_t *scratch) { + private_merge_sort_spread(first, n, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +void __devicelib_default_work_group_private_sort_spread_ascending_p1i8_u32_p3i8( + int8_t *first, uint32_t n, uint8_t *scratch) { + private_merge_sort_spread(first, n, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +void __devicelib_default_work_group_private_sort_spread_descending_p1i8_u32_p1i8( + int8_t *first, uint32_t n, uint8_t *scratch) { + private_merge_sort_spread(first, n, scratch, std::greater{}); +} + +DEVICE_EXTERN_C_INLINE +void __devicelib_default_work_group_private_sort_spread_descending_p1i8_u32_p3i8( + int8_t *first, uint32_t n, uint8_t *scratch) { + private_merge_sort_spread(first, n, scratch, std::greater{}); +} + #endif diff --git a/libdevice/sort_helper.hpp b/libdevice/sort_helper.hpp index 8b4a5eaba33eb..740033d9bc80d 100644 --- a/libdevice/sort_helper.hpp +++ b/libdevice/sort_helper.hpp @@ -35,7 +35,7 @@ template void merge(Tp *din, Tp *dout, size_t widx, size_t msize, size_t chunks, size_t n, Compare comp) { if (2 * widx >= chunks) - return ; + return; size_t beg1 = 2 * widx * msize; size_t end1 = beg1 + msize; @@ -86,7 +86,8 @@ void merge_sort(Tp *first, uint32_t n, uint8_t *scratch, Compare comp) { // workitem idx will merge chunk 2 * idx and 2 * idx + 1 Tp *data_in = data_in_scratch ? scratch1 : first; Tp *data_out = data_in_scratch ? first : scratch1; - merge(data_in, data_out, idx, merge_size, chunks_to_merge, n, comp); + merge(data_in, data_out, idx, merge_size, chunks_to_merge, n, + comp); group_barrier(); chunks_to_merge = (chunks_to_merge - 1) / 2 + 1; merge_size <<= 1; @@ -99,6 +100,47 @@ void merge_sort(Tp *first, uint32_t n, uint8_t *scratch, Compare comp) { } } +// Each work-item holds some input elements located in private memory and apply +// group sorting to all work-items' input. The sorted data will be copied back +// to each work-item's private memory. +// Assumption about scratch memory size: +// scratch_size >= n * wg_size * sizeof(Tp) * 2 +template +void private_merge_sort_close(Tp *first, uint32_t n, uint8_t *scratch, + Compare comp) { + const size_t idx = __get_wg_local_linear_id(); + const size_t wg_size = __get_wg_local_range(); + Tp *temp_buffer = reinterpret_cast(scratch); + for (size_t i = 0; i < n; ++i) + temp_buffer[idx * n + i] = first[i]; + + group_barrier(); + // do group sorting for whole input data + merge_sort(temp_buffer, n * wg_size, + reinterpret_cast(temp_buffer + n * wg_size), comp); + + for (size_t i = 0; i < n; ++i) + first[i] = temp_buffer[idx * n + i]; +} + +template +void private_merge_sort_spread(Tp *first, uint32_t n, uint8_t *scratch, + Compare comp) { + const size_t idx = __get_wg_local_linear_id(); + const size_t wg_size = __get_wg_local_range(); + Tp *temp_buffer = reinterpret_cast(scratch); + for (size_t i = 0; i < n; ++i) + temp_buffer[idx * n + i] = first[i]; + + group_barrier(); + // do group sorting for whole input data + merge_sort(temp_buffer, n * wg_size, + reinterpret_cast(temp_buffer + n * wg_size), comp); + + for (size_t i = 0; i < n; ++i) + first[i] = temp_buffer[i * wg_size + idx]; +} + #endif #endif From fd807094d1e224d8e9c5e0baf167820755b0a5f0 Mon Sep 17 00:00:00 2001 From: jinge90 Date: Mon, 4 Mar 2024 11:05:27 +0800 Subject: [PATCH 09/71] add private sorting for i16 Signed-off-by: jinge90 --- libdevice/fallback-gsort.cpp | 48 ++++++++++++++++++++++++++++++++++++ 1 file changed, 48 insertions(+) diff --git a/libdevice/fallback-gsort.cpp b/libdevice/fallback-gsort.cpp index 5dbeb74550fba..a1a318c96ae48 100644 --- a/libdevice/fallback-gsort.cpp +++ b/libdevice/fallback-gsort.cpp @@ -496,4 +496,52 @@ void __devicelib_default_work_group_private_sort_spread_descending_p1i8_u32_p3i8 private_merge_sort_spread(first, n, scratch, std::greater{}); } +DEVICE_EXTERN_C_INLINE +void __devicelib_default_work_group_private_sort_close_ascending_p1i16_u32_p1i8( + int16_t *first, uint32_t n, uint8_t *scratch) { + private_merge_sort_close(first, n, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +void __devicelib_default_work_group_private_sort_close_ascending_p1i16_u32_p3i8( + int16_t *first, uint32_t n, uint8_t *scratch) { + private_merge_sort_close(first, n, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +void __devicelib_default_work_group_private_sort_close_descending_p1i16_u32_p1i8( + int16_t *first, uint32_t n, uint8_t *scratch) { + private_merge_sort_close(first, n, scratch, std::greater{}); +} + +DEVICE_EXTERN_C_INLINE +void __devicelib_default_work_group_private_sort_close_descending_p1i16_u32_p3i8( + int16_t *first, uint32_t n, uint8_t *scratch) { + private_merge_sort_close(first, n, scratch, std::greater{}); +} + +DEVICE_EXTERN_C_INLINE +void __devicelib_default_work_group_private_sort_spread_ascending_p1i16_u32_p1i8( + int16_t *first, uint32_t n, uint8_t *scratch) { + private_merge_sort_spread(first, n, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +void __devicelib_default_work_group_private_sort_spread_ascending_p1i16_u32_p3i8( + int16_t *first, uint32_t n, uint8_t *scratch) { + private_merge_sort_spread(first, n, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +void __devicelib_default_work_group_private_sort_spread_descending_p1i16_u32_p1i8( + int16_t *first, uint32_t n, uint8_t *scratch) { + private_merge_sort_spread(first, n, scratch, std::greater{}); +} + +DEVICE_EXTERN_C_INLINE +void __devicelib_default_work_group_private_sort_spread_descending_p1i16_u32_p3i8( + int16_t *first, uint32_t n, uint8_t *scratch) { + private_merge_sort_spread(first, n, scratch, std::greater{}); +} + #endif From 38251b41217d1f51c34cb5853a46cfd90cb256f4 Mon Sep 17 00:00:00 2001 From: jinge90 Date: Mon, 4 Mar 2024 15:35:32 +0800 Subject: [PATCH 10/71] private sort for i64 Signed-off-by: jinge90 --- libdevice/fallback-gsort.cpp | 96 ++++++++++++++++++++++++++++++++++++ 1 file changed, 96 insertions(+) diff --git a/libdevice/fallback-gsort.cpp b/libdevice/fallback-gsort.cpp index a1a318c96ae48..5a074467111d2 100644 --- a/libdevice/fallback-gsort.cpp +++ b/libdevice/fallback-gsort.cpp @@ -544,4 +544,100 @@ void __devicelib_default_work_group_private_sort_spread_descending_p1i16_u32_p3i private_merge_sort_spread(first, n, scratch, std::greater{}); } +DEVICE_EXTERN_C_INLINE +void __devicelib_default_work_group_private_sort_close_ascending_p1i32_u32_p1i8( + int32_t *first, uint32_t n, uint8_t *scratch) { + private_merge_sort_close(first, n, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +void __devicelib_default_work_group_private_sort_close_ascending_p1i32_u32_p3i8( + int32_t *first, uint32_t n, uint8_t *scratch) { + private_merge_sort_close(first, n, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +void __devicelib_default_work_group_private_sort_close_descending_p1i32_u32_p1i8( + int32_t *first, uint32_t n, uint8_t *scratch) { + private_merge_sort_close(first, n, scratch, std::greater{}); +} + +DEVICE_EXTERN_C_INLINE +void __devicelib_default_work_group_private_sort_close_descending_p1i32_u32_p3i8( + int32_t *first, uint32_t n, uint8_t *scratch) { + private_merge_sort_close(first, n, scratch, std::greater{}); +} + +DEVICE_EXTERN_C_INLINE +void __devicelib_default_work_group_private_sort_spread_ascending_p1i32_u32_p1i8( + int32_t *first, uint32_t n, uint8_t *scratch) { + private_merge_sort_spread(first, n, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +void __devicelib_default_work_group_private_sort_spread_ascending_p1i32_u32_p3i8( + int32_t *first, uint32_t n, uint8_t *scratch) { + private_merge_sort_spread(first, n, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +void __devicelib_default_work_group_private_sort_spread_descending_p1i32_u32_p1i8( + int32_t *first, uint32_t n, uint8_t *scratch) { + private_merge_sort_spread(first, n, scratch, std::greater{}); +} + +DEVICE_EXTERN_C_INLINE +void __devicelib_default_work_group_private_sort_spread_descending_p1i32_u32_p3i8( + int32_t *first, uint32_t n, uint8_t *scratch) { + private_merge_sort_spread(first, n, scratch, std::greater{}); +} + +DEVICE_EXTERN_C_INLINE +void __devicelib_default_work_group_private_sort_close_ascending_p1i64_u32_p1i8( + int64_t *first, uint32_t n, uint8_t *scratch) { + private_merge_sort_close(first, n, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +void __devicelib_default_work_group_private_sort_close_ascending_p1i64_u32_p3i8( + int64_t *first, uint32_t n, uint8_t *scratch) { + private_merge_sort_close(first, n, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +void __devicelib_default_work_group_private_sort_close_descending_p1i64_u32_p1i8( + int64_t *first, uint32_t n, uint8_t *scratch) { + private_merge_sort_close(first, n, scratch, std::greater{}); +} + +DEVICE_EXTERN_C_INLINE +void __devicelib_default_work_group_private_sort_close_descending_p1i64_u32_p3i8( + int64_t *first, uint32_t n, uint8_t *scratch) { + private_merge_sort_close(first, n, scratch, std::greater{}); +} + +DEVICE_EXTERN_C_INLINE +void __devicelib_default_work_group_private_sort_spread_ascending_p1i64_u32_p1i8( + int64_t *first, uint32_t n, uint8_t *scratch) { + private_merge_sort_spread(first, n, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +void __devicelib_default_work_group_private_sort_spread_ascending_p1i64_u32_p3i8( + int64_t *first, uint32_t n, uint8_t *scratch) { + private_merge_sort_spread(first, n, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +void __devicelib_default_work_group_private_sort_spread_descending_p1i64_u32_p1i8( + int64_t *first, uint32_t n, uint8_t *scratch) { + private_merge_sort_spread(first, n, scratch, std::greater{}); +} + +DEVICE_EXTERN_C_INLINE +void __devicelib_default_work_group_private_sort_spread_descending_p1i64_u32_p3i8( + int64_t *first, uint32_t n, uint8_t *scratch) { + private_merge_sort_spread(first, n, scratch, std::greater{}); +} + #endif From 92d86f1239adc40a2bd4ae7ef62d4ad716091afd Mon Sep 17 00:00:00 2001 From: jinge90 Date: Mon, 4 Mar 2024 21:08:54 +0800 Subject: [PATCH 11/71] add private sort for uint Signed-off-by: jinge90 --- libdevice/fallback-gsort.cpp | 191 +++++++++++++++++++++++++++++++++++ 1 file changed, 191 insertions(+) diff --git a/libdevice/fallback-gsort.cpp b/libdevice/fallback-gsort.cpp index 5a074467111d2..2d1803cbf1af6 100644 --- a/libdevice/fallback-gsort.cpp +++ b/libdevice/fallback-gsort.cpp @@ -640,4 +640,195 @@ void __devicelib_default_work_group_private_sort_spread_descending_p1i64_u32_p3i private_merge_sort_spread(first, n, scratch, std::greater{}); } +DEVICE_EXTERN_C_INLINE +void __devicelib_default_work_group_private_sort_close_ascending_p1u8_u32_p1i8( + uint8_t *first, uint32_t n, uint8_t *scratch) { + private_merge_sort_close(first, n, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +void __devicelib_default_work_group_private_sort_close_ascending_p1u8_u32_p3i8( + uint8_t *first, uint32_t n, uint8_t *scratch) { + private_merge_sort_close(first, n, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +void __devicelib_default_work_group_private_sort_close_descending_p1u8_u32_p1i8( + uint8_t *first, uint32_t n, uint8_t *scratch) { + private_merge_sort_close(first, n, scratch, std::greater{}); +} + +DEVICE_EXTERN_C_INLINE +void __devicelib_default_work_group_private_sort_close_descending_p1u8_u32_p3i8( + uint8_t *first, uint32_t n, uint8_t *scratch) { + private_merge_sort_close(first, n, scratch, std::greater{}); +} + +DEVICE_EXTERN_C_INLINE +void __devicelib_default_work_group_private_sort_spread_ascending_p1u8_u32_p1i8( + uint8_t *first, uint32_t n, uint8_t *scratch) { + private_merge_sort_spread(first, n, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +void __devicelib_default_work_group_private_sort_spread_ascending_p1u8_u32_p3i8( + uint8_t *first, uint32_t n, uint8_t *scratch) { + private_merge_sort_spread(first, n, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +void __devicelib_default_work_group_private_sort_spread_descending_p1u8_u32_p1i8( + uint8_t *first, uint32_t n, uint8_t *scratch) { + private_merge_sort_spread(first, n, scratch, std::greater{}); +} + +DEVICE_EXTERN_C_INLINE +void __devicelib_default_work_group_private_sort_spread_descending_p1u8_u32_p3i8( + uint8_t *first, uint32_t n, uint8_t *scratch) { + private_merge_sort_spread(first, n, scratch, std::greater{}); +} + +DEVICE_EXTERN_C_INLINE +void __devicelib_default_work_group_private_sort_close_ascending_p1u16_u32_p1i8( + uint16_t *first, uint32_t n, uint8_t *scratch) { + private_merge_sort_close(first, n, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +void __devicelib_default_work_group_private_sort_close_ascending_p1u16_u32_p3i8( + uint16_t *first, uint32_t n, uint8_t *scratch) { + private_merge_sort_close(first, n, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +void __devicelib_default_work_group_private_sort_close_descending_p1u16_u32_p1i8( + uint16_t *first, uint32_t n, uint8_t *scratch) { + private_merge_sort_close(first, n, scratch, std::greater{}); +} + +DEVICE_EXTERN_C_INLINE +void __devicelib_default_work_group_private_sort_close_descending_p1u16_u32_p3i8( + uint16_t *first, uint32_t n, uint8_t *scratch) { + private_merge_sort_close(first, n, scratch, std::greater{}); +} + +DEVICE_EXTERN_C_INLINE +void __devicelib_default_work_group_private_sort_spread_ascending_p1u16_u32_p1i8( + uint16_t *first, uint32_t n, uint8_t *scratch) { + private_merge_sort_spread(first, n, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +void __devicelib_default_work_group_private_sort_spread_ascending_p1u16_u32_p3i8( + uint16_t *first, uint32_t n, uint8_t *scratch) { + private_merge_sort_spread(first, n, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +void __devicelib_default_work_group_private_sort_spread_descending_p1u16_u32_p1i8( + uint16_t *first, uint32_t n, uint8_t *scratch) { + private_merge_sort_spread(first, n, scratch, std::greater{}); +} + +DEVICE_EXTERN_C_INLINE +void __devicelib_default_work_group_private_sort_spread_descending_p1u16_u32_p3i8( + uint16_t *first, uint32_t n, uint8_t *scratch) { + private_merge_sort_spread(first, n, scratch, std::greater{}); +} + +DEVICE_EXTERN_C_INLINE +void __devicelib_default_work_group_private_sort_close_ascending_p1u32_u32_p1i8( + uint32_t *first, uint32_t n, uint8_t *scratch) { + private_merge_sort_close(first, n, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +void __devicelib_default_work_group_private_sort_close_ascending_p1u32_u32_p3i8( + uint32_t *first, uint32_t n, uint8_t *scratch) { + private_merge_sort_close(first, n, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +void __devicelib_default_work_group_private_sort_close_descending_p1u32_u32_p1i8( + uint32_t *first, uint32_t n, uint8_t *scratch) { + private_merge_sort_close(first, n, scratch, std::greater{}); +} + +DEVICE_EXTERN_C_INLINE +void __devicelib_default_work_group_private_sort_close_descending_p1u32_u32_p3i8( + uint32_t *first, uint32_t n, uint8_t *scratch) { + private_merge_sort_close(first, n, scratch, std::greater{}); +} + +DEVICE_EXTERN_C_INLINE +void __devicelib_default_work_group_private_sort_spread_ascending_p1u32_u32_p1i8( + uint32_t *first, uint32_t n, uint8_t *scratch) { + private_merge_sort_spread(first, n, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +void __devicelib_default_work_group_private_sort_spread_ascending_p1u32_u32_p3i8( + uint32_t *first, uint32_t n, uint8_t *scratch) { + private_merge_sort_spread(first, n, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +void __devicelib_default_work_group_private_sort_spread_descending_p1u32_u32_p1i8( + uint32_t *first, uint32_t n, uint8_t *scratch) { + private_merge_sort_spread(first, n, scratch, std::greater{}); +} + +DEVICE_EXTERN_C_INLINE +void __devicelib_default_work_group_private_sort_spread_descending_p1u32_u32_p3i8( + uint32_t *first, uint32_t n, uint8_t *scratch) { + private_merge_sort_spread(first, n, scratch, std::greater{}); +} + +DEVICE_EXTERN_C_INLINE +void __devicelib_default_work_group_private_sort_close_ascending_p1u64_u32_p1i8( + uint64_t *first, uint32_t n, uint8_t *scratch) { + private_merge_sort_close(first, n, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +void __devicelib_default_work_group_private_sort_close_ascending_p1u64_u32_p3i8( + uint64_t *first, uint32_t n, uint8_t *scratch) { + private_merge_sort_close(first, n, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +void __devicelib_default_work_group_private_sort_close_descending_p1u64_u32_p1i8( + uint64_t *first, uint32_t n, uint8_t *scratch) { + private_merge_sort_close(first, n, scratch, std::greater{}); +} + +DEVICE_EXTERN_C_INLINE +void __devicelib_default_work_group_private_sort_close_descending_p1u64_u32_p3i8( + uint64_t *first, uint32_t n, uint8_t *scratch) { + private_merge_sort_close(first, n, scratch, std::greater{}); +} + +DEVICE_EXTERN_C_INLINE +void __devicelib_default_work_group_private_sort_spread_ascending_p1u64_u32_p1i8( + uint64_t *first, uint32_t n, uint8_t *scratch) { + private_merge_sort_spread(first, n, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +void __devicelib_default_work_group_private_sort_spread_ascending_p1u64_u32_p3i8( + uint64_t *first, uint32_t n, uint8_t *scratch) { + private_merge_sort_spread(first, n, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +void __devicelib_default_work_group_private_sort_spread_descending_p1u64_u32_p1i8( + uint64_t *first, uint32_t n, uint8_t *scratch) { + private_merge_sort_spread(first, n, scratch, std::greater{}); +} + +DEVICE_EXTERN_C_INLINE +void __devicelib_default_work_group_private_sort_spread_descending_p1u64_u32_p3i8( + uint64_t *first, uint32_t n, uint8_t *scratch) { + private_merge_sort_spread(first, n, scratch, std::greater{}); +} #endif From 2a1124ac74a094025684378d67feeacf639990c8 Mon Sep 17 00:00:00 2001 From: jinge90 Date: Tue, 5 Mar 2024 13:41:34 +0800 Subject: [PATCH 12/71] add private sort for fp32 Signed-off-by: jinge90 --- libdevice/fallback-gsort.cpp | 53 ++++++++++++++++++++++++++++++++++++ 1 file changed, 53 insertions(+) diff --git a/libdevice/fallback-gsort.cpp b/libdevice/fallback-gsort.cpp index 2d1803cbf1af6..6bd6a9e57f9ce 100644 --- a/libdevice/fallback-gsort.cpp +++ b/libdevice/fallback-gsort.cpp @@ -640,6 +640,8 @@ void __devicelib_default_work_group_private_sort_spread_descending_p1i64_u32_p3i private_merge_sort_spread(first, n, scratch, std::greater{}); } +//=========== default work grop private sort for unsigned integer ============= + DEVICE_EXTERN_C_INLINE void __devicelib_default_work_group_private_sort_close_ascending_p1u8_u32_p1i8( uint8_t *first, uint32_t n, uint8_t *scratch) { @@ -831,4 +833,55 @@ void __devicelib_default_work_group_private_sort_spread_descending_p1u64_u32_p3i uint64_t *first, uint32_t n, uint8_t *scratch) { private_merge_sort_spread(first, n, scratch, std::greater{}); } + +//================= default work grop private sort for fp32 ==================== + +DEVICE_EXTERN_C_INLINE +void __devicelib_default_work_group_private_sort_close_ascending_p1f32_u32_p1i8( + float *first, uint32_t n, uint8_t *scratch) { + private_merge_sort_close(first, n, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +void __devicelib_default_work_group_private_sort_close_ascending_p1f32_u32_p3i8( + float *first, uint32_t n, uint8_t *scratch) { + private_merge_sort_close(first, n, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +void __devicelib_default_work_group_private_sort_close_descending_p1f32_u32_p1i8( + float *first, uint32_t n, uint8_t *scratch) { + private_merge_sort_close(first, n, scratch, std::greater{}); +} + +DEVICE_EXTERN_C_INLINE +void __devicelib_default_work_group_private_sort_close_descending_p1f32_u32_p3i8( + float *first, uint32_t n, uint8_t *scratch) { + private_merge_sort_close(first, n, scratch, std::greater{}); +} + +DEVICE_EXTERN_C_INLINE +void __devicelib_default_work_group_private_sort_spread_ascending_p1f32_u32_p1i8( + float *first, uint32_t n, uint8_t *scratch) { + private_merge_sort_spread(first, n, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +void __devicelib_default_work_group_private_sort_spread_ascending_p1f32_u32_p3i8( + float *first, uint32_t n, uint8_t *scratch) { + private_merge_sort_spread(first, n, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +void __devicelib_default_work_group_private_sort_spread_descending_p1f32_u32_p1i8( + float *first, uint32_t n, uint8_t *scratch) { + private_merge_sort_spread(first, n, scratch, std::greater{}); +} + +DEVICE_EXTERN_C_INLINE +void __devicelib_default_work_group_private_sort_spread_descending_p1f32_u32_p3i8( + float *first, uint32_t n, uint8_t *scratch) { + private_merge_sort_spread(first, n, scratch, std::greater{}); +} + #endif From 9ef76929245541bfbfccaf54d30268bb385b690f Mon Sep 17 00:00:00 2001 From: jinge90 Date: Wed, 6 Mar 2024 17:50:42 +0800 Subject: [PATCH 13/71] add sub group private sort Signed-off-by: jinge90 --- libdevice/fallback-gsort.cpp | 109 +++++++++++++++++++++++++++++++++++ libdevice/sort_helper.hpp | 17 ++++++ 2 files changed, 126 insertions(+) diff --git a/libdevice/fallback-gsort.cpp b/libdevice/fallback-gsort.cpp index 6bd6a9e57f9ce..602a9994e9fc9 100644 --- a/libdevice/fallback-gsort.cpp +++ b/libdevice/fallback-gsort.cpp @@ -884,4 +884,113 @@ void __devicelib_default_work_group_private_sort_spread_descending_p1f32_u32_p3i private_merge_sort_spread(first, n, scratch, std::greater{}); } +//============= default sub group private sort for signed integer ============= +DEVICE_EXTERN_C_INLINE +int8_t __devicelib_default_sub_group_private_sort_ascending_i8( + int8_t value, uint8_t *scratch) { + return sub_group_merge_sort(value, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +int16_t __devicelib_default_sub_group_private_sort_ascending_i16( + int16_t value, uint8_t *scratch) { + return sub_group_merge_sort(value, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +int32_t __devicelib_default_sub_group_private_sort_ascending_i32( + int32_t value, uint8_t *scratch) { + return sub_group_merge_sort(value, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +int64_t __devicelib_default_sub_group_private_sort_ascending_i64( + int64_t value, uint8_t *scratch) { + return sub_group_merge_sort(value, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +uint8_t __devicelib_default_sub_group_private_sort_ascending_u8( + uint8_t value, uint8_t *scratch) { + return sub_group_merge_sort(value, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +uint16_t __devicelib_default_sub_group_private_sort_ascending_u16( + uint16_t value, uint8_t *scratch) { + return sub_group_merge_sort(value, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +uint32_t __devicelib_default_sub_group_private_sort_ascending_u32( + uint32_t value, uint8_t *scratch) { + return sub_group_merge_sort(value, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +uint64_t __devicelib_default_sub_group_private_sort_ascending_u64( + uint64_t value, uint8_t *scratch) { + return sub_group_merge_sort(value, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +float __devicelib_default_sub_group_private_sort_ascending_f32( + float value, uint8_t *scratch) { + return sub_group_merge_sort(value, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +int8_t __devicelib_default_sub_group_private_sort_descending_i8( + int8_t value, uint8_t *scratch) { + return sub_group_merge_sort(value, scratch, std::greater{}); +} + +DEVICE_EXTERN_C_INLINE +int16_t __devicelib_default_sub_group_private_sort_descending_i16( + int16_t value, uint8_t *scratch) { + return sub_group_merge_sort(value, scratch, std::greater{}); +} + +DEVICE_EXTERN_C_INLINE +int32_t __devicelib_default_sub_group_private_sort_descending_i32( + int32_t value, uint8_t *scratch) { + return sub_group_merge_sort(value, scratch, std::greater{}); +} + +DEVICE_EXTERN_C_INLINE +int64_t __devicelib_default_sub_group_private_sort_descending_i64( + int64_t value, uint8_t *scratch) { + return sub_group_merge_sort(value, scratch, std::greater{}); +} + +DEVICE_EXTERN_C_INLINE +uint8_t __devicelib_default_sub_group_private_sort_descending_u8( + uint8_t value, uint8_t *scratch) { + return sub_group_merge_sort(value, scratch, std::greater{}); +} + +DEVICE_EXTERN_C_INLINE +uint16_t __devicelib_default_sub_group_private_sort_descending_u16( + uint16_t value, uint8_t *scratch) { + return sub_group_merge_sort(value, scratch, std::greater{}); +} + +DEVICE_EXTERN_C_INLINE +uint32_t __devicelib_default_sub_group_private_sort_descending_u32( + uint32_t value, uint8_t *scratch) { + return sub_group_merge_sort(value, scratch, std::greater{}); +} + +DEVICE_EXTERN_C_INLINE +uint64_t __devicelib_default_sub_group_private_sort_descending_u64( + uint64_t value, uint8_t *scratch) { + return sub_group_merge_sort(value, scratch, std::greater{}); +} + +DEVICE_EXTERN_C_INLINE +float __devicelib_default_sub_group_private_sort_descending_f32( + float value, uint8_t *scratch) { + return sub_group_merge_sort(value, scratch, std::greater{}); +} + #endif diff --git a/libdevice/sort_helper.hpp b/libdevice/sort_helper.hpp index 740033d9bc80d..887218b64a6c8 100644 --- a/libdevice/sort_helper.hpp +++ b/libdevice/sort_helper.hpp @@ -141,6 +141,23 @@ void private_merge_sort_spread(Tp *first, uint32_t n, uint8_t *scratch, first[i] = temp_buffer[i * wg_size + idx]; } +// sub group sort implementation, each work-item holds an element, the total +// number of input elements is work group size. +// Assumption about scratch memory size: +// scratch_size >= wg_size * sizeof(Tp) * 2 +template +Tp sub_group_merge_sort(Tp value, uint8_t *scratch, Compare comp) { + const size_t idx = __get_wg_local_linear_id(); + const size_t wg_size = __get_wg_local_range(); + Tp *temp_buffer = reinterpret_cast(scratch); + temp_buffer[idx] = value; + + group_barrier(); + merge_sort(temp_buffer, wg_size, + reinterpret_cast(temp_buffer + wg_size), comp); + return temp_buffer[idx]; +} + #endif #endif From 30254306799584a690b121b756851293e7dd894f Mon Sep 17 00:00:00 2001 From: jinge90 Date: Mon, 11 Mar 2024 13:59:58 +0800 Subject: [PATCH 14/71] add utils for fp16 comparison Signed-off-by: jinge90 --- libdevice/fallback-gsort.cpp | 1 + libdevice/sort_helper.hpp | 36 ++++++++++++++++++++++++++++++++++++ 2 files changed, 37 insertions(+) diff --git a/libdevice/fallback-gsort.cpp b/libdevice/fallback-gsort.cpp index 602a9994e9fc9..99780d9b8f1a0 100644 --- a/libdevice/fallback-gsort.cpp +++ b/libdevice/fallback-gsort.cpp @@ -396,6 +396,7 @@ void __devicelib_default_work_group_joint_sort_descending_p3u64_u32_p3i8( uint64_t *first, uint32_t n, uint8_t *scratch) { merge_sort(first, n, scratch, std::greater{}); } + //=============== default work grop joint sort for fp32 ====================== DEVICE_EXTERN_C_INLINE void __devicelib_default_work_group_joint_sort_ascending_p1f32_u32_p1i8( diff --git a/libdevice/sort_helper.hpp b/libdevice/sort_helper.hpp index 887218b64a6c8..96282a01c6286 100644 --- a/libdevice/sort_helper.hpp +++ b/libdevice/sort_helper.hpp @@ -12,6 +12,42 @@ #include #if defined(__SPIR__) +// A simple compare function for fp16 type emulated with uint16_t +// 1: great than +// 0: equal to +// -1: less than +// -2: can't compare(NAN) +int fp16_comp(uint16_t a, uint16_t b) { + uint16_t a_sig = a >> 15; + uint16_t a_exp = (a & 0x7fff) >> 10; + uint16_t a_fra = a & 0x3ff; + uint16_t b_sig = b >> 15; + uint16_t b_exp = (b & 0x7fff) >> 10; + uint16_t b_fra = b & 0x3ff; + if (((a_exp == 0x1f) && (a_fra != 0x0)) || ((b_exp == 0x1f) && (b_fra != 0x0))) + return -2; + + if ((a_sig == 0) && (b_sig == 1)) + return 1; + + if ((a_sig == 1) && (b_sig == 0)) + return -1; + + if (a_exp > b_exp) + return (a_sig == 0) ? 1 : -1; + + if (a_exp < b_exp) + return (a_sig == 0) ? -1 : 1; + + if (a_fra == b_fra) + return 0; + + if (a_sig == 0) + return (a_fra > b_fra) ? 1 : -1; + else + return (a_fra > b_fra) ? -1 : 1; +} + template void bubble_sort(Tp *first, const size_t beg, const size_t end, Compare comp) { if (beg < end) { From 455d3439b116fa784159d92199f74f9336553e2f Mon Sep 17 00:00:00 2001 From: jinge90 Date: Mon, 11 Mar 2024 15:02:08 +0800 Subject: [PATCH 15/71] add work group joint sort for f16 Signed-off-by: jinge90 --- libdevice/fallback-gsort.cpp | 137 +++++++++++++++++++++++++++-------- 1 file changed, 105 insertions(+), 32 deletions(-) diff --git a/libdevice/fallback-gsort.cpp b/libdevice/fallback-gsort.cpp index 99780d9b8f1a0..2967d8d513da4 100644 --- a/libdevice/fallback-gsort.cpp +++ b/libdevice/fallback-gsort.cpp @@ -446,6 +446,63 @@ void __devicelib_default_work_group_joint_sort_descending_p3f32_u32_p3i8( merge_sort(first, n, scratch, std::greater{}); } +//=============== default work grop joint sort for fp16 ====================== +DEVICE_EXTERN_C_INLINE +void __devicelib_default_work_group_joint_sort_ascending_p1f16_u32_p1i8( + uint16_t *first, uint32_t n, uint8_t *scratch) { + merge_sort(first, n, scratch, + [](uint16_t a, uint16_t b) { return (fp16_comp(a, b) == -1); }); +} + +DEVICE_EXTERN_C_INLINE +void __devicelib_default_work_group_joint_sort_ascending_p1f16_u32_p3i8( + uint16_t *first, uint32_t n, uint8_t *scratch) { + merge_sort(first, n, scratch, + [](uint16_t a, uint16_t b) { return (fp16_comp(a, b) == -1); }); +} + +DEVICE_EXTERN_C_INLINE +void __devicelib_default_work_group_joint_sort_ascending_p3f16_u32_p1i8( + uint16_t *first, uint32_t n, uint8_t *scratch) { + merge_sort(first, n, scratch, + [](uint16_t a, uint16_t b) { return (fp16_comp(a, b) == -1); }); +} + +DEVICE_EXTERN_C_INLINE +void __devicelib_default_work_group_joint_sort_ascending_p3f16_u32_p3i8( + uint16_t *first, uint32_t n, uint8_t *scratch) { + merge_sort(first, n, scratch, + [](uint16_t a, uint16_t b) { return (fp16_comp(a, b) == -1); }); +} + +DEVICE_EXTERN_C_INLINE +void __devicelib_default_work_group_joint_sort_descending_p1f16_u32_p1i8( + uint16_t *first, uint32_t n, uint8_t *scratch) { + merge_sort(first, n, scratch, + [](uint16_t a, uint16_t b) { return (fp16_comp(a, b) == 1); }); +} + +DEVICE_EXTERN_C_INLINE +void __devicelib_default_work_group_joint_sort_descending_p1f16_u32_p3i8( + uint16_t *first, uint32_t n, uint8_t *scratch) { + merge_sort(first, n, scratch, + [](uint16_t a, uint16_t b) { return (fp16_comp(a, b) == 1); }); +} + +DEVICE_EXTERN_C_INLINE +void __devicelib_default_work_group_joint_sort_descending_p3f16_u32_p1i8( + uint16_t *first, uint32_t n, uint8_t *scratch) { + merge_sort(first, n, scratch, + [](uint16_t a, uint16_t b) { return (fp16_comp(a, b) == 1); }); +} + +DEVICE_EXTERN_C_INLINE +void __devicelib_default_work_group_joint_sort_descending_p3f16_u32_p3i8( + uint16_t *first, uint32_t n, uint8_t *scratch) { + merge_sort(first, n, scratch, + [](uint16_t a, uint16_t b) { return (fp16_comp(a, b) == 1); }); +} + //============ default work grop private sort for signed integer ============== // Since 'first' should point to 'private' memory address space, it can only be // decorated with 'p1'. @@ -887,50 +944,58 @@ void __devicelib_default_work_group_private_sort_spread_descending_p1f32_u32_p3i //============= default sub group private sort for signed integer ============= DEVICE_EXTERN_C_INLINE -int8_t __devicelib_default_sub_group_private_sort_ascending_i8( - int8_t value, uint8_t *scratch) { +int8_t +__devicelib_default_sub_group_private_sort_ascending_i8(int8_t value, + uint8_t *scratch) { return sub_group_merge_sort(value, scratch, std::less{}); } DEVICE_EXTERN_C_INLINE -int16_t __devicelib_default_sub_group_private_sort_ascending_i16( - int16_t value, uint8_t *scratch) { +int16_t +__devicelib_default_sub_group_private_sort_ascending_i16(int16_t value, + uint8_t *scratch) { return sub_group_merge_sort(value, scratch, std::less{}); } DEVICE_EXTERN_C_INLINE -int32_t __devicelib_default_sub_group_private_sort_ascending_i32( - int32_t value, uint8_t *scratch) { +int32_t +__devicelib_default_sub_group_private_sort_ascending_i32(int32_t value, + uint8_t *scratch) { return sub_group_merge_sort(value, scratch, std::less{}); } DEVICE_EXTERN_C_INLINE -int64_t __devicelib_default_sub_group_private_sort_ascending_i64( - int64_t value, uint8_t *scratch) { +int64_t +__devicelib_default_sub_group_private_sort_ascending_i64(int64_t value, + uint8_t *scratch) { return sub_group_merge_sort(value, scratch, std::less{}); } DEVICE_EXTERN_C_INLINE -uint8_t __devicelib_default_sub_group_private_sort_ascending_u8( - uint8_t value, uint8_t *scratch) { +uint8_t +__devicelib_default_sub_group_private_sort_ascending_u8(uint8_t value, + uint8_t *scratch) { return sub_group_merge_sort(value, scratch, std::less{}); } DEVICE_EXTERN_C_INLINE -uint16_t __devicelib_default_sub_group_private_sort_ascending_u16( - uint16_t value, uint8_t *scratch) { +uint16_t +__devicelib_default_sub_group_private_sort_ascending_u16(uint16_t value, + uint8_t *scratch) { return sub_group_merge_sort(value, scratch, std::less{}); } DEVICE_EXTERN_C_INLINE -uint32_t __devicelib_default_sub_group_private_sort_ascending_u32( - uint32_t value, uint8_t *scratch) { +uint32_t +__devicelib_default_sub_group_private_sort_ascending_u32(uint32_t value, + uint8_t *scratch) { return sub_group_merge_sort(value, scratch, std::less{}); } DEVICE_EXTERN_C_INLINE -uint64_t __devicelib_default_sub_group_private_sort_ascending_u64( - uint64_t value, uint8_t *scratch) { +uint64_t +__devicelib_default_sub_group_private_sort_ascending_u64(uint64_t value, + uint8_t *scratch) { return sub_group_merge_sort(value, scratch, std::less{}); } @@ -941,50 +1006,58 @@ float __devicelib_default_sub_group_private_sort_ascending_f32( } DEVICE_EXTERN_C_INLINE -int8_t __devicelib_default_sub_group_private_sort_descending_i8( - int8_t value, uint8_t *scratch) { +int8_t +__devicelib_default_sub_group_private_sort_descending_i8(int8_t value, + uint8_t *scratch) { return sub_group_merge_sort(value, scratch, std::greater{}); } DEVICE_EXTERN_C_INLINE -int16_t __devicelib_default_sub_group_private_sort_descending_i16( - int16_t value, uint8_t *scratch) { +int16_t +__devicelib_default_sub_group_private_sort_descending_i16(int16_t value, + uint8_t *scratch) { return sub_group_merge_sort(value, scratch, std::greater{}); } DEVICE_EXTERN_C_INLINE -int32_t __devicelib_default_sub_group_private_sort_descending_i32( - int32_t value, uint8_t *scratch) { +int32_t +__devicelib_default_sub_group_private_sort_descending_i32(int32_t value, + uint8_t *scratch) { return sub_group_merge_sort(value, scratch, std::greater{}); } DEVICE_EXTERN_C_INLINE -int64_t __devicelib_default_sub_group_private_sort_descending_i64( - int64_t value, uint8_t *scratch) { +int64_t +__devicelib_default_sub_group_private_sort_descending_i64(int64_t value, + uint8_t *scratch) { return sub_group_merge_sort(value, scratch, std::greater{}); } DEVICE_EXTERN_C_INLINE -uint8_t __devicelib_default_sub_group_private_sort_descending_u8( - uint8_t value, uint8_t *scratch) { +uint8_t +__devicelib_default_sub_group_private_sort_descending_u8(uint8_t value, + uint8_t *scratch) { return sub_group_merge_sort(value, scratch, std::greater{}); } DEVICE_EXTERN_C_INLINE -uint16_t __devicelib_default_sub_group_private_sort_descending_u16( - uint16_t value, uint8_t *scratch) { +uint16_t +__devicelib_default_sub_group_private_sort_descending_u16(uint16_t value, + uint8_t *scratch) { return sub_group_merge_sort(value, scratch, std::greater{}); } DEVICE_EXTERN_C_INLINE -uint32_t __devicelib_default_sub_group_private_sort_descending_u32( - uint32_t value, uint8_t *scratch) { +uint32_t +__devicelib_default_sub_group_private_sort_descending_u32(uint32_t value, + uint8_t *scratch) { return sub_group_merge_sort(value, scratch, std::greater{}); } DEVICE_EXTERN_C_INLINE -uint64_t __devicelib_default_sub_group_private_sort_descending_u64( - uint64_t value, uint8_t *scratch) { +uint64_t +__devicelib_default_sub_group_private_sort_descending_u64(uint64_t value, + uint8_t *scratch) { return sub_group_merge_sort(value, scratch, std::greater{}); } From 15444c22f982d12f431babc04993cb5b5ef3a577 Mon Sep 17 00:00:00 2001 From: jinge90 Date: Mon, 11 Mar 2024 17:33:37 +0800 Subject: [PATCH 16/71] use native f16 type Signed-off-by: jinge90 --- libdevice/fallback-gsort.cpp | 119 +++++++++++++++++++++++++++-------- libdevice/sort_helper.hpp | 36 ----------- 2 files changed, 93 insertions(+), 62 deletions(-) diff --git a/libdevice/fallback-gsort.cpp b/libdevice/fallback-gsort.cpp index 2967d8d513da4..e7cb36546c13a 100644 --- a/libdevice/fallback-gsort.cpp +++ b/libdevice/fallback-gsort.cpp @@ -1,6 +1,5 @@ -//==--- fallback_gsort_fp32.cpp - fallback implementation of group sort -//-----==// +//==------ fallback_gsort.cpp - fallback implementation of group sort-------==// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -446,61 +445,55 @@ void __devicelib_default_work_group_joint_sort_descending_p3f32_u32_p3i8( merge_sort(first, n, scratch, std::greater{}); } +// TODO: split all f16 functions into separate libraries in case some platform +// doesn't support native fp16 //=============== default work grop joint sort for fp16 ====================== DEVICE_EXTERN_C_INLINE void __devicelib_default_work_group_joint_sort_ascending_p1f16_u32_p1i8( - uint16_t *first, uint32_t n, uint8_t *scratch) { - merge_sort(first, n, scratch, - [](uint16_t a, uint16_t b) { return (fp16_comp(a, b) == -1); }); + _Float16 *first, uint32_t n, uint8_t *scratch) { + merge_sort(first, n, scratch, [](_Float16 a, _Float16 b) { return (a < b); }); } DEVICE_EXTERN_C_INLINE void __devicelib_default_work_group_joint_sort_ascending_p1f16_u32_p3i8( - uint16_t *first, uint32_t n, uint8_t *scratch) { - merge_sort(first, n, scratch, - [](uint16_t a, uint16_t b) { return (fp16_comp(a, b) == -1); }); + _Float16 *first, uint32_t n, uint8_t *scratch) { + merge_sort(first, n, scratch, [](_Float16 a, _Float16 b) { return (a < b); }); } DEVICE_EXTERN_C_INLINE void __devicelib_default_work_group_joint_sort_ascending_p3f16_u32_p1i8( - uint16_t *first, uint32_t n, uint8_t *scratch) { - merge_sort(first, n, scratch, - [](uint16_t a, uint16_t b) { return (fp16_comp(a, b) == -1); }); + _Float16 *first, uint32_t n, uint8_t *scratch) { + merge_sort(first, n, scratch, [](_Float16 a, _Float16 b) { return (a < b); }); } DEVICE_EXTERN_C_INLINE void __devicelib_default_work_group_joint_sort_ascending_p3f16_u32_p3i8( - uint16_t *first, uint32_t n, uint8_t *scratch) { - merge_sort(first, n, scratch, - [](uint16_t a, uint16_t b) { return (fp16_comp(a, b) == -1); }); + _Float16 *first, uint32_t n, uint8_t *scratch) { + merge_sort(first, n, scratch, [](_Float16 a, _Float16 b) { return (a < b); }); } DEVICE_EXTERN_C_INLINE void __devicelib_default_work_group_joint_sort_descending_p1f16_u32_p1i8( - uint16_t *first, uint32_t n, uint8_t *scratch) { - merge_sort(first, n, scratch, - [](uint16_t a, uint16_t b) { return (fp16_comp(a, b) == 1); }); + _Float16 *first, uint32_t n, uint8_t *scratch) { + merge_sort(first, n, scratch, [](_Float16 a, _Float16 b) { return (a > b); }); } DEVICE_EXTERN_C_INLINE void __devicelib_default_work_group_joint_sort_descending_p1f16_u32_p3i8( - uint16_t *first, uint32_t n, uint8_t *scratch) { - merge_sort(first, n, scratch, - [](uint16_t a, uint16_t b) { return (fp16_comp(a, b) == 1); }); + _Float16 *first, uint32_t n, uint8_t *scratch) { + merge_sort(first, n, scratch, [](_Float16 a, _Float16 b) { return (a > b); }); } DEVICE_EXTERN_C_INLINE void __devicelib_default_work_group_joint_sort_descending_p3f16_u32_p1i8( - uint16_t *first, uint32_t n, uint8_t *scratch) { - merge_sort(first, n, scratch, - [](uint16_t a, uint16_t b) { return (fp16_comp(a, b) == 1); }); + _Float16 *first, uint32_t n, uint8_t *scratch) { + merge_sort(first, n, scratch, [](_Float16 a, _Float16 b) { return (a > b); }); } DEVICE_EXTERN_C_INLINE void __devicelib_default_work_group_joint_sort_descending_p3f16_u32_p3i8( - uint16_t *first, uint32_t n, uint8_t *scratch) { - merge_sort(first, n, scratch, - [](uint16_t a, uint16_t b) { return (fp16_comp(a, b) == 1); }); + _Float16 *first, uint32_t n, uint8_t *scratch) { + merge_sort(first, n, scratch, [](_Float16 a, _Float16 b) { return (a > b); }); } //============ default work grop private sort for signed integer ============== @@ -942,6 +935,64 @@ void __devicelib_default_work_group_private_sort_spread_descending_p1f32_u32_p3i private_merge_sort_spread(first, n, scratch, std::greater{}); } +//================= default work grop private sort for fp16 ==================== + +DEVICE_EXTERN_C_INLINE +void __devicelib_default_work_group_private_sort_close_ascending_p1f16_u32_p1i8( + _Float16 *first, uint32_t n, uint8_t *scratch) { + private_merge_sort_close(first, n, scratch, + [](_Float16 a, _Float16 b) { return (a < b); }); +} + +DEVICE_EXTERN_C_INLINE +void __devicelib_default_work_group_private_sort_close_ascending_p1f16_u32_p3i8( + _Float16 *first, uint32_t n, uint8_t *scratch) { + private_merge_sort_close(first, n, scratch, + [](_Float16 a, _Float16 b) { return (a < b); }); +} + +DEVICE_EXTERN_C_INLINE +void __devicelib_default_work_group_private_sort_close_descending_p1f16_u32_p1i8( + _Float16 *first, uint32_t n, uint8_t *scratch) { + private_merge_sort_close(first, n, scratch, + [](_Float16 a, _Float16 b) { return (a > b); }); +} + +DEVICE_EXTERN_C_INLINE +void __devicelib_default_work_group_private_sort_close_descending_p1f16_u32_p3i8( + _Float16 *first, uint32_t n, uint8_t *scratch) { + private_merge_sort_close(first, n, scratch, + [](_Float16 a, _Float16 b) { return (a > b); }); +} + +DEVICE_EXTERN_C_INLINE +void __devicelib_default_work_group_private_sort_spread_ascending_p1f16_u32_p1i8( + _Float16 *first, uint32_t n, uint8_t *scratch) { + private_merge_sort_spread(first, n, scratch, + [](_Float16 a, _Float16 b) { return (a < b); }); +} + +DEVICE_EXTERN_C_INLINE +void __devicelib_default_work_group_private_sort_spread_ascending_p1f16_u32_p3i8( + _Float16 *first, uint32_t n, uint8_t *scratch) { + private_merge_sort_spread(first, n, scratch, + [](_Float16 a, _Float16 b) { return (a < b); }); +} + +DEVICE_EXTERN_C_INLINE +void __devicelib_default_work_group_private_sort_spread_descending_p1f16_u32_p1i8( + _Float16 *first, uint32_t n, uint8_t *scratch) { + private_merge_sort_spread(first, n, scratch, + [](_Float16 a, _Float16 b) { return (a > b); }); +} + +DEVICE_EXTERN_C_INLINE +void __devicelib_default_work_group_private_sort_spread_descending_p1f16_u32_p3i8( + _Float16 *first, uint32_t n, uint8_t *scratch) { + private_merge_sort_spread(first, n, scratch, + [](_Float16 a, _Float16 b) { return (a > b); }); +} + //============= default sub group private sort for signed integer ============= DEVICE_EXTERN_C_INLINE int8_t @@ -1005,6 +1056,14 @@ float __devicelib_default_sub_group_private_sort_ascending_f32( return sub_group_merge_sort(value, scratch, std::less{}); } +DEVICE_EXTERN_C_INLINE +_Float16 +__devicelib_default_sub_group_private_sort_ascending_f16(_Float16 value, + uint8_t *scratch) { + return sub_group_merge_sort(value, scratch, + [](_Float16 a, _Float16 b) { return (a < b); }); +} + DEVICE_EXTERN_C_INLINE int8_t __devicelib_default_sub_group_private_sort_descending_i8(int8_t value, @@ -1067,4 +1126,12 @@ float __devicelib_default_sub_group_private_sort_descending_f32( return sub_group_merge_sort(value, scratch, std::greater{}); } +DEVICE_EXTERN_C_INLINE +_Float16 +__devicelib_default_sub_group_private_sort_descending_f16(_Float16 value, + uint8_t *scratch) { + return sub_group_merge_sort(value, scratch, + [](_Float16 a, _Float16 b) { return (a > b); }); +} + #endif diff --git a/libdevice/sort_helper.hpp b/libdevice/sort_helper.hpp index 96282a01c6286..887218b64a6c8 100644 --- a/libdevice/sort_helper.hpp +++ b/libdevice/sort_helper.hpp @@ -12,42 +12,6 @@ #include #if defined(__SPIR__) -// A simple compare function for fp16 type emulated with uint16_t -// 1: great than -// 0: equal to -// -1: less than -// -2: can't compare(NAN) -int fp16_comp(uint16_t a, uint16_t b) { - uint16_t a_sig = a >> 15; - uint16_t a_exp = (a & 0x7fff) >> 10; - uint16_t a_fra = a & 0x3ff; - uint16_t b_sig = b >> 15; - uint16_t b_exp = (b & 0x7fff) >> 10; - uint16_t b_fra = b & 0x3ff; - if (((a_exp == 0x1f) && (a_fra != 0x0)) || ((b_exp == 0x1f) && (b_fra != 0x0))) - return -2; - - if ((a_sig == 0) && (b_sig == 1)) - return 1; - - if ((a_sig == 1) && (b_sig == 0)) - return -1; - - if (a_exp > b_exp) - return (a_sig == 0) ? 1 : -1; - - if (a_exp < b_exp) - return (a_sig == 0) ? -1 : 1; - - if (a_fra == b_fra) - return 0; - - if (a_sig == 0) - return (a_fra > b_fra) ? 1 : -1; - else - return (a_fra > b_fra) ? -1 : 1; -} - template void bubble_sort(Tp *first, const size_t beg, const size_t end, Compare comp) { if (beg < end) { From ae0d2af987a3acce258cc51a11c96b9bb753a5dd Mon Sep 17 00:00:00 2001 From: jinge90 Date: Fri, 12 Jul 2024 11:15:56 +0800 Subject: [PATCH 17/71] link gsort fallback Signed-off-by: jinge90 --- clang/include/clang/Driver/Options.td | 4 ++-- clang/lib/Driver/ToolChains/SYCL.cpp | 3 ++- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/clang/include/clang/Driver/Options.td b/clang/include/clang/Driver/Options.td index 30d2b65204229..0ecb7fd8f2b32 100644 --- a/clang/include/clang/Driver/Options.td +++ b/clang/include/clang/Driver/Options.td @@ -6886,9 +6886,9 @@ def fno_sycl_dead_args_optimization : Flag<["-"], "fno-sycl-dead-args-optimizati "elimination of DPC++ dead kernel arguments">; def fsycl_device_lib_EQ : CommaJoined<["-"], "fsycl-device-lib=">, Group, Flags<[NoXarchOption]>, Visibility<[ClangOption, CLOption, DXCOption]>, - Values<"libc, libm-fp32, libm-fp64, libimf-fp32, libimf-fp64, libimf-bf16, all">, HelpText<"Control inclusion of " + Values<"libc, libm-fp32, libm-fp64, libimf-fp32, libimf-fp64, libimf-bf16, libgsort-fp32, all">, HelpText<"Control inclusion of " "device libraries into device binary linkage. Valid arguments " - "are libc, libm-fp32, libm-fp64, libimf-fp32, libimf-fp64, libimf-bf16, all">; + "are libc, libm-fp32, libm-fp64, libimf-fp32, libimf-fp64, libimf-bf16, libgsort-fp32, all">; def fno_sycl_device_lib_EQ : CommaJoined<["-"], "fno-sycl-device-lib=">, Group, Flags<[NoXarchOption]>, Visibility<[ClangOption, CLOption, DXCOption]>, Values<"libc, libm-fp32, libm-fp64, all">, HelpText<"Control exclusion of " diff --git a/clang/lib/Driver/ToolChains/SYCL.cpp b/clang/lib/Driver/ToolChains/SYCL.cpp index 55913aa368b49..14f174cfc29b7 100644 --- a/clang/lib/Driver/ToolChains/SYCL.cpp +++ b/clang/lib/Driver/ToolChains/SYCL.cpp @@ -222,7 +222,7 @@ SYCL::getDeviceLibraries(const Compilation &C, const llvm::Triple &TargetTriple, llvm::StringMap DeviceLibLinkInfo = { {"libc", true}, {"libm-fp32", true}, {"libm-fp64", true}, {"libimf-fp32", true}, {"libimf-fp64", true}, {"libimf-bf16", true}, - {"libm-bfloat16", true}, {"internal", true}}; + {"libm-bfloat16", true}, {"libgsort-fp32", true}, {"internal", true}}; if (Arg *A = Args.getLastArg(options::OPT_fsycl_device_lib_EQ, options::OPT_fno_sycl_device_lib_EQ)) { if (A->getValues().size() == 0) @@ -274,6 +274,7 @@ SYCL::getDeviceLibraries(const Compilation &C, const llvm::Triple &TargetTriple, {"libsycl-fallback-cmath", "libm-fp32"}, {"libsycl-fallback-cmath-fp64", "libm-fp64"}, {"libsycl-fallback-imf", "libimf-fp32"}, + {"libsycl-fallback-gsort", "libgsort-fp32"}, {"libsycl-fallback-imf-fp64", "libimf-fp64"}, {"libsycl-fallback-imf-bf16", "libimf-bf16"}}; const SYCLDeviceLibsList SYCLDeviceBfloat16FallbackLib = { From 045413cda2b7ea3b8f2c2ed6197150db36e32afb Mon Sep 17 00:00:00 2001 From: jinge90 Date: Fri, 12 Jul 2024 11:54:32 +0800 Subject: [PATCH 18/71] update driver test Signed-off-by: jinge90 --- clang/lib/Driver/ToolChains/SYCL.cpp | 3 ++- .../test/Driver/Inputs/SYCL/lib/libsycl-fallback-gsort.bc | 0 clang/test/Driver/Inputs/SYCL/lib/libsycl-fallback-gsort.o | 0 clang/test/Driver/sycl-device-lib-win.cpp | 5 +++++ clang/test/Driver/sycl-device-lib.cpp | 7 +++++++ 5 files changed, 14 insertions(+), 1 deletion(-) create mode 100644 clang/test/Driver/Inputs/SYCL/lib/libsycl-fallback-gsort.bc create mode 100644 clang/test/Driver/Inputs/SYCL/lib/libsycl-fallback-gsort.o diff --git a/clang/lib/Driver/ToolChains/SYCL.cpp b/clang/lib/Driver/ToolChains/SYCL.cpp index 14f174cfc29b7..7be8ec39533aa 100644 --- a/clang/lib/Driver/ToolChains/SYCL.cpp +++ b/clang/lib/Driver/ToolChains/SYCL.cpp @@ -273,8 +273,8 @@ SYCL::getDeviceLibraries(const Compilation &C, const llvm::Triple &TargetTriple, {"libsycl-fallback-complex-fp64", "libm-fp64"}, {"libsycl-fallback-cmath", "libm-fp32"}, {"libsycl-fallback-cmath-fp64", "libm-fp64"}, - {"libsycl-fallback-imf", "libimf-fp32"}, {"libsycl-fallback-gsort", "libgsort-fp32"}, + {"libsycl-fallback-imf", "libimf-fp32"}, {"libsycl-fallback-imf-fp64", "libimf-fp64"}, {"libsycl-fallback-imf-bf16", "libimf-bf16"}}; const SYCLDeviceLibsList SYCLDeviceBfloat16FallbackLib = { @@ -412,6 +412,7 @@ static llvm::SmallVector SYCLDeviceLibList{ "fallback-cmath-fp64", "fallback-complex", "fallback-complex-fp64", + "fallback-gsort", "fallback-imf", "fallback-imf-fp64", "fallback-imf-bf16", diff --git a/clang/test/Driver/Inputs/SYCL/lib/libsycl-fallback-gsort.bc b/clang/test/Driver/Inputs/SYCL/lib/libsycl-fallback-gsort.bc new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/clang/test/Driver/Inputs/SYCL/lib/libsycl-fallback-gsort.o b/clang/test/Driver/Inputs/SYCL/lib/libsycl-fallback-gsort.o new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/clang/test/Driver/sycl-device-lib-win.cpp b/clang/test/Driver/sycl-device-lib-win.cpp index 3f7267e017efa..60ed1d9abab7f 100644 --- a/clang/test/Driver/sycl-device-lib-win.cpp +++ b/clang/test/Driver/sycl-device-lib-win.cpp @@ -33,6 +33,7 @@ // SYCL_DEVICE_LIB_LINK_DEFAULT-SAME: "{{.*}}libsycl-fallback-complex-fp64.bc" // SYCL_DEVICE_LIB_LINK_DEFAULT-SAME: "{{.*}}libsycl-fallback-cmath.bc" // SYCL_DEVICE_LIB_LINK_DEFAULT-SAME: "{{.*}}libsycl-fallback-cmath-fp64.bc" +// SYCL_DEVICE_LIB_LINK_DEFAULT-SAME: "{{.*}}libsycl-fallback-gsort.bc" // SYCL_DEVICE_LIB_LINK_DEFAULT-SAME: "{{.*}}libsycl-fallback-imf.bc" // SYCL_DEVICE_LIB_LINK_DEFAULT-SAME: "{{.*}}libsycl-fallback-imf-fp64.bc" // SYCL_DEVICE_LIB_LINK_DEFAULT-SAME: "{{.*}}libsycl-fallback-imf-bf16.bc" @@ -70,6 +71,7 @@ // SYCL_DEVICE_LIB_LINK_WITH_FP64-SAME: "{{.*}}libsycl-fallback-complex-fp64.bc" // SYCL_DEVICE_LIB_LINK_WITH_FP64-SAME: "{{.*}}libsycl-fallback-cmath.bc" // SYCL_DEVICE_LIB_LINK_WITH_FP64-SAME: "{{.*}}libsycl-fallback-cmath-fp64.bc" +// SYCL_DEVICE_LIB_LINK_WITH_FP64-SAME: "{{.*}}libsycl-fallback-gsort.bc" // SYCL_DEVICE_LIB_LINK_WITH_FP64-SAME: "{{.*}}libsycl-fallback-imf.bc" // SYCL_DEVICE_LIB_LINK_WITH_FP64-SAME: "{{.*}}libsycl-fallback-imf-fp64.bc" // SYCL_DEVICE_LIB_LINK_WITH_FP64-SAME: "{{.*}}libsycl-fallback-imf-bf16.bc" @@ -93,6 +95,7 @@ // SYCL_DEVICE_LIB_LINK_NO_LIBC-SAME: "{{.*}}libsycl-fallback-complex-fp64.bc" // SYCL_DEVICE_LIB_LINK_NO_LIBC-SAME: "{{.*}}libsycl-fallback-cmath.bc" // SYCL_DEVICE_LIB_LINK_NO_LIBC-SAME: "{{.*}}libsycl-fallback-cmath-fp64.bc" +// SYCL_DEVICE_LIB_LINK_NO_LIBC-SAME: "{{.*}}libsycl-fallback-gsort.bc" // SYCL_DEVICE_LIB_LINK_NO_LIBC-SAME: "{{.*}}libsycl-fallback-imf.bc" // SYCL_DEVICE_LIB_LINK_NO_LIBC-SAME: "{{.*}}libsycl-fallback-imf-fp64.bc" // SYCL_DEVICE_LIB_LINK_NO_LIBC-SAME: "{{.*}}libsycl-fallback-imf-bf16.bc" @@ -115,6 +118,7 @@ // SYCL_DEVICE_LIB_LINK_NO_LIBM-NOT: "{{.*}}libsycl-fallback-complex-fp64.bc" // SYCL_DEVICE_LIB_LINK_NO_LIBM-NOT: "{{.*}}libsycl-fallback-cmath.bc" // SYCL_DEVICE_LIB_LINK_NO_LIBM-NOT: "{{.*}}libsycl-fallback-cmath-fp64.bc" +// SYCL_DEVICE_LIB_LINK_NO_LIBM-SAME: "{{.*}}libsycl-fallback-gsort.bc" // SYCL_DEVICE_LIB_LINK_NO_LIBM-SAME: "{{.*}}libsycl-fallback-imf.bc" // SYCL_DEVICE_LIB_LINK_NO_LIBM-SAME: "{{.*}}libsycl-fallback-imf-fp64.bc" // SYCL_DEVICE_LIB_LINK_NO_LIBM-SAME: "{{.*}}libsycl-fallback-imf-bf16.bc" @@ -172,6 +176,7 @@ // SYCL_LLVM_LINK_DEVICE_LIB-SAME: "{{.*}}libsycl-fallback-complex-fp64.bc" // SYCL_LLVM_LINK_DEVICE_LIB-SAME: "{{.*}}libsycl-fallback-cmath.bc" // SYCL_LLVM_LINK_DEVICE_LIB-SAME: "{{.*}}libsycl-fallback-cmath-fp64.bc" +// SYCL_LLVM_LINK_DEVICE_LIB-SAME: "{{.*}}libsycl-fallback-gsort.bc" // SYCL_LLVM_LINK_DEVICE_LIB-SAME: "{{.*}}libsycl-fallback-imf.bc" // SYCL_LLVM_LINK_DEVICE_LIB-SAME: "{{.*}}libsycl-fallback-imf-fp64.bc" // SYCL_LLVM_LINK_DEVICE_LIB-SAME: "{{.*}}libsycl-fallback-imf-bf16.bc" diff --git a/clang/test/Driver/sycl-device-lib.cpp b/clang/test/Driver/sycl-device-lib.cpp index df90b29872208..b667fa733696c 100644 --- a/clang/test/Driver/sycl-device-lib.cpp +++ b/clang/test/Driver/sycl-device-lib.cpp @@ -33,6 +33,7 @@ // SYCL_DEVICE_LIB_LINK_DEFAULT-SAME: "{{.*}}libsycl-fallback-complex-fp64.bc" // SYCL_DEVICE_LIB_LINK_DEFAULT-SAME: "{{.*}}libsycl-fallback-cmath.bc" // SYCL_DEVICE_LIB_LINK_DEFAULT-SAME: "{{.*}}libsycl-fallback-cmath-fp64.bc" +// SYCL_DEVICE_LIB_LINK_DEFAULT-SAME: "{{.*}}libsycl-fallback-gsort.bc" // SYCL_DEVICE_LIB_LINK_DEFAULT-SAME: "{{.*}}libsycl-fallback-imf.bc" // SYCL_DEVICE_LIB_LINK_DEFAULT-SAME: "{{.*}}libsycl-fallback-imf-fp64.bc" // SYCL_DEVICE_LIB_LINK_DEFAULT-SAME: "{{.*}}libsycl-fallback-imf-bf16.bc" @@ -70,6 +71,7 @@ // SYCL_DEVICE_LIB_LINK_WITH_FP64-SAME: "{{.*}}libsycl-fallback-complex-fp64.bc" // SYCL_DEVICE_LIB_LINK_WITH_FP64-SAME: "{{.*}}libsycl-fallback-cmath.bc" // SYCL_DEVICE_LIB_LINK_WITH_FP64-SAME: "{{.*}}libsycl-fallback-cmath-fp64.bc" +// SYCL_DEVICE_LIB_LINK_WITH_FP64-SAME: "{{.*}}libsycl-fallback-gsort.bc" // SYCL_DEVICE_LIB_LINK_WITH_FP64-SAME: "{{.*}}libsycl-fallback-imf.bc" // SYCL_DEVICE_LIB_LINK_WITH_FP64-SAME: "{{.*}}libsycl-fallback-imf-fp64.bc" // SYCL_DEVICE_LIB_LINK_WITH_FP64-SAME: "{{.*}}libsycl-fallback-imf-bf16.bc" @@ -93,6 +95,7 @@ // SYCL_DEVICE_LIB_LINK_NO_LIBC-SAME: "{{.*}}libsycl-fallback-complex-fp64.bc" // SYCL_DEVICE_LIB_LINK_NO_LIBC-SAME: "{{.*}}libsycl-fallback-cmath.bc" // SYCL_DEVICE_LIB_LINK_NO_LIBC-SAME: "{{.*}}libsycl-fallback-cmath-fp64.bc" +// SYCL_DEVICE_LIB_LINK_NO_LIBC-SAME: "{{.*}}libsycl-fallback-gsort.bc" // SYCL_DEVICE_LIB_LINK_NO_LIBC-SAME: "{{.*}}libsycl-fallback-imf.bc" // SYCL_DEVICE_LIB_LINK_NO_LIBC-SAME: "{{.*}}libsycl-fallback-imf-fp64.bc" // SYCL_DEVICE_LIB_LINK_NO_LIBC-SAME: "{{.*}}libsycl-fallback-imf-bf16.bc" @@ -107,6 +110,7 @@ // SYCL_DEVICE_LIB_LINK_NO_LIBM-NOT: "{{.*}}libsycl-cmath.bc" // SYCL_DEVICE_LIB_LINK_NO_LIBM-NOT: "{{.*}}libsycl-cmath-fp64.bc" // SYCL_DEVICE_LIB_LINK_NO_LIBM-SAME: "{{.*}}libsycl-imf.bc" +// SYCL_DEVICE_LIB_LINK_NO_LIBM-SAME: "{{.*}}libsycl-imf.bc" // SYCL_DEVICE_LIB_LINK_NO_LIBM-SAME: "{{.*}}libsycl-imf-fp64.bc" // SYCL_DEVICE_LIB_LINK_NO_LIBM-SAME: "{{.*}}libsycl-imf-bf16.bc" // SYCL_DEVICE_LIB_LINK_NO_LIBM-SAME: "{{.*}}libsycl-fallback-cassert.bc" @@ -115,6 +119,7 @@ // SYCL_DEVICE_LIB_LINK_NO_LIBM-NOT: "{{.*}}libsycl-fallback-complex-fp64.bc" // SYCL_DEVICE_LIB_LINK_NO_LIBM-NOT: "{{.*}}libsycl-fallback-cmath.bc" // SYCL_DEVICE_LIB_LINK_NO_LIBM-NOT: "{{.*}}libsycl-fallback-cmath-fp64.bc" +// SYCL_DEVICE_LIB_LINK_NO_LIBM-SAME: "{{.*}}libsycl-fallback-gsort.bc" // SYCL_DEVICE_LIB_LINK_NO_LIBM-SAME: "{{.*}}libsycl-fallback-imf.bc" // SYCL_DEVICE_LIB_LINK_NO_LIBM-SAME: "{{.*}}libsycl-fallback-imf-fp64.bc" // SYCL_DEVICE_LIB_LINK_NO_LIBM-SAME: "{{.*}}libsycl-fallback-imf-bf16.bc" @@ -175,6 +180,7 @@ // SYCL_LLVM_LINK_DEVICE_LIB-SAME: "{{.*}}libsycl-fallback-complex-fp64.bc" // SYCL_LLVM_LINK_DEVICE_LIB-SAME: "{{.*}}libsycl-fallback-cmath.bc" // SYCL_LLVM_LINK_DEVICE_LIB-SAME: "{{.*}}libsycl-fallback-cmath-fp64.bc" +// SYCL_LLVM_LINK_DEVICE_LIB-SAME: "{{.*}}libsycl-fallback-gsort.bc" // SYCL_LLVM_LINK_DEVICE_LIB-SAME: "{{.*}}libsycl-fallback-imf.bc" // SYCL_LLVM_LINK_DEVICE_LIB-SAME: "{{.*}}libsycl-fallback-imf-fp64.bc" // SYCL_LLVM_LINK_DEVICE_LIB-SAME: "{{.*}}libsycl-fallback-imf-bf16.bc" @@ -223,6 +229,7 @@ // SYCL_DEVICE_LIB_SANITIZER-SAME: "{{.*}}libsycl-fallback-complex-fp64.bc" // SYCL_DEVICE_LIB_SANITIZER-SAME: "{{.*}}libsycl-fallback-cmath.bc" // SYCL_DEVICE_LIB_SANITIZER-SAME: "{{.*}}libsycl-fallback-cmath-fp64.bc" +// SYCL_DEVICE_LIB_SANITIZER-SAME: "{{.*}}libsycl-fallback-gsort.bc" // SYCL_DEVICE_LIB_SANITIZER-SAME: "{{.*}}libsycl-fallback-imf.bc" // SYCL_DEVICE_LIB_SANITIZER-SAME: "{{.*}}libsycl-fallback-imf-fp64.bc" // SYCL_DEVICE_LIB_SANITIZER-SAME: "{{.*}}libsycl-fallback-imf-bf16.bc" From 4ae26deb6f2b0f50396dcbc791e309c73aaf9eb7 Mon Sep 17 00:00:00 2001 From: jinge90 Date: Fri, 12 Jul 2024 14:36:20 +0800 Subject: [PATCH 19/71] Enable fallback spv for group sort Signed-off-by: jinge90 --- libdevice/cmake/modules/SYCLLibdevice.cmake | 2 +- .../llvm/SYCLLowerIR/SYCLDeviceLibReqMask.h | 1 + llvm/lib/SYCLLowerIR/SYCLDeviceLibReqMask.cpp | 15 +++++++++++++++ .../detail/program_manager/program_manager.cpp | 4 ++++ .../detail/program_manager/program_manager.hpp | 1 + 5 files changed, 22 insertions(+), 1 deletion(-) diff --git a/libdevice/cmake/modules/SYCLLibdevice.cmake b/libdevice/cmake/modules/SYCLLibdevice.cmake index 9c339e1e8b1e9..2553bf163ae6d 100644 --- a/libdevice/cmake/modules/SYCLLibdevice.cmake +++ b/libdevice/cmake/modules/SYCLLibdevice.cmake @@ -206,7 +206,7 @@ add_devicelib(libsycl-fallback-cmath SRC fallback-cmath.cpp DEP ${cmath_obj_deps add_devicelib(libsycl-fallback-cmath-fp64 SRC fallback-cmath-fp64.cpp DEP ${cmath_obj_deps}) add_devicelib(libsycl-fallback-bfloat16 SRC fallback-bfloat16.cpp DEP ${bfloat16_obj_deps}) add_devicelib(libsycl-native-bfloat16 SRC bfloat16_wrapper.cpp DEP ${bfloat16_obj_deps}) -add_devicelib(libsycl-fallback-gsort SRC fallback-gsort.cpp DEP ${gsort_obj_deps}) +add_devicelib(libsycl-fallback-gsort SRC fallback-gsort.cpp DEP ${gsort_obj_deps} EXTRA_ARGS -fno-sycl-instrument-device-code) file(MAKE_DIRECTORY ${obj_binary_dir}/libdevice) set(imf_fallback_src_dir ${obj_binary_dir}/libdevice) diff --git a/llvm/include/llvm/SYCLLowerIR/SYCLDeviceLibReqMask.h b/llvm/include/llvm/SYCLLowerIR/SYCLDeviceLibReqMask.h index c9b737e2d053a..d107661d02e5c 100644 --- a/llvm/include/llvm/SYCLLowerIR/SYCLDeviceLibReqMask.h +++ b/llvm/include/llvm/SYCLLowerIR/SYCLDeviceLibReqMask.h @@ -36,6 +36,7 @@ enum class DeviceLibExt : std::uint32_t { cl_intel_devicelib_imf_fp64, cl_intel_devicelib_imf_bf16, cl_intel_devicelib_bfloat16, + cl_intel_devicelib_gsort, }; uint32_t getSYCLDeviceLibReqMask(const Module &M); diff --git a/llvm/lib/SYCLLowerIR/SYCLDeviceLibReqMask.cpp b/llvm/lib/SYCLLowerIR/SYCLDeviceLibReqMask.cpp index 5f270baecec1d..791ed6f49bab6 100644 --- a/llvm/lib/SYCLLowerIR/SYCLDeviceLibReqMask.cpp +++ b/llvm/lib/SYCLLowerIR/SYCLDeviceLibReqMask.cpp @@ -15,6 +15,7 @@ //===----------------------------------------------------------------------===// #include "llvm/SYCLLowerIR/SYCLDeviceLibReqMask.h" +#include "llvm/ADT/StringRef.h" #include "llvm/IR/Module.h" #include "llvm/TargetParser/Triple.h" @@ -694,6 +695,15 @@ SYCLDeviceLibFuncMap SDLMap = { DeviceLibExt::cl_intel_devicelib_bfloat16}, }; +// TODO: more robust check for all group sort fallback devicelib functions. +static bool checkGroupSortFallback(const StringRef &FuncName) { + if (FuncName.starts_with("__devicelib_default_work_group_") || + FuncName.starts_with("__devicelib_default_sub_group_")) + return true; + else + return false; +} + // Each fallback device library corresponds to one bit in "require mask" which // is an unsigned int32. getDeviceLibBit checks which fallback device library // is required for FuncName and returns the corresponding bit. The corresponding @@ -708,6 +718,7 @@ SYCLDeviceLibFuncMap SDLMap = { // cl_intel_devicelib_imf_fp64: 0x80 // cl_intel_devicelib_imf_bf16: 0x100 // cl_intel_devicelib_bfloat16: 0x200 +// cl_intel_devicelib_gsort: 0x400 uint32_t getDeviceLibBits(const std::string &FuncName) { auto DeviceLibFuncIter = SDLMap.find(FuncName); return ((DeviceLibFuncIter == SDLMap.end()) @@ -732,6 +743,10 @@ uint32_t llvm::getSYCLDeviceLibReqMask(const Module &M) { if (SF.getName().starts_with(DEVICELIB_FUNC_PREFIX) && SF.isDeclaration()) { assert(SF.getCallingConv() == CallingConv::SPIR_FUNC); uint32_t DeviceLibBits = getDeviceLibBits(SF.getName().str()); + if (!DeviceLibBits) { + if (checkGroupSortFallback(SF.getName())) + DeviceLibBits = 0x400; + } ReqMask |= DeviceLibBits; } } diff --git a/sycl/source/detail/program_manager/program_manager.cpp b/sycl/source/detail/program_manager/program_manager.cpp index 94e3e062a0d83..da1a56409775e 100644 --- a/sycl/source/detail/program_manager/program_manager.cpp +++ b/sycl/source/detail/program_manager/program_manager.cpp @@ -974,6 +974,8 @@ static const std::map> {nullptr, "libsycl-fallback-complex-fp64.spv"}}, {DeviceLibExt::cl_intel_devicelib_cstring, {nullptr, "libsycl-fallback-cstring.spv"}}, + {DeviceLibExt::cl_intel_devicelib_gsort, + {nullptr, "libsycl-fallback-gsort.spv"}}, {DeviceLibExt::cl_intel_devicelib_imf, {nullptr, "libsycl-fallback-imf.spv"}}, {DeviceLibExt::cl_intel_devicelib_imf_fp64, @@ -1006,6 +1008,7 @@ static const std::map DeviceLibExtensionStrs = { {DeviceLibExt::cl_intel_devicelib_complex_fp64, "cl_intel_devicelib_complex_fp64"}, {DeviceLibExt::cl_intel_devicelib_cstring, "cl_intel_devicelib_cstring"}, + {DeviceLibExt::cl_intel_devicelib_gsort, "cl_intel_devicelib_gsort"}, {DeviceLibExt::cl_intel_devicelib_imf, "cl_intel_devicelib_imf"}, {DeviceLibExt::cl_intel_devicelib_imf_fp64, "cl_intel_devicelib_imf_fp64"}, {DeviceLibExt::cl_intel_devicelib_imf_bf16, "cl_intel_devicelib_imf_bf16"}, @@ -1257,6 +1260,7 @@ getDeviceLibPrograms(const ContextImplPtr Context, {DeviceLibExt::cl_intel_devicelib_complex, false}, {DeviceLibExt::cl_intel_devicelib_complex_fp64, false}, {DeviceLibExt::cl_intel_devicelib_cstring, false}, + {DeviceLibExt::cl_intel_devicelib_gsort, false}, {DeviceLibExt::cl_intel_devicelib_imf, false}, {DeviceLibExt::cl_intel_devicelib_imf_fp64, false}, {DeviceLibExt::cl_intel_devicelib_imf_bf16, false}, diff --git a/sycl/source/detail/program_manager/program_manager.hpp b/sycl/source/detail/program_manager/program_manager.hpp index 3489dba53ffa1..c14f5db6d2850 100644 --- a/sycl/source/detail/program_manager/program_manager.hpp +++ b/sycl/source/detail/program_manager/program_manager.hpp @@ -80,6 +80,7 @@ enum class DeviceLibExt : std::uint32_t { cl_intel_devicelib_imf_fp64, cl_intel_devicelib_imf_bf16, cl_intel_devicelib_bfloat16, + cl_intel_devicelib_gsort, }; // Provides single loading and building OpenCL programs with unique contexts From 8a62ceb1a8b0c6e160ecc661637e83a439feeb9b Mon Sep 17 00:00:00 2001 From: jinge90 Date: Fri, 12 Jul 2024 15:44:16 +0800 Subject: [PATCH 20/71] Fix driver test regression Signed-off-by: jinge90 --- clang/test/Driver/sycl-device-lib.cpp | 1 - clang/test/Driver/sycl-offload-new-driver.c | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/clang/test/Driver/sycl-device-lib.cpp b/clang/test/Driver/sycl-device-lib.cpp index b667fa733696c..1f664b748c8cb 100644 --- a/clang/test/Driver/sycl-device-lib.cpp +++ b/clang/test/Driver/sycl-device-lib.cpp @@ -110,7 +110,6 @@ // SYCL_DEVICE_LIB_LINK_NO_LIBM-NOT: "{{.*}}libsycl-cmath.bc" // SYCL_DEVICE_LIB_LINK_NO_LIBM-NOT: "{{.*}}libsycl-cmath-fp64.bc" // SYCL_DEVICE_LIB_LINK_NO_LIBM-SAME: "{{.*}}libsycl-imf.bc" -// SYCL_DEVICE_LIB_LINK_NO_LIBM-SAME: "{{.*}}libsycl-imf.bc" // SYCL_DEVICE_LIB_LINK_NO_LIBM-SAME: "{{.*}}libsycl-imf-fp64.bc" // SYCL_DEVICE_LIB_LINK_NO_LIBM-SAME: "{{.*}}libsycl-imf-bf16.bc" // SYCL_DEVICE_LIB_LINK_NO_LIBM-SAME: "{{.*}}libsycl-fallback-cassert.bc" diff --git a/clang/test/Driver/sycl-offload-new-driver.c b/clang/test/Driver/sycl-offload-new-driver.c index be71376bfacda..9643c0db79009 100644 --- a/clang/test/Driver/sycl-offload-new-driver.c +++ b/clang/test/Driver/sycl-offload-new-driver.c @@ -34,7 +34,7 @@ // RUN: %clangxx --target=x86_64-unknown-linux-gnu -fsycl --offload-new-driver \ // RUN: --sysroot=%S/Inputs/SYCL -### %s 2>&1 \ // RUN: | FileCheck -check-prefix WRAPPER_OPTIONS %s -// WRAPPER_OPTIONS: clang-linker-wrapper{{.*}} "-sycl-device-libraries=libsycl-crt.new.o,libsycl-complex.new.o,libsycl-complex-fp64.new.o,libsycl-cmath.new.o,libsycl-cmath-fp64.new.o,libsycl-imf.new.o,libsycl-imf-fp64.new.o,libsycl-imf-bf16.new.o,libsycl-fallback-cassert.new.o,libsycl-fallback-cstring.new.o,libsycl-fallback-complex.new.o,libsycl-fallback-complex-fp64.new.o,libsycl-fallback-cmath.new.o,libsycl-fallback-cmath-fp64.new.o,libsycl-fallback-imf.new.o,libsycl-fallback-imf-fp64.new.o,libsycl-fallback-imf-bf16.new.o,libsycl-itt-user-wrappers.new.o,libsycl-itt-compiler-wrappers.new.o,libsycl-itt-stubs.new.o" +// WRAPPER_OPTIONS: clang-linker-wrapper{{.*}} "-sycl-device-libraries=libsycl-crt.new.o,libsycl-complex.new.o,libsycl-complex-fp64.new.o,libsycl-cmath.new.o,libsycl-cmath-fp64.new.o,libsycl-imf.new.o,libsycl-imf-fp64.new.o,libsycl-imf-bf16.new.o,libsycl-fallback-cassert.new.o,libsycl-fallback-cstring.new.o,libsycl-fallback-complex.new.o,libsycl-fallback-complex-fp64.new.o,libsycl-fallback-cmath.new.o,libsycl-fallback-cmath-fp64.new.o,libsycl-fallback-gsort.new.o,libsycl-fallback-imf.new.o,libsycl-fallback-imf-fp64.new.o,libsycl-fallback-imf-bf16.new.o,libsycl-itt-user-wrappers.new.o,libsycl-itt-compiler-wrappers.new.o,libsycl-itt-stubs.new.o" // WRAPPER_OPTIONS-SAME: "-sycl-device-library-location={{.*}}/lib" /// Verify phases used to generate SPIR-V instead of LLVM-IR From 7c721351d540f63e5bad8dfaad175dfa2dcea69b Mon Sep 17 00:00:00 2001 From: jinge90 Date: Fri, 12 Jul 2024 17:10:20 +0800 Subject: [PATCH 21/71] Add first test case for fallback group sort Signed-off-by: jinge90 --- .../DeviceLib/group_sort/group_sort_i32.cpp | 91 +++++++++++++++++++ 1 file changed, 91 insertions(+) create mode 100644 sycl/test-e2e/DeviceLib/group_sort/group_sort_i32.cpp diff --git a/sycl/test-e2e/DeviceLib/group_sort/group_sort_i32.cpp b/sycl/test-e2e/DeviceLib/group_sort/group_sort_i32.cpp new file mode 100644 index 0000000000000..22ba388d23849 --- /dev/null +++ b/sycl/test-e2e/DeviceLib/group_sort/group_sort_i32.cpp @@ -0,0 +1,91 @@ +// RUN: %{build} -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -fsycl-device-lib-jit-link -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DDES -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DES -fsycl-device-lib-jit-link -o %t.out +// RUN: %{run} %t.out +// +// UNSUPPORTED: cuda || hip + +#include +#include +#include +#include +#include +using namespace sycl; +#ifdef __SYCL_DEVICE_ONLY__ +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p1i32_u32_p3i8( + int32_t *first, uint32_t n, uint8_t *scratch); + +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_descending_p1i32_u32_p3i8( + int32_t *first, uint32_t n, uint8_t *scratch); +#else +extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p1i32_u32_p3i8( + int32_t *first, uint32_t n, uint8_t *scratch) { + return; +} + +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_descending_p1i32_u32_p3i8( + int32_t *first, uint32_t n, uint8_t *scratch) { + return; +} +#endif + +constexpr static size_t NUM = 19; +int main() { + queue q; + int32_t a[NUM] = {-1, 11, 1, 9, 3, 100, 34, 8, 1000, + 77, 293, 23, 36, 2, 111, 91, 88, -2, 525}; + int32_t b[NUM] = { + 0, + }; + int32_t c[NUM]; + memcpy(c, a, sizeof(a)); + std::sort(&c[0], &c[NUM]); + nd_range<1> num_items(range<1>(8), range<1>(8)); + { + buffer ibuf(a, NUM); + buffer obuf(b, NUM); + q.submit([&](auto &h) { + accessor in_acc{ibuf, h}; + accessor out_acc{obuf, h}; + h.parallel_for(num_items, [=](nd_item<1> i) { + int32_t *optr = + out_acc.template get_multi_ptr().get(); + uint8_t *by = reinterpret_cast(optr); +#ifdef DES + __devicelib_default_work_group_joint_sort_descending_p1i32_u32_p3i8( + in_acc.template get_multi_ptr().get(), NUM, + by); +#else + __devicelib_default_work_group_joint_sort_ascending_p1i32_u32_p3i8( + in_acc.template get_multi_ptr().get(), NUM, by); +#endif + }); + }).wait(); + } + + bool fails = false; + for (size_t idx = 0; idx < NUM; ++idx) { +#ifdef DES + if (a[idx] != c[NUM - 1 - idx]) { +#else + if (a[idx] != c[idx]) { +#endif + fails = true; + break; + } + } + assert(!fails); + std::cout << "Pass!" << std::endl; + return 0; +} From f8190b6979933dfa67115ed50886ea064d64636a Mon Sep 17 00:00:00 2001 From: jinge90 Date: Mon, 15 Jul 2024 11:34:26 +0800 Subject: [PATCH 22/71] Add test for group sort i8,i16,i64 Signed-off-by: jinge90 --- .../DeviceLib/group_sort/group_sort_i16.cpp | 92 +++++++++++++++++++ .../DeviceLib/group_sort/group_sort_i64.cpp | 91 ++++++++++++++++++ .../DeviceLib/group_sort/group_sort_i8.cpp | 91 ++++++++++++++++++ 3 files changed, 274 insertions(+) create mode 100644 sycl/test-e2e/DeviceLib/group_sort/group_sort_i16.cpp create mode 100644 sycl/test-e2e/DeviceLib/group_sort/group_sort_i64.cpp create mode 100644 sycl/test-e2e/DeviceLib/group_sort/group_sort_i8.cpp diff --git a/sycl/test-e2e/DeviceLib/group_sort/group_sort_i16.cpp b/sycl/test-e2e/DeviceLib/group_sort/group_sort_i16.cpp new file mode 100644 index 0000000000000..5ed7d55efe593 --- /dev/null +++ b/sycl/test-e2e/DeviceLib/group_sort/group_sort_i16.cpp @@ -0,0 +1,92 @@ +// RUN: %{build} -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -fsycl-device-lib-jit-link -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DDES -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DES -fsycl-device-lib-jit-link -o %t.out +// RUN: %{run} %t.out +// +// UNSUPPORTED: cuda || hip + +#include +#include +#include +#include +using namespace sycl; +#ifdef __SYCL_DEVICE_ONLY__ +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p1i16_u32_p3i8( + int16_t *first, uint32_t n, uint8_t *scratch); + +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_descending_p1i16_u32_p3i8( + int16_t *first, uint32_t n, uint8_t *scratch); +#else +extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p1i16_u32_p3i8( + int16_t *first, uint32_t n, uint8_t *scratch) { + return; +} + +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_descending_p1i16_u32_p3i8( + int16_t *first, uint32_t n, uint8_t *scratch) { + return; +} +#endif + +constexpr static size_t NUM = 18; +int main() { + queue q; + int16_t a[NUM] = {-1, 11, 1, 9, 3, 100, 34, 8, 121, + 77, 125, 23, 36, 2, -111, 91, -88, -2}; + int16_t b[NUM] = { + 0, + }; + int16_t c[NUM]; + memcpy(c, a, sizeof(a)); + std::sort(&c[0], &c[NUM]); + + nd_range<1> num_items(range<1>(8), range<1>(8)); + { + buffer ibuf(a, NUM); + buffer obuf(b, NUM); + q.submit([&](auto &h) { + accessor in_acc{ibuf, h}; + accessor out_acc{obuf, h}; + h.parallel_for(num_items, [=](nd_item<1> i) { + int16_t *optr = + out_acc.template get_multi_ptr().get(); + // int16_t *optr = out_acc.get_pointer(); + uint8_t *by = reinterpret_cast(optr); +#ifdef DES + __devicelib_default_work_group_joint_sort_descending_p1i16_u32_p3i8( + in_acc.template get_multi_ptr().get(), NUM, + by); +#else + __devicelib_default_work_group_joint_sort_ascending_p1i16_u32_p3i8( + in_acc.template get_multi_ptr().get(), NUM, by); +#endif + }); + }).wait(); + } + + bool fails = false; + for (size_t idx = 0; idx < NUM; ++idx) { +#ifdef DES + if (a[idx] != c[NUM - 1 - idx]) { +#else + if (a[idx] != c[idx]) { +#endif + fails = true; + break; + } + } + assert(!fails); + std::cout << "Pass!" << std::endl; + return 0; +} diff --git a/sycl/test-e2e/DeviceLib/group_sort/group_sort_i64.cpp b/sycl/test-e2e/DeviceLib/group_sort/group_sort_i64.cpp new file mode 100644 index 0000000000000..65fa4984bbbd1 --- /dev/null +++ b/sycl/test-e2e/DeviceLib/group_sort/group_sort_i64.cpp @@ -0,0 +1,91 @@ +// RUN: %{build} -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -fsycl-device-lib-jit-link -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DDES -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DES -fsycl-device-lib-jit-link -o %t.out +// RUN: %{run} %t.out +// +// UNSUPPORTED: cuda || hip + +#include +#include +#include +#include +using namespace sycl; +#ifdef __SYCL_DEVICE_ONLY__ +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p1i64_u32_p3i8( + int64_t *first, uint32_t n, uint8_t *scratch); + +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_descending_p1i64_u32_p3i8( + int64_t *first, uint32_t n, uint8_t *scratch); +#else +extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p1i64_u32_p3i8( + int64_t *first, uint32_t n, uint8_t *scratch) { + return; +} + +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_descending_p1i64_u32_p3i8( + int64_t *first, uint32_t n, uint8_t *scratch) { + return; +} +#endif + +constexpr static size_t NUM = 18; +int main() { + queue q; + int64_t a[NUM] = {-1, 11, 1, 9, 3, 100, 34, 8, 121, + 77, 125, 23, 36, 2, -111, 91, -88, -2}; + int64_t b[NUM] = { + 0, + }; + int64_t c[NUM]; + memcpy(c, a, sizeof(a)); + std::sort(&c[0], &c[NUM]); + + nd_range<1> num_items(range<1>(8), range<1>(8)); + { + buffer ibuf(a, NUM); + buffer obuf(b, NUM); + q.submit([&](auto &h) { + accessor in_acc{ibuf, h}; + accessor out_acc{obuf, h}; + h.parallel_for(num_items, [=](nd_item<1> i) { + int64_t *optr = + out_acc.template get_multi_ptr().get(); + uint8_t *by = reinterpret_cast(optr); +#ifdef DES + __devicelib_default_work_group_joint_sort_descending_p1i64_u32_p3i8( + in_acc.template get_multi_ptr().get(), NUM, + by); +#else + __devicelib_default_work_group_joint_sort_ascending_p1i64_u32_p3i8( + in_acc.template get_multi_ptr().get(), NUM, by); +#endif + }); + }).wait(); + } + + bool fails = false; + for (size_t idx = 0; idx < NUM; ++idx) { +#ifdef DES + if (a[idx] != c[NUM - 1 - idx]) { +#else + if (a[idx] != c[idx]) { +#endif + fails = true; + break; + } + } + assert(!fails); + std::cout << "Pass!" << std::endl; + return 0; +} diff --git a/sycl/test-e2e/DeviceLib/group_sort/group_sort_i8.cpp b/sycl/test-e2e/DeviceLib/group_sort/group_sort_i8.cpp new file mode 100644 index 0000000000000..a7ac5999637b6 --- /dev/null +++ b/sycl/test-e2e/DeviceLib/group_sort/group_sort_i8.cpp @@ -0,0 +1,91 @@ +// RUN: %{build} -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -fsycl-device-lib-jit-link -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DDES -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DES -fsycl-device-lib-jit-link -o %t.out +// RUN: %{run} %t.out +// +// UNSUPPORTED: cuda || hip + +#include +#include +#include +#include +using namespace sycl; +#ifdef __SYCL_DEVICE_ONLY__ +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p1i8_u32_p3i8( + int8_t *first, uint32_t n, uint8_t *scratch); + +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_descending_p1i8_u32_p3i8( + int8_t *first, uint32_t n, uint8_t *scratch); +#else +extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p1i8_u32_p3i8( + int8_t *first, uint32_t n, uint8_t *scratch) { + return; +} + +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_descending_p1i8_u32_p3i8( + int8_t *first, uint32_t n, uint8_t *scratch) { + return; +} +#endif + +constexpr static size_t NUM = 18; +int main() { + queue q; + int8_t a[NUM] = {-1, 11, 1, 9, 3, 100, 34, 8, 121, + 77, 125, 23, 36, 2, -111, 91, -88, -2}; + int8_t b[NUM] = { + 0, + }; + int8_t c[NUM]; + memcpy(c, a, sizeof(a)); + std::sort(&c[0], &c[NUM]); + + nd_range<1> num_items(range<1>(8), range<1>(8)); + { + buffer ibuf(a, NUM); + buffer obuf(b, NUM); + q.submit([&](auto &h) { + accessor in_acc{ibuf, h}; + accessor out_acc{obuf, h}; + h.parallel_for(num_items, [=](nd_item<1> i) { + int8_t *optr = + out_acc.template get_multi_ptr().get(); + uint8_t *by = reinterpret_cast(optr); +#ifdef DES + __devicelib_default_work_group_joint_sort_descending_p1i8_u32_p3i8( + in_acc.template get_multi_ptr().get(), NUM, + by); +#else + __devicelib_default_work_group_joint_sort_ascending_p1i8_u32_p3i8( + in_acc.template get_multi_ptr().get(), NUM, by); +#endif + }); + }).wait(); + } + + bool fails = false; + for (size_t idx = 0; idx < NUM; ++idx) { +#ifdef DES + if (a[idx] != c[NUM - 1 - idx]) { +#else + if (a[idx] != c[idx]) { +#endif + fails = true; + break; + } + } + assert(!fails); + std::cout << "Pass!" << std::endl; + return 0; +} From d242e77f4b0c866f09603753ee6449f00e8f0c89 Mon Sep 17 00:00:00 2001 From: jinge90 Date: Mon, 15 Jul 2024 14:20:15 +0800 Subject: [PATCH 23/71] Fix clang-format Signed-off-by: jinge90 --- clang/lib/Driver/ToolChains/SYCL.cpp | 4 ++-- sycl/test-e2e/DeviceLib/group_sort/group_sort_i32.cpp | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/clang/lib/Driver/ToolChains/SYCL.cpp b/clang/lib/Driver/ToolChains/SYCL.cpp index 7be8ec39533aa..82e4ef6770a7b 100644 --- a/clang/lib/Driver/ToolChains/SYCL.cpp +++ b/clang/lib/Driver/ToolChains/SYCL.cpp @@ -220,8 +220,8 @@ SYCL::getDeviceLibraries(const Compilation &C, const llvm::Triple &TargetTriple, // Currently, all SYCL device libraries will be linked by default. Linkage // of "internal" libraries cannot be affected via -fno-sycl-device-lib. llvm::StringMap DeviceLibLinkInfo = { - {"libc", true}, {"libm-fp32", true}, {"libm-fp64", true}, - {"libimf-fp32", true}, {"libimf-fp64", true}, {"libimf-bf16", true}, + {"libc", true}, {"libm-fp32", true}, {"libm-fp64", true}, + {"libimf-fp32", true}, {"libimf-fp64", true}, {"libimf-bf16", true}, {"libm-bfloat16", true}, {"libgsort-fp32", true}, {"internal", true}}; if (Arg *A = Args.getLastArg(options::OPT_fsycl_device_lib_EQ, options::OPT_fno_sycl_device_lib_EQ)) { diff --git a/sycl/test-e2e/DeviceLib/group_sort/group_sort_i32.cpp b/sycl/test-e2e/DeviceLib/group_sort/group_sort_i32.cpp index 22ba388d23849..28c24598125d8 100644 --- a/sycl/test-e2e/DeviceLib/group_sort/group_sort_i32.cpp +++ b/sycl/test-e2e/DeviceLib/group_sort/group_sort_i32.cpp @@ -12,9 +12,9 @@ // // UNSUPPORTED: cuda || hip +#include #include #include -#include #include #include using namespace sycl; @@ -43,8 +43,8 @@ __devicelib_default_work_group_joint_sort_descending_p1i32_u32_p3i8( constexpr static size_t NUM = 19; int main() { queue q; - int32_t a[NUM] = {-1, 11, 1, 9, 3, 100, 34, 8, 1000, - 77, 293, 23, 36, 2, 111, 91, 88, -2, 525}; + int32_t a[NUM] = {-1, 11, 1, 9, 3, 100, 34, 8, 1000, 77, + 293, 23, 36, 2, 111, 91, 88, -2, 525}; int32_t b[NUM] = { 0, }; From bfb2d8b4a73f2c73fb5d068023702ac462ff0741 Mon Sep 17 00:00:00 2001 From: jinge90 Date: Wed, 17 Jul 2024 16:14:03 +0800 Subject: [PATCH 24/71] Combine tests Signed-off-by: jinge90 --- .../group_sort/workgroup_joint_sort.cpp | 183 ++++++++++++++++++ 1 file changed, 183 insertions(+) create mode 100644 sycl/test-e2e/DeviceLib/group_sort/workgroup_joint_sort.cpp diff --git a/sycl/test-e2e/DeviceLib/group_sort/workgroup_joint_sort.cpp b/sycl/test-e2e/DeviceLib/group_sort/workgroup_joint_sort.cpp new file mode 100644 index 0000000000000..77e157617b4b2 --- /dev/null +++ b/sycl/test-e2e/DeviceLib/group_sort/workgroup_joint_sort.cpp @@ -0,0 +1,183 @@ +// RUN: %{build} -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -fsycl-device-lib-jit-link -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DDES -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DES -fsycl-device-lib-jit-link -o %t.out +// RUN: %{run} %t.out +// +// UNSUPPORTED: cuda || hip + +#include +#include +#include +#include +#include +using namespace sycl; +#ifdef __SYCL_DEVICE_ONLY__ +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p1i32_u32_p1i8( + int32_t *first, uint32_t n, uint8_t *scratch); + +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_descending_p1i32_u32_p1i8( + int32_t *first, uint32_t n, uint8_t *scratch); + +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p1i16_u32_p1i8( + int16_t *first, uint32_t n, uint8_t *scratch); + +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_descending_p1i16_u32_p1i8( + int16_t *first, uint32_t n, uint8_t *scratch); + +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p1i64_u32_p1i8( + int64_t *first, uint32_t n, uint8_t *scratch); + +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_descending_p1i64_u32_p1i8( + int64_t *first, uint32_t n, uint8_t *scratch); +#else +extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p1i32_u32_p1i8( + int32_t *first, uint32_t n, uint8_t *scratch) { + return; +} + +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_descending_p1i32_u32_p1i8( + int32_t *first, uint32_t n, uint8_t *scratch) { + return; +} + +extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p1i16_u32_p1i8( + int16_t *first, uint32_t n, uint8_t *scratch) { + return; +} + +extern "C" void +__devicelib_default_work_group_joint_sort_descending_p1i16_u32_p1i8( + int16_t *first, uint32_t n, uint8_t *scratch) { + return; +} + +extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p1i64_u32_p1i8( + int64_t *first, uint32_t n, uint8_t *scratch) { + return; +} + +extern "C" void +__devicelib_default_work_group_joint_sort_descending_p1i64_u32_p1i8( + int64_t *first, uint32_t n, uint8_t *scratch) { + return; +} + +#endif + +template +void test_work_group_joint_sort(sycl::queue &q, Ty input[NUM], SortHelper gsh) { + Ty scratch[NUM] = { + 0, + }; + Ty result[NUM]; + memcpy(result, input, sizeof(Ty) * NUM); + std::sort(&result[0], &result[NUM]); + const static size_t wg_size = WG_SZ; + nd_range<1> num_items((range<1>(wg_size)), (range<1>(wg_size))); + { + buffer ibuf(input, NUM); + buffer obuf(scratch, NUM); + q.submit([&](auto &h) { + accessor in_acc{ibuf, h}; + accessor out_acc{obuf, h}; + h.parallel_for(num_items, [=](nd_item<1> i) { + Ty *optr = + out_acc.template get_multi_ptr().get(); + uint8_t *by = reinterpret_cast(optr); + gsh(in_acc.template get_multi_ptr().get(), NUM, + by); + }); + }).wait(); + } + + bool fails = false; + for (size_t idx = 0; idx < NUM; ++idx) { +#ifdef DES + if (input[idx] != result[NUM - 1 - idx]) { +#else + if (input[idx] != result[idx]) { +#endif + fails = true; + break; + } + } + assert(!fails); +} + +int main() { + queue q; + + { + constexpr static int NUM = 19; + int32_t a[NUM] = {-1, 11, 1, 9, 3, 100, 34, 8, 1000, 77, + 293, 23, 36, 2, 111, 91, 88, -2, 525}; + auto work_group_sorter = [](int32_t *first, uint32_t n, uint8_t *scratch) { +#ifdef DES + __devicelib_default_work_group_joint_sort_descending_p1i32_u32_p1i8( + first, n, scratch); +#else + __devicelib_default_work_group_joint_sort_ascending_p1i32_u32_p1i8( + first, n, scratch); +#endif + }; + test_work_group_joint_sort( + q, a, work_group_sorter); + std::cout << "work group joint sort p1i32_u32_p1i8 passes" << std::endl; + } + + { + constexpr static int NUM = 21; + int16_t a[NUM] = {-1, 11, 1, 9, 3, 100, 34, 8, 1000, 77, 293, + 23, 36, 2, 111, 91, 88, -2, 525, -12, 525}; + auto work_group_sorter = [](int16_t *first, uint32_t n, uint8_t *scratch) { +#ifdef DES + __devicelib_default_work_group_joint_sort_descending_p1i16_u32_p1i8( + first, n, scratch); +#else + __devicelib_default_work_group_joint_sort_ascending_p1i16_u32_p1i8( + first, n, scratch); +#endif + }; + test_work_group_joint_sort( + q, a, work_group_sorter); + std::cout << "work group joint sort p1i16_u32_p1i8 passes" << std::endl; + } + + { + constexpr static int NUM = 23; + int64_t a[NUM] = {-1, 11, 1, 9, 3, 100, 34, 8, + 1000, 77, 293, 23, 36, 2, 111, 91, + 88, -2, 525, -12, 525, -99999999, 19928348493}; + auto work_group_sorter = [](int64_t *first, uint32_t n, uint8_t *scratch) { +#ifdef DES + __devicelib_default_work_group_joint_sort_descending_p1i64_u32_p1i8( + first, n, scratch); +#else + __devicelib_default_work_group_joint_sort_ascending_p1i64_u32_p1i8( + first, n, scratch); +#endif + }; + test_work_group_joint_sort( + q, a, work_group_sorter); + std::cout << "work group joint sort p1i64_u32_p1i8 passes" << std::endl; + } + + return 0; +} From ee09a2a81e984339b46383570f4b9b633f74dbe5 Mon Sep 17 00:00:00 2001 From: jinge90 Date: Wed, 17 Jul 2024 16:15:05 +0800 Subject: [PATCH 25/71] Fix joint sort p1i64_u32_p1i64 Signed-off-by: jinge90 --- libdevice/fallback-gsort.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/libdevice/fallback-gsort.cpp b/libdevice/fallback-gsort.cpp index 0cafd186cf5f7..fb21ab62a4af8 100644 --- a/libdevice/fallback-gsort.cpp +++ b/libdevice/fallback-gsort.cpp @@ -157,6 +157,7 @@ void __devicelib_default_work_group_joint_sort_descending_p3i32_u32_p3i8( merge_sort(first, n, scratch, std::greater{}); } +DEVICE_EXTERN_C_INLINE void __devicelib_default_work_group_joint_sort_ascending_p1i64_u32_p1i8( int64_t *first, uint32_t n, uint8_t *scratch) { merge_sort(first, n, scratch, std::less{}); From 152e09d3864b5b82a7e61f11d7e0de947c1bc363 Mon Sep 17 00:00:00 2001 From: jinge90 Date: Wed, 17 Jul 2024 17:05:36 +0800 Subject: [PATCH 26/71] add fp32 test Signed-off-by: jinge90 --- .../DeviceLib/group_sort/group_sort_decls.hpp | 104 ++++++++++++++++++ .../DeviceLib/group_sort/group_sort_i16.cpp | 92 ---------------- .../DeviceLib/group_sort/group_sort_i32.cpp | 91 --------------- .../DeviceLib/group_sort/group_sort_i64.cpp | 91 --------------- .../DeviceLib/group_sort/group_sort_i8.cpp | 91 --------------- .../group_sort/workgroup_joint_sort.cpp | 102 +++++++---------- 6 files changed, 144 insertions(+), 427 deletions(-) create mode 100644 sycl/test-e2e/DeviceLib/group_sort/group_sort_decls.hpp delete mode 100644 sycl/test-e2e/DeviceLib/group_sort/group_sort_i16.cpp delete mode 100644 sycl/test-e2e/DeviceLib/group_sort/group_sort_i32.cpp delete mode 100644 sycl/test-e2e/DeviceLib/group_sort/group_sort_i64.cpp delete mode 100644 sycl/test-e2e/DeviceLib/group_sort/group_sort_i8.cpp diff --git a/sycl/test-e2e/DeviceLib/group_sort/group_sort_decls.hpp b/sycl/test-e2e/DeviceLib/group_sort/group_sort_decls.hpp new file mode 100644 index 0000000000000..8200beaf007b7 --- /dev/null +++ b/sycl/test-e2e/DeviceLib/group_sort/group_sort_decls.hpp @@ -0,0 +1,104 @@ +#pragma once +#include + +#ifdef __SYCL_DEVICE_ONLY__ +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p1i32_u32_p1i8( + int32_t *first, uint32_t n, uint8_t *scratch); + +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_descending_p1i32_u32_p1i8( + int32_t *first, uint32_t n, uint8_t *scratch); + +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p1i16_u32_p1i8( + int16_t *first, uint32_t n, uint8_t *scratch); + +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_descending_p1i16_u32_p1i8( + int16_t *first, uint32_t n, uint8_t *scratch); + +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p1i64_u32_p1i8( + int64_t *first, uint32_t n, uint8_t *scratch); + +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_descending_p1i64_u32_p1i8( + int64_t *first, uint32_t n, uint8_t *scratch); + +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p1i8_u32_p1i8( + int8_t *first, uint32_t n, uint8_t *scratch); + +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_descending_p1i8_u32_p1i8( + int8_t *first, uint32_t n, uint8_t *scratch); + +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p1f32_u32_p1i8( + float *first, uint32_t n, uint8_t *scratch); + +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_descending_p1f32_u32_p1i8( + float *first, uint32_t n, uint8_t *scratch); +#else +extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p1i32_u32_p1i8( + int32_t *first, uint32_t n, uint8_t *scratch) { + return; +} + +extern "C" void +__devicelib_default_work_group_joint_sort_descending_p1i32_u32_p1i8( + int32_t *first, uint32_t n, uint8_t *scratch) { + return; +} + +extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p1i16_u32_p1i8( + int16_t *first, uint32_t n, uint8_t *scratch) { + return; +} + +extern "C" void +__devicelib_default_work_group_joint_sort_descending_p1i16_u32_p1i8( + int16_t *first, uint32_t n, uint8_t *scratch) { + return; +} + +extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p1i64_u32_p1i8( + int64_t *first, uint32_t n, uint8_t *scratch) { + return; +} + +extern "C" void +__devicelib_default_work_group_joint_sort_descending_p1i64_u32_p1i8( + int64_t *first, uint32_t n, uint8_t *scratch) { + return; +} + +extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p1i8_u32_p1i8( + int8_t *first, uint32_t n, uint8_t *scratch) { + return; +} + +extern "C" void +__devicelib_default_work_group_joint_sort_descending_p1i8_u32_p1i8( + int8_t *first, uint32_t n, uint8_t *scratch) { + return; +} + +extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p1f32_u32_p1i8( + float *first, uint32_t n, uint8_t *scratch) { + return; +} + +extern "C" void +__devicelib_default_work_group_joint_sort_descending_p1f32_u32_p1i8( + float *first, uint32_t n, uint8_t *scratch) { + return; +} +#endif diff --git a/sycl/test-e2e/DeviceLib/group_sort/group_sort_i16.cpp b/sycl/test-e2e/DeviceLib/group_sort/group_sort_i16.cpp deleted file mode 100644 index 5ed7d55efe593..0000000000000 --- a/sycl/test-e2e/DeviceLib/group_sort/group_sort_i16.cpp +++ /dev/null @@ -1,92 +0,0 @@ -// RUN: %{build} -o %t.out -// RUN: %{run} %t.out - -// RUN: %{build} -fsycl-device-lib-jit-link -o %t.out -// RUN: %{run} %t.out - -// RUN: %{build} -DDES -o %t.out -// RUN: %{run} %t.out - -// RUN: %{build} -DES -fsycl-device-lib-jit-link -o %t.out -// RUN: %{run} %t.out -// -// UNSUPPORTED: cuda || hip - -#include -#include -#include -#include -using namespace sycl; -#ifdef __SYCL_DEVICE_ONLY__ -SYCL_EXTERNAL extern "C" void -__devicelib_default_work_group_joint_sort_ascending_p1i16_u32_p3i8( - int16_t *first, uint32_t n, uint8_t *scratch); - -SYCL_EXTERNAL extern "C" void -__devicelib_default_work_group_joint_sort_descending_p1i16_u32_p3i8( - int16_t *first, uint32_t n, uint8_t *scratch); -#else -extern "C" void -__devicelib_default_work_group_joint_sort_ascending_p1i16_u32_p3i8( - int16_t *first, uint32_t n, uint8_t *scratch) { - return; -} - -SYCL_EXTERNAL extern "C" void -__devicelib_default_work_group_joint_sort_descending_p1i16_u32_p3i8( - int16_t *first, uint32_t n, uint8_t *scratch) { - return; -} -#endif - -constexpr static size_t NUM = 18; -int main() { - queue q; - int16_t a[NUM] = {-1, 11, 1, 9, 3, 100, 34, 8, 121, - 77, 125, 23, 36, 2, -111, 91, -88, -2}; - int16_t b[NUM] = { - 0, - }; - int16_t c[NUM]; - memcpy(c, a, sizeof(a)); - std::sort(&c[0], &c[NUM]); - - nd_range<1> num_items(range<1>(8), range<1>(8)); - { - buffer ibuf(a, NUM); - buffer obuf(b, NUM); - q.submit([&](auto &h) { - accessor in_acc{ibuf, h}; - accessor out_acc{obuf, h}; - h.parallel_for(num_items, [=](nd_item<1> i) { - int16_t *optr = - out_acc.template get_multi_ptr().get(); - // int16_t *optr = out_acc.get_pointer(); - uint8_t *by = reinterpret_cast(optr); -#ifdef DES - __devicelib_default_work_group_joint_sort_descending_p1i16_u32_p3i8( - in_acc.template get_multi_ptr().get(), NUM, - by); -#else - __devicelib_default_work_group_joint_sort_ascending_p1i16_u32_p3i8( - in_acc.template get_multi_ptr().get(), NUM, by); -#endif - }); - }).wait(); - } - - bool fails = false; - for (size_t idx = 0; idx < NUM; ++idx) { -#ifdef DES - if (a[idx] != c[NUM - 1 - idx]) { -#else - if (a[idx] != c[idx]) { -#endif - fails = true; - break; - } - } - assert(!fails); - std::cout << "Pass!" << std::endl; - return 0; -} diff --git a/sycl/test-e2e/DeviceLib/group_sort/group_sort_i32.cpp b/sycl/test-e2e/DeviceLib/group_sort/group_sort_i32.cpp deleted file mode 100644 index 28c24598125d8..0000000000000 --- a/sycl/test-e2e/DeviceLib/group_sort/group_sort_i32.cpp +++ /dev/null @@ -1,91 +0,0 @@ -// RUN: %{build} -o %t.out -// RUN: %{run} %t.out - -// RUN: %{build} -fsycl-device-lib-jit-link -o %t.out -// RUN: %{run} %t.out - -// RUN: %{build} -DDES -o %t.out -// RUN: %{run} %t.out - -// RUN: %{build} -DES -fsycl-device-lib-jit-link -o %t.out -// RUN: %{run} %t.out -// -// UNSUPPORTED: cuda || hip - -#include -#include -#include -#include -#include -using namespace sycl; -#ifdef __SYCL_DEVICE_ONLY__ -SYCL_EXTERNAL extern "C" void -__devicelib_default_work_group_joint_sort_ascending_p1i32_u32_p3i8( - int32_t *first, uint32_t n, uint8_t *scratch); - -SYCL_EXTERNAL extern "C" void -__devicelib_default_work_group_joint_sort_descending_p1i32_u32_p3i8( - int32_t *first, uint32_t n, uint8_t *scratch); -#else -extern "C" void -__devicelib_default_work_group_joint_sort_ascending_p1i32_u32_p3i8( - int32_t *first, uint32_t n, uint8_t *scratch) { - return; -} - -SYCL_EXTERNAL extern "C" void -__devicelib_default_work_group_joint_sort_descending_p1i32_u32_p3i8( - int32_t *first, uint32_t n, uint8_t *scratch) { - return; -} -#endif - -constexpr static size_t NUM = 19; -int main() { - queue q; - int32_t a[NUM] = {-1, 11, 1, 9, 3, 100, 34, 8, 1000, 77, - 293, 23, 36, 2, 111, 91, 88, -2, 525}; - int32_t b[NUM] = { - 0, - }; - int32_t c[NUM]; - memcpy(c, a, sizeof(a)); - std::sort(&c[0], &c[NUM]); - nd_range<1> num_items(range<1>(8), range<1>(8)); - { - buffer ibuf(a, NUM); - buffer obuf(b, NUM); - q.submit([&](auto &h) { - accessor in_acc{ibuf, h}; - accessor out_acc{obuf, h}; - h.parallel_for(num_items, [=](nd_item<1> i) { - int32_t *optr = - out_acc.template get_multi_ptr().get(); - uint8_t *by = reinterpret_cast(optr); -#ifdef DES - __devicelib_default_work_group_joint_sort_descending_p1i32_u32_p3i8( - in_acc.template get_multi_ptr().get(), NUM, - by); -#else - __devicelib_default_work_group_joint_sort_ascending_p1i32_u32_p3i8( - in_acc.template get_multi_ptr().get(), NUM, by); -#endif - }); - }).wait(); - } - - bool fails = false; - for (size_t idx = 0; idx < NUM; ++idx) { -#ifdef DES - if (a[idx] != c[NUM - 1 - idx]) { -#else - if (a[idx] != c[idx]) { -#endif - fails = true; - break; - } - } - assert(!fails); - std::cout << "Pass!" << std::endl; - return 0; -} diff --git a/sycl/test-e2e/DeviceLib/group_sort/group_sort_i64.cpp b/sycl/test-e2e/DeviceLib/group_sort/group_sort_i64.cpp deleted file mode 100644 index 65fa4984bbbd1..0000000000000 --- a/sycl/test-e2e/DeviceLib/group_sort/group_sort_i64.cpp +++ /dev/null @@ -1,91 +0,0 @@ -// RUN: %{build} -o %t.out -// RUN: %{run} %t.out - -// RUN: %{build} -fsycl-device-lib-jit-link -o %t.out -// RUN: %{run} %t.out - -// RUN: %{build} -DDES -o %t.out -// RUN: %{run} %t.out - -// RUN: %{build} -DES -fsycl-device-lib-jit-link -o %t.out -// RUN: %{run} %t.out -// -// UNSUPPORTED: cuda || hip - -#include -#include -#include -#include -using namespace sycl; -#ifdef __SYCL_DEVICE_ONLY__ -SYCL_EXTERNAL extern "C" void -__devicelib_default_work_group_joint_sort_ascending_p1i64_u32_p3i8( - int64_t *first, uint32_t n, uint8_t *scratch); - -SYCL_EXTERNAL extern "C" void -__devicelib_default_work_group_joint_sort_descending_p1i64_u32_p3i8( - int64_t *first, uint32_t n, uint8_t *scratch); -#else -extern "C" void -__devicelib_default_work_group_joint_sort_ascending_p1i64_u32_p3i8( - int64_t *first, uint32_t n, uint8_t *scratch) { - return; -} - -SYCL_EXTERNAL extern "C" void -__devicelib_default_work_group_joint_sort_descending_p1i64_u32_p3i8( - int64_t *first, uint32_t n, uint8_t *scratch) { - return; -} -#endif - -constexpr static size_t NUM = 18; -int main() { - queue q; - int64_t a[NUM] = {-1, 11, 1, 9, 3, 100, 34, 8, 121, - 77, 125, 23, 36, 2, -111, 91, -88, -2}; - int64_t b[NUM] = { - 0, - }; - int64_t c[NUM]; - memcpy(c, a, sizeof(a)); - std::sort(&c[0], &c[NUM]); - - nd_range<1> num_items(range<1>(8), range<1>(8)); - { - buffer ibuf(a, NUM); - buffer obuf(b, NUM); - q.submit([&](auto &h) { - accessor in_acc{ibuf, h}; - accessor out_acc{obuf, h}; - h.parallel_for(num_items, [=](nd_item<1> i) { - int64_t *optr = - out_acc.template get_multi_ptr().get(); - uint8_t *by = reinterpret_cast(optr); -#ifdef DES - __devicelib_default_work_group_joint_sort_descending_p1i64_u32_p3i8( - in_acc.template get_multi_ptr().get(), NUM, - by); -#else - __devicelib_default_work_group_joint_sort_ascending_p1i64_u32_p3i8( - in_acc.template get_multi_ptr().get(), NUM, by); -#endif - }); - }).wait(); - } - - bool fails = false; - for (size_t idx = 0; idx < NUM; ++idx) { -#ifdef DES - if (a[idx] != c[NUM - 1 - idx]) { -#else - if (a[idx] != c[idx]) { -#endif - fails = true; - break; - } - } - assert(!fails); - std::cout << "Pass!" << std::endl; - return 0; -} diff --git a/sycl/test-e2e/DeviceLib/group_sort/group_sort_i8.cpp b/sycl/test-e2e/DeviceLib/group_sort/group_sort_i8.cpp deleted file mode 100644 index a7ac5999637b6..0000000000000 --- a/sycl/test-e2e/DeviceLib/group_sort/group_sort_i8.cpp +++ /dev/null @@ -1,91 +0,0 @@ -// RUN: %{build} -o %t.out -// RUN: %{run} %t.out - -// RUN: %{build} -fsycl-device-lib-jit-link -o %t.out -// RUN: %{run} %t.out - -// RUN: %{build} -DDES -o %t.out -// RUN: %{run} %t.out - -// RUN: %{build} -DES -fsycl-device-lib-jit-link -o %t.out -// RUN: %{run} %t.out -// -// UNSUPPORTED: cuda || hip - -#include -#include -#include -#include -using namespace sycl; -#ifdef __SYCL_DEVICE_ONLY__ -SYCL_EXTERNAL extern "C" void -__devicelib_default_work_group_joint_sort_ascending_p1i8_u32_p3i8( - int8_t *first, uint32_t n, uint8_t *scratch); - -SYCL_EXTERNAL extern "C" void -__devicelib_default_work_group_joint_sort_descending_p1i8_u32_p3i8( - int8_t *first, uint32_t n, uint8_t *scratch); -#else -extern "C" void -__devicelib_default_work_group_joint_sort_ascending_p1i8_u32_p3i8( - int8_t *first, uint32_t n, uint8_t *scratch) { - return; -} - -SYCL_EXTERNAL extern "C" void -__devicelib_default_work_group_joint_sort_descending_p1i8_u32_p3i8( - int8_t *first, uint32_t n, uint8_t *scratch) { - return; -} -#endif - -constexpr static size_t NUM = 18; -int main() { - queue q; - int8_t a[NUM] = {-1, 11, 1, 9, 3, 100, 34, 8, 121, - 77, 125, 23, 36, 2, -111, 91, -88, -2}; - int8_t b[NUM] = { - 0, - }; - int8_t c[NUM]; - memcpy(c, a, sizeof(a)); - std::sort(&c[0], &c[NUM]); - - nd_range<1> num_items(range<1>(8), range<1>(8)); - { - buffer ibuf(a, NUM); - buffer obuf(b, NUM); - q.submit([&](auto &h) { - accessor in_acc{ibuf, h}; - accessor out_acc{obuf, h}; - h.parallel_for(num_items, [=](nd_item<1> i) { - int8_t *optr = - out_acc.template get_multi_ptr().get(); - uint8_t *by = reinterpret_cast(optr); -#ifdef DES - __devicelib_default_work_group_joint_sort_descending_p1i8_u32_p3i8( - in_acc.template get_multi_ptr().get(), NUM, - by); -#else - __devicelib_default_work_group_joint_sort_ascending_p1i8_u32_p3i8( - in_acc.template get_multi_ptr().get(), NUM, by); -#endif - }); - }).wait(); - } - - bool fails = false; - for (size_t idx = 0; idx < NUM; ++idx) { -#ifdef DES - if (a[idx] != c[NUM - 1 - idx]) { -#else - if (a[idx] != c[idx]) { -#endif - fails = true; - break; - } - } - assert(!fails); - std::cout << "Pass!" << std::endl; - return 0; -} diff --git a/sycl/test-e2e/DeviceLib/group_sort/workgroup_joint_sort.cpp b/sycl/test-e2e/DeviceLib/group_sort/workgroup_joint_sort.cpp index 77e157617b4b2..7a738b465f780 100644 --- a/sycl/test-e2e/DeviceLib/group_sort/workgroup_joint_sort.cpp +++ b/sycl/test-e2e/DeviceLib/group_sort/workgroup_joint_sort.cpp @@ -12,74 +12,13 @@ // // UNSUPPORTED: cuda || hip +#include "group_sort_decls.hpp" #include #include #include #include #include using namespace sycl; -#ifdef __SYCL_DEVICE_ONLY__ -SYCL_EXTERNAL extern "C" void -__devicelib_default_work_group_joint_sort_ascending_p1i32_u32_p1i8( - int32_t *first, uint32_t n, uint8_t *scratch); - -SYCL_EXTERNAL extern "C" void -__devicelib_default_work_group_joint_sort_descending_p1i32_u32_p1i8( - int32_t *first, uint32_t n, uint8_t *scratch); - -SYCL_EXTERNAL extern "C" void -__devicelib_default_work_group_joint_sort_ascending_p1i16_u32_p1i8( - int16_t *first, uint32_t n, uint8_t *scratch); - -SYCL_EXTERNAL extern "C" void -__devicelib_default_work_group_joint_sort_descending_p1i16_u32_p1i8( - int16_t *first, uint32_t n, uint8_t *scratch); - -SYCL_EXTERNAL extern "C" void -__devicelib_default_work_group_joint_sort_ascending_p1i64_u32_p1i8( - int64_t *first, uint32_t n, uint8_t *scratch); - -SYCL_EXTERNAL extern "C" void -__devicelib_default_work_group_joint_sort_descending_p1i64_u32_p1i8( - int64_t *first, uint32_t n, uint8_t *scratch); -#else -extern "C" void -__devicelib_default_work_group_joint_sort_ascending_p1i32_u32_p1i8( - int32_t *first, uint32_t n, uint8_t *scratch) { - return; -} - -SYCL_EXTERNAL extern "C" void -__devicelib_default_work_group_joint_sort_descending_p1i32_u32_p1i8( - int32_t *first, uint32_t n, uint8_t *scratch) { - return; -} - -extern "C" void -__devicelib_default_work_group_joint_sort_ascending_p1i16_u32_p1i8( - int16_t *first, uint32_t n, uint8_t *scratch) { - return; -} - -extern "C" void -__devicelib_default_work_group_joint_sort_descending_p1i16_u32_p1i8( - int16_t *first, uint32_t n, uint8_t *scratch) { - return; -} - -extern "C" void -__devicelib_default_work_group_joint_sort_ascending_p1i64_u32_p1i8( - int64_t *first, uint32_t n, uint8_t *scratch) { - return; -} - -extern "C" void -__devicelib_default_work_group_joint_sort_descending_p1i64_u32_p1i8( - int64_t *first, uint32_t n, uint8_t *scratch) { - return; -} - -#endif template void test_work_group_joint_sort(sycl::queue &q, Ty input[NUM], SortHelper gsh) { @@ -124,6 +63,24 @@ void test_work_group_joint_sort(sycl::queue &q, Ty input[NUM], SortHelper gsh) { int main() { queue q; + { + constexpr static int NUM = 19; + int8_t a[NUM] = {-1, 11, 10, 9, 3, 100, 34, 8, 10, 77, + -93, 23, 36, 2, 111, 91, 88, -2, -25}; + auto work_group_sorter = [](int8_t *first, uint32_t n, uint8_t *scratch) { +#ifdef DES + __devicelib_default_work_group_joint_sort_descending_p1i8_u32_p1i8( + first, n, scratch); +#else + __devicelib_default_work_group_joint_sort_ascending_p1i8_u32_p1i8( + first, n, scratch); +#endif + }; + test_work_group_joint_sort( + q, a, work_group_sorter); + std::cout << "work group joint sort p1i8_u32_p1i8 passes" << std::endl; + } + { constexpr static int NUM = 19; int32_t a[NUM] = {-1, 11, 1, 9, 3, 100, 34, 8, 1000, 77, @@ -179,5 +136,26 @@ int main() { std::cout << "work group joint sort p1i64_u32_p1i8 passes" << std::endl; } + { + constexpr static int NUM = 23; + float a[NUM] = {-1.25f, 11.4643f, 1.45f, -9.98f, 13.665f, + 100.0f, 34.625f, 8.125f, 1000.12f, 77.91f, + 293.33f, 23.4f, -36.6f, 2.5f, 111.11f, + 91.889f, 88.88f, -2.98f, 525.25f, -12.11f, + 525.0f, -9999999.9f, 19928348493.123f}; + auto work_group_sorter = [](float *first, uint32_t n, uint8_t *scratch) { +#ifdef DES + __devicelib_default_work_group_joint_sort_descending_p1f32_u32_p1i8( + first, n, scratch); +#else + __devicelib_default_work_group_joint_sort_ascending_p1f32_u32_p1i8( + first, n, scratch); +#endif + }; + test_work_group_joint_sort( + q, a, work_group_sorter); + std::cout << "work group joint sort p1f32_u32_p1i8 passes" << std::endl; + } + return 0; } From b400b033eb8fc2e2f9f441b2ebd3d4fff22e5ba9 Mon Sep 17 00:00:00 2001 From: jinge90 Date: Mon, 22 Jul 2024 11:31:27 +0800 Subject: [PATCH 27/71] Fix joint_sort_p1u64_u32_p1u8 Signed-off-by: jinge90 --- libdevice/fallback-gsort.cpp | 1 + .../DeviceLib/group_sort/group_sort_decls.hpp | 82 ++++++++++++++++ .../group_sort/workgroup_joint_sort.cpp | 95 +++++++++++++++++++ 3 files changed, 178 insertions(+) diff --git a/libdevice/fallback-gsort.cpp b/libdevice/fallback-gsort.cpp index fb21ab62a4af8..d36667fd17c00 100644 --- a/libdevice/fallback-gsort.cpp +++ b/libdevice/fallback-gsort.cpp @@ -350,6 +350,7 @@ void __devicelib_default_work_group_joint_sort_descending_p3u32_u32_p3i8( merge_sort(first, n, scratch, std::greater{}); } +DEVICE_EXTERN_C_INLINE void __devicelib_default_work_group_joint_sort_ascending_p1u64_u32_p1i8( uint64_t *first, uint32_t n, uint8_t *scratch) { merge_sort(first, n, scratch, std::less{}); diff --git a/sycl/test-e2e/DeviceLib/group_sort/group_sort_decls.hpp b/sycl/test-e2e/DeviceLib/group_sort/group_sort_decls.hpp index 8200beaf007b7..d54bdae571834 100644 --- a/sycl/test-e2e/DeviceLib/group_sort/group_sort_decls.hpp +++ b/sycl/test-e2e/DeviceLib/group_sort/group_sort_decls.hpp @@ -41,6 +41,39 @@ __devicelib_default_work_group_joint_sort_ascending_p1f32_u32_p1i8( SYCL_EXTERNAL extern "C" void __devicelib_default_work_group_joint_sort_descending_p1f32_u32_p1i8( float *first, uint32_t n, uint8_t *scratch); + +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p1u8_u32_p1i8( + uint8_t *first, uint32_t n, uint8_t *scratch); + +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_descending_p1u8_u32_p1i8( + uint8_t *first, uint32_t n, uint8_t *scratch); + +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p1u16_u32_p1i8( + uint16_t *first, uint32_t n, uint8_t *scratch); + +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_descending_p1u16_u32_p1i8( + uint16_t *first, uint32_t n, uint8_t *scratch); + +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p1u32_u32_p1i8( + uint32_t *first, uint32_t n, uint8_t *scratch); + +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_descending_p1u32_u32_p1i8( + uint32_t *first, uint32_t n, uint8_t *scratch); + +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p1u64_u32_p1i8( + uint64_t *first, uint32_t n, uint8_t *scratch); + +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_descending_p1u64_u32_p1i8( + uint64_t *first, uint32_t n, uint8_t *scratch); + #else extern "C" void __devicelib_default_work_group_joint_sort_ascending_p1i32_u32_p1i8( @@ -101,4 +134,53 @@ __devicelib_default_work_group_joint_sort_descending_p1f32_u32_p1i8( float *first, uint32_t n, uint8_t *scratch) { return; } + +extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p1u8_u32_p1i8( + uint8_t *first, uint32_t n, uint8_t *scratch) { + return; +} + +extern "C" void +__devicelib_default_work_group_joint_sort_descending_p1u8_u32_p1i8( + uint8_t *first, uint32_t n, uint8_t *scratch) { + return; +} + +extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p1u16_u32_p1i8( + uint16_t *first, uint32_t n, uint8_t *scratch) { + return; +} + +extern "C" void +__devicelib_default_work_group_joint_sort_descending_p1u16_u32_p1i8( + uint16_t *first, uint32_t n, uint8_t *scratch) { + return; +} + +extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p1u32_u32_p1i8( + uint32_t *first, uint32_t n, uint8_t *scratch) { + return; +} + +extern "C" void +__devicelib_default_work_group_joint_sort_descending_p1u32_u32_p1i8( + uint32_t *first, uint32_t n, uint8_t *scratch) { + return; +} + +extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p1u64_u32_p1i8( + uint64_t *first, uint32_t n, uint8_t *scratch) { + return; +} + +extern "C" void +__devicelib_default_work_group_joint_sort_descending_p1u64_u32_p1i8( + uint64_t *first, uint32_t n, uint8_t *scratch) { + return; +} + #endif diff --git a/sycl/test-e2e/DeviceLib/group_sort/workgroup_joint_sort.cpp b/sycl/test-e2e/DeviceLib/group_sort/workgroup_joint_sort.cpp index 7a738b465f780..3f9403964ec8e 100644 --- a/sycl/test-e2e/DeviceLib/group_sort/workgroup_joint_sort.cpp +++ b/sycl/test-e2e/DeviceLib/group_sort/workgroup_joint_sort.cpp @@ -157,5 +157,100 @@ int main() { std::cout << "work group joint sort p1f32_u32_p1i8 passes" << std::endl; } + { + constexpr static int NUM = 23; + uint8_t a[NUM] = {234, 11, 1, 9, 3, 100, 34, 8, 121, 77, 125, 23, + 36, 2, 111, 91, 201, 211, 77, 8, 88, 19, 0}; + auto work_group_sorter = [](uint8_t *first, uint32_t n, uint8_t *scratch) { +#ifdef DES + __devicelib_default_work_group_joint_sort_descending_p1u8_u32_p1i8( + first, n, scratch); +#else + __devicelib_default_work_group_joint_sort_ascending_p1u8_u32_p1i8( + first, n, scratch); +#endif + }; + test_work_group_joint_sort( + q, a, work_group_sorter); + std::cout << "work group joint sort p1u8_u32_p1i8 passes" << std::endl; + } + + { + constexpr static int NUM = 23; + uint16_t a[NUM] = {11234, 11, 1, 119, 3, 100, 341, 8, + 121, 77, 125, 23, 3226, 2, 111, 911, + 201, 211, 77, 8, 11188, 19, 0}; + auto work_group_sorter = [](uint16_t *first, uint32_t n, uint8_t *scratch) { +#ifdef DES + __devicelib_default_work_group_joint_sort_descending_p1u16_u32_p1i8( + first, n, scratch); +#else + __devicelib_default_work_group_joint_sort_ascending_p1u16_u32_p1i8( + first, n, scratch); +#endif + }; + test_work_group_joint_sort( + q, a, work_group_sorter); + std::cout << "work group joint sort p1u16_u32_p1i8 passes" << std::endl; + } + + { + constexpr static int NUM = 23; + uint32_t a[NUM] = {11234, 11, 1, 1193332332, 231, 100, 341, 8, + 121, 77, 125, 32, 3226, 2, 111, 911, + 9912201, 211, 711117, 8, 11188, 19, 0}; + auto work_group_sorter = [](uint32_t *first, uint32_t n, uint8_t *scratch) { +#ifdef DES + __devicelib_default_work_group_joint_sort_descending_p1u32_u32_p1i8( + first, n, scratch); +#else + __devicelib_default_work_group_joint_sort_ascending_p1u32_u32_p1i8( + first, n, scratch); +#endif + }; + test_work_group_joint_sort( + q, a, work_group_sorter); + std::cout << "work group joint sort p1u32_u32_p1i8 passes" << std::endl; + } + + { + constexpr static int NUM = 23; + uint64_t a[NUM] = {0x112A111111FFEEFF, + 0xAACC11, + 0x1, + 0x1193332332, + 0x231, + 0xAA, + 0xFCCCA341, + 0x8, + 0x121, + 0x987777777, + 0x81, + 0x20, + 0x3226, + 0x2, + 0x8FFFFFFFFF111, + 0x911, + 0xAAAA9912201, + 0x211, + 0x711117, + 0x8, + 0xABABABABCC, + 0x13, + 0}; + auto work_group_sorter = [](uint64_t *first, uint32_t n, uint8_t *scratch) { +#ifdef DES + __devicelib_default_work_group_joint_sort_descending_p1u64_u32_p1i8( + first, n, scratch); +#else + __devicelib_default_work_group_joint_sort_ascending_p1u64_u32_p1i8( + first, n, scratch); +#endif + }; + test_work_group_joint_sort( + q, a, work_group_sorter); + std::cout << "work group joint sort p1u64_u32_p1i8 passes" << std::endl; + } + return 0; } From c603cb1abaf282535b6d7521baed992b7f1f2fa3 Mon Sep 17 00:00:00 2001 From: jinge90 Date: Mon, 22 Jul 2024 16:24:38 +0800 Subject: [PATCH 28/71] add tests Signed-off-by: jinge90 --- ...rt_decls.hpp => group_joint_sort_p1p1.hpp} | 0 .../group_sort/group_joint_sort_p1p3.hpp | 186 +++++++++++ .../group_sort/group_joint_sort_p3p1.hpp | 186 +++++++++++ .../group_sort/group_joint_sort_p3p3.hpp | 186 +++++++++++ ...sort.cpp => workgroup_joint_sort_p1p1.cpp} | 2 +- .../group_sort/workgroup_joint_sort_p1p3.cpp | 253 +++++++++++++++ .../group_sort/workgroup_joint_sort_p3p1.cpp | 293 ++++++++++++++++++ .../group_sort/workgroup_joint_sort_p3p3.cpp | 289 +++++++++++++++++ 8 files changed, 1394 insertions(+), 1 deletion(-) rename sycl/test-e2e/DeviceLib/group_sort/{group_sort_decls.hpp => group_joint_sort_p1p1.hpp} (100%) create mode 100644 sycl/test-e2e/DeviceLib/group_sort/group_joint_sort_p1p3.hpp create mode 100644 sycl/test-e2e/DeviceLib/group_sort/group_joint_sort_p3p1.hpp create mode 100644 sycl/test-e2e/DeviceLib/group_sort/group_joint_sort_p3p3.hpp rename sycl/test-e2e/DeviceLib/group_sort/{workgroup_joint_sort.cpp => workgroup_joint_sort_p1p1.cpp} (99%) create mode 100644 sycl/test-e2e/DeviceLib/group_sort/workgroup_joint_sort_p1p3.cpp create mode 100644 sycl/test-e2e/DeviceLib/group_sort/workgroup_joint_sort_p3p1.cpp create mode 100644 sycl/test-e2e/DeviceLib/group_sort/workgroup_joint_sort_p3p3.cpp diff --git a/sycl/test-e2e/DeviceLib/group_sort/group_sort_decls.hpp b/sycl/test-e2e/DeviceLib/group_sort/group_joint_sort_p1p1.hpp similarity index 100% rename from sycl/test-e2e/DeviceLib/group_sort/group_sort_decls.hpp rename to sycl/test-e2e/DeviceLib/group_sort/group_joint_sort_p1p1.hpp diff --git a/sycl/test-e2e/DeviceLib/group_sort/group_joint_sort_p1p3.hpp b/sycl/test-e2e/DeviceLib/group_sort/group_joint_sort_p1p3.hpp new file mode 100644 index 0000000000000..1177c82254c01 --- /dev/null +++ b/sycl/test-e2e/DeviceLib/group_sort/group_joint_sort_p1p3.hpp @@ -0,0 +1,186 @@ +#pragma once +#include + +#ifdef __SYCL_DEVICE_ONLY__ +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p1i32_u32_p3i8( + int32_t *first, uint32_t n, uint8_t *scratch); + +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_descending_p1i32_u32_p3i8( + int32_t *first, uint32_t n, uint8_t *scratch); + +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p1i16_u32_p3i8( + int16_t *first, uint32_t n, uint8_t *scratch); + +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_descending_p1i16_u32_p3i8( + int16_t *first, uint32_t n, uint8_t *scratch); + +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p1i64_u32_p3i8( + int64_t *first, uint32_t n, uint8_t *scratch); + +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_descending_p1i64_u32_p3i8( + int64_t *first, uint32_t n, uint8_t *scratch); + +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p1i8_u32_p3i8( + int8_t *first, uint32_t n, uint8_t *scratch); + +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_descending_p1i8_u32_p3i8( + int8_t *first, uint32_t n, uint8_t *scratch); + +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p1f32_u32_p3i8( + float *first, uint32_t n, uint8_t *scratch); + +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_descending_p1f32_u32_p3i8( + float *first, uint32_t n, uint8_t *scratch); + +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p1u8_u32_p3i8( + uint8_t *first, uint32_t n, uint8_t *scratch); + +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_descending_p1u8_u32_p3i8( + uint8_t *first, uint32_t n, uint8_t *scratch); + +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p1u16_u32_p3i8( + uint16_t *first, uint32_t n, uint8_t *scratch); + +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_descending_p1u16_u32_p3i8( + uint16_t *first, uint32_t n, uint8_t *scratch); + +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p1u32_u32_p3i8( + uint32_t *first, uint32_t n, uint8_t *scratch); + +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_descending_p1u32_u32_p3i8( + uint32_t *first, uint32_t n, uint8_t *scratch); + +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p1u64_u32_p3i8( + uint64_t *first, uint32_t n, uint8_t *scratch); + +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_descending_p1u64_u32_p3i8( + uint64_t *first, uint32_t n, uint8_t *scratch); + +#else +extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p1i32_u32_p3i8( + int32_t *first, uint32_t n, uint8_t *scratch) { + return; +} + +extern "C" void +__devicelib_default_work_group_joint_sort_descending_p1i32_u32_p3i8( + int32_t *first, uint32_t n, uint8_t *scratch) { + return; +} + +extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p1i16_u32_p3i8( + int16_t *first, uint32_t n, uint8_t *scratch) { + return; +} + +extern "C" void +__devicelib_default_work_group_joint_sort_descending_p1i16_u32_p3i8( + int16_t *first, uint32_t n, uint8_t *scratch) { + return; +} + +extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p1i64_u32_p3i8( + int64_t *first, uint32_t n, uint8_t *scratch) { + return; +} + +extern "C" void +__devicelib_default_work_group_joint_sort_descending_p1i64_u32_p3i8( + int64_t *first, uint32_t n, uint8_t *scratch) { + return; +} + +extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p1i8_u32_p3i8( + int8_t *first, uint32_t n, uint8_t *scratch) { + return; +} + +extern "C" void +__devicelib_default_work_group_joint_sort_descending_p1i8_u32_p3i8( + int8_t *first, uint32_t n, uint8_t *scratch) { + return; +} + +extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p1f32_u32_p3i8( + float *first, uint32_t n, uint8_t *scratch) { + return; +} + +extern "C" void +__devicelib_default_work_group_joint_sort_descending_p1f32_u32_p3i8( + float *first, uint32_t n, uint8_t *scratch) { + return; +} + +extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p1u8_u32_p3i8( + uint8_t *first, uint32_t n, uint8_t *scratch) { + return; +} + +extern "C" void +__devicelib_default_work_group_joint_sort_descending_p1u8_u32_p3i8( + uint8_t *first, uint32_t n, uint8_t *scratch) { + return; +} + +extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p1u16_u32_p3i8( + uint16_t *first, uint32_t n, uint8_t *scratch) { + return; +} + +extern "C" void +__devicelib_default_work_group_joint_sort_descending_p1u16_u32_p3i8( + uint16_t *first, uint32_t n, uint8_t *scratch) { + return; +} + +extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p1u32_u32_p3i8( + uint32_t *first, uint32_t n, uint8_t *scratch) { + return; +} + +extern "C" void +__devicelib_default_work_group_joint_sort_descending_p1u32_u32_p3i8( + uint32_t *first, uint32_t n, uint8_t *scratch) { + return; +} + +extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p1u64_u32_p3i8( + uint64_t *first, uint32_t n, uint8_t *scratch) { + return; +} + +extern "C" void +__devicelib_default_work_group_joint_sort_descending_p1u64_u32_p3i8( + uint64_t *first, uint32_t n, uint8_t *scratch) { + return; +} + +#endif diff --git a/sycl/test-e2e/DeviceLib/group_sort/group_joint_sort_p3p1.hpp b/sycl/test-e2e/DeviceLib/group_sort/group_joint_sort_p3p1.hpp new file mode 100644 index 0000000000000..361e5393ce3fe --- /dev/null +++ b/sycl/test-e2e/DeviceLib/group_sort/group_joint_sort_p3p1.hpp @@ -0,0 +1,186 @@ +#pragma once +#include + +#ifdef __SYCL_DEVICE_ONLY__ +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p3i32_u32_p1i8( + int32_t *first, uint32_t n, uint8_t *scratch); + +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_descending_p3i32_u32_p1i8( + int32_t *first, uint32_t n, uint8_t *scratch); + +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p3i16_u32_p1i8( + int16_t *first, uint32_t n, uint8_t *scratch); + +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_descending_p3i16_u32_p1i8( + int16_t *first, uint32_t n, uint8_t *scratch); + +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p3i64_u32_p1i8( + int64_t *first, uint32_t n, uint8_t *scratch); + +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_descending_p3i64_u32_p1i8( + int64_t *first, uint32_t n, uint8_t *scratch); + +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p3i8_u32_p1i8( + int8_t *first, uint32_t n, uint8_t *scratch); + +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_descending_p3i8_u32_p1i8( + int8_t *first, uint32_t n, uint8_t *scratch); + +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p3f32_u32_p1i8( + float *first, uint32_t n, uint8_t *scratch); + +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_descending_p3f32_u32_p1i8( + float *first, uint32_t n, uint8_t *scratch); + +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p3u8_u32_p1i8( + uint8_t *first, uint32_t n, uint8_t *scratch); + +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_descending_p3u8_u32_p1i8( + uint8_t *first, uint32_t n, uint8_t *scratch); + +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p3u16_u32_p1i8( + uint16_t *first, uint32_t n, uint8_t *scratch); + +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_descending_p3u16_u32_p1i8( + uint16_t *first, uint32_t n, uint8_t *scratch); + +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p3u32_u32_p1i8( + uint32_t *first, uint32_t n, uint8_t *scratch); + +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_descending_p3u32_u32_p1i8( + uint32_t *first, uint32_t n, uint8_t *scratch); + +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p3u64_u32_p1i8( + uint64_t *first, uint32_t n, uint8_t *scratch); + +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_descending_p3u64_u32_p1i8( + uint64_t *first, uint32_t n, uint8_t *scratch); + +#else +extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p3i32_u32_p1i8( + int32_t *first, uint32_t n, uint8_t *scratch) { + return; +} + +extern "C" void +__devicelib_default_work_group_joint_sort_descending_p3i32_u32_p1i8( + int32_t *first, uint32_t n, uint8_t *scratch) { + return; +} + +extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p3i16_u32_p1i8( + int16_t *first, uint32_t n, uint8_t *scratch) { + return; +} + +extern "C" void +__devicelib_default_work_group_joint_sort_descending_p3i16_u32_p1i8( + int16_t *first, uint32_t n, uint8_t *scratch) { + return; +} + +extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p3i64_u32_p1i8( + int64_t *first, uint32_t n, uint8_t *scratch) { + return; +} + +extern "C" void +__devicelib_default_work_group_joint_sort_descending_p3i64_u32_p1i8( + int64_t *first, uint32_t n, uint8_t *scratch) { + return; +} + +extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p3i8_u32_p1i8( + int8_t *first, uint32_t n, uint8_t *scratch) { + return; +} + +extern "C" void +__devicelib_default_work_group_joint_sort_descending_p3i8_u32_p1i8( + int8_t *first, uint32_t n, uint8_t *scratch) { + return; +} + +extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p3f32_u32_p1i8( + float *first, uint32_t n, uint8_t *scratch) { + return; +} + +extern "C" void +__devicelib_default_work_group_joint_sort_descending_p3f32_u32_p1i8( + float *first, uint32_t n, uint8_t *scratch) { + return; +} + +extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p3u8_u32_p1i8( + uint8_t *first, uint32_t n, uint8_t *scratch) { + return; +} + +extern "C" void +__devicelib_default_work_group_joint_sort_descending_p3u8_u32_p1i8( + uint8_t *first, uint32_t n, uint8_t *scratch) { + return; +} + +extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p3u16_u32_p1i8( + uint16_t *first, uint32_t n, uint8_t *scratch) { + return; +} + +extern "C" void +__devicelib_default_work_group_joint_sort_descending_p3u16_u32_p1i8( + uint16_t *first, uint32_t n, uint8_t *scratch) { + return; +} + +extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p3u32_u32_p1i8( + uint32_t *first, uint32_t n, uint8_t *scratch) { + return; +} + +extern "C" void +__devicelib_default_work_group_joint_sort_descending_p3u32_u32_p1i8( + uint32_t *first, uint32_t n, uint8_t *scratch) { + return; +} + +extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p3u64_u32_p1i8( + uint64_t *first, uint32_t n, uint8_t *scratch) { + return; +} + +extern "C" void +__devicelib_default_work_group_joint_sort_descending_p3u64_u32_p1i8( + uint64_t *first, uint32_t n, uint8_t *scratch) { + return; +} + +#endif diff --git a/sycl/test-e2e/DeviceLib/group_sort/group_joint_sort_p3p3.hpp b/sycl/test-e2e/DeviceLib/group_sort/group_joint_sort_p3p3.hpp new file mode 100644 index 0000000000000..4b8ea3b737674 --- /dev/null +++ b/sycl/test-e2e/DeviceLib/group_sort/group_joint_sort_p3p3.hpp @@ -0,0 +1,186 @@ +#pragma once +#include + +#ifdef __SYCL_DEVICE_ONLY__ +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p3i32_u32_p3i8( + int32_t *first, uint32_t n, uint8_t *scratch); + +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_descending_p3i32_u32_p3i8( + int32_t *first, uint32_t n, uint8_t *scratch); + +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p3i16_u32_p3i8( + int16_t *first, uint32_t n, uint8_t *scratch); + +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_descending_p3i16_u32_p3i8( + int16_t *first, uint32_t n, uint8_t *scratch); + +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p3i64_u32_p3i8( + int64_t *first, uint32_t n, uint8_t *scratch); + +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_descending_p3i64_u32_p3i8( + int64_t *first, uint32_t n, uint8_t *scratch); + +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p3i8_u32_p3i8( + int8_t *first, uint32_t n, uint8_t *scratch); + +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_descending_p3i8_u32_p3i8( + int8_t *first, uint32_t n, uint8_t *scratch); + +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p3f32_u32_p3i8( + float *first, uint32_t n, uint8_t *scratch); + +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_descending_p3f32_u32_p3i8( + float *first, uint32_t n, uint8_t *scratch); + +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p3u8_u32_p3i8( + uint8_t *first, uint32_t n, uint8_t *scratch); + +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_descending_p3u8_u32_p3i8( + uint8_t *first, uint32_t n, uint8_t *scratch); + +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p3u16_u32_p3i8( + uint16_t *first, uint32_t n, uint8_t *scratch); + +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_descending_p3u16_u32_p3i8( + uint16_t *first, uint32_t n, uint8_t *scratch); + +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p3u32_u32_p3i8( + uint32_t *first, uint32_t n, uint8_t *scratch); + +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_descending_p3u32_u32_p3i8( + uint32_t *first, uint32_t n, uint8_t *scratch); + +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p3u64_u32_p3i8( + uint64_t *first, uint32_t n, uint8_t *scratch); + +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_descending_p3u64_u32_p3i8( + uint64_t *first, uint32_t n, uint8_t *scratch); + +#else +extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p3i32_u32_p3i8( + int32_t *first, uint32_t n, uint8_t *scratch) { + return; +} + +extern "C" void +__devicelib_default_work_group_joint_sort_descending_p3i32_u32_p3i8( + int32_t *first, uint32_t n, uint8_t *scratch) { + return; +} + +extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p3i16_u32_p3i8( + int16_t *first, uint32_t n, uint8_t *scratch) { + return; +} + +extern "C" void +__devicelib_default_work_group_joint_sort_descending_p3i16_u32_p3i8( + int16_t *first, uint32_t n, uint8_t *scratch) { + return; +} + +extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p3i64_u32_p3i8( + int64_t *first, uint32_t n, uint8_t *scratch) { + return; +} + +extern "C" void +__devicelib_default_work_group_joint_sort_descending_p3i64_u32_p3i8( + int64_t *first, uint32_t n, uint8_t *scratch) { + return; +} + +extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p3i8_u32_p3i8( + int8_t *first, uint32_t n, uint8_t *scratch) { + return; +} + +extern "C" void +__devicelib_default_work_group_joint_sort_descending_p3i8_u32_p3i8( + int8_t *first, uint32_t n, uint8_t *scratch) { + return; +} + +extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p3f32_u32_p3i8( + float *first, uint32_t n, uint8_t *scratch) { + return; +} + +extern "C" void +__devicelib_default_work_group_joint_sort_descending_p3f32_u32_p3i8( + float *first, uint32_t n, uint8_t *scratch) { + return; +} + +extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p3u8_u32_p3i8( + uint8_t *first, uint32_t n, uint8_t *scratch) { + return; +} + +extern "C" void +__devicelib_default_work_group_joint_sort_descending_p3u8_u32_p3i8( + uint8_t *first, uint32_t n, uint8_t *scratch) { + return; +} + +extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p3u16_u32_p3i8( + uint16_t *first, uint32_t n, uint8_t *scratch) { + return; +} + +extern "C" void +__devicelib_default_work_group_joint_sort_descending_p3u16_u32_p3i8( + uint16_t *first, uint32_t n, uint8_t *scratch) { + return; +} + +extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p3u32_u32_p3i8( + uint32_t *first, uint32_t n, uint8_t *scratch) { + return; +} + +extern "C" void +__devicelib_default_work_group_joint_sort_descending_p3u32_u32_p3i8( + uint32_t *first, uint32_t n, uint8_t *scratch) { + return; +} + +extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p3u64_u32_p3i8( + uint64_t *first, uint32_t n, uint8_t *scratch) { + return; +} + +extern "C" void +__devicelib_default_work_group_joint_sort_descending_p3u64_u32_p3i8( + uint64_t *first, uint32_t n, uint8_t *scratch) { + return; +} + +#endif diff --git a/sycl/test-e2e/DeviceLib/group_sort/workgroup_joint_sort.cpp b/sycl/test-e2e/DeviceLib/group_sort/workgroup_joint_sort_p1p1.cpp similarity index 99% rename from sycl/test-e2e/DeviceLib/group_sort/workgroup_joint_sort.cpp rename to sycl/test-e2e/DeviceLib/group_sort/workgroup_joint_sort_p1p1.cpp index 3f9403964ec8e..1a9f42af789ea 100644 --- a/sycl/test-e2e/DeviceLib/group_sort/workgroup_joint_sort.cpp +++ b/sycl/test-e2e/DeviceLib/group_sort/workgroup_joint_sort_p1p1.cpp @@ -12,7 +12,7 @@ // // UNSUPPORTED: cuda || hip -#include "group_sort_decls.hpp" +#include "group_joint_sort_p1p1.hpp" #include #include #include diff --git a/sycl/test-e2e/DeviceLib/group_sort/workgroup_joint_sort_p1p3.cpp b/sycl/test-e2e/DeviceLib/group_sort/workgroup_joint_sort_p1p3.cpp new file mode 100644 index 0000000000000..8ac8c0326ab41 --- /dev/null +++ b/sycl/test-e2e/DeviceLib/group_sort/workgroup_joint_sort_p1p3.cpp @@ -0,0 +1,253 @@ +// RUN: %{build} -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -fsycl-device-lib-jit-link -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DDES -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DES -fsycl-device-lib-jit-link -o %t.out +// RUN: %{run} %t.out +// +// UNSUPPORTED: cuda || hip + +#include "group_joint_sort_p1p3.hpp" +#include +#include +#include +#include +#include +using namespace sycl; + +// For __devicelib_default_work_group_xxx_p1*_u32_p3u8, the scratch memory is +// in shared local memory. +template +void test_work_group_joint_sort(sycl::queue &q, Ty input[NUM], SortHelper gsh) { + Ty result[NUM]; + memcpy(result, input, sizeof(Ty) * NUM); + std::sort(&result[0], &result[NUM]); + const static size_t wg_size = WG_SZ; + nd_range<1> num_items((range<1>(wg_size)), (range<1>(wg_size))); + { + buffer ibuf(input, NUM); + q.submit([&](auto &h) { + accessor in_acc{ibuf, h}; + local_accessor slm(NUM, h); + h.parallel_for(num_items, [=](nd_item<1> i) { + Ty *optr = slm.template get_multi_ptr().get(); + uint8_t *by = reinterpret_cast(optr); + gsh(in_acc.template get_multi_ptr().get(), NUM, + by); + }); + }).wait(); + } + + bool fails = false; + for (size_t idx = 0; idx < NUM; ++idx) { +#ifdef DES + if (input[idx] != result[NUM - 1 - idx]) { +#else + if (input[idx] != result[idx]) { +#endif + fails = true; + break; + } + } + assert(!fails); +} + +int main() { + queue q; + + { + constexpr static int NUM = 19; + int8_t a[NUM] = {-1, 11, 10, 9, 3, 100, 34, 8, 10, 77, + -93, 23, 36, 2, 111, 91, 88, -2, -25}; + auto work_group_sorter = [](int8_t *first, uint32_t n, uint8_t *scratch) { +#ifdef DES + __devicelib_default_work_group_joint_sort_descending_p1i8_u32_p3i8( + first, n, scratch); +#else + __devicelib_default_work_group_joint_sort_ascending_p1i8_u32_p3i8( + first, n, scratch); +#endif + }; + test_work_group_joint_sort( + q, a, work_group_sorter); + std::cout << "work group joint sort p1i8_u32_p3i8 passes" << std::endl; + } + + { + constexpr static int NUM = 19; + int32_t a[NUM] = {-1, 11, 1, 9, 3, 100, 34, 8, 1000, 77, + 293, 23, 36, 2, 111, 91, 88, -2, 525}; + auto work_group_sorter = [](int32_t *first, uint32_t n, uint8_t *scratch) { +#ifdef DES + __devicelib_default_work_group_joint_sort_descending_p1i32_u32_p3i8( + first, n, scratch); +#else + __devicelib_default_work_group_joint_sort_ascending_p1i32_u32_p3i8( + first, n, scratch); +#endif + }; + test_work_group_joint_sort( + q, a, work_group_sorter); + std::cout << "work group joint sort p1i32_u32_p3i8 passes" << std::endl; + } + + { + constexpr static int NUM = 21; + int16_t a[NUM] = {-1, 11, 1, 9, 3, 100, 34, 8, 1000, 77, 293, + 23, 36, 2, 111, 91, 88, -2, 525, -12, 525}; + auto work_group_sorter = [](int16_t *first, uint32_t n, uint8_t *scratch) { +#ifdef DES + __devicelib_default_work_group_joint_sort_descending_p1i16_u32_p3i8( + first, n, scratch); +#else + __devicelib_default_work_group_joint_sort_ascending_p1i16_u32_p3i8( + first, n, scratch); +#endif + }; + test_work_group_joint_sort( + q, a, work_group_sorter); + std::cout << "work group joint sort p1i16_u32_p3i8 passes" << std::endl; + } + + { + constexpr static int NUM = 23; + int64_t a[NUM] = {-1, 11, 1, 9, 3, 100, 34, 8, + 1000, 77, 293, 23, 36, 2, 111, 91, + 88, -2, 525, -12, 525, -99999999, 19928348493}; + auto work_group_sorter = [](int64_t *first, uint32_t n, uint8_t *scratch) { +#ifdef DES + __devicelib_default_work_group_joint_sort_descending_p1i64_u32_p3i8( + first, n, scratch); +#else + __devicelib_default_work_group_joint_sort_ascending_p1i64_u32_p3i8( + first, n, scratch); +#endif + }; + test_work_group_joint_sort( + q, a, work_group_sorter); + std::cout << "work group joint sort p1i64_u32_p3i8 passes" << std::endl; + } + + { + constexpr static int NUM = 23; + float a[NUM] = {-1.25f, 11.4643f, 1.45f, -9.98f, 13.665f, + 100.0f, 34.625f, 8.125f, 1000.12f, 77.91f, + 293.33f, 23.4f, -36.6f, 2.5f, 111.11f, + 91.889f, 88.88f, -2.98f, 525.25f, -12.11f, + 525.0f, -9999999.9f, 19928348493.123f}; + auto work_group_sorter = [](float *first, uint32_t n, uint8_t *scratch) { +#ifdef DES + __devicelib_default_work_group_joint_sort_descending_p1f32_u32_p3i8( + first, n, scratch); +#else + __devicelib_default_work_group_joint_sort_ascending_p1f32_u32_p3i8( + first, n, scratch); +#endif + }; + test_work_group_joint_sort( + q, a, work_group_sorter); + std::cout << "work group joint sort p1f32_u32_p3i8 passes" << std::endl; + } + + { + constexpr static int NUM = 23; + uint8_t a[NUM] = {234, 11, 1, 9, 3, 100, 34, 8, 121, 77, 125, 23, + 36, 2, 111, 91, 201, 211, 77, 8, 88, 19, 0}; + auto work_group_sorter = [](uint8_t *first, uint32_t n, uint8_t *scratch) { +#ifdef DES + __devicelib_default_work_group_joint_sort_descending_p1u8_u32_p3i8( + first, n, scratch); +#else + __devicelib_default_work_group_joint_sort_ascending_p1u8_u32_p3i8( + first, n, scratch); +#endif + }; + test_work_group_joint_sort( + q, a, work_group_sorter); + std::cout << "work group joint sort p1u8_u32_p3i8 passes" << std::endl; + } + + { + constexpr static int NUM = 23; + uint16_t a[NUM] = {11234, 11, 1, 119, 3, 100, 341, 8, + 121, 77, 125, 23, 3226, 2, 111, 911, + 201, 211, 77, 8, 11188, 19, 0}; + auto work_group_sorter = [](uint16_t *first, uint32_t n, uint8_t *scratch) { +#ifdef DES + __devicelib_default_work_group_joint_sort_descending_p1u16_u32_p3i8( + first, n, scratch); +#else + __devicelib_default_work_group_joint_sort_ascending_p1u16_u32_p3i8( + first, n, scratch); +#endif + }; + test_work_group_joint_sort( + q, a, work_group_sorter); + std::cout << "work group joint sort p1u16_u32_p3i8 passes" << std::endl; + } + + { + constexpr static int NUM = 23; + uint32_t a[NUM] = {11234, 11, 1, 1193332332, 231, 100, 341, 8, + 121, 77, 125, 32, 3226, 2, 111, 911, + 9912201, 211, 711117, 8, 11188, 19, 0}; + auto work_group_sorter = [](uint32_t *first, uint32_t n, uint8_t *scratch) { +#ifdef DES + __devicelib_default_work_group_joint_sort_descending_p1u32_u32_p3i8( + first, n, scratch); +#else + __devicelib_default_work_group_joint_sort_ascending_p1u32_u32_p3i8( + first, n, scratch); +#endif + }; + test_work_group_joint_sort( + q, a, work_group_sorter); + std::cout << "work group joint sort p1u32_u32_p3i8 passes" << std::endl; + } + + { + constexpr static int NUM = 23; + uint64_t a[NUM] = {0x112A111111FFEEFF, + 0xAACC11, + 0x1, + 0x1193332332, + 0x231, + 0xAA, + 0xFCCCA341, + 0x8, + 0x121, + 0x987777777, + 0x81, + 0x20, + 0x3226, + 0x2, + 0x8FFFFFFFFF111, + 0x911, + 0xAAAA9912201, + 0x211, + 0x711117, + 0x8, + 0xABABABABCC, + 0x13, + 0}; + auto work_group_sorter = [](uint64_t *first, uint32_t n, uint8_t *scratch) { +#ifdef DES + __devicelib_default_work_group_joint_sort_descending_p1u64_u32_p3i8( + first, n, scratch); +#else + __devicelib_default_work_group_joint_sort_ascending_p1u64_u32_p3i8( + first, n, scratch); +#endif + }; + test_work_group_joint_sort( + q, a, work_group_sorter); + std::cout << "work group joint sort p1u64_u32_p3i8 passes" << std::endl; + } + + return 0; +} diff --git a/sycl/test-e2e/DeviceLib/group_sort/workgroup_joint_sort_p3p1.cpp b/sycl/test-e2e/DeviceLib/group_sort/workgroup_joint_sort_p3p1.cpp new file mode 100644 index 0000000000000..3ca5eaed71017 --- /dev/null +++ b/sycl/test-e2e/DeviceLib/group_sort/workgroup_joint_sort_p3p1.cpp @@ -0,0 +1,293 @@ +// RUN: %{build} -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -fsycl-device-lib-jit-link -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DDES -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DES -fsycl-device-lib-jit-link -o %t.out +// RUN: %{run} %t.out +// +// UNSUPPORTED: cuda || hip + +#include "group_joint_sort_p3p1.hpp" +#include +#include +#include +#include +#include +using namespace sycl; + +template +void test_work_group_joint_sort(sycl::queue &q, Ty input[NUM], SortHelper gsh) { + Ty scratch[NUM] = { + 0, + }; + Ty result[NUM]; + memcpy(result, input, sizeof(Ty) * NUM); + std::sort(&result[0], &result[NUM]); + const static size_t wg_size = WG_SZ; + constexpr size_t copy_sz_per_item = NUM / WG_SZ; + nd_range<1> num_items((range<1>(wg_size)), (range<1>(wg_size))); + { + buffer ibuf(input, NUM); + buffer obuf(scratch, NUM); + q.submit([&](auto &h) { + accessor in_acc{ibuf, h}; + local_accessor slm(NUM, h); + accessor out_acc{obuf, h}; + h.parallel_for(num_items, [=](nd_item<1> i) { + // Copy input to local shared memory for sort. + if constexpr (copy_sz_per_item == 0) { + if (i.get_global_id() < NUM) + slm[i.get_global_id()] = in_acc[i.get_global_id()]; + } else { + for (size_t idx = 0; idx < copy_sz_per_item; ++idx) { + slm[copy_sz_per_item * i.get_global_id() + idx] = + in_acc[copy_sz_per_item * i.get_global_id() + idx]; + } + + if (i.get_global_id() < (NUM - copy_sz_per_item * WG_SZ)) { + slm[copy_sz_per_item * WG_SZ + i.get_global_id()] = + in_acc[copy_sz_per_item * WG_SZ + i.get_global_id()]; + } + } + + group_barrier(i.get_group()); + Ty *optr = + out_acc.template get_multi_ptr().get(); + uint8_t *by = reinterpret_cast(optr); + gsh(slm.template get_multi_ptr().get(), NUM, + by); + + // Copy sorted data in shared memory to in_acc + if constexpr (copy_sz_per_item == 0) { + if (i.get_global_id() < NUM) + in_acc[i.get_global_id()] = slm[i.get_global_id()]; + } else { + for (size_t idx = 0; idx < copy_sz_per_item; ++idx) { + in_acc[copy_sz_per_item * i.get_global_id() + idx] = + slm[copy_sz_per_item * i.get_global_id() + idx]; + } + + if (i.get_global_id() < (NUM - copy_sz_per_item * WG_SZ)) { + in_acc[copy_sz_per_item * WG_SZ + i.get_global_id()] = + slm[copy_sz_per_item * WG_SZ + i.get_global_id()]; + } + } + + group_barrier(i.get_group()); + }); + }).wait(); + } + + bool fails = false; + for (size_t idx = 0; idx < NUM; ++idx) { +#ifdef DES + if (input[idx] != result[NUM - 1 - idx]) { +#else + if (input[idx] != result[idx]) { +#endif + fails = true; + break; + } + } + assert(!fails); +} + +int main() { + queue q; + + { + constexpr static int NUM = 19; + int8_t a[NUM] = {-1, 11, 10, 9, 3, 100, 34, 8, 10, 77, + -93, 23, 36, 2, 111, 91, 88, -2, -25}; + auto work_group_sorter = [](int8_t *first, uint32_t n, uint8_t *scratch) { +#ifdef DES + __devicelib_default_work_group_joint_sort_descending_p3i8_u32_p1i8( + first, n, scratch); +#else + __devicelib_default_work_group_joint_sort_ascending_p3i8_u32_p1i8( + first, n, scratch); +#endif + }; + test_work_group_joint_sort( + q, a, work_group_sorter); + std::cout << "work group joint sort p3i8_u32_p1i8 passes" << std::endl; + } + + { + constexpr static int NUM = 19; + int32_t a[NUM] = {-1, 11, 1, 9, 3, 100, 34, 8, 1000, 77, + 293, 23, 36, 2, 111, 91, 88, -2, 525}; + auto work_group_sorter = [](int32_t *first, uint32_t n, uint8_t *scratch) { +#ifdef DES + __devicelib_default_work_group_joint_sort_descending_p3i32_u32_p1i8( + first, n, scratch); +#else + __devicelib_default_work_group_joint_sort_ascending_p3i32_u32_p1i8( + first, n, scratch); +#endif + }; + test_work_group_joint_sort( + q, a, work_group_sorter); + std::cout << "work group joint sort p3i32_u32_p1i8 passes" << std::endl; + } + + { + constexpr static int NUM = 21; + int16_t a[NUM] = {-1, 11, 1, 9, 3, 100, 34, 8, 1000, 77, 293, + 23, 36, 2, 111, 91, 88, -2, 525, -12, 525}; + auto work_group_sorter = [](int16_t *first, uint32_t n, uint8_t *scratch) { +#ifdef DES + __devicelib_default_work_group_joint_sort_descending_p3i16_u32_p1i8( + first, n, scratch); +#else + __devicelib_default_work_group_joint_sort_ascending_p3i16_u32_p1i8( + first, n, scratch); +#endif + }; + test_work_group_joint_sort( + q, a, work_group_sorter); + std::cout << "work group joint sort p3i16_u32_p1i8 passes" << std::endl; + } + + { + constexpr static int NUM = 23; + int64_t a[NUM] = {-1, 11, 1, 9, 3, 100, 34, 8, + 1000, 77, 293, 23, 36, 2, 111, 91, + 88, -2, 525, -12, 525, -99999999, 19928348493}; + auto work_group_sorter = [](int64_t *first, uint32_t n, uint8_t *scratch) { +#ifdef DES + __devicelib_default_work_group_joint_sort_descending_p3i64_u32_p1i8( + first, n, scratch); +#else + __devicelib_default_work_group_joint_sort_ascending_p3i64_u32_p1i8( + first, n, scratch); +#endif + }; + test_work_group_joint_sort( + q, a, work_group_sorter); + std::cout << "work group joint sort p3i64_u32_p1i8 passes" << std::endl; + } + + { + constexpr static int NUM = 23; + float a[NUM] = {-1.25f, 11.4643f, 1.45f, -9.98f, 13.665f, + 100.0f, 34.625f, 8.125f, 1000.12f, 77.91f, + 293.33f, 23.4f, -36.6f, 2.5f, 111.11f, + 91.889f, 88.88f, -2.98f, 525.25f, -12.11f, + 525.0f, -9999999.9f, 19928348493.123f}; + auto work_group_sorter = [](float *first, uint32_t n, uint8_t *scratch) { +#ifdef DES + __devicelib_default_work_group_joint_sort_descending_p3f32_u32_p1i8( + first, n, scratch); +#else + __devicelib_default_work_group_joint_sort_ascending_p3f32_u32_p1i8( + first, n, scratch); +#endif + }; + test_work_group_joint_sort( + q, a, work_group_sorter); + std::cout << "work group joint sort p3f32_u32_p1i8 passes" << std::endl; + } + + { + constexpr static int NUM = 23; + uint8_t a[NUM] = {234, 11, 1, 9, 3, 100, 34, 8, 121, 77, 125, 23, + 36, 2, 111, 91, 201, 211, 77, 8, 88, 19, 0}; + auto work_group_sorter = [](uint8_t *first, uint32_t n, uint8_t *scratch) { +#ifdef DES + __devicelib_default_work_group_joint_sort_descending_p3u8_u32_p1i8( + first, n, scratch); +#else + __devicelib_default_work_group_joint_sort_ascending_p3u8_u32_p1i8( + first, n, scratch); +#endif + }; + test_work_group_joint_sort( + q, a, work_group_sorter); + std::cout << "work group joint sort p3u8_u32_p1i8 passes" << std::endl; + } + + { + constexpr static int NUM = 23; + uint16_t a[NUM] = {11234, 11, 1, 119, 3, 100, 341, 8, + 121, 77, 125, 23, 3226, 2, 111, 911, + 201, 211, 77, 8, 11188, 19, 0}; + auto work_group_sorter = [](uint16_t *first, uint32_t n, uint8_t *scratch) { +#ifdef DES + __devicelib_default_work_group_joint_sort_descending_p3u16_u32_p1i8( + first, n, scratch); +#else + __devicelib_default_work_group_joint_sort_ascending_p3u16_u32_p1i8( + first, n, scratch); +#endif + }; + test_work_group_joint_sort( + q, a, work_group_sorter); + std::cout << "work group joint sort p3u16_u32_p1i8 passes" << std::endl; + } + + { + constexpr static int NUM = 23; + uint32_t a[NUM] = {11234, 11, 1, 1193332332, 231, 100, 341, 8, + 121, 77, 125, 32, 3226, 2, 111, 911, + 9912201, 211, 711117, 8, 11188, 19, 0}; + auto work_group_sorter = [](uint32_t *first, uint32_t n, uint8_t *scratch) { +#ifdef DES + __devicelib_default_work_group_joint_sort_descending_p3u32_u32_p1i8( + first, n, scratch); +#else + __devicelib_default_work_group_joint_sort_ascending_p3u32_u32_p1i8( + first, n, scratch); +#endif + }; + test_work_group_joint_sort( + q, a, work_group_sorter); + std::cout << "work group joint sort p3u32_u32_p1i8 passes" << std::endl; + } + + { + constexpr static int NUM = 23; + uint64_t a[NUM] = {0x112A111111FFEEFF, + 0xAACC11, + 0x1, + 0x1193332332, + 0x231, + 0xAA, + 0xFCCCA341, + 0x8, + 0x121, + 0x987777777, + 0x81, + 0x20, + 0x3226, + 0x2, + 0x8FFFFFFFFF111, + 0x911, + 0xAAAA9912201, + 0x211, + 0x711117, + 0x8, + 0xABABABABCC, + 0x13, + 0}; + auto work_group_sorter = [](uint64_t *first, uint32_t n, uint8_t *scratch) { +#ifdef DES + __devicelib_default_work_group_joint_sort_descending_p3u64_u32_p1i8( + first, n, scratch); +#else + __devicelib_default_work_group_joint_sort_ascending_p3u64_u32_p1i8( + first, n, scratch); +#endif + }; + test_work_group_joint_sort( + q, a, work_group_sorter); + std::cout << "work group joint sort p3u64_u32_p1i8 passes" << std::endl; + } + + return 0; +} diff --git a/sycl/test-e2e/DeviceLib/group_sort/workgroup_joint_sort_p3p3.cpp b/sycl/test-e2e/DeviceLib/group_sort/workgroup_joint_sort_p3p3.cpp new file mode 100644 index 0000000000000..a280615375fee --- /dev/null +++ b/sycl/test-e2e/DeviceLib/group_sort/workgroup_joint_sort_p3p3.cpp @@ -0,0 +1,289 @@ +// RUN: %{build} -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -fsycl-device-lib-jit-link -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DDES -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DES -fsycl-device-lib-jit-link -o %t.out +// RUN: %{run} %t.out +// +// UNSUPPORTED: cuda || hip + +#include "group_joint_sort_p3p3.hpp" +#include +#include +#include +#include +#include +using namespace sycl; + +template +void test_work_group_joint_sort(sycl::queue &q, Ty input[NUM], SortHelper gsh) { + Ty result[NUM]; + memcpy(result, input, sizeof(Ty) * NUM); + std::sort(&result[0], &result[NUM]); + const static size_t wg_size = WG_SZ; + constexpr size_t copy_sz_per_item = NUM / WG_SZ; + nd_range<1> num_items((range<1>(wg_size)), (range<1>(wg_size))); + { + buffer ibuf(input, NUM); + q.submit([&](auto &h) { + accessor in_acc{ibuf, h}; + local_accessor slm(NUM, h); + local_accessor scratch(NUM, h); + h.parallel_for(num_items, [=](nd_item<1> i) { + // Copy input to local shared memory for sort. + if constexpr (copy_sz_per_item == 0) { + if (i.get_global_id() < NUM) + slm[i.get_global_id()] = in_acc[i.get_global_id()]; + } else { + for (size_t idx = 0; idx < copy_sz_per_item; ++idx) { + slm[copy_sz_per_item * i.get_global_id() + idx] = + in_acc[copy_sz_per_item * i.get_global_id() + idx]; + } + + if (i.get_global_id() < (NUM - copy_sz_per_item * WG_SZ)) { + slm[copy_sz_per_item * WG_SZ + i.get_global_id()] = + in_acc[copy_sz_per_item * WG_SZ + i.get_global_id()]; + } + } + + group_barrier(i.get_group()); + Ty *optr = + scratch.template get_multi_ptr().get(); + uint8_t *by = reinterpret_cast(optr); + gsh(slm.template get_multi_ptr().get(), NUM, + by); + + // Copy sorted data in shared memory to in_acc + if constexpr (copy_sz_per_item == 0) { + if (i.get_global_id() < NUM) + in_acc[i.get_global_id()] = slm[i.get_global_id()]; + } else { + for (size_t idx = 0; idx < copy_sz_per_item; ++idx) { + in_acc[copy_sz_per_item * i.get_global_id() + idx] = + slm[copy_sz_per_item * i.get_global_id() + idx]; + } + + if (i.get_global_id() < (NUM - copy_sz_per_item * WG_SZ)) { + in_acc[copy_sz_per_item * WG_SZ + i.get_global_id()] = + slm[copy_sz_per_item * WG_SZ + i.get_global_id()]; + } + } + + group_barrier(i.get_group()); + }); + }).wait(); + } + + bool fails = false; + for (size_t idx = 0; idx < NUM; ++idx) { +#ifdef DES + if (input[idx] != result[NUM - 1 - idx]) { +#else + if (input[idx] != result[idx]) { +#endif + fails = true; + break; + } + } + assert(!fails); +} + +int main() { + queue q; + + { + constexpr static int NUM = 19; + int8_t a[NUM] = {-1, 11, 10, 9, 3, 100, 34, 8, 10, 77, + -93, 23, 36, 2, 111, 91, 88, -2, -25}; + auto work_group_sorter = [](int8_t *first, uint32_t n, uint8_t *scratch) { +#ifdef DES + __devicelib_default_work_group_joint_sort_descending_p3i8_u32_p3i8( + first, n, scratch); +#else + __devicelib_default_work_group_joint_sort_ascending_p3i8_u32_p3i8( + first, n, scratch); +#endif + }; + test_work_group_joint_sort( + q, a, work_group_sorter); + std::cout << "work group joint sort p3i8_u32_p3i8 passes" << std::endl; + } + + { + constexpr static int NUM = 19; + int32_t a[NUM] = {-1, 11, 1, 9, 3, 100, 34, 8, 1000, 77, + 293, 23, 36, 2, 111, 91, 88, -2, 525}; + auto work_group_sorter = [](int32_t *first, uint32_t n, uint8_t *scratch) { +#ifdef DES + __devicelib_default_work_group_joint_sort_descending_p3i32_u32_p3i8( + first, n, scratch); +#else + __devicelib_default_work_group_joint_sort_ascending_p3i32_u32_p3i8( + first, n, scratch); +#endif + }; + test_work_group_joint_sort( + q, a, work_group_sorter); + std::cout << "work group joint sort p3i32_u32_p3i8 passes" << std::endl; + } + + { + constexpr static int NUM = 21; + int16_t a[NUM] = {-1, 11, 1, 9, 3, 100, 34, 8, 1000, 77, 293, + 23, 36, 2, 111, 91, 88, -2, 525, -12, 525}; + auto work_group_sorter = [](int16_t *first, uint32_t n, uint8_t *scratch) { +#ifdef DES + __devicelib_default_work_group_joint_sort_descending_p3i16_u32_p3i8( + first, n, scratch); +#else + __devicelib_default_work_group_joint_sort_ascending_p3i16_u32_p3i8( + first, n, scratch); +#endif + }; + test_work_group_joint_sort( + q, a, work_group_sorter); + std::cout << "work group joint sort p3i16_u32_p3i8 passes" << std::endl; + } + + { + constexpr static int NUM = 23; + int64_t a[NUM] = {-1, 11, 1, 9, 3, 100, 34, 8, + 1000, 77, 293, 23, 36, 2, 111, 91, + 88, -2, 525, -12, 525, -99999999, 19928348493}; + auto work_group_sorter = [](int64_t *first, uint32_t n, uint8_t *scratch) { +#ifdef DES + __devicelib_default_work_group_joint_sort_descending_p3i64_u32_p3i8( + first, n, scratch); +#else + __devicelib_default_work_group_joint_sort_ascending_p3i64_u32_p3i8( + first, n, scratch); +#endif + }; + test_work_group_joint_sort( + q, a, work_group_sorter); + std::cout << "work group joint sort p3i64_u32_p3i8 passes" << std::endl; + } + + { + constexpr static int NUM = 23; + float a[NUM] = {-1.25f, 11.4643f, 1.45f, -9.98f, 13.665f, + 100.0f, 34.625f, 8.125f, 1000.12f, 77.91f, + 293.33f, 23.4f, -36.6f, 2.5f, 111.11f, + 91.889f, 88.88f, -2.98f, 525.25f, -12.11f, + 525.0f, -9999999.9f, 19928348493.123f}; + auto work_group_sorter = [](float *first, uint32_t n, uint8_t *scratch) { +#ifdef DES + __devicelib_default_work_group_joint_sort_descending_p3f32_u32_p3i8( + first, n, scratch); +#else + __devicelib_default_work_group_joint_sort_ascending_p3f32_u32_p3i8( + first, n, scratch); +#endif + }; + test_work_group_joint_sort( + q, a, work_group_sorter); + std::cout << "work group joint sort p3f32_u32_p3i8 passes" << std::endl; + } + + { + constexpr static int NUM = 23; + uint8_t a[NUM] = {234, 11, 1, 9, 3, 100, 34, 8, 121, 77, 125, 23, + 36, 2, 111, 91, 201, 211, 77, 8, 88, 19, 0}; + auto work_group_sorter = [](uint8_t *first, uint32_t n, uint8_t *scratch) { +#ifdef DES + __devicelib_default_work_group_joint_sort_descending_p3u8_u32_p3i8( + first, n, scratch); +#else + __devicelib_default_work_group_joint_sort_ascending_p3u8_u32_p3i8( + first, n, scratch); +#endif + }; + test_work_group_joint_sort( + q, a, work_group_sorter); + std::cout << "work group joint sort p3u8_u32_p3i8 passes" << std::endl; + } + + { + constexpr static int NUM = 23; + uint16_t a[NUM] = {11234, 11, 1, 119, 3, 100, 341, 8, + 121, 77, 125, 23, 3226, 2, 111, 911, + 201, 211, 77, 8, 11188, 19, 0}; + auto work_group_sorter = [](uint16_t *first, uint32_t n, uint8_t *scratch) { +#ifdef DES + __devicelib_default_work_group_joint_sort_descending_p3u16_u32_p3i8( + first, n, scratch); +#else + __devicelib_default_work_group_joint_sort_ascending_p3u16_u32_p3i8( + first, n, scratch); +#endif + }; + test_work_group_joint_sort( + q, a, work_group_sorter); + std::cout << "work group joint sort p3u16_u32_p3i8 passes" << std::endl; + } + + { + constexpr static int NUM = 23; + uint32_t a[NUM] = {11234, 11, 1, 1193332332, 231, 100, 341, 8, + 121, 77, 125, 32, 3226, 2, 111, 911, + 9912201, 211, 711117, 8, 11188, 19, 0}; + auto work_group_sorter = [](uint32_t *first, uint32_t n, uint8_t *scratch) { +#ifdef DES + __devicelib_default_work_group_joint_sort_descending_p3u32_u32_p3i8( + first, n, scratch); +#else + __devicelib_default_work_group_joint_sort_ascending_p3u32_u32_p3i8( + first, n, scratch); +#endif + }; + test_work_group_joint_sort( + q, a, work_group_sorter); + std::cout << "work group joint sort p3u32_u32_p3i8 passes" << std::endl; + } + + { + constexpr static int NUM = 23; + uint64_t a[NUM] = {0x112A111111FFEEFF, + 0xAACC11, + 0x1, + 0x1193332332, + 0x231, + 0xAA, + 0xFCCCA341, + 0x8, + 0x121, + 0x987777777, + 0x81, + 0x20, + 0x3226, + 0x2, + 0x8FFFFFFFFF111, + 0x911, + 0xAAAA9912201, + 0x211, + 0x711117, + 0x8, + 0xABABABABCC, + 0x13, + 0}; + auto work_group_sorter = [](uint64_t *first, uint32_t n, uint8_t *scratch) { +#ifdef DES + __devicelib_default_work_group_joint_sort_descending_p3u64_u32_p3i8( + first, n, scratch); +#else + __devicelib_default_work_group_joint_sort_ascending_p3u64_u32_p3i8( + first, n, scratch); +#endif + }; + test_work_group_joint_sort( + q, a, work_group_sorter); + std::cout << "work group joint sort p3u64_u32_p3i8 passes" << std::endl; + } + + return 0; +} From 82fa9ce06d695a6fc2c4b3991183844d685ab5ec Mon Sep 17 00:00:00 2001 From: jinge90 Date: Tue, 23 Jul 2024 17:13:56 +0800 Subject: [PATCH 29/71] Add tests for private sort Signed-off-by: jinge90 --- .../group_sort/group_private_sort_p1p1.hpp | 124 +++++++++++ .../group_sort/workgroup_joint_sort_p1p1.cpp | 2 +- .../group_sort/workgroup_joint_sort_p1p3.cpp | 2 +- .../group_sort/workgroup_joint_sort_p3p1.cpp | 2 +- .../group_sort/workgroup_joint_sort_p3p3.cpp | 2 +- .../workgroup_private_sort_p1p1.cpp | 195 ++++++++++++++++++ 6 files changed, 323 insertions(+), 4 deletions(-) create mode 100644 sycl/test-e2e/DeviceLib/group_sort/group_private_sort_p1p1.hpp create mode 100644 sycl/test-e2e/DeviceLib/group_sort/workgroup_private_sort_p1p1.cpp diff --git a/sycl/test-e2e/DeviceLib/group_sort/group_private_sort_p1p1.hpp b/sycl/test-e2e/DeviceLib/group_sort/group_private_sort_p1p1.hpp new file mode 100644 index 0000000000000..35b27a9e1d9eb --- /dev/null +++ b/sycl/test-e2e/DeviceLib/group_sort/group_private_sort_p1p1.hpp @@ -0,0 +1,124 @@ +#pragma once +#include +#ifdef __SYCL_DEVICE_ONLY__ +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1i8_u32_p1i8( + int8_t *first, uint32_t n, uint8_t *scratch); + +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1i8_u32_p1i8( + int8_t *first, uint32_t n, uint8_t *scratch); + +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1i8_u32_p1i8( + int8_t *first, uint32_t n, uint8_t *scratch); + +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1i8_u32_p1i8( + int8_t *first, uint32_t n, uint8_t *scratch); + +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1i16_u32_p1i8( + int16_t *first, uint32_t n, uint8_t *scratch); + +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1i16_u32_p1i8( + int16_t *first, uint32_t n, uint8_t *scratch); + +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1i16_u32_p1i8( + int16_t *first, uint32_t n, uint8_t *scratch); + +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1i16_u32_p1i8( + int16_t *first, uint32_t n, uint8_t *scratch); + +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1i32_u32_p1i8( + int32_t *first, uint32_t n, uint8_t *scratch); + +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1i32_u32_p1i8( + int32_t *first, uint32_t n, uint8_t *scratch); + +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1i32_u32_p1i8( + int32_t *first, uint32_t n, uint8_t *scratch); + +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1i32_u32_p1i8( + int32_t *first, uint32_t n, uint8_t *scratch); +#else +extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1i8_u32_p1i8( + int8_t *first, uint32_t n, uint8_t *scratch) { + return; +} + +extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1i8_u32_p1i8( + int8_t *first, uint32_t n, uint8_t *scratch) { + return; +} + +extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1i8_u32_p1i8( + int8_t *first, uint32_t n, uint8_t *scratch) { + return; +} + +extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1i8_u32_p1i8( + int8_t *first, uint32_t n, uint8_t *scratch) { + return; +} + +extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1i16_u32_p1i8( + int16_t *first, uint32_t n, uint8_t *scratch) { + return; +} + +extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1i16_u32_p1i8( + int16_t *first, uint32_t n, uint8_t *scratch) { + return; +} + +extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1i16_u32_p1i8( + int16_t *first, uint32_t n, uint8_t *scratch) { + return; +} + +extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1i16_u32_p1i8( + int16_t *first, uint32_t n, uint8_t *scratch) { + return; +} + +extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1i32_u32_p1i8( + int32_t *first, uint32_t n, uint8_t *scratch) { + return; +} + +extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1i32_u32_p1i8( + int32_t *first, uint32_t n, uint8_t *scratch) { + return; +} + +extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1i32_u32_p1i8( + int32_t *first, uint32_t n, uint8_t *scratch) { + return; +} + +extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1i32_u32_p1i8( + int32_t *first, uint32_t n, uint8_t *scratch) { + return; +} + +#endif diff --git a/sycl/test-e2e/DeviceLib/group_sort/workgroup_joint_sort_p1p1.cpp b/sycl/test-e2e/DeviceLib/group_sort/workgroup_joint_sort_p1p1.cpp index 1a9f42af789ea..ff2538d1bfb57 100644 --- a/sycl/test-e2e/DeviceLib/group_sort/workgroup_joint_sort_p1p1.cpp +++ b/sycl/test-e2e/DeviceLib/group_sort/workgroup_joint_sort_p1p1.cpp @@ -7,7 +7,7 @@ // RUN: %{build} -DDES -o %t.out // RUN: %{run} %t.out -// RUN: %{build} -DES -fsycl-device-lib-jit-link -o %t.out +// RUN: %{build} -DDES -fsycl-device-lib-jit-link -o %t.out // RUN: %{run} %t.out // // UNSUPPORTED: cuda || hip diff --git a/sycl/test-e2e/DeviceLib/group_sort/workgroup_joint_sort_p1p3.cpp b/sycl/test-e2e/DeviceLib/group_sort/workgroup_joint_sort_p1p3.cpp index 8ac8c0326ab41..603ce777e470f 100644 --- a/sycl/test-e2e/DeviceLib/group_sort/workgroup_joint_sort_p1p3.cpp +++ b/sycl/test-e2e/DeviceLib/group_sort/workgroup_joint_sort_p1p3.cpp @@ -7,7 +7,7 @@ // RUN: %{build} -DDES -o %t.out // RUN: %{run} %t.out -// RUN: %{build} -DES -fsycl-device-lib-jit-link -o %t.out +// RUN: %{build} -DDES -fsycl-device-lib-jit-link -o %t.out // RUN: %{run} %t.out // // UNSUPPORTED: cuda || hip diff --git a/sycl/test-e2e/DeviceLib/group_sort/workgroup_joint_sort_p3p1.cpp b/sycl/test-e2e/DeviceLib/group_sort/workgroup_joint_sort_p3p1.cpp index 3ca5eaed71017..5204aae9c09fb 100644 --- a/sycl/test-e2e/DeviceLib/group_sort/workgroup_joint_sort_p3p1.cpp +++ b/sycl/test-e2e/DeviceLib/group_sort/workgroup_joint_sort_p3p1.cpp @@ -7,7 +7,7 @@ // RUN: %{build} -DDES -o %t.out // RUN: %{run} %t.out -// RUN: %{build} -DES -fsycl-device-lib-jit-link -o %t.out +// RUN: %{build} -DDES -fsycl-device-lib-jit-link -o %t.out // RUN: %{run} %t.out // // UNSUPPORTED: cuda || hip diff --git a/sycl/test-e2e/DeviceLib/group_sort/workgroup_joint_sort_p3p3.cpp b/sycl/test-e2e/DeviceLib/group_sort/workgroup_joint_sort_p3p3.cpp index a280615375fee..e83fa5d0d4361 100644 --- a/sycl/test-e2e/DeviceLib/group_sort/workgroup_joint_sort_p3p3.cpp +++ b/sycl/test-e2e/DeviceLib/group_sort/workgroup_joint_sort_p3p3.cpp @@ -7,7 +7,7 @@ // RUN: %{build} -DDES -o %t.out // RUN: %{run} %t.out -// RUN: %{build} -DES -fsycl-device-lib-jit-link -o %t.out +// RUN: %{build} -DDES -fsycl-device-lib-jit-link -o %t.out // RUN: %{run} %t.out // // UNSUPPORTED: cuda || hip diff --git a/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_sort_p1p1.cpp b/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_sort_p1p1.cpp new file mode 100644 index 0000000000000..5fd2444d0c9ae --- /dev/null +++ b/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_sort_p1p1.cpp @@ -0,0 +1,195 @@ +// RUN: %{build} -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DCLOSE -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -fsycl-device-lib-jit-link -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DCLOSE -fsycl-device-lib-jit-link -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DDES -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DDES -DCLOSE -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DDES -DCLOSE -fsycl-device-lib-jit-link -o %t.out +// RUN: %{run} %t.out +// +// UNSUPPORTED: cuda || hip + +#include "group_private_sort_p1p1.hpp" +#include +#include +#include +#include +#include +using namespace sycl; +template +void test_work_group_private_sort(sycl::queue &q, Ty input[NUM], + SortHelper gsh) { + static_assert(NUM % WG_SZ == 0, + "Input size must be divisible by Work group size!"); + // Scratch memory size >= NUM * sizeof(Ty) * 2 + Ty scratch[NUM * 2] = { + 0, + }; + Ty result[NUM]; + Ty reference[NUM]; + memcpy(reference, input, sizeof(Ty) * NUM); +#ifdef DES + std::sort(&reference[0], &reference[NUM], std::greater()); +#else + std::sort(&reference[0], &reference[NUM]); +#endif + const static size_t wg_size = WG_SZ; + constexpr size_t num_per_work_item = NUM / WG_SZ; + nd_range<1> num_items((range<1>(wg_size)), (range<1>(wg_size))); + + { + buffer ibuf(input, NUM); + buffer obuf(scratch, NUM * 2); + buffer rbuf(result, NUM); + q.submit([&](auto &h) { + accessor in_acc{ibuf, h}; + accessor out_acc{obuf, h}; + accessor re_acc{rbuf, h}; + h.parallel_for(num_items, [=](nd_item<1> i) { + Ty private_buf[num_per_work_item]; + for (size_t idx = 0; idx < num_per_work_item; ++idx) + private_buf[idx] = + in_acc[i.get_local_linear_id() * num_per_work_item + idx]; + group_barrier(i.get_group()); + Ty *optr = + out_acc.template get_multi_ptr().get(); + uint8_t *by = reinterpret_cast(optr); + gsh(private_buf, num_per_work_item, by); + for (size_t idx = 0; idx < num_per_work_item; ++idx) + re_acc[i.get_local_linear_id() * num_per_work_item + idx] = + private_buf[idx]; + }); + }).wait(); + } + + bool fails = false; + +#ifdef CLOSE + for (size_t idx = 0; idx < NUM; ++idx) { + if (result[idx] != reference[idx]) { + fails = true; + break; + } + } +#else + for (size_t idx = 0; idx < NUM; ++idx) { + size_t idx1 = idx % WG_SZ; + size_t idx2 = idx / WG_SZ; + if (reference[idx] != result[idx1 * num_per_work_item + idx2]) { + fails = true; + break; + } + } +#endif + + assert(!fails); +} + +int main() { + queue q; + + { + constexpr static int NUM = 24; + int8_t a[NUM] = {-1, 11, 10, 9, 3, 100, 34, 8, 10, 77, 10, 103, + -12, -93, 23, 36, 2, 111, 91, 88, -2, -25, 98, -111}; + auto work_group_sorter = [](int8_t *first, uint32_t n, uint8_t *scratch) { +#ifdef CLOSE +#ifdef DES + __devicelib_default_work_group_private_sort_close_descending_p1i8_u32_p1i8( + first, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1i8_u32_p1i8( + first, n, scratch); +#endif +#else +#ifdef DES + __devicelib_default_work_group_private_sort_spread_descending_p1i8_u32_p1i8( + first, n, scratch); +#else + __devicelib_default_work_group_private_sort_spread_ascending_p1i8_u32_p1i8( + first, n, scratch); +#endif +#endif + }; + test_work_group_private_sort( + q, a, work_group_sorter); + std::cout << "work group private sort p1i8_u32_p1i8 passes" << std::endl; + } + + { + constexpr static int NUM = 32; + int16_t a[NUM] = {2162, 29891, 14709, -20987, -30051, -26861, 5629, + -11244, 25702, 29438, 22560, -15282, 27812, 28455, + 26871, -22327, 6495, 23519, 19389, 26328, 13253, + -24369, -1616, 3278, 5624, -6317, -3669, 11874, + -46, -4717, -27449, -9790}; + auto work_group_sorter = [](int16_t *first, uint32_t n, uint8_t *scratch) { +#ifdef CLOSE +#ifdef DES + __devicelib_default_work_group_private_sort_close_descending_p1i16_u32_p1i8( + first, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1i16_u32_p1i8( + first, n, scratch); +#endif +#else +#ifdef DES + __devicelib_default_work_group_private_sort_spread_descending_p1i16_u32_p1i8( + first, n, scratch); +#else + __devicelib_default_work_group_private_sort_spread_ascending_p1i16_u32_p1i8( + first, n, scratch); +#endif +#endif + }; + test_work_group_private_sort( + q, a, work_group_sorter); + std::cout << "work group private sort p1i16_u32_p1i8 passes" << std::endl; + } + + { + constexpr static int NUM = 32; + int32_t a[NUM] = {1319329913, -390041276, -2040725419, -217333100, + -900793956, -2138508211, 769705434, 122767310, + -1918605668, -16813517, 1616926513, -2141526068, + 631985359, 541606467, 662050754, 140359040, + 1834119354, 1910851165, 809736505, 451506849, + -1713169862, -1916401837, 1490159094, -2066441094, + -332318833, -1550930943, 1763101596, 500568854, + -1574546569, -596440302, 1522396193, -980468122}; + auto work_group_sorter = [](int32_t *first, uint32_t n, uint8_t *scratch) { +#ifdef CLOSE +#ifdef DES + __devicelib_default_work_group_private_sort_close_descending_p1i32_u32_p1i8( + first, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1i32_u32_p1i8( + first, n, scratch); +#endif +#else +#ifdef DES + __devicelib_default_work_group_private_sort_spread_descending_p1i32_u32_p1i8( + first, n, scratch); +#else + __devicelib_default_work_group_private_sort_spread_ascending_p1i32_u32_p1i8( + first, n, scratch); +#endif +#endif + }; + test_work_group_private_sort( + q, a, work_group_sorter); + std::cout << "work group private sort p1i32_u32_p1i8 passes" << std::endl; + } +} From 5fa8ec38248094b5ebb818bb28ae462e119254b1 Mon Sep 17 00:00:00 2001 From: jinge90 Date: Mon, 29 Jul 2024 15:05:42 +0800 Subject: [PATCH 30/71] draft to add key value Signed-off-by: jinge90 --- libdevice/sort_helper.hpp | 88 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 88 insertions(+) diff --git a/libdevice/sort_helper.hpp b/libdevice/sort_helper.hpp index 7e48caf5c6a25..ec6be3133e85d 100644 --- a/libdevice/sort_helper.hpp +++ b/libdevice/sort_helper.hpp @@ -157,4 +157,92 @@ Tp sub_group_merge_sort(Tp value, uint8_t *scratch, Compare comp) { return temp_buffer[idx]; } +static void __get_chunk_size(size_t group_id, size_t group_size, size_t n, + size_t *beg, size_t *end) { + size_t tmp = n % group_size; + size_t chunk_size = n / group_size; + if (tmp) { + if (group_id < tmp) { + *beg = group_id * (chunk_size + 1); + *end = *beg + chunk_size + 1; + } else { + *beg = tmp * (chunk_size + 1) + (group_id - tmp) * chunk_size; + *end = *beg + chunk_size; + } + } else { + *beg = group_id * chunk_size; + *end = *beg + chunk_size; + } +} + +template +void merge_key_value(KeyT *keys_in, KeyT *keys_out, ValT *vals_in, + ValT *vals_out, size_t widx, size_t iter_num, + size_t chunks_to_merge, Compare comp) { + if (2 * widx >= chunks_to_merge) + return; + + // +} + +template +void bubble_sort_key_value(KeyT *keys, ValT *vals, const size_t beg, + const size_t end, Compare comp) { + if (beg < end) { + KeyT temp_key; + ValT temp_val; + for (size_t i = beg; i < end; ++i) + for (size_t j = i + 1; j < end; ++j) { + if (!comp(keys[i], keys[j])) { + temp_key = keys[i]; + keys[i] = keys[j]; + keys[j] = temp_key; + temp_val = vals[i]; + vals[i] = vals[j]; + vals[j] = temp_val; + } + } + } +} + +// We have following assumption for scratch memory size for key-value +// group sort: size of scratch > (sizeof(KeyT) + sizeof(ValT)) + +// max(alignof(KeyT), alignof(ValT)). +template +void merge_sort_key_value(KeyT *keys, ValT *vals, size_t n, uint8_t *scratch, + Compare comp) { + const size_t idx = __get_wg_local_linear_id(); + const size_t wg_size = __get_wg_local_range(); + const size_t bubble_beg, bubble_end; + __get_chunk_size(idx, wg_size, n, &bubble_beg, &bubble_end); + bubble_sort(keys, vals, bubble_beg, bubble_end, comp); + group_barrier(); + bool data_in_scratch = false; + KeyT *scratch_keys = reinterpret_cast(scratch); + uint8_t *val_offset = scratch + sizeof(KeyT) * (n + 1); + val_offset += alignof(ValT) - val_offset % alignof(ValT); + ValT *scratch_vals = reinterpret_cast(val_offset); + // If n > work_group_size, each work item holds sorted elements to be merged. + // Otherwise, only n work items hold 1 element. Chunk size <= work group size. + size_t chunks_to_merge = (n > wg_size) ? wg_size : n; + size_t iter_num = 0; + while (chunks_to_merge > 1) { + // workitem 0 will merge chunk 0, 1. + // workitem 1 will merge chunk 2, 3. + // workitem idx will merge chunk 2 * idx and 2 * idx + 1 + KeyT *keys_in = data_in_scratch ? scratch_keys : keys; + KeyT *keys_out = data_in_scratch ? keys : scratch_keys; + ValT *vals_in = data_in_scratch ? scratch_vals : vals; + ValT *vals_out = data_in_scratch ? vals : scratch_vals; + merge_key_value(keys_in, keys_out, vals_in, vals_out, + idx, iter_num, chunks_to_merge, comp); + // merge(data_in, data_out, idx, merge_size, chunks_to_merge, + // n, + // comp); + group_barrier(); + chunks_to_merge = (chunks_to_merge - 1) / 2 + 1; + data_in_scratch = !data_in_scratch; + } +} + #endif // __SPIR__ || __SPIRV__ From b775b859d3389e0d0196ce088380df6afe0b66ee Mon Sep 17 00:00:00 2001 From: jinge90 Date: Mon, 29 Jul 2024 17:02:50 +0800 Subject: [PATCH 31/71] Add (uint32_t, uint32_t) Signed-off-by: jinge90 --- libdevice/fallback-gsort.cpp | 6 ++ libdevice/sort_helper.hpp | 111 ++++++++++++++++++++++------------- 2 files changed, 77 insertions(+), 40 deletions(-) diff --git a/libdevice/fallback-gsort.cpp b/libdevice/fallback-gsort.cpp index d36667fd17c00..eeee9fb492695 100644 --- a/libdevice/fallback-gsort.cpp +++ b/libdevice/fallback-gsort.cpp @@ -1136,4 +1136,10 @@ __devicelib_default_sub_group_private_sort_descending_f16(_Float16 value, [](_Float16 a, _Float16 b) { return (a > b); }); } +//========= default work grop joint sort for (uint32_t, uint32_t) ============== +void __devicelib_default_work_group_joint_sort_ascending_p1u32_p1u32_u32_p1i8( + uint32_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::less{}); +} + #endif // __SPIR__ || __SPIRV__ diff --git a/libdevice/sort_helper.hpp b/libdevice/sort_helper.hpp index ec6be3133e85d..137d5238b7b55 100644 --- a/libdevice/sort_helper.hpp +++ b/libdevice/sort_helper.hpp @@ -157,34 +157,6 @@ Tp sub_group_merge_sort(Tp value, uint8_t *scratch, Compare comp) { return temp_buffer[idx]; } -static void __get_chunk_size(size_t group_id, size_t group_size, size_t n, - size_t *beg, size_t *end) { - size_t tmp = n % group_size; - size_t chunk_size = n / group_size; - if (tmp) { - if (group_id < tmp) { - *beg = group_id * (chunk_size + 1); - *end = *beg + chunk_size + 1; - } else { - *beg = tmp * (chunk_size + 1) + (group_id - tmp) * chunk_size; - *end = *beg + chunk_size; - } - } else { - *beg = group_id * chunk_size; - *end = *beg + chunk_size; - } -} - -template -void merge_key_value(KeyT *keys_in, KeyT *keys_out, ValT *vals_in, - ValT *vals_out, size_t widx, size_t iter_num, - size_t chunks_to_merge, Compare comp) { - if (2 * widx >= chunks_to_merge) - return; - - // -} - template void bubble_sort_key_value(KeyT *keys, ValT *vals, const size_t beg, const size_t end, Compare comp) { @@ -205,6 +177,57 @@ void bubble_sort_key_value(KeyT *keys, ValT *vals, const size_t beg, } } +template +void merge_key_value(KeyT *keys_in, KeyT *keys_out, ValT *vals_in, + ValT *vals_out, size_t widx, size_t msize, size_t chunks, + size_t n, Compare comp) { + if (2 * widx >= chunks) + return; + + size_t beg1 = 2 * widx * msize; + size_t end1 = beg1 + msize; + size_t beg2, end2; + if (end1 >= n) + end1 = beg2 = end2 = n; + else { + beg2 = end1; + end2 = beg2 + msize; + if (end2 >= n) + end2 = n; + } + + size_t output_idx = 2 * widx * msize; + while ((beg1 != end1) && (beg2 != end2)) { + KeyT key_temp; + ValT val_temp; + if (comp(keys_in[beg1], keys_in[beg2])) { + key_temp = keys_in[beg1]; + val_temp = vals_in[beg1]; + ++beg1; + } else { + key_temp = keys_in[beg2]; + val_temp = vals_in[beg2]; + ++beg2; + } + keys_out[output_idx] = key_temp; + vals_out[output_idx] = val_temp; + ++output_idx; + } + + while (beg1 != end1) { + keys_out[output_idx] = keys_in[beg1]; + vals_out[output_idx] = vals_in[beg1]; + ++beg1; + ++output_idx; + } + while (beg2 != end2) { + keys_out[output_idx] = keys_in[beg2]; + vals_out[output_idx] = vals_in[beg2]; + ++beg2; + ++output_idx; + } +} + // We have following assumption for scratch memory size for key-value // group sort: size of scratch > (sizeof(KeyT) + sizeof(ValT)) + // max(alignof(KeyT), alignof(ValT)). @@ -213,19 +236,20 @@ void merge_sort_key_value(KeyT *keys, ValT *vals, size_t n, uint8_t *scratch, Compare comp) { const size_t idx = __get_wg_local_linear_id(); const size_t wg_size = __get_wg_local_range(); - const size_t bubble_beg, bubble_end; - __get_chunk_size(idx, wg_size, n, &bubble_beg, &bubble_end); - bubble_sort(keys, vals, bubble_beg, bubble_end, comp); + size_t chunk_size = (n - 1) / wg_size + 1; + size_t bubble_beg, bubble_end; + bubble_beg = (idx * chunk_size) >= n ? n : idx * chunk_size; + bubble_end = ((idx + 1) * chunk_size) > n ? n : (idx + 1) * chunk_size; + bubble_sort_key_value(keys, vals, bubble_beg, bubble_end, comp); group_barrier(); bool data_in_scratch = false; KeyT *scratch_keys = reinterpret_cast(scratch); uint8_t *val_offset = scratch + sizeof(KeyT) * (n + 1); - val_offset += alignof(ValT) - val_offset % alignof(ValT); + val_offset += + alignof(ValT) - reinterpret_cast(val_offset) % alignof(ValT); ValT *scratch_vals = reinterpret_cast(val_offset); - // If n > work_group_size, each work item holds sorted elements to be merged. - // Otherwise, only n work items hold 1 element. Chunk size <= work group size. - size_t chunks_to_merge = (n > wg_size) ? wg_size : n; - size_t iter_num = 0; + size_t chunks_to_merge = (n - 1) / chunk_size + 1; + size_t merge_size = chunk_size; while (chunks_to_merge > 1) { // workitem 0 will merge chunk 0, 1. // workitem 1 will merge chunk 2, 3. @@ -235,14 +259,21 @@ void merge_sort_key_value(KeyT *keys, ValT *vals, size_t n, uint8_t *scratch, ValT *vals_in = data_in_scratch ? scratch_vals : vals; ValT *vals_out = data_in_scratch ? vals : scratch_vals; merge_key_value(keys_in, keys_out, vals_in, vals_out, - idx, iter_num, chunks_to_merge, comp); - // merge(data_in, data_out, idx, merge_size, chunks_to_merge, - // n, - // comp); + idx, merge_size, chunks_to_merge, n, + comp); group_barrier(); chunks_to_merge = (chunks_to_merge - 1) / 2 + 1; + merge_size <<= 1; data_in_scratch = !data_in_scratch; } + + if (data_in_scratch) { + for (size_t i = idx * chunk_size; i < bubble_end; ++i) { + keys[i] = scratch_keys[i]; + vals[i] = scratch_vals[i]; + } + group_barrier(); + } } #endif // __SPIR__ || __SPIRV__ From c0dbd92868eb210d91f22c52aca630a8656ff242 Mon Sep 17 00:00:00 2001 From: jinge90 Date: Tue, 30 Jul 2024 16:52:01 +0800 Subject: [PATCH 32/71] Add joint sort KV Signed-off-by: jinge90 --- libdevice/fallback-gsort.cpp | 33 ++++++++++++++++++++++++++++++++- libdevice/sort_helper.hpp | 2 +- 2 files changed, 33 insertions(+), 2 deletions(-) diff --git a/libdevice/fallback-gsort.cpp b/libdevice/fallback-gsort.cpp index eeee9fb492695..94252979f8c0c 100644 --- a/libdevice/fallback-gsort.cpp +++ b/libdevice/fallback-gsort.cpp @@ -1137,9 +1137,40 @@ __devicelib_default_sub_group_private_sort_descending_f16(_Float16 value, } //========= default work grop joint sort for (uint32_t, uint32_t) ============== +DEVICE_EXTERN_C_INLINE void __devicelib_default_work_group_joint_sort_ascending_p1u32_p1u32_u32_p1i8( uint32_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch) { - merge_sort_key_value(keys, vals, n, scratch, std::less{}); + merge_sort_key_value(keys, vals, n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void __devicelib_default_work_group_joint_sort_descending_p1u32_p1u32_u32_p1i8( + uint32_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void __devicelib_default_work_group_joint_sort_ascending_p1u8_p1u32_u32_p1i8( + uint8_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void __devicelib_default_work_group_joint_sort_descending_p1u8_p1u32_u32_p1i8( + uint8_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void __devicelib_default_work_group_joint_sort_ascending_p1u32_p1u8_u32_p1i8( + uint32_t *keys, uint8_t *vals, uint32_t n, uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void __devicelib_default_work_group_joint_sort_descending_p1u32_p1u8_u32_p1i8( + uint32_t *keys, uint8_t *vals, uint32_t n, uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::greater_equal{}); } #endif // __SPIR__ || __SPIRV__ diff --git a/libdevice/sort_helper.hpp b/libdevice/sort_helper.hpp index 137d5238b7b55..c62f2eaa685f0 100644 --- a/libdevice/sort_helper.hpp +++ b/libdevice/sort_helper.hpp @@ -229,7 +229,7 @@ void merge_key_value(KeyT *keys_in, KeyT *keys_out, ValT *vals_in, } // We have following assumption for scratch memory size for key-value -// group sort: size of scratch > (sizeof(KeyT) + sizeof(ValT)) + +// group sort: size of scratch > n * (sizeof(KeyT) + sizeof(ValT)) + // max(alignof(KeyT), alignof(ValT)). template void merge_sort_key_value(KeyT *keys, ValT *vals, size_t n, uint8_t *scratch, From f61d9ce1c2cc9fe877f7993fc3b9672b48d48c57 Mon Sep 17 00:00:00 2001 From: jinge90 Date: Tue, 30 Jul 2024 17:30:40 +0800 Subject: [PATCH 33/71] add test for KV joint sort Signed-off-by: jinge90 --- .../group_joint_KV_sort_p1p1_p1.hpp | 19 ++++ .../workgroup_joint_KV_sort_p1p1_p1.cpp | 100 ++++++++++++++++++ 2 files changed, 119 insertions(+) create mode 100644 sycl/test-e2e/DeviceLib/group_sort/group_joint_KV_sort_p1p1_p1.hpp create mode 100644 sycl/test-e2e/DeviceLib/group_sort/workgroup_joint_KV_sort_p1p1_p1.cpp diff --git a/sycl/test-e2e/DeviceLib/group_sort/group_joint_KV_sort_p1p1_p1.hpp b/sycl/test-e2e/DeviceLib/group_sort/group_joint_KV_sort_p1p1_p1.hpp new file mode 100644 index 0000000000000..ff3b204e20b57 --- /dev/null +++ b/sycl/test-e2e/DeviceLib/group_sort/group_joint_KV_sort_p1p1_p1.hpp @@ -0,0 +1,19 @@ +#pragma once +#include + +#ifdef __SYCL_DEVICE_ONLY__ +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p1u32_p1u32_u32_p1i8( + uint32_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch); + +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_descending_p1u32_p1u32_u32_p1i8( + uint32_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch); +#else +extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p1u32_p1u32_u32_p1i8( + uint32_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch) {} +extern "C" void +__devicelib_default_work_group_joint_sort_descending_p1u32_p1u32_u32_p1i8( + uint32_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch) {} +#endif diff --git a/sycl/test-e2e/DeviceLib/group_sort/workgroup_joint_KV_sort_p1p1_p1.cpp b/sycl/test-e2e/DeviceLib/group_sort/workgroup_joint_KV_sort_p1p1_p1.cpp new file mode 100644 index 0000000000000..23ef81b303e16 --- /dev/null +++ b/sycl/test-e2e/DeviceLib/group_sort/workgroup_joint_KV_sort_p1p1_p1.cpp @@ -0,0 +1,100 @@ +// RUN: %{build} -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -fsycl-device-lib-jit-link -o %t.out +// RUN: %{run} %t.out + +#include "group_joint_KV_sort_p1p1_p1.hpp" +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace sycl; + +template +void test_work_group_KV_joint_sort(sycl::queue &q, KeyT input_keys[NUM], + ValT input_vals[NUM], SortHelper gsh) { + size_t scratch_size = NUM * (sizeof(KeyT) + sizeof(ValT)) + + std::max(alignof(KeyT), alignof(ValT)); + uint8_t *scratch_ptr = + (uint8_t *)aligned_alloc_device(alignof(KeyT), scratch_size, q); + const static size_t wg_size = WG_SZ; + std::vector> sorted_vec; + for (size_t idx = 0; idx < NUM; ++idx) + sorted_vec.push_back(std::make_tuple(input_keys[idx], input_vals[idx])); +#ifdef DES + auto kv_tuple_comp = [](const std::tuple &t1, + const std::tuple &t2) { + return std::get<0>(t1) > std::get<0>(t2); + }; +#else + auto kv_tuple_comp = [](const std::tuple &t1, + const std::tuple &t2) { + return std::get<0>(t1) < std::get<0>(t2); + }; +#endif + std::stable_sort(sorted_vec.begin(), sorted_vec.end(), kv_tuple_comp); + + /*for (size_t idx = 0; idx < NUM; ++idx) { + std::cout << "key: " << std::get<0>(sorted_vec[idx]) << " val: " << + std::get<1>(sorted_vec[idx]) << std::endl; + }*/ + nd_range<1> num_items((range<1>(wg_size)), (range<1>(wg_size))); + { + buffer ikeys_buf(input_keys, NUM); + buffer ivals_buf(input_vals, NUM); + q.submit([&](auto &h) { + accessor ikeys_acc{ikeys_buf, h}; + accessor ivals_acc{ivals_buf, h}; + h.parallel_for(num_items, [=](nd_item<1> i) { + gsh(ikeys_acc.template get_multi_ptr().get(), + ivals_acc.template get_multi_ptr().get(), + NUM, scratch_ptr); + }); + }).wait(); + } + bool fails = false; + for (size_t idx = 0; idx < NUM; ++idx) { + if ((input_keys[idx] != std::get<0>(sorted_vec[idx])) || + (input_vals[idx] != std::get<1>(sorted_vec[idx]))) { + fails = true; + break; + } + } + assert(!fails); +} + +int main() { + queue q; + + { + constexpr static int NUM = 21; + uint32_t ikeys[NUM] = {1, 11, 1, 9, 3, 100, 34, 8, 121, 77, 125, + 23, 36, 2, 111, 91, 88, 2, 51, 95431, 881}; + uint32_t ivals[NUM] = {99, 32, 1, 2, 67, 9123, 453, + 435, 91111, 777, 165, 145, 2456, 88811, + 761, 96, 765, 10000, 6364, 90, 525}; + auto work_group_sorter = [](uint32_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { +#ifdef DES + __devicelib_default_work_group_joint_sort_descending_p1u32_p1u32_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_joint_sort_ascending_p1u32_p1u32_u32_p1i8( + keys, vals, n, scratch); +#endif + }; + test_work_group_KV_joint_sort( + q, ikeys, ivals, work_group_sorter); + std::cout << "KV joint sort p1u32_p1u32_u32_p1i8 pass." << std::endl; + } + + return 0; +} From 060842e0bd63ab63a62b242e10a195d6b91f0feb Mon Sep 17 00:00:00 2001 From: jinge90 Date: Mon, 5 Aug 2024 16:54:23 +0800 Subject: [PATCH 34/71] add KV sort private sort Signed-off-by: jinge90 --- libdevice/fallback-gsort.cpp | 14 ++++++++++++ libdevice/sort_helper.hpp | 43 ++++++++++++++++++++++++++++++++++++ 2 files changed, 57 insertions(+) diff --git a/libdevice/fallback-gsort.cpp b/libdevice/fallback-gsort.cpp index 94252979f8c0c..881ceae4e4370 100644 --- a/libdevice/fallback-gsort.cpp +++ b/libdevice/fallback-gsort.cpp @@ -1173,4 +1173,18 @@ void __devicelib_default_work_group_joint_sort_descending_p1u32_p1u8_u32_p1i8( merge_sort_key_value(keys, vals, n, scratch, std::greater_equal{}); } +DEVICE_EXTERN_C_INLINE +void __devicelib_default_work_group_private_sort_close_ascending_p1u32_p1u32_u32_p1i8( + uint32_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::less{}); +} + +DEVICE_EXTERN_C_INLINE +void __devicelib_default_work_group_private_sort_close_ascending_p1u8_p1u32_u32_p1i8( + uint8_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::less{}); +} + #endif // __SPIR__ || __SPIRV__ diff --git a/libdevice/sort_helper.hpp b/libdevice/sort_helper.hpp index c62f2eaa685f0..b3c9d4093bf4f 100644 --- a/libdevice/sort_helper.hpp +++ b/libdevice/sort_helper.hpp @@ -276,4 +276,47 @@ void merge_sort_key_value(KeyT *keys, ValT *vals, size_t n, uint8_t *scratch, } } +// Each work-item holds fixed-size input keys/values located in private memory +// and apply group sorting to all work-items' input. The sorted data will be +// copied back to each work-item's private memory. +// Assumption about scratch memory size: +// scratch_size >= 2 *(n * wg_size * (sizeof(KeyT) + sizeof(ValT))) + \ +// max(alignof(KeyT), alignof(ValT)) +// The scrach memory alignment is max(alignof(KeyT), alignof(ValT)) +template +void private_merge_sort_key_value_close(KeyT *keys, ValT *vals, size_t n, + uint8_t *scratch, Compare comp) { + const size_t local_idx = __get_wg_local_linear_id(); + const size_t wg_size = __get_wg_local_range(); + KeyT *temp_key_beg = reinterpret_cast(scratch); + uint64_t temp_val_unaligned = + reinterpret_cast(scratch + 2 * wg_size * n * sizeof(KeyT)); + ValT *temp_val_beg = nullptr; + uint64_t temp1 = temp_val_unaligned % alignof(ValT); + if (temp1) + temp_val_beg = + reinterpret_cast(temp_val_unaligned + alignof(ValT) - temp1); + else + temp_val_beg = reinterpret_cast(temp_val_unaligned); + + uint8_t *internal_scratch = + reinterpret_cast(&temp_key_beg[n * wg_size]); + temp_val_beg = &temp_val_beg[n * wg_size]; + + for (size_t i = 0; i < n; ++i) { + temp_key_beg[local_idx * n + i] = keys[i]; + temp_val_beg[local_idx * n + i] = vals[i]; + } + + group_barrier(); + + merge_sort_key_value(temp_key_beg, temp_val_beg, n * wg_size, + internal_scratch, comp); + + for (size_t i = 0; i < n; ++i) { + keys[i] = temp_key_beg[local_idx * n + i]; + vals[i] = temp_val_beg[local_idx * n + i]; + } +} + #endif // __SPIR__ || __SPIRV__ From e5543056ae19a1c7e99005cc352a7c30753371aa Mon Sep 17 00:00:00 2001 From: jinge90 Date: Tue, 6 Aug 2024 11:58:54 +0800 Subject: [PATCH 35/71] disable KV sort tests on CPU backend Signed-off-by: jinge90 --- .../DeviceLib/group_sort/workgroup_joint_KV_sort_p1p1_p1.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/sycl/test-e2e/DeviceLib/group_sort/workgroup_joint_KV_sort_p1p1_p1.cpp b/sycl/test-e2e/DeviceLib/group_sort/workgroup_joint_KV_sort_p1p1_p1.cpp index 23ef81b303e16..16f6db4873e39 100644 --- a/sycl/test-e2e/DeviceLib/group_sort/workgroup_joint_KV_sort_p1p1_p1.cpp +++ b/sycl/test-e2e/DeviceLib/group_sort/workgroup_joint_KV_sort_p1p1_p1.cpp @@ -4,6 +4,8 @@ // RUN: %{build} -fsycl-device-lib-jit-link -o %t.out // RUN: %{run} %t.out +// UNSUPPORTED: cuda || hip || cpu + #include "group_joint_KV_sort_p1p1_p1.hpp" #include #include From 627e0e87e806ba8b4cd1a3f1c3e9f7123ad1a739 Mon Sep 17 00:00:00 2001 From: jinge90 Date: Tue, 6 Aug 2024 17:37:16 +0800 Subject: [PATCH 36/71] add private KV sort Signed-off-by: jinge90 --- libdevice/fallback-gsort.cpp | 4 +- libdevice/sort_helper.hpp | 12 +- .../group_private_KV_sort_p1p1_p1.hpp | 20 +++ .../workgroup_joint_KV_sort_p1p1_p1.cpp | 26 ++++ .../workgroup_private_KV_sort_p1p1_p1.cpp | 126 ++++++++++++++++++ 5 files changed, 182 insertions(+), 6 deletions(-) create mode 100644 sycl/test-e2e/DeviceLib/group_sort/group_private_KV_sort_p1p1_p1.hpp create mode 100644 sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_p1p1_p1.cpp diff --git a/libdevice/fallback-gsort.cpp b/libdevice/fallback-gsort.cpp index 881ceae4e4370..a1a2843c9b314 100644 --- a/libdevice/fallback-gsort.cpp +++ b/libdevice/fallback-gsort.cpp @@ -1177,14 +1177,14 @@ DEVICE_EXTERN_C_INLINE void __devicelib_default_work_group_private_sort_close_ascending_p1u32_p1u32_u32_p1i8( uint32_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch) { private_merge_sort_key_value_close(keys, vals, n, scratch, - std::less{}); + std::less_equal{}); } DEVICE_EXTERN_C_INLINE void __devicelib_default_work_group_private_sort_close_ascending_p1u8_p1u32_u32_p1i8( uint8_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch) { private_merge_sort_key_value_close(keys, vals, n, scratch, - std::less{}); + std::less_equal{}); } #endif // __SPIR__ || __SPIRV__ diff --git a/libdevice/sort_helper.hpp b/libdevice/sort_helper.hpp index b3c9d4093bf4f..071db026fc02f 100644 --- a/libdevice/sort_helper.hpp +++ b/libdevice/sort_helper.hpp @@ -244,10 +244,14 @@ void merge_sort_key_value(KeyT *keys, ValT *vals, size_t n, uint8_t *scratch, group_barrier(); bool data_in_scratch = false; KeyT *scratch_keys = reinterpret_cast(scratch); - uint8_t *val_offset = scratch + sizeof(KeyT) * (n + 1); - val_offset += - alignof(ValT) - reinterpret_cast(val_offset) % alignof(ValT); - ValT *scratch_vals = reinterpret_cast(val_offset); + ValT *scratch_vals = nullptr; + uint64_t val_offset = reinterpret_cast(scratch + sizeof(KeyT) * n); + uint64_t temp1 = val_offset % alignof(ValT); + if (temp1) + scratch_vals = reinterpret_cast(val_offset + alignof(ValT) - temp1); + else + scratch_vals = reinterpret_cast(val_offset); + size_t chunks_to_merge = (n - 1) / chunk_size + 1; size_t merge_size = chunk_size; while (chunks_to_merge > 1) { diff --git a/sycl/test-e2e/DeviceLib/group_sort/group_private_KV_sort_p1p1_p1.hpp b/sycl/test-e2e/DeviceLib/group_sort/group_private_KV_sort_p1p1_p1.hpp new file mode 100644 index 0000000000000..c4f5893950a87 --- /dev/null +++ b/sycl/test-e2e/DeviceLib/group_sort/group_private_KV_sort_p1p1_p1.hpp @@ -0,0 +1,20 @@ +#pragma once +#include + +#ifdef __SYCL_DEVICE_ONLY__ +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1u32_p1u32_u32_p1i8( + uint32_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch); + +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1u8_p1u32_u32_p1i8( + uint8_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch); +#else +extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1u32_p1u32_u32_p1i8( + uint32_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch) {} + +extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1u8_p1u32_u32_p1i8( + uint8_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch) {} +#endif diff --git a/sycl/test-e2e/DeviceLib/group_sort/workgroup_joint_KV_sort_p1p1_p1.cpp b/sycl/test-e2e/DeviceLib/group_sort/workgroup_joint_KV_sort_p1p1_p1.cpp index 16f6db4873e39..dbf55e3295bb6 100644 --- a/sycl/test-e2e/DeviceLib/group_sort/workgroup_joint_KV_sort_p1p1_p1.cpp +++ b/sycl/test-e2e/DeviceLib/group_sort/workgroup_joint_KV_sort_p1p1_p1.cpp @@ -98,5 +98,31 @@ int main() { std::cout << "KV joint sort p1u32_p1u32_u32_p1i8 pass." << std::endl; } + { + constexpr static int NUM = 32; + uint32_t ikeys[NUM] = {1, 11, 1, 9, 3, 100, 34, 8, + 121, 77, 125, 23, 36, 2, 111, 91, + 88, 2, 51, 95431, 881, 99183, 31, 142, + 416, 701, 699, 1024, 8912, 0, 7981, 17}; + uint32_t ivals[NUM] = {99, 32, 1, 2, 67, 9123, 453, 435, + 91111, 777, 165, 145, 2456, 88811, 761, 96, + 765, 10000, 6364, 90, 525, 882, 1, 2423, + 9, 4324, 9123, 0, 1232, 777, 555, 314159}; + auto work_group_sorter = [](uint32_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { +#ifdef DES + __devicelib_default_work_group_joint_sort_descending_p1u32_p1u32_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_joint_sort_ascending_p1u32_p1u32_u32_p1i8( + keys, vals, n, scratch); +#endif + }; + test_work_group_KV_joint_sort( + q, ikeys, ivals, work_group_sorter); + std::cout << "KV joint sort p1u32_p1u32_u32_p1i8 pass." << std::endl; + } + return 0; } diff --git a/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_p1p1_p1.cpp b/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_p1p1_p1.cpp new file mode 100644 index 0000000000000..0267fc250071d --- /dev/null +++ b/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_p1p1_p1.cpp @@ -0,0 +1,126 @@ +// RUN: %{build} -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -fsycl-device-lib-jit-link -o %t.out +// RUN: %{run} %t.out + +// UNSUPPORTED: cuda || hip || cpu + +#include "group_private_KV_sort_p1p1_p1.hpp" +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace sycl; + +template +void test_work_group_KV_private_sort(sycl::queue &q, KeyT input_keys[NUM], + ValT input_vals[NUM], SortHelper gsh) { + static_assert((NUM % WG_SZ == 0), + "Input number must be divisible by work group size!"); + size_t scratch_size = 2 * NUM * (sizeof(KeyT) + sizeof(ValT)) + + std::max(alignof(KeyT), alignof(ValT)); + uint8_t *scratch_ptr = + (uint8_t *)aligned_alloc_device(alignof(KeyT), scratch_size, q); + const static size_t wg_size = WG_SZ; + constexpr size_t num_per_work_item = NUM / WG_SZ; + KeyT output_keys[NUM]; + ValT output_vals[NUM]; + std::vector> sorted_vec; + for (size_t idx = 0; idx < NUM; ++idx) + sorted_vec.push_back(std::make_tuple(input_keys[idx], input_vals[idx])); +#ifdef DES + auto kv_tuple_comp = [](const std::tuple &t1, + const std::tuple &t2) { + return std::get<0>(t1) > std::get<0>(t2); + }; +#else + auto kv_tuple_comp = [](const std::tuple &t1, + const std::tuple &t2) { + return std::get<0>(t1) < std::get<0>(t2); + }; +#endif + std::stable_sort(sorted_vec.begin(), sorted_vec.end(), kv_tuple_comp); + + nd_range<1> num_items((range<1>(wg_size)), (range<1>(wg_size))); + { + buffer ikeys_buf(input_keys, NUM); + buffer ivals_buf(input_vals, NUM); + buffer okeys_buf(output_keys, NUM); + buffer ovals_buf(output_vals, NUM); + q.submit([&](auto &h) { + accessor ikeys_acc{ikeys_buf, h}; + accessor ivals_acc{ivals_buf, h}; + accessor okeys_acc{okeys_buf, h}; + accessor ovals_acc{ovals_buf, h}; + sycl::stream os(1024, 128, h); + h.parallel_for(num_items, [=](nd_item<1> i) { + KeyT pkeys[num_per_work_item]; + ValT pvals[num_per_work_item]; + // copy from global input to fix-size private array. + for (size_t idx = 0; idx < num_per_work_item; ++idx) { + pkeys[idx] = + ikeys_acc[i.get_local_linear_id() * num_per_work_item + idx]; + pvals[idx] = + ivals_acc[i.get_local_linear_id() * num_per_work_item + idx]; + } + + gsh(pkeys, pvals, num_per_work_item, scratch_ptr); + + for (size_t idx = 0; idx < num_per_work_item; ++idx) { + okeys_acc[i.get_local_linear_id() * num_per_work_item + idx] = + pkeys[idx]; + ovals_acc[i.get_local_linear_id() * num_per_work_item + idx] = + pvals[idx]; + } + }); + }).wait(); + } + + bool fails = false; + for (size_t idx = 0; idx < NUM; ++idx) { + if ((output_keys[idx] != std::get<0>(sorted_vec[idx])) || + (output_vals[idx] != std::get<1>(sorted_vec[idx]))) { + std::cout << "idx: " << idx << std::endl; + fails = true; + break; + } + } + assert(!fails); +} + +int main() { + queue q; + + { + constexpr static int NUM = 32; + uint32_t ikeys[NUM] = {1, 11, 1, 9, 3, 100, 34, 8, + 121, 77, 125, 23, 36, 2, 111, 91, + 88, 2, 51, 95431, 881, 99183, 31, 142, + 416, 701, 699, 1024, 8912, 0, 7981, 17}; + uint32_t ivals[NUM] = {99, 32, 1, 2, 67, 9123, 453, 435, + 91111, 777, 165, 145, 2456, 88811, 761, 96, + 765, 10000, 6364, 90, 525, 882, 1, 2423, + 9, 4324, 9123, 0, 1232, 777, 555, 314159}; + auto work_group_sorter = [](uint32_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { +#ifdef DES + __devicelib_default_work_group_private_sort_close_descending_p1u32_p1u32_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1u32_p1u32_u32_p1i8( + keys, vals, n, scratch); +#endif + }; + test_work_group_KV_private_sort( + q, ikeys, ivals, work_group_sorter); + std::cout << "KV private sort p1u32_p1u32_u32_p1i8 pass." << std::endl; + } +} From 07a5df47f425e0ac910dccf6435e2844f5bfc743 Mon Sep 17 00:00:00 2001 From: jinge90 Date: Wed, 7 Aug 2024 17:02:32 +0800 Subject: [PATCH 37/71] simplify function name Signed-off-by: jinge90 --- libdevice/fallback-gsort.cpp | 603 ++++++------------ .../workgroup_private_KV_sort_p1p1_p1.cpp | 27 + 2 files changed, 236 insertions(+), 394 deletions(-) diff --git a/libdevice/fallback-gsort.cpp b/libdevice/fallback-gsort.cpp index a1a2843c9b314..ddf68665c8d06 100644 --- a/libdevice/fallback-gsort.cpp +++ b/libdevice/fallback-gsort.cpp @@ -12,438 +12,379 @@ #include #if defined(__SPIR__) || defined(__SPIRV__) +#define WG_JS_A(EP) __devicelib_default_work_group_joint_sort_ascending_##EP +#define WG_JS_D(EP) __devicelib_default_work_group_joint_sort_descending_##EP +#define WG_PS_CA(EP) \ + __devicelib_default_work_group_private_sort_close_ascending_##EP +#define WG_PS_CD(EP) \ + __devicelib_default_work_group_private_sort_close_descending_##EP +#define WG_PS_SA(EP) \ + __devicelib_default_work_group_private_sort_spread_ascending_##EP +#define WG_PS_SD(EP) \ + __devicelib_default_work_group_private_sort_spread_descending_##EP +#define SG_PS_A(EP) __devicelib_default_sub_group_private_sort_ascending_##EP +#define SG_PS_D(EP) __devicelib_default_sub_group_private_sort_descending_##EP + //============ default work grop joint sort for signed integer =============== DEVICE_EXTERN_C_INLINE -void __devicelib_default_work_group_joint_sort_ascending_p1i8_u32_p1i8( - int8_t *first, uint32_t n, uint8_t *scratch) { +void WG_JS_A(p1i8_u32_p1i8)(int8_t *first, uint32_t n, uint8_t *scratch) { merge_sort(first, n, scratch, std::less{}); } DEVICE_EXTERN_C_INLINE -void __devicelib_default_work_group_joint_sort_ascending_p1i8_u32_p3i8( - int8_t *first, uint32_t n, uint8_t *scratch) { +void WG_JS_A(p1i8_u32_p3i8)(int8_t *first, uint32_t n, uint8_t *scratch) { merge_sort(first, n, scratch, std::less{}); } DEVICE_EXTERN_C_INLINE -void __devicelib_default_work_group_joint_sort_ascending_p3i8_u32_p1i8( - int8_t *first, uint32_t n, uint8_t *scratch) { +void WG_JS_A(p3i8_u32_p1i8)(int8_t *first, uint32_t n, uint8_t *scratch) { merge_sort(first, n, scratch, std::less{}); } DEVICE_EXTERN_C_INLINE -void __devicelib_default_work_group_joint_sort_ascending_p3i8_u32_p3i8( - int8_t *first, uint32_t n, uint8_t *scratch) { +void WG_JS_A(p3i8_u32_p3i8)(int8_t *first, uint32_t n, uint8_t *scratch) { merge_sort(first, n, scratch, std::less{}); } DEVICE_EXTERN_C_INLINE -void __devicelib_default_work_group_joint_sort_descending_p1i8_u32_p1i8( - int8_t *first, uint32_t n, uint8_t *scratch) { +void WG_JS_D(p1i8_u32_p1i8)(int8_t *first, uint32_t n, uint8_t *scratch) { merge_sort(first, n, scratch, std::greater{}); } DEVICE_EXTERN_C_INLINE -void __devicelib_default_work_group_joint_sort_descending_p1i8_u32_p3i8( - int8_t *first, uint32_t n, uint8_t *scratch) { +void WG_JS_D(p1i8_u32_p3i8)(int8_t *first, uint32_t n, uint8_t *scratch) { merge_sort(first, n, scratch, std::greater{}); } DEVICE_EXTERN_C_INLINE -void __devicelib_default_work_group_joint_sort_descending_p3i8_u32_p1i8( - int8_t *first, uint32_t n, uint8_t *scratch) { +void WG_JS_D(p3i8_u32_p1i8)(int8_t *first, uint32_t n, uint8_t *scratch) { merge_sort(first, n, scratch, std::greater{}); } DEVICE_EXTERN_C_INLINE -void __devicelib_default_work_group_joint_sort_descending_p3i8_u32_p3i8( - int8_t *first, uint32_t n, uint8_t *scratch) { +void WG_JS_D(p3i8_u32_p3i8)(int8_t *first, uint32_t n, uint8_t *scratch) { merge_sort(first, n, scratch, std::greater{}); } DEVICE_EXTERN_C_INLINE -void __devicelib_default_work_group_joint_sort_ascending_p1i16_u32_p1i8( - int16_t *first, uint32_t n, uint8_t *scratch) { +void WG_JS_A(p1i16_u32_p1i8)(int16_t *first, uint32_t n, uint8_t *scratch) { merge_sort(first, n, scratch, std::less{}); } DEVICE_EXTERN_C_INLINE -void __devicelib_default_work_group_joint_sort_ascending_p1i16_u32_p3i8( - int16_t *first, uint32_t n, uint8_t *scratch) { +void WG_JS_A(p1i16_u32_p3i8)(int16_t *first, uint32_t n, uint8_t *scratch) { merge_sort(first, n, scratch, std::less{}); } DEVICE_EXTERN_C_INLINE -void __devicelib_default_work_group_joint_sort_ascending_p3i16_u32_p1i8( - int16_t *first, uint32_t n, uint8_t *scratch) { +void WG_JS_A(p3i16_u32_p1i8)(int16_t *first, uint32_t n, uint8_t *scratch) { merge_sort(first, n, scratch, std::less{}); } DEVICE_EXTERN_C_INLINE -void __devicelib_default_work_group_joint_sort_ascending_p3i16_u32_p3i8( - int16_t *first, uint32_t n, uint8_t *scratch) { +void WG_JS_A(p3i16_u32_p3i8)(int16_t *first, uint32_t n, uint8_t *scratch) { merge_sort(first, n, scratch, std::less{}); } DEVICE_EXTERN_C_INLINE -void __devicelib_default_work_group_joint_sort_descending_p1i16_u32_p1i8( - int16_t *first, uint32_t n, uint8_t *scratch) { +void WG_JS_D(p1i16_u32_p1i8)(int16_t *first, uint32_t n, uint8_t *scratch) { merge_sort(first, n, scratch, std::greater{}); } DEVICE_EXTERN_C_INLINE -void __devicelib_default_work_group_joint_sort_descending_p1i16_u32_p3i8( - int16_t *first, uint32_t n, uint8_t *scratch) { +void WG_JS_D(p1i16_u32_p3i8)(int16_t *first, uint32_t n, uint8_t *scratch) { merge_sort(first, n, scratch, std::greater{}); } DEVICE_EXTERN_C_INLINE -void __devicelib_default_work_group_joint_sort_descending_p3i16_u32_p1i8( - int16_t *first, uint32_t n, uint8_t *scratch) { +void WG_JS_D(p3i16_u32_p1i8)(int16_t *first, uint32_t n, uint8_t *scratch) { merge_sort(first, n, scratch, std::greater{}); } DEVICE_EXTERN_C_INLINE -void __devicelib_default_work_group_joint_sort_descending_p3i16_u32_p3i8( - int16_t *first, uint32_t n, uint8_t *scratch) { +void WG_JS_D(p3i16_u32_p3i8)(int16_t *first, uint32_t n, uint8_t *scratch) { merge_sort(first, n, scratch, std::greater{}); } DEVICE_EXTERN_C_INLINE -void __devicelib_default_work_group_joint_sort_ascending_p1i32_u32_p1i8( - int32_t *first, uint32_t n, uint8_t *scratch) { +void WG_JS_A(p1i32_u32_p1i8)(int32_t *first, uint32_t n, uint8_t *scratch) { merge_sort(first, n, scratch, std::less{}); } DEVICE_EXTERN_C_INLINE -void __devicelib_default_work_group_joint_sort_ascending_p1i32_u32_p3i8( - int32_t *first, uint32_t n, uint8_t *scratch) { +void WG_JS_A(p1i32_u32_p3i8)(int32_t *first, uint32_t n, uint8_t *scratch) { merge_sort(first, n, scratch, std::less{}); } DEVICE_EXTERN_C_INLINE -void __devicelib_default_work_group_joint_sort_ascending_p3i32_u32_p1i8( - int32_t *first, uint32_t n, uint8_t *scratch) { +void WG_JS_A(p3i32_u32_p1i8)(int32_t *first, uint32_t n, uint8_t *scratch) { merge_sort(first, n, scratch, std::less{}); } DEVICE_EXTERN_C_INLINE -void __devicelib_default_work_group_joint_sort_ascending_p3i32_u32_p3i8( - int32_t *first, uint32_t n, uint8_t *scratch) { +void WG_JS_A(p3i32_u32_p3i8)(int32_t *first, uint32_t n, uint8_t *scratch) { merge_sort(first, n, scratch, std::less{}); } DEVICE_EXTERN_C_INLINE -void __devicelib_default_work_group_joint_sort_descending_p1i32_u32_p1i8( - int32_t *first, uint32_t n, uint8_t *scratch) { +void WG_JS_D(p1i32_u32_p1i8)(int32_t *first, uint32_t n, uint8_t *scratch) { merge_sort(first, n, scratch, std::greater{}); } DEVICE_EXTERN_C_INLINE -void __devicelib_default_work_group_joint_sort_descending_p1i32_u32_p3i8( - int32_t *first, uint32_t n, uint8_t *scratch) { +void WG_JS_D(p1i32_u32_p3i8)(int32_t *first, uint32_t n, uint8_t *scratch) { merge_sort(first, n, scratch, std::greater{}); } DEVICE_EXTERN_C_INLINE -void __devicelib_default_work_group_joint_sort_descending_p3i32_u32_p1i8( - int32_t *first, uint32_t n, uint8_t *scratch) { +void WG_JS_D(p3i32_u32_p1i8)(int32_t *first, uint32_t n, uint8_t *scratch) { merge_sort(first, n, scratch, std::greater{}); } DEVICE_EXTERN_C_INLINE -void __devicelib_default_work_group_joint_sort_descending_p3i32_u32_p3i8( - int32_t *first, uint32_t n, uint8_t *scratch) { +void WG_JS_D(p3i32_u32_p3i8)(int32_t *first, uint32_t n, uint8_t *scratch) { merge_sort(first, n, scratch, std::greater{}); } DEVICE_EXTERN_C_INLINE -void __devicelib_default_work_group_joint_sort_ascending_p1i64_u32_p1i8( - int64_t *first, uint32_t n, uint8_t *scratch) { +void WG_JS_A(p1i64_u32_p1i8)(int64_t *first, uint32_t n, uint8_t *scratch) { merge_sort(first, n, scratch, std::less{}); } DEVICE_EXTERN_C_INLINE -void __devicelib_default_work_group_joint_sort_ascending_p1i64_u32_p3i8( - int64_t *first, uint32_t n, uint8_t *scratch) { +void WG_JS_A(p1i64_u32_p3i8)(int64_t *first, uint32_t n, uint8_t *scratch) { merge_sort(first, n, scratch, std::less{}); } DEVICE_EXTERN_C_INLINE -void __devicelib_default_work_group_joint_sort_ascending_p3i64_u32_p1i8( - int64_t *first, uint32_t n, uint8_t *scratch) { +void WG_JS_A(p3i64_u32_p1i8)(int64_t *first, uint32_t n, uint8_t *scratch) { merge_sort(first, n, scratch, std::less{}); } DEVICE_EXTERN_C_INLINE -void __devicelib_default_work_group_joint_sort_ascending_p3i64_u32_p3i8( - int64_t *first, uint32_t n, uint8_t *scratch) { +void WG_JS_A(p3i64_u32_p3i8)(int64_t *first, uint32_t n, uint8_t *scratch) { merge_sort(first, n, scratch, std::less{}); } DEVICE_EXTERN_C_INLINE -void __devicelib_default_work_group_joint_sort_descending_p1i64_u32_p1i8( - int64_t *first, uint32_t n, uint8_t *scratch) { +void WG_JS_D(p1i64_u32_p1i8)(int64_t *first, uint32_t n, uint8_t *scratch) { merge_sort(first, n, scratch, std::greater{}); } DEVICE_EXTERN_C_INLINE -void __devicelib_default_work_group_joint_sort_descending_p1i64_u32_p3i8( - int64_t *first, uint32_t n, uint8_t *scratch) { +void WG_JS_D(p1i64_u32_p3i8)(int64_t *first, uint32_t n, uint8_t *scratch) { merge_sort(first, n, scratch, std::greater{}); } DEVICE_EXTERN_C_INLINE -void __devicelib_default_work_group_joint_sort_descending_p3i64_u32_p1i8( - int64_t *first, uint32_t n, uint8_t *scratch) { +void WG_JS_D(p3i64_u32_p1i8)(int64_t *first, uint32_t n, uint8_t *scratch) { merge_sort(first, n, scratch, std::greater{}); } DEVICE_EXTERN_C_INLINE -void __devicelib_default_work_group_joint_sort_descending_p3i64_u32_p3i8( - int64_t *first, uint32_t n, uint8_t *scratch) { +void WG_JS_D(p3i64_u32_p3i8)(int64_t *first, uint32_t n, uint8_t *scratch) { merge_sort(first, n, scratch, std::greater{}); } //=========== default work grop joint sort for unsigned integer ============== DEVICE_EXTERN_C_INLINE -void __devicelib_default_work_group_joint_sort_ascending_p1u8_u32_p1i8( - uint8_t *first, uint32_t n, uint8_t *scratch) { +void WG_JS_A(p1u8_u32_p1i8)(uint8_t *first, uint32_t n, uint8_t *scratch) { merge_sort(first, n, scratch, std::less{}); } DEVICE_EXTERN_C_INLINE -void __devicelib_default_work_group_joint_sort_ascending_p1u8_u32_p3i8( - uint8_t *first, uint32_t n, uint8_t *scratch) { +void WG_JS_A(p1u8_u32_p3i8)(uint8_t *first, uint32_t n, uint8_t *scratch) { merge_sort(first, n, scratch, std::less{}); } DEVICE_EXTERN_C_INLINE -void __devicelib_default_work_group_joint_sort_ascending_p3u8_u32_p1i8( - uint8_t *first, uint32_t n, uint8_t *scratch) { +void WG_JS_A(p3u8_u32_p1i8)(uint8_t *first, uint32_t n, uint8_t *scratch) { merge_sort(first, n, scratch, std::less{}); } DEVICE_EXTERN_C_INLINE -void __devicelib_default_work_group_joint_sort_ascending_p3u8_u32_p3i8( - uint8_t *first, uint32_t n, uint8_t *scratch) { +void WG_JS_A(p3u8_u32_p3i8)(uint8_t *first, uint32_t n, uint8_t *scratch) { merge_sort(first, n, scratch, std::less{}); } DEVICE_EXTERN_C_INLINE -void __devicelib_default_work_group_joint_sort_descending_p1u8_u32_p1i8( - uint8_t *first, uint32_t n, uint8_t *scratch) { +void WG_JS_D(p1u8_u32_p1i8)(uint8_t *first, uint32_t n, uint8_t *scratch) { merge_sort(first, n, scratch, std::greater{}); } DEVICE_EXTERN_C_INLINE -void __devicelib_default_work_group_joint_sort_descending_p1u8_u32_p3i8( - uint8_t *first, uint32_t n, uint8_t *scratch) { +void WG_JS_D(p1u8_u32_p3i8)(uint8_t *first, uint32_t n, uint8_t *scratch) { merge_sort(first, n, scratch, std::greater{}); } DEVICE_EXTERN_C_INLINE -void __devicelib_default_work_group_joint_sort_descending_p3u8_u32_p1i8( - uint8_t *first, uint32_t n, uint8_t *scratch) { +void WG_JS_D(p3u8_u32_p1i8)(uint8_t *first, uint32_t n, uint8_t *scratch) { merge_sort(first, n, scratch, std::greater{}); } DEVICE_EXTERN_C_INLINE -void __devicelib_default_work_group_joint_sort_descending_p3u8_u32_p3i8( - uint8_t *first, uint32_t n, uint8_t *scratch) { +void WG_JS_D(p3u8_u32_p3i8)(uint8_t *first, uint32_t n, uint8_t *scratch) { merge_sort(first, n, scratch, std::greater{}); } DEVICE_EXTERN_C_INLINE -void __devicelib_default_work_group_joint_sort_ascending_p1u16_u32_p1i8( - uint16_t *first, uint32_t n, uint8_t *scratch) { +void WG_JS_A(p1u16_u32_p1i8)(uint16_t *first, uint32_t n, uint8_t *scratch) { merge_sort(first, n, scratch, std::less{}); } DEVICE_EXTERN_C_INLINE -void __devicelib_default_work_group_joint_sort_ascending_p1u16_u32_p3i8( - uint16_t *first, uint32_t n, uint8_t *scratch) { +void WG_JS_A(p1u16_u32_p3i8)(uint16_t *first, uint32_t n, uint8_t *scratch) { merge_sort(first, n, scratch, std::less{}); } DEVICE_EXTERN_C_INLINE -void __devicelib_default_work_group_joint_sort_ascending_p3u16_u32_p1i8( - uint16_t *first, uint32_t n, uint8_t *scratch) { +void WG_JS_A(p3u16_u32_p1i8)(uint16_t *first, uint32_t n, uint8_t *scratch) { merge_sort(first, n, scratch, std::less{}); } DEVICE_EXTERN_C_INLINE -void __devicelib_default_work_group_joint_sort_ascending_p3u16_u32_p3i8( - uint16_t *first, uint32_t n, uint8_t *scratch) { +void WG_JS_A(p3u16_u32_p3i8)(uint16_t *first, uint32_t n, uint8_t *scratch) { merge_sort(first, n, scratch, std::less{}); } DEVICE_EXTERN_C_INLINE -void __devicelib_default_work_group_joint_sort_descending_p1u16_u32_p1i8( - uint16_t *first, uint32_t n, uint8_t *scratch) { +void WG_JS_D(p1u16_u32_p1i8)(uint16_t *first, uint32_t n, uint8_t *scratch) { merge_sort(first, n, scratch, std::greater{}); } DEVICE_EXTERN_C_INLINE -void __devicelib_default_work_group_joint_sort_descending_p1u16_u32_p3i8( - uint16_t *first, uint32_t n, uint8_t *scratch) { +void WG_JS_D(p1u16_u32_p3i8)(uint16_t *first, uint32_t n, uint8_t *scratch) { merge_sort(first, n, scratch, std::greater{}); } DEVICE_EXTERN_C_INLINE -void __devicelib_default_work_group_joint_sort_descending_p3u16_u32_p1i8( - uint16_t *first, uint32_t n, uint8_t *scratch) { +void WG_JS_D(p3u16_u32_p1i8)(uint16_t *first, uint32_t n, uint8_t *scratch) { merge_sort(first, n, scratch, std::greater{}); } DEVICE_EXTERN_C_INLINE -void __devicelib_default_work_group_joint_sort_descending_p3u16_u32_p3i8( - uint16_t *first, uint32_t n, uint8_t *scratch) { +void WG_JS_D(p3u16_u32_p3i8)(uint16_t *first, uint32_t n, uint8_t *scratch) { merge_sort(first, n, scratch, std::greater{}); } DEVICE_EXTERN_C_INLINE -void __devicelib_default_work_group_joint_sort_ascending_p1u32_u32_p1i8( - uint32_t *first, uint32_t n, uint8_t *scratch) { +void WG_JS_A(p1u32_u32_p1i8)(uint32_t *first, uint32_t n, uint8_t *scratch) { merge_sort(first, n, scratch, std::less{}); } DEVICE_EXTERN_C_INLINE -void __devicelib_default_work_group_joint_sort_ascending_p1u32_u32_p3i8( - uint32_t *first, uint32_t n, uint8_t *scratch) { +void WG_JS_A(p1u32_u32_p3i8)(uint32_t *first, uint32_t n, uint8_t *scratch) { merge_sort(first, n, scratch, std::less{}); } DEVICE_EXTERN_C_INLINE -void __devicelib_default_work_group_joint_sort_ascending_p3u32_u32_p1i8( - uint32_t *first, uint32_t n, uint8_t *scratch) { +void WG_JS_A(p3u32_u32_p1i8)(uint32_t *first, uint32_t n, uint8_t *scratch) { merge_sort(first, n, scratch, std::less{}); } DEVICE_EXTERN_C_INLINE -void __devicelib_default_work_group_joint_sort_ascending_p3u32_u32_p3i8( - uint32_t *first, uint32_t n, uint8_t *scratch) { +void WG_JS_A(p3u32_u32_p3i8)(uint32_t *first, uint32_t n, uint8_t *scratch) { merge_sort(first, n, scratch, std::less{}); } DEVICE_EXTERN_C_INLINE -void __devicelib_default_work_group_joint_sort_descending_p1u32_u32_p1i8( - uint32_t *first, uint32_t n, uint8_t *scratch) { +void WG_JS_D(p1u32_u32_p1i8)(uint32_t *first, uint32_t n, uint8_t *scratch) { merge_sort(first, n, scratch, std::greater{}); } DEVICE_EXTERN_C_INLINE -void __devicelib_default_work_group_joint_sort_descending_p1u32_u32_p3i8( - uint32_t *first, uint32_t n, uint8_t *scratch) { +void WG_JS_D(p1u32_u32_p3i8)(uint32_t *first, uint32_t n, uint8_t *scratch) { merge_sort(first, n, scratch, std::greater{}); } DEVICE_EXTERN_C_INLINE -void __devicelib_default_work_group_joint_sort_descending_p3u32_u32_p1i8( - uint32_t *first, uint32_t n, uint8_t *scratch) { +void WG_JS_D(p3u32_u32_p1i8)(uint32_t *first, uint32_t n, uint8_t *scratch) { merge_sort(first, n, scratch, std::greater{}); } DEVICE_EXTERN_C_INLINE -void __devicelib_default_work_group_joint_sort_descending_p3u32_u32_p3i8( - uint32_t *first, uint32_t n, uint8_t *scratch) { +void WG_JS_D(p3u32_u32_p3i8)(uint32_t *first, uint32_t n, uint8_t *scratch) { merge_sort(first, n, scratch, std::greater{}); } DEVICE_EXTERN_C_INLINE -void __devicelib_default_work_group_joint_sort_ascending_p1u64_u32_p1i8( - uint64_t *first, uint32_t n, uint8_t *scratch) { +void WG_JS_A(p1u64_u32_p1i8)(uint64_t *first, uint32_t n, uint8_t *scratch) { merge_sort(first, n, scratch, std::less{}); } DEVICE_EXTERN_C_INLINE -void __devicelib_default_work_group_joint_sort_ascending_p1u64_u32_p3i8( - uint64_t *first, uint32_t n, uint8_t *scratch) { +void WG_JS_A(p1u64_u32_p3i8)(uint64_t *first, uint32_t n, uint8_t *scratch) { merge_sort(first, n, scratch, std::less{}); } DEVICE_EXTERN_C_INLINE -void __devicelib_default_work_group_joint_sort_ascending_p3u64_u32_p1i8( - uint64_t *first, uint32_t n, uint8_t *scratch) { +void WG_JS_A(p3u64_u32_p1i8)(uint64_t *first, uint32_t n, uint8_t *scratch) { merge_sort(first, n, scratch, std::less{}); } DEVICE_EXTERN_C_INLINE -void __devicelib_default_work_group_joint_sort_ascending_p3u64_u32_p3i8( - uint64_t *first, uint32_t n, uint8_t *scratch) { +void WG_JS_A(p3u64_u32_p3i8)(uint64_t *first, uint32_t n, uint8_t *scratch) { merge_sort(first, n, scratch, std::less{}); } DEVICE_EXTERN_C_INLINE -void __devicelib_default_work_group_joint_sort_descending_p1u64_u32_p1i8( - uint64_t *first, uint32_t n, uint8_t *scratch) { +void WG_JS_D(p1u64_u32_p1i8)(uint64_t *first, uint32_t n, uint8_t *scratch) { merge_sort(first, n, scratch, std::greater{}); } DEVICE_EXTERN_C_INLINE -void __devicelib_default_work_group_joint_sort_descending_p1u64_u32_p3i8( - uint64_t *first, uint32_t n, uint8_t *scratch) { +void WG_JS_D(p1u64_u32_p3i8)(uint64_t *first, uint32_t n, uint8_t *scratch) { merge_sort(first, n, scratch, std::greater{}); } DEVICE_EXTERN_C_INLINE -void __devicelib_default_work_group_joint_sort_descending_p3u64_u32_p1i8( - uint64_t *first, uint32_t n, uint8_t *scratch) { +void WG_JS_D(p3u64_u32_p1i8)(uint64_t *first, uint32_t n, uint8_t *scratch) { merge_sort(first, n, scratch, std::greater{}); } DEVICE_EXTERN_C_INLINE -void __devicelib_default_work_group_joint_sort_descending_p3u64_u32_p3i8( - uint64_t *first, uint32_t n, uint8_t *scratch) { +void WG_JS_D(p3u64_u32_p3i8)(uint64_t *first, uint32_t n, uint8_t *scratch) { merge_sort(first, n, scratch, std::greater{}); } //=============== default work grop joint sort for fp32 ====================== DEVICE_EXTERN_C_INLINE -void __devicelib_default_work_group_joint_sort_ascending_p1f32_u32_p1i8( - float *first, uint32_t n, uint8_t *scratch) { +void WG_JS_A(p1f32_u32_p1i8)(float *first, uint32_t n, uint8_t *scratch) { merge_sort(first, n, scratch, std::less{}); } DEVICE_EXTERN_C_INLINE -void __devicelib_default_work_group_joint_sort_ascending_p1f32_u32_p3i8( - float *first, uint32_t n, uint8_t *scratch) { +void WG_JS_A(p1f32_u32_p3i8)(float *first, uint32_t n, uint8_t *scratch) { merge_sort(first, n, scratch, std::less{}); } DEVICE_EXTERN_C_INLINE -void __devicelib_default_work_group_joint_sort_ascending_p3f32_u32_p1i8( - float *first, uint32_t n, uint8_t *scratch) { +void WG_JS_A(p3f32_u32_p1i8)(float *first, uint32_t n, uint8_t *scratch) { merge_sort(first, n, scratch, std::less{}); } DEVICE_EXTERN_C_INLINE -void __devicelib_default_work_group_joint_sort_ascending_p3f32_u32_p3i8( - float *first, uint32_t n, uint8_t *scratch) { +void WG_JS_A(p3f32_u32_p3i8)(float *first, uint32_t n, uint8_t *scratch) { merge_sort(first, n, scratch, std::less{}); } DEVICE_EXTERN_C_INLINE -void __devicelib_default_work_group_joint_sort_descending_p1f32_u32_p1i8( - float *first, uint32_t n, uint8_t *scratch) { +void WG_JS_D(p1f32_u32_p1i8)(float *first, uint32_t n, uint8_t *scratch) { merge_sort(first, n, scratch, std::greater{}); } DEVICE_EXTERN_C_INLINE -void __devicelib_default_work_group_joint_sort_descending_p1f32_u32_p3i8( - float *first, uint32_t n, uint8_t *scratch) { +void WG_JS_D(p1f32_u32_p3i8)(float *first, uint32_t n, uint8_t *scratch) { merge_sort(first, n, scratch, std::greater{}); } DEVICE_EXTERN_C_INLINE -void __devicelib_default_work_group_joint_sort_descending_p3f32_u32_p1i8( - float *first, uint32_t n, uint8_t *scratch) { +void WG_JS_D(p3f32_u32_p1i8)(float *first, uint32_t n, uint8_t *scratch) { merge_sort(first, n, scratch, std::greater{}); } DEVICE_EXTERN_C_INLINE -void __devicelib_default_work_group_joint_sort_descending_p3f32_u32_p3i8( - float *first, uint32_t n, uint8_t *scratch) { +void WG_JS_D(p3f32_u32_p3i8)(float *first, uint32_t n, uint8_t *scratch) { merge_sort(first, n, scratch, std::greater{}); } @@ -451,50 +392,42 @@ void __devicelib_default_work_group_joint_sort_descending_p3f32_u32_p3i8( // doesn't support native fp16 //=============== default work grop joint sort for fp16 ====================== DEVICE_EXTERN_C_INLINE -void __devicelib_default_work_group_joint_sort_ascending_p1f16_u32_p1i8( - _Float16 *first, uint32_t n, uint8_t *scratch) { +void WG_JS_A(p1f16_u32_p1i8)(_Float16 *first, uint32_t n, uint8_t *scratch) { merge_sort(first, n, scratch, [](_Float16 a, _Float16 b) { return (a < b); }); } DEVICE_EXTERN_C_INLINE -void __devicelib_default_work_group_joint_sort_ascending_p1f16_u32_p3i8( - _Float16 *first, uint32_t n, uint8_t *scratch) { +void WG_JS_A(p1f16_u32_p3i8)(_Float16 *first, uint32_t n, uint8_t *scratch) { merge_sort(first, n, scratch, [](_Float16 a, _Float16 b) { return (a < b); }); } DEVICE_EXTERN_C_INLINE -void __devicelib_default_work_group_joint_sort_ascending_p3f16_u32_p1i8( - _Float16 *first, uint32_t n, uint8_t *scratch) { +void WG_JS_A(p3f16_u32_p1i8)(_Float16 *first, uint32_t n, uint8_t *scratch) { merge_sort(first, n, scratch, [](_Float16 a, _Float16 b) { return (a < b); }); } DEVICE_EXTERN_C_INLINE -void __devicelib_default_work_group_joint_sort_ascending_p3f16_u32_p3i8( - _Float16 *first, uint32_t n, uint8_t *scratch) { +void WG_JS_A(p3f16_u32_p3i8)(_Float16 *first, uint32_t n, uint8_t *scratch) { merge_sort(first, n, scratch, [](_Float16 a, _Float16 b) { return (a < b); }); } DEVICE_EXTERN_C_INLINE -void __devicelib_default_work_group_joint_sort_descending_p1f16_u32_p1i8( - _Float16 *first, uint32_t n, uint8_t *scratch) { +void WG_JS_D(p1f16_u32_p1i8)(_Float16 *first, uint32_t n, uint8_t *scratch) { merge_sort(first, n, scratch, [](_Float16 a, _Float16 b) { return (a > b); }); } DEVICE_EXTERN_C_INLINE -void __devicelib_default_work_group_joint_sort_descending_p1f16_u32_p3i8( - _Float16 *first, uint32_t n, uint8_t *scratch) { +void WG_JS_D(p1f16_u32_p3i8)(_Float16 *first, uint32_t n, uint8_t *scratch) { merge_sort(first, n, scratch, [](_Float16 a, _Float16 b) { return (a > b); }); } DEVICE_EXTERN_C_INLINE -void __devicelib_default_work_group_joint_sort_descending_p3f16_u32_p1i8( - _Float16 *first, uint32_t n, uint8_t *scratch) { +void WG_JS_D(p3f16_u32_p1i8)(_Float16 *first, uint32_t n, uint8_t *scratch) { merge_sort(first, n, scratch, [](_Float16 a, _Float16 b) { return (a > b); }); } DEVICE_EXTERN_C_INLINE -void __devicelib_default_work_group_joint_sort_descending_p3f16_u32_p3i8( - _Float16 *first, uint32_t n, uint8_t *scratch) { +void WG_JS_D(p3f16_u32_p3i8)(_Float16 *first, uint32_t n, uint8_t *scratch) { merge_sort(first, n, scratch, [](_Float16 a, _Float16 b) { return (a > b); }); } @@ -502,687 +435,569 @@ void __devicelib_default_work_group_joint_sort_descending_p3f16_u32_p3i8( // Since 'first' should point to 'private' memory address space, it can only be // decorated with 'p1'. DEVICE_EXTERN_C_INLINE -void __devicelib_default_work_group_private_sort_close_ascending_p1i8_u32_p1i8( - int8_t *first, uint32_t n, uint8_t *scratch) { +void WG_PS_CA(p1i8_u32_p1i8)(int8_t *first, uint32_t n, uint8_t *scratch) { private_merge_sort_close(first, n, scratch, std::less{}); } DEVICE_EXTERN_C_INLINE -void __devicelib_default_work_group_private_sort_close_ascending_p1i8_u32_p3i8( - int8_t *first, uint32_t n, uint8_t *scratch) { +void WG_PS_CA(p1i8_u32_p3i8)(int8_t *first, uint32_t n, uint8_t *scratch) { private_merge_sort_close(first, n, scratch, std::less{}); } DEVICE_EXTERN_C_INLINE -void __devicelib_default_work_group_private_sort_close_descending_p1i8_u32_p1i8( - int8_t *first, uint32_t n, uint8_t *scratch) { +void WG_PS_CD(p1i8_u32_p1i8)(int8_t *first, uint32_t n, uint8_t *scratch) { private_merge_sort_close(first, n, scratch, std::greater{}); } DEVICE_EXTERN_C_INLINE -void __devicelib_default_work_group_private_sort_close_descending_p1i8_u32_p3i8( - int8_t *first, uint32_t n, uint8_t *scratch) { +void WG_PS_CD(p1i8_u32_p3i8)(int8_t *first, uint32_t n, uint8_t *scratch) { private_merge_sort_close(first, n, scratch, std::greater{}); } DEVICE_EXTERN_C_INLINE -void __devicelib_default_work_group_private_sort_spread_ascending_p1i8_u32_p1i8( - int8_t *first, uint32_t n, uint8_t *scratch) { +void WG_PS_SA(p1i8_u32_p1i8)(int8_t *first, uint32_t n, uint8_t *scratch) { private_merge_sort_spread(first, n, scratch, std::less{}); } DEVICE_EXTERN_C_INLINE -void __devicelib_default_work_group_private_sort_spread_ascending_p1i8_u32_p3i8( - int8_t *first, uint32_t n, uint8_t *scratch) { +void WG_PS_SA(p1i8_u32_p3i8)(int8_t *first, uint32_t n, uint8_t *scratch) { private_merge_sort_spread(first, n, scratch, std::less{}); } DEVICE_EXTERN_C_INLINE -void __devicelib_default_work_group_private_sort_spread_descending_p1i8_u32_p1i8( - int8_t *first, uint32_t n, uint8_t *scratch) { +void WG_PS_SD(p1i8_u32_p1i8)(int8_t *first, uint32_t n, uint8_t *scratch) { private_merge_sort_spread(first, n, scratch, std::greater{}); } DEVICE_EXTERN_C_INLINE -void __devicelib_default_work_group_private_sort_spread_descending_p1i8_u32_p3i8( - int8_t *first, uint32_t n, uint8_t *scratch) { +void WG_PS_SD(p1i8_u32_p3i8)(int8_t *first, uint32_t n, uint8_t *scratch) { private_merge_sort_spread(first, n, scratch, std::greater{}); } DEVICE_EXTERN_C_INLINE -void __devicelib_default_work_group_private_sort_close_ascending_p1i16_u32_p1i8( - int16_t *first, uint32_t n, uint8_t *scratch) { +void WG_PS_CA(p1i16_u32_p1i8)(int16_t *first, uint32_t n, uint8_t *scratch) { private_merge_sort_close(first, n, scratch, std::less{}); } DEVICE_EXTERN_C_INLINE -void __devicelib_default_work_group_private_sort_close_ascending_p1i16_u32_p3i8( - int16_t *first, uint32_t n, uint8_t *scratch) { +void WG_PS_CA(p1i16_u32_p3i8)(int16_t *first, uint32_t n, uint8_t *scratch) { private_merge_sort_close(first, n, scratch, std::less{}); } DEVICE_EXTERN_C_INLINE -void __devicelib_default_work_group_private_sort_close_descending_p1i16_u32_p1i8( - int16_t *first, uint32_t n, uint8_t *scratch) { +void WG_PS_CD(p1i16_u32_p1i8)(int16_t *first, uint32_t n, uint8_t *scratch) { private_merge_sort_close(first, n, scratch, std::greater{}); } DEVICE_EXTERN_C_INLINE -void __devicelib_default_work_group_private_sort_close_descending_p1i16_u32_p3i8( - int16_t *first, uint32_t n, uint8_t *scratch) { +void WG_PS_CD(p1i16_u32_p3i8)(int16_t *first, uint32_t n, uint8_t *scratch) { private_merge_sort_close(first, n, scratch, std::greater{}); } DEVICE_EXTERN_C_INLINE -void __devicelib_default_work_group_private_sort_spread_ascending_p1i16_u32_p1i8( - int16_t *first, uint32_t n, uint8_t *scratch) { +void WG_PS_SA(p1i16_u32_p1i8)(int16_t *first, uint32_t n, uint8_t *scratch) { private_merge_sort_spread(first, n, scratch, std::less{}); } DEVICE_EXTERN_C_INLINE -void __devicelib_default_work_group_private_sort_spread_ascending_p1i16_u32_p3i8( - int16_t *first, uint32_t n, uint8_t *scratch) { +void WG_PS_SA(p1i16_u32_p3i8)(int16_t *first, uint32_t n, uint8_t *scratch) { private_merge_sort_spread(first, n, scratch, std::less{}); } DEVICE_EXTERN_C_INLINE -void __devicelib_default_work_group_private_sort_spread_descending_p1i16_u32_p1i8( - int16_t *first, uint32_t n, uint8_t *scratch) { +void WG_PS_SD(p1i16_u32_p1i8)(int16_t *first, uint32_t n, uint8_t *scratch) { private_merge_sort_spread(first, n, scratch, std::greater{}); } DEVICE_EXTERN_C_INLINE -void __devicelib_default_work_group_private_sort_spread_descending_p1i16_u32_p3i8( - int16_t *first, uint32_t n, uint8_t *scratch) { +void WG_PS_SD(p1i16_u32_p3i8)(int16_t *first, uint32_t n, uint8_t *scratch) { private_merge_sort_spread(first, n, scratch, std::greater{}); } DEVICE_EXTERN_C_INLINE -void __devicelib_default_work_group_private_sort_close_ascending_p1i32_u32_p1i8( - int32_t *first, uint32_t n, uint8_t *scratch) { +void WG_PS_CA(p1i32_u32_p1i8)(int32_t *first, uint32_t n, uint8_t *scratch) { private_merge_sort_close(first, n, scratch, std::less{}); } DEVICE_EXTERN_C_INLINE -void __devicelib_default_work_group_private_sort_close_ascending_p1i32_u32_p3i8( - int32_t *first, uint32_t n, uint8_t *scratch) { +void WG_PS_CA(p1i32_u32_p3i8)(int32_t *first, uint32_t n, uint8_t *scratch) { private_merge_sort_close(first, n, scratch, std::less{}); } DEVICE_EXTERN_C_INLINE -void __devicelib_default_work_group_private_sort_close_descending_p1i32_u32_p1i8( - int32_t *first, uint32_t n, uint8_t *scratch) { +void WG_PS_CD(p1i32_u32_p1i8)(int32_t *first, uint32_t n, uint8_t *scratch) { private_merge_sort_close(first, n, scratch, std::greater{}); } DEVICE_EXTERN_C_INLINE -void __devicelib_default_work_group_private_sort_close_descending_p1i32_u32_p3i8( - int32_t *first, uint32_t n, uint8_t *scratch) { +void WG_PS_CD(p1i32_u32_p3i8)(int32_t *first, uint32_t n, uint8_t *scratch) { private_merge_sort_close(first, n, scratch, std::greater{}); } DEVICE_EXTERN_C_INLINE -void __devicelib_default_work_group_private_sort_spread_ascending_p1i32_u32_p1i8( - int32_t *first, uint32_t n, uint8_t *scratch) { +void WG_PS_SA(p1i32_u32_p1i8)(int32_t *first, uint32_t n, uint8_t *scratch) { private_merge_sort_spread(first, n, scratch, std::less{}); } DEVICE_EXTERN_C_INLINE -void __devicelib_default_work_group_private_sort_spread_ascending_p1i32_u32_p3i8( - int32_t *first, uint32_t n, uint8_t *scratch) { +void WG_PS_SA(p1i32_u32_p3i8)(int32_t *first, uint32_t n, uint8_t *scratch) { private_merge_sort_spread(first, n, scratch, std::less{}); } DEVICE_EXTERN_C_INLINE -void __devicelib_default_work_group_private_sort_spread_descending_p1i32_u32_p1i8( - int32_t *first, uint32_t n, uint8_t *scratch) { +void WG_PS_SD(p1i32_u32_p1i8)(int32_t *first, uint32_t n, uint8_t *scratch) { private_merge_sort_spread(first, n, scratch, std::greater{}); } DEVICE_EXTERN_C_INLINE -void __devicelib_default_work_group_private_sort_spread_descending_p1i32_u32_p3i8( - int32_t *first, uint32_t n, uint8_t *scratch) { +void WG_PS_SD(p1i32_u32_p3i8)(int32_t *first, uint32_t n, uint8_t *scratch) { private_merge_sort_spread(first, n, scratch, std::greater{}); } DEVICE_EXTERN_C_INLINE -void __devicelib_default_work_group_private_sort_close_ascending_p1i64_u32_p1i8( - int64_t *first, uint32_t n, uint8_t *scratch) { +void WG_PS_CA(p1i64_u32_p1i8)(int64_t *first, uint32_t n, uint8_t *scratch) { private_merge_sort_close(first, n, scratch, std::less{}); } DEVICE_EXTERN_C_INLINE -void __devicelib_default_work_group_private_sort_close_ascending_p1i64_u32_p3i8( - int64_t *first, uint32_t n, uint8_t *scratch) { +void WG_PS_CA(p1i64_u32_p3i8)(int64_t *first, uint32_t n, uint8_t *scratch) { private_merge_sort_close(first, n, scratch, std::less{}); } DEVICE_EXTERN_C_INLINE -void __devicelib_default_work_group_private_sort_close_descending_p1i64_u32_p1i8( - int64_t *first, uint32_t n, uint8_t *scratch) { +void WG_PS_CD(p1i64_u32_p1i8)(int64_t *first, uint32_t n, uint8_t *scratch) { private_merge_sort_close(first, n, scratch, std::greater{}); } DEVICE_EXTERN_C_INLINE -void __devicelib_default_work_group_private_sort_close_descending_p1i64_u32_p3i8( - int64_t *first, uint32_t n, uint8_t *scratch) { +void WG_PS_CD(p1i64_u32_p3i8)(int64_t *first, uint32_t n, uint8_t *scratch) { private_merge_sort_close(first, n, scratch, std::greater{}); } DEVICE_EXTERN_C_INLINE -void __devicelib_default_work_group_private_sort_spread_ascending_p1i64_u32_p1i8( - int64_t *first, uint32_t n, uint8_t *scratch) { +void WG_PS_SA(p1i64_u32_p1i8)(int64_t *first, uint32_t n, uint8_t *scratch) { private_merge_sort_spread(first, n, scratch, std::less{}); } DEVICE_EXTERN_C_INLINE -void __devicelib_default_work_group_private_sort_spread_ascending_p1i64_u32_p3i8( - int64_t *first, uint32_t n, uint8_t *scratch) { +void WG_PS_SA(p1i64_u32_p3i8)(int64_t *first, uint32_t n, uint8_t *scratch) { private_merge_sort_spread(first, n, scratch, std::less{}); } DEVICE_EXTERN_C_INLINE -void __devicelib_default_work_group_private_sort_spread_descending_p1i64_u32_p1i8( - int64_t *first, uint32_t n, uint8_t *scratch) { +void WG_PS_SD(p1i64_u32_p1i8)(int64_t *first, uint32_t n, uint8_t *scratch) { private_merge_sort_spread(first, n, scratch, std::greater{}); } DEVICE_EXTERN_C_INLINE -void __devicelib_default_work_group_private_sort_spread_descending_p1i64_u32_p3i8( - int64_t *first, uint32_t n, uint8_t *scratch) { +void WG_PS_SD(p1i64_u32_p3i8)(int64_t *first, uint32_t n, uint8_t *scratch) { private_merge_sort_spread(first, n, scratch, std::greater{}); } //=========== default work grop private sort for unsigned integer ============= DEVICE_EXTERN_C_INLINE -void __devicelib_default_work_group_private_sort_close_ascending_p1u8_u32_p1i8( - uint8_t *first, uint32_t n, uint8_t *scratch) { +void WG_PS_CA(p1u8_u32_p1i8)(uint8_t *first, uint32_t n, uint8_t *scratch) { private_merge_sort_close(first, n, scratch, std::less{}); } DEVICE_EXTERN_C_INLINE -void __devicelib_default_work_group_private_sort_close_ascending_p1u8_u32_p3i8( - uint8_t *first, uint32_t n, uint8_t *scratch) { +void WG_PS_CA(p1u8_u32_p3i8)(uint8_t *first, uint32_t n, uint8_t *scratch) { private_merge_sort_close(first, n, scratch, std::less{}); } DEVICE_EXTERN_C_INLINE -void __devicelib_default_work_group_private_sort_close_descending_p1u8_u32_p1i8( - uint8_t *first, uint32_t n, uint8_t *scratch) { +void WG_PS_CD(p1u8_u32_p1i8)(uint8_t *first, uint32_t n, uint8_t *scratch) { private_merge_sort_close(first, n, scratch, std::greater{}); } DEVICE_EXTERN_C_INLINE -void __devicelib_default_work_group_private_sort_close_descending_p1u8_u32_p3i8( - uint8_t *first, uint32_t n, uint8_t *scratch) { +void WG_PS_CD(p1u8_u32_p3i8)(uint8_t *first, uint32_t n, uint8_t *scratch) { private_merge_sort_close(first, n, scratch, std::greater{}); } DEVICE_EXTERN_C_INLINE -void __devicelib_default_work_group_private_sort_spread_ascending_p1u8_u32_p1i8( - uint8_t *first, uint32_t n, uint8_t *scratch) { +void WG_PS_SA(p1u8_u32_p1i8)(uint8_t *first, uint32_t n, uint8_t *scratch) { private_merge_sort_spread(first, n, scratch, std::less{}); } DEVICE_EXTERN_C_INLINE -void __devicelib_default_work_group_private_sort_spread_ascending_p1u8_u32_p3i8( - uint8_t *first, uint32_t n, uint8_t *scratch) { +void WG_PS_SA(p1u8_u32_p3i8)(uint8_t *first, uint32_t n, uint8_t *scratch) { private_merge_sort_spread(first, n, scratch, std::less{}); } DEVICE_EXTERN_C_INLINE -void __devicelib_default_work_group_private_sort_spread_descending_p1u8_u32_p1i8( - uint8_t *first, uint32_t n, uint8_t *scratch) { +void WG_PS_SD(p1u8_u32_p1i8)(uint8_t *first, uint32_t n, uint8_t *scratch) { private_merge_sort_spread(first, n, scratch, std::greater{}); } DEVICE_EXTERN_C_INLINE -void __devicelib_default_work_group_private_sort_spread_descending_p1u8_u32_p3i8( - uint8_t *first, uint32_t n, uint8_t *scratch) { +void WG_PS_SD(p1u8_u32_p3i8)(uint8_t *first, uint32_t n, uint8_t *scratch) { private_merge_sort_spread(first, n, scratch, std::greater{}); } DEVICE_EXTERN_C_INLINE -void __devicelib_default_work_group_private_sort_close_ascending_p1u16_u32_p1i8( - uint16_t *first, uint32_t n, uint8_t *scratch) { +void WG_PS_CA(p1u16_u32_p1i8)(uint16_t *first, uint32_t n, uint8_t *scratch) { private_merge_sort_close(first, n, scratch, std::less{}); } DEVICE_EXTERN_C_INLINE -void __devicelib_default_work_group_private_sort_close_ascending_p1u16_u32_p3i8( - uint16_t *first, uint32_t n, uint8_t *scratch) { +void WG_PS_CA(p1u16_u32_p3i8)(uint16_t *first, uint32_t n, uint8_t *scratch) { private_merge_sort_close(first, n, scratch, std::less{}); } DEVICE_EXTERN_C_INLINE -void __devicelib_default_work_group_private_sort_close_descending_p1u16_u32_p1i8( - uint16_t *first, uint32_t n, uint8_t *scratch) { +void WG_PS_CD(p1u16_u32_p1i8)(uint16_t *first, uint32_t n, uint8_t *scratch) { private_merge_sort_close(first, n, scratch, std::greater{}); } DEVICE_EXTERN_C_INLINE -void __devicelib_default_work_group_private_sort_close_descending_p1u16_u32_p3i8( - uint16_t *first, uint32_t n, uint8_t *scratch) { +void WG_PS_CD(p1u16_u32_p3i8)(uint16_t *first, uint32_t n, uint8_t *scratch) { private_merge_sort_close(first, n, scratch, std::greater{}); } DEVICE_EXTERN_C_INLINE -void __devicelib_default_work_group_private_sort_spread_ascending_p1u16_u32_p1i8( - uint16_t *first, uint32_t n, uint8_t *scratch) { +void WG_PS_SA(p1u16_u32_p1i8)(uint16_t *first, uint32_t n, uint8_t *scratch) { private_merge_sort_spread(first, n, scratch, std::less{}); } DEVICE_EXTERN_C_INLINE -void __devicelib_default_work_group_private_sort_spread_ascending_p1u16_u32_p3i8( - uint16_t *first, uint32_t n, uint8_t *scratch) { +void WG_PS_SA(p1u16_u32_p3i8)(uint16_t *first, uint32_t n, uint8_t *scratch) { private_merge_sort_spread(first, n, scratch, std::less{}); } DEVICE_EXTERN_C_INLINE -void __devicelib_default_work_group_private_sort_spread_descending_p1u16_u32_p1i8( - uint16_t *first, uint32_t n, uint8_t *scratch) { +void WG_PS_SD(p1u16_u32_p1i8)(uint16_t *first, uint32_t n, uint8_t *scratch) { private_merge_sort_spread(first, n, scratch, std::greater{}); } DEVICE_EXTERN_C_INLINE -void __devicelib_default_work_group_private_sort_spread_descending_p1u16_u32_p3i8( - uint16_t *first, uint32_t n, uint8_t *scratch) { +void WG_PS_SD(p1u16_u32_p3i8)(uint16_t *first, uint32_t n, uint8_t *scratch) { private_merge_sort_spread(first, n, scratch, std::greater{}); } DEVICE_EXTERN_C_INLINE -void __devicelib_default_work_group_private_sort_close_ascending_p1u32_u32_p1i8( - uint32_t *first, uint32_t n, uint8_t *scratch) { +void WG_PS_CA(p1u32_u32_p1i8)(uint32_t *first, uint32_t n, uint8_t *scratch) { private_merge_sort_close(first, n, scratch, std::less{}); } DEVICE_EXTERN_C_INLINE -void __devicelib_default_work_group_private_sort_close_ascending_p1u32_u32_p3i8( - uint32_t *first, uint32_t n, uint8_t *scratch) { +void WG_PS_CA(p1u32_u32_p3i8)(uint32_t *first, uint32_t n, uint8_t *scratch) { private_merge_sort_close(first, n, scratch, std::less{}); } DEVICE_EXTERN_C_INLINE -void __devicelib_default_work_group_private_sort_close_descending_p1u32_u32_p1i8( - uint32_t *first, uint32_t n, uint8_t *scratch) { +void WG_PS_CD(p1u32_u32_p1i8)(uint32_t *first, uint32_t n, uint8_t *scratch) { private_merge_sort_close(first, n, scratch, std::greater{}); } DEVICE_EXTERN_C_INLINE -void __devicelib_default_work_group_private_sort_close_descending_p1u32_u32_p3i8( - uint32_t *first, uint32_t n, uint8_t *scratch) { +void WG_PS_CD(p1u32_u32_p3i8)(uint32_t *first, uint32_t n, uint8_t *scratch) { private_merge_sort_close(first, n, scratch, std::greater{}); } DEVICE_EXTERN_C_INLINE -void __devicelib_default_work_group_private_sort_spread_ascending_p1u32_u32_p1i8( - uint32_t *first, uint32_t n, uint8_t *scratch) { +void WG_PS_SA(p1u32_u32_p1i8)(uint32_t *first, uint32_t n, uint8_t *scratch) { private_merge_sort_spread(first, n, scratch, std::less{}); } DEVICE_EXTERN_C_INLINE -void __devicelib_default_work_group_private_sort_spread_ascending_p1u32_u32_p3i8( - uint32_t *first, uint32_t n, uint8_t *scratch) { +void WG_PS_SA(p1u32_u32_p3i8)(uint32_t *first, uint32_t n, uint8_t *scratch) { private_merge_sort_spread(first, n, scratch, std::less{}); } DEVICE_EXTERN_C_INLINE -void __devicelib_default_work_group_private_sort_spread_descending_p1u32_u32_p1i8( - uint32_t *first, uint32_t n, uint8_t *scratch) { +void WG_PS_SD(p1u32_u32_p1i8)(uint32_t *first, uint32_t n, uint8_t *scratch) { private_merge_sort_spread(first, n, scratch, std::greater{}); } DEVICE_EXTERN_C_INLINE -void __devicelib_default_work_group_private_sort_spread_descending_p1u32_u32_p3i8( - uint32_t *first, uint32_t n, uint8_t *scratch) { +void WG_PS_SD(p1u32_u32_p3i8)(uint32_t *first, uint32_t n, uint8_t *scratch) { private_merge_sort_spread(first, n, scratch, std::greater{}); } DEVICE_EXTERN_C_INLINE -void __devicelib_default_work_group_private_sort_close_ascending_p1u64_u32_p1i8( - uint64_t *first, uint32_t n, uint8_t *scratch) { +void WG_PS_CA(p1u64_u32_p1i8)(uint64_t *first, uint32_t n, uint8_t *scratch) { private_merge_sort_close(first, n, scratch, std::less{}); } DEVICE_EXTERN_C_INLINE -void __devicelib_default_work_group_private_sort_close_ascending_p1u64_u32_p3i8( - uint64_t *first, uint32_t n, uint8_t *scratch) { +void WG_PS_CA(p1u64_u32_p3i8)(uint64_t *first, uint32_t n, uint8_t *scratch) { private_merge_sort_close(first, n, scratch, std::less{}); } DEVICE_EXTERN_C_INLINE -void __devicelib_default_work_group_private_sort_close_descending_p1u64_u32_p1i8( - uint64_t *first, uint32_t n, uint8_t *scratch) { +void WG_PS_CD(p1u64_u32_p1i8)(uint64_t *first, uint32_t n, uint8_t *scratch) { private_merge_sort_close(first, n, scratch, std::greater{}); } DEVICE_EXTERN_C_INLINE -void __devicelib_default_work_group_private_sort_close_descending_p1u64_u32_p3i8( - uint64_t *first, uint32_t n, uint8_t *scratch) { +void WG_PS_CD(p1u64_u32_p3i8)(uint64_t *first, uint32_t n, uint8_t *scratch) { private_merge_sort_close(first, n, scratch, std::greater{}); } DEVICE_EXTERN_C_INLINE -void __devicelib_default_work_group_private_sort_spread_ascending_p1u64_u32_p1i8( - uint64_t *first, uint32_t n, uint8_t *scratch) { +void WG_PS_SA(p1u64_u32_p1i8)(uint64_t *first, uint32_t n, uint8_t *scratch) { private_merge_sort_spread(first, n, scratch, std::less{}); } DEVICE_EXTERN_C_INLINE -void __devicelib_default_work_group_private_sort_spread_ascending_p1u64_u32_p3i8( - uint64_t *first, uint32_t n, uint8_t *scratch) { +void WG_PS_SA(p1u64_u32_p3i8)(uint64_t *first, uint32_t n, uint8_t *scratch) { private_merge_sort_spread(first, n, scratch, std::less{}); } DEVICE_EXTERN_C_INLINE -void __devicelib_default_work_group_private_sort_spread_descending_p1u64_u32_p1i8( - uint64_t *first, uint32_t n, uint8_t *scratch) { +void WG_PS_SD(p1u64_u32_p1i8)(uint64_t *first, uint32_t n, uint8_t *scratch) { private_merge_sort_spread(first, n, scratch, std::greater{}); } DEVICE_EXTERN_C_INLINE -void __devicelib_default_work_group_private_sort_spread_descending_p1u64_u32_p3i8( - uint64_t *first, uint32_t n, uint8_t *scratch) { +void WG_PS_SD(p1u64_u32_p3i8)(uint64_t *first, uint32_t n, uint8_t *scratch) { private_merge_sort_spread(first, n, scratch, std::greater{}); } //================= default work grop private sort for fp32 ==================== DEVICE_EXTERN_C_INLINE -void __devicelib_default_work_group_private_sort_close_ascending_p1f32_u32_p1i8( - float *first, uint32_t n, uint8_t *scratch) { +void WG_PS_CA(p1f32_u32_p1i8)(float *first, uint32_t n, uint8_t *scratch) { private_merge_sort_close(first, n, scratch, std::less{}); } DEVICE_EXTERN_C_INLINE -void __devicelib_default_work_group_private_sort_close_ascending_p1f32_u32_p3i8( - float *first, uint32_t n, uint8_t *scratch) { +void WG_PS_CA(p1f32_u32_p3i8)(float *first, uint32_t n, uint8_t *scratch) { private_merge_sort_close(first, n, scratch, std::less{}); } DEVICE_EXTERN_C_INLINE -void __devicelib_default_work_group_private_sort_close_descending_p1f32_u32_p1i8( - float *first, uint32_t n, uint8_t *scratch) { +void WG_PS_CD(p1f32_u32_p1i8)(float *first, uint32_t n, uint8_t *scratch) { private_merge_sort_close(first, n, scratch, std::greater{}); } DEVICE_EXTERN_C_INLINE -void __devicelib_default_work_group_private_sort_close_descending_p1f32_u32_p3i8( - float *first, uint32_t n, uint8_t *scratch) { +void WG_PS_CD(p1f32_u32_p3i8)(float *first, uint32_t n, uint8_t *scratch) { private_merge_sort_close(first, n, scratch, std::greater{}); } DEVICE_EXTERN_C_INLINE -void __devicelib_default_work_group_private_sort_spread_ascending_p1f32_u32_p1i8( - float *first, uint32_t n, uint8_t *scratch) { +void WG_PS_SA(p1f32_u32_p1i8)(float *first, uint32_t n, uint8_t *scratch) { private_merge_sort_spread(first, n, scratch, std::less{}); } DEVICE_EXTERN_C_INLINE -void __devicelib_default_work_group_private_sort_spread_ascending_p1f32_u32_p3i8( - float *first, uint32_t n, uint8_t *scratch) { +void WG_PS_SA(p1f32_u32_p3i8)(float *first, uint32_t n, uint8_t *scratch) { private_merge_sort_spread(first, n, scratch, std::less{}); } DEVICE_EXTERN_C_INLINE -void __devicelib_default_work_group_private_sort_spread_descending_p1f32_u32_p1i8( - float *first, uint32_t n, uint8_t *scratch) { +void WG_PS_SD(p1f32_u32_p1i8)(float *first, uint32_t n, uint8_t *scratch) { private_merge_sort_spread(first, n, scratch, std::greater{}); } DEVICE_EXTERN_C_INLINE -void __devicelib_default_work_group_private_sort_spread_descending_p1f32_u32_p3i8( - float *first, uint32_t n, uint8_t *scratch) { +void WG_PS_SD(p1f32_u32_p3i8)(float *first, uint32_t n, uint8_t *scratch) { private_merge_sort_spread(first, n, scratch, std::greater{}); } //================= default work grop private sort for fp16 ==================== DEVICE_EXTERN_C_INLINE -void __devicelib_default_work_group_private_sort_close_ascending_p1f16_u32_p1i8( - _Float16 *first, uint32_t n, uint8_t *scratch) { +void WG_PS_CA(p1f16_u32_p1i8)(_Float16 *first, uint32_t n, uint8_t *scratch) { private_merge_sort_close(first, n, scratch, [](_Float16 a, _Float16 b) { return (a < b); }); } DEVICE_EXTERN_C_INLINE -void __devicelib_default_work_group_private_sort_close_ascending_p1f16_u32_p3i8( - _Float16 *first, uint32_t n, uint8_t *scratch) { +void WG_PS_CA(p1f16_u32_p3i8)(_Float16 *first, uint32_t n, uint8_t *scratch) { private_merge_sort_close(first, n, scratch, [](_Float16 a, _Float16 b) { return (a < b); }); } DEVICE_EXTERN_C_INLINE -void __devicelib_default_work_group_private_sort_close_descending_p1f16_u32_p1i8( - _Float16 *first, uint32_t n, uint8_t *scratch) { +void WG_PS_CD(p1f16_u32_p1i8)(_Float16 *first, uint32_t n, uint8_t *scratch) { private_merge_sort_close(first, n, scratch, [](_Float16 a, _Float16 b) { return (a > b); }); } DEVICE_EXTERN_C_INLINE -void __devicelib_default_work_group_private_sort_close_descending_p1f16_u32_p3i8( - _Float16 *first, uint32_t n, uint8_t *scratch) { +void WG_PS_CD(p1f16_u32_p3i8)(_Float16 *first, uint32_t n, uint8_t *scratch) { private_merge_sort_close(first, n, scratch, [](_Float16 a, _Float16 b) { return (a > b); }); } DEVICE_EXTERN_C_INLINE -void __devicelib_default_work_group_private_sort_spread_ascending_p1f16_u32_p1i8( - _Float16 *first, uint32_t n, uint8_t *scratch) { +void WG_PS_SA(p1f16_u32_p1i8)(_Float16 *first, uint32_t n, uint8_t *scratch) { private_merge_sort_spread(first, n, scratch, [](_Float16 a, _Float16 b) { return (a < b); }); } DEVICE_EXTERN_C_INLINE -void __devicelib_default_work_group_private_sort_spread_ascending_p1f16_u32_p3i8( - _Float16 *first, uint32_t n, uint8_t *scratch) { +void WG_PS_SA(p1f16_u32_p3i8)(_Float16 *first, uint32_t n, uint8_t *scratch) { private_merge_sort_spread(first, n, scratch, [](_Float16 a, _Float16 b) { return (a < b); }); } DEVICE_EXTERN_C_INLINE -void __devicelib_default_work_group_private_sort_spread_descending_p1f16_u32_p1i8( - _Float16 *first, uint32_t n, uint8_t *scratch) { +void WG_PS_SD(p1f16_u32_p1i8)(_Float16 *first, uint32_t n, uint8_t *scratch) { private_merge_sort_spread(first, n, scratch, [](_Float16 a, _Float16 b) { return (a > b); }); } DEVICE_EXTERN_C_INLINE -void __devicelib_default_work_group_private_sort_spread_descending_p1f16_u32_p3i8( - _Float16 *first, uint32_t n, uint8_t *scratch) { +void WG_PS_SD(p1f16_u32_p3i8)(_Float16 *first, uint32_t n, uint8_t *scratch) { private_merge_sort_spread(first, n, scratch, [](_Float16 a, _Float16 b) { return (a > b); }); } //============= default sub group private sort for signed integer ============= DEVICE_EXTERN_C_INLINE -int8_t -__devicelib_default_sub_group_private_sort_ascending_i8(int8_t value, - uint8_t *scratch) { +int8_t SG_PS_A(i8)(int8_t value, uint8_t *scratch) { return sub_group_merge_sort(value, scratch, std::less{}); } DEVICE_EXTERN_C_INLINE -int16_t -__devicelib_default_sub_group_private_sort_ascending_i16(int16_t value, - uint8_t *scratch) { +int16_t SG_PS_A(i16)(int16_t value, uint8_t *scratch) { return sub_group_merge_sort(value, scratch, std::less{}); } DEVICE_EXTERN_C_INLINE -int32_t -__devicelib_default_sub_group_private_sort_ascending_i32(int32_t value, - uint8_t *scratch) { +int32_t SG_PS_A(i32)(int32_t value, uint8_t *scratch) { return sub_group_merge_sort(value, scratch, std::less{}); } DEVICE_EXTERN_C_INLINE -int64_t -__devicelib_default_sub_group_private_sort_ascending_i64(int64_t value, - uint8_t *scratch) { +int64_t SG_PS_A(i64)(int64_t value, uint8_t *scratch) { return sub_group_merge_sort(value, scratch, std::less{}); } DEVICE_EXTERN_C_INLINE -uint8_t -__devicelib_default_sub_group_private_sort_ascending_u8(uint8_t value, - uint8_t *scratch) { +uint8_t SG_PS_A(u8)(uint8_t value, uint8_t *scratch) { return sub_group_merge_sort(value, scratch, std::less{}); } DEVICE_EXTERN_C_INLINE -uint16_t -__devicelib_default_sub_group_private_sort_ascending_u16(uint16_t value, - uint8_t *scratch) { +uint16_t SG_PS_A(u16)(uint16_t value, uint8_t *scratch) { return sub_group_merge_sort(value, scratch, std::less{}); } DEVICE_EXTERN_C_INLINE -uint32_t -__devicelib_default_sub_group_private_sort_ascending_u32(uint32_t value, - uint8_t *scratch) { +uint32_t SG_PS_A(u32)(uint32_t value, uint8_t *scratch) { return sub_group_merge_sort(value, scratch, std::less{}); } DEVICE_EXTERN_C_INLINE -uint64_t -__devicelib_default_sub_group_private_sort_ascending_u64(uint64_t value, - uint8_t *scratch) { +uint64_t SG_PS_A(u64)(uint64_t value, uint8_t *scratch) { return sub_group_merge_sort(value, scratch, std::less{}); } DEVICE_EXTERN_C_INLINE -float __devicelib_default_sub_group_private_sort_ascending_f32( - float value, uint8_t *scratch) { +float SG_PS_A(f32)(float value, uint8_t *scratch) { return sub_group_merge_sort(value, scratch, std::less{}); } DEVICE_EXTERN_C_INLINE -_Float16 -__devicelib_default_sub_group_private_sort_ascending_f16(_Float16 value, - uint8_t *scratch) { +_Float16 SG_PS_A(f16)(_Float16 value, uint8_t *scratch) { return sub_group_merge_sort(value, scratch, [](_Float16 a, _Float16 b) { return (a < b); }); } DEVICE_EXTERN_C_INLINE -int8_t -__devicelib_default_sub_group_private_sort_descending_i8(int8_t value, - uint8_t *scratch) { +int8_t SG_PS_D(i8)(int8_t value, uint8_t *scratch) { return sub_group_merge_sort(value, scratch, std::greater{}); } DEVICE_EXTERN_C_INLINE -int16_t -__devicelib_default_sub_group_private_sort_descending_i16(int16_t value, - uint8_t *scratch) { +int16_t SG_PS_D(i16)(int16_t value, uint8_t *scratch) { return sub_group_merge_sort(value, scratch, std::greater{}); } DEVICE_EXTERN_C_INLINE -int32_t -__devicelib_default_sub_group_private_sort_descending_i32(int32_t value, - uint8_t *scratch) { +int32_t SG_PS_D(i32)(int32_t value, uint8_t *scratch) { return sub_group_merge_sort(value, scratch, std::greater{}); } DEVICE_EXTERN_C_INLINE -int64_t -__devicelib_default_sub_group_private_sort_descending_i64(int64_t value, - uint8_t *scratch) { +int64_t SG_PS_D(i64)(int64_t value, uint8_t *scratch) { return sub_group_merge_sort(value, scratch, std::greater{}); } DEVICE_EXTERN_C_INLINE -uint8_t -__devicelib_default_sub_group_private_sort_descending_u8(uint8_t value, - uint8_t *scratch) { +uint8_t SG_PS_D(u8)(uint8_t value, uint8_t *scratch) { return sub_group_merge_sort(value, scratch, std::greater{}); } DEVICE_EXTERN_C_INLINE -uint16_t -__devicelib_default_sub_group_private_sort_descending_u16(uint16_t value, - uint8_t *scratch) { +uint16_t SG_PS_D(u16)(uint16_t value, uint8_t *scratch) { return sub_group_merge_sort(value, scratch, std::greater{}); } DEVICE_EXTERN_C_INLINE -uint32_t -__devicelib_default_sub_group_private_sort_descending_u32(uint32_t value, - uint8_t *scratch) { +uint32_t SG_PS_D(u32)(uint32_t value, uint8_t *scratch) { return sub_group_merge_sort(value, scratch, std::greater{}); } DEVICE_EXTERN_C_INLINE -uint64_t -__devicelib_default_sub_group_private_sort_descending_u64(uint64_t value, - uint8_t *scratch) { +uint64_t SG_PS_D(u64)(uint64_t value, uint8_t *scratch) { return sub_group_merge_sort(value, scratch, std::greater{}); } DEVICE_EXTERN_C_INLINE -float __devicelib_default_sub_group_private_sort_descending_f32( - float value, uint8_t *scratch) { +float SG_PS_D(f32)(float value, uint8_t *scratch) { return sub_group_merge_sort(value, scratch, std::greater{}); } DEVICE_EXTERN_C_INLINE -_Float16 -__devicelib_default_sub_group_private_sort_descending_f16(_Float16 value, - uint8_t *scratch) { +_Float16 SG_PS_D(f16)(_Float16 value, uint8_t *scratch) { return sub_group_merge_sort(value, scratch, [](_Float16 a, _Float16 b) { return (a > b); }); } //========= default work grop joint sort for (uint32_t, uint32_t) ============== DEVICE_EXTERN_C_INLINE -void __devicelib_default_work_group_joint_sort_ascending_p1u32_p1u32_u32_p1i8( - uint32_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch) { +void WG_JS_A(p1u32_p1u32_u32_p1i8)(uint32_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { merge_sort_key_value(keys, vals, n, scratch, std::less_equal{}); } DEVICE_EXTERN_C_INLINE -void __devicelib_default_work_group_joint_sort_descending_p1u32_p1u32_u32_p1i8( - uint32_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch) { +void WG_JS_D(p1u32_p1u32_u32_p1i8)(uint32_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { merge_sort_key_value(keys, vals, n, scratch, std::greater_equal{}); } DEVICE_EXTERN_C_INLINE -void __devicelib_default_work_group_joint_sort_ascending_p1u8_p1u32_u32_p1i8( - uint8_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch) { +void WG_JS_A(p1u8_p1u32_u32_p1i8)(uint8_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { merge_sort_key_value(keys, vals, n, scratch, std::less_equal{}); } DEVICE_EXTERN_C_INLINE -void __devicelib_default_work_group_joint_sort_descending_p1u8_p1u32_u32_p1i8( - uint8_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch) { +void WG_JS_D(p1u8_p1u32_u32_p1i8)(uint8_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { merge_sort_key_value(keys, vals, n, scratch, std::greater_equal{}); } DEVICE_EXTERN_C_INLINE -void __devicelib_default_work_group_joint_sort_ascending_p1u32_p1u8_u32_p1i8( - uint32_t *keys, uint8_t *vals, uint32_t n, uint8_t *scratch) { +void WG_JS_A(p1u32_p1u8_u32_p1i8)(uint32_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { merge_sort_key_value(keys, vals, n, scratch, std::less_equal{}); } DEVICE_EXTERN_C_INLINE -void __devicelib_default_work_group_joint_sort_descending_p1u32_p1u8_u32_p1i8( - uint32_t *keys, uint8_t *vals, uint32_t n, uint8_t *scratch) { +void WG_JS_D(p1u32_p1u8_u32_p1i8)(uint32_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { merge_sort_key_value(keys, vals, n, scratch, std::greater_equal{}); } DEVICE_EXTERN_C_INLINE -void __devicelib_default_work_group_private_sort_close_ascending_p1u32_p1u32_u32_p1i8( - uint32_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch) { +void WG_PS_CA(p1u32_p1u32_u32_p1i8)(uint32_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { private_merge_sort_key_value_close(keys, vals, n, scratch, std::less_equal{}); } DEVICE_EXTERN_C_INLINE -void __devicelib_default_work_group_private_sort_close_ascending_p1u8_p1u32_u32_p1i8( - uint8_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch) { +void WG_PS_CA(p1u8_p1u32_u32_p1i8)(uint8_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { private_merge_sort_key_value_close(keys, vals, n, scratch, std::less_equal{}); } diff --git a/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_p1p1_p1.cpp b/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_p1p1_p1.cpp index 0267fc250071d..31e91c58a0d2d 100644 --- a/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_p1p1_p1.cpp +++ b/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_p1p1_p1.cpp @@ -123,4 +123,31 @@ int main() { q, ikeys, ivals, work_group_sorter); std::cout << "KV private sort p1u32_p1u32_u32_p1i8 pass." << std::endl; } + + { + constexpr static int NUM = 35; + uint8_t ikeys[NUM] = {1, 11, 1, 9, 3, 100, 34, 8, 121, + 77, 125, 23, 36, 2, 111, 91, 88, 2, + 51, 213, 181, 183, 31, 142, 216, 1, 199, + 124, 12, 0, 181, 17, 15, 101, 44}; + uint32_t ivals[NUM] = {99, 32, 1, 2, 67, 9123, 453, + 435, 91111, 777, 165, 145, 2456, 88811, + 761, 96, 765, 10000, 6364, 90, 525, + 882, 1, 2423, 9, 4324, 9123, 0, + 1232, 777, 555, 314159, 905, 9831, 84341}; + auto work_group_sorter = [](uint8_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { +#ifdef DES + __devicelib_default_work_group_private_sort_close_descending_p1u8_p1u32_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1u8_p1u32_u32_p1i8( + keys, vals, n, scratch); +#endif + }; + test_work_group_KV_private_sort( + q, ikeys, ivals, work_group_sorter); + std::cout << "KV private sort p1u8_p1u32_u32_p1i8 pass." << std::endl; + } } From f227a4e5f04dc75fc76c37d215ca545ce63d50a9 Mon Sep 17 00:00:00 2001 From: jinge90 Date: Tue, 20 Aug 2024 17:02:45 +0800 Subject: [PATCH 38/71] fix clang format Signed-off-by: jinge90 --- .../group_sort/workgroup_private_KV_sort_p1p1_p1.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_p1p1_p1.cpp b/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_p1p1_p1.cpp index 31e91c58a0d2d..3a03e322618a2 100644 --- a/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_p1p1_p1.cpp +++ b/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_p1p1_p1.cpp @@ -126,10 +126,10 @@ int main() { { constexpr static int NUM = 35; - uint8_t ikeys[NUM] = {1, 11, 1, 9, 3, 100, 34, 8, 121, - 77, 125, 23, 36, 2, 111, 91, 88, 2, - 51, 213, 181, 183, 31, 142, 216, 1, 199, - 124, 12, 0, 181, 17, 15, 101, 44}; + uint8_t ikeys[NUM] = {1, 11, 1, 9, 3, 100, 34, 8, 121, + 77, 125, 23, 36, 2, 111, 91, 88, 2, + 51, 213, 181, 183, 31, 142, 216, 1, 199, + 124, 12, 0, 181, 17, 15, 101, 44}; uint32_t ivals[NUM] = {99, 32, 1, 2, 67, 9123, 453, 435, 91111, 777, 165, 145, 2456, 88811, 761, 96, 765, 10000, 6364, 90, 525, From 724113bfe1abe25dc164acb6aa9d75bb4b2a36f4 Mon Sep 17 00:00:00 2001 From: jinge90 Date: Tue, 10 Sep 2024 14:29:26 +0800 Subject: [PATCH 39/71] add more misc APIs Signed-off-by: jinge90 --- libdevice/fallback-gsort.cpp | 48 ++++++++++++++++++++++++++++++++++++ 1 file changed, 48 insertions(+) diff --git a/libdevice/fallback-gsort.cpp b/libdevice/fallback-gsort.cpp index ddf68665c8d06..da4e967701fd3 100644 --- a/libdevice/fallback-gsort.cpp +++ b/libdevice/fallback-gsort.cpp @@ -988,6 +988,54 @@ void WG_JS_D(p1u32_p1u8_u32_p1i8)(uint32_t *keys, uint8_t *vals, uint32_t n, merge_sort_key_value(keys, vals, n, scratch, std::greater_equal{}); } +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1u8_p1u8_u32_p1i8)(uint8_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1u8_p1u8_u32_p1i8)(uint8_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1u8_p1u16_u32_p1i8)(uint8_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1u8_p1u16_u32_p1i8)(uint8_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1u8_p1u64_u32_p1i8)(uint8_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1u8_p1u64_u32_p1i8)(uint8_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1u16_p1u8_u32_p1i8)(uint16_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1u16_p1u8_u32_p1i8)(uint16_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::greater_equal{}); +} + DEVICE_EXTERN_C_INLINE void WG_PS_CA(p1u32_p1u32_u32_p1i8)(uint32_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch) { From 2670023d2f9134400bab4a3efcac87a9061f1b23 Mon Sep 17 00:00:00 2001 From: jinge90 Date: Tue, 10 Sep 2024 21:59:29 +0800 Subject: [PATCH 40/71] fix cmake issue Signed-off-by: jinge90 --- libdevice/cmake/modules/SYCLLibdevice.cmake | 4 + libdevice/fallback-gsort.cpp | 209 +++++++++++++++++--- 2 files changed, 189 insertions(+), 24 deletions(-) diff --git a/libdevice/cmake/modules/SYCLLibdevice.cmake b/libdevice/cmake/modules/SYCLLibdevice.cmake index 8480e4273b16a..3710f3ea466ef 100644 --- a/libdevice/cmake/modules/SYCLLibdevice.cmake +++ b/libdevice/cmake/modules/SYCLLibdevice.cmake @@ -294,6 +294,10 @@ add_devicelibs(libsycl-fallback-bfloat16 add_devicelibs(libsycl-native-bfloat16 SRC bfloat16_wrapper.cpp DEPENDENCIES ${bfloat16_obj_deps}) +add_devicelibs(libsycl-fallback-gsort + SRC fallback-gsort.cpp + DEPENDENCIES ${gsort_obj_deps} + EXTRA_OPTS -fno-sycl-instrument-device-code) # Create dependency and source lists for Intel math function libraries. file(MAKE_DIRECTORY ${obj_binary_dir}/libdevice) diff --git a/libdevice/fallback-gsort.cpp b/libdevice/fallback-gsort.cpp index da4e967701fd3..4d37f7dbb349f 100644 --- a/libdevice/fallback-gsort.cpp +++ b/libdevice/fallback-gsort.cpp @@ -952,90 +952,251 @@ _Float16 SG_PS_D(f16)(_Float16 value, uint8_t *scratch) { } //========= default work grop joint sort for (uint32_t, uint32_t) ============== + +// uint8_t as key type DEVICE_EXTERN_C_INLINE -void WG_JS_A(p1u32_p1u32_u32_p1i8)(uint32_t *keys, uint32_t *vals, uint32_t n, - uint8_t *scratch) { - merge_sort_key_value(keys, vals, n, scratch, std::less_equal{}); +void WG_JS_A(p1u8_p1u8_u32_p1i8)(uint8_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::less_equal{}); } DEVICE_EXTERN_C_INLINE -void WG_JS_D(p1u32_p1u32_u32_p1i8)(uint32_t *keys, uint32_t *vals, uint32_t n, - uint8_t *scratch) { - merge_sort_key_value(keys, vals, n, scratch, std::greater_equal{}); +void WG_JS_D(p1u8_p1u8_u32_p1i8)(uint8_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::greater_equal{}); } DEVICE_EXTERN_C_INLINE -void WG_JS_A(p1u8_p1u32_u32_p1i8)(uint8_t *keys, uint32_t *vals, uint32_t n, +void WG_JS_A(p1u8_p1u16_u32_p1i8)(uint8_t *keys, uint16_t *vals, uint32_t n, uint8_t *scratch) { merge_sort_key_value(keys, vals, n, scratch, std::less_equal{}); } DEVICE_EXTERN_C_INLINE -void WG_JS_D(p1u8_p1u32_u32_p1i8)(uint8_t *keys, uint32_t *vals, uint32_t n, +void WG_JS_D(p1u8_p1u16_u32_p1i8)(uint8_t *keys, uint16_t *vals, uint32_t n, uint8_t *scratch) { merge_sort_key_value(keys, vals, n, scratch, std::greater_equal{}); } DEVICE_EXTERN_C_INLINE -void WG_JS_A(p1u32_p1u8_u32_p1i8)(uint32_t *keys, uint8_t *vals, uint32_t n, +void WG_JS_A(p1u8_p1u32_u32_p1i8)(uint8_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch) { - merge_sort_key_value(keys, vals, n, scratch, std::less_equal{}); + merge_sort_key_value(keys, vals, n, scratch, std::less_equal{}); } DEVICE_EXTERN_C_INLINE -void WG_JS_D(p1u32_p1u8_u32_p1i8)(uint32_t *keys, uint8_t *vals, uint32_t n, +void WG_JS_D(p1u8_p1u32_u32_p1i8)(uint8_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch) { - merge_sort_key_value(keys, vals, n, scratch, std::greater_equal{}); + merge_sort_key_value(keys, vals, n, scratch, std::greater_equal{}); } DEVICE_EXTERN_C_INLINE -void WG_JS_A(p1u8_p1u8_u32_p1i8)(uint8_t *keys, uint8_t *vals, uint32_t n, +void WG_JS_A(p1u8_p1u64_u32_p1i8)(uint8_t *keys, uint64_t *vals, uint32_t n, uint8_t *scratch) { merge_sort_key_value(keys, vals, n, scratch, std::less_equal{}); } DEVICE_EXTERN_C_INLINE -void WG_JS_D(p1u8_p1u8_u32_p1i8)(uint8_t *keys, uint8_t *vals, uint32_t n, +void WG_JS_D(p1u8_p1u64_u32_p1i8)(uint8_t *keys, uint64_t *vals, uint32_t n, uint8_t *scratch) { merge_sort_key_value(keys, vals, n, scratch, std::greater_equal{}); } DEVICE_EXTERN_C_INLINE -void WG_JS_A(p1u8_p1u16_u32_p1i8)(uint8_t *keys, uint16_t *vals, uint32_t n, +void WG_JS_A(p1u8_p1f32_u32_p1i8)(uint8_t *keys, float *vals, uint32_t n, uint8_t *scratch) { merge_sort_key_value(keys, vals, n, scratch, std::less_equal{}); } DEVICE_EXTERN_C_INLINE -void WG_JS_D(p1u8_p1u16_u32_p1i8)(uint8_t *keys, uint16_t *vals, uint32_t n, +void WG_JS_D(p1u8_p1f32_u32_p1i8)(uint8_t *keys, float *vals, uint32_t n, uint8_t *scratch) { merge_sort_key_value(keys, vals, n, scratch, std::greater_equal{}); } +// uint16_t as key type DEVICE_EXTERN_C_INLINE -void WG_JS_A(p1u8_p1u64_u32_p1i8)(uint8_t *keys, uint64_t *vals, uint32_t n, +void WG_JS_A(p1u16_p1u8_u32_p1i8)(uint16_t *keys, uint8_t *vals, uint32_t n, uint8_t *scratch) { - merge_sort_key_value(keys, vals, n, scratch, std::less_equal{}); + merge_sort_key_value(keys, vals, n, scratch, std::less_equal{}); } DEVICE_EXTERN_C_INLINE -void WG_JS_D(p1u8_p1u64_u32_p1i8)(uint8_t *keys, uint64_t *vals, uint32_t n, +void WG_JS_D(p1u16_p1u8_u32_p1i8)(uint16_t *keys, uint8_t *vals, uint32_t n, uint8_t *scratch) { - merge_sort_key_value(keys, vals, n, scratch, std::greater_equal{}); + merge_sort_key_value(keys, vals, n, scratch, std::greater_equal{}); } DEVICE_EXTERN_C_INLINE -void WG_JS_A(p1u16_p1u8_u32_p1i8)(uint16_t *keys, uint8_t *vals, uint32_t n, - uint8_t *scratch) { +void WG_JS_A(p1u16_p1u16_u32_p1i8)(uint16_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { merge_sort_key_value(keys, vals, n, scratch, std::less_equal{}); } DEVICE_EXTERN_C_INLINE -void WG_JS_D(p1u16_p1u8_u32_p1i8)(uint16_t *keys, uint8_t *vals, uint32_t n, - uint8_t *scratch) { +void WG_JS_D(p1u16_p1u16_u32_p1i8)(uint16_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1u16_p1u32_u32_p1i8)(uint16_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1u16_p1u32_u32_p1i8)(uint16_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1u16_p1u64_u32_p1i8)(uint16_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1u16_p1u64_u32_p1i8)(uint16_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { merge_sort_key_value(keys, vals, n, scratch, std::greater_equal{}); } +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1u16_p1f32_u32_p1i8)(uint16_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1u16_p1f32_u32_p1i8)(uint16_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::greater_equal{}); +} + +// uint32_t as key type +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1u32_p1u8_u32_p1i8)(uint32_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1u32_p1u8_u32_p1i8)(uint32_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1u32_p1u16_u32_p1i8)(uint32_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1u32_p1u16_u32_p1i8)(uint32_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1u32_p1u32_u32_p1i8)(uint32_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1u32_p1u32_u32_p1i8)(uint32_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1u32_p1u64_u32_p1i8)(uint32_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1u32_p1u64_u32_p1i8)(uint32_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1u32_p1f32_u32_p1i8)(uint32_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1u32_p1f32_u32_p1i8)(uint32_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::greater_equal{}); +} + +// uint64_t as key type +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1u64_p1u8_u32_p1i8)(uint64_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1u64_p1u8_u32_p1i8)(uint64_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1u64_p1u16_u32_p1i8)(uint64_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1u64_p1u16_u32_p1i8)(uint64_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1u64_p1u32_u32_p1i8)(uint64_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1u64_p1u32_u32_p1i8)(uint64_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1u64_p1u64_u32_p1i8)(uint64_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1u64_p1u64_u32_p1i8)(uint64_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1u64_p1f32_u32_p1i8)(uint64_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1u64_p1f32_u32_p1i8)(uint64_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::greater_equal{}); +} + DEVICE_EXTERN_C_INLINE void WG_PS_CA(p1u32_p1u32_u32_p1i8)(uint32_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch) { From 3139b2864a064122a0b112a083e6e148d2ca57c8 Mon Sep 17 00:00:00 2001 From: jinge90 Date: Thu, 12 Sep 2024 11:05:23 +0800 Subject: [PATCH 41/71] add int8_t value for kv gsort Signed-off-by: jinge90 --- libdevice/fallback-gsort.cpp | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/libdevice/fallback-gsort.cpp b/libdevice/fallback-gsort.cpp index 4d37f7dbb349f..296cc693919c7 100644 --- a/libdevice/fallback-gsort.cpp +++ b/libdevice/fallback-gsort.cpp @@ -966,6 +966,18 @@ void WG_JS_D(p1u8_p1u8_u32_p1i8)(uint8_t *keys, uint8_t *vals, uint32_t n, merge_sort_key_value(keys, vals, n, scratch, std::greater_equal{}); } +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1u8_p1i8_u32_p1i8)(uint8_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1u8_p1i8_u32_p1i8)(uint8_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::greater_equal{}); +} + DEVICE_EXTERN_C_INLINE void WG_JS_A(p1u8_p1u16_u32_p1i8)(uint8_t *keys, uint16_t *vals, uint32_t n, uint8_t *scratch) { From ffdee6e68fb092015189dca1f3f5a455eb7791d1 Mon Sep 17 00:00:00 2001 From: jinge90 Date: Thu, 12 Sep 2024 16:36:43 +0800 Subject: [PATCH 42/71] fix KV buble sort stable issue Signed-off-by: jinge90 --- libdevice/fallback-gsort.cpp | 50 +++++++++++- libdevice/sort_helper.hpp | 24 +++--- .../group_joint_KV_sort_p1p1_p1.hpp | 29 ++++--- .../workgroup_joint_KV_sort_p1p1_p1.cpp | 80 ++++++++++++++++++- .../workgroup_private_KV_sort_p1p1_p1.cpp | 1 + 5 files changed, 157 insertions(+), 27 deletions(-) diff --git a/libdevice/fallback-gsort.cpp b/libdevice/fallback-gsort.cpp index 296cc693919c7..d11410bdd3f9e 100644 --- a/libdevice/fallback-gsort.cpp +++ b/libdevice/fallback-gsort.cpp @@ -966,16 +966,20 @@ void WG_JS_D(p1u8_p1u8_u32_p1i8)(uint8_t *keys, uint8_t *vals, uint32_t n, merge_sort_key_value(keys, vals, n, scratch, std::greater_equal{}); } +// For int8_t values, the size and alignment are same as uint8_t, we use same +// implementation as uint8_t values to reduce code size. DEVICE_EXTERN_C_INLINE void WG_JS_A(p1u8_p1i8_u32_p1i8)(uint8_t *keys, int8_t *vals, uint32_t n, uint8_t *scratch) { - merge_sort_key_value(keys, vals, n, scratch, std::less_equal{}); + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::less_equal{}); } DEVICE_EXTERN_C_INLINE void WG_JS_D(p1u8_p1i8_u32_p1i8)(uint8_t *keys, int8_t *vals, uint32_t n, uint8_t *scratch) { - merge_sort_key_value(keys, vals, n, scratch, std::greater_equal{}); + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::greater_equal{}); } DEVICE_EXTERN_C_INLINE @@ -990,6 +994,20 @@ void WG_JS_D(p1u8_p1u16_u32_p1i8)(uint8_t *keys, uint16_t *vals, uint32_t n, merge_sort_key_value(keys, vals, n, scratch, std::greater_equal{}); } +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1u8_p1i16_u32_p1i8)(uint8_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1u8_p1i16_u32_p1i8)(uint8_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::greater_equal{}); +} + DEVICE_EXTERN_C_INLINE void WG_JS_A(p1u8_p1u32_u32_p1i8)(uint8_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch) { @@ -1002,6 +1020,20 @@ void WG_JS_D(p1u8_p1u32_u32_p1i8)(uint8_t *keys, uint32_t *vals, uint32_t n, merge_sort_key_value(keys, vals, n, scratch, std::greater_equal{}); } +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1u8_p1i32_u32_p1i8)(uint8_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1u8_p1i32_u32_p1i8)(uint8_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::greater_equal{}); +} + DEVICE_EXTERN_C_INLINE void WG_JS_A(p1u8_p1u64_u32_p1i8)(uint8_t *keys, uint64_t *vals, uint32_t n, uint8_t *scratch) { @@ -1014,6 +1046,20 @@ void WG_JS_D(p1u8_p1u64_u32_p1i8)(uint8_t *keys, uint64_t *vals, uint32_t n, merge_sort_key_value(keys, vals, n, scratch, std::greater_equal{}); } +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1u8_p1i64_u32_p1i8)(uint8_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1u8_p1i64_u32_p1i8)(uint8_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::greater_equal{}); +} + DEVICE_EXTERN_C_INLINE void WG_JS_A(p1u8_p1f32_u32_p1i8)(uint8_t *keys, float *vals, uint32_t n, uint8_t *scratch) { diff --git a/libdevice/sort_helper.hpp b/libdevice/sort_helper.hpp index 071db026fc02f..13ca9b1e1c64f 100644 --- a/libdevice/sort_helper.hpp +++ b/libdevice/sort_helper.hpp @@ -158,22 +158,26 @@ Tp sub_group_merge_sort(Tp value, uint8_t *scratch, Compare comp) { } template -void bubble_sort_key_value(KeyT *keys, ValT *vals, const size_t beg, - const size_t end, Compare comp) { +void bubble_sort_key_value_stable(KeyT *keys, ValT *vals, const size_t beg, + const size_t end, Compare comp) { if (beg < end) { KeyT temp_key; ValT temp_val; - for (size_t i = beg; i < end; ++i) - for (size_t j = i + 1; j < end; ++j) { - if (!comp(keys[i], keys[j])) { + size_t swaps; + do { + swaps = 0; + for (size_t i = beg; i < (end - 1); ++i) { + if (!comp(keys[i], keys[i + 1])) { temp_key = keys[i]; - keys[i] = keys[j]; - keys[j] = temp_key; + keys[i] = keys[i + 1]; + keys[i + 1] = temp_key; temp_val = vals[i]; - vals[i] = vals[j]; - vals[j] = temp_val; + vals[i] = vals[i + 1]; + vals[i + 1] = temp_val; + swaps += 1; } } + } while (swaps != 0); } } @@ -240,7 +244,7 @@ void merge_sort_key_value(KeyT *keys, ValT *vals, size_t n, uint8_t *scratch, size_t bubble_beg, bubble_end; bubble_beg = (idx * chunk_size) >= n ? n : idx * chunk_size; bubble_end = ((idx + 1) * chunk_size) > n ? n : (idx + 1) * chunk_size; - bubble_sort_key_value(keys, vals, bubble_beg, bubble_end, comp); + bubble_sort_key_value_stable(keys, vals, bubble_beg, bubble_end, comp); group_barrier(); bool data_in_scratch = false; KeyT *scratch_keys = reinterpret_cast(scratch); diff --git a/sycl/test-e2e/DeviceLib/group_sort/group_joint_KV_sort_p1p1_p1.hpp b/sycl/test-e2e/DeviceLib/group_sort/group_joint_KV_sort_p1p1_p1.hpp index ff3b204e20b57..15d18d155f455 100644 --- a/sycl/test-e2e/DeviceLib/group_sort/group_joint_KV_sort_p1p1_p1.hpp +++ b/sycl/test-e2e/DeviceLib/group_sort/group_joint_KV_sort_p1p1_p1.hpp @@ -1,19 +1,26 @@ #pragma once #include -#ifdef __SYCL_DEVICE_ONLY__ -SYCL_EXTERNAL extern "C" void +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p1u8_p1u8_u32_p1i8( + uint8_t *keys, uint8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_descending_p1u8_p1u8_u32_p1i8( + uint8_t *keys, uint8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p1u8_p1i8_u32_p1i8( + uint8_t *keys, int8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_descending_p1u8_p1i8_u32_p1i8( + uint8_t *keys, int8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void __devicelib_default_work_group_joint_sort_ascending_p1u32_p1u32_u32_p1i8( uint32_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch); -SYCL_EXTERNAL extern "C" void +__DPCPP_SYCL_EXTERNAL extern "C" void __devicelib_default_work_group_joint_sort_descending_p1u32_p1u32_u32_p1i8( uint32_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch); -#else -extern "C" void -__devicelib_default_work_group_joint_sort_ascending_p1u32_p1u32_u32_p1i8( - uint32_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch) {} -extern "C" void -__devicelib_default_work_group_joint_sort_descending_p1u32_p1u32_u32_p1i8( - uint32_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch) {} -#endif diff --git a/sycl/test-e2e/DeviceLib/group_sort/workgroup_joint_KV_sort_p1p1_p1.cpp b/sycl/test-e2e/DeviceLib/group_sort/workgroup_joint_KV_sort_p1p1_p1.cpp index dbf55e3295bb6..c039303244c0a 100644 --- a/sycl/test-e2e/DeviceLib/group_sort/workgroup_joint_KV_sort_p1p1_p1.cpp +++ b/sycl/test-e2e/DeviceLib/group_sort/workgroup_joint_KV_sort_p1p1_p1.cpp @@ -20,8 +20,13 @@ using namespace sycl; template -void test_work_group_KV_joint_sort(sycl::queue &q, KeyT input_keys[NUM], - ValT input_vals[NUM], SortHelper gsh) { +void test_work_group_KV_joint_sort(sycl::queue &q, KeyT keys[NUM], + ValT vals[NUM], SortHelper gsh) { + + KeyT input_keys[NUM]; + ValT input_vals[NUM]; + memcpy((void *)input_keys, (void *)keys, NUM * sizeof(KeyT)); + memcpy((void *)input_vals, (void *)vals, NUM * sizeof(ValT)); size_t scratch_size = NUM * (sizeof(KeyT) + sizeof(ValT)) + std::max(alignof(KeyT), alignof(ValT)); uint8_t *scratch_ptr = @@ -47,6 +52,7 @@ void test_work_group_KV_joint_sort(sycl::queue &q, KeyT input_keys[NUM], std::cout << "key: " << std::get<0>(sorted_vec[idx]) << " val: " << std::get<1>(sorted_vec[idx]) << std::endl; }*/ + nd_range<1> num_items((range<1>(wg_size)), (range<1>(wg_size))); { buffer ikeys_buf(input_keys, NUM); @@ -61,11 +67,19 @@ void test_work_group_KV_joint_sort(sycl::queue &q, KeyT input_keys[NUM], }); }).wait(); } + + /*for (size_t idx = 0; idx < NUM; ++idx) { + std::cout << "key: " << (input_keys[idx]) << " val: " << + (input_vals[idx]) << std::endl; + }*/ + + sycl::free(scratch_ptr, q); bool fails = false; for (size_t idx = 0; idx < NUM; ++idx) { if ((input_keys[idx] != std::get<0>(sorted_vec[idx])) || (input_vals[idx] != std::get<1>(sorted_vec[idx]))) { fails = true; + std::cout << idx << std::endl; break; } } @@ -75,6 +89,32 @@ void test_work_group_KV_joint_sort(sycl::queue &q, KeyT input_keys[NUM], int main() { queue q; + { + constexpr static int NUM = 23; + uint8_t ikeys[NUM] = {1, 11, 1, 9, 3, 100, 34, 8, 121, 77, 125, 23, + 36, 2, 111, 91, 88, 2, 51, 91, 81, 122, 22}; + uint8_t ivals[NUM] = {99, 32, 1, 2, 67, 123, 253, 35, + 111, 77, 165, 145, 254, 11, 161, 96, + 165, 100, 64, 90, 255, 147, 135}; + auto work_group_sorter = [](uint8_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { +#if defined(__SYCL_DEVICE_ONLY__) && (defined(__SPIR__) || defined(__SPIRV__)) +#ifdef DES + __devicelib_default_work_group_joint_sort_descending_p1u8_p1u8_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_joint_sort_ascending_p1u8_p1u8_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif + }; + test_work_group_KV_joint_sort( + q, ikeys, ivals, work_group_sorter); + std::cout << "KV joint sort p1u8_p1u8_u32_p1i8 pass." + << std::endl; + } + { constexpr static int NUM = 21; uint32_t ikeys[NUM] = {1, 11, 1, 9, 3, 100, 34, 8, 121, 77, 125, @@ -84,18 +124,21 @@ int main() { 761, 96, 765, 10000, 6364, 90, 525}; auto work_group_sorter = [](uint32_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch) { +#if defined(__SYCL_DEVICE_ONLY__) && (defined(__SPIR__) || defined(__SPIRV__)) #ifdef DES __devicelib_default_work_group_joint_sort_descending_p1u32_p1u32_u32_p1i8( keys, vals, n, scratch); #else __devicelib_default_work_group_joint_sort_ascending_p1u32_p1u32_u32_p1i8( keys, vals, n, scratch); +#endif #endif }; test_work_group_KV_joint_sort( q, ikeys, ivals, work_group_sorter); - std::cout << "KV joint sort p1u32_p1u32_u32_p1i8 pass." << std::endl; + std::cout << "KV joint sort p1u32_p1u32_u32_p1i8 pass." + << std::endl; } { @@ -110,18 +153,47 @@ int main() { 9, 4324, 9123, 0, 1232, 777, 555, 314159}; auto work_group_sorter = [](uint32_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch) { + +#if defined(__SYCL_DEVICE_ONLY__) && (defined(__SPIR__) || defined(__SPIRV__)) #ifdef DES __devicelib_default_work_group_joint_sort_descending_p1u32_p1u32_u32_p1i8( keys, vals, n, scratch); #else __devicelib_default_work_group_joint_sort_ascending_p1u32_p1u32_u32_p1i8( keys, vals, n, scratch); +#endif #endif }; + + test_work_group_KV_joint_sort( + q, ikeys, ivals, work_group_sorter); + std::cout << "KV joint sort p1u32_p1u32_u32_p1i8 pass." + << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys, ivals, work_group_sorter); + std::cout << "KV joint sort p1u32_p1u32_u32_p1i8 pass." + << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys, ivals, work_group_sorter); + std::cout << "KV joint sort p1u32_p1u32_u32_p1i8 pass." + << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys, ivals, work_group_sorter); + std::cout << "KV joint sort p1u32_p1u32_u32_p1i8 pass." + << std::endl; + test_work_group_KV_joint_sort( q, ikeys, ivals, work_group_sorter); - std::cout << "KV joint sort p1u32_p1u32_u32_p1i8 pass." << std::endl; + std::cout << "KV joint sort p1u32_p1u32_u32_p1i8 pass." + << std::endl; } return 0; diff --git a/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_p1p1_p1.cpp b/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_p1p1_p1.cpp index 3a03e322618a2..ec92477842da0 100644 --- a/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_p1p1_p1.cpp +++ b/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_p1p1_p1.cpp @@ -83,6 +83,7 @@ void test_work_group_KV_private_sort(sycl::queue &q, KeyT input_keys[NUM], }).wait(); } + sycl::free(scratch_ptr, q); bool fails = false; for (size_t idx = 0; idx < NUM; ++idx) { if ((output_keys[idx] != std::get<0>(sorted_vec[idx])) || From c52baf4e17cf86f65fe16ade64fde09a6423058c Mon Sep 17 00:00:00 2001 From: jinge90 Date: Thu, 12 Sep 2024 17:19:13 +0800 Subject: [PATCH 43/71] add e2e tests for p1u8_i8 Signed-off-by: jinge90 --- .../workgroup_joint_KV_sort_p1p1_p1.cpp | 83 +++++++++++++++++++ 1 file changed, 83 insertions(+) diff --git a/sycl/test-e2e/DeviceLib/group_sort/workgroup_joint_KV_sort_p1p1_p1.cpp b/sycl/test-e2e/DeviceLib/group_sort/workgroup_joint_KV_sort_p1p1_p1.cpp index c039303244c0a..a9bfeb305c637 100644 --- a/sycl/test-e2e/DeviceLib/group_sort/workgroup_joint_KV_sort_p1p1_p1.cpp +++ b/sycl/test-e2e/DeviceLib/group_sort/workgroup_joint_KV_sort_p1p1_p1.cpp @@ -96,6 +96,9 @@ int main() { uint8_t ivals[NUM] = {99, 32, 1, 2, 67, 123, 253, 35, 111, 77, 165, 145, 254, 11, 161, 96, 165, 100, 64, 90, 255, 147, 135}; + int8_t ivals2[NUM] = {-1, 23, 0, 123, 99, 44, 8, 11, -67, -54, -113, 7, + 1, 81, -81, 21, 25, -38, 66, 99, -121, 34, 45}; + auto work_group_sorter = [](uint8_t *keys, uint8_t *vals, uint32_t n, uint8_t *scratch) { #if defined(__SYCL_DEVICE_ONLY__) && (defined(__SPIR__) || defined(__SPIRV__)) @@ -108,11 +111,91 @@ int main() { #endif #endif }; + + auto work_group_sorter1 = [](uint8_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { +#if defined(__SYCL_DEVICE_ONLY__) && (defined(__SPIR__) || defined(__SPIRV__)) +#ifdef DES + __devicelib_default_work_group_joint_sort_descending_p1u8_p1i8_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_joint_sort_ascending_p1u8_p1i8_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif + }; + + test_work_group_KV_joint_sort( + q, ikeys, ivals, work_group_sorter); + std::cout << "KV joint sort p1u8_p1u8_u32_p1i8 pass." + << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys, ivals, work_group_sorter); + std::cout << "KV joint sort p1u8_p1u8_u32_p1i8 pass." + << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys, ivals, work_group_sorter); + std::cout << "KV joint sort p1u8_p1u8_u32_p1i8 pass." + << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys, ivals, work_group_sorter); + std::cout << "KV joint sort p1u8_p1u8_u32_p1i8 pass." + << std::endl; + test_work_group_KV_joint_sort( q, ikeys, ivals, work_group_sorter); std::cout << "KV joint sort p1u8_p1u8_u32_p1i8 pass." << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys, ivals, work_group_sorter); + std::cout << "KV joint sort p1u8_p1u8_u32_p1i8 pass." + << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys, ivals2, work_group_sorter1); + std::cout << "KV joint sort p1u8_p1i8_u32_p1i8 pass." + << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys, ivals2, work_group_sorter1); + std::cout << "KV joint sort p1u8_p1i8_u32_p1i8 pass." + << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys, ivals2, work_group_sorter1); + std::cout << "KV joint sort p1u8_p1i8_u32_p1i8 pass." + << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys, ivals2, work_group_sorter1); + std::cout << "KV joint sort p1u8_p1i8_u32_p1i8 pass." + << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys, ivals2, work_group_sorter1); + std::cout << "KV joint sort p1u8_p1i8_u32_p1i8 pass." + << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys, ivals2, work_group_sorter1); + std::cout << "KV joint sort p1u8_p1i8_u32_p1i8 pass." + << std::endl; } { From 328b25f23a6ebf99c8090934c9be6de28aa6dc98 Mon Sep 17 00:00:00 2001 From: jinge90 Date: Fri, 13 Sep 2024 14:46:34 +0800 Subject: [PATCH 44/71] add e2e tests for key_u8 Signed-off-by: jinge90 --- .../group_joint_KV_sort_p1p1_p1.hpp | 48 ++ .../workgroup_joint_KV_sort_p1p1_p1.cpp | 431 ++++++++++++++++-- 2 files changed, 440 insertions(+), 39 deletions(-) diff --git a/sycl/test-e2e/DeviceLib/group_sort/group_joint_KV_sort_p1p1_p1.hpp b/sycl/test-e2e/DeviceLib/group_sort/group_joint_KV_sort_p1p1_p1.hpp index 15d18d155f455..582ab62d04116 100644 --- a/sycl/test-e2e/DeviceLib/group_sort/group_joint_KV_sort_p1p1_p1.hpp +++ b/sycl/test-e2e/DeviceLib/group_sort/group_joint_KV_sort_p1p1_p1.hpp @@ -17,6 +17,54 @@ __DPCPP_SYCL_EXTERNAL extern "C" void __devicelib_default_work_group_joint_sort_descending_p1u8_p1i8_u32_p1i8( uint8_t *keys, int8_t *vals, uint32_t n, uint8_t *scratch); +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p1u8_p1u16_u32_p1i8( + uint8_t *keys, uint16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_descending_p1u8_p1u16_u32_p1i8( + uint8_t *keys, uint16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p1u8_p1i16_u32_p1i8( + uint8_t *keys, int16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_descending_p1u8_p1i16_u32_p1i8( + uint8_t *keys, int16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p1u8_p1u32_u32_p1i8( + uint8_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_descending_p1u8_p1u32_u32_p1i8( + uint8_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p1u8_p1i32_u32_p1i8( + uint8_t *keys, int32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_descending_p1u8_p1i32_u32_p1i8( + uint8_t *keys, int32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p1u8_p1u64_u32_p1i8( + uint8_t *keys, uint64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_descending_p1u8_p1u64_u32_p1i8( + uint8_t *keys, uint64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p1u8_p1i64_u32_p1i8( + uint8_t *keys, int64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_descending_p1u8_p1i64_u32_p1i8( + uint8_t *keys, int64_t *vals, uint32_t n, uint8_t *scratch); + __DPCPP_SYCL_EXTERNAL extern "C" void __devicelib_default_work_group_joint_sort_ascending_p1u32_p1u32_u32_p1i8( uint32_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch); diff --git a/sycl/test-e2e/DeviceLib/group_sort/workgroup_joint_KV_sort_p1p1_p1.cpp b/sycl/test-e2e/DeviceLib/group_sort/workgroup_joint_KV_sort_p1p1_p1.cpp index a9bfeb305c637..c91b78491c9a6 100644 --- a/sycl/test-e2e/DeviceLib/group_sort/workgroup_joint_KV_sort_p1p1_p1.cpp +++ b/sycl/test-e2e/DeviceLib/group_sort/workgroup_joint_KV_sort_p1p1_p1.cpp @@ -48,10 +48,10 @@ void test_work_group_KV_joint_sort(sycl::queue &q, KeyT keys[NUM], #endif std::stable_sort(sorted_vec.begin(), sorted_vec.end(), kv_tuple_comp); - /*for (size_t idx = 0; idx < NUM; ++idx) { - std::cout << "key: " << std::get<0>(sorted_vec[idx]) << " val: " << - std::get<1>(sorted_vec[idx]) << std::endl; - }*/ + /* for (size_t idx = 0; idx < NUM; ++idx) { + std::cout << "key: " << (int)std::get<0>(sorted_vec[idx]) << " val: " << + (int)std::get<1>(sorted_vec[idx]) << std::endl; + } */ nd_range<1> num_items((range<1>(wg_size)), (range<1>(wg_size))); { @@ -68,9 +68,9 @@ void test_work_group_KV_joint_sort(sycl::queue &q, KeyT keys[NUM], }).wait(); } - /*for (size_t idx = 0; idx < NUM; ++idx) { - std::cout << "key: " << (input_keys[idx]) << " val: " << - (input_vals[idx]) << std::endl; + /* for (size_t idx = 0; idx < NUM; ++idx) { + std::cout << "key: " << (int)(input_keys[idx]) << " val: " << + (int)(input_vals[idx]) << std::endl; }*/ sycl::free(scratch_ptr, q); @@ -96,8 +96,44 @@ int main() { uint8_t ivals[NUM] = {99, 32, 1, 2, 67, 123, 253, 35, 111, 77, 165, 145, 254, 11, 161, 96, 165, 100, 64, 90, 255, 147, 135}; - int8_t ivals2[NUM] = {-1, 23, 0, 123, 99, 44, 8, 11, -67, -54, -113, 7, - 1, 81, -81, 21, 25, -38, 66, 99, -121, 34, 45}; + int8_t ivals2[NUM] = {-1, 23, 0, 123, 99, 44, 8, 11, -67, -54, -113, 7, + 1, 81, -81, 21, 25, -38, 66, 99, -121, 34, 45}; + + uint16_t ivals3[NUM] = {36882, 47565, 20664, 59517, 55773, 5858, + 30720, 64786, 42129, 13618, 62202, 16225, + 54751, 38268, 25563, 44332, 45475, 12550, + 5478, 3301, 3779, 25518, 6659}; + + int16_t ivals4[NUM] = {3882, 7565, -20664, 9517, -5773, 5858, + -30720, 86, 429, 13618, 2202, -16225, + 751, -368, 25563, -4332, -5475, 12550, + 5478, 3301, 3779, 25518, 6659}; + + uint32_t ivals5[NUM] = { + 2, 771, 76, 450, 76421894, 273377, + 85040, 831870667, 402825730, 2774821, 10786, 47164, + 1951118976, 75033606, 35755, 93312, 21, 3257266819, + 1065029990, 139884, 11355, 1464548796, 403290}; + + int32_t ivals6[NUM] = { + 2, 771, -76, 450, 76421894, 273377, + 85040, 831870667, -402825730, -2774821, 10786, 47164, + 1951118976, -75033606, 35755, 93312, 21, 57266819, + -1065029990, 139884, -11355, 1464548796, -403290}; + + uint64_t ivals7[NUM] = { + 2, 771, 76, 1112450, 76421894, + 898273377, 66585040, 11831870667, 402825730, 2774821, + 10786, 47164, 1951118976, 75033606, 99935755, + 9331211, 21, 3257266819, 10650299901112, 1224139884, + 9837411355, 1464548796, 403290}; + + int64_t ivals8[NUM] = { + 2, 771, 76, 1112450, 76421894, + 898273377, 66585040, -11831870667, -402825730, 2774821, + 10786, 47164, 1951118976, 75033606, 10, + 0, 21, -3257266819, -10650299901112, 76, + -9837411355, 1464548796, 403290}; auto work_group_sorter = [](uint8_t *keys, uint8_t *vals, uint32_t n, uint8_t *scratch) { @@ -113,7 +149,7 @@ int main() { }; auto work_group_sorter1 = [](uint8_t *keys, int8_t *vals, uint32_t n, - uint8_t *scratch) { + uint8_t *scratch) { #if defined(__SYCL_DEVICE_ONLY__) && (defined(__SPIR__) || defined(__SPIRV__)) #ifdef DES __devicelib_default_work_group_joint_sort_descending_p1u8_p1i8_u32_p1i8( @@ -125,77 +161,383 @@ int main() { #endif }; + auto work_group_sorter2 = [](uint8_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { +#if defined(__SYCL_DEVICE_ONLY__) && (defined(__SPIR__) || defined(__SPIRV__)) +#ifdef DES + __devicelib_default_work_group_joint_sort_descending_p1u8_p1u16_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_joint_sort_ascending_p1u8_p1u16_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif + }; + + auto work_group_sorter3 = [](uint8_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { +#if defined(__SYCL_DEVICE_ONLY__) && (defined(__SPIR__) || defined(__SPIRV__)) +#ifdef DES + __devicelib_default_work_group_joint_sort_descending_p1u8_p1i16_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_joint_sort_ascending_p1u8_p1i16_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif + }; + + auto work_group_sorter4 = [](uint8_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { +#if defined(__SYCL_DEVICE_ONLY__) && (defined(__SPIR__) || defined(__SPIRV__)) +#ifdef DES + __devicelib_default_work_group_joint_sort_descending_p1u8_p1u32_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_joint_sort_ascending_p1u8_p1u32_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif + }; + + auto work_group_sorter5 = [](uint8_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { +#if defined(__SYCL_DEVICE_ONLY__) && (defined(__SPIR__) || defined(__SPIRV__)) +#ifdef DES + __devicelib_default_work_group_joint_sort_descending_p1u8_p1i32_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_joint_sort_ascending_p1u8_p1i32_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif + }; + + auto work_group_sorter6 = [](uint8_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { +#if defined(__SYCL_DEVICE_ONLY__) && (defined(__SPIR__) || defined(__SPIRV__)) +#ifdef DES + __devicelib_default_work_group_joint_sort_descending_p1u8_p1u64_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_joint_sort_ascending_p1u8_p1u64_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif + }; + + auto work_group_sorter7 = [](uint8_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { +#if defined(__SYCL_DEVICE_ONLY__) && (defined(__SPIR__) || defined(__SPIRV__)) +#ifdef DES + __devicelib_default_work_group_joint_sort_descending_p1u8_p1i64_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_joint_sort_ascending_p1u8_p1i64_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif + }; + test_work_group_KV_joint_sort( q, ikeys, ivals, work_group_sorter); - std::cout << "KV joint sort p1u8_p1u8_u32_p1i8 pass." - << std::endl; + std::cout + << "KV joint sort (Key: uint8_t, Val: uint8_t) pass." + << std::endl; test_work_group_KV_joint_sort( q, ikeys, ivals, work_group_sorter); - std::cout << "KV joint sort p1u8_p1u8_u32_p1i8 pass." - << std::endl; + std::cout + << "KV joint sort (Key: uint8_t, Val: uint8_t) pass." + << std::endl; test_work_group_KV_joint_sort( q, ikeys, ivals, work_group_sorter); - std::cout << "KV joint sort p1u8_p1u8_u32_p1i8 pass." - << std::endl; + std::cout + << "KV joint sort (Key: uint8_t, Val: uint8_t) pass." + << std::endl; test_work_group_KV_joint_sort( q, ikeys, ivals, work_group_sorter); - std::cout << "KV joint sort p1u8_p1u8_u32_p1i8 pass." - << std::endl; + std::cout + << "KV joint sort (Key: uint8_t, Val: uint8_t) pass." + << std::endl; test_work_group_KV_joint_sort( q, ikeys, ivals, work_group_sorter); - std::cout << "KV joint sort p1u8_p1u8_u32_p1i8 pass." + std::cout << "KV joint sort (Key: uint8_t, Val: uint8_t) pass." << std::endl; test_work_group_KV_joint_sort( q, ikeys, ivals, work_group_sorter); - std::cout << "KV joint sort p1u8_p1u8_u32_p1i8 pass." + std::cout << "KV joint sort (Key: uint8_t, Val: uint8_t) pass." << std::endl; test_work_group_KV_joint_sort( q, ikeys, ivals2, work_group_sorter1); - std::cout << "KV joint sort p1u8_p1i8_u32_p1i8 pass." - << std::endl; + std::cout + << "KV joint sort (Key: uint8_t, Val: int8_t) pass." + << std::endl; test_work_group_KV_joint_sort( q, ikeys, ivals2, work_group_sorter1); - std::cout << "KV joint sort p1u8_p1i8_u32_p1i8 pass." - << std::endl; + std::cout + << "KV joint sort (Key: uint8_t, Val: int8_t) pass." + << std::endl; test_work_group_KV_joint_sort( q, ikeys, ivals2, work_group_sorter1); - std::cout << "KV joint sort p1u8_p1i8_u32_p1i8 pass." - << std::endl; + std::cout + << "KV joint sort (Key: uint8_t, Val: int8_t) pass." + << std::endl; test_work_group_KV_joint_sort( q, ikeys, ivals2, work_group_sorter1); - std::cout << "KV joint sort p1u8_p1i8_u32_p1i8 pass." - << std::endl; + std::cout + << "KV joint sort (Key: uint8_t, Val: int8_t) pass." + << std::endl; test_work_group_KV_joint_sort( q, ikeys, ivals2, work_group_sorter1); - std::cout << "KV joint sort p1u8_p1i8_u32_p1i8 pass." - << std::endl; + std::cout + << "KV joint sort (Key: uint8_t, Val: int8_t) pass." + << std::endl; test_work_group_KV_joint_sort( q, ikeys, ivals2, work_group_sorter1); - std::cout << "KV joint sort p1u8_p1i8_u32_p1i8 pass." - << std::endl; + std::cout + << "KV joint sort (Key: uint8_t, Val: int8_t) pass." + << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys, ivals3, work_group_sorter2); + std::cout << "KV joint sort (Key: uint8_t, Val: uint16_t) pass." << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys, ivals3, work_group_sorter2); + std::cout << "KV joint sort (Key: uint8_t, Val: uint16_t) pass." << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys, ivals3, work_group_sorter2); + std::cout << "KV joint sort (Key: uint8_t, Val: uint16_t) pass." << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys, ivals3, work_group_sorter2); + std::cout << "KV joint sort (Key: uint8_t, Val: uint16_t) pass." << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys, ivals3, work_group_sorter2); + std::cout << "KV joint sort (Key: uint8_t, Val: uint16_t) pass." << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys, ivals3, work_group_sorter2); + std::cout << "KV joint sort (Key: uint8_t, Val: uint16_t) pass." << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys, ivals4, work_group_sorter3); + std::cout << "KV joint sort (Key: uint8_t, Val: int16_t) pass." << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys, ivals4, work_group_sorter3); + std::cout << "KV joint sort (Key: uint8_t, Val: int16_t) pass." << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys, ivals4, work_group_sorter3); + std::cout << "KV joint sort (Key: uint8_t, Val: int16_t) pass." << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys, ivals4, work_group_sorter3); + std::cout << "KV joint sort (Key: uint8_t, Val: int16_t) pass." << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys, ivals4, work_group_sorter3); + std::cout << "KV joint sort (Key: uint8_t, Val: int16_t) pass." << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys, ivals4, work_group_sorter3); + std::cout << "KV joint sort (Key: uint8_t, Val: int16_t) pass." << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys, ivals5, work_group_sorter4); + std::cout << "KV joint sort (Key: uint8_t, Val: uint32_t) pass." << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys, ivals5, work_group_sorter4); + std::cout << "KV joint sort (Key: uint8_t, Val: uint32_t) pass." << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys, ivals5, work_group_sorter4); + std::cout << "KV joint sort (Key: uint8_t, Val: uint32_t) pass." << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys, ivals5, work_group_sorter4); + std::cout << "KV joint sort (Key: uint8_t, Val: uint32_t) pass." << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys, ivals5, work_group_sorter4); + std::cout << "KV joint sort (Key: uint8_t, Val: uint32_t) pass." << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys, ivals5, work_group_sorter4); + std::cout << "KV joint sort (Key: uint8_t, Val: uint32_t) pass." << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys, ivals6, work_group_sorter5); + std::cout << "KV joint sort (Key: uint8_t, Val: int32_t) pass." << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys, ivals6, work_group_sorter5); + std::cout << "KV joint sort (Key: uint8_t, Val: int32_t) pass." << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys, ivals6, work_group_sorter5); + std::cout << "KV joint sort (Key: uint8_t, Val: int32_t) pass." << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys, ivals6, work_group_sorter5); + std::cout << "KV joint sort (Key: uint8_t, Val: int32_t) pass." << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys, ivals6, work_group_sorter5); + std::cout << "KV joint sort (Key: uint8_t, Val: int32_t) pass." << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys, ivals6, work_group_sorter5); + std::cout << "KV joint sort (Key: uint8_t, Val: int32_t) pass." << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys, ivals7, work_group_sorter6); + std::cout << "KV joint sort (Key: uint8_t, Val: uint64_t) pass." << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys, ivals7, work_group_sorter6); + std::cout << "KV joint sort (Key: uint8_t, Val: uint64_t) pass." << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys, ivals7, work_group_sorter6); + std::cout << "KV joint sort (Key: uint8_t, Val: uint64_t) pass." << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys, ivals7, work_group_sorter6); + std::cout << "KV joint sort (Key: uint8_t, Val: uint64_t) pass." << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys, ivals7, work_group_sorter6); + std::cout << "KV joint sort (Key: uint8_t, Val: uint64_t) pass." << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys, ivals7, work_group_sorter6); + std::cout << "KV joint sort (Key: uint8_t, Val: uint64_t) pass." << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys, ivals8, work_group_sorter7); + std::cout << "KV joint sort (Key: uint8_t, Val: int64_t) pass." << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys, ivals8, work_group_sorter7); + std::cout << "KV joint sort (Key: uint8_t, Val: int64_t) pass." << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys, ivals8, work_group_sorter7); + std::cout << "KV joint sort (Key: uint8_t, Val: int64_t) pass." << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys, ivals8, work_group_sorter7); + std::cout << "KV joint sort (Key: uint8_t, Val: int64_t) pass." << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys, ivals8, work_group_sorter7); + std::cout << "KV joint sort (Key: uint8_t, Val: int64_t) pass." << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys, ivals8, work_group_sorter7); + std::cout << "KV joint sort (Key: uint8_t, Val: int64_t) pass." << std::endl; } { @@ -220,7 +562,8 @@ int main() { test_work_group_KV_joint_sort( q, ikeys, ivals, work_group_sorter); - std::cout << "KV joint sort p1u32_p1u32_u32_p1i8 pass." + std::cout << "KV joint sort (Key: uint32_t, Val: uint32_t) pass." << std::endl; } @@ -251,31 +594,41 @@ int main() { test_work_group_KV_joint_sort( q, ikeys, ivals, work_group_sorter); - std::cout << "KV joint sort p1u32_p1u32_u32_p1i8 pass." + std::cout << "KV joint sort (Key: uint32_t, Val: uint32_t) pass." << std::endl; test_work_group_KV_joint_sort( q, ikeys, ivals, work_group_sorter); - std::cout << "KV joint sort p1u32_p1u32_u32_p1i8 pass." + std::cout << "KV joint sort (Key: uint32_t, Val: uint32_t) pass." << std::endl; test_work_group_KV_joint_sort( q, ikeys, ivals, work_group_sorter); - std::cout << "KV joint sort p1u32_p1u32_u32_p1i8 pass." + std::cout << "KV joint sort (Key: uint32_t, Val: uint32_t) pass." << std::endl; test_work_group_KV_joint_sort( q, ikeys, ivals, work_group_sorter); - std::cout << "KV joint sort p1u32_p1u32_u32_p1i8 pass." + std::cout << "KV joint sort (Key: uint32_t, Val: uint32_t) pass." << std::endl; - test_work_group_KV_joint_sort( q, ikeys, ivals, work_group_sorter); - std::cout << "KV joint sort p1u32_p1u32_u32_p1i8 pass." + std::cout << "KV joint sort (Key: uint32_t, Val: uint32_t) pass." << std::endl; } From 7b8dc3038fcab6ab61f72a8c59c7d0c2dd0e7253 Mon Sep 17 00:00:00 2001 From: jinge90 Date: Fri, 13 Sep 2024 15:36:29 +0800 Subject: [PATCH 45/71] use p1u8_p1u32 for p1u8_p1f32 Signed-off-by: jinge90 --- libdevice/fallback-gsort.cpp | 6 +- .../group_joint_KV_sort_p1p1_p1.hpp | 8 +++ .../workgroup_joint_KV_sort_p1p1_p1.cpp | 57 +++++++++++++++++++ 3 files changed, 69 insertions(+), 2 deletions(-) diff --git a/libdevice/fallback-gsort.cpp b/libdevice/fallback-gsort.cpp index d11410bdd3f9e..14cabfe25d14d 100644 --- a/libdevice/fallback-gsort.cpp +++ b/libdevice/fallback-gsort.cpp @@ -1063,13 +1063,15 @@ void WG_JS_D(p1u8_p1i64_u32_p1i8)(uint8_t *keys, int64_t *vals, uint32_t n, DEVICE_EXTERN_C_INLINE void WG_JS_A(p1u8_p1f32_u32_p1i8)(uint8_t *keys, float *vals, uint32_t n, uint8_t *scratch) { - merge_sort_key_value(keys, vals, n, scratch, std::less_equal{}); + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::less_equal{}); } DEVICE_EXTERN_C_INLINE void WG_JS_D(p1u8_p1f32_u32_p1i8)(uint8_t *keys, float *vals, uint32_t n, uint8_t *scratch) { - merge_sort_key_value(keys, vals, n, scratch, std::greater_equal{}); + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::greater_equal{}); } // uint16_t as key type diff --git a/sycl/test-e2e/DeviceLib/group_sort/group_joint_KV_sort_p1p1_p1.hpp b/sycl/test-e2e/DeviceLib/group_sort/group_joint_KV_sort_p1p1_p1.hpp index 582ab62d04116..40f99b861dfef 100644 --- a/sycl/test-e2e/DeviceLib/group_sort/group_joint_KV_sort_p1p1_p1.hpp +++ b/sycl/test-e2e/DeviceLib/group_sort/group_joint_KV_sort_p1p1_p1.hpp @@ -65,6 +65,14 @@ __DPCPP_SYCL_EXTERNAL extern "C" void __devicelib_default_work_group_joint_sort_descending_p1u8_p1i64_u32_p1i8( uint8_t *keys, int64_t *vals, uint32_t n, uint8_t *scratch); +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p1u8_p1f32_u32_p1i8( + uint8_t *keys, float *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_descending_p1u8_p1f32_u32_p1i8( + uint8_t *keys, float *vals, uint32_t n, uint8_t *scratch); + __DPCPP_SYCL_EXTERNAL extern "C" void __devicelib_default_work_group_joint_sort_ascending_p1u32_p1u32_u32_p1i8( uint32_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch); diff --git a/sycl/test-e2e/DeviceLib/group_sort/workgroup_joint_KV_sort_p1p1_p1.cpp b/sycl/test-e2e/DeviceLib/group_sort/workgroup_joint_KV_sort_p1p1_p1.cpp index c91b78491c9a6..2e6b36f2cc554 100644 --- a/sycl/test-e2e/DeviceLib/group_sort/workgroup_joint_KV_sort_p1p1_p1.cpp +++ b/sycl/test-e2e/DeviceLib/group_sort/workgroup_joint_KV_sort_p1p1_p1.cpp @@ -135,6 +135,14 @@ int main() { 0, 21, -3257266819, -10650299901112, 76, -9837411355, 1464548796, 403290}; + float ivals9[NUM] = { + 1.628561f, 2.998057f, 0.082604f, 0.0f, 12.330798f, + -1.350443f, 0.437885f, 0.017387f, 0.474454f, -0.718838f, + 98.150388f, 0.732236f, 0.519963f, -0.332644f, 0.648420f, + 0.578913f, -0.853190f, -910.141650f, 110.037210f, 0.434222f, + -0.343777f, 0.346011f, 0.767590f, + }; + auto work_group_sorter = [](uint8_t *keys, uint8_t *vals, uint32_t n, uint8_t *scratch) { #if defined(__SYCL_DEVICE_ONLY__) && (defined(__SPIR__) || defined(__SPIRV__)) @@ -239,6 +247,19 @@ int main() { #endif }; + auto work_group_sorter8 = [](uint8_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { +#if defined(__SYCL_DEVICE_ONLY__) && (defined(__SPIR__) || defined(__SPIRV__)) +#ifdef DES + __devicelib_default_work_group_joint_sort_descending_p1u8_p1f32_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_joint_sort_ascending_p1u8_p1f32_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif + }; + test_work_group_KV_joint_sort( q, ikeys, ivals, work_group_sorter); @@ -538,6 +559,42 @@ int main() { q, ikeys, ivals8, work_group_sorter7); std::cout << "KV joint sort (Key: uint8_t, Val: int64_t) pass." << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys, ivals9, work_group_sorter8); + std::cout << "KV joint sort (Key: uint8_t, Val: float) pass." << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys, ivals9, work_group_sorter8); + std::cout << "KV joint sort (Key: uint8_t, Val: float) pass." << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys, ivals9, work_group_sorter8); + std::cout << "KV joint sort (Key: uint8_t, Val: float) pass." << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys, ivals9, work_group_sorter8); + std::cout << "KV joint sort (Key: uint8_t, Val: float) pass." << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys, ivals9, work_group_sorter8); + std::cout << "KV joint sort (Key: uint8_t, Val: float) pass." << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys, ivals9, work_group_sorter8); + std::cout << "KV joint sort (Key: uint8_t, Val: float) pass." << std::endl; } { From 4f70308db15398176d02315385de04663a2f30a3 Mon Sep 17 00:00:00 2001 From: jinge90 Date: Fri, 13 Sep 2024 16:26:47 +0800 Subject: [PATCH 46/71] use u32 copy for fp32 Signed-off-by: jinge90 --- libdevice/fallback-gsort.cpp | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/libdevice/fallback-gsort.cpp b/libdevice/fallback-gsort.cpp index 14cabfe25d14d..a9d18d7aca5ec 100644 --- a/libdevice/fallback-gsort.cpp +++ b/libdevice/fallback-gsort.cpp @@ -1126,13 +1126,15 @@ void WG_JS_D(p1u16_p1u64_u32_p1i8)(uint16_t *keys, uint64_t *vals, uint32_t n, DEVICE_EXTERN_C_INLINE void WG_JS_A(p1u16_p1f32_u32_p1i8)(uint16_t *keys, float *vals, uint32_t n, uint8_t *scratch) { - merge_sort_key_value(keys, vals, n, scratch, std::less_equal{}); + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::less_equal{}); } DEVICE_EXTERN_C_INLINE void WG_JS_D(p1u16_p1f32_u32_p1i8)(uint16_t *keys, float *vals, uint32_t n, uint8_t *scratch) { - merge_sort_key_value(keys, vals, n, scratch, std::greater_equal{}); + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::greater_equal{}); } // uint32_t as key type @@ -1187,13 +1189,15 @@ void WG_JS_D(p1u32_p1u64_u32_p1i8)(uint32_t *keys, uint64_t *vals, uint32_t n, DEVICE_EXTERN_C_INLINE void WG_JS_A(p1u32_p1f32_u32_p1i8)(uint32_t *keys, float *vals, uint32_t n, uint8_t *scratch) { - merge_sort_key_value(keys, vals, n, scratch, std::less_equal{}); + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::less_equal{}); } DEVICE_EXTERN_C_INLINE void WG_JS_D(p1u32_p1f32_u32_p1i8)(uint32_t *keys, float *vals, uint32_t n, uint8_t *scratch) { - merge_sort_key_value(keys, vals, n, scratch, std::greater_equal{}); + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::greater_equal{}); } // uint64_t as key type @@ -1248,13 +1252,15 @@ void WG_JS_D(p1u64_p1u64_u32_p1i8)(uint64_t *keys, uint64_t *vals, uint32_t n, DEVICE_EXTERN_C_INLINE void WG_JS_A(p1u64_p1f32_u32_p1i8)(uint64_t *keys, float *vals, uint32_t n, uint8_t *scratch) { - merge_sort_key_value(keys, vals, n, scratch, std::less_equal{}); + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::less_equal{}); } DEVICE_EXTERN_C_INLINE void WG_JS_D(p1u64_p1f32_u32_p1i8)(uint64_t *keys, float *vals, uint32_t n, uint8_t *scratch) { - merge_sort_key_value(keys, vals, n, scratch, std::greater_equal{}); + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::greater_equal{}); } DEVICE_EXTERN_C_INLINE From 96f7ba7d3a96e46b96c85441bb29034e1d325462 Mon Sep 17 00:00:00 2001 From: jinge90 Date: Sat, 14 Sep 2024 10:22:19 +0800 Subject: [PATCH 47/71] add p1u16_p1i8/16/32/64 Signed-off-by: jinge90 --- libdevice/fallback-gsort.cpp | 56 ++++++++++++++++++++++++++++++++++++ 1 file changed, 56 insertions(+) diff --git a/libdevice/fallback-gsort.cpp b/libdevice/fallback-gsort.cpp index a9d18d7aca5ec..908347050860a 100644 --- a/libdevice/fallback-gsort.cpp +++ b/libdevice/fallback-gsort.cpp @@ -1087,6 +1087,20 @@ void WG_JS_D(p1u16_p1u8_u32_p1i8)(uint16_t *keys, uint8_t *vals, uint32_t n, merge_sort_key_value(keys, vals, n, scratch, std::greater_equal{}); } +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1u16_p1i8_u32_p1i8)(uint16_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1u16_p1i8_u32_p1i8)(uint16_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::greater_equal{}); +} + DEVICE_EXTERN_C_INLINE void WG_JS_A(p1u16_p1u16_u32_p1i8)(uint16_t *keys, uint16_t *vals, uint32_t n, uint8_t *scratch) { @@ -1099,6 +1113,20 @@ void WG_JS_D(p1u16_p1u16_u32_p1i8)(uint16_t *keys, uint16_t *vals, uint32_t n, merge_sort_key_value(keys, vals, n, scratch, std::greater_equal{}); } +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1u16_p1i16_u32_p1i8)(uint16_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1u16_p1i16_u32_p1i8)(uint16_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::greater_equal{}); +} + DEVICE_EXTERN_C_INLINE void WG_JS_A(p1u16_p1u32_u32_p1i8)(uint16_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch) { @@ -1111,6 +1139,20 @@ void WG_JS_D(p1u16_p1u32_u32_p1i8)(uint16_t *keys, uint32_t *vals, uint32_t n, merge_sort_key_value(keys, vals, n, scratch, std::greater_equal{}); } +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1u16_p1i32_u32_p1i8)(uint16_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1u16_p1i32_u32_p1i8)(uint16_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::greater_equal{}); +} + DEVICE_EXTERN_C_INLINE void WG_JS_A(p1u16_p1u64_u32_p1i8)(uint16_t *keys, uint64_t *vals, uint32_t n, uint8_t *scratch) { @@ -1123,6 +1165,20 @@ void WG_JS_D(p1u16_p1u64_u32_p1i8)(uint16_t *keys, uint64_t *vals, uint32_t n, merge_sort_key_value(keys, vals, n, scratch, std::greater_equal{}); } +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1u16_p1i64_u32_p1i8)(uint16_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1u16_p1i64_u32_p1i8)(uint16_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::greater_equal{}); +} + DEVICE_EXTERN_C_INLINE void WG_JS_A(p1u16_p1f32_u32_p1i8)(uint16_t *keys, float *vals, uint32_t n, uint8_t *scratch) { From 77d9856c2365a3656c9dccd992ed8b0841246d90 Mon Sep 17 00:00:00 2001 From: jinge90 Date: Thu, 19 Sep 2024 11:47:05 +0800 Subject: [PATCH 48/71] add p1 signed integer key --- libdevice/fallback-gsort.cpp | 588 +++++++++++++++++++++++++++++++++++ 1 file changed, 588 insertions(+) diff --git a/libdevice/fallback-gsort.cpp b/libdevice/fallback-gsort.cpp index 908347050860a..44544a5163f2c 100644 --- a/libdevice/fallback-gsort.cpp +++ b/libdevice/fallback-gsort.cpp @@ -1074,6 +1074,125 @@ void WG_JS_D(p1u8_p1f32_u32_p1i8)(uint8_t *keys, float *vals, uint32_t n, std::greater_equal{}); } +// int8_t as key +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1i8_p1u8_u32_p1i8)(int8_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1i8_p1u8_u32_p1i8)(int8_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1i8_p1i8_u32_p1i8)(int8_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1i8_p1i8_u32_p1i8)(int8_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1i8_p1u16_u32_p1i8)(int8_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1i8_p1u16_u32_p1i8)(int8_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1i8_p1i16_u32_p1i8)(int8_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1i8_p1i16_u32_p1i8)(int8_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1i8_p1u32_u32_p1i8)(int8_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1i8_p1u32_u32_p1i8)(int8_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1i8_p1i32_u32_p1i8)(int8_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1i8_p1i32_u32_p1i8)(int8_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1i8_p1u64_u32_p1i8)(int8_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1i8_p1u64_u32_p1i8)(int8_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1i8_p1i64_u32_p1i8)(int8_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1i8_p1i64_u32_p1i8)(int8_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1i8_p1f32_u32_p1i8)(int8_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1i8_p1f32_u32_p1i8)(int8_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::greater_equal{}); +} + // uint16_t as key type DEVICE_EXTERN_C_INLINE void WG_JS_A(p1u16_p1u8_u32_p1i8)(uint16_t *keys, uint8_t *vals, uint32_t n, @@ -1193,6 +1312,125 @@ void WG_JS_D(p1u16_p1f32_u32_p1i8)(uint16_t *keys, float *vals, uint32_t n, std::greater_equal{}); } +// int16_t as key type +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1i16_p1u8_u32_p1i8)(int16_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1i16_p1u8_u32_p1i8)(int16_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1i16_p1i8_u32_p1i8)(int16_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1i16_p1i8_u32_p1i8)(int16_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1i16_p1u16_u32_p1i8)(int16_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1i16_p1u16_u32_p1i8)(int16_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1i16_p1i16_u32_p1i8)(int16_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1i16_p1i16_u32_p1i8)(int16_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1i16_p1u32_u32_p1i8)(int16_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1i16_p1u32_u32_p1i8)(int16_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1i16_p1i32_u32_p1i8)(int16_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1i16_p1i32_u32_p1i8)(int16_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1i16_p1u64_u32_p1i8)(int16_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1i16_p1u64_u32_p1i8)(int16_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1i16_p1i64_u32_p1i8)(int16_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1i16_p1i64_u32_p1i8)(int16_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1i16_p1f32_u32_p1i8)(int16_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1i16_p1f32_u32_p1i8)(int16_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::greater_equal{}); +} + // uint32_t as key type DEVICE_EXTERN_C_INLINE void WG_JS_A(p1u32_p1u8_u32_p1i8)(uint32_t *keys, uint8_t *vals, uint32_t n, @@ -1206,6 +1444,20 @@ void WG_JS_D(p1u32_p1u8_u32_p1i8)(uint32_t *keys, uint8_t *vals, uint32_t n, merge_sort_key_value(keys, vals, n, scratch, std::greater_equal{}); } +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1u32_p1i8_u32_p1i8)(uint32_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1u32_p1i8_u32_p1i8)(uint32_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::greater_equal{}); +} + DEVICE_EXTERN_C_INLINE void WG_JS_A(p1u32_p1u16_u32_p1i8)(uint32_t *keys, uint16_t *vals, uint32_t n, uint8_t *scratch) { @@ -1218,6 +1470,20 @@ void WG_JS_D(p1u32_p1u16_u32_p1i8)(uint32_t *keys, uint16_t *vals, uint32_t n, merge_sort_key_value(keys, vals, n, scratch, std::greater_equal{}); } +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1u32_p1i16_u32_p1i8)(uint32_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1u32_p1i16_u32_p1i8)(uint32_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::greater_equal{}); +} + DEVICE_EXTERN_C_INLINE void WG_JS_A(p1u32_p1u32_u32_p1i8)(uint32_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch) { @@ -1230,6 +1496,20 @@ void WG_JS_D(p1u32_p1u32_u32_p1i8)(uint32_t *keys, uint32_t *vals, uint32_t n, merge_sort_key_value(keys, vals, n, scratch, std::greater_equal{}); } +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1u32_p1i32_u32_p1i8)(uint32_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1u32_p1i32_u32_p1i8)(uint32_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::greater_equal{}); +} + DEVICE_EXTERN_C_INLINE void WG_JS_A(p1u32_p1u64_u32_p1i8)(uint32_t *keys, uint64_t *vals, uint32_t n, uint8_t *scratch) { @@ -1242,6 +1522,20 @@ void WG_JS_D(p1u32_p1u64_u32_p1i8)(uint32_t *keys, uint64_t *vals, uint32_t n, merge_sort_key_value(keys, vals, n, scratch, std::greater_equal{}); } +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1u32_p1i64_u32_p1i8)(uint32_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1u32_p1i64_u32_p1i8)(uint32_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::greater_equal{}); +} + DEVICE_EXTERN_C_INLINE void WG_JS_A(p1u32_p1f32_u32_p1i8)(uint32_t *keys, float *vals, uint32_t n, uint8_t *scratch) { @@ -1256,6 +1550,125 @@ void WG_JS_D(p1u32_p1f32_u32_p1i8)(uint32_t *keys, float *vals, uint32_t n, std::greater_equal{}); } +// int32_t as key type +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1i32_p1u8_u32_p1i8)(int32_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1i32_p1u8_u32_p1i8)(int32_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1i32_p1i8_u32_p1i8)(int32_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1i32_p1i8_u32_p1i8)(int32_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1i32_p1u16_u32_p1i8)(int32_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1i32_p1u16_u32_p1i8)(int32_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1i32_p1i16_u32_p1i8)(int32_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1i32_p1i16_u32_p1i8)(int32_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1i32_p1u32_u32_p1i8)(int32_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1i32_p1u32_u32_p1i8)(int32_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1i32_p1i32_u32_p1i8)(int32_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1i32_p1i32_u32_p1i8)(int32_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1i32_p1u64_u32_p1i8)(int32_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1i32_p1u64_u32_p1i8)(int32_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1i32_p1i64_u32_p1i8)(int32_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1i32_p1i64_u32_p1i8)(int32_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1i32_p1f32_u32_p1i8)(int32_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1i32_p1f32_u32_p1i8)(int32_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::greater_equal{}); +} + // uint64_t as key type DEVICE_EXTERN_C_INLINE void WG_JS_A(p1u64_p1u8_u32_p1i8)(uint64_t *keys, uint8_t *vals, uint32_t n, @@ -1269,6 +1682,20 @@ void WG_JS_D(p1u64_p1u8_u32_p1i8)(uint64_t *keys, uint8_t *vals, uint32_t n, merge_sort_key_value(keys, vals, n, scratch, std::greater_equal{}); } +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1u64_p1i8_u32_p1i8)(uint64_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1u64_p1i8_u32_p1i8)(uint64_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::greater_equal{}); +} + DEVICE_EXTERN_C_INLINE void WG_JS_A(p1u64_p1u16_u32_p1i8)(uint64_t *keys, uint16_t *vals, uint32_t n, uint8_t *scratch) { @@ -1281,6 +1708,20 @@ void WG_JS_D(p1u64_p1u16_u32_p1i8)(uint64_t *keys, uint16_t *vals, uint32_t n, merge_sort_key_value(keys, vals, n, scratch, std::greater_equal{}); } +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1u64_p1i16_u32_p1i8)(uint64_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1u64_p1i16_u32_p1i8)(uint64_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::greater_equal{}); +} + DEVICE_EXTERN_C_INLINE void WG_JS_A(p1u64_p1u32_u32_p1i8)(uint64_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch) { @@ -1293,6 +1734,20 @@ void WG_JS_D(p1u64_p1u32_u32_p1i8)(uint64_t *keys, uint32_t *vals, uint32_t n, merge_sort_key_value(keys, vals, n, scratch, std::greater_equal{}); } +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1u64_p1i32_u32_p1i8)(uint64_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1u64_p1i32_u32_p1i8)(uint64_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::greater_equal{}); +} + DEVICE_EXTERN_C_INLINE void WG_JS_A(p1u64_p1u64_u32_p1i8)(uint64_t *keys, uint64_t *vals, uint32_t n, uint8_t *scratch) { @@ -1305,6 +1760,20 @@ void WG_JS_D(p1u64_p1u64_u32_p1i8)(uint64_t *keys, uint64_t *vals, uint32_t n, merge_sort_key_value(keys, vals, n, scratch, std::greater_equal{}); } +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1u64_p1i64_u32_p1i8)(uint64_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1u64_p1i64_u32_p1i8)(uint64_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::greater_equal{}); +} + DEVICE_EXTERN_C_INLINE void WG_JS_A(p1u64_p1f32_u32_p1i8)(uint64_t *keys, float *vals, uint32_t n, uint8_t *scratch) { @@ -1319,6 +1788,125 @@ void WG_JS_D(p1u64_p1f32_u32_p1i8)(uint64_t *keys, float *vals, uint32_t n, std::greater_equal{}); } +// int64_t as key type +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1i64_p1u8_u32_p1i8)(int64_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1i64_p1u8_u32_p1i8)(int64_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1i64_p1i8_u32_p1i8)(int64_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1i64_p1i8_u32_p1i8)(int64_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1i64_p1u16_u32_p1i8)(int64_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1i64_p1u16_u32_p1i8)(int64_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1i64_p1i16_u32_p1i8)(int64_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1i64_p1i16_u32_p1i8)(int64_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1i64_p1u32_u32_p1i8)(int64_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1i64_p1u32_u32_p1i8)(int64_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1i64_p1i32_u32_p1i8)(int64_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1i64_p1i32_u32_p1i8)(int64_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1i64_p1u64_u32_p1i8)(int64_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1i64_p1u64_u32_p1i8)(int64_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1i64_p1i64_u32_p1i8)(int64_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1i64_p1i64_u32_p1i8)(int64_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1i64_p1f32_u32_p1i8)(int64_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1i64_p1f32_u32_p1i8)(int64_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::greater_equal{}); +} + DEVICE_EXTERN_C_INLINE void WG_PS_CA(p1u32_p1u32_u32_p1i8)(uint32_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch) { From cfd749d33be22e37f32653d4df0f50db4b22a841 Mon Sep 17 00:00:00 2001 From: jinge90 Date: Thu, 19 Sep 2024 17:22:30 +0800 Subject: [PATCH 49/71] add tests for p1i8_p1* Signed-off-by: jinge90 --- .../group_joint_KV_sort_p1p1_p1.hpp | 48 +++ .../workgroup_joint_KV_sort_p1p1_p1.cpp | 332 ++++++++++++++++++ 2 files changed, 380 insertions(+) diff --git a/sycl/test-e2e/DeviceLib/group_sort/group_joint_KV_sort_p1p1_p1.hpp b/sycl/test-e2e/DeviceLib/group_sort/group_joint_KV_sort_p1p1_p1.hpp index 40f99b861dfef..2929ff2976c9f 100644 --- a/sycl/test-e2e/DeviceLib/group_sort/group_joint_KV_sort_p1p1_p1.hpp +++ b/sycl/test-e2e/DeviceLib/group_sort/group_joint_KV_sort_p1p1_p1.hpp @@ -73,6 +73,54 @@ __DPCPP_SYCL_EXTERNAL extern "C" void __devicelib_default_work_group_joint_sort_descending_p1u8_p1f32_u32_p1i8( uint8_t *keys, float *vals, uint32_t n, uint8_t *scratch); +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p1i8_p1u8_u32_p1i8( + int8_t *keys, uint8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_descending_p1i8_p1u8_u32_p1i8( + int8_t *keys, uint8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p1i8_p1i8_u32_p1i8( + int8_t *keys, int8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_descending_p1i8_p1i8_u32_p1i8( + int8_t *keys, int8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p1i8_p1u16_u32_p1i8( + int8_t *keys, uint16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_descending_p1i8_p1u16_u32_p1i8( + int8_t *keys, uint16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p1i8_p1i16_u32_p1i8( + int8_t *keys, int16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_descending_p1i8_p1i16_u32_p1i8( + int8_t *keys, int16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p1i8_p1u32_u32_p1i8( + int8_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_descending_p1i8_p1u32_u32_p1i8( + int8_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p1i8_p1u64_u32_p1i8( + int8_t *keys, uint64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_descending_p1i8_p1u64_u32_p1i8( + int8_t *keys, uint64_t *vals, uint32_t n, uint8_t *scratch); + __DPCPP_SYCL_EXTERNAL extern "C" void __devicelib_default_work_group_joint_sort_ascending_p1u32_p1u32_u32_p1i8( uint32_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch); diff --git a/sycl/test-e2e/DeviceLib/group_sort/workgroup_joint_KV_sort_p1p1_p1.cpp b/sycl/test-e2e/DeviceLib/group_sort/workgroup_joint_KV_sort_p1p1_p1.cpp index 2e6b36f2cc554..4b1d6a50f2c73 100644 --- a/sycl/test-e2e/DeviceLib/group_sort/workgroup_joint_KV_sort_p1p1_p1.cpp +++ b/sycl/test-e2e/DeviceLib/group_sort/workgroup_joint_KV_sort_p1p1_p1.cpp @@ -93,6 +93,8 @@ int main() { constexpr static int NUM = 23; uint8_t ikeys[NUM] = {1, 11, 1, 9, 3, 100, 34, 8, 121, 77, 125, 23, 36, 2, 111, 91, 88, 2, 51, 91, 81, 122, 22}; + int8_t ikeys1[NUM] = {0, 1, -2, 1, 122, -123, 99, -91, 9, 12, 13, 46, + 13, 13, 9, 5, 77, 81, -100, 35, -64, 22, 23}; uint8_t ivals[NUM] = {99, 32, 1, 2, 67, 123, 253, 35, 111, 77, 165, 145, 254, 11, 161, 96, 165, 100, 64, 90, 255, 147, 135}; @@ -260,6 +262,336 @@ int main() { #endif }; + auto work_group_sorter9 = [](int8_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { +#if defined(__SYCL_DEVICE_ONLY__) && (defined(__SPIR__) || defined(__SPIRV__)) +#ifdef DES + __devicelib_default_work_group_joint_sort_descending_p1i8_p1u8_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_joint_sort_ascending_p1i8_p1u8_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif + }; + + auto work_group_sorter10 = [](int8_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { +#if defined(__SYCL_DEVICE_ONLY__) && (defined(__SPIR__) || defined(__SPIRV__)) +#ifdef DES + __devicelib_default_work_group_joint_sort_descending_p1i8_p1u16_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_joint_sort_ascending_p1i8_p1u16_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif + }; + + auto work_group_sorter11 = [](int8_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { +#if defined(__SYCL_DEVICE_ONLY__) && (defined(__SPIR__) || defined(__SPIRV__)) +#ifdef DES + __devicelib_default_work_group_joint_sort_descending_p1i8_p1u32_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_joint_sort_ascending_p1i8_p1u32_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif + }; + + auto work_group_sorter12 = [](int8_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { +#if defined(__SYCL_DEVICE_ONLY__) && (defined(__SPIR__) || defined(__SPIRV__)) +#ifdef DES + __devicelib_default_work_group_joint_sort_descending_p1i8_p1u64_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_joint_sort_ascending_p1i8_p1u64_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif + }; + + auto work_group_sorter13 = [](int8_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { +#if defined(__SYCL_DEVICE_ONLY__) && (defined(__SPIR__) || defined(__SPIRV__)) +#ifdef DES + __devicelib_default_work_group_joint_sort_descending_p1i8_p1i8_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_joint_sort_ascending_p1i8_p1i8_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif + }; + + auto work_group_sorter14 = [](int8_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { +#if defined(__SYCL_DEVICE_ONLY__) && (defined(__SPIR__) || defined(__SPIRV__)) +#ifdef DES + __devicelib_default_work_group_joint_sort_descending_p1i8_p1i16_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_joint_sort_ascending_p1i8_p1i16_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif + }; + + test_work_group_KV_joint_sort( + q, ikeys1, ivals, work_group_sorter9); + std::cout + << "KV joint sort (Key: int8_t, Val: uint8_t) pass." + << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys1, ivals, work_group_sorter9); + std::cout + << "KV joint sort (Key: int8_t, Val: uint8_t) pass." + << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys1, ivals, work_group_sorter9); + std::cout + << "KV joint sort (Key: int8_t, Val: uint8_t) pass." + << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys1, ivals, work_group_sorter9); + std::cout + << "KV joint sort (Key: int8_t, Val: uint8_t) pass." + << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys1, ivals, work_group_sorter9); + std::cout + << "KV joint sort (Key: int8_t, Val: uint8_t) pass." + << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys1, ivals, work_group_sorter9); + std::cout + << "KV joint sort (Key: int8_t, Val: uint8_t) pass." + << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys1, ivals2, work_group_sorter13); + std::cout + << "KV joint sort (Key: int8_t, Val: int8_t) pass." + << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys1, ivals2, work_group_sorter13); + std::cout + << "KV joint sort (Key: int8_t, Val: int8_t) pass." + << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys1, ivals2, work_group_sorter13); + std::cout + << "KV joint sort (Key: int8_t, Val: int8_t) pass." + << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys1, ivals2, work_group_sorter13); + std::cout + << "KV joint sort (Key: int8_t, Val: int8_t) pass." + << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys1, ivals2, work_group_sorter13); + std::cout + << "KV joint sort (Key: int8_t, Val: int8_t) pass." + << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys1, ivals2, work_group_sorter13); + std::cout + << "KV joint sort (Key: int8_t, Val: int8_t) pass." + << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys1, ivals3, work_group_sorter10); + std::cout + << "KV joint sort (Key: int8_t, Val: uint16_t) pass." + << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys1, ivals3, work_group_sorter10); + std::cout + << "KV joint sort (Key: int8_t, Val: uint16_t) pass." + << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys1, ivals3, work_group_sorter10); + std::cout + << "KV joint sort (Key: int8_t, Val: uint16_t) pass." + << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys1, ivals3, work_group_sorter10); + std::cout + << "KV joint sort (Key: int8_t, Val: uint16_t) pass." + << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys1, ivals3, work_group_sorter10); + std::cout << "KV joint sort (Key: int8_t, Val: uint16_t) pass." + << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys1, ivals3, work_group_sorter10); + std::cout << "KV joint sort (Key: int8_t, Val: uint16_t) pass." + << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys1, ivals4, work_group_sorter14); + std::cout + << "KV joint sort (Key: int8_t, Val: int16_t) pass." + << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys1, ivals4, work_group_sorter14); + std::cout + << "KV joint sort (Key: int8_t, Val: int16_t) pass." + << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys1, ivals4, work_group_sorter14); + std::cout + << "KV joint sort (Key: int8_t, Val: int16_t) pass." + << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys1, ivals4, work_group_sorter14); + std::cout + << "KV joint sort (Key: int8_t, Val: int16_t) pass." + << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys1, ivals4, work_group_sorter14); + std::cout + << "KV joint sort (Key: int8_t, Val: int16_t) pass." + << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys1, ivals4, work_group_sorter14); + std::cout + << "KV joint sort (Key: int8_t, Val: int16_t) pass." + << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys1, ivals5, work_group_sorter11); + std::cout + << "KV joint sort (Key: int8_t, Val: uint32_t) pass." + << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys1, ivals5, work_group_sorter11); + std::cout + << "KV joint sort (Key: int8_t, Val: uint32_t) pass." + << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys1, ivals5, work_group_sorter11); + std::cout + << "KV joint sort (Key: int8_t, Val: uint32_t) pass." + << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys1, ivals5, work_group_sorter11); + std::cout + << "KV joint sort (Key: int8_t, Val: uint32_t) pass." + << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys1, ivals5, work_group_sorter11); + std::cout << "KV joint sort (Key: int8_t, Val: uint32_t) pass." + << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys1, ivals5, work_group_sorter11); + std::cout << "KV joint sort (Key: int8_t, Val: uint32_t) pass." + << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys1, ivals7, work_group_sorter12); + std::cout + << "KV joint sort (Key: int8_t, Val: uint64_t) pass." + << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys1, ivals7, work_group_sorter12); + std::cout + << "KV joint sort (Key: int8_t, Val: uint64_t) pass." + << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys1, ivals7, work_group_sorter12); + std::cout + << "KV joint sort (Key: int8_t, Val: uint64_t) pass." + << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys1, ivals7, work_group_sorter12); + std::cout + << "KV joint sort (Key: int8_t, Val: uint64_t) pass." + << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys1, ivals7, work_group_sorter12); + std::cout << "KV joint sort (Key: int8_t, Val: uint64_t) pass." + << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys1, ivals7, work_group_sorter12); + std::cout << "KV joint sort (Key: int8_t, Val: uint64_t) pass." + << std::endl; + test_work_group_KV_joint_sort( q, ikeys, ivals, work_group_sorter); From bdd19f2a90d19308666805a39c6cfbaee0b838fe Mon Sep 17 00:00:00 2001 From: jinge90 Date: Mon, 23 Sep 2024 17:28:22 +0800 Subject: [PATCH 50/71] add wg_private_sort for key-value Signed-off-by: jinge90 --- libdevice/fallback-gsort.cpp | 133 ++++++- .../group_joint_KV_sort_p1p1_p1.hpp | 3 +- .../group_private_KV_sort_p1p1_p1.hpp | 75 +++- .../DeviceLib/group_sort/group_sort.hpp | 8 + .../workgroup_joint_KV_sort_p1p1_p1.cpp | 34 +- .../workgroup_private_KV_sort_p1p1_p1.cpp | 346 +++++++++++++++++- 6 files changed, 560 insertions(+), 39 deletions(-) create mode 100644 sycl/test-e2e/DeviceLib/group_sort/group_sort.hpp diff --git a/libdevice/fallback-gsort.cpp b/libdevice/fallback-gsort.cpp index 44544a5163f2c..a4140133f49ff 100644 --- a/libdevice/fallback-gsort.cpp +++ b/libdevice/fallback-gsort.cpp @@ -1907,11 +1907,61 @@ void WG_JS_D(p1i64_p1f32_u32_p1i8)(int64_t *keys, float *vals, uint32_t n, std::greater_equal{}); } +// Work group private sorting algorithms. DEVICE_EXTERN_C_INLINE -void WG_PS_CA(p1u32_p1u32_u32_p1i8)(uint32_t *keys, uint32_t *vals, uint32_t n, - uint8_t *scratch) { +void WG_PS_CA(p1u8_p1u8_u32_p1i8)(uint8_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { private_merge_sort_key_value_close(keys, vals, n, scratch, - std::less_equal{}); + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1u8_p1u8_u32_p1i8)(uint8_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1u8_p1i8_u32_p1i8)(uint8_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), n, + scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1u8_p1i8_u32_p1i8)(uint8_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), n, + scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1u8_p1u16_u32_p1i8)(uint8_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1u8_p1u16_u32_p1i8)(uint8_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1u8_p1i16_u32_p1i8)(uint8_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1u8_p1i16_u32_p1i8)(uint8_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::greater_equal{}); } DEVICE_EXTERN_C_INLINE @@ -1921,4 +1971,81 @@ void WG_PS_CA(p1u8_p1u32_u32_p1i8)(uint8_t *keys, uint32_t *vals, uint32_t n, std::less_equal{}); } +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1u8_p1u32_u32_p1i8)(uint8_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1u8_p1i32_u32_p1i8)(uint8_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1u8_p1i32_u32_p1i8)(uint8_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1u8_p1u64_u32_p1i8)(uint8_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1u8_p1u64_u32_p1i8)(uint8_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1u8_p1i64_u32_p1i8)(uint8_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1u8_p1i64_u32_p1i8)(uint8_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1u8_p1f32_u32_p1i8)(uint8_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1u8_p1f32_u32_p1i8)(uint8_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1u32_p1u32_u32_p1i8)(uint32_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1u32_p1u32_u32_p1i8)(uint32_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::greater_equal{}); +} + #endif // __SPIR__ || __SPIRV__ diff --git a/sycl/test-e2e/DeviceLib/group_sort/group_joint_KV_sort_p1p1_p1.hpp b/sycl/test-e2e/DeviceLib/group_sort/group_joint_KV_sort_p1p1_p1.hpp index 2929ff2976c9f..5bc5dc0e9f2b3 100644 --- a/sycl/test-e2e/DeviceLib/group_sort/group_joint_KV_sort_p1p1_p1.hpp +++ b/sycl/test-e2e/DeviceLib/group_sort/group_joint_KV_sort_p1p1_p1.hpp @@ -1,5 +1,4 @@ -#pragma once -#include +#include "group_sort.hpp" __DPCPP_SYCL_EXTERNAL extern "C" void __devicelib_default_work_group_joint_sort_ascending_p1u8_p1u8_u32_p1i8( diff --git a/sycl/test-e2e/DeviceLib/group_sort/group_private_KV_sort_p1p1_p1.hpp b/sycl/test-e2e/DeviceLib/group_sort/group_private_KV_sort_p1p1_p1.hpp index c4f5893950a87..c2351b29b774a 100644 --- a/sycl/test-e2e/DeviceLib/group_sort/group_private_KV_sort_p1p1_p1.hpp +++ b/sycl/test-e2e/DeviceLib/group_sort/group_private_KV_sort_p1p1_p1.hpp @@ -1,20 +1,69 @@ -#pragma once -#include +#include "group_sort.hpp" -#ifdef __SYCL_DEVICE_ONLY__ -SYCL_EXTERNAL extern "C" void +__DPCPP_SYCL_EXTERNAL extern "C" void __devicelib_default_work_group_private_sort_close_ascending_p1u32_p1u32_u32_p1i8( uint32_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch); -SYCL_EXTERNAL extern "C" void +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1u32_p1u32_u32_p1i8( + uint32_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1u8_p1u8_u32_p1i8( + uint8_t *keys, uint8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1u8_p1u8_u32_p1i8( + uint8_t *keys, uint8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1u8_p1i8_u32_p1i8( + uint8_t *keys, int8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1u8_p1i8_u32_p1i8( + uint8_t *keys, int8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1u8_p1u16_u32_p1i8( + uint8_t *keys, uint16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1u8_p1u16_u32_p1i8( + uint8_t *keys, uint16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1u8_p1i16_u32_p1i8( + uint8_t *keys, int16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1u8_p1i16_u32_p1i8( + uint8_t *keys, int16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void __devicelib_default_work_group_private_sort_close_ascending_p1u8_p1u32_u32_p1i8( uint8_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch); -#else -extern "C" void -__devicelib_default_work_group_private_sort_close_ascending_p1u32_p1u32_u32_p1i8( - uint32_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch) {} -extern "C" void -__devicelib_default_work_group_private_sort_close_ascending_p1u8_p1u32_u32_p1i8( - uint8_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch) {} -#endif +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1u8_p1u32_u32_p1i8( + uint8_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1u8_p1i32_u32_p1i8( + uint8_t *keys, int32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1u8_p1i32_u32_p1i8( + uint8_t *keys, int32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1u8_p1u64_u32_p1i8( + uint8_t *keys, uint64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1u8_p1u64_u32_p1i8( + uint8_t *keys, uint64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1u8_p1f32_u32_p1i8( + uint8_t *keys, float *vals, uint32_t n, uint8_t *scratch); diff --git a/sycl/test-e2e/DeviceLib/group_sort/group_sort.hpp b/sycl/test-e2e/DeviceLib/group_sort/group_sort.hpp new file mode 100644 index 0000000000000..e07ab7ac523b0 --- /dev/null +++ b/sycl/test-e2e/DeviceLib/group_sort/group_sort.hpp @@ -0,0 +1,8 @@ +#pragma once +#include + +#if defined(__SYCL_DEVICE_ONLY__) && (defined(__SPIR__) || defined(__SPIRV__)) +#define __DEVICE_CODE 1 +#else +#define __DEVICE_CODE 0 +#endif diff --git a/sycl/test-e2e/DeviceLib/group_sort/workgroup_joint_KV_sort_p1p1_p1.cpp b/sycl/test-e2e/DeviceLib/group_sort/workgroup_joint_KV_sort_p1p1_p1.cpp index 4b1d6a50f2c73..8063efa465436 100644 --- a/sycl/test-e2e/DeviceLib/group_sort/workgroup_joint_KV_sort_p1p1_p1.cpp +++ b/sycl/test-e2e/DeviceLib/group_sort/workgroup_joint_KV_sort_p1p1_p1.cpp @@ -147,7 +147,7 @@ int main() { auto work_group_sorter = [](uint8_t *keys, uint8_t *vals, uint32_t n, uint8_t *scratch) { -#if defined(__SYCL_DEVICE_ONLY__) && (defined(__SPIR__) || defined(__SPIRV__)) +#if __DEVICE_CODE #ifdef DES __devicelib_default_work_group_joint_sort_descending_p1u8_p1u8_u32_p1i8( keys, vals, n, scratch); @@ -160,7 +160,7 @@ int main() { auto work_group_sorter1 = [](uint8_t *keys, int8_t *vals, uint32_t n, uint8_t *scratch) { -#if defined(__SYCL_DEVICE_ONLY__) && (defined(__SPIR__) || defined(__SPIRV__)) +#if __DEVICE_CODE #ifdef DES __devicelib_default_work_group_joint_sort_descending_p1u8_p1i8_u32_p1i8( keys, vals, n, scratch); @@ -173,7 +173,7 @@ int main() { auto work_group_sorter2 = [](uint8_t *keys, uint16_t *vals, uint32_t n, uint8_t *scratch) { -#if defined(__SYCL_DEVICE_ONLY__) && (defined(__SPIR__) || defined(__SPIRV__)) +#if __DEVICE_CODE #ifdef DES __devicelib_default_work_group_joint_sort_descending_p1u8_p1u16_u32_p1i8( keys, vals, n, scratch); @@ -186,7 +186,7 @@ int main() { auto work_group_sorter3 = [](uint8_t *keys, int16_t *vals, uint32_t n, uint8_t *scratch) { -#if defined(__SYCL_DEVICE_ONLY__) && (defined(__SPIR__) || defined(__SPIRV__)) +#if __DEVICE_CODE #ifdef DES __devicelib_default_work_group_joint_sort_descending_p1u8_p1i16_u32_p1i8( keys, vals, n, scratch); @@ -199,7 +199,7 @@ int main() { auto work_group_sorter4 = [](uint8_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch) { -#if defined(__SYCL_DEVICE_ONLY__) && (defined(__SPIR__) || defined(__SPIRV__)) +#if __DEVICE_CODE #ifdef DES __devicelib_default_work_group_joint_sort_descending_p1u8_p1u32_u32_p1i8( keys, vals, n, scratch); @@ -212,7 +212,7 @@ int main() { auto work_group_sorter5 = [](uint8_t *keys, int32_t *vals, uint32_t n, uint8_t *scratch) { -#if defined(__SYCL_DEVICE_ONLY__) && (defined(__SPIR__) || defined(__SPIRV__)) +#if __DEVICE_CODE #ifdef DES __devicelib_default_work_group_joint_sort_descending_p1u8_p1i32_u32_p1i8( keys, vals, n, scratch); @@ -225,7 +225,7 @@ int main() { auto work_group_sorter6 = [](uint8_t *keys, uint64_t *vals, uint32_t n, uint8_t *scratch) { -#if defined(__SYCL_DEVICE_ONLY__) && (defined(__SPIR__) || defined(__SPIRV__)) +#if __DEVICE_CODE #ifdef DES __devicelib_default_work_group_joint_sort_descending_p1u8_p1u64_u32_p1i8( keys, vals, n, scratch); @@ -238,7 +238,7 @@ int main() { auto work_group_sorter7 = [](uint8_t *keys, int64_t *vals, uint32_t n, uint8_t *scratch) { -#if defined(__SYCL_DEVICE_ONLY__) && (defined(__SPIR__) || defined(__SPIRV__)) +#if __DEVICE_CODE #ifdef DES __devicelib_default_work_group_joint_sort_descending_p1u8_p1i64_u32_p1i8( keys, vals, n, scratch); @@ -251,7 +251,7 @@ int main() { auto work_group_sorter8 = [](uint8_t *keys, float *vals, uint32_t n, uint8_t *scratch) { -#if defined(__SYCL_DEVICE_ONLY__) && (defined(__SPIR__) || defined(__SPIRV__)) +#if __DEVICE_CODE #ifdef DES __devicelib_default_work_group_joint_sort_descending_p1u8_p1f32_u32_p1i8( keys, vals, n, scratch); @@ -264,7 +264,7 @@ int main() { auto work_group_sorter9 = [](int8_t *keys, uint8_t *vals, uint32_t n, uint8_t *scratch) { -#if defined(__SYCL_DEVICE_ONLY__) && (defined(__SPIR__) || defined(__SPIRV__)) +#if __DEVICE_CODE #ifdef DES __devicelib_default_work_group_joint_sort_descending_p1i8_p1u8_u32_p1i8( keys, vals, n, scratch); @@ -277,7 +277,7 @@ int main() { auto work_group_sorter10 = [](int8_t *keys, uint16_t *vals, uint32_t n, uint8_t *scratch) { -#if defined(__SYCL_DEVICE_ONLY__) && (defined(__SPIR__) || defined(__SPIRV__)) +#if __DEVICE_CODE #ifdef DES __devicelib_default_work_group_joint_sort_descending_p1i8_p1u16_u32_p1i8( keys, vals, n, scratch); @@ -290,7 +290,7 @@ int main() { auto work_group_sorter11 = [](int8_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch) { -#if defined(__SYCL_DEVICE_ONLY__) && (defined(__SPIR__) || defined(__SPIRV__)) +#if __DEVICE_CODE #ifdef DES __devicelib_default_work_group_joint_sort_descending_p1i8_p1u32_u32_p1i8( keys, vals, n, scratch); @@ -303,7 +303,7 @@ int main() { auto work_group_sorter12 = [](int8_t *keys, uint64_t *vals, uint32_t n, uint8_t *scratch) { -#if defined(__SYCL_DEVICE_ONLY__) && (defined(__SPIR__) || defined(__SPIRV__)) +#if __DEVICE_CODE #ifdef DES __devicelib_default_work_group_joint_sort_descending_p1i8_p1u64_u32_p1i8( keys, vals, n, scratch); @@ -316,7 +316,7 @@ int main() { auto work_group_sorter13 = [](int8_t *keys, int8_t *vals, uint32_t n, uint8_t *scratch) { -#if defined(__SYCL_DEVICE_ONLY__) && (defined(__SPIR__) || defined(__SPIRV__)) +#if __DEVICE_CODE #ifdef DES __devicelib_default_work_group_joint_sort_descending_p1i8_p1i8_u32_p1i8( keys, vals, n, scratch); @@ -329,7 +329,7 @@ int main() { auto work_group_sorter14 = [](int8_t *keys, int16_t *vals, uint32_t n, uint8_t *scratch) { -#if defined(__SYCL_DEVICE_ONLY__) && (defined(__SPIR__) || defined(__SPIRV__)) +#if __DEVICE_CODE #ifdef DES __devicelib_default_work_group_joint_sort_descending_p1i8_p1i16_u32_p1i8( keys, vals, n, scratch); @@ -938,7 +938,7 @@ int main() { 761, 96, 765, 10000, 6364, 90, 525}; auto work_group_sorter = [](uint32_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch) { -#if defined(__SYCL_DEVICE_ONLY__) && (defined(__SPIR__) || defined(__SPIRV__)) +#if __DEVICE_CODE #ifdef DES __devicelib_default_work_group_joint_sort_descending_p1u32_p1u32_u32_p1i8( keys, vals, n, scratch); @@ -969,7 +969,7 @@ int main() { auto work_group_sorter = [](uint32_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch) { -#if defined(__SYCL_DEVICE_ONLY__) && (defined(__SPIR__) || defined(__SPIRV__)) +#if __DEVICE_CODE #ifdef DES __devicelib_default_work_group_joint_sort_descending_p1u32_p1u32_u32_p1i8( keys, vals, n, scratch); diff --git a/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_p1p1_p1.cpp b/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_p1p1_p1.cpp index ec92477842da0..f4319d3a91a05 100644 --- a/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_p1p1_p1.cpp +++ b/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_p1p1_p1.cpp @@ -48,6 +48,11 @@ void test_work_group_KV_private_sort(sycl::queue &q, KeyT input_keys[NUM], #endif std::stable_sort(sorted_vec.begin(), sorted_vec.end(), kv_tuple_comp); + /*for (size_t idx = 0; idx < NUM; ++idx) { + std::cout << "key: " << (int)std::get<0>(sorted_vec[idx]) << " val: " << + (int)std::get<1>(sorted_vec[idx]) << std::endl; + }*/ + nd_range<1> num_items((range<1>(wg_size)), (range<1>(wg_size))); { buffer ikeys_buf(input_keys, NUM); @@ -83,6 +88,11 @@ void test_work_group_KV_private_sort(sycl::queue &q, KeyT input_keys[NUM], }).wait(); } + /* for (size_t idx = 0; idx < NUM; ++idx) { + std::cout << "key: " << (int)(input_keys[idx]) << " val: " << + (int)(input_vals[idx]) << std::endl; + }*/ + sycl::free(scratch_ptr, q); bool fails = false; for (size_t idx = 0; idx < NUM; ++idx) { @@ -111,12 +121,14 @@ int main() { 9, 4324, 9123, 0, 1232, 777, 555, 314159}; auto work_group_sorter = [](uint32_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch) { +#if __DEVICE_CODE #ifdef DES __devicelib_default_work_group_private_sort_close_descending_p1u32_p1u32_u32_p1i8( keys, vals, n, scratch); #else __devicelib_default_work_group_private_sort_close_ascending_p1u32_p1u32_u32_p1i8( keys, vals, n, scratch); +#endif #endif }; test_work_group_KV_private_sort( - q, ikeys, ivals, work_group_sorter); - std::cout << "KV private sort p1u8_p1u32_u32_p1i8 pass." << std::endl; + + auto work_group_sorter1 = [](uint8_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES + __devicelib_default_work_group_private_sort_close_descending_p1u8_p1u8_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1u8_p1u8_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif + }; + + auto work_group_sorter2 = [](uint8_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES + __devicelib_default_work_group_private_sort_close_descending_p1u8_p1i8_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1u8_p1i8_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif + }; + + auto work_group_sorter3 = [](uint8_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES + __devicelib_default_work_group_private_sort_close_descending_p1u8_p1u16_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1u8_p1u16_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif + }; + + auto work_group_sorter4 = [](uint8_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES + __devicelib_default_work_group_private_sort_close_descending_p1u8_p1i16_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1u8_p1i16_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif + }; + + auto work_group_sorter5 = [](uint8_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES + __devicelib_default_work_group_private_sort_close_descending_p1u8_p1u32_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1u8_p1u32_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif + }; + + auto work_group_sorter6 = [](uint8_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES + __devicelib_default_work_group_private_sort_close_descending_p1u8_p1i32_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1u8_p1i32_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif + }; + + constexpr static int NUM1 = 32; + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 4 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 8 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 16 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 32 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 5 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 7 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 35 pass." << std::endl; + + constexpr static int NUM2 = 24; + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 3 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 4 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 6 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 8 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 12 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 24 pass." << std::endl; + + constexpr static int NUM3 = 20; + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 4 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 5 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 10 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 20 pass." << std::endl; + + constexpr static int NUM4 = 30; + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 3 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 5 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 6 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 10 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 15 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 30 pass." << std::endl; + + constexpr static int NUM5 = 25; + test_work_group_KV_private_sort( + q, ikeys, ivals6, work_group_sorter6); + std::cout << "KV private sort NUM = " << NUM5 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals6, work_group_sorter6); + std::cout << "KV private sort NUM = " << NUM5 + << ", WG = 5 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals6, work_group_sorter6); + std::cout << "KV private sort NUM = " << NUM5 + << ", WG = 25 pass." << std::endl; } } From 287498cc37cb75d33755cfb2226b86bb7cd368d9 Mon Sep 17 00:00:00 2001 From: jinge90 Date: Tue, 24 Sep 2024 12:30:00 +0800 Subject: [PATCH 51/71] add test for u8_u64 Signed-off-by: jinge90 --- .../group_private_KV_sort_p1p1_p1.hpp | 8 ++ .../workgroup_private_KV_sort_p1p1_p1.cpp | 133 ++++++++++++++++++ 2 files changed, 141 insertions(+) diff --git a/sycl/test-e2e/DeviceLib/group_sort/group_private_KV_sort_p1p1_p1.hpp b/sycl/test-e2e/DeviceLib/group_sort/group_private_KV_sort_p1p1_p1.hpp index c2351b29b774a..39af6ff52afd0 100644 --- a/sycl/test-e2e/DeviceLib/group_sort/group_private_KV_sort_p1p1_p1.hpp +++ b/sycl/test-e2e/DeviceLib/group_sort/group_private_KV_sort_p1p1_p1.hpp @@ -64,6 +64,14 @@ __DPCPP_SYCL_EXTERNAL extern "C" void __devicelib_default_work_group_private_sort_close_descending_p1u8_p1u64_u32_p1i8( uint8_t *keys, uint64_t *vals, uint32_t n, uint8_t *scratch); +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1u8_p1i64_u32_p1i8( + uint8_t *keys, int64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1u8_p1i64_u32_p1i8( + uint8_t *keys, int64_t *vals, uint32_t n, uint8_t *scratch); + __DPCPP_SYCL_EXTERNAL extern "C" void __devicelib_default_work_group_private_sort_close_ascending_p1u8_p1f32_u32_p1i8( uint8_t *keys, float *vals, uint32_t n, uint8_t *scratch); diff --git a/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_p1p1_p1.cpp b/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_p1p1_p1.cpp index f4319d3a91a05..eef870cdebc50 100644 --- a/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_p1p1_p1.cpp +++ b/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_p1p1_p1.cpp @@ -183,6 +183,39 @@ int main() { 1098974670, 56900257, 876775101, -1496897817, 1172877939, 1528916082, 559152364, 749878571, 2071902702, -430851798}; + uint64_t ivals7[NUM] = {7916688577774406903ULL, 16873442000174444235ULL, + 5261140449171682429ULL, 6274209348061756377ULL, + 17881284978944229367ULL, 4701456380424752599ULL, + 6241062870187348613ULL, 8972524466137433448ULL, + 468629112944127776ULL, 17523909311582893643ULL, + 17447772191733931166ULL, 14311152396797789854ULL, + 9265327272409079662ULL, 9958475911404242556ULL, + 15359829357736068242ULL, 11416531655738519189ULL, + 16839972510321195914ULL, 1927049095689256442ULL, + 3356565661065236628ULL, 1065114285796701562ULL, + 7071763288904613033ULL, 16473053015775147286ULL, + 10317354477399696817ULL, 16005969584273256379ULL, + 15391010921709289298ULL, 17671303749287233862ULL, + 8028596930095411867ULL, 10265936863337610975ULL, + 17403745948359398749ULL, 8504886230707230194ULL, + 12855085916215721214ULL, 5562885793068933146ULL, + 1508385574711135517ULL, 5953119477818575536ULL, + 9165320150094769334ULL}; + + int64_t ivals8[NUM] = { + 2944696543084623337, 137239101631340692, 4370169869966467498, + 3842452903439153631, -795080033670806202, 3023506421574592237, + -4142692575864168559, 1716333381567689984, 1591746912204250089, + -1974664662220599925, 3144022139297218102, -371429365537296255, + 4202906659701034264, 3878513012313576184, -3425767072006791628, + -2929337291418891626, 1880013370888913338, 1498977159463939728, + -2928775660744278650, 4074214991200977615, 4291797122649374026, + -763110360214750992, 2883673064242977727, 4270151072450399778, + 1408696225027958214, 1264214335825459628, -4152065441956669638, + 2684706424226400837, 569335474272794084, -2798088842177577224, + 814002749054152728, 2517003920904582842, 4089891582575745386, + 705067059635512048, -2500935118374519236}; + auto work_group_sorter = [](uint8_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch) { #if __DEVICE_CODE @@ -274,6 +307,32 @@ int main() { #endif }; + auto work_group_sorter7 = [](uint8_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES + __devicelib_default_work_group_private_sort_close_descending_p1u8_p1u64_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1u8_p1u64_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif + }; + + auto work_group_sorter8 = [](uint8_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES + __devicelib_default_work_group_private_sort_close_descending_p1u8_p1i64_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1u8_p1i64_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif + }; + constexpr static int NUM1 = 32; test_work_group_KV_private_sort( @@ -488,5 +547,79 @@ int main() { q, ikeys, ivals6, work_group_sorter6); std::cout << "KV private sort NUM = " << NUM5 << ", WG = 25 pass." << std::endl; + + constexpr static int NUM6 = 30; + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 3 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 5 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 6 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 10 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 15 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 30 pass." << std::endl; + + constexpr static int NUM7 = 21; + test_work_group_KV_private_sort( + q, ikeys, ivals8, work_group_sorter8); + std::cout << "KV private sort NUM = " << NUM7 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals8, work_group_sorter8); + std::cout << "KV private sort NUM = " << NUM7 + << ", WG = 3 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals8, work_group_sorter8); + std::cout << "KV private sort NUM = " << NUM7 + << ", WG = 7 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals8, work_group_sorter8); + std::cout << "KV private sort NUM = " << NUM7 + << ", WG = 21 pass." << std::endl; } } From 343f56dd3c28dd9d7ca1a1853d0817f1645cbf53 Mon Sep 17 00:00:00 2001 From: jinge90 Date: Thu, 26 Sep 2024 11:57:15 +0800 Subject: [PATCH 52/71] add KV sort for i8 key Signed-off-by: jinge90 --- libdevice/fallback-gsort.cpp | 126 ++++ .../group_private_KV_sort_p1p1_p1.hpp | 68 ++ .../workgroup_private_KV_sort_i8.cpp | 597 ++++++++++++++++++ ...1.cpp => workgroup_private_KV_sort_u8.cpp} | 0 4 files changed, 791 insertions(+) create mode 100644 sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_i8.cpp rename sycl/test-e2e/DeviceLib/group_sort/{workgroup_private_KV_sort_p1p1_p1.cpp => workgroup_private_KV_sort_u8.cpp} (100%) diff --git a/libdevice/fallback-gsort.cpp b/libdevice/fallback-gsort.cpp index a4140133f49ff..093704129bed0 100644 --- a/libdevice/fallback-gsort.cpp +++ b/libdevice/fallback-gsort.cpp @@ -2034,6 +2034,132 @@ void WG_PS_CD(p1u8_p1f32_u32_p1i8)(uint8_t *keys, float *vals, uint32_t n, n, scratch, std::greater_equal{}); } +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1i8_p1u8_u32_p1i8)(int8_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1i8_p1u8_u32_p1i8)(int8_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1i8_p1i8_u32_p1i8)(int8_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), n, + scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1i8_p1i8_u32_p1i8)(int8_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), n, + scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1i8_p1u16_u32_p1i8)(int8_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1i8_p1u16_u32_p1i8)(int8_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1i8_p1i16_u32_p1i8)(int8_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1i8_p1i16_u32_p1i8)(int8_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1i8_p1u32_u32_p1i8)(int8_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1i8_p1u32_u32_p1i8)(int8_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1i8_p1i32_u32_p1i8)(int8_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1i8_p1i32_u32_p1i8)(int8_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1i8_p1u64_u32_p1i8)(int8_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1i8_p1u64_u32_p1i8)(int8_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1i8_p1i64_u32_p1i8)(int8_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1i8_p1i64_u32_p1i8)(int8_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1i8_p1f32_u32_p1i8)(int8_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1i8_p1f32_u32_p1i8)(int8_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::greater_equal{}); +} + DEVICE_EXTERN_C_INLINE void WG_PS_CA(p1u32_p1u32_u32_p1i8)(uint32_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch) { diff --git a/sycl/test-e2e/DeviceLib/group_sort/group_private_KV_sort_p1p1_p1.hpp b/sycl/test-e2e/DeviceLib/group_sort/group_private_KV_sort_p1p1_p1.hpp index 39af6ff52afd0..bad5753d36685 100644 --- a/sycl/test-e2e/DeviceLib/group_sort/group_private_KV_sort_p1p1_p1.hpp +++ b/sycl/test-e2e/DeviceLib/group_sort/group_private_KV_sort_p1p1_p1.hpp @@ -75,3 +75,71 @@ __devicelib_default_work_group_private_sort_close_descending_p1u8_p1i64_u32_p1i8 __DPCPP_SYCL_EXTERNAL extern "C" void __devicelib_default_work_group_private_sort_close_ascending_p1u8_p1f32_u32_p1i8( uint8_t *keys, float *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1i8_p1u8_u32_p1i8( + int8_t *keys, uint8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1i8_p1u8_u32_p1i8( + int8_t *keys, uint8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1i8_p1i8_u32_p1i8( + int8_t *keys, int8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1i8_p1i8_u32_p1i8( + int8_t *keys, int8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1i8_p1u16_u32_p1i8( + int8_t *keys, uint16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1i8_p1u16_u32_p1i8( + int8_t *keys, uint16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1i8_p1i16_u32_p1i8( + int8_t *keys, int16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1i8_p1i16_u32_p1i8( + int8_t *keys, int16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1i8_p1u32_u32_p1i8( + int8_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1i8_p1u32_u32_p1i8( + int8_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1i8_p1i32_u32_p1i8( + int8_t *keys, int32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1i8_p1i32_u32_p1i8( + int8_t *keys, int32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1i8_p1u64_u32_p1i8( + int8_t *keys, uint64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1i8_p1u64_u32_p1i8( + int8_t *keys, uint64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1i8_p1i64_u32_p1i8( + int8_t *keys, int64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1i8_p1i64_u32_p1i8( + int8_t *keys, int64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1i8_p1f32_u32_p1i8( + int8_t *keys, float *vals, uint32_t n, uint8_t *scratch); diff --git a/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_i8.cpp b/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_i8.cpp new file mode 100644 index 0000000000000..316cb7eb8d8ff --- /dev/null +++ b/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_i8.cpp @@ -0,0 +1,597 @@ +// RUN: %{build} -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -fsycl-device-lib-jit-link -o %t.out +// RUN: %{run} %t.out + +// UNSUPPORTED: cuda || hip || cpu + +#include "group_private_KV_sort_p1p1_p1.hpp" +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace sycl; + +template +void test_work_group_KV_private_sort(sycl::queue &q, KeyT input_keys[NUM], + ValT input_vals[NUM], SortHelper gsh) { + static_assert((NUM % WG_SZ == 0), + "Input number must be divisible by work group size!"); + size_t scratch_size = 2 * NUM * (sizeof(KeyT) + sizeof(ValT)) + + std::max(alignof(KeyT), alignof(ValT)); + uint8_t *scratch_ptr = + (uint8_t *)aligned_alloc_device(alignof(KeyT), scratch_size, q); + const static size_t wg_size = WG_SZ; + constexpr size_t num_per_work_item = NUM / WG_SZ; + KeyT output_keys[NUM]; + ValT output_vals[NUM]; + std::vector> sorted_vec; + for (size_t idx = 0; idx < NUM; ++idx) + sorted_vec.push_back(std::make_tuple(input_keys[idx], input_vals[idx])); +#ifdef DES + auto kv_tuple_comp = [](const std::tuple &t1, + const std::tuple &t2) { + return std::get<0>(t1) > std::get<0>(t2); + }; +#else + auto kv_tuple_comp = [](const std::tuple &t1, + const std::tuple &t2) { + return std::get<0>(t1) < std::get<0>(t2); + }; +#endif + std::stable_sort(sorted_vec.begin(), sorted_vec.end(), kv_tuple_comp); + + /*for (size_t idx = 0; idx < NUM; ++idx) { + std::cout << "key: " << (int)std::get<0>(sorted_vec[idx]) << " val: " << + (int)std::get<1>(sorted_vec[idx]) << std::endl; + }*/ + + nd_range<1> num_items((range<1>(wg_size)), (range<1>(wg_size))); + { + buffer ikeys_buf(input_keys, NUM); + buffer ivals_buf(input_vals, NUM); + buffer okeys_buf(output_keys, NUM); + buffer ovals_buf(output_vals, NUM); + q.submit([&](auto &h) { + accessor ikeys_acc{ikeys_buf, h}; + accessor ivals_acc{ivals_buf, h}; + accessor okeys_acc{okeys_buf, h}; + accessor ovals_acc{ovals_buf, h}; + sycl::stream os(1024, 128, h); + h.parallel_for(num_items, [=](nd_item<1> i) { + KeyT pkeys[num_per_work_item]; + ValT pvals[num_per_work_item]; + // copy from global input to fix-size private array. + for (size_t idx = 0; idx < num_per_work_item; ++idx) { + pkeys[idx] = + ikeys_acc[i.get_local_linear_id() * num_per_work_item + idx]; + pvals[idx] = + ivals_acc[i.get_local_linear_id() * num_per_work_item + idx]; + } + + gsh(pkeys, pvals, num_per_work_item, scratch_ptr); + + for (size_t idx = 0; idx < num_per_work_item; ++idx) { + okeys_acc[i.get_local_linear_id() * num_per_work_item + idx] = + pkeys[idx]; + ovals_acc[i.get_local_linear_id() * num_per_work_item + idx] = + pvals[idx]; + } + }); + }).wait(); + } + + /* for (size_t idx = 0; idx < NUM; ++idx) { + std::cout << "key: " << (int)(input_keys[idx]) << " val: " << + (int)(input_vals[idx]) << std::endl; + }*/ + + sycl::free(scratch_ptr, q); + bool fails = false; + for (size_t idx = 0; idx < NUM; ++idx) { + if ((output_keys[idx] != std::get<0>(sorted_vec[idx])) || + (output_vals[idx] != std::get<1>(sorted_vec[idx]))) { + std::cout << "idx: " << idx << std::endl; + fails = true; + break; + } + } + assert(!fails); +} + +int main() { + queue q; + + { + constexpr static int NUM = 35; + int8_t ikeys[NUM] = {1, -11, 1, 9, -3, 100, 34, 8, 121, + -77, 125, 23, 36, -2, -111, 91, 88, -2, + 51, -23, -81, 83, 31, 42, 2, 1, -99, + 124, 12, 0, -81, 17, 15, 101, 44}; + uint32_t ivals[NUM] = {99, 32, 1, 2, 67, 9123, 453, + 435, 91111, 777, 165, 145, 2456, 88811, + 761, 96, 765, 10000, 6364, 90, 525, + 882, 1, 2423, 9, 4324, 9123, 0, + 1232, 777, 555, 314159, 905, 9831, 84341}; + uint8_t ivals1[NUM] = {99, 32, 1, 2, 67, 91, 45, 43, 91, 77, 16, 14, + 24, 88, 76, 96, 76, 100, 63, 90, 52, 82, 1, 22, + 9, 225, 127, 0, 12, 128, 3, 102, 200, 111, 123}; + int8_t ivals2[NUM] = {-99, 127, -121, 100, 9, 5, 12, 35, -98, + 77, 112, -91, 11, 12, 3, 71, -66, 121, + 18, 14, 21, -22, 54, 88, -81, 31, 23, + 53, 97, 103, 71, 83, 97, 37, -41}; + uint16_t ivals3[NUM] = {28831, 23870, 54250, 5022, 9571, 60147, 9554, + 18818, 28689, 18229, 40512, 23200, 40454, 24841, + 43251, 63264, 29448, 45917, 882, 30788, 7586, + 57541, 22108, 59535, 31880, 7152, 63919, 58703, + 14686, 29914, 5872, 35868, 51479, 22721, 50927}; + int16_t ivals4[NUM] = { + 2798, -13656, 1592, 3992, -25870, 25172, 7761, -18347, 1617, + 25472, 26763, -5982, 24791, 27189, 22911, 22502, 15801, 25326, + -2196, 9205, -10418, 20464, -16616, -11285, 7249, 22866, 30574, + -1298, 31351, 28252, 21322, -10072, 7874, -26785, 22016}; + + uint32_t ivals5[NUM] = { + 2238578408, 102907035, 2316773768, 617902655, 532045482, 73173328, + 1862406505, 142735533, 3494078873, 610196959, 4210902254, 1863122236, + 1257721692, 30008197, 3199012044, 3503276708, 3504950001, 1240383071, + 2463430884, 904104390, 4044803029, 3164373711, 1586440767, 1999536602, + 3377662770, 927279985, 1740225703, 1133653675, 3975816601, 260339911, + 1115507520, 2279020820, 4289105012, 692964674, 53775301}; + + int32_t ivals6[NUM] = { + 507394811, 1949685322, 1624859474, -940434061, -1440675113, + -2002743224, 369969519, 840772268, 224522238, 296113452, + -714007528, 480713824, 665592454, 1696360848, 780843358, + -1901994531, 1667711523, 1390737696, 1357434904, -290165630, + 305128121, 1301489180, 630469211, -1385846315, 809333959, + 1098974670, 56900257, 876775101, -1496897817, 1172877939, + 1528916082, 559152364, 749878571, 2071902702, -430851798}; + + uint64_t ivals7[NUM] = {7916688577774406903ULL, 16873442000174444235ULL, + 5261140449171682429ULL, 6274209348061756377ULL, + 17881284978944229367ULL, 4701456380424752599ULL, + 6241062870187348613ULL, 8972524466137433448ULL, + 468629112944127776ULL, 17523909311582893643ULL, + 17447772191733931166ULL, 14311152396797789854ULL, + 9265327272409079662ULL, 9958475911404242556ULL, + 15359829357736068242ULL, 11416531655738519189ULL, + 16839972510321195914ULL, 1927049095689256442ULL, + 3356565661065236628ULL, 1065114285796701562ULL, + 7071763288904613033ULL, 16473053015775147286ULL, + 10317354477399696817ULL, 16005969584273256379ULL, + 15391010921709289298ULL, 17671303749287233862ULL, + 8028596930095411867ULL, 10265936863337610975ULL, + 17403745948359398749ULL, 8504886230707230194ULL, + 12855085916215721214ULL, 5562885793068933146ULL, + 1508385574711135517ULL, 5953119477818575536ULL, + 9165320150094769334ULL}; + + int64_t ivals8[NUM] = { + 2944696543084623337, 137239101631340692, 4370169869966467498, + 3842452903439153631, -795080033670806202, 3023506421574592237, + -4142692575864168559, 1716333381567689984, 1591746912204250089, + -1974664662220599925, 3144022139297218102, -371429365537296255, + 4202906659701034264, 3878513012313576184, -3425767072006791628, + -2929337291418891626, 1880013370888913338, 1498977159463939728, + -2928775660744278650, 4074214991200977615, 4291797122649374026, + -763110360214750992, 2883673064242977727, 4270151072450399778, + 1408696225027958214, 1264214335825459628, -4152065441956669638, + 2684706424226400837, 569335474272794084, -2798088842177577224, + 814002749054152728, 2517003920904582842, 4089891582575745386, + 705067059635512048, -2500935118374519236}; + + auto work_group_sorter = [](int8_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES + __devicelib_default_work_group_private_sort_close_descending_p1i8_p1u32_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1i8_p1u32_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif + }; + + auto work_group_sorter1 = [](int8_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES + __devicelib_default_work_group_private_sort_close_descending_p1i8_p1u8_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1i8_p1u8_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif + }; + + auto work_group_sorter2 = [](int8_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES + __devicelib_default_work_group_private_sort_close_descending_p1i8_p1i8_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1i8_p1i8_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif + }; + + auto work_group_sorter3 = [](int8_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES + __devicelib_default_work_group_private_sort_close_descending_p1i8_p1u16_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1i8_p1u16_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif + }; + + auto work_group_sorter4 = [](int8_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES + __devicelib_default_work_group_private_sort_close_descending_p1i8_p1i16_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1i8_p1i16_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif + }; + + auto work_group_sorter5 = [](int8_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES + __devicelib_default_work_group_private_sort_close_descending_p1i8_p1u32_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1i8_p1u32_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif + }; + + auto work_group_sorter6 = [](int8_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES + __devicelib_default_work_group_private_sort_close_descending_p1i8_p1i32_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1i8_p1i32_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif + }; + + auto work_group_sorter7 = [](int8_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES + __devicelib_default_work_group_private_sort_close_descending_p1i8_p1u64_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1i8_p1u64_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif + }; + + auto work_group_sorter8 = [](int8_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES + __devicelib_default_work_group_private_sort_close_descending_p1i8_p1i64_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1i8_p1i64_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif + }; + + constexpr static int NUM1 = 32; + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 4 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 8 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 16 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 32 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 5 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 7 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 35 pass." << std::endl; + + constexpr static int NUM2 = 24; + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 3 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 4 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 6 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 8 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 12 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 24 pass." << std::endl; + + constexpr static int NUM3 = 20; + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 4 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 5 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 10 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 20 pass." << std::endl; + + constexpr static int NUM4 = 30; + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 3 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 5 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 6 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 10 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 15 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 30 pass." << std::endl; + + constexpr static int NUM5 = 25; + test_work_group_KV_private_sort( + q, ikeys, ivals6, work_group_sorter6); + std::cout << "KV private sort NUM = " << NUM5 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals6, work_group_sorter6); + std::cout << "KV private sort NUM = " << NUM5 + << ", WG = 5 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals6, work_group_sorter6); + std::cout << "KV private sort NUM = " << NUM5 + << ", WG = 25 pass." << std::endl; + + constexpr static int NUM6 = 30; + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 3 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 5 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 6 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 10 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 15 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 30 pass." << std::endl; + + constexpr static int NUM7 = 21; + test_work_group_KV_private_sort( + q, ikeys, ivals8, work_group_sorter8); + std::cout << "KV private sort NUM = " << NUM7 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals8, work_group_sorter8); + std::cout << "KV private sort NUM = " << NUM7 + << ", WG = 3 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals8, work_group_sorter8); + std::cout << "KV private sort NUM = " << NUM7 + << ", WG = 7 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals8, work_group_sorter8); + std::cout << "KV private sort NUM = " << NUM7 + << ", WG = 21 pass." << std::endl; + } +} diff --git a/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_p1p1_p1.cpp b/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_u8.cpp similarity index 100% rename from sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_p1p1_p1.cpp rename to sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_u8.cpp From cff0a6b845880da6d3f0634d1d33cc877121252a Mon Sep 17 00:00:00 2001 From: jinge90 Date: Sun, 29 Sep 2024 16:31:44 +0800 Subject: [PATCH 53/71] add work group private sort for i32/u32/i64/u64 key type Signed-off-by: jinge90 --- libdevice/fallback-gsort.cpp | 754 +++++++++++++++++++++++++++++++++++ 1 file changed, 754 insertions(+) diff --git a/libdevice/fallback-gsort.cpp b/libdevice/fallback-gsort.cpp index 093704129bed0..f4de0eb50e465 100644 --- a/libdevice/fallback-gsort.cpp +++ b/libdevice/fallback-gsort.cpp @@ -2160,6 +2160,319 @@ void WG_PS_CD(p1i8_p1f32_u32_p1i8)(int8_t *keys, float *vals, uint32_t n, n, scratch, std::greater_equal{}); } +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1u16_p1u8_u32_p1i8)(uint16_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1u16_p1u8_u32_p1i8)(uint16_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1u16_p1i8_u32_p1i8)(uint16_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), n, + scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1u16_p1i8_u32_p1i8)(uint16_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), n, + scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1u16_p1u16_u32_p1i8)(uint16_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1u16_p1u16_u32_p1i8)(uint16_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1u16_p1i16_u32_p1i8)(uint16_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1u16_p1i16_u32_p1i8)(uint16_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1u16_p1u32_u32_p1i8)(uint16_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1u16_p1u32_u32_p1i8)(uint16_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1u16_p1i32_u32_p1i8)(uint16_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1u16_p1i32_u32_p1i8)(uint16_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1u16_p1u64_u32_p1i8)(uint16_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1u16_p1u64_u32_p1i8)(uint16_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1u16_p1i64_u32_p1i8)(uint16_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1u16_p1i64_u32_p1i8)(uint16_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1u16_p1f32_u32_p1i8)(uint16_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1u16_p1f32_u32_p1i8)(uint16_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1i16_p1u8_u32_p1i8)(int16_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1i16_p1u8_u32_p1i8)(int16_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1i16_p1i8_u32_p1i8)(int16_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), n, + scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1i16_p1i8_u32_p1i8)(int16_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), n, + scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1i16_p1u16_u32_p1i8)(int16_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1i16_p1u16_u32_p1i8)(int16_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1i16_p1i16_u32_p1i8)(int16_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1i16_p1i16_u32_p1i8)(int16_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1i16_p1u32_u32_p1i8)(int16_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1i16_p1u32_u32_p1i8)(int16_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1i16_p1i32_u32_p1i8)(int16_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1i16_p1i32_u32_p1i8)(int16_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1i16_p1u64_u32_p1i8)(int16_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1i16_p1u64_u32_p1i8)(int16_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1i16_p1i64_u32_p1i8)(int16_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1i16_p1i64_u32_p1i8)(int16_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1i16_p1f32_u32_p1i8)(int16_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1i16_p1f32_u32_p1i8)(int16_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1u32_p1u8_u32_p1i8)(uint32_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1u32_p1u8_u32_p1i8)(uint32_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1u32_p1i8_u32_p1i8)(uint32_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), n, + scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1u32_p1i8_u32_p1i8)(uint32_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), n, + scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1u32_p1u16_u32_p1i8)(uint32_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1u32_p1u16_u32_p1i8)(uint32_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1u32_p1i16_u32_p1i8)(uint32_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1u32_p1i16_u32_p1i8)(uint32_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + DEVICE_EXTERN_C_INLINE void WG_PS_CA(p1u32_p1u32_u32_p1i8)(uint32_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch) { @@ -2174,4 +2487,445 @@ void WG_PS_CD(p1u32_p1u32_u32_p1i8)(uint32_t *keys, uint32_t *vals, uint32_t n, std::greater_equal{}); } +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1u32_p1i32_u32_p1i8)(uint32_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1u32_p1i32_u32_p1i8)(uint32_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1u32_p1u64_u32_p1i8)(uint32_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1u32_p1u64_u32_p1i8)(uint32_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1u32_p1i64_u32_p1i8)(uint32_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1u32_p1i64_u32_p1i8)(uint32_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1u32_p1f32_u32_p1i8)(uint32_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1u32_p1f32_u32_p1i8)(uint32_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1i32_p1u8_u32_p1i8)(int32_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1i32_p1u8_u32_p1i8)(int32_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1i32_p1i8_u32_p1i8)(int32_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), n, + scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1i32_p1i8_u32_p1i8)(int32_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), n, + scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1i32_p1u16_u32_p1i8)(int32_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1i32_p1u16_u32_p1i8)(int32_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1i32_p1i16_u32_p1i8)(int32_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1i32_p1i16_u32_p1i8)(int32_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1i32_p1u32_u32_p1i8)(int32_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1i32_p1u32_u32_p1i8)(int32_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1i32_p1i32_u32_p1i8)(int32_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1i32_p1i32_u32_p1i8)(int32_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1i32_p1u64_u32_p1i8)(int32_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1i32_p1u64_u32_p1i8)(int32_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1i32_p1i64_u32_p1i8)(int32_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1i32_p1i64_u32_p1i8)(int32_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1i32_p1f32_u32_p1i8)(int32_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1i32_p1f32_u32_p1i8)(int32_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1u64_p1u8_u32_p1i8)(uint64_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1u64_p1u8_u32_p1i8)(uint64_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1u64_p1i8_u32_p1i8)(uint64_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), n, + scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1u64_p1i8_u32_p1i8)(uint64_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), n, + scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1u64_p1u16_u32_p1i8)(uint64_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1u64_p1u16_u32_p1i8)(uint64_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1u64_p1i16_u32_p1i8)(uint64_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1u64_p1i16_u32_p1i8)(uint64_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1u64_p1u32_u32_p1i8)(uint64_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1u64_p1u32_u32_p1i8)(uint64_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1u64_p1i32_u32_p1i8)(uint64_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1u64_p1i32_u32_p1i8)(uint64_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1u64_p1u64_u32_p1i8)(uint64_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1u64_p1u64_u32_p1i8)(uint64_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1u64_p1i64_u32_p1i8)(uint64_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1u64_p1i64_u32_p1i8)(uint64_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1u64_p1f32_u32_p1i8)(uint64_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1u64_p1f32_u32_p1i8)(uint64_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1i64_p1u8_u32_p1i8)(int64_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1i64_p1u8_u32_p1i8)(int64_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1i64_p1i8_u32_p1i8)(int64_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), n, + scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1i64_p1i8_u32_p1i8)(int64_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), n, + scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1i64_p1u16_u32_p1i8)(int64_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1i64_p1u16_u32_p1i8)(int64_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1i64_p1i16_u32_p1i8)(int64_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1i64_p1i16_u32_p1i8)(int64_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1i64_p1u32_u32_p1i8)(int64_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1i64_p1u32_u32_p1i8)(int64_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1i64_p1i32_u32_p1i8)(int64_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1i64_p1i32_u32_p1i8)(int64_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1i64_p1u64_u32_p1i8)(int64_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1i64_p1u64_u32_p1i8)(int64_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1i64_p1i64_u32_p1i8)(int64_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1i64_p1i64_u32_p1i8)(int64_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1i64_p1f32_u32_p1i8)(int64_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1i64_p1f32_u32_p1i8)(int64_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::greater_equal{}); +} + #endif // __SPIR__ || __SPIRV__ From f43572a27d6009efd11d7467028263be42165d2b Mon Sep 17 00:00:00 2001 From: jinge90 Date: Mon, 30 Sep 2024 15:51:14 +0800 Subject: [PATCH 54/71] add tests for i16 key type Signed-off-by: jinge90 --- .../group_private_KV_sort_p1p1_p1.hpp | 144 +++++ .../workgroup_private_KV_sort_i16.cpp | 604 ++++++++++++++++++ .../workgroup_private_KV_sort_u16.cpp | 597 +++++++++++++++++ 3 files changed, 1345 insertions(+) create mode 100644 sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_i16.cpp create mode 100644 sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_u16.cpp diff --git a/sycl/test-e2e/DeviceLib/group_sort/group_private_KV_sort_p1p1_p1.hpp b/sycl/test-e2e/DeviceLib/group_sort/group_private_KV_sort_p1p1_p1.hpp index bad5753d36685..eec351cb1a376 100644 --- a/sycl/test-e2e/DeviceLib/group_sort/group_private_KV_sort_p1p1_p1.hpp +++ b/sycl/test-e2e/DeviceLib/group_sort/group_private_KV_sort_p1p1_p1.hpp @@ -143,3 +143,147 @@ __devicelib_default_work_group_private_sort_close_descending_p1i8_p1i64_u32_p1i8 __DPCPP_SYCL_EXTERNAL extern "C" void __devicelib_default_work_group_private_sort_close_ascending_p1i8_p1f32_u32_p1i8( int8_t *keys, float *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1u16_p1u8_u32_p1i8( + uint16_t *keys, uint8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1u16_p1u8_u32_p1i8( + uint16_t *keys, uint8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1u16_p1i8_u32_p1i8( + uint16_t *keys, int8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1u16_p1i8_u32_p1i8( + uint16_t *keys, int8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1u16_p1u16_u32_p1i8( + uint16_t *keys, uint16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1u16_p1u16_u32_p1i8( + uint16_t *keys, uint16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1u16_p1i16_u32_p1i8( + uint16_t *keys, int16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1u16_p1i16_u32_p1i8( + uint16_t *keys, int16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1u16_p1u32_u32_p1i8( + uint16_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1u16_p1u32_u32_p1i8( + uint16_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1u16_p1i32_u32_p1i8( + uint16_t *keys, int32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1u16_p1i32_u32_p1i8( + uint16_t *keys, int32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1u16_p1u64_u32_p1i8( + uint16_t *keys, uint64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1u16_p1u64_u32_p1i8( + uint16_t *keys, uint64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1u16_p1i64_u32_p1i8( + uint16_t *keys, int64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1u16_p1i64_u32_p1i8( + uint16_t *keys, int64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1u16_p1f32_u32_p1i8( + uint16_t *keys, float *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1u16_p1f32_u32_p1i8( + uint16_t *keys, float *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1i16_p1u8_u32_p1i8( + int16_t *keys, uint8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1i16_p1u8_u32_p1i8( + int16_t *keys, uint8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1i16_p1i8_u32_p1i8( + int16_t *keys, int8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1i16_p1i8_u32_p1i8( + int16_t *keys, int8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1i16_p1u16_u32_p1i8( + int16_t *keys, uint16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1i16_p1u16_u32_p1i8( + int16_t *keys, uint16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1i16_p1i16_u32_p1i8( + int16_t *keys, int16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1i16_p1i16_u32_p1i8( + int16_t *keys, int16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1i16_p1u32_u32_p1i8( + int16_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1i16_p1u32_u32_p1i8( + int16_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1i16_p1i32_u32_p1i8( + int16_t *keys, int32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1i16_p1i32_u32_p1i8( + int16_t *keys, int32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1i16_p1u64_u32_p1i8( + int16_t *keys, uint64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1i16_p1u64_u32_p1i8( + int16_t *keys, uint64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1i16_p1i64_u32_p1i8( + int16_t *keys, int64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1i16_p1i64_u32_p1i8( + int16_t *keys, int64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1i16_p1f32_u32_p1i8( + int16_t *keys, float *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1i16_p1f32_u32_p1i8( + int16_t *keys, float *vals, uint32_t n, uint8_t *scratch); diff --git a/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_i16.cpp b/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_i16.cpp new file mode 100644 index 0000000000000..38800a2e8c2a0 --- /dev/null +++ b/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_i16.cpp @@ -0,0 +1,604 @@ +// RUN: %{build} -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -fsycl-device-lib-jit-link -o %t.out +// RUN: %{run} %t.out + +// UNSUPPORTED: cuda || hip || cpu + +#include "group_private_KV_sort_p1p1_p1.hpp" +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace sycl; + +template +void test_work_group_KV_private_sort(sycl::queue &q, KeyT keys[NUM], + ValT vals[NUM], SortHelper gsh) { + static_assert((NUM % WG_SZ == 0), + "Input number must be divisible by work group size!"); + + KeyT input_keys[NUM]; + ValT input_vals[NUM]; + memcpy(&input_keys[0], &keys[0], NUM * sizeof(KeyT)); + memcpy(&input_vals[0], &vals[0], NUM * sizeof(ValT)); + for (size_t idx = 0; idx < NUM; ++idx) { + std::cout << "key: " << input_keys[idx] << ", val: " << input_vals[idx] << std::endl; + } + size_t scratch_size = 2 * NUM * (sizeof(KeyT) + sizeof(ValT)) + + std::max(alignof(KeyT), alignof(ValT)); + uint8_t *scratch_ptr = + (uint8_t *)aligned_alloc_device(alignof(KeyT), scratch_size, q); + const static size_t wg_size = WG_SZ; + constexpr size_t num_per_work_item = NUM / WG_SZ; + KeyT output_keys[NUM]; + ValT output_vals[NUM]; + std::vector> sorted_vec; + for (size_t idx = 0; idx < NUM; ++idx) + sorted_vec.push_back(std::make_tuple(input_keys[idx], input_vals[idx])); +#ifdef DES + auto kv_tuple_comp = [](const std::tuple &t1, + const std::tuple &t2) { + return std::get<0>(t1) > std::get<0>(t2); + }; +#else + auto kv_tuple_comp = [](const std::tuple &t1, + const std::tuple &t2) { + return std::get<0>(t1) < std::get<0>(t2); + }; +#endif + std::stable_sort(sorted_vec.begin(), sorted_vec.end(), kv_tuple_comp); + + for (size_t idx = 0; idx < NUM; ++idx) { + std::cout << "key: " << std::get<0>(sorted_vec[idx]) << " val: " << + (int64_t)std::get<1>(sorted_vec[idx]) << std::endl; + } + + nd_range<1> num_items((range<1>(wg_size)), (range<1>(wg_size))); + { + buffer ikeys_buf(input_keys, NUM); + buffer ivals_buf(input_vals, NUM); + buffer okeys_buf(output_keys, NUM); + buffer ovals_buf(output_vals, NUM); + q.submit([&](auto &h) { + accessor ikeys_acc{ikeys_buf, h}; + accessor ivals_acc{ivals_buf, h}; + accessor okeys_acc{okeys_buf, h}; + accessor ovals_acc{ovals_buf, h}; + sycl::stream os(1024, 128, h); + h.parallel_for(num_items, [=](nd_item<1> i) { + KeyT pkeys[num_per_work_item]; + ValT pvals[num_per_work_item]; + // copy from global input to fix-size private array. + for (size_t idx = 0; idx < num_per_work_item; ++idx) { + pkeys[idx] = + ikeys_acc[i.get_local_linear_id() * num_per_work_item + idx]; + pvals[idx] = + ivals_acc[i.get_local_linear_id() * num_per_work_item + idx]; + } + + gsh(pkeys, pvals, num_per_work_item, scratch_ptr); + + for (size_t idx = 0; idx < num_per_work_item; ++idx) { + okeys_acc[i.get_local_linear_id() * num_per_work_item + idx] = + pkeys[idx]; + ovals_acc[i.get_local_linear_id() * num_per_work_item + idx] = + pvals[idx]; + } + }); + }).wait(); + } + + /* for (size_t idx = 0; idx < NUM; ++idx) { + std::cout << "key: " << (int)(input_keys[idx]) << " val: " << + (int)(input_vals[idx]) << std::endl; + }*/ + + sycl::free(scratch_ptr, q); + bool fails = false; + for (size_t idx = 0; idx < NUM; ++idx) { + if ((output_keys[idx] != std::get<0>(sorted_vec[idx])) || + (output_vals[idx] != std::get<1>(sorted_vec[idx]))) { + std::cout << "idx: " << idx << std::endl; + fails = true; + break; + } + } + assert(!fails); +} + +int main() { + queue q; + { + constexpr static int NUM = 35; + int16_t ikeys[NUM] = {1, -11, 1, 9, -3, 100, 34, 8, 121, + -77, 125, 23, 36, -2, -111, 91, 88, -2, + 51, -23, -81, 83, 31, 42, 2, 1, -99, + 124, 12, 0, -81, 17, 15, 101, 44}; + uint32_t ivals[NUM] = {99, 32, 1, 2, 67, 9123, 453, + 435, 91111, 777, 165, 145, 2456, 88811, + 761, 96, 765, 10000, 6364, 90, 525, + 882, 1, 2423, 9, 4324, 9123, 0, + 1232, 777, 555, 314159, 905, 9831, 84341}; + uint8_t ivals1[NUM] = {99, 32, 1, 2, 67, 91, 45, 43, 91, 77, 16, 14, + 24, 88, 76, 96, 76, 100, 63, 90, 52, 82, 1, 22, + 9, 225, 127, 0, 12, 128, 3, 102, 200, 111, 123}; + int8_t ivals2[NUM] = {-99, 127, -121, 100, 9, 5, 12, 35, -98, + 77, 112, -91, 11, 12, 3, 71, -66, 121, + 18, 14, 21, -22, 54, 88, -81, 31, 23, + 53, 97, 103, 71, 83, 97, 37, -41}; + uint16_t ivals3[NUM] = {28831, 23870, 54250, 5022, 9571, 60147, 9554, + 18818, 28689, 18229, 40512, 23200, 40454, 24841, + 43251, 63264, 29448, 45917, 882, 30788, 7586, + 57541, 22108, 59535, 31880, 7152, 63919, 58703, + 14686, 29914, 5872, 35868, 51479, 22721, 50927}; + int16_t ivals4[NUM] = { + 2798, -13656, 1592, 3992, -25870, 25172, 7761, -18347, 1617, + 25472, 26763, -5982, 24791, 27189, 22911, 22502, 15801, 25326, + -2196, 9205, -10418, 20464, -16616, -11285, 7249, 22866, 30574, + -1298, 31351, 28252, 21322, -10072, 7874, -26785, 22016}; + + uint32_t ivals5[NUM] = { + 2238578408, 102907035, 2316773768, 617902655, 532045482, 73173328, + 1862406505, 142735533, 3494078873, 610196959, 4210902254, 1863122236, + 1257721692, 30008197, 3199012044, 3503276708, 3504950001, 1240383071, + 2463430884, 904104390, 4044803029, 3164373711, 1586440767, 1999536602, + 3377662770, 927279985, 1740225703, 1133653675, 3975816601, 260339911, + 1115507520, 2279020820, 4289105012, 692964674, 53775301}; + + int32_t ivals6[NUM] = { + 507394811, 1949685322, 1624859474, -940434061, -1440675113, + -2002743224, 369969519, 840772268, 224522238, 296113452, + -714007528, 480713824, 665592454, 1696360848, 780843358, + -1901994531, 1667711523, 1390737696, 1357434904, -290165630, + 305128121, 1301489180, 630469211, -1385846315, 809333959, + 1098974670, 56900257, 876775101, -1496897817, 1172877939, + 1528916082, 559152364, 749878571, 2071902702, -430851798}; + + uint64_t ivals7[NUM] = {7916688577774406903ULL, 16873442000174444235ULL, + 5261140449171682429ULL, 6274209348061756377ULL, + 17881284978944229367ULL, 4701456380424752599ULL, + 6241062870187348613ULL, 8972524466137433448ULL, + 468629112944127776ULL, 17523909311582893643ULL, + 17447772191733931166ULL, 14311152396797789854ULL, + 9265327272409079662ULL, 9958475911404242556ULL, + 15359829357736068242ULL, 11416531655738519189ULL, + 16839972510321195914ULL, 1927049095689256442ULL, + 3356565661065236628ULL, 1065114285796701562ULL, + 7071763288904613033ULL, 16473053015775147286ULL, + 10317354477399696817ULL, 16005969584273256379ULL, + 15391010921709289298ULL, 17671303749287233862ULL, + 8028596930095411867ULL, 10265936863337610975ULL, + 17403745948359398749ULL, 8504886230707230194ULL, + 12855085916215721214ULL, 5562885793068933146ULL, + 1508385574711135517ULL, 5953119477818575536ULL, + 9165320150094769334ULL}; + + int64_t ivals8[NUM] = { + 2944696543084623337, 137239101631340692, 4370169869966467498, + 3842452903439153631, -795080033670806202, 3023506421574592237, + -4142692575864168559, 1716333381567689984, 1591746912204250089, + -1974664662220599925, 3144022139297218102, -371429365537296255, + 4202906659701034264, 3878513012313576184, -3425767072006791628, + -2929337291418891626, 1880013370888913338, 1498977159463939728, + -2928775660744278650, 4074214991200977615, 4291797122649374026, + -763110360214750992, 2883673064242977727, 4270151072450399778, + 1408696225027958214, 1264214335825459628, -4152065441956669638, + 2684706424226400837, 569335474272794084, -2798088842177577224, + 814002749054152728, 2517003920904582842, 4089891582575745386, + 705067059635512048, -2500935118374519236}; + + auto work_group_sorter = [](int16_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES + __devicelib_default_work_group_private_sort_close_descending_p1i16_p1u32_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1i16_p1u32_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif + }; + + auto work_group_sorter1 = [](int16_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES + __devicelib_default_work_group_private_sort_close_descending_p1i16_p1u8_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1i16_p1u8_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif + }; + + auto work_group_sorter2 = [](int16_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES + __devicelib_default_work_group_private_sort_close_descending_p1i16_p1i8_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1i16_p1i8_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif + }; + + auto work_group_sorter3 = [](int16_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES + __devicelib_default_work_group_private_sort_close_descending_p1i16_p1u16_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1i16_p1u16_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif + }; + + auto work_group_sorter4 = [](int16_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES + __devicelib_default_work_group_private_sort_close_descending_p1i16_p1i16_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1i16_p1i16_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif + }; + + auto work_group_sorter5 = [](int16_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES + __devicelib_default_work_group_private_sort_close_descending_p1i16_p1u32_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1i16_p1u32_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif + }; + + auto work_group_sorter6 = [](int16_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES + __devicelib_default_work_group_private_sort_close_descending_p1i16_p1i32_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1i16_p1i32_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif + }; + + auto work_group_sorter7 = [](int16_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES + __devicelib_default_work_group_private_sort_close_descending_p1i16_p1u64_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1i16_p1u64_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif + }; + + auto work_group_sorter8 = [](int16_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES + __devicelib_default_work_group_private_sort_close_descending_p1i16_p1i64_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1i16_p1i64_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif + }; +#if 0 + constexpr static int NUM1 = 32; + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 4 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 8 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 16 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 32 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 5 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 7 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 35 pass." << std::endl; + + constexpr static int NUM2 = 24; + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 3 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 4 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 6 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 8 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 12 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 24 pass." << std::endl; + + constexpr static int NUM3 = 20; + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 4 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 5 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 10 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 20 pass." << std::endl; + + constexpr static int NUM4 = 30; + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 3 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 5 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 6 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 10 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 15 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 30 pass." << std::endl; + + constexpr static int NUM5 = 25; + test_work_group_KV_private_sort( + q, ikeys, ivals6, work_group_sorter6); + std::cout << "KV private sort NUM = " << NUM5 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals6, work_group_sorter6); + std::cout << "KV private sort NUM = " << NUM5 + << ", WG = 5 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals6, work_group_sorter6); + std::cout << "KV private sort NUM = " << NUM5 + << ", WG = 25 pass." << std::endl; + + constexpr static int NUM6 = 30; + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 3 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 5 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 6 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 10 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 15 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 30 pass." << std::endl; +#endif + constexpr static int NUM7 = 21; + test_work_group_KV_private_sort( + q, ikeys, ivals8, work_group_sorter8); + std::cout << "KV private sort NUM = " << NUM7 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals8, work_group_sorter8); + std::cout << "KV private sort NUM = " << NUM7 + << ", WG = 3 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals8, work_group_sorter8); + std::cout << "KV private sort NUM = " << NUM7 + << ", WG = 7 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals8, work_group_sorter8); + std::cout << "KV private sort NUM = " << NUM7 + << ", WG = 21 pass." << std::endl; + } +} diff --git a/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_u16.cpp b/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_u16.cpp new file mode 100644 index 0000000000000..c1f94ebbe29aa --- /dev/null +++ b/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_u16.cpp @@ -0,0 +1,597 @@ +// RUN: %{build} -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -fsycl-device-lib-jit-link -o %t.out +// RUN: %{run} %t.out + +// UNSUPPORTED: cuda || hip || cpu + +#include "group_private_KV_sort_p1p1_p1.hpp" +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace sycl; + +template +void test_work_group_KV_private_sort(sycl::queue &q, KeyT input_keys[NUM], + ValT input_vals[NUM], SortHelper gsh) { + static_assert((NUM % WG_SZ == 0), + "Input number must be divisible by work group size!"); + size_t scratch_size = 2 * NUM * (sizeof(KeyT) + sizeof(ValT)) + + std::max(alignof(KeyT), alignof(ValT)); + uint8_t *scratch_ptr = + (uint8_t *)aligned_alloc_device(alignof(KeyT), scratch_size, q); + const static size_t wg_size = WG_SZ; + constexpr size_t num_per_work_item = NUM / WG_SZ; + KeyT output_keys[NUM]; + ValT output_vals[NUM]; + std::vector> sorted_vec; + for (size_t idx = 0; idx < NUM; ++idx) + sorted_vec.push_back(std::make_tuple(input_keys[idx], input_vals[idx])); +#ifdef DES + auto kv_tuple_comp = [](const std::tuple &t1, + const std::tuple &t2) { + return std::get<0>(t1) > std::get<0>(t2); + }; +#else + auto kv_tuple_comp = [](const std::tuple &t1, + const std::tuple &t2) { + return std::get<0>(t1) < std::get<0>(t2); + }; +#endif + std::stable_sort(sorted_vec.begin(), sorted_vec.end(), kv_tuple_comp); + + /*for (size_t idx = 0; idx < NUM; ++idx) { + std::cout << "key: " << (int)std::get<0>(sorted_vec[idx]) << " val: " << + (int)std::get<1>(sorted_vec[idx]) << std::endl; + }*/ + + nd_range<1> num_items((range<1>(wg_size)), (range<1>(wg_size))); + { + buffer ikeys_buf(input_keys, NUM); + buffer ivals_buf(input_vals, NUM); + buffer okeys_buf(output_keys, NUM); + buffer ovals_buf(output_vals, NUM); + q.submit([&](auto &h) { + accessor ikeys_acc{ikeys_buf, h}; + accessor ivals_acc{ivals_buf, h}; + accessor okeys_acc{okeys_buf, h}; + accessor ovals_acc{ovals_buf, h}; + sycl::stream os(1024, 128, h); + h.parallel_for(num_items, [=](nd_item<1> i) { + KeyT pkeys[num_per_work_item]; + ValT pvals[num_per_work_item]; + // copy from global input to fix-size private array. + for (size_t idx = 0; idx < num_per_work_item; ++idx) { + pkeys[idx] = + ikeys_acc[i.get_local_linear_id() * num_per_work_item + idx]; + pvals[idx] = + ivals_acc[i.get_local_linear_id() * num_per_work_item + idx]; + } + + gsh(pkeys, pvals, num_per_work_item, scratch_ptr); + + for (size_t idx = 0; idx < num_per_work_item; ++idx) { + okeys_acc[i.get_local_linear_id() * num_per_work_item + idx] = + pkeys[idx]; + ovals_acc[i.get_local_linear_id() * num_per_work_item + idx] = + pvals[idx]; + } + }); + }).wait(); + } + + /* for (size_t idx = 0; idx < NUM; ++idx) { + std::cout << "key: " << (int)(input_keys[idx]) << " val: " << + (int)(input_vals[idx]) << std::endl; + }*/ + + sycl::free(scratch_ptr, q); + bool fails = false; + for (size_t idx = 0; idx < NUM; ++idx) { + if ((output_keys[idx] != std::get<0>(sorted_vec[idx])) || + (output_vals[idx] != std::get<1>(sorted_vec[idx]))) { + std::cout << "idx: " << idx << std::endl; + fails = true; + break; + } + } + assert(!fails); +} + +int main() { + queue q; + + { + constexpr static int NUM = 35; + uint16_t ikeys[NUM] = {1, 11, 1, 9, 3, 100, 34, 8, 121, + 77, 125, 23, 36, 2, 111, 91, 88, 2, + 51, 213, 181, 183, 31, 142, 216, 1, 199, + 124, 12, 0, 181, 17, 15, 101, 44}; + uint32_t ivals[NUM] = {99, 32, 1, 2, 67, 9123, 453, + 435, 91111, 777, 165, 145, 2456, 88811, + 761, 96, 765, 10000, 6364, 90, 525, + 882, 1, 2423, 9, 4324, 9123, 0, + 1232, 777, 555, 314159, 905, 9831, 84341}; + uint8_t ivals1[NUM] = {99, 32, 1, 2, 67, 91, 45, 43, 91, 77, 16, 14, + 24, 88, 76, 96, 76, 100, 63, 90, 52, 82, 1, 22, + 9, 225, 127, 0, 12, 128, 3, 102, 200, 111, 123}; + int8_t ivals2[NUM] = {-99, 127, -121, 100, 9, 5, 12, 35, -98, + 77, 112, -91, 11, 12, 3, 71, -66, 121, + 18, 14, 21, -22, 54, 88, -81, 31, 23, + 53, 97, 103, 71, 83, 97, 37, -41}; + uint16_t ivals3[NUM] = {28831, 23870, 54250, 5022, 9571, 60147, 9554, + 18818, 28689, 18229, 40512, 23200, 40454, 24841, + 43251, 63264, 29448, 45917, 882, 30788, 7586, + 57541, 22108, 59535, 31880, 7152, 63919, 58703, + 14686, 29914, 5872, 35868, 51479, 22721, 50927}; + int16_t ivals4[NUM] = { + 2798, -13656, 1592, 3992, -25870, 25172, 7761, -18347, 1617, + 25472, 26763, -5982, 24791, 27189, 22911, 22502, 15801, 25326, + -2196, 9205, -10418, 20464, -16616, -11285, 7249, 22866, 30574, + -1298, 31351, 28252, 21322, -10072, 7874, -26785, 22016}; + + uint32_t ivals5[NUM] = { + 2238578408, 102907035, 2316773768, 617902655, 532045482, 73173328, + 1862406505, 142735533, 3494078873, 610196959, 4210902254, 1863122236, + 1257721692, 30008197, 3199012044, 3503276708, 3504950001, 1240383071, + 2463430884, 904104390, 4044803029, 3164373711, 1586440767, 1999536602, + 3377662770, 927279985, 1740225703, 1133653675, 3975816601, 260339911, + 1115507520, 2279020820, 4289105012, 692964674, 53775301}; + + int32_t ivals6[NUM] = { + 507394811, 1949685322, 1624859474, -940434061, -1440675113, + -2002743224, 369969519, 840772268, 224522238, 296113452, + -714007528, 480713824, 665592454, 1696360848, 780843358, + -1901994531, 1667711523, 1390737696, 1357434904, -290165630, + 305128121, 1301489180, 630469211, -1385846315, 809333959, + 1098974670, 56900257, 876775101, -1496897817, 1172877939, + 1528916082, 559152364, 749878571, 2071902702, -430851798}; + + uint64_t ivals7[NUM] = {7916688577774406903ULL, 16873442000174444235ULL, + 5261140449171682429ULL, 6274209348061756377ULL, + 17881284978944229367ULL, 4701456380424752599ULL, + 6241062870187348613ULL, 8972524466137433448ULL, + 468629112944127776ULL, 17523909311582893643ULL, + 17447772191733931166ULL, 14311152396797789854ULL, + 9265327272409079662ULL, 9958475911404242556ULL, + 15359829357736068242ULL, 11416531655738519189ULL, + 16839972510321195914ULL, 1927049095689256442ULL, + 3356565661065236628ULL, 1065114285796701562ULL, + 7071763288904613033ULL, 16473053015775147286ULL, + 10317354477399696817ULL, 16005969584273256379ULL, + 15391010921709289298ULL, 17671303749287233862ULL, + 8028596930095411867ULL, 10265936863337610975ULL, + 17403745948359398749ULL, 8504886230707230194ULL, + 12855085916215721214ULL, 5562885793068933146ULL, + 1508385574711135517ULL, 5953119477818575536ULL, + 9165320150094769334ULL}; + + int64_t ivals8[NUM] = { + 2944696543084623337, 137239101631340692, 4370169869966467498, + 3842452903439153631, -795080033670806202, 3023506421574592237, + -4142692575864168559, 1716333381567689984, 1591746912204250089, + -1974664662220599925, 3144022139297218102, -371429365537296255, + 4202906659701034264, 3878513012313576184, -3425767072006791628, + -2929337291418891626, 1880013370888913338, 1498977159463939728, + -2928775660744278650, 4074214991200977615, 4291797122649374026, + -763110360214750992, 2883673064242977727, 4270151072450399778, + 1408696225027958214, 1264214335825459628, -4152065441956669638, + 2684706424226400837, 569335474272794084, -2798088842177577224, + 814002749054152728, 2517003920904582842, 4089891582575745386, + 705067059635512048, -2500935118374519236}; + + auto work_group_sorter = [](uint16_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES + __devicelib_default_work_group_private_sort_close_descending_p1u16_p1u32_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1u16_p1u32_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif + }; + + auto work_group_sorter1 = [](uint16_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES + __devicelib_default_work_group_private_sort_close_descending_p1u16_p1u8_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1u16_p1u8_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif + }; + + auto work_group_sorter2 = [](uint16_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES + __devicelib_default_work_group_private_sort_close_descending_p1u16_p1i8_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1u16_p1i8_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif + }; + + auto work_group_sorter3 = [](uint16_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES + __devicelib_default_work_group_private_sort_close_descending_p1u16_p1u16_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1u16_p1u16_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif + }; + + auto work_group_sorter4 = [](uint16_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES + __devicelib_default_work_group_private_sort_close_descending_p1u16_p1i16_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1u16_p1i16_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif + }; + + auto work_group_sorter5 = [](uint16_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES + __devicelib_default_work_group_private_sort_close_descending_p1u16_p1u32_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1u16_p1u32_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif + }; + + auto work_group_sorter6 = [](uint16_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES + __devicelib_default_work_group_private_sort_close_descending_p1u16_p1i32_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1u16_p1i32_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif + }; + + auto work_group_sorter7 = [](uint16_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES + __devicelib_default_work_group_private_sort_close_descending_p1u16_p1u64_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1u16_p1u64_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif + }; + + auto work_group_sorter8 = [](uint16_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES + __devicelib_default_work_group_private_sort_close_descending_p1u16_p1i64_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1u16_p1i64_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif + }; + + constexpr static int NUM1 = 32; + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 4 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 8 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 16 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 32 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 5 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 7 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 35 pass." << std::endl; + + constexpr static int NUM2 = 24; + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 3 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 4 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 6 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 8 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 12 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 24 pass." << std::endl; + + constexpr static int NUM3 = 20; + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 4 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 5 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 10 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 20 pass." << std::endl; + + constexpr static int NUM4 = 30; + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 3 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 5 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 6 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 10 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 15 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 30 pass." << std::endl; + + constexpr static int NUM5 = 25; + test_work_group_KV_private_sort( + q, ikeys, ivals6, work_group_sorter6); + std::cout << "KV private sort NUM = " << NUM5 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals6, work_group_sorter6); + std::cout << "KV private sort NUM = " << NUM5 + << ", WG = 5 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals6, work_group_sorter6); + std::cout << "KV private sort NUM = " << NUM5 + << ", WG = 25 pass." << std::endl; + + constexpr static int NUM6 = 30; + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 3 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 5 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 6 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 10 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 15 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 30 pass." << std::endl; + + constexpr static int NUM7 = 21; + test_work_group_KV_private_sort( + q, ikeys, ivals8, work_group_sorter8); + std::cout << "KV private sort NUM = " << NUM7 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals8, work_group_sorter8); + std::cout << "KV private sort NUM = " << NUM7 + << ", WG = 3 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals8, work_group_sorter8); + std::cout << "KV private sort NUM = " << NUM7 + << ", WG = 7 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals8, work_group_sorter8); + std::cout << "KV private sort NUM = " << NUM7 + << ", WG = 21 pass." << std::endl; + } +} From 396d310ede65c6b10b1bdf76d98842fdc04cc31a Mon Sep 17 00:00:00 2001 From: jinge90 Date: Mon, 30 Sep 2024 17:26:02 +0800 Subject: [PATCH 55/71] add tests for i16,u16 key type Signed-off-by: jinge90 --- .../group_sort/workgroup_private_KV_sort_i16.cpp | 15 ++++++--------- .../group_sort/workgroup_private_KV_sort_i8.cpp | 9 +++++++-- .../group_sort/workgroup_private_KV_sort_u16.cpp | 9 +++++++-- .../group_sort/workgroup_private_KV_sort_u8.cpp | 9 +++++++-- 4 files changed, 27 insertions(+), 15 deletions(-) diff --git a/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_i16.cpp b/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_i16.cpp index 38800a2e8c2a0..40ce1fececdb4 100644 --- a/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_i16.cpp +++ b/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_i16.cpp @@ -20,8 +20,8 @@ using namespace sycl; template -void test_work_group_KV_private_sort(sycl::queue &q, KeyT keys[NUM], - ValT vals[NUM], SortHelper gsh) { +void test_work_group_KV_private_sort(sycl::queue &q, const KeyT keys[NUM], + const ValT vals[NUM], SortHelper gsh) { static_assert((NUM % WG_SZ == 0), "Input number must be divisible by work group size!"); @@ -29,9 +29,6 @@ void test_work_group_KV_private_sort(sycl::queue &q, KeyT keys[NUM], ValT input_vals[NUM]; memcpy(&input_keys[0], &keys[0], NUM * sizeof(KeyT)); memcpy(&input_vals[0], &vals[0], NUM * sizeof(ValT)); - for (size_t idx = 0; idx < NUM; ++idx) { - std::cout << "key: " << input_keys[idx] << ", val: " << input_vals[idx] << std::endl; - } size_t scratch_size = 2 * NUM * (sizeof(KeyT) + sizeof(ValT)) + std::max(alignof(KeyT), alignof(ValT)); uint8_t *scratch_ptr = @@ -56,10 +53,10 @@ void test_work_group_KV_private_sort(sycl::queue &q, KeyT keys[NUM], #endif std::stable_sort(sorted_vec.begin(), sorted_vec.end(), kv_tuple_comp); - for (size_t idx = 0; idx < NUM; ++idx) { + /* for (size_t idx = 0; idx < NUM; ++idx) { std::cout << "key: " << std::get<0>(sorted_vec[idx]) << " val: " << (int64_t)std::get<1>(sorted_vec[idx]) << std::endl; - } + }*/ nd_range<1> num_items((range<1>(wg_size)), (range<1>(wg_size))); { @@ -311,7 +308,7 @@ int main() { #endif #endif }; -#if 0 + constexpr static int NUM1 = 32; test_work_group_KV_private_sort( @@ -575,7 +572,7 @@ int main() { q, ikeys, ivals7, work_group_sorter7); std::cout << "KV private sort NUM = " << NUM6 << ", WG = 30 pass." << std::endl; -#endif + constexpr static int NUM7 = 21; test_work_group_KV_private_sort( diff --git a/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_i8.cpp b/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_i8.cpp index 316cb7eb8d8ff..f4ab477ad8ce9 100644 --- a/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_i8.cpp +++ b/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_i8.cpp @@ -20,10 +20,15 @@ using namespace sycl; template -void test_work_group_KV_private_sort(sycl::queue &q, KeyT input_keys[NUM], - ValT input_vals[NUM], SortHelper gsh) { +void test_work_group_KV_private_sort(sycl::queue &q, const KeyT keys[NUM], + const ValT vals[NUM], SortHelper gsh) { static_assert((NUM % WG_SZ == 0), "Input number must be divisible by work group size!"); + + KeyT input_keys[NUM]; + ValT input_vals[NUM]; + memcpy(&input_keys[0], &keys[0], NUM * sizeof(KeyT)); + memcpy(&input_vals[0], &vals[0], NUM * sizeof(ValT)); size_t scratch_size = 2 * NUM * (sizeof(KeyT) + sizeof(ValT)) + std::max(alignof(KeyT), alignof(ValT)); uint8_t *scratch_ptr = diff --git a/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_u16.cpp b/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_u16.cpp index c1f94ebbe29aa..cfd16f9c45a95 100644 --- a/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_u16.cpp +++ b/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_u16.cpp @@ -20,10 +20,15 @@ using namespace sycl; template -void test_work_group_KV_private_sort(sycl::queue &q, KeyT input_keys[NUM], - ValT input_vals[NUM], SortHelper gsh) { +void test_work_group_KV_private_sort(sycl::queue &q, const KeyT keys[NUM], + const ValT vals[NUM], SortHelper gsh) { static_assert((NUM % WG_SZ == 0), "Input number must be divisible by work group size!"); + + KeyT input_keys[NUM]; + ValT input_vals[NUM]; + memcpy(&input_keys[0], &keys[0], NUM * sizeof(KeyT)); + memcpy(&input_vals[0], &vals[0], NUM * sizeof(ValT)); size_t scratch_size = 2 * NUM * (sizeof(KeyT) + sizeof(ValT)) + std::max(alignof(KeyT), alignof(ValT)); uint8_t *scratch_ptr = diff --git a/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_u8.cpp b/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_u8.cpp index eef870cdebc50..dedae328797f4 100644 --- a/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_u8.cpp +++ b/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_u8.cpp @@ -20,10 +20,15 @@ using namespace sycl; template -void test_work_group_KV_private_sort(sycl::queue &q, KeyT input_keys[NUM], - ValT input_vals[NUM], SortHelper gsh) { +void test_work_group_KV_private_sort(sycl::queue &q, const KeyT keys[NUM], + const ValT vals[NUM], SortHelper gsh) { static_assert((NUM % WG_SZ == 0), "Input number must be divisible by work group size!"); + + KeyT input_keys[NUM]; + ValT input_vals[NUM]; + memcpy(&input_keys[0], &keys[0], NUM * sizeof(KeyT)); + memcpy(&input_vals[0], &vals[0], NUM * sizeof(ValT)); size_t scratch_size = 2 * NUM * (sizeof(KeyT) + sizeof(ValT)) + std::max(alignof(KeyT), alignof(ValT)); uint8_t *scratch_ptr = From 5e4f50925d83a7eecbceb044e96e3349def43817 Mon Sep 17 00:00:00 2001 From: jinge90 Date: Wed, 9 Oct 2024 10:46:34 +0800 Subject: [PATCH 56/71] add tests for key type u32 Signed-off-by: jinge90 --- .../group_private_KV_sort_p1p1_p1.hpp | 64 ++ .../workgroup_private_KV_sort_u32.cpp | 688 ++++++++++++++++++ 2 files changed, 752 insertions(+) create mode 100644 sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_u32.cpp diff --git a/sycl/test-e2e/DeviceLib/group_sort/group_private_KV_sort_p1p1_p1.hpp b/sycl/test-e2e/DeviceLib/group_sort/group_private_KV_sort_p1p1_p1.hpp index eec351cb1a376..72ecf8a3effde 100644 --- a/sycl/test-e2e/DeviceLib/group_sort/group_private_KV_sort_p1p1_p1.hpp +++ b/sycl/test-e2e/DeviceLib/group_sort/group_private_KV_sort_p1p1_p1.hpp @@ -287,3 +287,67 @@ __devicelib_default_work_group_private_sort_close_ascending_p1i16_p1f32_u32_p1i8 __DPCPP_SYCL_EXTERNAL extern "C" void __devicelib_default_work_group_private_sort_close_descending_p1i16_p1f32_u32_p1i8( int16_t *keys, float *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1u32_p1u8_u32_p1i8( + uint32_t *keys, uint8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1u32_p1u8_u32_p1i8( + uint32_t *keys, uint8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1u32_p1i8_u32_p1i8( + uint32_t *keys, int8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1u32_p1i8_u32_p1i8( + uint32_t *keys, int8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1u32_p1u16_u32_p1i8( + uint32_t *keys, uint16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1u32_p1u16_u32_p1i8( + uint32_t *keys, uint16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1u32_p1i16_u32_p1i8( + uint32_t *keys, int16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1u32_p1i16_u32_p1i8( + uint32_t *keys, int16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1u32_p1u32_u32_p1i8( + uint32_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1u32_p1u32_u32_p1i8( + uint32_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1u32_p1i32_u32_p1i8( + uint32_t *keys, int32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1u32_p1i32_u32_p1i8( + uint32_t *keys, int32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1u32_p1u64_u32_p1i8( + uint32_t *keys, uint64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1u32_p1u64_u32_p1i8( + uint32_t *keys, uint64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1u32_p1i64_u32_p1i8( + uint32_t *keys, int64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1u32_p1i64_u32_p1i8( + uint32_t *keys, int64_t *vals, uint32_t n, uint8_t *scratch); diff --git a/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_u32.cpp b/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_u32.cpp new file mode 100644 index 0000000000000..40681e088288e --- /dev/null +++ b/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_u32.cpp @@ -0,0 +1,688 @@ +// RUN: %{build} -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -fsycl-device-lib-jit-link -o %t.out +// RUN: %{run} %t.out + +// UNSUPPORTED: cuda || hip || cpu + +#include "group_private_KV_sort_p1p1_p1.hpp" +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace sycl; + +template +void test_work_group_KV_private_sort(sycl::queue &q, const KeyT keys[NUM], + const ValT vals[NUM], SortHelper gsh) { + static_assert((NUM % WG_SZ == 0), + "Input number must be divisible by work group size!"); + + KeyT input_keys[NUM]; + ValT input_vals[NUM]; + memcpy(&input_keys[0], &keys[0], NUM * sizeof(KeyT)); + memcpy(&input_vals[0], &vals[0], NUM * sizeof(ValT)); + size_t scratch_size = 2 * NUM * (sizeof(KeyT) + sizeof(ValT)) + + std::max(alignof(KeyT), alignof(ValT)); + uint8_t *scratch_ptr = + (uint8_t *)aligned_alloc_device(alignof(KeyT), scratch_size, q); + const static size_t wg_size = WG_SZ; + constexpr size_t num_per_work_item = NUM / WG_SZ; + KeyT output_keys[NUM]; + ValT output_vals[NUM]; + std::vector> sorted_vec; + for (size_t idx = 0; idx < NUM; ++idx) + sorted_vec.push_back(std::make_tuple(input_keys[idx], input_vals[idx])); +#ifdef DES + auto kv_tuple_comp = [](const std::tuple &t1, + const std::tuple &t2) { + return std::get<0>(t1) > std::get<0>(t2); + }; +#else + auto kv_tuple_comp = [](const std::tuple &t1, + const std::tuple &t2) { + return std::get<0>(t1) < std::get<0>(t2); + }; +#endif + std::stable_sort(sorted_vec.begin(), sorted_vec.end(), kv_tuple_comp); + + /*for (size_t idx = 0; idx < NUM; ++idx) { + std::cout << "key: " << (int)std::get<0>(sorted_vec[idx]) << " val: " << + (int)std::get<1>(sorted_vec[idx]) << std::endl; + }*/ + + nd_range<1> num_items((range<1>(wg_size)), (range<1>(wg_size))); + { + buffer ikeys_buf(input_keys, NUM); + buffer ivals_buf(input_vals, NUM); + buffer okeys_buf(output_keys, NUM); + buffer ovals_buf(output_vals, NUM); + q.submit([&](auto &h) { + accessor ikeys_acc{ikeys_buf, h}; + accessor ivals_acc{ivals_buf, h}; + accessor okeys_acc{okeys_buf, h}; + accessor ovals_acc{ovals_buf, h}; + sycl::stream os(1024, 128, h); + h.parallel_for(num_items, [=](nd_item<1> i) { + KeyT pkeys[num_per_work_item]; + ValT pvals[num_per_work_item]; + // copy from global input to fix-size private array. + for (size_t idx = 0; idx < num_per_work_item; ++idx) { + pkeys[idx] = + ikeys_acc[i.get_local_linear_id() * num_per_work_item + idx]; + pvals[idx] = + ivals_acc[i.get_local_linear_id() * num_per_work_item + idx]; + } + + gsh(pkeys, pvals, num_per_work_item, scratch_ptr); + + for (size_t idx = 0; idx < num_per_work_item; ++idx) { + okeys_acc[i.get_local_linear_id() * num_per_work_item + idx] = + pkeys[idx]; + ovals_acc[i.get_local_linear_id() * num_per_work_item + idx] = + pvals[idx]; + } + }); + }).wait(); + } + + /* for (size_t idx = 0; idx < NUM; ++idx) { + std::cout << "key: " << (int)(input_keys[idx]) << " val: " << + (int)(input_vals[idx]) << std::endl; + }*/ + + sycl::free(scratch_ptr, q); + bool fails = false; + for (size_t idx = 0; idx < NUM; ++idx) { + if ((output_keys[idx] != std::get<0>(sorted_vec[idx])) || + (output_vals[idx] != std::get<1>(sorted_vec[idx]))) { + std::cout << "idx: " << idx << std::endl; + fails = true; + break; + } + } + assert(!fails); +} + +int main() { + queue q; + + { + constexpr static int NUM = 40; + uint32_t ikeys[NUM] = { + 1, 11, 1, 9, 3, 100, 34, 8, + 121, 77, 125, 23, 222336, 2, 111, 91, + 881112, 2, 6662451, 213, 199181, 3183, 310910, 11242, + 3216, 1, 199, 7712423, 0, 0, 181, 17, + 15, 101, 44, 103934, 1, 11, 193, 213}; + uint8_t ivals1[NUM] = {99, 32, 1, 2, 67, 91, 45, 43, 91, 77, + 16, 14, 24, 88, 76, 96, 76, 100, 63, 90, + 52, 82, 1, 22, 9, 225, 127, 0, 12, 128, + 3, 102, 200, 111, 123, 15, 45, 66, 123, 91}; + int8_t ivals2[NUM] = {-99, 127, -121, 100, 9, 5, 12, 35, -98, 77, + 112, -91, 11, 12, 3, 71, -66, 121, 18, 14, + 21, -22, 54, 88, -81, 31, 23, 53, 97, 103, + 71, 83, 97, 37, -41, -71, 112, -121, 98, 78}; + uint16_t ivals3[NUM] = { + 28831, 23870, 54250, 5022, 9571, 60147, 9554, 18818, 28689, 18229, + 40512, 23200, 40454, 24841, 43251, 63264, 29448, 45917, 882, 30788, + 7586, 57541, 22108, 59535, 31880, 7152, 63919, 58703, 14686, 29914, + 5872, 35868, 51479, 22721, 50927, 55094, 2341, 12, 23411, 9812}; + int16_t ivals4[NUM] = { + 2798, -13656, 1592, 3992, -25870, 25172, 7761, -18347, + 1617, 25472, 26763, -5982, 24791, 27189, 22911, 22502, + 15801, 25326, -2196, 9205, -10418, 20464, -16616, -11285, + 7249, 22866, 30574, -1298, 31351, 28252, 21322, -10072, + 7874, -26785, 22016, -12421, 0, 9999, 8888, -7777}; + + uint32_t ivals5[NUM] = { + 2238578408, 102907035, 2316773768, 617902655, 532045482, 73173328, + 1862406505, 142735533, 3494078873, 610196959, 4210902254, 1863122236, + 1257721692, 30008197, 3199012044, 3503276708, 3504950001, 1240383071, + 2463430884, 904104390, 4044803029, 3164373711, 1586440767, 1999536602, + 3377662770, 927279985, 1740225703, 1133653675, 3975816601, 260339911, + 1115507520, 2279020820, 4289105012, 692964674, 53775301, 999999999, + 777777777, 44444444, 1000, 9124}; + + int32_t ivals6[NUM] = { + 507394811, 1949685322, 1624859474, -940434061, -1440675113, + -2002743224, 369969519, 840772268, 224522238, 296113452, + -714007528, 480713824, 665592454, 1696360848, 780843358, + -1901994531, 1667711523, 1390737696, 1357434904, -290165630, + 305128121, 1301489180, 630469211, -1385846315, 809333959, + 1098974670, 56900257, 876775101, -1496897817, 1172877939, + 1528916082, 559152364, 749878571, 2071902702, -430851798, + 912342, -88888888, 777777777, -11111111, 0}; + + uint64_t ivals7[NUM] = {7916688577774406903ULL, + 16873442000174444235ULL, + 5261140449171682429ULL, + 6274209348061756377ULL, + 17881284978944229367ULL, + 4701456380424752599ULL, + 6241062870187348613ULL, + 8972524466137433448ULL, + 468629112944127776ULL, + 17523909311582893643ULL, + 17447772191733931166ULL, + 14311152396797789854ULL, + 9265327272409079662ULL, + 9958475911404242556ULL, + 15359829357736068242ULL, + 11416531655738519189ULL, + 16839972510321195914ULL, + 1927049095689256442ULL, + 3356565661065236628ULL, + 1065114285796701562ULL, + 7071763288904613033ULL, + 16473053015775147286ULL, + 10317354477399696817ULL, + 16005969584273256379ULL, + 15391010921709289298ULL, + 17671303749287233862ULL, + 8028596930095411867ULL, + 10265936863337610975ULL, + 17403745948359398749ULL, + 8504886230707230194ULL, + 12855085916215721214ULL, + 5562885793068933146ULL, + 1508385574711135517ULL, + 5953119477818575536ULL, + 9165320150094769334ULL, + 0, + 11, + 324, + 943534, + 930525}; + + int64_t ivals8[NUM] = {2944696543084623337, + 137239101631340692, + 4370169869966467498, + 3842452903439153631, + -795080033670806202, + 3023506421574592237, + -4142692575864168559, + 1716333381567689984, + 1591746912204250089, + -1974664662220599925, + 3144022139297218102, + -371429365537296255, + 4202906659701034264, + 3878513012313576184, + -3425767072006791628, + -2929337291418891626, + 1880013370888913338, + 1498977159463939728, + -2928775660744278650, + 4074214991200977615, + 4291797122649374026, + -763110360214750992, + 2883673064242977727, + 4270151072450399778, + 1408696225027958214, + 1264214335825459628, + -4152065441956669638, + 2684706424226400837, + 569335474272794084, + -2798088842177577224, + 814002749054152728, + 2517003920904582842, + 4089891582575745386, + 705067059635512048, + -2500935118374519236, + 0, + -1, + 234, + 45, + -1232143254}; + + auto work_group_sorter = [](uint32_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES + __devicelib_default_work_group_private_sort_close_descending_p1u32_p1u32_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1u32_p1u32_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif + }; + + auto work_group_sorter1 = [](uint32_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES + __devicelib_default_work_group_private_sort_close_descending_p1u32_p1u8_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1u32_p1u8_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif + }; + + auto work_group_sorter2 = [](uint32_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES + __devicelib_default_work_group_private_sort_close_descending_p1u32_p1i8_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1u32_p1i8_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif + }; + + auto work_group_sorter3 = [](uint32_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES + __devicelib_default_work_group_private_sort_close_descending_p1u32_p1u16_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1u32_p1u16_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif + }; + + auto work_group_sorter4 = [](uint32_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES + __devicelib_default_work_group_private_sort_close_descending_p1u32_p1i16_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1u32_p1i16_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif + }; + + auto work_group_sorter5 = [](uint32_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES + __devicelib_default_work_group_private_sort_close_descending_p1u32_p1u32_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1u32_p1u32_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif + }; + + auto work_group_sorter6 = [](uint32_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES + __devicelib_default_work_group_private_sort_close_descending_p1u32_p1i32_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1u32_p1i32_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif + }; + + auto work_group_sorter7 = [](uint32_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES + __devicelib_default_work_group_private_sort_close_descending_p1u32_p1u64_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1u32_p1u64_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif + }; + + auto work_group_sorter8 = [](uint32_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES + __devicelib_default_work_group_private_sort_close_descending_p1u32_p1i64_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1u32_p1i64_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif + }; + + constexpr static int NUM1 = 36; + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 4 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 3 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 6 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 9 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 18 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 36 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 5 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 8 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 4 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 10 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 20 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 40 pass." << std::endl; + + constexpr static int NUM2 = 24; + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 3 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 4 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 6 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 8 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 12 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 24 pass." << std::endl; + + constexpr static int NUM3 = 20; + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 4 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 5 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 10 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 20 pass." << std::endl; + + constexpr static int NUM4 = 30; + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 3 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 5 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 6 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 10 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 15 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 30 pass." << std::endl; + + constexpr static int NUM5 = 25; + test_work_group_KV_private_sort( + q, ikeys, ivals6, work_group_sorter6); + std::cout << "KV private sort NUM = " << NUM5 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals6, work_group_sorter6); + std::cout << "KV private sort NUM = " << NUM5 + << ", WG = 5 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals6, work_group_sorter6); + std::cout << "KV private sort NUM = " << NUM5 + << ", WG = 25 pass." << std::endl; + + constexpr static int NUM6 = 30; + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 3 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 5 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 6 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 10 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 15 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 30 pass." << std::endl; + + constexpr static int NUM7 = 21; + test_work_group_KV_private_sort( + q, ikeys, ivals8, work_group_sorter8); + std::cout << "KV private sort NUM = " << NUM7 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals8, work_group_sorter8); + std::cout << "KV private sort NUM = " << NUM7 + << ", WG = 3 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals8, work_group_sorter8); + std::cout << "KV private sort NUM = " << NUM7 + << ", WG = 7 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals8, work_group_sorter8); + std::cout << "KV private sort NUM = " << NUM7 + << ", WG = 21 pass." << std::endl; + } +} From bef0bcfcfc5151d688cc65f116634beac87b95ab Mon Sep 17 00:00:00 2001 From: jinge90 Date: Wed, 9 Oct 2024 11:14:32 +0800 Subject: [PATCH 57/71] add test for key type i32 Signed-off-by: jinge90 --- .../group_private_KV_sort_p1p1_p1.hpp | 64 ++ .../workgroup_private_KV_sort_i32.cpp | 688 ++++++++++++++++++ 2 files changed, 752 insertions(+) create mode 100644 sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_i32.cpp diff --git a/sycl/test-e2e/DeviceLib/group_sort/group_private_KV_sort_p1p1_p1.hpp b/sycl/test-e2e/DeviceLib/group_sort/group_private_KV_sort_p1p1_p1.hpp index 72ecf8a3effde..3d662aa0c152f 100644 --- a/sycl/test-e2e/DeviceLib/group_sort/group_private_KV_sort_p1p1_p1.hpp +++ b/sycl/test-e2e/DeviceLib/group_sort/group_private_KV_sort_p1p1_p1.hpp @@ -351,3 +351,67 @@ __devicelib_default_work_group_private_sort_close_ascending_p1u32_p1i64_u32_p1i8 __DPCPP_SYCL_EXTERNAL extern "C" void __devicelib_default_work_group_private_sort_close_descending_p1u32_p1i64_u32_p1i8( uint32_t *keys, int64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1i32_p1u8_u32_p1i8( + int32_t *keys, uint8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1i32_p1u8_u32_p1i8( + int32_t *keys, uint8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1i32_p1i8_u32_p1i8( + int32_t *keys, int8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1i32_p1i8_u32_p1i8( + int32_t *keys, int8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1i32_p1u16_u32_p1i8( + int32_t *keys, uint16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1i32_p1u16_u32_p1i8( + int32_t *keys, uint16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1i32_p1i16_u32_p1i8( + int32_t *keys, int16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1i32_p1i16_u32_p1i8( + int32_t *keys, int16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1i32_p1u32_u32_p1i8( + int32_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1i32_p1u32_u32_p1i8( + int32_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1i32_p1i32_u32_p1i8( + int32_t *keys, int32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1i32_p1i32_u32_p1i8( + int32_t *keys, int32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1i32_p1u64_u32_p1i8( + int32_t *keys, uint64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1i32_p1u64_u32_p1i8( + int32_t *keys, uint64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1i32_p1i64_u32_p1i8( + int32_t *keys, int64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1i32_p1i64_u32_p1i8( + int32_t *keys, int64_t *vals, uint32_t n, uint8_t *scratch); diff --git a/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_i32.cpp b/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_i32.cpp new file mode 100644 index 0000000000000..ce5cc2373861c --- /dev/null +++ b/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_i32.cpp @@ -0,0 +1,688 @@ +// RUN: %{build} -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -fsycl-device-lib-jit-link -o %t.out +// RUN: %{run} %t.out + +// UNSUPPORTED: cuda || hip || cpu + +#include "group_private_KV_sort_p1p1_p1.hpp" +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace sycl; + +template +void test_work_group_KV_private_sort(sycl::queue &q, const KeyT keys[NUM], + const ValT vals[NUM], SortHelper gsh) { + static_assert((NUM % WG_SZ == 0), + "Input number must be divisible by work group size!"); + + KeyT input_keys[NUM]; + ValT input_vals[NUM]; + memcpy(&input_keys[0], &keys[0], NUM * sizeof(KeyT)); + memcpy(&input_vals[0], &vals[0], NUM * sizeof(ValT)); + size_t scratch_size = 2 * NUM * (sizeof(KeyT) + sizeof(ValT)) + + std::max(alignof(KeyT), alignof(ValT)); + uint8_t *scratch_ptr = + (uint8_t *)aligned_alloc_device(alignof(KeyT), scratch_size, q); + const static size_t wg_size = WG_SZ; + constexpr size_t num_per_work_item = NUM / WG_SZ; + KeyT output_keys[NUM]; + ValT output_vals[NUM]; + std::vector> sorted_vec; + for (size_t idx = 0; idx < NUM; ++idx) + sorted_vec.push_back(std::make_tuple(input_keys[idx], input_vals[idx])); +#ifdef DES + auto kv_tuple_comp = [](const std::tuple &t1, + const std::tuple &t2) { + return std::get<0>(t1) > std::get<0>(t2); + }; +#else + auto kv_tuple_comp = [](const std::tuple &t1, + const std::tuple &t2) { + return std::get<0>(t1) < std::get<0>(t2); + }; +#endif + std::stable_sort(sorted_vec.begin(), sorted_vec.end(), kv_tuple_comp); + + /*for (size_t idx = 0; idx < NUM; ++idx) { + std::cout << "key: " << (int)std::get<0>(sorted_vec[idx]) << " val: " << + (int)std::get<1>(sorted_vec[idx]) << std::endl; + }*/ + + nd_range<1> num_items((range<1>(wg_size)), (range<1>(wg_size))); + { + buffer ikeys_buf(input_keys, NUM); + buffer ivals_buf(input_vals, NUM); + buffer okeys_buf(output_keys, NUM); + buffer ovals_buf(output_vals, NUM); + q.submit([&](auto &h) { + accessor ikeys_acc{ikeys_buf, h}; + accessor ivals_acc{ivals_buf, h}; + accessor okeys_acc{okeys_buf, h}; + accessor ovals_acc{ovals_buf, h}; + sycl::stream os(1024, 128, h); + h.parallel_for(num_items, [=](nd_item<1> i) { + KeyT pkeys[num_per_work_item]; + ValT pvals[num_per_work_item]; + // copy from global input to fix-size private array. + for (size_t idx = 0; idx < num_per_work_item; ++idx) { + pkeys[idx] = + ikeys_acc[i.get_local_linear_id() * num_per_work_item + idx]; + pvals[idx] = + ivals_acc[i.get_local_linear_id() * num_per_work_item + idx]; + } + + gsh(pkeys, pvals, num_per_work_item, scratch_ptr); + + for (size_t idx = 0; idx < num_per_work_item; ++idx) { + okeys_acc[i.get_local_linear_id() * num_per_work_item + idx] = + pkeys[idx]; + ovals_acc[i.get_local_linear_id() * num_per_work_item + idx] = + pvals[idx]; + } + }); + }).wait(); + } + + /* for (size_t idx = 0; idx < NUM; ++idx) { + std::cout << "key: " << (int)(input_keys[idx]) << " val: " << + (int)(input_vals[idx]) << std::endl; + }*/ + + sycl::free(scratch_ptr, q); + bool fails = false; + for (size_t idx = 0; idx < NUM; ++idx) { + if ((output_keys[idx] != std::get<0>(sorted_vec[idx])) || + (output_vals[idx] != std::get<1>(sorted_vec[idx]))) { + std::cout << "idx: " << idx << std::endl; + fails = true; + break; + } + } + assert(!fails); +} + +int main() { + queue q; + + { + constexpr static int NUM = 40; + int32_t ikeys[NUM] = { + 1, 11, 1, 9, 3, 100, 34, 8, + 121, 77, 125, 23, 222336, 2, 111, 91, + 881112, -2, -62451, 213, 199181, 3183, 310910, 11242, + 3216, 1, 199, -7712423, 0, 0, -181, 17, + 15, -101, 44, 103934, 1, 11, 193, 213}; + uint8_t ivals1[NUM] = {99, 32, 1, 2, 67, 91, 45, 43, 91, 77, + 16, 14, 24, 88, 76, 96, 76, 100, 63, 90, + 52, 82, 1, 22, 9, 225, 127, 0, 12, 128, + 3, 102, 200, 111, 123, 15, 45, 66, 123, 91}; + int8_t ivals2[NUM] = {-99, 127, -121, 100, 9, 5, 12, 35, -98, 77, + 112, -91, 11, 12, 3, 71, -66, 121, 18, 14, + 21, -22, 54, 88, -81, 31, 23, 53, 97, 103, + 71, 83, 97, 37, -41, -71, 112, -121, 98, 78}; + uint16_t ivals3[NUM] = { + 28831, 23870, 54250, 5022, 9571, 60147, 9554, 18818, 28689, 18229, + 40512, 23200, 40454, 24841, 43251, 63264, 29448, 45917, 882, 30788, + 7586, 57541, 22108, 59535, 31880, 7152, 63919, 58703, 14686, 29914, + 5872, 35868, 51479, 22721, 50927, 55094, 2341, 12, 23411, 9812}; + int16_t ivals4[NUM] = { + 2798, -13656, 1592, 3992, -25870, 25172, 7761, -18347, + 1617, 25472, 26763, -5982, 24791, 27189, 22911, 22502, + 15801, 25326, -2196, 9205, -10418, 20464, -16616, -11285, + 7249, 22866, 30574, -1298, 31351, 28252, 21322, -10072, + 7874, -26785, 22016, -12421, 0, 9999, 8888, -7777}; + + uint32_t ivals5[NUM] = { + 2238578408, 102907035, 2316773768, 617902655, 532045482, 73173328, + 1862406505, 142735533, 3494078873, 610196959, 4210902254, 1863122236, + 1257721692, 30008197, 3199012044, 3503276708, 3504950001, 1240383071, + 2463430884, 904104390, 4044803029, 3164373711, 1586440767, 1999536602, + 3377662770, 927279985, 1740225703, 1133653675, 3975816601, 260339911, + 1115507520, 2279020820, 4289105012, 692964674, 53775301, 999999999, + 777777777, 44444444, 1000, 9124}; + + int32_t ivals6[NUM] = { + 507394811, 1949685322, 1624859474, -940434061, -1440675113, + -2002743224, 369969519, 840772268, 224522238, 296113452, + -714007528, 480713824, 665592454, 1696360848, 780843358, + -1901994531, 1667711523, 1390737696, 1357434904, -290165630, + 305128121, 1301489180, 630469211, -1385846315, 809333959, + 1098974670, 56900257, 876775101, -1496897817, 1172877939, + 1528916082, 559152364, 749878571, 2071902702, -430851798, + 912342, -88888888, 777777777, -11111111, 0}; + + uint64_t ivals7[NUM] = {7916688577774406903ULL, + 16873442000174444235ULL, + 5261140449171682429ULL, + 6274209348061756377ULL, + 17881284978944229367ULL, + 4701456380424752599ULL, + 6241062870187348613ULL, + 8972524466137433448ULL, + 468629112944127776ULL, + 17523909311582893643ULL, + 17447772191733931166ULL, + 14311152396797789854ULL, + 9265327272409079662ULL, + 9958475911404242556ULL, + 15359829357736068242ULL, + 11416531655738519189ULL, + 16839972510321195914ULL, + 1927049095689256442ULL, + 3356565661065236628ULL, + 1065114285796701562ULL, + 7071763288904613033ULL, + 16473053015775147286ULL, + 10317354477399696817ULL, + 16005969584273256379ULL, + 15391010921709289298ULL, + 17671303749287233862ULL, + 8028596930095411867ULL, + 10265936863337610975ULL, + 17403745948359398749ULL, + 8504886230707230194ULL, + 12855085916215721214ULL, + 5562885793068933146ULL, + 1508385574711135517ULL, + 5953119477818575536ULL, + 9165320150094769334ULL, + 0, + 11, + 324, + 943534, + 930525}; + + int64_t ivals8[NUM] = {2944696543084623337, + 137239101631340692, + 4370169869966467498, + 3842452903439153631, + -795080033670806202, + 3023506421574592237, + -4142692575864168559, + 1716333381567689984, + 1591746912204250089, + -1974664662220599925, + 3144022139297218102, + -371429365537296255, + 4202906659701034264, + 3878513012313576184, + -3425767072006791628, + -2929337291418891626, + 1880013370888913338, + 1498977159463939728, + -2928775660744278650, + 4074214991200977615, + 4291797122649374026, + -763110360214750992, + 2883673064242977727, + 4270151072450399778, + 1408696225027958214, + 1264214335825459628, + -4152065441956669638, + 2684706424226400837, + 569335474272794084, + -2798088842177577224, + 814002749054152728, + 2517003920904582842, + 4089891582575745386, + 705067059635512048, + -2500935118374519236, + 0, + -1, + 234, + 45, + -1232143254}; + + auto work_group_sorter = [](int32_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES + __devicelib_default_work_group_private_sort_close_descending_p1i32_p1u32_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1i32_p1u32_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif + }; + + auto work_group_sorter1 = [](int32_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES + __devicelib_default_work_group_private_sort_close_descending_p1i32_p1u8_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1i32_p1u8_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif + }; + + auto work_group_sorter2 = [](int32_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES + __devicelib_default_work_group_private_sort_close_descending_p1i32_p1i8_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1i32_p1i8_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif + }; + + auto work_group_sorter3 = [](int32_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES + __devicelib_default_work_group_private_sort_close_descending_p1i32_p1u16_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1i32_p1u16_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif + }; + + auto work_group_sorter4 = [](int32_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES + __devicelib_default_work_group_private_sort_close_descending_p1i32_p1i16_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1i32_p1i16_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif + }; + + auto work_group_sorter5 = [](int32_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES + __devicelib_default_work_group_private_sort_close_descending_p1i32_p1u32_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1i32_p1u32_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif + }; + + auto work_group_sorter6 = [](int32_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES + __devicelib_default_work_group_private_sort_close_descending_p1i32_p1i32_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1i32_p1i32_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif + }; + + auto work_group_sorter7 = [](int32_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES + __devicelib_default_work_group_private_sort_close_descending_p1i32_p1u64_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1i32_p1u64_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif + }; + + auto work_group_sorter8 = [](int32_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES + __devicelib_default_work_group_private_sort_close_descending_p1i32_p1i64_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1i32_p1i64_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif + }; + + constexpr static int NUM1 = 36; + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 4 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 3 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 6 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 9 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 18 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 36 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 5 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 8 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 4 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 10 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 20 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 40 pass." << std::endl; + + constexpr static int NUM2 = 24; + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 3 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 4 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 6 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 8 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 12 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 24 pass." << std::endl; + + constexpr static int NUM3 = 20; + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 4 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 5 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 10 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 20 pass." << std::endl; + + constexpr static int NUM4 = 30; + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 3 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 5 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 6 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 10 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 15 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 30 pass." << std::endl; + + constexpr static int NUM5 = 25; + test_work_group_KV_private_sort( + q, ikeys, ivals6, work_group_sorter6); + std::cout << "KV private sort NUM = " << NUM5 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals6, work_group_sorter6); + std::cout << "KV private sort NUM = " << NUM5 + << ", WG = 5 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals6, work_group_sorter6); + std::cout << "KV private sort NUM = " << NUM5 + << ", WG = 25 pass." << std::endl; + + constexpr static int NUM6 = 30; + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 3 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 5 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 6 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 10 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 15 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 30 pass." << std::endl; + + constexpr static int NUM7 = 21; + test_work_group_KV_private_sort( + q, ikeys, ivals8, work_group_sorter8); + std::cout << "KV private sort NUM = " << NUM7 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals8, work_group_sorter8); + std::cout << "KV private sort NUM = " << NUM7 + << ", WG = 3 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals8, work_group_sorter8); + std::cout << "KV private sort NUM = " << NUM7 + << ", WG = 7 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals8, work_group_sorter8); + std::cout << "KV private sort NUM = " << NUM7 + << ", WG = 21 pass." << std::endl; + } +} From 35413eafcb239b19f16aacf95067327e3756cd08 Mon Sep 17 00:00:00 2001 From: jinge90 Date: Wed, 9 Oct 2024 13:31:45 +0800 Subject: [PATCH 58/71] add tests for key type i64/u64 Signed-off-by: jinge90 --- .../group_private_KV_sort_p1p1_p1.hpp | 128 ++++ .../workgroup_private_KV_sort_i64.cpp | 688 ++++++++++++++++++ .../workgroup_private_KV_sort_u64.cpp | 688 ++++++++++++++++++ 3 files changed, 1504 insertions(+) create mode 100644 sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_i64.cpp create mode 100644 sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_u64.cpp diff --git a/sycl/test-e2e/DeviceLib/group_sort/group_private_KV_sort_p1p1_p1.hpp b/sycl/test-e2e/DeviceLib/group_sort/group_private_KV_sort_p1p1_p1.hpp index 3d662aa0c152f..db8ac61c4bc2f 100644 --- a/sycl/test-e2e/DeviceLib/group_sort/group_private_KV_sort_p1p1_p1.hpp +++ b/sycl/test-e2e/DeviceLib/group_sort/group_private_KV_sort_p1p1_p1.hpp @@ -415,3 +415,131 @@ __devicelib_default_work_group_private_sort_close_ascending_p1i32_p1i64_u32_p1i8 __DPCPP_SYCL_EXTERNAL extern "C" void __devicelib_default_work_group_private_sort_close_descending_p1i32_p1i64_u32_p1i8( int32_t *keys, int64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1u64_p1u8_u32_p1i8( + uint64_t *keys, uint8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1u64_p1u8_u32_p1i8( + uint64_t *keys, uint8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1u64_p1i8_u32_p1i8( + uint64_t *keys, int8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1u64_p1i8_u32_p1i8( + uint64_t *keys, int8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1u64_p1u16_u32_p1i8( + uint64_t *keys, uint16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1u64_p1u16_u32_p1i8( + uint64_t *keys, uint16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1u64_p1i16_u32_p1i8( + uint64_t *keys, int16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1u64_p1i16_u32_p1i8( + uint64_t *keys, int16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1u64_p1u32_u32_p1i8( + uint64_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1u64_p1u32_u32_p1i8( + uint64_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1u64_p1i32_u32_p1i8( + uint64_t *keys, int32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1u64_p1i32_u32_p1i8( + uint64_t *keys, int32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1u64_p1u64_u32_p1i8( + uint64_t *keys, uint64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1u64_p1u64_u32_p1i8( + uint64_t *keys, uint64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1u64_p1i64_u32_p1i8( + uint64_t *keys, int64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1u64_p1i64_u32_p1i8( + uint64_t *keys, int64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1i64_p1u8_u32_p1i8( + int64_t *keys, uint8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1i64_p1u8_u32_p1i8( + int64_t *keys, uint8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1i64_p1i8_u32_p1i8( + int64_t *keys, int8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1i64_p1i8_u32_p1i8( + int64_t *keys, int8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1i64_p1u16_u32_p1i8( + int64_t *keys, uint16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1i64_p1u16_u32_p1i8( + int64_t *keys, uint16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1i64_p1i16_u32_p1i8( + int64_t *keys, int16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1i64_p1i16_u32_p1i8( + int64_t *keys, int16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1i64_p1u32_u32_p1i8( + int64_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1i64_p1u32_u32_p1i8( + int64_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1i64_p1i32_u32_p1i8( + int64_t *keys, int32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1i64_p1i32_u32_p1i8( + int64_t *keys, int32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1i64_p1u64_u32_p1i8( + int64_t *keys, uint64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1i64_p1u64_u32_p1i8( + int64_t *keys, uint64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1i64_p1i64_u32_p1i8( + int64_t *keys, int64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1i64_p1i64_u32_p1i8( + int64_t *keys, int64_t *vals, uint32_t n, uint8_t *scratch); diff --git a/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_i64.cpp b/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_i64.cpp new file mode 100644 index 0000000000000..5ef1e653baa16 --- /dev/null +++ b/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_i64.cpp @@ -0,0 +1,688 @@ +// RUN: %{build} -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -fsycl-device-lib-jit-link -o %t.out +// RUN: %{run} %t.out + +// UNSUPPORTED: cuda || hip || cpu + +#include "group_private_KV_sort_p1p1_p1.hpp" +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace sycl; + +template +void test_work_group_KV_private_sort(sycl::queue &q, const KeyT keys[NUM], + const ValT vals[NUM], SortHelper gsh) { + static_assert((NUM % WG_SZ == 0), + "Input number must be divisible by work group size!"); + + KeyT input_keys[NUM]; + ValT input_vals[NUM]; + memcpy(&input_keys[0], &keys[0], NUM * sizeof(KeyT)); + memcpy(&input_vals[0], &vals[0], NUM * sizeof(ValT)); + size_t scratch_size = 2 * NUM * (sizeof(KeyT) + sizeof(ValT)) + + std::max(alignof(KeyT), alignof(ValT)); + uint8_t *scratch_ptr = + (uint8_t *)aligned_alloc_device(alignof(KeyT), scratch_size, q); + const static size_t wg_size = WG_SZ; + constexpr size_t num_per_work_item = NUM / WG_SZ; + KeyT output_keys[NUM]; + ValT output_vals[NUM]; + std::vector> sorted_vec; + for (size_t idx = 0; idx < NUM; ++idx) + sorted_vec.push_back(std::make_tuple(input_keys[idx], input_vals[idx])); +#ifdef DES + auto kv_tuple_comp = [](const std::tuple &t1, + const std::tuple &t2) { + return std::get<0>(t1) > std::get<0>(t2); + }; +#else + auto kv_tuple_comp = [](const std::tuple &t1, + const std::tuple &t2) { + return std::get<0>(t1) < std::get<0>(t2); + }; +#endif + std::stable_sort(sorted_vec.begin(), sorted_vec.end(), kv_tuple_comp); + + /*for (size_t idx = 0; idx < NUM; ++idx) { + std::cout << "key: " << (int)std::get<0>(sorted_vec[idx]) << " val: " << + (int)std::get<1>(sorted_vec[idx]) << std::endl; + }*/ + + nd_range<1> num_items((range<1>(wg_size)), (range<1>(wg_size))); + { + buffer ikeys_buf(input_keys, NUM); + buffer ivals_buf(input_vals, NUM); + buffer okeys_buf(output_keys, NUM); + buffer ovals_buf(output_vals, NUM); + q.submit([&](auto &h) { + accessor ikeys_acc{ikeys_buf, h}; + accessor ivals_acc{ivals_buf, h}; + accessor okeys_acc{okeys_buf, h}; + accessor ovals_acc{ovals_buf, h}; + sycl::stream os(1024, 128, h); + h.parallel_for(num_items, [=](nd_item<1> i) { + KeyT pkeys[num_per_work_item]; + ValT pvals[num_per_work_item]; + // copy from global input to fix-size private array. + for (size_t idx = 0; idx < num_per_work_item; ++idx) { + pkeys[idx] = + ikeys_acc[i.get_local_linear_id() * num_per_work_item + idx]; + pvals[idx] = + ivals_acc[i.get_local_linear_id() * num_per_work_item + idx]; + } + + gsh(pkeys, pvals, num_per_work_item, scratch_ptr); + + for (size_t idx = 0; idx < num_per_work_item; ++idx) { + okeys_acc[i.get_local_linear_id() * num_per_work_item + idx] = + pkeys[idx]; + ovals_acc[i.get_local_linear_id() * num_per_work_item + idx] = + pvals[idx]; + } + }); + }).wait(); + } + + /* for (size_t idx = 0; idx < NUM; ++idx) { + std::cout << "key: " << (int)(input_keys[idx]) << " val: " << + (int)(input_vals[idx]) << std::endl; + }*/ + + sycl::free(scratch_ptr, q); + bool fails = false; + for (size_t idx = 0; idx < NUM; ++idx) { + if ((output_keys[idx] != std::get<0>(sorted_vec[idx])) || + (output_vals[idx] != std::get<1>(sorted_vec[idx]))) { + std::cout << "idx: " << idx << std::endl; + fails = true; + break; + } + } + assert(!fails); +} + +int main() { + queue q; + + { + constexpr static int NUM = 40; + int64_t ikeys[NUM] = { + 1, 11, -1, 9, 3, 100, -34, 8, + 121, 77, 125, 23, 222336, 2, 111, 91, + -881112, 2, 6662451, 213, 199181, -3183, 310910, 11242, + 3216, 1, -199, 7712423, 0, 0, 181, 17, + 15, -101, 44, 103934, 1, -11, 193, 213}; + uint8_t ivals1[NUM] = {99, 32, 1, 2, 67, 91, 45, 43, 91, 77, + 16, 14, 24, 88, 76, 96, 76, 100, 63, 90, + 52, 82, 1, 22, 9, 225, 127, 0, 12, 128, + 3, 102, 200, 111, 123, 15, 45, 66, 123, 91}; + int8_t ivals2[NUM] = {-99, 127, -121, 100, 9, 5, 12, 35, -98, 77, + 112, -91, 11, 12, 3, 71, -66, 121, 18, 14, + 21, -22, 54, 88, -81, 31, 23, 53, 97, 103, + 71, 83, 97, 37, -41, -71, 112, -121, 98, 78}; + uint16_t ivals3[NUM] = { + 28831, 23870, 54250, 5022, 9571, 60147, 9554, 18818, 28689, 18229, + 40512, 23200, 40454, 24841, 43251, 63264, 29448, 45917, 882, 30788, + 7586, 57541, 22108, 59535, 31880, 7152, 63919, 58703, 14686, 29914, + 5872, 35868, 51479, 22721, 50927, 55094, 2341, 12, 23411, 9812}; + int16_t ivals4[NUM] = { + 2798, -13656, 1592, 3992, -25870, 25172, 7761, -18347, + 1617, 25472, 26763, -5982, 24791, 27189, 22911, 22502, + 15801, 25326, -2196, 9205, -10418, 20464, -16616, -11285, + 7249, 22866, 30574, -1298, 31351, 28252, 21322, -10072, + 7874, -26785, 22016, -12421, 0, 9999, 8888, -7777}; + + uint32_t ivals5[NUM] = { + 2238578408, 102907035, 2316773768, 617902655, 532045482, 73173328, + 1862406505, 142735533, 3494078873, 610196959, 4210902254, 1863122236, + 1257721692, 30008197, 3199012044, 3503276708, 3504950001, 1240383071, + 2463430884, 904104390, 4044803029, 3164373711, 1586440767, 1999536602, + 3377662770, 927279985, 1740225703, 1133653675, 3975816601, 260339911, + 1115507520, 2279020820, 4289105012, 692964674, 53775301, 999999999, + 777777777, 44444444, 1000, 9124}; + + int32_t ivals6[NUM] = { + 507394811, 1949685322, 1624859474, -940434061, -1440675113, + -2002743224, 369969519, 840772268, 224522238, 296113452, + -714007528, 480713824, 665592454, 1696360848, 780843358, + -1901994531, 1667711523, 1390737696, 1357434904, -290165630, + 305128121, 1301489180, 630469211, -1385846315, 809333959, + 1098974670, 56900257, 876775101, -1496897817, 1172877939, + 1528916082, 559152364, 749878571, 2071902702, -430851798, + 912342, -88888888, 777777777, -11111111, 0}; + + uint64_t ivals7[NUM] = {7916688577774406903ULL, + 16873442000174444235ULL, + 5261140449171682429ULL, + 6274209348061756377ULL, + 17881284978944229367ULL, + 4701456380424752599ULL, + 6241062870187348613ULL, + 8972524466137433448ULL, + 468629112944127776ULL, + 17523909311582893643ULL, + 17447772191733931166ULL, + 14311152396797789854ULL, + 9265327272409079662ULL, + 9958475911404242556ULL, + 15359829357736068242ULL, + 11416531655738519189ULL, + 16839972510321195914ULL, + 1927049095689256442ULL, + 3356565661065236628ULL, + 1065114285796701562ULL, + 7071763288904613033ULL, + 16473053015775147286ULL, + 10317354477399696817ULL, + 16005969584273256379ULL, + 15391010921709289298ULL, + 17671303749287233862ULL, + 8028596930095411867ULL, + 10265936863337610975ULL, + 17403745948359398749ULL, + 8504886230707230194ULL, + 12855085916215721214ULL, + 5562885793068933146ULL, + 1508385574711135517ULL, + 5953119477818575536ULL, + 9165320150094769334ULL, + 0, + 11, + 324, + 943534, + 930525}; + + int64_t ivals8[NUM] = {2944696543084623337, + 137239101631340692, + 4370169869966467498, + 3842452903439153631, + -795080033670806202, + 3023506421574592237, + -4142692575864168559, + 1716333381567689984, + 1591746912204250089, + -1974664662220599925, + 3144022139297218102, + -371429365537296255, + 4202906659701034264, + 3878513012313576184, + -3425767072006791628, + -2929337291418891626, + 1880013370888913338, + 1498977159463939728, + -2928775660744278650, + 4074214991200977615, + 4291797122649374026, + -763110360214750992, + 2883673064242977727, + 4270151072450399778, + 1408696225027958214, + 1264214335825459628, + -4152065441956669638, + 2684706424226400837, + 569335474272794084, + -2798088842177577224, + 814002749054152728, + 2517003920904582842, + 4089891582575745386, + 705067059635512048, + -2500935118374519236, + 0, + -1, + 234, + 45, + -1232143254}; + + auto work_group_sorter = [](int64_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES + __devicelib_default_work_group_private_sort_close_descending_p1i64_p1u32_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1i64_p1u32_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif + }; + + auto work_group_sorter1 = [](int64_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES + __devicelib_default_work_group_private_sort_close_descending_p1i64_p1u8_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1i64_p1u8_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif + }; + + auto work_group_sorter2 = [](int64_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES + __devicelib_default_work_group_private_sort_close_descending_p1i64_p1i8_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1i64_p1i8_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif + }; + + auto work_group_sorter3 = [](int64_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES + __devicelib_default_work_group_private_sort_close_descending_p1i64_p1u16_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1i64_p1u16_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif + }; + + auto work_group_sorter4 = [](int64_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES + __devicelib_default_work_group_private_sort_close_descending_p1i64_p1i16_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1i64_p1i16_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif + }; + + auto work_group_sorter5 = [](int64_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES + __devicelib_default_work_group_private_sort_close_descending_p1i64_p1u32_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1i64_p1u32_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif + }; + + auto work_group_sorter6 = [](int64_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES + __devicelib_default_work_group_private_sort_close_descending_p1i64_p1i32_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1i64_p1i32_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif + }; + + auto work_group_sorter7 = [](int64_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES + __devicelib_default_work_group_private_sort_close_descending_p1i64_p1u64_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1i64_p1u64_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif + }; + + auto work_group_sorter8 = [](int64_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES + __devicelib_default_work_group_private_sort_close_descending_p1i64_p1i64_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1i64_p1i64_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif + }; + + constexpr static int NUM1 = 36; + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 4 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 3 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 6 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 9 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 18 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 36 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 5 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 8 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 4 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 10 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 20 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 40 pass." << std::endl; + + constexpr static int NUM2 = 24; + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 3 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 4 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 6 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 8 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 12 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 24 pass." << std::endl; + + constexpr static int NUM3 = 20; + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 4 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 5 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 10 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 20 pass." << std::endl; + + constexpr static int NUM4 = 30; + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 3 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 5 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 6 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 10 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 15 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 30 pass." << std::endl; + + constexpr static int NUM5 = 25; + test_work_group_KV_private_sort( + q, ikeys, ivals6, work_group_sorter6); + std::cout << "KV private sort NUM = " << NUM5 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals6, work_group_sorter6); + std::cout << "KV private sort NUM = " << NUM5 + << ", WG = 5 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals6, work_group_sorter6); + std::cout << "KV private sort NUM = " << NUM5 + << ", WG = 25 pass." << std::endl; + + constexpr static int NUM6 = 30; + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 3 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 5 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 6 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 10 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 15 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 30 pass." << std::endl; + + constexpr static int NUM7 = 21; + test_work_group_KV_private_sort( + q, ikeys, ivals8, work_group_sorter8); + std::cout << "KV private sort NUM = " << NUM7 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals8, work_group_sorter8); + std::cout << "KV private sort NUM = " << NUM7 + << ", WG = 3 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals8, work_group_sorter8); + std::cout << "KV private sort NUM = " << NUM7 + << ", WG = 7 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals8, work_group_sorter8); + std::cout << "KV private sort NUM = " << NUM7 + << ", WG = 21 pass." << std::endl; + } +} diff --git a/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_u64.cpp b/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_u64.cpp new file mode 100644 index 0000000000000..723a4b037f227 --- /dev/null +++ b/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_u64.cpp @@ -0,0 +1,688 @@ +// RUN: %{build} -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -fsycl-device-lib-jit-link -o %t.out +// RUN: %{run} %t.out + +// UNSUPPORTED: cuda || hip || cpu + +#include "group_private_KV_sort_p1p1_p1.hpp" +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace sycl; + +template +void test_work_group_KV_private_sort(sycl::queue &q, const KeyT keys[NUM], + const ValT vals[NUM], SortHelper gsh) { + static_assert((NUM % WG_SZ == 0), + "Input number must be divisible by work group size!"); + + KeyT input_keys[NUM]; + ValT input_vals[NUM]; + memcpy(&input_keys[0], &keys[0], NUM * sizeof(KeyT)); + memcpy(&input_vals[0], &vals[0], NUM * sizeof(ValT)); + size_t scratch_size = 2 * NUM * (sizeof(KeyT) + sizeof(ValT)) + + std::max(alignof(KeyT), alignof(ValT)); + uint8_t *scratch_ptr = + (uint8_t *)aligned_alloc_device(alignof(KeyT), scratch_size, q); + const static size_t wg_size = WG_SZ; + constexpr size_t num_per_work_item = NUM / WG_SZ; + KeyT output_keys[NUM]; + ValT output_vals[NUM]; + std::vector> sorted_vec; + for (size_t idx = 0; idx < NUM; ++idx) + sorted_vec.push_back(std::make_tuple(input_keys[idx], input_vals[idx])); +#ifdef DES + auto kv_tuple_comp = [](const std::tuple &t1, + const std::tuple &t2) { + return std::get<0>(t1) > std::get<0>(t2); + }; +#else + auto kv_tuple_comp = [](const std::tuple &t1, + const std::tuple &t2) { + return std::get<0>(t1) < std::get<0>(t2); + }; +#endif + std::stable_sort(sorted_vec.begin(), sorted_vec.end(), kv_tuple_comp); + + /*for (size_t idx = 0; idx < NUM; ++idx) { + std::cout << "key: " << (int)std::get<0>(sorted_vec[idx]) << " val: " << + (int)std::get<1>(sorted_vec[idx]) << std::endl; + }*/ + + nd_range<1> num_items((range<1>(wg_size)), (range<1>(wg_size))); + { + buffer ikeys_buf(input_keys, NUM); + buffer ivals_buf(input_vals, NUM); + buffer okeys_buf(output_keys, NUM); + buffer ovals_buf(output_vals, NUM); + q.submit([&](auto &h) { + accessor ikeys_acc{ikeys_buf, h}; + accessor ivals_acc{ivals_buf, h}; + accessor okeys_acc{okeys_buf, h}; + accessor ovals_acc{ovals_buf, h}; + sycl::stream os(1024, 128, h); + h.parallel_for(num_items, [=](nd_item<1> i) { + KeyT pkeys[num_per_work_item]; + ValT pvals[num_per_work_item]; + // copy from global input to fix-size private array. + for (size_t idx = 0; idx < num_per_work_item; ++idx) { + pkeys[idx] = + ikeys_acc[i.get_local_linear_id() * num_per_work_item + idx]; + pvals[idx] = + ivals_acc[i.get_local_linear_id() * num_per_work_item + idx]; + } + + gsh(pkeys, pvals, num_per_work_item, scratch_ptr); + + for (size_t idx = 0; idx < num_per_work_item; ++idx) { + okeys_acc[i.get_local_linear_id() * num_per_work_item + idx] = + pkeys[idx]; + ovals_acc[i.get_local_linear_id() * num_per_work_item + idx] = + pvals[idx]; + } + }); + }).wait(); + } + + /* for (size_t idx = 0; idx < NUM; ++idx) { + std::cout << "key: " << (int)(input_keys[idx]) << " val: " << + (int)(input_vals[idx]) << std::endl; + }*/ + + sycl::free(scratch_ptr, q); + bool fails = false; + for (size_t idx = 0; idx < NUM; ++idx) { + if ((output_keys[idx] != std::get<0>(sorted_vec[idx])) || + (output_vals[idx] != std::get<1>(sorted_vec[idx]))) { + std::cout << "idx: " << idx << std::endl; + fails = true; + break; + } + } + assert(!fails); +} + +int main() { + queue q; + + { + constexpr static int NUM = 40; + uint64_t ikeys[NUM] = { + 1, 11, 1, 9, 3, 100, 34, 8, + 121, 77, 125, 23, 222336, 2, 111, 91, + 881112, 2, 6662451, 213, 199181, 3183, 310910, 11242, + 3216, 1, 199, 7712423, 0, 0, 181, 17, + 15, 101, 44, 103934, 1, 11, 193, 213}; + uint8_t ivals1[NUM] = {99, 32, 1, 2, 67, 91, 45, 43, 91, 77, + 16, 14, 24, 88, 76, 96, 76, 100, 63, 90, + 52, 82, 1, 22, 9, 225, 127, 0, 12, 128, + 3, 102, 200, 111, 123, 15, 45, 66, 123, 91}; + int8_t ivals2[NUM] = {-99, 127, -121, 100, 9, 5, 12, 35, -98, 77, + 112, -91, 11, 12, 3, 71, -66, 121, 18, 14, + 21, -22, 54, 88, -81, 31, 23, 53, 97, 103, + 71, 83, 97, 37, -41, -71, 112, -121, 98, 78}; + uint16_t ivals3[NUM] = { + 28831, 23870, 54250, 5022, 9571, 60147, 9554, 18818, 28689, 18229, + 40512, 23200, 40454, 24841, 43251, 63264, 29448, 45917, 882, 30788, + 7586, 57541, 22108, 59535, 31880, 7152, 63919, 58703, 14686, 29914, + 5872, 35868, 51479, 22721, 50927, 55094, 2341, 12, 23411, 9812}; + int16_t ivals4[NUM] = { + 2798, -13656, 1592, 3992, -25870, 25172, 7761, -18347, + 1617, 25472, 26763, -5982, 24791, 27189, 22911, 22502, + 15801, 25326, -2196, 9205, -10418, 20464, -16616, -11285, + 7249, 22866, 30574, -1298, 31351, 28252, 21322, -10072, + 7874, -26785, 22016, -12421, 0, 9999, 8888, -7777}; + + uint32_t ivals5[NUM] = { + 2238578408, 102907035, 2316773768, 617902655, 532045482, 73173328, + 1862406505, 142735533, 3494078873, 610196959, 4210902254, 1863122236, + 1257721692, 30008197, 3199012044, 3503276708, 3504950001, 1240383071, + 2463430884, 904104390, 4044803029, 3164373711, 1586440767, 1999536602, + 3377662770, 927279985, 1740225703, 1133653675, 3975816601, 260339911, + 1115507520, 2279020820, 4289105012, 692964674, 53775301, 999999999, + 777777777, 44444444, 1000, 9124}; + + int32_t ivals6[NUM] = { + 507394811, 1949685322, 1624859474, -940434061, -1440675113, + -2002743224, 369969519, 840772268, 224522238, 296113452, + -714007528, 480713824, 665592454, 1696360848, 780843358, + -1901994531, 1667711523, 1390737696, 1357434904, -290165630, + 305128121, 1301489180, 630469211, -1385846315, 809333959, + 1098974670, 56900257, 876775101, -1496897817, 1172877939, + 1528916082, 559152364, 749878571, 2071902702, -430851798, + 912342, -88888888, 777777777, -11111111, 0}; + + uint64_t ivals7[NUM] = {7916688577774406903ULL, + 16873442000174444235ULL, + 5261140449171682429ULL, + 6274209348061756377ULL, + 17881284978944229367ULL, + 4701456380424752599ULL, + 6241062870187348613ULL, + 8972524466137433448ULL, + 468629112944127776ULL, + 17523909311582893643ULL, + 17447772191733931166ULL, + 14311152396797789854ULL, + 9265327272409079662ULL, + 9958475911404242556ULL, + 15359829357736068242ULL, + 11416531655738519189ULL, + 16839972510321195914ULL, + 1927049095689256442ULL, + 3356565661065236628ULL, + 1065114285796701562ULL, + 7071763288904613033ULL, + 16473053015775147286ULL, + 10317354477399696817ULL, + 16005969584273256379ULL, + 15391010921709289298ULL, + 17671303749287233862ULL, + 8028596930095411867ULL, + 10265936863337610975ULL, + 17403745948359398749ULL, + 8504886230707230194ULL, + 12855085916215721214ULL, + 5562885793068933146ULL, + 1508385574711135517ULL, + 5953119477818575536ULL, + 9165320150094769334ULL, + 0, + 11, + 324, + 943534, + 930525}; + + int64_t ivals8[NUM] = {2944696543084623337, + 137239101631340692, + 4370169869966467498, + 3842452903439153631, + -795080033670806202, + 3023506421574592237, + -4142692575864168559, + 1716333381567689984, + 1591746912204250089, + -1974664662220599925, + 3144022139297218102, + -371429365537296255, + 4202906659701034264, + 3878513012313576184, + -3425767072006791628, + -2929337291418891626, + 1880013370888913338, + 1498977159463939728, + -2928775660744278650, + 4074214991200977615, + 4291797122649374026, + -763110360214750992, + 2883673064242977727, + 4270151072450399778, + 1408696225027958214, + 1264214335825459628, + -4152065441956669638, + 2684706424226400837, + 569335474272794084, + -2798088842177577224, + 814002749054152728, + 2517003920904582842, + 4089891582575745386, + 705067059635512048, + -2500935118374519236, + 0, + -1, + 234, + 45, + -1232143254}; + + auto work_group_sorter = [](uint64_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES + __devicelib_default_work_group_private_sort_close_descending_p1u64_p1u32_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1u64_p1u32_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif + }; + + auto work_group_sorter1 = [](uint64_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES + __devicelib_default_work_group_private_sort_close_descending_p1u64_p1u8_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1u64_p1u8_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif + }; + + auto work_group_sorter2 = [](uint64_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES + __devicelib_default_work_group_private_sort_close_descending_p1u64_p1i8_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1u64_p1i8_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif + }; + + auto work_group_sorter3 = [](uint64_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES + __devicelib_default_work_group_private_sort_close_descending_p1u64_p1u16_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1u64_p1u16_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif + }; + + auto work_group_sorter4 = [](uint64_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES + __devicelib_default_work_group_private_sort_close_descending_p1u64_p1i16_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1u64_p1i16_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif + }; + + auto work_group_sorter5 = [](uint64_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES + __devicelib_default_work_group_private_sort_close_descending_p1u64_p1u32_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1u64_p1u32_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif + }; + + auto work_group_sorter6 = [](uint64_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES + __devicelib_default_work_group_private_sort_close_descending_p1u64_p1i32_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1u64_p1i32_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif + }; + + auto work_group_sorter7 = [](uint64_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES + __devicelib_default_work_group_private_sort_close_descending_p1u64_p1u64_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1u64_p1u64_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif + }; + + auto work_group_sorter8 = [](uint64_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES + __devicelib_default_work_group_private_sort_close_descending_p1u64_p1i64_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1u64_p1i64_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif + }; + + constexpr static int NUM1 = 36; + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 4 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 3 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 6 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 9 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 18 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 36 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 5 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 8 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 4 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 10 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 20 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 40 pass." << std::endl; + + constexpr static int NUM2 = 24; + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 3 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 4 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 6 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 8 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 12 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 24 pass." << std::endl; + + constexpr static int NUM3 = 20; + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 4 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 5 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 10 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 20 pass." << std::endl; + + constexpr static int NUM4 = 30; + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 3 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 5 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 6 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 10 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 15 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 30 pass." << std::endl; + + constexpr static int NUM5 = 25; + test_work_group_KV_private_sort( + q, ikeys, ivals6, work_group_sorter6); + std::cout << "KV private sort NUM = " << NUM5 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals6, work_group_sorter6); + std::cout << "KV private sort NUM = " << NUM5 + << ", WG = 5 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals6, work_group_sorter6); + std::cout << "KV private sort NUM = " << NUM5 + << ", WG = 25 pass." << std::endl; + + constexpr static int NUM6 = 30; + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 3 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 5 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 6 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 10 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 15 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 30 pass." << std::endl; + + constexpr static int NUM7 = 21; + test_work_group_KV_private_sort( + q, ikeys, ivals8, work_group_sorter8); + std::cout << "KV private sort NUM = " << NUM7 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals8, work_group_sorter8); + std::cout << "KV private sort NUM = " << NUM7 + << ", WG = 3 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals8, work_group_sorter8); + std::cout << "KV private sort NUM = " << NUM7 + << ", WG = 7 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals8, work_group_sorter8); + std::cout << "KV private sort NUM = " << NUM7 + << ", WG = 21 pass." << std::endl; + } +} From 828be3b8ce1cfa4cb3d0b16ba61e1061177042e8 Mon Sep 17 00:00:00 2001 From: jinge90 Date: Wed, 9 Oct 2024 17:07:26 +0800 Subject: [PATCH 59/71] apply group broadcast to KV private sort Signed-off-by: jinge90 --- libdevice/group_helper.hpp | 4 ++++ libdevice/sort_helper.hpp | 42 +++++++++++++++++++++----------------- libdevice/spirv_decls.hpp | 2 ++ 3 files changed, 29 insertions(+), 19 deletions(-) diff --git a/libdevice/group_helper.hpp b/libdevice/group_helper.hpp index 92fe256dd500f..934514e5326c2 100644 --- a/libdevice/group_helper.hpp +++ b/libdevice/group_helper.hpp @@ -30,4 +30,8 @@ static inline void group_barrier() { __spv::MemorySemanticsMask::WorkgroupMemory | __spv::MemorySemanticsMask::CrossWorkgroupMemory); } + +static inline uint64_t group_broadcast(uint64_t x) { + return __spirv_GroupBroadcast(__spv::Scope::Flag::Workgroup, x, 0); +} #endif // __SPIR__ || __SPIRV__ diff --git a/libdevice/sort_helper.hpp b/libdevice/sort_helper.hpp index 13ca9b1e1c64f..9f8f04e07c520 100644 --- a/libdevice/sort_helper.hpp +++ b/libdevice/sort_helper.hpp @@ -296,34 +296,38 @@ void private_merge_sort_key_value_close(KeyT *keys, ValT *vals, size_t n, uint8_t *scratch, Compare comp) { const size_t local_idx = __get_wg_local_linear_id(); const size_t wg_size = __get_wg_local_range(); - KeyT *temp_key_beg = reinterpret_cast(scratch); - uint64_t temp_val_unaligned = - reinterpret_cast(scratch + 2 * wg_size * n * sizeof(KeyT)); - ValT *temp_val_beg = nullptr; - uint64_t temp1 = temp_val_unaligned % alignof(ValT); - if (temp1) - temp_val_beg = - reinterpret_cast(temp_val_unaligned + alignof(ValT) - temp1); - else - temp_val_beg = reinterpret_cast(temp_val_unaligned); + uint64_t temp_val_beg = 0, temp_key_beg = 0, internal_scratch = 0; + KeyT *keys_ptr = nullptr; + ValT *vals_ptr = nullptr; + uint8_t *scratch_ptr = nullptr; + + if (local_idx == 0) { + uint64_t temp_val_unaligned = + reinterpret_cast(scratch + 2 * wg_size * n * sizeof(KeyT)); + uint64_t temp1 = temp_val_unaligned % alignof(ValT); + temp_val_beg = (temp1 != 0) ? (temp_val_unaligned + alignof(ValT) - temp1) + : temp_val_unaligned; + temp_val_beg += sizeof(ValT) * n * wg_size; + temp_key_beg = reinterpret_cast(scratch); + internal_scratch = temp_key_beg + sizeof(KeyT) * n * wg_size; + } - uint8_t *internal_scratch = - reinterpret_cast(&temp_key_beg[n * wg_size]); - temp_val_beg = &temp_val_beg[n * wg_size]; + keys_ptr = reinterpret_cast(group_broadcast(temp_key_beg)); + vals_ptr = reinterpret_cast(group_broadcast(temp_val_beg)); + scratch_ptr = reinterpret_cast(group_broadcast(internal_scratch)); for (size_t i = 0; i < n; ++i) { - temp_key_beg[local_idx * n + i] = keys[i]; - temp_val_beg[local_idx * n + i] = vals[i]; + keys_ptr[local_idx * n + i] = keys[i]; + vals_ptr[local_idx * n + i] = vals[i]; } group_barrier(); - merge_sort_key_value(temp_key_beg, temp_val_beg, n * wg_size, - internal_scratch, comp); + merge_sort_key_value(keys_ptr, vals_ptr, n * wg_size, scratch_ptr, comp); for (size_t i = 0; i < n; ++i) { - keys[i] = temp_key_beg[local_idx * n + i]; - vals[i] = temp_val_beg[local_idx * n + i]; + keys[i] = keys_ptr[local_idx * n + i]; + vals[i] = vals_ptr[local_idx * n + i]; } } diff --git a/libdevice/spirv_decls.hpp b/libdevice/spirv_decls.hpp index e22fdb7f338f8..7e48f4104be0e 100644 --- a/libdevice/spirv_decls.hpp +++ b/libdevice/spirv_decls.hpp @@ -83,4 +83,6 @@ extern DEVICE_EXTERNAL void __spirv_AtomicStore(int *, __spv::Scope::Flag, __spv::MemorySemanticsMask::Flag, int); +extern DEVICE_EXTERNAL uint64_t __spirv_GroupBroadcast(__spv::Scope::Flag, + uint64_t, uint64_t); #endif // __SPIR__ || __SPIRV__ From 8a03ba51fd2607def7d71cf8d080e5af52b80049 Mon Sep 17 00:00:00 2001 From: jinge90 Date: Fri, 11 Oct 2024 17:44:02 +0800 Subject: [PATCH 60/71] add private KV spread sort Signed-off-by: jinge90 --- libdevice/fallback-gsort.cpp | 1042 +++++++++++++++++ libdevice/sort_helper.hpp | 43 +- .../group_private_KV_sort_p1p1_p1.hpp | 661 ++++++++++- .../workgroup_private_KV_sort_i16.cpp | 211 ++-- .../workgroup_private_KV_sort_i32.cpp | 211 ++-- .../workgroup_private_KV_sort_i64.cpp | 211 ++-- .../workgroup_private_KV_sort_i8.cpp | 211 ++-- .../workgroup_private_KV_sort_u16.cpp | 211 ++-- .../workgroup_private_KV_sort_u32.cpp | 211 ++-- .../workgroup_private_KV_sort_u64.cpp | 211 ++-- .../workgroup_private_KV_sort_u8.cpp | 239 ++-- 11 files changed, 2605 insertions(+), 857 deletions(-) diff --git a/libdevice/fallback-gsort.cpp b/libdevice/fallback-gsort.cpp index f4de0eb50e465..e7f176153456c 100644 --- a/libdevice/fallback-gsort.cpp +++ b/libdevice/fallback-gsort.cpp @@ -2928,4 +2928,1046 @@ void WG_PS_CD(p1i64_p1f32_u32_p1i8)(int64_t *keys, float *vals, uint32_t n, n, scratch, std::greater_equal{}); } +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1u8_p1u8_u32_p1i8)(uint8_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1u8_p1u8_u32_p1i8)(uint8_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1u8_p1i8_u32_p1i8)(uint8_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1u8_p1i8_u32_p1i8)(uint8_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1u8_p1u16_u32_p1i8)(uint8_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1u8_p1u16_u32_p1i8)(uint8_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1u8_p1i16_u32_p1i8)(uint8_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1u8_p1i16_u32_p1i8)(uint8_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1u8_p1u32_u32_p1i8)(uint8_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1u8_p1u32_u32_p1i8)(uint8_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1u8_p1i32_u32_p1i8)(uint8_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1u8_p1i32_u32_p1i8)(uint8_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1u8_p1u64_u32_p1i8)(uint8_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1u8_p1u64_u32_p1i8)(uint8_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1u8_p1i64_u32_p1i8)(uint8_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1u8_p1i64_u32_p1i8)(uint8_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1u8_p1f32_u32_p1i8)(uint8_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1u8_p1f32_u32_p1i8)(uint8_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1i8_p1u8_u32_p1i8)(int8_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1i8_p1u8_u32_p1i8)(int8_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1i8_p1i8_u32_p1i8)(int8_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1i8_p1i8_u32_p1i8)(int8_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1i8_p1u16_u32_p1i8)(int8_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1i8_p1u16_u32_p1i8)(int8_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1i8_p1i16_u32_p1i8)(int8_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1i8_p1i16_u32_p1i8)(int8_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1i8_p1u32_u32_p1i8)(int8_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1i8_p1u32_u32_p1i8)(int8_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1i8_p1i32_u32_p1i8)(int8_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1i8_p1i32_u32_p1i8)(int8_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1i8_p1u64_u32_p1i8)(int8_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1i8_p1u64_u32_p1i8)(int8_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1i8_p1i64_u32_p1i8)(int8_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1i8_p1i64_u32_p1i8)(int8_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1i8_p1f32_u32_p1i8)(int8_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1i8_p1f32_u32_p1i8)(int8_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1u16_p1u8_u32_p1i8)(uint16_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1u16_p1u8_u32_p1i8)(uint16_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1u16_p1i8_u32_p1i8)(uint16_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1u16_p1i8_u32_p1i8)(uint16_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1u16_p1u16_u32_p1i8)(uint16_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1u16_p1u16_u32_p1i8)(uint16_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1u16_p1i16_u32_p1i8)(uint16_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1u16_p1i16_u32_p1i8)(uint16_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1u16_p1u32_u32_p1i8)(uint16_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1u16_p1u32_u32_p1i8)(uint16_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1u16_p1i32_u32_p1i8)(uint16_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1u16_p1i32_u32_p1i8)(uint16_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1u16_p1u64_u32_p1i8)(uint16_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1u16_p1u64_u32_p1i8)(uint16_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1u16_p1i64_u32_p1i8)(uint16_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1u16_p1i64_u32_p1i8)(uint16_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1u16_p1f32_u32_p1i8)(uint16_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1u16_p1f32_u32_p1i8)(uint16_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1i16_p1u8_u32_p1i8)(int16_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1i16_p1u8_u32_p1i8)(int16_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1i16_p1i8_u32_p1i8)(int16_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1i16_p1i8_u32_p1i8)(int16_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1i16_p1u16_u32_p1i8)(int16_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1i16_p1u16_u32_p1i8)(int16_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1i16_p1i16_u32_p1i8)(int16_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1i16_p1i16_u32_p1i8)(int16_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1i16_p1u32_u32_p1i8)(int16_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1i16_p1u32_u32_p1i8)(int16_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1i16_p1i32_u32_p1i8)(int16_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1i16_p1i32_u32_p1i8)(int16_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1i16_p1u64_u32_p1i8)(int16_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1i16_p1u64_u32_p1i8)(int16_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1i16_p1i64_u32_p1i8)(int16_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1i16_p1i64_u32_p1i8)(int16_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1i16_p1f32_u32_p1i8)(int16_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1i16_p1f32_u32_p1i8)(int16_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1u32_p1u8_u32_p1i8)(uint32_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1u32_p1u8_u32_p1i8)(uint32_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1u32_p1i8_u32_p1i8)(uint32_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1u32_p1i8_u32_p1i8)(uint32_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1u32_p1u16_u32_p1i8)(uint32_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1u32_p1u16_u32_p1i8)(uint32_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1u32_p1i16_u32_p1i8)(uint32_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1u32_p1i16_u32_p1i8)(uint32_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1u32_p1u32_u32_p1i8)(uint32_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1u32_p1u32_u32_p1i8)(uint32_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1u32_p1i32_u32_p1i8)(uint32_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1u32_p1i32_u32_p1i8)(uint32_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1u32_p1u64_u32_p1i8)(uint32_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1u32_p1u64_u32_p1i8)(uint32_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1u32_p1i64_u32_p1i8)(uint32_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1u32_p1i64_u32_p1i8)(uint32_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1u32_p1f32_u32_p1i8)(uint32_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1u32_p1f32_u32_p1i8)(uint32_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1i32_p1u8_u32_p1i8)(int32_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1i32_p1u8_u32_p1i8)(int32_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1i32_p1i8_u32_p1i8)(int32_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1i32_p1i8_u32_p1i8)(int32_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1i32_p1u16_u32_p1i8)(int32_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1i32_p1u16_u32_p1i8)(int32_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1i32_p1i16_u32_p1i8)(int32_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1i32_p1i16_u32_p1i8)(int32_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1i32_p1u32_u32_p1i8)(int32_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1i32_p1u32_u32_p1i8)(int32_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1i32_p1i32_u32_p1i8)(int32_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1i32_p1i32_u32_p1i8)(int32_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1i32_p1u64_u32_p1i8)(int32_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1i32_p1u64_u32_p1i8)(int32_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1i32_p1i64_u32_p1i8)(int32_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1i32_p1i64_u32_p1i8)(int32_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1i32_p1f32_u32_p1i8)(int32_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1i32_p1f32_u32_p1i8)(int32_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1u64_p1u8_u32_p1i8)(uint64_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1u64_p1u8_u32_p1i8)(uint64_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1u64_p1i8_u32_p1i8)(uint64_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1u64_p1i8_u32_p1i8)(uint64_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1u64_p1u16_u32_p1i8)(uint64_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1u64_p1u16_u32_p1i8)(uint64_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1u64_p1i16_u32_p1i8)(uint64_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1u64_p1i16_u32_p1i8)(uint64_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1u64_p1u32_u32_p1i8)(uint64_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1u64_p1u32_u32_p1i8)(uint64_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1u64_p1i32_u32_p1i8)(uint64_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1u64_p1i32_u32_p1i8)(uint64_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1u64_p1u64_u32_p1i8)(uint64_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1u64_p1u64_u32_p1i8)(uint64_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1u64_p1i64_u32_p1i8)(uint64_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1u64_p1i64_u32_p1i8)(uint64_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1u64_p1f32_u32_p1i8)(uint64_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1u64_p1f32_u32_p1i8)(uint64_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1i64_p1u8_u32_p1i8)(int64_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1i64_p1u8_u32_p1i8)(int64_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1i64_p1i8_u32_p1i8)(int64_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1i64_p1i8_u32_p1i8)(int64_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1i64_p1u16_u32_p1i8)(int64_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1i64_p1u16_u32_p1i8)(int64_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1i64_p1i16_u32_p1i8)(int64_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1i64_p1i16_u32_p1i8)(int64_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1i64_p1u32_u32_p1i8)(int64_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1i64_p1u32_u32_p1i8)(int64_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1i64_p1i32_u32_p1i8)(int64_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1i64_p1i32_u32_p1i8)(int64_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1i64_p1u64_u32_p1i8)(int64_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1i64_p1u64_u32_p1i8)(int64_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1i64_p1i64_u32_p1i8)(int64_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1i64_p1i64_u32_p1i8)(int64_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1i64_p1f32_u32_p1i8)(int64_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1i64_p1f32_u32_p1i8)(int64_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} #endif // __SPIR__ || __SPIRV__ diff --git a/libdevice/sort_helper.hpp b/libdevice/sort_helper.hpp index 9f8f04e07c520..f87adc09bf121 100644 --- a/libdevice/sort_helper.hpp +++ b/libdevice/sort_helper.hpp @@ -292,8 +292,10 @@ void merge_sort_key_value(KeyT *keys, ValT *vals, size_t n, uint8_t *scratch, // max(alignof(KeyT), alignof(ValT)) // The scrach memory alignment is max(alignof(KeyT), alignof(ValT)) template -void private_merge_sort_key_value_close(KeyT *keys, ValT *vals, size_t n, - uint8_t *scratch, Compare comp) { +static void private_merge_sort_key_value_helper(KeyT *keys, ValT *vals, + size_t n, uint8_t *scratch, + Compare comp, KeyT **keys_back, + ValT **vals_back) { const size_t local_idx = __get_wg_local_linear_id(); const size_t wg_size = __get_wg_local_range(); uint64_t temp_val_beg = 0, temp_key_beg = 0, internal_scratch = 0; @@ -315,6 +317,8 @@ void private_merge_sort_key_value_close(KeyT *keys, ValT *vals, size_t n, keys_ptr = reinterpret_cast(group_broadcast(temp_key_beg)); vals_ptr = reinterpret_cast(group_broadcast(temp_val_beg)); scratch_ptr = reinterpret_cast(group_broadcast(internal_scratch)); + *keys_back = keys_ptr; + *vals_back = vals_ptr; for (size_t i = 0; i < n; ++i) { keys_ptr[local_idx * n + i] = keys[i]; @@ -324,10 +328,41 @@ void private_merge_sort_key_value_close(KeyT *keys, ValT *vals, size_t n, group_barrier(); merge_sort_key_value(keys_ptr, vals_ptr, n * wg_size, scratch_ptr, comp); +} + +template +void private_merge_sort_key_value_close(KeyT *keys, ValT *vals, size_t n, + uint8_t *scratch, Compare comp) { + + KeyT *keys_back = nullptr; + ValT *vals_back = nullptr; + private_merge_sort_key_value_helper(keys, vals, n, scratch, comp, &keys_back, + &vals_back); + + + const size_t local_idx = __get_wg_local_linear_id(); + for (size_t i = 0; i < n; ++i) { + keys[i] = keys_back[local_idx * n + i]; + vals[i] = vals_back[local_idx * n + i]; + } +} + +template +void private_merge_sort_key_value_spread(KeyT *keys, ValT *vals, size_t n, + uint8_t *scratch, Compare comp) { + + KeyT *keys_back = nullptr; + ValT *vals_back = nullptr; + private_merge_sort_key_value_helper(keys, vals, n, scratch, comp, &keys_back, + &vals_back); + + + const size_t local_idx = __get_wg_local_linear_id(); + const size_t wg_size = __get_wg_local_range(); for (size_t i = 0; i < n; ++i) { - keys[i] = keys_ptr[local_idx * n + i]; - vals[i] = vals_ptr[local_idx * n + i]; + keys[i] = keys_back[wg_size * i + local_idx]; + vals[i] = vals_back[wg_size * i + local_idx]; } } diff --git a/sycl/test-e2e/DeviceLib/group_sort/group_private_KV_sort_p1p1_p1.hpp b/sycl/test-e2e/DeviceLib/group_sort/group_private_KV_sort_p1p1_p1.hpp index db8ac61c4bc2f..9a99a53e569d8 100644 --- a/sycl/test-e2e/DeviceLib/group_sort/group_private_KV_sort_p1p1_p1.hpp +++ b/sycl/test-e2e/DeviceLib/group_sort/group_private_KV_sort_p1p1_p1.hpp @@ -1,13 +1,12 @@ #include "group_sort.hpp" - -__DPCPP_SYCL_EXTERNAL extern "C" void -__devicelib_default_work_group_private_sort_close_ascending_p1u32_p1u32_u32_p1i8( - uint32_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch); - -__DPCPP_SYCL_EXTERNAL extern "C" void -__devicelib_default_work_group_private_sort_close_descending_p1u32_p1u32_u32_p1i8( - uint32_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch); - +#include +#include +#include +#include +#include +#include +#include +#include __DPCPP_SYCL_EXTERNAL extern "C" void __devicelib_default_work_group_private_sort_close_ascending_p1u8_p1u8_u32_p1i8( uint8_t *keys, uint8_t *vals, uint32_t n, uint8_t *scratch); @@ -543,3 +542,647 @@ __devicelib_default_work_group_private_sort_close_ascending_p1i64_p1i64_u32_p1i8 __DPCPP_SYCL_EXTERNAL extern "C" void __devicelib_default_work_group_private_sort_close_descending_p1i64_p1i64_u32_p1i8( int64_t *keys, int64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1u8_p1u8_u32_p1i8( + uint8_t *keys, uint8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1u8_p1u8_u32_p1i8( + uint8_t *keys, uint8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1u8_p1i8_u32_p1i8( + uint8_t *keys, int8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1u8_p1i8_u32_p1i8( + uint8_t *keys, int8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1u8_p1u16_u32_p1i8( + uint8_t *keys, uint16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1u8_p1u16_u32_p1i8( + uint8_t *keys, uint16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1u8_p1i16_u32_p1i8( + uint8_t *keys, int16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1u8_p1i16_u32_p1i8( + uint8_t *keys, int16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1u8_p1u32_u32_p1i8( + uint8_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1u8_p1u32_u32_p1i8( + uint8_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1u8_p1i32_u32_p1i8( + uint8_t *keys, int32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1u8_p1i32_u32_p1i8( + uint8_t *keys, int32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1u8_p1u64_u32_p1i8( + uint8_t *keys, uint64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1u8_p1u64_u32_p1i8( + uint8_t *keys, uint64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1u8_p1i64_u32_p1i8( + uint8_t *keys, int64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1u8_p1i64_u32_p1i8( + uint8_t *keys, int64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1u8_p1f32_u32_p1i8( + uint8_t *keys, float *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1i8_p1u8_u32_p1i8( + int8_t *keys, uint8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1i8_p1u8_u32_p1i8( + int8_t *keys, uint8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1i8_p1i8_u32_p1i8( + int8_t *keys, int8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1i8_p1i8_u32_p1i8( + int8_t *keys, int8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1i8_p1u16_u32_p1i8( + int8_t *keys, uint16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1i8_p1u16_u32_p1i8( + int8_t *keys, uint16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1i8_p1i16_u32_p1i8( + int8_t *keys, int16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1i8_p1i16_u32_p1i8( + int8_t *keys, int16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1i8_p1u32_u32_p1i8( + int8_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1i8_p1u32_u32_p1i8( + int8_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1i8_p1i32_u32_p1i8( + int8_t *keys, int32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1i8_p1i32_u32_p1i8( + int8_t *keys, int32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1i8_p1u64_u32_p1i8( + int8_t *keys, uint64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1i8_p1u64_u32_p1i8( + int8_t *keys, uint64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1i8_p1i64_u32_p1i8( + int8_t *keys, int64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1i8_p1i64_u32_p1i8( + int8_t *keys, int64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1i8_p1f32_u32_p1i8( + int8_t *keys, float *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1u16_p1u8_u32_p1i8( + uint16_t *keys, uint8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1u16_p1u8_u32_p1i8( + uint16_t *keys, uint8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1u16_p1i8_u32_p1i8( + uint16_t *keys, int8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1u16_p1i8_u32_p1i8( + uint16_t *keys, int8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1u16_p1u16_u32_p1i8( + uint16_t *keys, uint16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1u16_p1u16_u32_p1i8( + uint16_t *keys, uint16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1u16_p1i16_u32_p1i8( + uint16_t *keys, int16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1u16_p1i16_u32_p1i8( + uint16_t *keys, int16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1u16_p1u32_u32_p1i8( + uint16_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1u16_p1u32_u32_p1i8( + uint16_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1u16_p1i32_u32_p1i8( + uint16_t *keys, int32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1u16_p1i32_u32_p1i8( + uint16_t *keys, int32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1u16_p1u64_u32_p1i8( + uint16_t *keys, uint64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1u16_p1u64_u32_p1i8( + uint16_t *keys, uint64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1u16_p1i64_u32_p1i8( + uint16_t *keys, int64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1u16_p1i64_u32_p1i8( + uint16_t *keys, int64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1u16_p1f32_u32_p1i8( + uint16_t *keys, float *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1u16_p1f32_u32_p1i8( + uint16_t *keys, float *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1i16_p1u8_u32_p1i8( + int16_t *keys, uint8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1i16_p1u8_u32_p1i8( + int16_t *keys, uint8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1i16_p1i8_u32_p1i8( + int16_t *keys, int8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1i16_p1i8_u32_p1i8( + int16_t *keys, int8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1i16_p1u16_u32_p1i8( + int16_t *keys, uint16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1i16_p1u16_u32_p1i8( + int16_t *keys, uint16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1i16_p1i16_u32_p1i8( + int16_t *keys, int16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1i16_p1i16_u32_p1i8( + int16_t *keys, int16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1i16_p1u32_u32_p1i8( + int16_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1i16_p1u32_u32_p1i8( + int16_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1i16_p1i32_u32_p1i8( + int16_t *keys, int32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1i16_p1i32_u32_p1i8( + int16_t *keys, int32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1i16_p1u64_u32_p1i8( + int16_t *keys, uint64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1i16_p1u64_u32_p1i8( + int16_t *keys, uint64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1i16_p1i64_u32_p1i8( + int16_t *keys, int64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1i16_p1i64_u32_p1i8( + int16_t *keys, int64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1i16_p1f32_u32_p1i8( + int16_t *keys, float *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1i16_p1f32_u32_p1i8( + int16_t *keys, float *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1u32_p1u8_u32_p1i8( + uint32_t *keys, uint8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1u32_p1u8_u32_p1i8( + uint32_t *keys, uint8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1u32_p1i8_u32_p1i8( + uint32_t *keys, int8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1u32_p1i8_u32_p1i8( + uint32_t *keys, int8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1u32_p1u16_u32_p1i8( + uint32_t *keys, uint16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1u32_p1u16_u32_p1i8( + uint32_t *keys, uint16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1u32_p1i16_u32_p1i8( + uint32_t *keys, int16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1u32_p1i16_u32_p1i8( + uint32_t *keys, int16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1u32_p1u32_u32_p1i8( + uint32_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1u32_p1u32_u32_p1i8( + uint32_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1u32_p1i32_u32_p1i8( + uint32_t *keys, int32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1u32_p1i32_u32_p1i8( + uint32_t *keys, int32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1u32_p1u64_u32_p1i8( + uint32_t *keys, uint64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1u32_p1u64_u32_p1i8( + uint32_t *keys, uint64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1u32_p1i64_u32_p1i8( + uint32_t *keys, int64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1u32_p1i64_u32_p1i8( + uint32_t *keys, int64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1i32_p1u8_u32_p1i8( + int32_t *keys, uint8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1i32_p1u8_u32_p1i8( + int32_t *keys, uint8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1i32_p1i8_u32_p1i8( + int32_t *keys, int8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1i32_p1i8_u32_p1i8( + int32_t *keys, int8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1i32_p1u16_u32_p1i8( + int32_t *keys, uint16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1i32_p1u16_u32_p1i8( + int32_t *keys, uint16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1i32_p1i16_u32_p1i8( + int32_t *keys, int16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1i32_p1i16_u32_p1i8( + int32_t *keys, int16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1i32_p1u32_u32_p1i8( + int32_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1i32_p1u32_u32_p1i8( + int32_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1i32_p1i32_u32_p1i8( + int32_t *keys, int32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1i32_p1i32_u32_p1i8( + int32_t *keys, int32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1i32_p1u64_u32_p1i8( + int32_t *keys, uint64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1i32_p1u64_u32_p1i8( + int32_t *keys, uint64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1i32_p1i64_u32_p1i8( + int32_t *keys, int64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1i32_p1i64_u32_p1i8( + int32_t *keys, int64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1u64_p1u8_u32_p1i8( + uint64_t *keys, uint8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1u64_p1u8_u32_p1i8( + uint64_t *keys, uint8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1u64_p1i8_u32_p1i8( + uint64_t *keys, int8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1u64_p1i8_u32_p1i8( + uint64_t *keys, int8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1u64_p1u16_u32_p1i8( + uint64_t *keys, uint16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1u64_p1u16_u32_p1i8( + uint64_t *keys, uint16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1u64_p1i16_u32_p1i8( + uint64_t *keys, int16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1u64_p1i16_u32_p1i8( + uint64_t *keys, int16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1u64_p1u32_u32_p1i8( + uint64_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1u64_p1u32_u32_p1i8( + uint64_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1u64_p1i32_u32_p1i8( + uint64_t *keys, int32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1u64_p1i32_u32_p1i8( + uint64_t *keys, int32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1u64_p1u64_u32_p1i8( + uint64_t *keys, uint64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1u64_p1u64_u32_p1i8( + uint64_t *keys, uint64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1u64_p1i64_u32_p1i8( + uint64_t *keys, int64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1u64_p1i64_u32_p1i8( + uint64_t *keys, int64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1i64_p1u8_u32_p1i8( + int64_t *keys, uint8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1i64_p1u8_u32_p1i8( + int64_t *keys, uint8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1i64_p1i8_u32_p1i8( + int64_t *keys, int8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1i64_p1i8_u32_p1i8( + int64_t *keys, int8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1i64_p1u16_u32_p1i8( + int64_t *keys, uint16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1i64_p1u16_u32_p1i8( + int64_t *keys, uint16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1i64_p1i16_u32_p1i8( + int64_t *keys, int16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1i64_p1i16_u32_p1i8( + int64_t *keys, int16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1i64_p1u32_u32_p1i8( + int64_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1i64_p1u32_u32_p1i8( + int64_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1i64_p1i32_u32_p1i8( + int64_t *keys, int32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1i64_p1i32_u32_p1i8( + int64_t *keys, int32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1i64_p1u64_u32_p1i8( + int64_t *keys, uint64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1i64_p1u64_u32_p1i8( + int64_t *keys, uint64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1i64_p1i64_u32_p1i8( + int64_t *keys, int64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1i64_p1i64_u32_p1i8( + int64_t *keys, int64_t *vals, uint32_t n, uint8_t *scratch); + +using namespace sycl; + +template +void test_work_group_KV_private_sort(sycl::queue &q, const KeyT keys[NUM], + const ValT vals[NUM], SortHelper gsh) { + static_assert((NUM % WG_SZ == 0), + "Input number must be divisible by work group size!"); + + KeyT input_keys[NUM]; + ValT input_vals[NUM]; + memcpy(&input_keys[0], &keys[0], NUM * sizeof(KeyT)); + memcpy(&input_vals[0], &vals[0], NUM * sizeof(ValT)); + size_t scratch_size = 2 * NUM * (sizeof(KeyT) + sizeof(ValT)) + + std::max(alignof(KeyT), alignof(ValT)); + uint8_t *scratch_ptr = + (uint8_t *)aligned_alloc_device(alignof(KeyT), scratch_size, q); + const static size_t wg_size = WG_SZ; + constexpr size_t num_per_work_item = NUM / WG_SZ; + KeyT output_keys[NUM]; + ValT output_vals[NUM]; + std::vector> sorted_vec; + for (size_t idx = 0; idx < NUM; ++idx) + sorted_vec.push_back(std::make_tuple(input_keys[idx], input_vals[idx])); +#ifdef DES + auto kv_tuple_comp = [](const std::tuple &t1, + const std::tuple &t2) { + return std::get<0>(t1) > std::get<0>(t2); + }; +#else + auto kv_tuple_comp = [](const std::tuple &t1, + const std::tuple &t2) { + return std::get<0>(t1) < std::get<0>(t2); + }; +#endif + std::stable_sort(sorted_vec.begin(), sorted_vec.end(), kv_tuple_comp); + + /*for (size_t idx = 0; idx < NUM; ++idx) { + std::cout << "key: " << (int)std::get<0>(sorted_vec[idx]) << " val: " << + (int)std::get<1>(sorted_vec[idx]) << std::endl; + }*/ + + nd_range<1> num_items((range<1>(wg_size)), (range<1>(wg_size))); + { + buffer ikeys_buf(input_keys, NUM); + buffer ivals_buf(input_vals, NUM); + buffer okeys_buf(output_keys, NUM); + buffer ovals_buf(output_vals, NUM); + q.submit([&](auto &h) { + accessor ikeys_acc{ikeys_buf, h}; + accessor ivals_acc{ivals_buf, h}; + accessor okeys_acc{okeys_buf, h}; + accessor ovals_acc{ovals_buf, h}; + sycl::stream os(1024, 128, h); + h.parallel_for(num_items, [=](nd_item<1> i) { + KeyT pkeys[num_per_work_item]; + ValT pvals[num_per_work_item]; + // copy from global input to fix-size private array. + for (size_t idx = 0; idx < num_per_work_item; ++idx) { + pkeys[idx] = + ikeys_acc[i.get_local_linear_id() * num_per_work_item + idx]; + pvals[idx] = + ivals_acc[i.get_local_linear_id() * num_per_work_item + idx]; + } + + gsh(pkeys, pvals, num_per_work_item, scratch_ptr); + + for (size_t idx = 0; idx < num_per_work_item; ++idx) { + okeys_acc[i.get_local_linear_id() * num_per_work_item + idx] = + pkeys[idx]; + ovals_acc[i.get_local_linear_id() * num_per_work_item + idx] = + pvals[idx]; + } + }); + }).wait(); + } + + /* for (size_t idx = 0; idx < NUM; ++idx) { + std::cout << "key: " << (int)(input_keys[idx]) << " val: " << + (int)(input_vals[idx]) << std::endl; + }*/ + + sycl::free(scratch_ptr, q); + bool fails = false; +#ifdef SPREAD + for (size_t idx = 0; idx < NUM; ++idx) { + size_t idx1 = idx / WG_SZ; + size_t idx2 = idx % WG_SZ; + if ((output_keys[idx2 * num_per_work_item + idx1] != std::get<0>(sorted_vec[idx])) || + (output_vals[idx2 * num_per_work_item + idx1] != std::get<1>(sorted_vec[idx]))) { + std::cout << "idx: " << idx << std::endl; + fails = true; + break; + } + } +#else + for (size_t idx = 0; idx < NUM; ++idx) { + if ((output_keys[idx] != std::get<0>(sorted_vec[idx])) || + (output_vals[idx] != std::get<1>(sorted_vec[idx]))) { + std::cout << "idx: " << idx << std::endl; + fails = true; + break; + } + } +#endif + assert(!fails); +} diff --git a/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_i16.cpp b/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_i16.cpp index 40ce1fececdb4..ba6f1af2b4261 100644 --- a/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_i16.cpp +++ b/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_i16.cpp @@ -4,112 +4,29 @@ // RUN: %{build} -fsycl-device-lib-jit-link -o %t.out // RUN: %{run} %t.out -// UNSUPPORTED: cuda || hip || cpu -#include "group_private_KV_sort_p1p1_p1.hpp" -#include -#include -#include -#include -#include -#include -#include -#include - -using namespace sycl; - -template -void test_work_group_KV_private_sort(sycl::queue &q, const KeyT keys[NUM], - const ValT vals[NUM], SortHelper gsh) { - static_assert((NUM % WG_SZ == 0), - "Input number must be divisible by work group size!"); - - KeyT input_keys[NUM]; - ValT input_vals[NUM]; - memcpy(&input_keys[0], &keys[0], NUM * sizeof(KeyT)); - memcpy(&input_vals[0], &vals[0], NUM * sizeof(ValT)); - size_t scratch_size = 2 * NUM * (sizeof(KeyT) + sizeof(ValT)) + - std::max(alignof(KeyT), alignof(ValT)); - uint8_t *scratch_ptr = - (uint8_t *)aligned_alloc_device(alignof(KeyT), scratch_size, q); - const static size_t wg_size = WG_SZ; - constexpr size_t num_per_work_item = NUM / WG_SZ; - KeyT output_keys[NUM]; - ValT output_vals[NUM]; - std::vector> sorted_vec; - for (size_t idx = 0; idx < NUM; ++idx) - sorted_vec.push_back(std::make_tuple(input_keys[idx], input_vals[idx])); -#ifdef DES - auto kv_tuple_comp = [](const std::tuple &t1, - const std::tuple &t2) { - return std::get<0>(t1) > std::get<0>(t2); - }; -#else - auto kv_tuple_comp = [](const std::tuple &t1, - const std::tuple &t2) { - return std::get<0>(t1) < std::get<0>(t2); - }; -#endif - std::stable_sort(sorted_vec.begin(), sorted_vec.end(), kv_tuple_comp); +// RUN: %{build} -DDES -o %t.out +// RUN: %{run} %t.out - /* for (size_t idx = 0; idx < NUM; ++idx) { - std::cout << "key: " << std::get<0>(sorted_vec[idx]) << " val: " << - (int64_t)std::get<1>(sorted_vec[idx]) << std::endl; - }*/ +// RUN: %{build} -DDES -fsycl-device-lib-jit-link -o %t.out +// RUN: %{run} %t.out - nd_range<1> num_items((range<1>(wg_size)), (range<1>(wg_size))); - { - buffer ikeys_buf(input_keys, NUM); - buffer ivals_buf(input_vals, NUM); - buffer okeys_buf(output_keys, NUM); - buffer ovals_buf(output_vals, NUM); - q.submit([&](auto &h) { - accessor ikeys_acc{ikeys_buf, h}; - accessor ivals_acc{ivals_buf, h}; - accessor okeys_acc{okeys_buf, h}; - accessor ovals_acc{ovals_buf, h}; - sycl::stream os(1024, 128, h); - h.parallel_for(num_items, [=](nd_item<1> i) { - KeyT pkeys[num_per_work_item]; - ValT pvals[num_per_work_item]; - // copy from global input to fix-size private array. - for (size_t idx = 0; idx < num_per_work_item; ++idx) { - pkeys[idx] = - ikeys_acc[i.get_local_linear_id() * num_per_work_item + idx]; - pvals[idx] = - ivals_acc[i.get_local_linear_id() * num_per_work_item + idx]; - } - - gsh(pkeys, pvals, num_per_work_item, scratch_ptr); - - for (size_t idx = 0; idx < num_per_work_item; ++idx) { - okeys_acc[i.get_local_linear_id() * num_per_work_item + idx] = - pkeys[idx]; - ovals_acc[i.get_local_linear_id() * num_per_work_item + idx] = - pvals[idx]; - } - }); - }).wait(); - } - /* for (size_t idx = 0; idx < NUM; ++idx) { - std::cout << "key: " << (int)(input_keys[idx]) << " val: " << - (int)(input_vals[idx]) << std::endl; - }*/ - - sycl::free(scratch_ptr, q); - bool fails = false; - for (size_t idx = 0; idx < NUM; ++idx) { - if ((output_keys[idx] != std::get<0>(sorted_vec[idx])) || - (output_vals[idx] != std::get<1>(sorted_vec[idx]))) { - std::cout << "idx: " << idx << std::endl; - fails = true; - break; - } - } - assert(!fails); -} +// RUN: %{build} -DSPREAD -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DSPREAD -fsycl-device-lib-jit-link -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DDES -DSPREAD -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DDES -DSPREAD -fsycl-device-lib-jit-link -o %t.out +// RUN: %{run} %t.out + +// UNSUPPORTED: cuda || hip || cpu + +#include "group_private_KV_sort_p1p1_p1.hpp" int main() { queue q; @@ -196,12 +113,22 @@ int main() { uint8_t *scratch) { #if __DEVICE_CODE #ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1i16_p1u32_u32_p1i8( + keys, vals, n, scratch); +#else __devicelib_default_work_group_private_sort_close_descending_p1i16_p1u32_u32_p1i8( keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1i16_p1u32_u32_p1i8( + keys, vals, n, scratch); #else __devicelib_default_work_group_private_sort_close_ascending_p1i16_p1u32_u32_p1i8( keys, vals, n, scratch); #endif +#endif #endif }; @@ -209,12 +136,22 @@ int main() { uint8_t *scratch) { #if __DEVICE_CODE #ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1i16_p1u8_u32_p1i8( + keys, vals, n, scratch); +#else __devicelib_default_work_group_private_sort_close_descending_p1i16_p1u8_u32_p1i8( keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1i16_p1u8_u32_p1i8( + keys, vals, n, scratch); #else __devicelib_default_work_group_private_sort_close_ascending_p1i16_p1u8_u32_p1i8( keys, vals, n, scratch); #endif +#endif #endif }; @@ -222,12 +159,22 @@ int main() { uint8_t *scratch) { #if __DEVICE_CODE #ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1i16_p1i8_u32_p1i8( + keys, vals, n, scratch); +#else __devicelib_default_work_group_private_sort_close_descending_p1i16_p1i8_u32_p1i8( keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1i16_p1i8_u32_p1i8( + keys, vals, n, scratch); #else __devicelib_default_work_group_private_sort_close_ascending_p1i16_p1i8_u32_p1i8( keys, vals, n, scratch); #endif +#endif #endif }; @@ -235,12 +182,22 @@ int main() { uint8_t *scratch) { #if __DEVICE_CODE #ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1i16_p1u16_u32_p1i8( + keys, vals, n, scratch); +#else __devicelib_default_work_group_private_sort_close_descending_p1i16_p1u16_u32_p1i8( keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1i16_p1u16_u32_p1i8( + keys, vals, n, scratch); #else __devicelib_default_work_group_private_sort_close_ascending_p1i16_p1u16_u32_p1i8( keys, vals, n, scratch); #endif +#endif #endif }; @@ -248,12 +205,22 @@ int main() { uint8_t *scratch) { #if __DEVICE_CODE #ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1i16_p1i16_u32_p1i8( + keys, vals, n, scratch); +#else __devicelib_default_work_group_private_sort_close_descending_p1i16_p1i16_u32_p1i8( keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1i16_p1i16_u32_p1i8( + keys, vals, n, scratch); #else __devicelib_default_work_group_private_sort_close_ascending_p1i16_p1i16_u32_p1i8( keys, vals, n, scratch); #endif +#endif #endif }; @@ -261,12 +228,22 @@ int main() { uint8_t *scratch) { #if __DEVICE_CODE #ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1i16_p1u32_u32_p1i8( + keys, vals, n, scratch); +#else __devicelib_default_work_group_private_sort_close_descending_p1i16_p1u32_u32_p1i8( keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1i16_p1u32_u32_p1i8( + keys, vals, n, scratch); #else __devicelib_default_work_group_private_sort_close_ascending_p1i16_p1u32_u32_p1i8( keys, vals, n, scratch); #endif +#endif #endif }; @@ -274,12 +251,22 @@ int main() { uint8_t *scratch) { #if __DEVICE_CODE #ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1i16_p1i32_u32_p1i8( + keys, vals, n, scratch); +#else __devicelib_default_work_group_private_sort_close_descending_p1i16_p1i32_u32_p1i8( keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1i16_p1i32_u32_p1i8( + keys, vals, n, scratch); #else __devicelib_default_work_group_private_sort_close_ascending_p1i16_p1i32_u32_p1i8( keys, vals, n, scratch); #endif +#endif #endif }; @@ -287,12 +274,22 @@ int main() { uint8_t *scratch) { #if __DEVICE_CODE #ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1i16_p1u64_u32_p1i8( + keys, vals, n, scratch); +#else __devicelib_default_work_group_private_sort_close_descending_p1i16_p1u64_u32_p1i8( keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1i16_p1u64_u32_p1i8( + keys, vals, n, scratch); #else __devicelib_default_work_group_private_sort_close_ascending_p1i16_p1u64_u32_p1i8( keys, vals, n, scratch); #endif +#endif #endif }; @@ -300,12 +297,22 @@ int main() { uint8_t *scratch) { #if __DEVICE_CODE #ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1i16_p1i64_u32_p1i8( + keys, vals, n, scratch); +#else __devicelib_default_work_group_private_sort_close_descending_p1i16_p1i64_u32_p1i8( keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1i16_p1i64_u32_p1i8( + keys, vals, n, scratch); #else __devicelib_default_work_group_private_sort_close_ascending_p1i16_p1i64_u32_p1i8( keys, vals, n, scratch); #endif +#endif #endif }; diff --git a/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_i32.cpp b/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_i32.cpp index ce5cc2373861c..af38c56d3be3c 100644 --- a/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_i32.cpp +++ b/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_i32.cpp @@ -4,112 +4,29 @@ // RUN: %{build} -fsycl-device-lib-jit-link -o %t.out // RUN: %{run} %t.out -// UNSUPPORTED: cuda || hip || cpu -#include "group_private_KV_sort_p1p1_p1.hpp" -#include -#include -#include -#include -#include -#include -#include -#include - -using namespace sycl; - -template -void test_work_group_KV_private_sort(sycl::queue &q, const KeyT keys[NUM], - const ValT vals[NUM], SortHelper gsh) { - static_assert((NUM % WG_SZ == 0), - "Input number must be divisible by work group size!"); - - KeyT input_keys[NUM]; - ValT input_vals[NUM]; - memcpy(&input_keys[0], &keys[0], NUM * sizeof(KeyT)); - memcpy(&input_vals[0], &vals[0], NUM * sizeof(ValT)); - size_t scratch_size = 2 * NUM * (sizeof(KeyT) + sizeof(ValT)) + - std::max(alignof(KeyT), alignof(ValT)); - uint8_t *scratch_ptr = - (uint8_t *)aligned_alloc_device(alignof(KeyT), scratch_size, q); - const static size_t wg_size = WG_SZ; - constexpr size_t num_per_work_item = NUM / WG_SZ; - KeyT output_keys[NUM]; - ValT output_vals[NUM]; - std::vector> sorted_vec; - for (size_t idx = 0; idx < NUM; ++idx) - sorted_vec.push_back(std::make_tuple(input_keys[idx], input_vals[idx])); -#ifdef DES - auto kv_tuple_comp = [](const std::tuple &t1, - const std::tuple &t2) { - return std::get<0>(t1) > std::get<0>(t2); - }; -#else - auto kv_tuple_comp = [](const std::tuple &t1, - const std::tuple &t2) { - return std::get<0>(t1) < std::get<0>(t2); - }; -#endif - std::stable_sort(sorted_vec.begin(), sorted_vec.end(), kv_tuple_comp); +// RUN: %{build} -DDES -o %t.out +// RUN: %{run} %t.out - /*for (size_t idx = 0; idx < NUM; ++idx) { - std::cout << "key: " << (int)std::get<0>(sorted_vec[idx]) << " val: " << - (int)std::get<1>(sorted_vec[idx]) << std::endl; - }*/ +// RUN: %{build} -DDES -fsycl-device-lib-jit-link -o %t.out +// RUN: %{run} %t.out - nd_range<1> num_items((range<1>(wg_size)), (range<1>(wg_size))); - { - buffer ikeys_buf(input_keys, NUM); - buffer ivals_buf(input_vals, NUM); - buffer okeys_buf(output_keys, NUM); - buffer ovals_buf(output_vals, NUM); - q.submit([&](auto &h) { - accessor ikeys_acc{ikeys_buf, h}; - accessor ivals_acc{ivals_buf, h}; - accessor okeys_acc{okeys_buf, h}; - accessor ovals_acc{ovals_buf, h}; - sycl::stream os(1024, 128, h); - h.parallel_for(num_items, [=](nd_item<1> i) { - KeyT pkeys[num_per_work_item]; - ValT pvals[num_per_work_item]; - // copy from global input to fix-size private array. - for (size_t idx = 0; idx < num_per_work_item; ++idx) { - pkeys[idx] = - ikeys_acc[i.get_local_linear_id() * num_per_work_item + idx]; - pvals[idx] = - ivals_acc[i.get_local_linear_id() * num_per_work_item + idx]; - } - - gsh(pkeys, pvals, num_per_work_item, scratch_ptr); - - for (size_t idx = 0; idx < num_per_work_item; ++idx) { - okeys_acc[i.get_local_linear_id() * num_per_work_item + idx] = - pkeys[idx]; - ovals_acc[i.get_local_linear_id() * num_per_work_item + idx] = - pvals[idx]; - } - }); - }).wait(); - } - /* for (size_t idx = 0; idx < NUM; ++idx) { - std::cout << "key: " << (int)(input_keys[idx]) << " val: " << - (int)(input_vals[idx]) << std::endl; - }*/ - - sycl::free(scratch_ptr, q); - bool fails = false; - for (size_t idx = 0; idx < NUM; ++idx) { - if ((output_keys[idx] != std::get<0>(sorted_vec[idx])) || - (output_vals[idx] != std::get<1>(sorted_vec[idx]))) { - std::cout << "idx: " << idx << std::endl; - fails = true; - break; - } - } - assert(!fails); -} +// RUN: %{build} -DSPREAD -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DSPREAD -fsycl-device-lib-jit-link -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DDES -DSPREAD -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DDES -DSPREAD -fsycl-device-lib-jit-link -o %t.out +// RUN: %{run} %t.out + +// UNSUPPORTED: cuda || hip || cpu + +#include "group_private_KV_sort_p1p1_p1.hpp" int main() { queue q; @@ -247,12 +164,22 @@ int main() { uint8_t *scratch) { #if __DEVICE_CODE #ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1i32_p1u32_u32_p1i8( + keys, vals, n, scratch); +#else __devicelib_default_work_group_private_sort_close_descending_p1i32_p1u32_u32_p1i8( keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1i32_p1u32_u32_p1i8( + keys, vals, n, scratch); #else __devicelib_default_work_group_private_sort_close_ascending_p1i32_p1u32_u32_p1i8( keys, vals, n, scratch); #endif +#endif #endif }; @@ -260,12 +187,22 @@ int main() { uint8_t *scratch) { #if __DEVICE_CODE #ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1i32_p1u8_u32_p1i8( + keys, vals, n, scratch); +#else __devicelib_default_work_group_private_sort_close_descending_p1i32_p1u8_u32_p1i8( keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1i32_p1u8_u32_p1i8( + keys, vals, n, scratch); #else __devicelib_default_work_group_private_sort_close_ascending_p1i32_p1u8_u32_p1i8( keys, vals, n, scratch); #endif +#endif #endif }; @@ -273,12 +210,22 @@ int main() { uint8_t *scratch) { #if __DEVICE_CODE #ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1i32_p1i8_u32_p1i8( + keys, vals, n, scratch); +#else __devicelib_default_work_group_private_sort_close_descending_p1i32_p1i8_u32_p1i8( keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1i32_p1i8_u32_p1i8( + keys, vals, n, scratch); #else __devicelib_default_work_group_private_sort_close_ascending_p1i32_p1i8_u32_p1i8( keys, vals, n, scratch); #endif +#endif #endif }; @@ -286,12 +233,22 @@ int main() { uint8_t *scratch) { #if __DEVICE_CODE #ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1i32_p1u16_u32_p1i8( + keys, vals, n, scratch); +#else __devicelib_default_work_group_private_sort_close_descending_p1i32_p1u16_u32_p1i8( keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1i32_p1u16_u32_p1i8( + keys, vals, n, scratch); #else __devicelib_default_work_group_private_sort_close_ascending_p1i32_p1u16_u32_p1i8( keys, vals, n, scratch); #endif +#endif #endif }; @@ -299,12 +256,22 @@ int main() { uint8_t *scratch) { #if __DEVICE_CODE #ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1i32_p1i16_u32_p1i8( + keys, vals, n, scratch); +#else __devicelib_default_work_group_private_sort_close_descending_p1i32_p1i16_u32_p1i8( keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1i32_p1i16_u32_p1i8( + keys, vals, n, scratch); #else __devicelib_default_work_group_private_sort_close_ascending_p1i32_p1i16_u32_p1i8( keys, vals, n, scratch); #endif +#endif #endif }; @@ -312,12 +279,22 @@ int main() { uint8_t *scratch) { #if __DEVICE_CODE #ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1i32_p1u32_u32_p1i8( + keys, vals, n, scratch); +#else __devicelib_default_work_group_private_sort_close_descending_p1i32_p1u32_u32_p1i8( keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1i32_p1u32_u32_p1i8( + keys, vals, n, scratch); #else __devicelib_default_work_group_private_sort_close_ascending_p1i32_p1u32_u32_p1i8( keys, vals, n, scratch); #endif +#endif #endif }; @@ -325,12 +302,22 @@ int main() { uint8_t *scratch) { #if __DEVICE_CODE #ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1i32_p1i32_u32_p1i8( + keys, vals, n, scratch); +#else __devicelib_default_work_group_private_sort_close_descending_p1i32_p1i32_u32_p1i8( keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1i32_p1i32_u32_p1i8( + keys, vals, n, scratch); #else __devicelib_default_work_group_private_sort_close_ascending_p1i32_p1i32_u32_p1i8( keys, vals, n, scratch); #endif +#endif #endif }; @@ -338,12 +325,22 @@ int main() { uint8_t *scratch) { #if __DEVICE_CODE #ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1i32_p1u64_u32_p1i8( + keys, vals, n, scratch); +#else __devicelib_default_work_group_private_sort_close_descending_p1i32_p1u64_u32_p1i8( keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1i32_p1u64_u32_p1i8( + keys, vals, n, scratch); #else __devicelib_default_work_group_private_sort_close_ascending_p1i32_p1u64_u32_p1i8( keys, vals, n, scratch); #endif +#endif #endif }; @@ -351,12 +348,22 @@ int main() { uint8_t *scratch) { #if __DEVICE_CODE #ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1i32_p1i64_u32_p1i8( + keys, vals, n, scratch); +#else __devicelib_default_work_group_private_sort_close_descending_p1i32_p1i64_u32_p1i8( keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1i32_p1i64_u32_p1i8( + keys, vals, n, scratch); #else __devicelib_default_work_group_private_sort_close_ascending_p1i32_p1i64_u32_p1i8( keys, vals, n, scratch); #endif +#endif #endif }; diff --git a/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_i64.cpp b/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_i64.cpp index 5ef1e653baa16..7d1944150250d 100644 --- a/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_i64.cpp +++ b/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_i64.cpp @@ -4,112 +4,29 @@ // RUN: %{build} -fsycl-device-lib-jit-link -o %t.out // RUN: %{run} %t.out -// UNSUPPORTED: cuda || hip || cpu -#include "group_private_KV_sort_p1p1_p1.hpp" -#include -#include -#include -#include -#include -#include -#include -#include - -using namespace sycl; - -template -void test_work_group_KV_private_sort(sycl::queue &q, const KeyT keys[NUM], - const ValT vals[NUM], SortHelper gsh) { - static_assert((NUM % WG_SZ == 0), - "Input number must be divisible by work group size!"); - - KeyT input_keys[NUM]; - ValT input_vals[NUM]; - memcpy(&input_keys[0], &keys[0], NUM * sizeof(KeyT)); - memcpy(&input_vals[0], &vals[0], NUM * sizeof(ValT)); - size_t scratch_size = 2 * NUM * (sizeof(KeyT) + sizeof(ValT)) + - std::max(alignof(KeyT), alignof(ValT)); - uint8_t *scratch_ptr = - (uint8_t *)aligned_alloc_device(alignof(KeyT), scratch_size, q); - const static size_t wg_size = WG_SZ; - constexpr size_t num_per_work_item = NUM / WG_SZ; - KeyT output_keys[NUM]; - ValT output_vals[NUM]; - std::vector> sorted_vec; - for (size_t idx = 0; idx < NUM; ++idx) - sorted_vec.push_back(std::make_tuple(input_keys[idx], input_vals[idx])); -#ifdef DES - auto kv_tuple_comp = [](const std::tuple &t1, - const std::tuple &t2) { - return std::get<0>(t1) > std::get<0>(t2); - }; -#else - auto kv_tuple_comp = [](const std::tuple &t1, - const std::tuple &t2) { - return std::get<0>(t1) < std::get<0>(t2); - }; -#endif - std::stable_sort(sorted_vec.begin(), sorted_vec.end(), kv_tuple_comp); +// RUN: %{build} -DDES -o %t.out +// RUN: %{run} %t.out - /*for (size_t idx = 0; idx < NUM; ++idx) { - std::cout << "key: " << (int)std::get<0>(sorted_vec[idx]) << " val: " << - (int)std::get<1>(sorted_vec[idx]) << std::endl; - }*/ +// RUN: %{build} -DDES -fsycl-device-lib-jit-link -o %t.out +// RUN: %{run} %t.out - nd_range<1> num_items((range<1>(wg_size)), (range<1>(wg_size))); - { - buffer ikeys_buf(input_keys, NUM); - buffer ivals_buf(input_vals, NUM); - buffer okeys_buf(output_keys, NUM); - buffer ovals_buf(output_vals, NUM); - q.submit([&](auto &h) { - accessor ikeys_acc{ikeys_buf, h}; - accessor ivals_acc{ivals_buf, h}; - accessor okeys_acc{okeys_buf, h}; - accessor ovals_acc{ovals_buf, h}; - sycl::stream os(1024, 128, h); - h.parallel_for(num_items, [=](nd_item<1> i) { - KeyT pkeys[num_per_work_item]; - ValT pvals[num_per_work_item]; - // copy from global input to fix-size private array. - for (size_t idx = 0; idx < num_per_work_item; ++idx) { - pkeys[idx] = - ikeys_acc[i.get_local_linear_id() * num_per_work_item + idx]; - pvals[idx] = - ivals_acc[i.get_local_linear_id() * num_per_work_item + idx]; - } - - gsh(pkeys, pvals, num_per_work_item, scratch_ptr); - - for (size_t idx = 0; idx < num_per_work_item; ++idx) { - okeys_acc[i.get_local_linear_id() * num_per_work_item + idx] = - pkeys[idx]; - ovals_acc[i.get_local_linear_id() * num_per_work_item + idx] = - pvals[idx]; - } - }); - }).wait(); - } - /* for (size_t idx = 0; idx < NUM; ++idx) { - std::cout << "key: " << (int)(input_keys[idx]) << " val: " << - (int)(input_vals[idx]) << std::endl; - }*/ - - sycl::free(scratch_ptr, q); - bool fails = false; - for (size_t idx = 0; idx < NUM; ++idx) { - if ((output_keys[idx] != std::get<0>(sorted_vec[idx])) || - (output_vals[idx] != std::get<1>(sorted_vec[idx]))) { - std::cout << "idx: " << idx << std::endl; - fails = true; - break; - } - } - assert(!fails); -} +// RUN: %{build} -DSPREAD -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DSPREAD -fsycl-device-lib-jit-link -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DDES -DSPREAD -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DDES -DSPREAD -fsycl-device-lib-jit-link -o %t.out +// RUN: %{run} %t.out + +// UNSUPPORTED: cuda || hip || cpu + +#include "group_private_KV_sort_p1p1_p1.hpp" int main() { queue q; @@ -247,12 +164,22 @@ int main() { uint8_t *scratch) { #if __DEVICE_CODE #ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1i64_p1u32_u32_p1i8( + keys, vals, n, scratch); +#else __devicelib_default_work_group_private_sort_close_descending_p1i64_p1u32_u32_p1i8( keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1i64_p1u32_u32_p1i8( + keys, vals, n, scratch); #else __devicelib_default_work_group_private_sort_close_ascending_p1i64_p1u32_u32_p1i8( keys, vals, n, scratch); #endif +#endif #endif }; @@ -260,12 +187,22 @@ int main() { uint8_t *scratch) { #if __DEVICE_CODE #ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1i64_p1u8_u32_p1i8( + keys, vals, n, scratch); +#else __devicelib_default_work_group_private_sort_close_descending_p1i64_p1u8_u32_p1i8( keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1i64_p1u8_u32_p1i8( + keys, vals, n, scratch); #else __devicelib_default_work_group_private_sort_close_ascending_p1i64_p1u8_u32_p1i8( keys, vals, n, scratch); #endif +#endif #endif }; @@ -273,12 +210,22 @@ int main() { uint8_t *scratch) { #if __DEVICE_CODE #ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1i64_p1i8_u32_p1i8( + keys, vals, n, scratch); +#else __devicelib_default_work_group_private_sort_close_descending_p1i64_p1i8_u32_p1i8( keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1i64_p1i8_u32_p1i8( + keys, vals, n, scratch); #else __devicelib_default_work_group_private_sort_close_ascending_p1i64_p1i8_u32_p1i8( keys, vals, n, scratch); #endif +#endif #endif }; @@ -286,12 +233,22 @@ int main() { uint8_t *scratch) { #if __DEVICE_CODE #ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1i64_p1u16_u32_p1i8( + keys, vals, n, scratch); +#else __devicelib_default_work_group_private_sort_close_descending_p1i64_p1u16_u32_p1i8( keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1i64_p1u16_u32_p1i8( + keys, vals, n, scratch); #else __devicelib_default_work_group_private_sort_close_ascending_p1i64_p1u16_u32_p1i8( keys, vals, n, scratch); #endif +#endif #endif }; @@ -299,12 +256,22 @@ int main() { uint8_t *scratch) { #if __DEVICE_CODE #ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1i64_p1i16_u32_p1i8( + keys, vals, n, scratch); +#else __devicelib_default_work_group_private_sort_close_descending_p1i64_p1i16_u32_p1i8( keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1i64_p1i16_u32_p1i8( + keys, vals, n, scratch); #else __devicelib_default_work_group_private_sort_close_ascending_p1i64_p1i16_u32_p1i8( keys, vals, n, scratch); #endif +#endif #endif }; @@ -312,12 +279,22 @@ int main() { uint8_t *scratch) { #if __DEVICE_CODE #ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1i64_p1u32_u32_p1i8( + keys, vals, n, scratch); +#else __devicelib_default_work_group_private_sort_close_descending_p1i64_p1u32_u32_p1i8( keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1i64_p1u32_u32_p1i8( + keys, vals, n, scratch); #else __devicelib_default_work_group_private_sort_close_ascending_p1i64_p1u32_u32_p1i8( keys, vals, n, scratch); #endif +#endif #endif }; @@ -325,12 +302,22 @@ int main() { uint8_t *scratch) { #if __DEVICE_CODE #ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1i64_p1i32_u32_p1i8( + keys, vals, n, scratch); +#else __devicelib_default_work_group_private_sort_close_descending_p1i64_p1i32_u32_p1i8( keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1i64_p1i32_u32_p1i8( + keys, vals, n, scratch); #else __devicelib_default_work_group_private_sort_close_ascending_p1i64_p1i32_u32_p1i8( keys, vals, n, scratch); #endif +#endif #endif }; @@ -338,12 +325,22 @@ int main() { uint8_t *scratch) { #if __DEVICE_CODE #ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1i64_p1u64_u32_p1i8( + keys, vals, n, scratch); +#else __devicelib_default_work_group_private_sort_close_descending_p1i64_p1u64_u32_p1i8( keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1i64_p1u64_u32_p1i8( + keys, vals, n, scratch); #else __devicelib_default_work_group_private_sort_close_ascending_p1i64_p1u64_u32_p1i8( keys, vals, n, scratch); #endif +#endif #endif }; @@ -351,12 +348,22 @@ int main() { uint8_t *scratch) { #if __DEVICE_CODE #ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1i64_p1i64_u32_p1i8( + keys, vals, n, scratch); +#else __devicelib_default_work_group_private_sort_close_descending_p1i64_p1i64_u32_p1i8( keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1i64_p1i64_u32_p1i8( + keys, vals, n, scratch); #else __devicelib_default_work_group_private_sort_close_ascending_p1i64_p1i64_u32_p1i8( keys, vals, n, scratch); #endif +#endif #endif }; diff --git a/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_i8.cpp b/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_i8.cpp index f4ab477ad8ce9..ddd91f3902d36 100644 --- a/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_i8.cpp +++ b/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_i8.cpp @@ -4,112 +4,29 @@ // RUN: %{build} -fsycl-device-lib-jit-link -o %t.out // RUN: %{run} %t.out -// UNSUPPORTED: cuda || hip || cpu -#include "group_private_KV_sort_p1p1_p1.hpp" -#include -#include -#include -#include -#include -#include -#include -#include - -using namespace sycl; - -template -void test_work_group_KV_private_sort(sycl::queue &q, const KeyT keys[NUM], - const ValT vals[NUM], SortHelper gsh) { - static_assert((NUM % WG_SZ == 0), - "Input number must be divisible by work group size!"); - - KeyT input_keys[NUM]; - ValT input_vals[NUM]; - memcpy(&input_keys[0], &keys[0], NUM * sizeof(KeyT)); - memcpy(&input_vals[0], &vals[0], NUM * sizeof(ValT)); - size_t scratch_size = 2 * NUM * (sizeof(KeyT) + sizeof(ValT)) + - std::max(alignof(KeyT), alignof(ValT)); - uint8_t *scratch_ptr = - (uint8_t *)aligned_alloc_device(alignof(KeyT), scratch_size, q); - const static size_t wg_size = WG_SZ; - constexpr size_t num_per_work_item = NUM / WG_SZ; - KeyT output_keys[NUM]; - ValT output_vals[NUM]; - std::vector> sorted_vec; - for (size_t idx = 0; idx < NUM; ++idx) - sorted_vec.push_back(std::make_tuple(input_keys[idx], input_vals[idx])); -#ifdef DES - auto kv_tuple_comp = [](const std::tuple &t1, - const std::tuple &t2) { - return std::get<0>(t1) > std::get<0>(t2); - }; -#else - auto kv_tuple_comp = [](const std::tuple &t1, - const std::tuple &t2) { - return std::get<0>(t1) < std::get<0>(t2); - }; -#endif - std::stable_sort(sorted_vec.begin(), sorted_vec.end(), kv_tuple_comp); +// RUN: %{build} -DDES -o %t.out +// RUN: %{run} %t.out - /*for (size_t idx = 0; idx < NUM; ++idx) { - std::cout << "key: " << (int)std::get<0>(sorted_vec[idx]) << " val: " << - (int)std::get<1>(sorted_vec[idx]) << std::endl; - }*/ +// RUN: %{build} -DDES -fsycl-device-lib-jit-link -o %t.out +// RUN: %{run} %t.out - nd_range<1> num_items((range<1>(wg_size)), (range<1>(wg_size))); - { - buffer ikeys_buf(input_keys, NUM); - buffer ivals_buf(input_vals, NUM); - buffer okeys_buf(output_keys, NUM); - buffer ovals_buf(output_vals, NUM); - q.submit([&](auto &h) { - accessor ikeys_acc{ikeys_buf, h}; - accessor ivals_acc{ivals_buf, h}; - accessor okeys_acc{okeys_buf, h}; - accessor ovals_acc{ovals_buf, h}; - sycl::stream os(1024, 128, h); - h.parallel_for(num_items, [=](nd_item<1> i) { - KeyT pkeys[num_per_work_item]; - ValT pvals[num_per_work_item]; - // copy from global input to fix-size private array. - for (size_t idx = 0; idx < num_per_work_item; ++idx) { - pkeys[idx] = - ikeys_acc[i.get_local_linear_id() * num_per_work_item + idx]; - pvals[idx] = - ivals_acc[i.get_local_linear_id() * num_per_work_item + idx]; - } - - gsh(pkeys, pvals, num_per_work_item, scratch_ptr); - - for (size_t idx = 0; idx < num_per_work_item; ++idx) { - okeys_acc[i.get_local_linear_id() * num_per_work_item + idx] = - pkeys[idx]; - ovals_acc[i.get_local_linear_id() * num_per_work_item + idx] = - pvals[idx]; - } - }); - }).wait(); - } - /* for (size_t idx = 0; idx < NUM; ++idx) { - std::cout << "key: " << (int)(input_keys[idx]) << " val: " << - (int)(input_vals[idx]) << std::endl; - }*/ - - sycl::free(scratch_ptr, q); - bool fails = false; - for (size_t idx = 0; idx < NUM; ++idx) { - if ((output_keys[idx] != std::get<0>(sorted_vec[idx])) || - (output_vals[idx] != std::get<1>(sorted_vec[idx]))) { - std::cout << "idx: " << idx << std::endl; - fails = true; - break; - } - } - assert(!fails); -} +// RUN: %{build} -DSPREAD -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DSPREAD -fsycl-device-lib-jit-link -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DDES -DSPREAD -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DDES -DSPREAD -fsycl-device-lib-jit-link -o %t.out +// RUN: %{run} %t.out + +// UNSUPPORTED: cuda || hip || cpu + +#include "group_private_KV_sort_p1p1_p1.hpp" int main() { queue q; @@ -197,12 +114,22 @@ int main() { uint8_t *scratch) { #if __DEVICE_CODE #ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1i8_p1u32_u32_p1i8( + keys, vals, n, scratch); +#else __devicelib_default_work_group_private_sort_close_descending_p1i8_p1u32_u32_p1i8( keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1i8_p1u32_u32_p1i8( + keys, vals, n, scratch); #else __devicelib_default_work_group_private_sort_close_ascending_p1i8_p1u32_u32_p1i8( keys, vals, n, scratch); #endif +#endif #endif }; @@ -210,12 +137,22 @@ int main() { uint8_t *scratch) { #if __DEVICE_CODE #ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1i8_p1u8_u32_p1i8( + keys, vals, n, scratch); +#else __devicelib_default_work_group_private_sort_close_descending_p1i8_p1u8_u32_p1i8( keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1i8_p1u8_u32_p1i8( + keys, vals, n, scratch); #else __devicelib_default_work_group_private_sort_close_ascending_p1i8_p1u8_u32_p1i8( keys, vals, n, scratch); #endif +#endif #endif }; @@ -223,12 +160,22 @@ int main() { uint8_t *scratch) { #if __DEVICE_CODE #ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1i8_p1i8_u32_p1i8( + keys, vals, n, scratch); +#else __devicelib_default_work_group_private_sort_close_descending_p1i8_p1i8_u32_p1i8( keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1i8_p1i8_u32_p1i8( + keys, vals, n, scratch); #else __devicelib_default_work_group_private_sort_close_ascending_p1i8_p1i8_u32_p1i8( keys, vals, n, scratch); #endif +#endif #endif }; @@ -236,12 +183,22 @@ int main() { uint8_t *scratch) { #if __DEVICE_CODE #ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1i8_p1u16_u32_p1i8( + keys, vals, n, scratch); +#else __devicelib_default_work_group_private_sort_close_descending_p1i8_p1u16_u32_p1i8( keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1i8_p1u16_u32_p1i8( + keys, vals, n, scratch); #else __devicelib_default_work_group_private_sort_close_ascending_p1i8_p1u16_u32_p1i8( keys, vals, n, scratch); #endif +#endif #endif }; @@ -249,12 +206,22 @@ int main() { uint8_t *scratch) { #if __DEVICE_CODE #ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1i8_p1i16_u32_p1i8( + keys, vals, n, scratch); +#else __devicelib_default_work_group_private_sort_close_descending_p1i8_p1i16_u32_p1i8( keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1i8_p1i16_u32_p1i8( + keys, vals, n, scratch); #else __devicelib_default_work_group_private_sort_close_ascending_p1i8_p1i16_u32_p1i8( keys, vals, n, scratch); #endif +#endif #endif }; @@ -262,12 +229,22 @@ int main() { uint8_t *scratch) { #if __DEVICE_CODE #ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1i8_p1u32_u32_p1i8( + keys, vals, n, scratch); +#else __devicelib_default_work_group_private_sort_close_descending_p1i8_p1u32_u32_p1i8( keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1i8_p1u32_u32_p1i8( + keys, vals, n, scratch); #else __devicelib_default_work_group_private_sort_close_ascending_p1i8_p1u32_u32_p1i8( keys, vals, n, scratch); #endif +#endif #endif }; @@ -275,12 +252,22 @@ int main() { uint8_t *scratch) { #if __DEVICE_CODE #ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1i8_p1i32_u32_p1i8( + keys, vals, n, scratch); +#else __devicelib_default_work_group_private_sort_close_descending_p1i8_p1i32_u32_p1i8( keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1i8_p1i32_u32_p1i8( + keys, vals, n, scratch); #else __devicelib_default_work_group_private_sort_close_ascending_p1i8_p1i32_u32_p1i8( keys, vals, n, scratch); #endif +#endif #endif }; @@ -288,12 +275,22 @@ int main() { uint8_t *scratch) { #if __DEVICE_CODE #ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1i8_p1u64_u32_p1i8( + keys, vals, n, scratch); +#else __devicelib_default_work_group_private_sort_close_descending_p1i8_p1u64_u32_p1i8( keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1i8_p1u64_u32_p1i8( + keys, vals, n, scratch); #else __devicelib_default_work_group_private_sort_close_ascending_p1i8_p1u64_u32_p1i8( keys, vals, n, scratch); #endif +#endif #endif }; @@ -301,12 +298,22 @@ int main() { uint8_t *scratch) { #if __DEVICE_CODE #ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1i8_p1i64_u32_p1i8( + keys, vals, n, scratch); +#else __devicelib_default_work_group_private_sort_close_descending_p1i8_p1i64_u32_p1i8( keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1i8_p1i64_u32_p1i8( + keys, vals, n, scratch); #else __devicelib_default_work_group_private_sort_close_ascending_p1i8_p1i64_u32_p1i8( keys, vals, n, scratch); #endif +#endif #endif }; diff --git a/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_u16.cpp b/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_u16.cpp index cfd16f9c45a95..54355af4f825d 100644 --- a/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_u16.cpp +++ b/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_u16.cpp @@ -4,112 +4,29 @@ // RUN: %{build} -fsycl-device-lib-jit-link -o %t.out // RUN: %{run} %t.out -// UNSUPPORTED: cuda || hip || cpu -#include "group_private_KV_sort_p1p1_p1.hpp" -#include -#include -#include -#include -#include -#include -#include -#include - -using namespace sycl; - -template -void test_work_group_KV_private_sort(sycl::queue &q, const KeyT keys[NUM], - const ValT vals[NUM], SortHelper gsh) { - static_assert((NUM % WG_SZ == 0), - "Input number must be divisible by work group size!"); - - KeyT input_keys[NUM]; - ValT input_vals[NUM]; - memcpy(&input_keys[0], &keys[0], NUM * sizeof(KeyT)); - memcpy(&input_vals[0], &vals[0], NUM * sizeof(ValT)); - size_t scratch_size = 2 * NUM * (sizeof(KeyT) + sizeof(ValT)) + - std::max(alignof(KeyT), alignof(ValT)); - uint8_t *scratch_ptr = - (uint8_t *)aligned_alloc_device(alignof(KeyT), scratch_size, q); - const static size_t wg_size = WG_SZ; - constexpr size_t num_per_work_item = NUM / WG_SZ; - KeyT output_keys[NUM]; - ValT output_vals[NUM]; - std::vector> sorted_vec; - for (size_t idx = 0; idx < NUM; ++idx) - sorted_vec.push_back(std::make_tuple(input_keys[idx], input_vals[idx])); -#ifdef DES - auto kv_tuple_comp = [](const std::tuple &t1, - const std::tuple &t2) { - return std::get<0>(t1) > std::get<0>(t2); - }; -#else - auto kv_tuple_comp = [](const std::tuple &t1, - const std::tuple &t2) { - return std::get<0>(t1) < std::get<0>(t2); - }; -#endif - std::stable_sort(sorted_vec.begin(), sorted_vec.end(), kv_tuple_comp); +// RUN: %{build} -DDES -o %t.out +// RUN: %{run} %t.out - /*for (size_t idx = 0; idx < NUM; ++idx) { - std::cout << "key: " << (int)std::get<0>(sorted_vec[idx]) << " val: " << - (int)std::get<1>(sorted_vec[idx]) << std::endl; - }*/ +// RUN: %{build} -DDES -fsycl-device-lib-jit-link -o %t.out +// RUN: %{run} %t.out - nd_range<1> num_items((range<1>(wg_size)), (range<1>(wg_size))); - { - buffer ikeys_buf(input_keys, NUM); - buffer ivals_buf(input_vals, NUM); - buffer okeys_buf(output_keys, NUM); - buffer ovals_buf(output_vals, NUM); - q.submit([&](auto &h) { - accessor ikeys_acc{ikeys_buf, h}; - accessor ivals_acc{ivals_buf, h}; - accessor okeys_acc{okeys_buf, h}; - accessor ovals_acc{ovals_buf, h}; - sycl::stream os(1024, 128, h); - h.parallel_for(num_items, [=](nd_item<1> i) { - KeyT pkeys[num_per_work_item]; - ValT pvals[num_per_work_item]; - // copy from global input to fix-size private array. - for (size_t idx = 0; idx < num_per_work_item; ++idx) { - pkeys[idx] = - ikeys_acc[i.get_local_linear_id() * num_per_work_item + idx]; - pvals[idx] = - ivals_acc[i.get_local_linear_id() * num_per_work_item + idx]; - } - - gsh(pkeys, pvals, num_per_work_item, scratch_ptr); - - for (size_t idx = 0; idx < num_per_work_item; ++idx) { - okeys_acc[i.get_local_linear_id() * num_per_work_item + idx] = - pkeys[idx]; - ovals_acc[i.get_local_linear_id() * num_per_work_item + idx] = - pvals[idx]; - } - }); - }).wait(); - } - /* for (size_t idx = 0; idx < NUM; ++idx) { - std::cout << "key: " << (int)(input_keys[idx]) << " val: " << - (int)(input_vals[idx]) << std::endl; - }*/ - - sycl::free(scratch_ptr, q); - bool fails = false; - for (size_t idx = 0; idx < NUM; ++idx) { - if ((output_keys[idx] != std::get<0>(sorted_vec[idx])) || - (output_vals[idx] != std::get<1>(sorted_vec[idx]))) { - std::cout << "idx: " << idx << std::endl; - fails = true; - break; - } - } - assert(!fails); -} +// RUN: %{build} -DSPREAD -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DSPREAD -fsycl-device-lib-jit-link -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DDES -DSPREAD -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DDES -DSPREAD -fsycl-device-lib-jit-link -o %t.out +// RUN: %{run} %t.out + +// UNSUPPORTED: cuda || hip || cpu + +#include "group_private_KV_sort_p1p1_p1.hpp" int main() { queue q; @@ -197,12 +114,22 @@ int main() { uint8_t *scratch) { #if __DEVICE_CODE #ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1u16_p1u32_u32_p1i8( + keys, vals, n, scratch); +#else __devicelib_default_work_group_private_sort_close_descending_p1u16_p1u32_u32_p1i8( keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1u16_p1u32_u32_p1i8( + keys, vals, n, scratch); #else __devicelib_default_work_group_private_sort_close_ascending_p1u16_p1u32_u32_p1i8( keys, vals, n, scratch); #endif +#endif #endif }; @@ -210,12 +137,22 @@ int main() { uint8_t *scratch) { #if __DEVICE_CODE #ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1u16_p1u8_u32_p1i8( + keys, vals, n, scratch); +#else __devicelib_default_work_group_private_sort_close_descending_p1u16_p1u8_u32_p1i8( keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1u16_p1u8_u32_p1i8( + keys, vals, n, scratch); #else __devicelib_default_work_group_private_sort_close_ascending_p1u16_p1u8_u32_p1i8( keys, vals, n, scratch); #endif +#endif #endif }; @@ -223,12 +160,22 @@ int main() { uint8_t *scratch) { #if __DEVICE_CODE #ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1u16_p1i8_u32_p1i8( + keys, vals, n, scratch); +#else __devicelib_default_work_group_private_sort_close_descending_p1u16_p1i8_u32_p1i8( keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1u16_p1i8_u32_p1i8( + keys, vals, n, scratch); #else __devicelib_default_work_group_private_sort_close_ascending_p1u16_p1i8_u32_p1i8( keys, vals, n, scratch); #endif +#endif #endif }; @@ -236,12 +183,22 @@ int main() { uint8_t *scratch) { #if __DEVICE_CODE #ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1u16_p1u16_u32_p1i8( + keys, vals, n, scratch); +#else __devicelib_default_work_group_private_sort_close_descending_p1u16_p1u16_u32_p1i8( keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1u16_p1u16_u32_p1i8( + keys, vals, n, scratch); #else __devicelib_default_work_group_private_sort_close_ascending_p1u16_p1u16_u32_p1i8( keys, vals, n, scratch); #endif +#endif #endif }; @@ -249,12 +206,22 @@ int main() { uint8_t *scratch) { #if __DEVICE_CODE #ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1u16_p1i16_u32_p1i8( + keys, vals, n, scratch); +#else __devicelib_default_work_group_private_sort_close_descending_p1u16_p1i16_u32_p1i8( keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1u16_p1i16_u32_p1i8( + keys, vals, n, scratch); #else __devicelib_default_work_group_private_sort_close_ascending_p1u16_p1i16_u32_p1i8( keys, vals, n, scratch); #endif +#endif #endif }; @@ -262,12 +229,22 @@ int main() { uint8_t *scratch) { #if __DEVICE_CODE #ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1u16_p1u32_u32_p1i8( + keys, vals, n, scratch); +#else __devicelib_default_work_group_private_sort_close_descending_p1u16_p1u32_u32_p1i8( keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1u16_p1u32_u32_p1i8( + keys, vals, n, scratch); #else __devicelib_default_work_group_private_sort_close_ascending_p1u16_p1u32_u32_p1i8( keys, vals, n, scratch); #endif +#endif #endif }; @@ -275,12 +252,22 @@ int main() { uint8_t *scratch) { #if __DEVICE_CODE #ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1u16_p1i32_u32_p1i8( + keys, vals, n, scratch); +#else __devicelib_default_work_group_private_sort_close_descending_p1u16_p1i32_u32_p1i8( keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1u16_p1i32_u32_p1i8( + keys, vals, n, scratch); #else __devicelib_default_work_group_private_sort_close_ascending_p1u16_p1i32_u32_p1i8( keys, vals, n, scratch); #endif +#endif #endif }; @@ -288,12 +275,22 @@ int main() { uint8_t *scratch) { #if __DEVICE_CODE #ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1u16_p1u64_u32_p1i8( + keys, vals, n, scratch); +#else __devicelib_default_work_group_private_sort_close_descending_p1u16_p1u64_u32_p1i8( keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1u16_p1u64_u32_p1i8( + keys, vals, n, scratch); #else __devicelib_default_work_group_private_sort_close_ascending_p1u16_p1u64_u32_p1i8( keys, vals, n, scratch); #endif +#endif #endif }; @@ -301,12 +298,22 @@ int main() { uint8_t *scratch) { #if __DEVICE_CODE #ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1u16_p1i64_u32_p1i8( + keys, vals, n, scratch); +#else __devicelib_default_work_group_private_sort_close_descending_p1u16_p1i64_u32_p1i8( keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1u16_p1i64_u32_p1i8( + keys, vals, n, scratch); #else __devicelib_default_work_group_private_sort_close_ascending_p1u16_p1i64_u32_p1i8( keys, vals, n, scratch); #endif +#endif #endif }; diff --git a/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_u32.cpp b/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_u32.cpp index 40681e088288e..be598938aa196 100644 --- a/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_u32.cpp +++ b/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_u32.cpp @@ -4,112 +4,29 @@ // RUN: %{build} -fsycl-device-lib-jit-link -o %t.out // RUN: %{run} %t.out -// UNSUPPORTED: cuda || hip || cpu -#include "group_private_KV_sort_p1p1_p1.hpp" -#include -#include -#include -#include -#include -#include -#include -#include - -using namespace sycl; - -template -void test_work_group_KV_private_sort(sycl::queue &q, const KeyT keys[NUM], - const ValT vals[NUM], SortHelper gsh) { - static_assert((NUM % WG_SZ == 0), - "Input number must be divisible by work group size!"); - - KeyT input_keys[NUM]; - ValT input_vals[NUM]; - memcpy(&input_keys[0], &keys[0], NUM * sizeof(KeyT)); - memcpy(&input_vals[0], &vals[0], NUM * sizeof(ValT)); - size_t scratch_size = 2 * NUM * (sizeof(KeyT) + sizeof(ValT)) + - std::max(alignof(KeyT), alignof(ValT)); - uint8_t *scratch_ptr = - (uint8_t *)aligned_alloc_device(alignof(KeyT), scratch_size, q); - const static size_t wg_size = WG_SZ; - constexpr size_t num_per_work_item = NUM / WG_SZ; - KeyT output_keys[NUM]; - ValT output_vals[NUM]; - std::vector> sorted_vec; - for (size_t idx = 0; idx < NUM; ++idx) - sorted_vec.push_back(std::make_tuple(input_keys[idx], input_vals[idx])); -#ifdef DES - auto kv_tuple_comp = [](const std::tuple &t1, - const std::tuple &t2) { - return std::get<0>(t1) > std::get<0>(t2); - }; -#else - auto kv_tuple_comp = [](const std::tuple &t1, - const std::tuple &t2) { - return std::get<0>(t1) < std::get<0>(t2); - }; -#endif - std::stable_sort(sorted_vec.begin(), sorted_vec.end(), kv_tuple_comp); +// RUN: %{build} -DDES -o %t.out +// RUN: %{run} %t.out - /*for (size_t idx = 0; idx < NUM; ++idx) { - std::cout << "key: " << (int)std::get<0>(sorted_vec[idx]) << " val: " << - (int)std::get<1>(sorted_vec[idx]) << std::endl; - }*/ +// RUN: %{build} -DDES -fsycl-device-lib-jit-link -o %t.out +// RUN: %{run} %t.out - nd_range<1> num_items((range<1>(wg_size)), (range<1>(wg_size))); - { - buffer ikeys_buf(input_keys, NUM); - buffer ivals_buf(input_vals, NUM); - buffer okeys_buf(output_keys, NUM); - buffer ovals_buf(output_vals, NUM); - q.submit([&](auto &h) { - accessor ikeys_acc{ikeys_buf, h}; - accessor ivals_acc{ivals_buf, h}; - accessor okeys_acc{okeys_buf, h}; - accessor ovals_acc{ovals_buf, h}; - sycl::stream os(1024, 128, h); - h.parallel_for(num_items, [=](nd_item<1> i) { - KeyT pkeys[num_per_work_item]; - ValT pvals[num_per_work_item]; - // copy from global input to fix-size private array. - for (size_t idx = 0; idx < num_per_work_item; ++idx) { - pkeys[idx] = - ikeys_acc[i.get_local_linear_id() * num_per_work_item + idx]; - pvals[idx] = - ivals_acc[i.get_local_linear_id() * num_per_work_item + idx]; - } - - gsh(pkeys, pvals, num_per_work_item, scratch_ptr); - - for (size_t idx = 0; idx < num_per_work_item; ++idx) { - okeys_acc[i.get_local_linear_id() * num_per_work_item + idx] = - pkeys[idx]; - ovals_acc[i.get_local_linear_id() * num_per_work_item + idx] = - pvals[idx]; - } - }); - }).wait(); - } - /* for (size_t idx = 0; idx < NUM; ++idx) { - std::cout << "key: " << (int)(input_keys[idx]) << " val: " << - (int)(input_vals[idx]) << std::endl; - }*/ - - sycl::free(scratch_ptr, q); - bool fails = false; - for (size_t idx = 0; idx < NUM; ++idx) { - if ((output_keys[idx] != std::get<0>(sorted_vec[idx])) || - (output_vals[idx] != std::get<1>(sorted_vec[idx]))) { - std::cout << "idx: " << idx << std::endl; - fails = true; - break; - } - } - assert(!fails); -} +// RUN: %{build} -DSPREAD -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DSPREAD -fsycl-device-lib-jit-link -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DDES -DSPREAD -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DDES -DSPREAD -fsycl-device-lib-jit-link -o %t.out +// RUN: %{run} %t.out + +// UNSUPPORTED: cuda || hip || cpu + +#include "group_private_KV_sort_p1p1_p1.hpp" int main() { queue q; @@ -247,12 +164,22 @@ int main() { uint8_t *scratch) { #if __DEVICE_CODE #ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1u32_p1u32_u32_p1i8( + keys, vals, n, scratch); +#else __devicelib_default_work_group_private_sort_close_descending_p1u32_p1u32_u32_p1i8( keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1u32_p1u32_u32_p1i8( + keys, vals, n, scratch); #else __devicelib_default_work_group_private_sort_close_ascending_p1u32_p1u32_u32_p1i8( keys, vals, n, scratch); #endif +#endif #endif }; @@ -260,12 +187,22 @@ int main() { uint8_t *scratch) { #if __DEVICE_CODE #ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1u32_p1u8_u32_p1i8( + keys, vals, n, scratch); +#else __devicelib_default_work_group_private_sort_close_descending_p1u32_p1u8_u32_p1i8( keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1u32_p1u8_u32_p1i8( + keys, vals, n, scratch); #else __devicelib_default_work_group_private_sort_close_ascending_p1u32_p1u8_u32_p1i8( keys, vals, n, scratch); #endif +#endif #endif }; @@ -273,12 +210,22 @@ int main() { uint8_t *scratch) { #if __DEVICE_CODE #ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1u32_p1i8_u32_p1i8( + keys, vals, n, scratch); +#else __devicelib_default_work_group_private_sort_close_descending_p1u32_p1i8_u32_p1i8( keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1u32_p1i8_u32_p1i8( + keys, vals, n, scratch); #else __devicelib_default_work_group_private_sort_close_ascending_p1u32_p1i8_u32_p1i8( keys, vals, n, scratch); #endif +#endif #endif }; @@ -286,12 +233,22 @@ int main() { uint8_t *scratch) { #if __DEVICE_CODE #ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1u32_p1u16_u32_p1i8( + keys, vals, n, scratch); +#else __devicelib_default_work_group_private_sort_close_descending_p1u32_p1u16_u32_p1i8( keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1u32_p1u16_u32_p1i8( + keys, vals, n, scratch); #else __devicelib_default_work_group_private_sort_close_ascending_p1u32_p1u16_u32_p1i8( keys, vals, n, scratch); #endif +#endif #endif }; @@ -299,12 +256,22 @@ int main() { uint8_t *scratch) { #if __DEVICE_CODE #ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1u32_p1i16_u32_p1i8( + keys, vals, n, scratch); +#else __devicelib_default_work_group_private_sort_close_descending_p1u32_p1i16_u32_p1i8( keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1u32_p1i16_u32_p1i8( + keys, vals, n, scratch); #else __devicelib_default_work_group_private_sort_close_ascending_p1u32_p1i16_u32_p1i8( keys, vals, n, scratch); #endif +#endif #endif }; @@ -312,12 +279,22 @@ int main() { uint8_t *scratch) { #if __DEVICE_CODE #ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1u32_p1u32_u32_p1i8( + keys, vals, n, scratch); +#else __devicelib_default_work_group_private_sort_close_descending_p1u32_p1u32_u32_p1i8( keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1u32_p1u32_u32_p1i8( + keys, vals, n, scratch); #else __devicelib_default_work_group_private_sort_close_ascending_p1u32_p1u32_u32_p1i8( keys, vals, n, scratch); #endif +#endif #endif }; @@ -325,12 +302,22 @@ int main() { uint8_t *scratch) { #if __DEVICE_CODE #ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1u32_p1i32_u32_p1i8( + keys, vals, n, scratch); +#else __devicelib_default_work_group_private_sort_close_descending_p1u32_p1i32_u32_p1i8( keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1u32_p1i32_u32_p1i8( + keys, vals, n, scratch); #else __devicelib_default_work_group_private_sort_close_ascending_p1u32_p1i32_u32_p1i8( keys, vals, n, scratch); #endif +#endif #endif }; @@ -338,12 +325,22 @@ int main() { uint8_t *scratch) { #if __DEVICE_CODE #ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1u32_p1u64_u32_p1i8( + keys, vals, n, scratch); +#else __devicelib_default_work_group_private_sort_close_descending_p1u32_p1u64_u32_p1i8( keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1u32_p1u64_u32_p1i8( + keys, vals, n, scratch); #else __devicelib_default_work_group_private_sort_close_ascending_p1u32_p1u64_u32_p1i8( keys, vals, n, scratch); #endif +#endif #endif }; @@ -351,12 +348,22 @@ int main() { uint8_t *scratch) { #if __DEVICE_CODE #ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1u32_p1i64_u32_p1i8( + keys, vals, n, scratch); +#else __devicelib_default_work_group_private_sort_close_descending_p1u32_p1i64_u32_p1i8( keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1u32_p1i64_u32_p1i8( + keys, vals, n, scratch); #else __devicelib_default_work_group_private_sort_close_ascending_p1u32_p1i64_u32_p1i8( keys, vals, n, scratch); #endif +#endif #endif }; diff --git a/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_u64.cpp b/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_u64.cpp index 723a4b037f227..8a60e0831c0b9 100644 --- a/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_u64.cpp +++ b/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_u64.cpp @@ -4,112 +4,29 @@ // RUN: %{build} -fsycl-device-lib-jit-link -o %t.out // RUN: %{run} %t.out -// UNSUPPORTED: cuda || hip || cpu -#include "group_private_KV_sort_p1p1_p1.hpp" -#include -#include -#include -#include -#include -#include -#include -#include - -using namespace sycl; - -template -void test_work_group_KV_private_sort(sycl::queue &q, const KeyT keys[NUM], - const ValT vals[NUM], SortHelper gsh) { - static_assert((NUM % WG_SZ == 0), - "Input number must be divisible by work group size!"); - - KeyT input_keys[NUM]; - ValT input_vals[NUM]; - memcpy(&input_keys[0], &keys[0], NUM * sizeof(KeyT)); - memcpy(&input_vals[0], &vals[0], NUM * sizeof(ValT)); - size_t scratch_size = 2 * NUM * (sizeof(KeyT) + sizeof(ValT)) + - std::max(alignof(KeyT), alignof(ValT)); - uint8_t *scratch_ptr = - (uint8_t *)aligned_alloc_device(alignof(KeyT), scratch_size, q); - const static size_t wg_size = WG_SZ; - constexpr size_t num_per_work_item = NUM / WG_SZ; - KeyT output_keys[NUM]; - ValT output_vals[NUM]; - std::vector> sorted_vec; - for (size_t idx = 0; idx < NUM; ++idx) - sorted_vec.push_back(std::make_tuple(input_keys[idx], input_vals[idx])); -#ifdef DES - auto kv_tuple_comp = [](const std::tuple &t1, - const std::tuple &t2) { - return std::get<0>(t1) > std::get<0>(t2); - }; -#else - auto kv_tuple_comp = [](const std::tuple &t1, - const std::tuple &t2) { - return std::get<0>(t1) < std::get<0>(t2); - }; -#endif - std::stable_sort(sorted_vec.begin(), sorted_vec.end(), kv_tuple_comp); +// RUN: %{build} -DDES -o %t.out +// RUN: %{run} %t.out - /*for (size_t idx = 0; idx < NUM; ++idx) { - std::cout << "key: " << (int)std::get<0>(sorted_vec[idx]) << " val: " << - (int)std::get<1>(sorted_vec[idx]) << std::endl; - }*/ +// RUN: %{build} -DDES -fsycl-device-lib-jit-link -o %t.out +// RUN: %{run} %t.out - nd_range<1> num_items((range<1>(wg_size)), (range<1>(wg_size))); - { - buffer ikeys_buf(input_keys, NUM); - buffer ivals_buf(input_vals, NUM); - buffer okeys_buf(output_keys, NUM); - buffer ovals_buf(output_vals, NUM); - q.submit([&](auto &h) { - accessor ikeys_acc{ikeys_buf, h}; - accessor ivals_acc{ivals_buf, h}; - accessor okeys_acc{okeys_buf, h}; - accessor ovals_acc{ovals_buf, h}; - sycl::stream os(1024, 128, h); - h.parallel_for(num_items, [=](nd_item<1> i) { - KeyT pkeys[num_per_work_item]; - ValT pvals[num_per_work_item]; - // copy from global input to fix-size private array. - for (size_t idx = 0; idx < num_per_work_item; ++idx) { - pkeys[idx] = - ikeys_acc[i.get_local_linear_id() * num_per_work_item + idx]; - pvals[idx] = - ivals_acc[i.get_local_linear_id() * num_per_work_item + idx]; - } - - gsh(pkeys, pvals, num_per_work_item, scratch_ptr); - - for (size_t idx = 0; idx < num_per_work_item; ++idx) { - okeys_acc[i.get_local_linear_id() * num_per_work_item + idx] = - pkeys[idx]; - ovals_acc[i.get_local_linear_id() * num_per_work_item + idx] = - pvals[idx]; - } - }); - }).wait(); - } - /* for (size_t idx = 0; idx < NUM; ++idx) { - std::cout << "key: " << (int)(input_keys[idx]) << " val: " << - (int)(input_vals[idx]) << std::endl; - }*/ - - sycl::free(scratch_ptr, q); - bool fails = false; - for (size_t idx = 0; idx < NUM; ++idx) { - if ((output_keys[idx] != std::get<0>(sorted_vec[idx])) || - (output_vals[idx] != std::get<1>(sorted_vec[idx]))) { - std::cout << "idx: " << idx << std::endl; - fails = true; - break; - } - } - assert(!fails); -} +// RUN: %{build} -DSPREAD -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DSPREAD -fsycl-device-lib-jit-link -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DDES -DSPREAD -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DDES -DSPREAD -fsycl-device-lib-jit-link -o %t.out +// RUN: %{run} %t.out + +// UNSUPPORTED: cuda || hip || cpu + +#include "group_private_KV_sort_p1p1_p1.hpp" int main() { queue q; @@ -247,12 +164,22 @@ int main() { uint8_t *scratch) { #if __DEVICE_CODE #ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1u64_p1u32_u32_p1i8( + keys, vals, n, scratch); +#else __devicelib_default_work_group_private_sort_close_descending_p1u64_p1u32_u32_p1i8( keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1u64_p1u32_u32_p1i8( + keys, vals, n, scratch); #else __devicelib_default_work_group_private_sort_close_ascending_p1u64_p1u32_u32_p1i8( keys, vals, n, scratch); #endif +#endif #endif }; @@ -260,12 +187,22 @@ int main() { uint8_t *scratch) { #if __DEVICE_CODE #ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1u64_p1u8_u32_p1i8( + keys, vals, n, scratch); +#else __devicelib_default_work_group_private_sort_close_descending_p1u64_p1u8_u32_p1i8( keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1u64_p1u8_u32_p1i8( + keys, vals, n, scratch); #else __devicelib_default_work_group_private_sort_close_ascending_p1u64_p1u8_u32_p1i8( keys, vals, n, scratch); #endif +#endif #endif }; @@ -273,12 +210,22 @@ int main() { uint8_t *scratch) { #if __DEVICE_CODE #ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1u64_p1i8_u32_p1i8( + keys, vals, n, scratch); +#else __devicelib_default_work_group_private_sort_close_descending_p1u64_p1i8_u32_p1i8( keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1u64_p1i8_u32_p1i8( + keys, vals, n, scratch); #else __devicelib_default_work_group_private_sort_close_ascending_p1u64_p1i8_u32_p1i8( keys, vals, n, scratch); #endif +#endif #endif }; @@ -286,12 +233,22 @@ int main() { uint8_t *scratch) { #if __DEVICE_CODE #ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1u64_p1u16_u32_p1i8( + keys, vals, n, scratch); +#else __devicelib_default_work_group_private_sort_close_descending_p1u64_p1u16_u32_p1i8( keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1u64_p1u16_u32_p1i8( + keys, vals, n, scratch); #else __devicelib_default_work_group_private_sort_close_ascending_p1u64_p1u16_u32_p1i8( keys, vals, n, scratch); #endif +#endif #endif }; @@ -299,12 +256,22 @@ int main() { uint8_t *scratch) { #if __DEVICE_CODE #ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1u64_p1i16_u32_p1i8( + keys, vals, n, scratch); +#else __devicelib_default_work_group_private_sort_close_descending_p1u64_p1i16_u32_p1i8( keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1u64_p1i16_u32_p1i8( + keys, vals, n, scratch); #else __devicelib_default_work_group_private_sort_close_ascending_p1u64_p1i16_u32_p1i8( keys, vals, n, scratch); #endif +#endif #endif }; @@ -312,12 +279,22 @@ int main() { uint8_t *scratch) { #if __DEVICE_CODE #ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1u64_p1u32_u32_p1i8( + keys, vals, n, scratch); +#else __devicelib_default_work_group_private_sort_close_descending_p1u64_p1u32_u32_p1i8( keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1u64_p1u32_u32_p1i8( + keys, vals, n, scratch); #else __devicelib_default_work_group_private_sort_close_ascending_p1u64_p1u32_u32_p1i8( keys, vals, n, scratch); #endif +#endif #endif }; @@ -325,12 +302,22 @@ int main() { uint8_t *scratch) { #if __DEVICE_CODE #ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1u64_p1i32_u32_p1i8( + keys, vals, n, scratch); +#else __devicelib_default_work_group_private_sort_close_descending_p1u64_p1i32_u32_p1i8( keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1u64_p1i32_u32_p1i8( + keys, vals, n, scratch); #else __devicelib_default_work_group_private_sort_close_ascending_p1u64_p1i32_u32_p1i8( keys, vals, n, scratch); #endif +#endif #endif }; @@ -338,12 +325,22 @@ int main() { uint8_t *scratch) { #if __DEVICE_CODE #ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1u64_p1u64_u32_p1i8( + keys, vals, n, scratch); +#else __devicelib_default_work_group_private_sort_close_descending_p1u64_p1u64_u32_p1i8( keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1u64_p1u64_u32_p1i8( + keys, vals, n, scratch); #else __devicelib_default_work_group_private_sort_close_ascending_p1u64_p1u64_u32_p1i8( keys, vals, n, scratch); #endif +#endif #endif }; @@ -351,12 +348,22 @@ int main() { uint8_t *scratch) { #if __DEVICE_CODE #ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1u64_p1i64_u32_p1i8( + keys, vals, n, scratch); +#else __devicelib_default_work_group_private_sort_close_descending_p1u64_p1i64_u32_p1i8( keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1u64_p1i64_u32_p1i8( + keys, vals, n, scratch); #else __devicelib_default_work_group_private_sort_close_ascending_p1u64_p1i64_u32_p1i8( keys, vals, n, scratch); #endif +#endif #endif }; diff --git a/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_u8.cpp b/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_u8.cpp index dedae328797f4..90080dbbfc161 100644 --- a/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_u8.cpp +++ b/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_u8.cpp @@ -4,144 +4,33 @@ // RUN: %{build} -fsycl-device-lib-jit-link -o %t.out // RUN: %{run} %t.out -// UNSUPPORTED: cuda || hip || cpu -#include "group_private_KV_sort_p1p1_p1.hpp" -#include -#include -#include -#include -#include -#include -#include -#include - -using namespace sycl; - -template -void test_work_group_KV_private_sort(sycl::queue &q, const KeyT keys[NUM], - const ValT vals[NUM], SortHelper gsh) { - static_assert((NUM % WG_SZ == 0), - "Input number must be divisible by work group size!"); - - KeyT input_keys[NUM]; - ValT input_vals[NUM]; - memcpy(&input_keys[0], &keys[0], NUM * sizeof(KeyT)); - memcpy(&input_vals[0], &vals[0], NUM * sizeof(ValT)); - size_t scratch_size = 2 * NUM * (sizeof(KeyT) + sizeof(ValT)) + - std::max(alignof(KeyT), alignof(ValT)); - uint8_t *scratch_ptr = - (uint8_t *)aligned_alloc_device(alignof(KeyT), scratch_size, q); - const static size_t wg_size = WG_SZ; - constexpr size_t num_per_work_item = NUM / WG_SZ; - KeyT output_keys[NUM]; - ValT output_vals[NUM]; - std::vector> sorted_vec; - for (size_t idx = 0; idx < NUM; ++idx) - sorted_vec.push_back(std::make_tuple(input_keys[idx], input_vals[idx])); -#ifdef DES - auto kv_tuple_comp = [](const std::tuple &t1, - const std::tuple &t2) { - return std::get<0>(t1) > std::get<0>(t2); - }; -#else - auto kv_tuple_comp = [](const std::tuple &t1, - const std::tuple &t2) { - return std::get<0>(t1) < std::get<0>(t2); - }; -#endif - std::stable_sort(sorted_vec.begin(), sorted_vec.end(), kv_tuple_comp); +// RUN: %{build} -DDES -o %t.out +// RUN: %{run} %t.out - /*for (size_t idx = 0; idx < NUM; ++idx) { - std::cout << "key: " << (int)std::get<0>(sorted_vec[idx]) << " val: " << - (int)std::get<1>(sorted_vec[idx]) << std::endl; - }*/ +// RUN: %{build} -DDES -fsycl-device-lib-jit-link -o %t.out +// RUN: %{run} %t.out - nd_range<1> num_items((range<1>(wg_size)), (range<1>(wg_size))); - { - buffer ikeys_buf(input_keys, NUM); - buffer ivals_buf(input_vals, NUM); - buffer okeys_buf(output_keys, NUM); - buffer ovals_buf(output_vals, NUM); - q.submit([&](auto &h) { - accessor ikeys_acc{ikeys_buf, h}; - accessor ivals_acc{ivals_buf, h}; - accessor okeys_acc{okeys_buf, h}; - accessor ovals_acc{ovals_buf, h}; - sycl::stream os(1024, 128, h); - h.parallel_for(num_items, [=](nd_item<1> i) { - KeyT pkeys[num_per_work_item]; - ValT pvals[num_per_work_item]; - // copy from global input to fix-size private array. - for (size_t idx = 0; idx < num_per_work_item; ++idx) { - pkeys[idx] = - ikeys_acc[i.get_local_linear_id() * num_per_work_item + idx]; - pvals[idx] = - ivals_acc[i.get_local_linear_id() * num_per_work_item + idx]; - } - - gsh(pkeys, pvals, num_per_work_item, scratch_ptr); - - for (size_t idx = 0; idx < num_per_work_item; ++idx) { - okeys_acc[i.get_local_linear_id() * num_per_work_item + idx] = - pkeys[idx]; - ovals_acc[i.get_local_linear_id() * num_per_work_item + idx] = - pvals[idx]; - } - }); - }).wait(); - } - /* for (size_t idx = 0; idx < NUM; ++idx) { - std::cout << "key: " << (int)(input_keys[idx]) << " val: " << - (int)(input_vals[idx]) << std::endl; - }*/ - - sycl::free(scratch_ptr, q); - bool fails = false; - for (size_t idx = 0; idx < NUM; ++idx) { - if ((output_keys[idx] != std::get<0>(sorted_vec[idx])) || - (output_vals[idx] != std::get<1>(sorted_vec[idx]))) { - std::cout << "idx: " << idx << std::endl; - fails = true; - break; - } - } - assert(!fails); -} +// RUN: %{build} -DSPREAD -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DSPREAD -fsycl-device-lib-jit-link -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DDES -DSPREAD -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DDES -DSPREAD -fsycl-device-lib-jit-link -o %t.out +// RUN: %{run} %t.out + +// UNSUPPORTED: cuda || hip || cpu + +#include "group_private_KV_sort_p1p1_p1.hpp" int main() { queue q; - { - constexpr static int NUM = 32; - uint32_t ikeys[NUM] = {1, 11, 1, 9, 3, 100, 34, 8, - 121, 77, 125, 23, 36, 2, 111, 91, - 88, 2, 51, 95431, 881, 99183, 31, 142, - 416, 701, 699, 1024, 8912, 0, 7981, 17}; - uint32_t ivals[NUM] = {99, 32, 1, 2, 67, 9123, 453, 435, - 91111, 777, 165, 145, 2456, 88811, 761, 96, - 765, 10000, 6364, 90, 525, 882, 1, 2423, - 9, 4324, 9123, 0, 1232, 777, 555, 314159}; - auto work_group_sorter = [](uint32_t *keys, uint32_t *vals, uint32_t n, - uint8_t *scratch) { -#if __DEVICE_CODE -#ifdef DES - __devicelib_default_work_group_private_sort_close_descending_p1u32_p1u32_u32_p1i8( - keys, vals, n, scratch); -#else - __devicelib_default_work_group_private_sort_close_ascending_p1u32_p1u32_u32_p1i8( - keys, vals, n, scratch); -#endif -#endif - }; - test_work_group_KV_private_sort( - q, ikeys, ivals, work_group_sorter); - std::cout << "KV private sort p1u32_p1u32_u32_p1i8 pass." << std::endl; - } - { constexpr static int NUM = 35; uint8_t ikeys[NUM] = {1, 11, 1, 9, 3, 100, 34, 8, 121, @@ -225,12 +114,22 @@ int main() { uint8_t *scratch) { #if __DEVICE_CODE #ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1u8_p1u32_u32_p1i8( + keys, vals, n, scratch); +#else __devicelib_default_work_group_private_sort_close_descending_p1u8_p1u32_u32_p1i8( keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1u8_p1u32_u32_p1i8( + keys, vals, n, scratch); #else __devicelib_default_work_group_private_sort_close_ascending_p1u8_p1u32_u32_p1i8( keys, vals, n, scratch); #endif +#endif #endif }; @@ -238,12 +137,22 @@ int main() { uint8_t *scratch) { #if __DEVICE_CODE #ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1u8_p1u8_u32_p1i8( + keys, vals, n, scratch); +#else __devicelib_default_work_group_private_sort_close_descending_p1u8_p1u8_u32_p1i8( keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1u8_p1u8_u32_p1i8( + keys, vals, n, scratch); #else __devicelib_default_work_group_private_sort_close_ascending_p1u8_p1u8_u32_p1i8( keys, vals, n, scratch); #endif +#endif #endif }; @@ -251,12 +160,22 @@ int main() { uint8_t *scratch) { #if __DEVICE_CODE #ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1u8_p1i8_u32_p1i8( + keys, vals, n, scratch); +#else __devicelib_default_work_group_private_sort_close_descending_p1u8_p1i8_u32_p1i8( keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1u8_p1i8_u32_p1i8( + keys, vals, n, scratch); #else __devicelib_default_work_group_private_sort_close_ascending_p1u8_p1i8_u32_p1i8( keys, vals, n, scratch); #endif +#endif #endif }; @@ -264,12 +183,22 @@ int main() { uint8_t *scratch) { #if __DEVICE_CODE #ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1u8_p1u16_u32_p1i8( + keys, vals, n, scratch); +#else __devicelib_default_work_group_private_sort_close_descending_p1u8_p1u16_u32_p1i8( keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1u8_p1u16_u32_p1i8( + keys, vals, n, scratch); #else __devicelib_default_work_group_private_sort_close_ascending_p1u8_p1u16_u32_p1i8( keys, vals, n, scratch); #endif +#endif #endif }; @@ -277,12 +206,22 @@ int main() { uint8_t *scratch) { #if __DEVICE_CODE #ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1u8_p1i16_u32_p1i8( + keys, vals, n, scratch); +#else __devicelib_default_work_group_private_sort_close_descending_p1u8_p1i16_u32_p1i8( keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1u8_p1i16_u32_p1i8( + keys, vals, n, scratch); #else __devicelib_default_work_group_private_sort_close_ascending_p1u8_p1i16_u32_p1i8( keys, vals, n, scratch); #endif +#endif #endif }; @@ -290,12 +229,22 @@ int main() { uint8_t *scratch) { #if __DEVICE_CODE #ifdef DES +#if SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1u8_p1u32_u32_p1i8( + keys, vals, n, scratch); +#else __devicelib_default_work_group_private_sort_close_descending_p1u8_p1u32_u32_p1i8( keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1u8_p1u32_u32_p1i8( + keys, vals, n, scratch); #else __devicelib_default_work_group_private_sort_close_ascending_p1u8_p1u32_u32_p1i8( keys, vals, n, scratch); #endif +#endif #endif }; @@ -303,12 +252,22 @@ int main() { uint8_t *scratch) { #if __DEVICE_CODE #ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1u8_p1i32_u32_p1i8( + keys, vals, n, scratch); +#else __devicelib_default_work_group_private_sort_close_descending_p1u8_p1i32_u32_p1i8( keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1u8_p1i32_u32_p1i8( + keys, vals, n, scratch); #else __devicelib_default_work_group_private_sort_close_ascending_p1u8_p1i32_u32_p1i8( keys, vals, n, scratch); #endif +#endif #endif }; @@ -316,12 +275,22 @@ int main() { uint8_t *scratch) { #if __DEVICE_CODE #ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1u8_p1u64_u32_p1i8( + keys, vals, n, scratch); +#else __devicelib_default_work_group_private_sort_close_descending_p1u8_p1u64_u32_p1i8( keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1u8_p1u64_u32_p1i8( + keys, vals, n, scratch); #else __devicelib_default_work_group_private_sort_close_ascending_p1u8_p1u64_u32_p1i8( keys, vals, n, scratch); #endif +#endif #endif }; @@ -329,12 +298,22 @@ int main() { uint8_t *scratch) { #if __DEVICE_CODE #ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1u8_p1i64_u32_p1i8( + keys, vals, n, scratch); +#else __devicelib_default_work_group_private_sort_close_descending_p1u8_p1i64_u32_p1i8( keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1u8_p1i64_u32_p1i8( + keys, vals, n, scratch); #else __devicelib_default_work_group_private_sort_close_ascending_p1u8_p1i64_u32_p1i8( keys, vals, n, scratch); #endif +#endif #endif }; From c1e3aaa0cd320378e713b5068498c40aaeffd906 Mon Sep 17 00:00:00 2001 From: jinge90 Date: Sat, 12 Oct 2024 15:07:03 +0800 Subject: [PATCH 61/71] add private key value sorting with local shared scratch memory Signed-off-by: jinge90 --- libdevice/fallback-gsort.cpp | 3015 +++++++++++++++++ .../group_private_KV_sort_p1p1_p3.hpp | 1193 +++++++ .../workgroup_private_KV_sort_p3_i16.cpp | 608 ++++ .../workgroup_private_KV_sort_p3_i32.cpp | 695 ++++ .../workgroup_private_KV_sort_p3_i64.cpp | 695 ++++ .../workgroup_private_KV_sort_p3_i8.cpp | 609 ++++ .../workgroup_private_KV_sort_p3_u16.cpp | 609 ++++ .../workgroup_private_KV_sort_p3_u32.cpp | 695 ++++ .../workgroup_private_KV_sort_p3_u64.cpp | 695 ++++ .../workgroup_private_KV_sort_p3_u8.cpp | 609 ++++ 10 files changed, 9423 insertions(+) create mode 100644 sycl/test-e2e/DeviceLib/group_sort/group_private_KV_sort_p1p1_p3.hpp create mode 100644 sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_p3_i16.cpp create mode 100644 sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_p3_i32.cpp create mode 100644 sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_p3_i64.cpp create mode 100644 sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_p3_i8.cpp create mode 100644 sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_p3_u16.cpp create mode 100644 sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_p3_u32.cpp create mode 100644 sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_p3_u64.cpp create mode 100644 sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_p3_u8.cpp diff --git a/libdevice/fallback-gsort.cpp b/libdevice/fallback-gsort.cpp index e7f176153456c..3eb02dd02bf62 100644 --- a/libdevice/fallback-gsort.cpp +++ b/libdevice/fallback-gsort.cpp @@ -3970,4 +3970,3019 @@ void WG_PS_SD(p1i64_p1f32_u32_p1i8)(int64_t *keys, float *vals, uint32_t n, n, scratch, std::greater_equal{}); } + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1u8_p1u8_u32_p3i8)(uint8_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1u8_p1u8_u32_p3i8)(uint8_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1u8_p1i8_u32_p3i8)(uint8_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1u8_p1i8_u32_p3i8)(uint8_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1u8_p1u16_u32_p3i8)(uint8_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1u8_p1u16_u32_p3i8)(uint8_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1u8_p1i16_u32_p3i8)(uint8_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1u8_p1i16_u32_p3i8)(uint8_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1u8_p1u32_u32_p3i8)(uint8_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1u8_p1u32_u32_p3i8)(uint8_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1u8_p1i32_u32_p3i8)(uint8_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1u8_p1i32_u32_p3i8)(uint8_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1u8_p1u64_u32_p3i8)(uint8_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1u8_p1u64_u32_p3i8)(uint8_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1u8_p1i64_u32_p3i8)(uint8_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1u8_p1i64_u32_p3i8)(uint8_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1u8_p1f32_u32_p3i8)(uint8_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1u8_p1f32_u32_p3i8)(uint8_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::greater_equal{}); +} + +// int8_t as key +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1i8_p1u8_u32_p3i8)(int8_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1i8_p1u8_u32_p3i8)(int8_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1i8_p1i8_u32_p3i8)(int8_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1i8_p1i8_u32_p3i8)(int8_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1i8_p1u16_u32_p3i8)(int8_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1i8_p1u16_u32_p3i8)(int8_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1i8_p1i16_u32_p3i8)(int8_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1i8_p1i16_u32_p3i8)(int8_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1i8_p1u32_u32_p3i8)(int8_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1i8_p1u32_u32_p3i8)(int8_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1i8_p1i32_u32_p3i8)(int8_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1i8_p1i32_u32_p3i8)(int8_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1i8_p1u64_u32_p3i8)(int8_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1i8_p1u64_u32_p3i8)(int8_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1i8_p1i64_u32_p3i8)(int8_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1i8_p1i64_u32_p3i8)(int8_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1i8_p1f32_u32_p3i8)(int8_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1i8_p1f32_u32_p3i8)(int8_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::greater_equal{}); +} + +// uint16_t as key type +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1u16_p1u8_u32_p3i8)(uint16_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1u16_p1u8_u32_p3i8)(uint16_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1u16_p1i8_u32_p3i8)(uint16_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1u16_p1i8_u32_p3i8)(uint16_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1u16_p1u16_u32_p3i8)(uint16_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1u16_p1u16_u32_p3i8)(uint16_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1u16_p1i16_u32_p3i8)(uint16_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1u16_p1i16_u32_p3i8)(uint16_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1u16_p1u32_u32_p3i8)(uint16_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1u16_p1u32_u32_p3i8)(uint16_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1u16_p1i32_u32_p3i8)(uint16_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1u16_p1i32_u32_p3i8)(uint16_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1u16_p1u64_u32_p3i8)(uint16_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1u16_p1u64_u32_p3i8)(uint16_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1u16_p1i64_u32_p3i8)(uint16_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1u16_p1i64_u32_p3i8)(uint16_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1u16_p1f32_u32_p3i8)(uint16_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1u16_p1f32_u32_p3i8)(uint16_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::greater_equal{}); +} + +// int16_t as key type +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1i16_p1u8_u32_p3i8)(int16_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1i16_p1u8_u32_p3i8)(int16_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1i16_p1i8_u32_p3i8)(int16_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1i16_p1i8_u32_p3i8)(int16_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1i16_p1u16_u32_p3i8)(int16_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1i16_p1u16_u32_p3i8)(int16_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1i16_p1i16_u32_p3i8)(int16_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1i16_p1i16_u32_p3i8)(int16_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1i16_p1u32_u32_p3i8)(int16_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1i16_p1u32_u32_p3i8)(int16_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1i16_p1i32_u32_p3i8)(int16_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1i16_p1i32_u32_p3i8)(int16_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1i16_p1u64_u32_p3i8)(int16_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1i16_p1u64_u32_p3i8)(int16_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1i16_p1i64_u32_p3i8)(int16_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1i16_p1i64_u32_p3i8)(int16_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1i16_p1f32_u32_p3i8)(int16_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1i16_p1f32_u32_p3i8)(int16_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::greater_equal{}); +} + +// uint32_t as key type +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1u32_p1u8_u32_p3i8)(uint32_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1u32_p1u8_u32_p3i8)(uint32_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1u32_p3i8_u32_p3i8)(uint32_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1u32_p3i8_u32_p3i8)(uint32_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1u32_p1u16_u32_p3i8)(uint32_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1u32_p1u16_u32_p3i8)(uint32_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1u32_p1i16_u32_p3i8)(uint32_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1u32_p1i16_u32_p3i8)(uint32_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1u32_p1u32_u32_p3i8)(uint32_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1u32_p1u32_u32_p3i8)(uint32_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1u32_p1i32_u32_p3i8)(uint32_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1u32_p1i32_u32_p3i8)(uint32_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1u32_p1u64_u32_p3i8)(uint32_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1u32_p1u64_u32_p3i8)(uint32_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1u32_p1i64_u32_p3i8)(uint32_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1u32_p1i64_u32_p3i8)(uint32_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1u32_p1f32_u32_p3i8)(uint32_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1u32_p1f32_u32_p3i8)(uint32_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::greater_equal{}); +} + +// int32_t as key type +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1i32_p1u8_u32_p3i8)(int32_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1i32_p1u8_u32_p3i8)(int32_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1i32_p1i8_u32_p3i8)(int32_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1i32_p1i8_u32_p3i8)(int32_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1i32_p1u16_u32_p3i8)(int32_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1i32_p1u16_u32_p3i8)(int32_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1i32_p1i16_u32_p3i8)(int32_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1i32_p1i16_u32_p3i8)(int32_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1i32_p1u32_u32_p3i8)(int32_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1i32_p1u32_u32_p3i8)(int32_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1i32_p1i32_u32_p3i8)(int32_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1i32_p1i32_u32_p3i8)(int32_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1i32_p1u64_u32_p3i8)(int32_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1i32_p1u64_u32_p3i8)(int32_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1i32_p1i64_u32_p3i8)(int32_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1i32_p1i64_u32_p3i8)(int32_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1i32_p1f32_u32_p3i8)(int32_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1i32_p1f32_u32_p3i8)(int32_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::greater_equal{}); +} + +// uint64_t as key type +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1u64_p1u8_u32_p3i8)(uint64_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1u64_p1u8_u32_p3i8)(uint64_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1u64_p1i8_u32_p3i8)(uint64_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1u64_p1i8_u32_p3i8)(uint64_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1u64_p1u16_u32_p3i8)(uint64_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1u64_p1u16_u32_p3i8)(uint64_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1u64_p1i16_u32_p3i8)(uint64_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1u64_p1i16_u32_p3i8)(uint64_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1u64_p1u32_u32_p3i8)(uint64_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1u64_p1u32_u32_p3i8)(uint64_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1u64_p1i32_u32_p3i8)(uint64_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1u64_p1i32_u32_p3i8)(uint64_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1u64_p1u64_u32_p3i8)(uint64_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1u64_p1u64_u32_p3i8)(uint64_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1u64_p1i64_u32_p3i8)(uint64_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1u64_p1i64_u32_p3i8)(uint64_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1u64_p1f32_u32_p3i8)(uint64_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1u64_p1f32_u32_p3i8)(uint64_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::greater_equal{}); +} + +// int64_t as key type +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1i64_p1u8_u32_p3i8)(int64_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1i64_p1u8_u32_p3i8)(int64_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1i64_p1i8_u32_p3i8)(int64_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1i64_p1i8_u32_p3i8)(int64_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1i64_p1u16_u32_p3i8)(int64_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1i64_p1u16_u32_p3i8)(int64_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1i64_p1i16_u32_p3i8)(int64_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1i64_p1i16_u32_p3i8)(int64_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1i64_p1u32_u32_p3i8)(int64_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1i64_p1u32_u32_p3i8)(int64_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1i64_p1i32_u32_p3i8)(int64_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1i64_p1i32_u32_p3i8)(int64_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1i64_p1u64_u32_p3i8)(int64_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1i64_p1u64_u32_p3i8)(int64_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1i64_p1i64_u32_p3i8)(int64_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1i64_p1i64_u32_p3i8)(int64_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1i64_p1f32_u32_p3i8)(int64_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1i64_p1f32_u32_p3i8)(int64_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::greater_equal{}); +} + +// Work group private sorting algorithms. +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1u8_p1u8_u32_p3i8)(uint8_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1u8_p1u8_u32_p3i8)(uint8_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1u8_p1i8_u32_p3i8)(uint8_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), n, + scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1u8_p1i8_u32_p3i8)(uint8_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), n, + scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1u8_p1u16_u32_p3i8)(uint8_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1u8_p1u16_u32_p3i8)(uint8_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1u8_p1i16_u32_p3i8)(uint8_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1u8_p1i16_u32_p3i8)(uint8_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1u8_p1u32_u32_p3i8)(uint8_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1u8_p1u32_u32_p3i8)(uint8_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1u8_p1i32_u32_p3i8)(uint8_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1u8_p1i32_u32_p3i8)(uint8_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1u8_p1u64_u32_p3i8)(uint8_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1u8_p1u64_u32_p3i8)(uint8_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1u8_p1i64_u32_p3i8)(uint8_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1u8_p1i64_u32_p3i8)(uint8_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1u8_p1f32_u32_p3i8)(uint8_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1u8_p1f32_u32_p3i8)(uint8_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1i8_p1u8_u32_p3i8)(int8_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1i8_p1u8_u32_p3i8)(int8_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1i8_p1i8_u32_p3i8)(int8_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), n, + scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1i8_p1i8_u32_p3i8)(int8_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), n, + scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1i8_p1u16_u32_p3i8)(int8_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1i8_p1u16_u32_p3i8)(int8_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1i8_p1i16_u32_p3i8)(int8_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1i8_p1i16_u32_p3i8)(int8_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1i8_p1u32_u32_p3i8)(int8_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1i8_p1u32_u32_p3i8)(int8_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1i8_p1i32_u32_p3i8)(int8_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1i8_p1i32_u32_p3i8)(int8_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1i8_p1u64_u32_p3i8)(int8_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1i8_p1u64_u32_p3i8)(int8_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1i8_p1i64_u32_p3i8)(int8_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1i8_p1i64_u32_p3i8)(int8_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1i8_p1f32_u32_p3i8)(int8_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1i8_p1f32_u32_p3i8)(int8_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1u16_p1u8_u32_p3i8)(uint16_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1u16_p1u8_u32_p3i8)(uint16_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1u16_p1i8_u32_p3i8)(uint16_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), n, + scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1u16_p1i8_u32_p3i8)(uint16_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), n, + scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1u16_p1u16_u32_p3i8)(uint16_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1u16_p1u16_u32_p3i8)(uint16_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1u16_p1i16_u32_p3i8)(uint16_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1u16_p1i16_u32_p3i8)(uint16_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1u16_p1u32_u32_p3i8)(uint16_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1u16_p1u32_u32_p3i8)(uint16_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1u16_p1i32_u32_p3i8)(uint16_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1u16_p1i32_u32_p3i8)(uint16_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1u16_p1u64_u32_p3i8)(uint16_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1u16_p1u64_u32_p3i8)(uint16_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1u16_p1i64_u32_p3i8)(uint16_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1u16_p1i64_u32_p3i8)(uint16_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1u16_p1f32_u32_p3i8)(uint16_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1u16_p1f32_u32_p3i8)(uint16_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1i16_p1u8_u32_p3i8)(int16_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1i16_p1u8_u32_p3i8)(int16_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1i16_p1i8_u32_p3i8)(int16_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), n, + scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1i16_p1i8_u32_p3i8)(int16_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), n, + scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1i16_p1u16_u32_p3i8)(int16_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1i16_p1u16_u32_p3i8)(int16_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1i16_p1i16_u32_p3i8)(int16_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1i16_p1i16_u32_p3i8)(int16_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1i16_p1u32_u32_p3i8)(int16_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1i16_p1u32_u32_p3i8)(int16_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1i16_p1i32_u32_p3i8)(int16_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1i16_p1i32_u32_p3i8)(int16_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1i16_p1u64_u32_p3i8)(int16_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1i16_p1u64_u32_p3i8)(int16_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1i16_p1i64_u32_p3i8)(int16_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1i16_p1i64_u32_p3i8)(int16_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1i16_p1f32_u32_p3i8)(int16_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1i16_p1f32_u32_p3i8)(int16_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1u32_p1u8_u32_p3i8)(uint32_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1u32_p1u8_u32_p3i8)(uint32_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1u32_p3i8_u32_p3i8)(uint32_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), n, + scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1u32_p3i8_u32_p3i8)(uint32_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), n, + scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1u32_p1u16_u32_p3i8)(uint32_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1u32_p1u16_u32_p3i8)(uint32_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1u32_p1i16_u32_p3i8)(uint32_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1u32_p1i16_u32_p3i8)(uint32_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1u32_p1u32_u32_p3i8)(uint32_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1u32_p1u32_u32_p3i8)(uint32_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1u32_p1i32_u32_p3i8)(uint32_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1u32_p1i32_u32_p3i8)(uint32_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1u32_p1u64_u32_p3i8)(uint32_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1u32_p1u64_u32_p3i8)(uint32_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1u32_p1i64_u32_p3i8)(uint32_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1u32_p1i64_u32_p3i8)(uint32_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1u32_p1f32_u32_p3i8)(uint32_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1u32_p1f32_u32_p3i8)(uint32_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1i32_p1u8_u32_p3i8)(int32_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1i32_p1u8_u32_p3i8)(int32_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1i32_p1i8_u32_p3i8)(int32_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), n, + scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1i32_p1i8_u32_p3i8)(int32_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), n, + scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1i32_p1u16_u32_p3i8)(int32_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1i32_p1u16_u32_p3i8)(int32_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1i32_p1i16_u32_p3i8)(int32_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1i32_p1i16_u32_p3i8)(int32_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1i32_p1u32_u32_p3i8)(int32_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1i32_p1u32_u32_p3i8)(int32_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1i32_p1i32_u32_p3i8)(int32_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1i32_p1i32_u32_p3i8)(int32_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1i32_p1u64_u32_p3i8)(int32_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1i32_p1u64_u32_p3i8)(int32_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1i32_p1i64_u32_p3i8)(int32_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1i32_p1i64_u32_p3i8)(int32_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1i32_p1f32_u32_p3i8)(int32_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1i32_p1f32_u32_p3i8)(int32_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1u64_p1u8_u32_p3i8)(uint64_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1u64_p1u8_u32_p3i8)(uint64_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1u64_p1i8_u32_p3i8)(uint64_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), n, + scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1u64_p1i8_u32_p3i8)(uint64_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), n, + scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1u64_p1u16_u32_p3i8)(uint64_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1u64_p1u16_u32_p3i8)(uint64_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1u64_p1i16_u32_p3i8)(uint64_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1u64_p1i16_u32_p3i8)(uint64_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1u64_p1u32_u32_p3i8)(uint64_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1u64_p1u32_u32_p3i8)(uint64_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1u64_p1i32_u32_p3i8)(uint64_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1u64_p1i32_u32_p3i8)(uint64_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1u64_p1u64_u32_p3i8)(uint64_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1u64_p1u64_u32_p3i8)(uint64_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1u64_p1i64_u32_p3i8)(uint64_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1u64_p1i64_u32_p3i8)(uint64_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1u64_p1f32_u32_p3i8)(uint64_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1u64_p1f32_u32_p3i8)(uint64_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1i64_p1u8_u32_p3i8)(int64_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1i64_p1u8_u32_p3i8)(int64_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1i64_p1i8_u32_p3i8)(int64_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), n, + scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1i64_p1i8_u32_p3i8)(int64_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), n, + scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1i64_p1u16_u32_p3i8)(int64_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1i64_p1u16_u32_p3i8)(int64_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1i64_p1i16_u32_p3i8)(int64_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1i64_p1i16_u32_p3i8)(int64_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1i64_p1u32_u32_p3i8)(int64_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1i64_p1u32_u32_p3i8)(int64_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1i64_p1i32_u32_p3i8)(int64_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1i64_p1i32_u32_p3i8)(int64_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1i64_p1u64_u32_p3i8)(int64_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1i64_p1u64_u32_p3i8)(int64_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1i64_p1i64_u32_p3i8)(int64_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1i64_p1i64_u32_p3i8)(int64_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1i64_p1f32_u32_p3i8)(int64_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1i64_p1f32_u32_p3i8)(int64_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1u8_p1u8_u32_p3i8)(uint8_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1u8_p1u8_u32_p3i8)(uint8_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1u8_p1i8_u32_p3i8)(uint8_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1u8_p1i8_u32_p3i8)(uint8_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1u8_p1u16_u32_p3i8)(uint8_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1u8_p1u16_u32_p3i8)(uint8_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1u8_p1i16_u32_p3i8)(uint8_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1u8_p1i16_u32_p3i8)(uint8_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1u8_p1u32_u32_p3i8)(uint8_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1u8_p1u32_u32_p3i8)(uint8_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1u8_p1i32_u32_p3i8)(uint8_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1u8_p1i32_u32_p3i8)(uint8_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1u8_p1u64_u32_p3i8)(uint8_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1u8_p1u64_u32_p3i8)(uint8_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1u8_p1i64_u32_p3i8)(uint8_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1u8_p1i64_u32_p3i8)(uint8_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1u8_p1f32_u32_p3i8)(uint8_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1u8_p1f32_u32_p3i8)(uint8_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1i8_p1u8_u32_p3i8)(int8_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1i8_p1u8_u32_p3i8)(int8_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1i8_p1i8_u32_p3i8)(int8_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1i8_p1i8_u32_p3i8)(int8_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1i8_p1u16_u32_p3i8)(int8_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1i8_p1u16_u32_p3i8)(int8_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1i8_p1i16_u32_p3i8)(int8_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1i8_p1i16_u32_p3i8)(int8_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1i8_p1u32_u32_p3i8)(int8_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1i8_p1u32_u32_p3i8)(int8_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1i8_p1i32_u32_p3i8)(int8_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1i8_p1i32_u32_p3i8)(int8_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1i8_p1u64_u32_p3i8)(int8_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1i8_p1u64_u32_p3i8)(int8_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1i8_p1i64_u32_p3i8)(int8_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1i8_p1i64_u32_p3i8)(int8_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1i8_p1f32_u32_p3i8)(int8_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1i8_p1f32_u32_p3i8)(int8_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1u16_p1u8_u32_p3i8)(uint16_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1u16_p1u8_u32_p3i8)(uint16_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1u16_p1i8_u32_p3i8)(uint16_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1u16_p1i8_u32_p3i8)(uint16_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1u16_p1u16_u32_p3i8)(uint16_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1u16_p1u16_u32_p3i8)(uint16_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1u16_p1i16_u32_p3i8)(uint16_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1u16_p1i16_u32_p3i8)(uint16_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1u16_p1u32_u32_p3i8)(uint16_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1u16_p1u32_u32_p3i8)(uint16_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1u16_p1i32_u32_p3i8)(uint16_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1u16_p1i32_u32_p3i8)(uint16_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1u16_p1u64_u32_p3i8)(uint16_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1u16_p1u64_u32_p3i8)(uint16_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1u16_p1i64_u32_p3i8)(uint16_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1u16_p1i64_u32_p3i8)(uint16_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1u16_p1f32_u32_p3i8)(uint16_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1u16_p1f32_u32_p3i8)(uint16_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1i16_p1u8_u32_p3i8)(int16_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1i16_p1u8_u32_p3i8)(int16_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1i16_p1i8_u32_p3i8)(int16_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1i16_p1i8_u32_p3i8)(int16_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1i16_p1u16_u32_p3i8)(int16_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1i16_p1u16_u32_p3i8)(int16_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1i16_p1i16_u32_p3i8)(int16_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1i16_p1i16_u32_p3i8)(int16_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1i16_p1u32_u32_p3i8)(int16_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1i16_p1u32_u32_p3i8)(int16_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1i16_p1i32_u32_p3i8)(int16_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1i16_p1i32_u32_p3i8)(int16_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1i16_p1u64_u32_p3i8)(int16_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1i16_p1u64_u32_p3i8)(int16_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1i16_p1i64_u32_p3i8)(int16_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1i16_p1i64_u32_p3i8)(int16_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1i16_p1f32_u32_p3i8)(int16_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1i16_p1f32_u32_p3i8)(int16_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1u32_p1u8_u32_p3i8)(uint32_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1u32_p1u8_u32_p3i8)(uint32_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1u32_p3i8_u32_p3i8)(uint32_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1u32_p3i8_u32_p3i8)(uint32_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1u32_p1u16_u32_p3i8)(uint32_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1u32_p1u16_u32_p3i8)(uint32_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1u32_p1i16_u32_p3i8)(uint32_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1u32_p1i16_u32_p3i8)(uint32_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1u32_p1u32_u32_p3i8)(uint32_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1u32_p1u32_u32_p3i8)(uint32_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1u32_p1i32_u32_p3i8)(uint32_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1u32_p1i32_u32_p3i8)(uint32_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1u32_p1u64_u32_p3i8)(uint32_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1u32_p1u64_u32_p3i8)(uint32_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1u32_p1i64_u32_p3i8)(uint32_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1u32_p1i64_u32_p3i8)(uint32_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1u32_p1f32_u32_p3i8)(uint32_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1u32_p1f32_u32_p3i8)(uint32_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1i32_p1u8_u32_p3i8)(int32_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1i32_p1u8_u32_p3i8)(int32_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1i32_p1i8_u32_p3i8)(int32_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1i32_p1i8_u32_p3i8)(int32_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1i32_p1u16_u32_p3i8)(int32_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1i32_p1u16_u32_p3i8)(int32_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1i32_p1i16_u32_p3i8)(int32_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1i32_p1i16_u32_p3i8)(int32_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1i32_p1u32_u32_p3i8)(int32_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1i32_p1u32_u32_p3i8)(int32_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1i32_p1i32_u32_p3i8)(int32_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1i32_p1i32_u32_p3i8)(int32_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1i32_p1u64_u32_p3i8)(int32_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1i32_p1u64_u32_p3i8)(int32_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1i32_p1i64_u32_p3i8)(int32_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1i32_p1i64_u32_p3i8)(int32_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1i32_p1f32_u32_p3i8)(int32_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1i32_p1f32_u32_p3i8)(int32_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1u64_p1u8_u32_p3i8)(uint64_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1u64_p1u8_u32_p3i8)(uint64_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1u64_p1i8_u32_p3i8)(uint64_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1u64_p1i8_u32_p3i8)(uint64_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1u64_p1u16_u32_p3i8)(uint64_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1u64_p1u16_u32_p3i8)(uint64_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1u64_p1i16_u32_p3i8)(uint64_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1u64_p1i16_u32_p3i8)(uint64_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1u64_p1u32_u32_p3i8)(uint64_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1u64_p1u32_u32_p3i8)(uint64_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1u64_p1i32_u32_p3i8)(uint64_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1u64_p1i32_u32_p3i8)(uint64_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1u64_p1u64_u32_p3i8)(uint64_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1u64_p1u64_u32_p3i8)(uint64_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1u64_p1i64_u32_p3i8)(uint64_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1u64_p1i64_u32_p3i8)(uint64_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1u64_p1f32_u32_p3i8)(uint64_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1u64_p1f32_u32_p3i8)(uint64_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1i64_p1u8_u32_p3i8)(int64_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1i64_p1u8_u32_p3i8)(int64_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1i64_p1i8_u32_p3i8)(int64_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1i64_p1i8_u32_p3i8)(int64_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1i64_p1u16_u32_p3i8)(int64_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1i64_p1u16_u32_p3i8)(int64_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1i64_p1i16_u32_p3i8)(int64_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1i64_p1i16_u32_p3i8)(int64_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1i64_p1u32_u32_p3i8)(int64_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1i64_p1u32_u32_p3i8)(int64_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1i64_p1i32_u32_p3i8)(int64_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1i64_p1i32_u32_p3i8)(int64_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1i64_p1u64_u32_p3i8)(int64_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1i64_p1u64_u32_p3i8)(int64_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1i64_p1i64_u32_p3i8)(int64_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1i64_p1i64_u32_p3i8)(int64_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1i64_p1f32_u32_p3i8)(int64_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1i64_p1f32_u32_p3i8)(int64_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} #endif // __SPIR__ || __SPIRV__ diff --git a/sycl/test-e2e/DeviceLib/group_sort/group_private_KV_sort_p1p1_p3.hpp b/sycl/test-e2e/DeviceLib/group_sort/group_private_KV_sort_p1p1_p3.hpp new file mode 100644 index 0000000000000..14f2dd3eea821 --- /dev/null +++ b/sycl/test-e2e/DeviceLib/group_sort/group_private_KV_sort_p1p1_p3.hpp @@ -0,0 +1,1193 @@ +#include "group_sort.hpp" +#include +#include +#include +#include +#include +#include +#include +#include +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1u8_p1u8_u32_p3i8( + uint8_t *keys, uint8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1u8_p1u8_u32_p3i8( + uint8_t *keys, uint8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1u8_p1i8_u32_p3i8( + uint8_t *keys, int8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1u8_p1i8_u32_p3i8( + uint8_t *keys, int8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1u8_p1u16_u32_p3i8( + uint8_t *keys, uint16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1u8_p1u16_u32_p3i8( + uint8_t *keys, uint16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1u8_p1i16_u32_p3i8( + uint8_t *keys, int16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1u8_p1i16_u32_p3i8( + uint8_t *keys, int16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1u8_p1u32_u32_p3i8( + uint8_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1u8_p1u32_u32_p3i8( + uint8_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1u8_p1i32_u32_p3i8( + uint8_t *keys, int32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1u8_p1i32_u32_p3i8( + uint8_t *keys, int32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1u8_p1u64_u32_p3i8( + uint8_t *keys, uint64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1u8_p1u64_u32_p3i8( + uint8_t *keys, uint64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1u8_p1i64_u32_p3i8( + uint8_t *keys, int64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1u8_p1i64_u32_p3i8( + uint8_t *keys, int64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1u8_p1f32_u32_p3i8( + uint8_t *keys, float *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1i8_p1u8_u32_p3i8( + int8_t *keys, uint8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1i8_p1u8_u32_p3i8( + int8_t *keys, uint8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1i8_p1i8_u32_p3i8( + int8_t *keys, int8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1i8_p1i8_u32_p3i8( + int8_t *keys, int8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1i8_p1u16_u32_p3i8( + int8_t *keys, uint16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1i8_p1u16_u32_p3i8( + int8_t *keys, uint16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1i8_p1i16_u32_p3i8( + int8_t *keys, int16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1i8_p1i16_u32_p3i8( + int8_t *keys, int16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1i8_p1u32_u32_p3i8( + int8_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1i8_p1u32_u32_p3i8( + int8_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1i8_p1i32_u32_p3i8( + int8_t *keys, int32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1i8_p1i32_u32_p3i8( + int8_t *keys, int32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1i8_p1u64_u32_p3i8( + int8_t *keys, uint64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1i8_p1u64_u32_p3i8( + int8_t *keys, uint64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1i8_p1i64_u32_p3i8( + int8_t *keys, int64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1i8_p1i64_u32_p3i8( + int8_t *keys, int64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1i8_p1f32_u32_p3i8( + int8_t *keys, float *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1u16_p1u8_u32_p3i8( + uint16_t *keys, uint8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1u16_p1u8_u32_p3i8( + uint16_t *keys, uint8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1u16_p1i8_u32_p3i8( + uint16_t *keys, int8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1u16_p1i8_u32_p3i8( + uint16_t *keys, int8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1u16_p1u16_u32_p3i8( + uint16_t *keys, uint16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1u16_p1u16_u32_p3i8( + uint16_t *keys, uint16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1u16_p1i16_u32_p3i8( + uint16_t *keys, int16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1u16_p1i16_u32_p3i8( + uint16_t *keys, int16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1u16_p1u32_u32_p3i8( + uint16_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1u16_p1u32_u32_p3i8( + uint16_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1u16_p1i32_u32_p3i8( + uint16_t *keys, int32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1u16_p1i32_u32_p3i8( + uint16_t *keys, int32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1u16_p1u64_u32_p3i8( + uint16_t *keys, uint64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1u16_p1u64_u32_p3i8( + uint16_t *keys, uint64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1u16_p1i64_u32_p3i8( + uint16_t *keys, int64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1u16_p1i64_u32_p3i8( + uint16_t *keys, int64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1u16_p1f32_u32_p3i8( + uint16_t *keys, float *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1u16_p1f32_u32_p3i8( + uint16_t *keys, float *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1i16_p1u8_u32_p3i8( + int16_t *keys, uint8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1i16_p1u8_u32_p3i8( + int16_t *keys, uint8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1i16_p1i8_u32_p3i8( + int16_t *keys, int8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1i16_p1i8_u32_p3i8( + int16_t *keys, int8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1i16_p1u16_u32_p3i8( + int16_t *keys, uint16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1i16_p1u16_u32_p3i8( + int16_t *keys, uint16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1i16_p1i16_u32_p3i8( + int16_t *keys, int16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1i16_p1i16_u32_p3i8( + int16_t *keys, int16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1i16_p1u32_u32_p3i8( + int16_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1i16_p1u32_u32_p3i8( + int16_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1i16_p1i32_u32_p3i8( + int16_t *keys, int32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1i16_p1i32_u32_p3i8( + int16_t *keys, int32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1i16_p1u64_u32_p3i8( + int16_t *keys, uint64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1i16_p1u64_u32_p3i8( + int16_t *keys, uint64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1i16_p1i64_u32_p3i8( + int16_t *keys, int64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1i16_p1i64_u32_p3i8( + int16_t *keys, int64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1i16_p1f32_u32_p3i8( + int16_t *keys, float *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1i16_p1f32_u32_p3i8( + int16_t *keys, float *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1u32_p1u8_u32_p3i8( + uint32_t *keys, uint8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1u32_p1u8_u32_p3i8( + uint32_t *keys, uint8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1u32_p3i8_u32_p3i8( + uint32_t *keys, int8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1u32_p3i8_u32_p3i8( + uint32_t *keys, int8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1u32_p1u16_u32_p3i8( + uint32_t *keys, uint16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1u32_p1u16_u32_p3i8( + uint32_t *keys, uint16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1u32_p1i16_u32_p3i8( + uint32_t *keys, int16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1u32_p1i16_u32_p3i8( + uint32_t *keys, int16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1u32_p1u32_u32_p3i8( + uint32_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1u32_p1u32_u32_p3i8( + uint32_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1u32_p1i32_u32_p3i8( + uint32_t *keys, int32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1u32_p1i32_u32_p3i8( + uint32_t *keys, int32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1u32_p1u64_u32_p3i8( + uint32_t *keys, uint64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1u32_p1u64_u32_p3i8( + uint32_t *keys, uint64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1u32_p1i64_u32_p3i8( + uint32_t *keys, int64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1u32_p1i64_u32_p3i8( + uint32_t *keys, int64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1i32_p1u8_u32_p3i8( + int32_t *keys, uint8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1i32_p1u8_u32_p3i8( + int32_t *keys, uint8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1i32_p1i8_u32_p3i8( + int32_t *keys, int8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1i32_p1i8_u32_p3i8( + int32_t *keys, int8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1i32_p1u16_u32_p3i8( + int32_t *keys, uint16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1i32_p1u16_u32_p3i8( + int32_t *keys, uint16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1i32_p1i16_u32_p3i8( + int32_t *keys, int16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1i32_p1i16_u32_p3i8( + int32_t *keys, int16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1i32_p1u32_u32_p3i8( + int32_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1i32_p1u32_u32_p3i8( + int32_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1i32_p1i32_u32_p3i8( + int32_t *keys, int32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1i32_p1i32_u32_p3i8( + int32_t *keys, int32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1i32_p1u64_u32_p3i8( + int32_t *keys, uint64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1i32_p1u64_u32_p3i8( + int32_t *keys, uint64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1i32_p1i64_u32_p3i8( + int32_t *keys, int64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1i32_p1i64_u32_p3i8( + int32_t *keys, int64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1u64_p1u8_u32_p3i8( + uint64_t *keys, uint8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1u64_p1u8_u32_p3i8( + uint64_t *keys, uint8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1u64_p1i8_u32_p3i8( + uint64_t *keys, int8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1u64_p1i8_u32_p3i8( + uint64_t *keys, int8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1u64_p1u16_u32_p3i8( + uint64_t *keys, uint16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1u64_p1u16_u32_p3i8( + uint64_t *keys, uint16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1u64_p1i16_u32_p3i8( + uint64_t *keys, int16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1u64_p1i16_u32_p3i8( + uint64_t *keys, int16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1u64_p1u32_u32_p3i8( + uint64_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1u64_p1u32_u32_p3i8( + uint64_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1u64_p1i32_u32_p3i8( + uint64_t *keys, int32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1u64_p1i32_u32_p3i8( + uint64_t *keys, int32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1u64_p1u64_u32_p3i8( + uint64_t *keys, uint64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1u64_p1u64_u32_p3i8( + uint64_t *keys, uint64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1u64_p1i64_u32_p3i8( + uint64_t *keys, int64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1u64_p1i64_u32_p3i8( + uint64_t *keys, int64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1i64_p1u8_u32_p3i8( + int64_t *keys, uint8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1i64_p1u8_u32_p3i8( + int64_t *keys, uint8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1i64_p1i8_u32_p3i8( + int64_t *keys, int8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1i64_p1i8_u32_p3i8( + int64_t *keys, int8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1i64_p1u16_u32_p3i8( + int64_t *keys, uint16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1i64_p1u16_u32_p3i8( + int64_t *keys, uint16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1i64_p1i16_u32_p3i8( + int64_t *keys, int16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1i64_p1i16_u32_p3i8( + int64_t *keys, int16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1i64_p1u32_u32_p3i8( + int64_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1i64_p1u32_u32_p3i8( + int64_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1i64_p1i32_u32_p3i8( + int64_t *keys, int32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1i64_p1i32_u32_p3i8( + int64_t *keys, int32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1i64_p1u64_u32_p3i8( + int64_t *keys, uint64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1i64_p1u64_u32_p3i8( + int64_t *keys, uint64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1i64_p1i64_u32_p3i8( + int64_t *keys, int64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1i64_p1i64_u32_p3i8( + int64_t *keys, int64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1u8_p1u8_u32_p3i8( + uint8_t *keys, uint8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1u8_p1u8_u32_p3i8( + uint8_t *keys, uint8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1u8_p1i8_u32_p3i8( + uint8_t *keys, int8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1u8_p1i8_u32_p3i8( + uint8_t *keys, int8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1u8_p1u16_u32_p3i8( + uint8_t *keys, uint16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1u8_p1u16_u32_p3i8( + uint8_t *keys, uint16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1u8_p1i16_u32_p3i8( + uint8_t *keys, int16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1u8_p1i16_u32_p3i8( + uint8_t *keys, int16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1u8_p1u32_u32_p3i8( + uint8_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1u8_p1u32_u32_p3i8( + uint8_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1u8_p1i32_u32_p3i8( + uint8_t *keys, int32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1u8_p1i32_u32_p3i8( + uint8_t *keys, int32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1u8_p1u64_u32_p3i8( + uint8_t *keys, uint64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1u8_p1u64_u32_p3i8( + uint8_t *keys, uint64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1u8_p1i64_u32_p3i8( + uint8_t *keys, int64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1u8_p1i64_u32_p3i8( + uint8_t *keys, int64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1u8_p1f32_u32_p3i8( + uint8_t *keys, float *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1i8_p1u8_u32_p3i8( + int8_t *keys, uint8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1i8_p1u8_u32_p3i8( + int8_t *keys, uint8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1i8_p1i8_u32_p3i8( + int8_t *keys, int8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1i8_p1i8_u32_p3i8( + int8_t *keys, int8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1i8_p1u16_u32_p3i8( + int8_t *keys, uint16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1i8_p1u16_u32_p3i8( + int8_t *keys, uint16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1i8_p1i16_u32_p3i8( + int8_t *keys, int16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1i8_p1i16_u32_p3i8( + int8_t *keys, int16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1i8_p1u32_u32_p3i8( + int8_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1i8_p1u32_u32_p3i8( + int8_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1i8_p1i32_u32_p3i8( + int8_t *keys, int32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1i8_p1i32_u32_p3i8( + int8_t *keys, int32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1i8_p1u64_u32_p3i8( + int8_t *keys, uint64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1i8_p1u64_u32_p3i8( + int8_t *keys, uint64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1i8_p1i64_u32_p3i8( + int8_t *keys, int64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1i8_p1i64_u32_p3i8( + int8_t *keys, int64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1i8_p1f32_u32_p3i8( + int8_t *keys, float *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1u16_p1u8_u32_p3i8( + uint16_t *keys, uint8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1u16_p1u8_u32_p3i8( + uint16_t *keys, uint8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1u16_p1i8_u32_p3i8( + uint16_t *keys, int8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1u16_p1i8_u32_p3i8( + uint16_t *keys, int8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1u16_p1u16_u32_p3i8( + uint16_t *keys, uint16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1u16_p1u16_u32_p3i8( + uint16_t *keys, uint16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1u16_p1i16_u32_p3i8( + uint16_t *keys, int16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1u16_p1i16_u32_p3i8( + uint16_t *keys, int16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1u16_p1u32_u32_p3i8( + uint16_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1u16_p1u32_u32_p3i8( + uint16_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1u16_p1i32_u32_p3i8( + uint16_t *keys, int32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1u16_p1i32_u32_p3i8( + uint16_t *keys, int32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1u16_p1u64_u32_p3i8( + uint16_t *keys, uint64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1u16_p1u64_u32_p3i8( + uint16_t *keys, uint64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1u16_p1i64_u32_p3i8( + uint16_t *keys, int64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1u16_p1i64_u32_p3i8( + uint16_t *keys, int64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1u16_p1f32_u32_p3i8( + uint16_t *keys, float *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1u16_p1f32_u32_p3i8( + uint16_t *keys, float *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1i16_p1u8_u32_p3i8( + int16_t *keys, uint8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1i16_p1u8_u32_p3i8( + int16_t *keys, uint8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1i16_p1i8_u32_p3i8( + int16_t *keys, int8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1i16_p1i8_u32_p3i8( + int16_t *keys, int8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1i16_p1u16_u32_p3i8( + int16_t *keys, uint16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1i16_p1u16_u32_p3i8( + int16_t *keys, uint16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1i16_p1i16_u32_p3i8( + int16_t *keys, int16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1i16_p1i16_u32_p3i8( + int16_t *keys, int16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1i16_p1u32_u32_p3i8( + int16_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1i16_p1u32_u32_p3i8( + int16_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1i16_p1i32_u32_p3i8( + int16_t *keys, int32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1i16_p1i32_u32_p3i8( + int16_t *keys, int32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1i16_p1u64_u32_p3i8( + int16_t *keys, uint64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1i16_p1u64_u32_p3i8( + int16_t *keys, uint64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1i16_p1i64_u32_p3i8( + int16_t *keys, int64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1i16_p1i64_u32_p3i8( + int16_t *keys, int64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1i16_p1f32_u32_p3i8( + int16_t *keys, float *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1i16_p1f32_u32_p3i8( + int16_t *keys, float *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1u32_p1u8_u32_p3i8( + uint32_t *keys, uint8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1u32_p1u8_u32_p3i8( + uint32_t *keys, uint8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1u32_p3i8_u32_p3i8( + uint32_t *keys, int8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1u32_p3i8_u32_p3i8( + uint32_t *keys, int8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1u32_p1u16_u32_p3i8( + uint32_t *keys, uint16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1u32_p1u16_u32_p3i8( + uint32_t *keys, uint16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1u32_p1i16_u32_p3i8( + uint32_t *keys, int16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1u32_p1i16_u32_p3i8( + uint32_t *keys, int16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1u32_p1u32_u32_p3i8( + uint32_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1u32_p1u32_u32_p3i8( + uint32_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1u32_p1i32_u32_p3i8( + uint32_t *keys, int32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1u32_p1i32_u32_p3i8( + uint32_t *keys, int32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1u32_p1u64_u32_p3i8( + uint32_t *keys, uint64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1u32_p1u64_u32_p3i8( + uint32_t *keys, uint64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1u32_p1i64_u32_p3i8( + uint32_t *keys, int64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1u32_p1i64_u32_p3i8( + uint32_t *keys, int64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1i32_p1u8_u32_p3i8( + int32_t *keys, uint8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1i32_p1u8_u32_p3i8( + int32_t *keys, uint8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1i32_p1i8_u32_p3i8( + int32_t *keys, int8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1i32_p1i8_u32_p3i8( + int32_t *keys, int8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1i32_p1u16_u32_p3i8( + int32_t *keys, uint16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1i32_p1u16_u32_p3i8( + int32_t *keys, uint16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1i32_p1i16_u32_p3i8( + int32_t *keys, int16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1i32_p1i16_u32_p3i8( + int32_t *keys, int16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1i32_p1u32_u32_p3i8( + int32_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1i32_p1u32_u32_p3i8( + int32_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1i32_p1i32_u32_p3i8( + int32_t *keys, int32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1i32_p1i32_u32_p3i8( + int32_t *keys, int32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1i32_p1u64_u32_p3i8( + int32_t *keys, uint64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1i32_p1u64_u32_p3i8( + int32_t *keys, uint64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1i32_p1i64_u32_p3i8( + int32_t *keys, int64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1i32_p1i64_u32_p3i8( + int32_t *keys, int64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1u64_p1u8_u32_p3i8( + uint64_t *keys, uint8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1u64_p1u8_u32_p3i8( + uint64_t *keys, uint8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1u64_p1i8_u32_p3i8( + uint64_t *keys, int8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1u64_p1i8_u32_p3i8( + uint64_t *keys, int8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1u64_p1u16_u32_p3i8( + uint64_t *keys, uint16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1u64_p1u16_u32_p3i8( + uint64_t *keys, uint16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1u64_p1i16_u32_p3i8( + uint64_t *keys, int16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1u64_p1i16_u32_p3i8( + uint64_t *keys, int16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1u64_p1u32_u32_p3i8( + uint64_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1u64_p1u32_u32_p3i8( + uint64_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1u64_p1i32_u32_p3i8( + uint64_t *keys, int32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1u64_p1i32_u32_p3i8( + uint64_t *keys, int32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1u64_p1u64_u32_p3i8( + uint64_t *keys, uint64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1u64_p1u64_u32_p3i8( + uint64_t *keys, uint64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1u64_p1i64_u32_p3i8( + uint64_t *keys, int64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1u64_p1i64_u32_p3i8( + uint64_t *keys, int64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1i64_p1u8_u32_p3i8( + int64_t *keys, uint8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1i64_p1u8_u32_p3i8( + int64_t *keys, uint8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1i64_p1i8_u32_p3i8( + int64_t *keys, int8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1i64_p1i8_u32_p3i8( + int64_t *keys, int8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1i64_p1u16_u32_p3i8( + int64_t *keys, uint16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1i64_p1u16_u32_p3i8( + int64_t *keys, uint16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1i64_p1i16_u32_p3i8( + int64_t *keys, int16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1i64_p1i16_u32_p3i8( + int64_t *keys, int16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1i64_p1u32_u32_p3i8( + int64_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1i64_p1u32_u32_p3i8( + int64_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1i64_p1i32_u32_p3i8( + int64_t *keys, int32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1i64_p1i32_u32_p3i8( + int64_t *keys, int32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1i64_p1u64_u32_p3i8( + int64_t *keys, uint64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1i64_p1u64_u32_p3i8( + int64_t *keys, uint64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1i64_p1i64_u32_p3i8( + int64_t *keys, int64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1i64_p1i64_u32_p3i8( + int64_t *keys, int64_t *vals, uint32_t n, uint8_t *scratch); + +using namespace sycl; + +template +void test_work_group_KV_private_sort(sycl::queue &q, const KeyT keys[NUM], + const ValT vals[NUM], SortHelper gsh) { + static_assert((NUM % WG_SZ == 0), + "Input number must be divisible by work group size!"); + + KeyT input_keys[NUM]; + ValT input_vals[NUM]; + memcpy(&input_keys[0], &keys[0], NUM * sizeof(KeyT)); + memcpy(&input_vals[0], &vals[0], NUM * sizeof(ValT)); + // Make sure sratch memory is always 8-byte aligned. + size_t scratch_size = 2 * NUM * (sizeof(KeyT) + sizeof(ValT)) + + std::max(alignof(KeyT), alignof(ValT)); + scratch_size = (((scratch_size - 1) >> 3) + 1) << 3; + + const static size_t wg_size = WG_SZ; + constexpr size_t num_per_work_item = NUM / WG_SZ; + KeyT output_keys[NUM]; + ValT output_vals[NUM]; + std::vector> sorted_vec; + for (size_t idx = 0; idx < NUM; ++idx) + sorted_vec.push_back(std::make_tuple(input_keys[idx], input_vals[idx])); +#ifdef DES + auto kv_tuple_comp = [](const std::tuple &t1, + const std::tuple &t2) { + return std::get<0>(t1) > std::get<0>(t2); + }; +#else + auto kv_tuple_comp = [](const std::tuple &t1, + const std::tuple &t2) { + return std::get<0>(t1) < std::get<0>(t2); + }; +#endif + std::stable_sort(sorted_vec.begin(), sorted_vec.end(), kv_tuple_comp); + + /*for (size_t idx = 0; idx < NUM; ++idx) { + std::cout << "key: " << (int)std::get<0>(sorted_vec[idx]) << " val: " << + (int)std::get<1>(sorted_vec[idx]) << std::endl; + }*/ + + nd_range<1> num_items((range<1>(wg_size)), (range<1>(wg_size))); + { + buffer ikeys_buf(input_keys, NUM); + buffer ivals_buf(input_vals, NUM); + buffer okeys_buf(output_keys, NUM); + buffer ovals_buf(output_vals, NUM); + q.submit([&](auto &h) { + accessor ikeys_acc{ikeys_buf, h}; + accessor ivals_acc{ivals_buf, h}; + accessor okeys_acc{okeys_buf, h}; + accessor ovals_acc{ovals_buf, h}; + local_accessor scratch_acc(scratch_size >> 3, h); + sycl::stream os(1024, 128, h); + h.parallel_for(num_items, [=](nd_item<1> i) { + KeyT pkeys[num_per_work_item]; + ValT pvals[num_per_work_item]; + // copy from global input to fix-size private array. + for (size_t idx = 0; idx < num_per_work_item; ++idx) { + pkeys[idx] = + ikeys_acc[i.get_local_linear_id() * num_per_work_item + idx]; + pvals[idx] = + ivals_acc[i.get_local_linear_id() * num_per_work_item + idx]; + } + + uint8_t *scratch_ptr = reinterpret_cast( + scratch_acc.template get_multi_ptr().get()); + gsh(pkeys, pvals, num_per_work_item, scratch_ptr); + + for (size_t idx = 0; idx < num_per_work_item; ++idx) { + okeys_acc[i.get_local_linear_id() * num_per_work_item + idx] = + pkeys[idx]; + ovals_acc[i.get_local_linear_id() * num_per_work_item + idx] = + pvals[idx]; + } + }); + }).wait(); + } + + /* for (size_t idx = 0; idx < NUM; ++idx) { + std::cout << "key: " << (int)(input_keys[idx]) << " val: " << + (int)(input_vals[idx]) << std::endl; + }*/ + + bool fails = false; +#ifdef SPREAD + for (size_t idx = 0; idx < NUM; ++idx) { + size_t idx1 = idx / WG_SZ; + size_t idx2 = idx % WG_SZ; + if ((output_keys[idx2 * num_per_work_item + idx1] != + std::get<0>(sorted_vec[idx])) || + (output_vals[idx2 * num_per_work_item + idx1] != + std::get<1>(sorted_vec[idx]))) { + std::cout << "idx: " << idx << std::endl; + fails = true; + break; + } + } +#else + for (size_t idx = 0; idx < NUM; ++idx) { + if ((output_keys[idx] != std::get<0>(sorted_vec[idx])) || + (output_vals[idx] != std::get<1>(sorted_vec[idx]))) { + std::cout << "idx: " << idx << std::endl; + fails = true; + break; + } + } +#endif + assert(!fails); +} diff --git a/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_p3_i16.cpp b/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_p3_i16.cpp new file mode 100644 index 0000000000000..9516999e0d9b1 --- /dev/null +++ b/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_p3_i16.cpp @@ -0,0 +1,608 @@ +// RUN: %{build} -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -fsycl-device-lib-jit-link -o %t.out +// RUN: %{run} %t.out + + +// RUN: %{build} -DDES -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DDES -fsycl-device-lib-jit-link -o %t.out +// RUN: %{run} %t.out + + +// RUN: %{build} -DSPREAD -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DSPREAD -fsycl-device-lib-jit-link -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DDES -DSPREAD -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DDES -DSPREAD -fsycl-device-lib-jit-link -o %t.out +// RUN: %{run} %t.out + +// UNSUPPORTED: cuda || hip || cpu + +#include "group_private_KV_sort_p1p1_p3.hpp" + +int main() { + queue q; + { + constexpr static int NUM = 35; + int16_t ikeys[NUM] = {1, -11, 1, 9, -3, 100, 34, 8, 121, + -77, 125, 23, 36, -2, -111, 91, 88, -2, + 51, -23, -81, 83, 31, 42, 2, 1, -99, + 124, 12, 0, -81, 17, 15, 101, 44}; + uint32_t ivals[NUM] = {99, 32, 1, 2, 67, 9123, 453, + 435, 91111, 777, 165, 145, 2456, 88811, + 761, 96, 765, 10000, 6364, 90, 525, + 882, 1, 2423, 9, 4324, 9123, 0, + 1232, 777, 555, 314159, 905, 9831, 84341}; + uint8_t ivals1[NUM] = {99, 32, 1, 2, 67, 91, 45, 43, 91, 77, 16, 14, + 24, 88, 76, 96, 76, 100, 63, 90, 52, 82, 1, 22, + 9, 225, 127, 0, 12, 128, 3, 102, 200, 111, 123}; + int8_t ivals2[NUM] = {-99, 127, -121, 100, 9, 5, 12, 35, -98, + 77, 112, -91, 11, 12, 3, 71, -66, 121, + 18, 14, 21, -22, 54, 88, -81, 31, 23, + 53, 97, 103, 71, 83, 97, 37, -41}; + uint16_t ivals3[NUM] = {28831, 23870, 54250, 5022, 9571, 60147, 9554, + 18818, 28689, 18229, 40512, 23200, 40454, 24841, + 43251, 63264, 29448, 45917, 882, 30788, 7586, + 57541, 22108, 59535, 31880, 7152, 63919, 58703, + 14686, 29914, 5872, 35868, 51479, 22721, 50927}; + int16_t ivals4[NUM] = { + 2798, -13656, 1592, 3992, -25870, 25172, 7761, -18347, 1617, + 25472, 26763, -5982, 24791, 27189, 22911, 22502, 15801, 25326, + -2196, 9205, -10418, 20464, -16616, -11285, 7249, 22866, 30574, + -1298, 31351, 28252, 21322, -10072, 7874, -26785, 22016}; + + uint32_t ivals5[NUM] = { + 2238578408, 102907035, 2316773768, 617902655, 532045482, 73173328, + 1862406505, 142735533, 3494078873, 610196959, 4210902254, 1863122236, + 1257721692, 30008197, 3199012044, 3503276708, 3504950001, 1240383071, + 2463430884, 904104390, 4044803029, 3164373711, 1586440767, 1999536602, + 3377662770, 927279985, 1740225703, 1133653675, 3975816601, 260339911, + 1115507520, 2279020820, 4289105012, 692964674, 53775301}; + + int32_t ivals6[NUM] = { + 507394811, 1949685322, 1624859474, -940434061, -1440675113, + -2002743224, 369969519, 840772268, 224522238, 296113452, + -714007528, 480713824, 665592454, 1696360848, 780843358, + -1901994531, 1667711523, 1390737696, 1357434904, -290165630, + 305128121, 1301489180, 630469211, -1385846315, 809333959, + 1098974670, 56900257, 876775101, -1496897817, 1172877939, + 1528916082, 559152364, 749878571, 2071902702, -430851798}; + + uint64_t ivals7[NUM] = {7916688577774406903ULL, 16873442000174444235ULL, + 5261140449171682429ULL, 6274209348061756377ULL, + 17881284978944229367ULL, 4701456380424752599ULL, + 6241062870187348613ULL, 8972524466137433448ULL, + 468629112944127776ULL, 17523909311582893643ULL, + 17447772191733931166ULL, 14311152396797789854ULL, + 9265327272409079662ULL, 9958475911404242556ULL, + 15359829357736068242ULL, 11416531655738519189ULL, + 16839972510321195914ULL, 1927049095689256442ULL, + 3356565661065236628ULL, 1065114285796701562ULL, + 7071763288904613033ULL, 16473053015775147286ULL, + 10317354477399696817ULL, 16005969584273256379ULL, + 15391010921709289298ULL, 17671303749287233862ULL, + 8028596930095411867ULL, 10265936863337610975ULL, + 17403745948359398749ULL, 8504886230707230194ULL, + 12855085916215721214ULL, 5562885793068933146ULL, + 1508385574711135517ULL, 5953119477818575536ULL, + 9165320150094769334ULL}; + + int64_t ivals8[NUM] = { + 2944696543084623337, 137239101631340692, 4370169869966467498, + 3842452903439153631, -795080033670806202, 3023506421574592237, + -4142692575864168559, 1716333381567689984, 1591746912204250089, + -1974664662220599925, 3144022139297218102, -371429365537296255, + 4202906659701034264, 3878513012313576184, -3425767072006791628, + -2929337291418891626, 1880013370888913338, 1498977159463939728, + -2928775660744278650, 4074214991200977615, 4291797122649374026, + -763110360214750992, 2883673064242977727, 4270151072450399778, + 1408696225027958214, 1264214335825459628, -4152065441956669638, + 2684706424226400837, 569335474272794084, -2798088842177577224, + 814002749054152728, 2517003920904582842, 4089891582575745386, + 705067059635512048, -2500935118374519236}; + + auto work_group_sorter = [](int16_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1i16_p1u32_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1i16_p1u32_u32_p3i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1i16_p1u32_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1i16_p1u32_u32_p3i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter1 = [](int16_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1i16_p1u8_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1i16_p1u8_u32_p3i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1i16_p1u8_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1i16_p1u8_u32_p3i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter2 = [](int16_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1i16_p1i8_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1i16_p1i8_u32_p3i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1i16_p1i8_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1i16_p1i8_u32_p3i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter3 = [](int16_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1i16_p1u16_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1i16_p1u16_u32_p3i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1i16_p1u16_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1i16_p1u16_u32_p3i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter4 = [](int16_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1i16_p1i16_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1i16_p1i16_u32_p3i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1i16_p1i16_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1i16_p1i16_u32_p3i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter5 = [](int16_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1i16_p1u32_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1i16_p1u32_u32_p3i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1i16_p1u32_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1i16_p1u32_u32_p3i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter6 = [](int16_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1i16_p1i32_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1i16_p1i32_u32_p3i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1i16_p1i32_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1i16_p1i32_u32_p3i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter7 = [](int16_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1i16_p1u64_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1i16_p1u64_u32_p3i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1i16_p1u64_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1i16_p1u64_u32_p3i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter8 = [](int16_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1i16_p1i64_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1i16_p1i64_u32_p3i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1i16_p1i64_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1i16_p1i64_u32_p3i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + constexpr static int NUM1 = 32; + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 4 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 8 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 16 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 32 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 5 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 7 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 35 pass." << std::endl; + + constexpr static int NUM2 = 24; + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 3 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 4 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 6 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 8 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 12 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 24 pass." << std::endl; + + constexpr static int NUM3 = 20; + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 4 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 5 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 10 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 20 pass." << std::endl; + + constexpr static int NUM4 = 30; + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 3 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 5 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 6 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 10 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 15 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 30 pass." << std::endl; + + constexpr static int NUM5 = 25; + test_work_group_KV_private_sort( + q, ikeys, ivals6, work_group_sorter6); + std::cout << "KV private sort NUM = " << NUM5 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals6, work_group_sorter6); + std::cout << "KV private sort NUM = " << NUM5 + << ", WG = 5 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals6, work_group_sorter6); + std::cout << "KV private sort NUM = " << NUM5 + << ", WG = 25 pass." << std::endl; + + constexpr static int NUM6 = 30; + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 3 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 5 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 6 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 10 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 15 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 30 pass." << std::endl; + + constexpr static int NUM7 = 21; + test_work_group_KV_private_sort( + q, ikeys, ivals8, work_group_sorter8); + std::cout << "KV private sort NUM = " << NUM7 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals8, work_group_sorter8); + std::cout << "KV private sort NUM = " << NUM7 + << ", WG = 3 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals8, work_group_sorter8); + std::cout << "KV private sort NUM = " << NUM7 + << ", WG = 7 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals8, work_group_sorter8); + std::cout << "KV private sort NUM = " << NUM7 + << ", WG = 21 pass." << std::endl; + } +} diff --git a/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_p3_i32.cpp b/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_p3_i32.cpp new file mode 100644 index 0000000000000..36eb3a46f00cc --- /dev/null +++ b/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_p3_i32.cpp @@ -0,0 +1,695 @@ +// RUN: %{build} -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -fsycl-device-lib-jit-link -o %t.out +// RUN: %{run} %t.out + + +// RUN: %{build} -DDES -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DDES -fsycl-device-lib-jit-link -o %t.out +// RUN: %{run} %t.out + + +// RUN: %{build} -DSPREAD -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DSPREAD -fsycl-device-lib-jit-link -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DDES -DSPREAD -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DDES -DSPREAD -fsycl-device-lib-jit-link -o %t.out +// RUN: %{run} %t.out + +// UNSUPPORTED: cuda || hip || cpu + +#include "group_private_KV_sort_p1p1_p3.hpp" + +int main() { + queue q; + + { + constexpr static int NUM = 40; + int32_t ikeys[NUM] = { + 1, 11, 1, 9, 3, 100, 34, 8, + 121, 77, 125, 23, 222336, 2, 111, 91, + 881112, -2, -62451, 213, 199181, 3183, 310910, 11242, + 3216, 1, 199, -7712423, 0, 0, -181, 17, + 15, -101, 44, 103934, 1, 11, 193, 213}; + uint8_t ivals1[NUM] = {99, 32, 1, 2, 67, 91, 45, 43, 91, 77, + 16, 14, 24, 88, 76, 96, 76, 100, 63, 90, + 52, 82, 1, 22, 9, 225, 127, 0, 12, 128, + 3, 102, 200, 111, 123, 15, 45, 66, 123, 91}; + int8_t ivals2[NUM] = {-99, 127, -121, 100, 9, 5, 12, 35, -98, 77, + 112, -91, 11, 12, 3, 71, -66, 121, 18, 14, + 21, -22, 54, 88, -81, 31, 23, 53, 97, 103, + 71, 83, 97, 37, -41, -71, 112, -121, 98, 78}; + uint16_t ivals3[NUM] = { + 28831, 23870, 54250, 5022, 9571, 60147, 9554, 18818, 28689, 18229, + 40512, 23200, 40454, 24841, 43251, 63264, 29448, 45917, 882, 30788, + 7586, 57541, 22108, 59535, 31880, 7152, 63919, 58703, 14686, 29914, + 5872, 35868, 51479, 22721, 50927, 55094, 2341, 12, 23411, 9812}; + int16_t ivals4[NUM] = { + 2798, -13656, 1592, 3992, -25870, 25172, 7761, -18347, + 1617, 25472, 26763, -5982, 24791, 27189, 22911, 22502, + 15801, 25326, -2196, 9205, -10418, 20464, -16616, -11285, + 7249, 22866, 30574, -1298, 31351, 28252, 21322, -10072, + 7874, -26785, 22016, -12421, 0, 9999, 8888, -7777}; + + uint32_t ivals5[NUM] = { + 2238578408, 102907035, 2316773768, 617902655, 532045482, 73173328, + 1862406505, 142735533, 3494078873, 610196959, 4210902254, 1863122236, + 1257721692, 30008197, 3199012044, 3503276708, 3504950001, 1240383071, + 2463430884, 904104390, 4044803029, 3164373711, 1586440767, 1999536602, + 3377662770, 927279985, 1740225703, 1133653675, 3975816601, 260339911, + 1115507520, 2279020820, 4289105012, 692964674, 53775301, 999999999, + 777777777, 44444444, 1000, 9124}; + + int32_t ivals6[NUM] = { + 507394811, 1949685322, 1624859474, -940434061, -1440675113, + -2002743224, 369969519, 840772268, 224522238, 296113452, + -714007528, 480713824, 665592454, 1696360848, 780843358, + -1901994531, 1667711523, 1390737696, 1357434904, -290165630, + 305128121, 1301489180, 630469211, -1385846315, 809333959, + 1098974670, 56900257, 876775101, -1496897817, 1172877939, + 1528916082, 559152364, 749878571, 2071902702, -430851798, + 912342, -88888888, 777777777, -11111111, 0}; + + uint64_t ivals7[NUM] = {7916688577774406903ULL, + 16873442000174444235ULL, + 5261140449171682429ULL, + 6274209348061756377ULL, + 17881284978944229367ULL, + 4701456380424752599ULL, + 6241062870187348613ULL, + 8972524466137433448ULL, + 468629112944127776ULL, + 17523909311582893643ULL, + 17447772191733931166ULL, + 14311152396797789854ULL, + 9265327272409079662ULL, + 9958475911404242556ULL, + 15359829357736068242ULL, + 11416531655738519189ULL, + 16839972510321195914ULL, + 1927049095689256442ULL, + 3356565661065236628ULL, + 1065114285796701562ULL, + 7071763288904613033ULL, + 16473053015775147286ULL, + 10317354477399696817ULL, + 16005969584273256379ULL, + 15391010921709289298ULL, + 17671303749287233862ULL, + 8028596930095411867ULL, + 10265936863337610975ULL, + 17403745948359398749ULL, + 8504886230707230194ULL, + 12855085916215721214ULL, + 5562885793068933146ULL, + 1508385574711135517ULL, + 5953119477818575536ULL, + 9165320150094769334ULL, + 0, + 11, + 324, + 943534, + 930525}; + + int64_t ivals8[NUM] = {2944696543084623337, + 137239101631340692, + 4370169869966467498, + 3842452903439153631, + -795080033670806202, + 3023506421574592237, + -4142692575864168559, + 1716333381567689984, + 1591746912204250089, + -1974664662220599925, + 3144022139297218102, + -371429365537296255, + 4202906659701034264, + 3878513012313576184, + -3425767072006791628, + -2929337291418891626, + 1880013370888913338, + 1498977159463939728, + -2928775660744278650, + 4074214991200977615, + 4291797122649374026, + -763110360214750992, + 2883673064242977727, + 4270151072450399778, + 1408696225027958214, + 1264214335825459628, + -4152065441956669638, + 2684706424226400837, + 569335474272794084, + -2798088842177577224, + 814002749054152728, + 2517003920904582842, + 4089891582575745386, + 705067059635512048, + -2500935118374519236, + 0, + -1, + 234, + 45, + -1232143254}; + + auto work_group_sorter = [](int32_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1i32_p1u32_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1i32_p1u32_u32_p3i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1i32_p1u32_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1i32_p1u32_u32_p3i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter1 = [](int32_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1i32_p1u8_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1i32_p1u8_u32_p3i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1i32_p1u8_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1i32_p1u8_u32_p3i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter2 = [](int32_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1i32_p1i8_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1i32_p1i8_u32_p3i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1i32_p1i8_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1i32_p1i8_u32_p3i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter3 = [](int32_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1i32_p1u16_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1i32_p1u16_u32_p3i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1i32_p1u16_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1i32_p1u16_u32_p3i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter4 = [](int32_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1i32_p1i16_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1i32_p1i16_u32_p3i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1i32_p1i16_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1i32_p1i16_u32_p3i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter5 = [](int32_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1i32_p1u32_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1i32_p1u32_u32_p3i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1i32_p1u32_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1i32_p1u32_u32_p3i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter6 = [](int32_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1i32_p1i32_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1i32_p1i32_u32_p3i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1i32_p1i32_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1i32_p1i32_u32_p3i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter7 = [](int32_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1i32_p1u64_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1i32_p1u64_u32_p3i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1i32_p1u64_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1i32_p1u64_u32_p3i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter8 = [](int32_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1i32_p1i64_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1i32_p1i64_u32_p3i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1i32_p1i64_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1i32_p1i64_u32_p3i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + constexpr static int NUM1 = 36; + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 4 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 3 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 6 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 9 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 18 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 36 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 5 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 8 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 4 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 10 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 20 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 40 pass." << std::endl; + + constexpr static int NUM2 = 24; + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 3 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 4 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 6 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 8 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 12 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 24 pass." << std::endl; + + constexpr static int NUM3 = 20; + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 4 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 5 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 10 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 20 pass." << std::endl; + + constexpr static int NUM4 = 30; + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 3 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 5 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 6 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 10 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 15 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 30 pass." << std::endl; + + constexpr static int NUM5 = 25; + test_work_group_KV_private_sort( + q, ikeys, ivals6, work_group_sorter6); + std::cout << "KV private sort NUM = " << NUM5 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals6, work_group_sorter6); + std::cout << "KV private sort NUM = " << NUM5 + << ", WG = 5 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals6, work_group_sorter6); + std::cout << "KV private sort NUM = " << NUM5 + << ", WG = 25 pass." << std::endl; + + constexpr static int NUM6 = 30; + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 3 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 5 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 6 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 10 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 15 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 30 pass." << std::endl; + + constexpr static int NUM7 = 21; + test_work_group_KV_private_sort( + q, ikeys, ivals8, work_group_sorter8); + std::cout << "KV private sort NUM = " << NUM7 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals8, work_group_sorter8); + std::cout << "KV private sort NUM = " << NUM7 + << ", WG = 3 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals8, work_group_sorter8); + std::cout << "KV private sort NUM = " << NUM7 + << ", WG = 7 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals8, work_group_sorter8); + std::cout << "KV private sort NUM = " << NUM7 + << ", WG = 21 pass." << std::endl; + } +} diff --git a/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_p3_i64.cpp b/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_p3_i64.cpp new file mode 100644 index 0000000000000..a0f7785170a30 --- /dev/null +++ b/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_p3_i64.cpp @@ -0,0 +1,695 @@ +// RUN: %{build} -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -fsycl-device-lib-jit-link -o %t.out +// RUN: %{run} %t.out + + +// RUN: %{build} -DDES -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DDES -fsycl-device-lib-jit-link -o %t.out +// RUN: %{run} %t.out + + +// RUN: %{build} -DSPREAD -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DSPREAD -fsycl-device-lib-jit-link -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DDES -DSPREAD -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DDES -DSPREAD -fsycl-device-lib-jit-link -o %t.out +// RUN: %{run} %t.out + +// UNSUPPORTED: cuda || hip || cpu + +#include "group_private_KV_sort_p1p1_p3.hpp" + +int main() { + queue q; + + { + constexpr static int NUM = 40; + int64_t ikeys[NUM] = { + 1, 11, -1, 9, 3, 100, -34, 8, + 121, 77, 125, 23, 222336, 2, 111, 91, + -881112, 2, 6662451, 213, 199181, -3183, 310910, 11242, + 3216, 1, -199, 7712423, 0, 0, 181, 17, + 15, -101, 44, 103934, 1, -11, 193, 213}; + uint8_t ivals1[NUM] = {99, 32, 1, 2, 67, 91, 45, 43, 91, 77, + 16, 14, 24, 88, 76, 96, 76, 100, 63, 90, + 52, 82, 1, 22, 9, 225, 127, 0, 12, 128, + 3, 102, 200, 111, 123, 15, 45, 66, 123, 91}; + int8_t ivals2[NUM] = {-99, 127, -121, 100, 9, 5, 12, 35, -98, 77, + 112, -91, 11, 12, 3, 71, -66, 121, 18, 14, + 21, -22, 54, 88, -81, 31, 23, 53, 97, 103, + 71, 83, 97, 37, -41, -71, 112, -121, 98, 78}; + uint16_t ivals3[NUM] = { + 28831, 23870, 54250, 5022, 9571, 60147, 9554, 18818, 28689, 18229, + 40512, 23200, 40454, 24841, 43251, 63264, 29448, 45917, 882, 30788, + 7586, 57541, 22108, 59535, 31880, 7152, 63919, 58703, 14686, 29914, + 5872, 35868, 51479, 22721, 50927, 55094, 2341, 12, 23411, 9812}; + int16_t ivals4[NUM] = { + 2798, -13656, 1592, 3992, -25870, 25172, 7761, -18347, + 1617, 25472, 26763, -5982, 24791, 27189, 22911, 22502, + 15801, 25326, -2196, 9205, -10418, 20464, -16616, -11285, + 7249, 22866, 30574, -1298, 31351, 28252, 21322, -10072, + 7874, -26785, 22016, -12421, 0, 9999, 8888, -7777}; + + uint32_t ivals5[NUM] = { + 2238578408, 102907035, 2316773768, 617902655, 532045482, 73173328, + 1862406505, 142735533, 3494078873, 610196959, 4210902254, 1863122236, + 1257721692, 30008197, 3199012044, 3503276708, 3504950001, 1240383071, + 2463430884, 904104390, 4044803029, 3164373711, 1586440767, 1999536602, + 3377662770, 927279985, 1740225703, 1133653675, 3975816601, 260339911, + 1115507520, 2279020820, 4289105012, 692964674, 53775301, 999999999, + 777777777, 44444444, 1000, 9124}; + + int32_t ivals6[NUM] = { + 507394811, 1949685322, 1624859474, -940434061, -1440675113, + -2002743224, 369969519, 840772268, 224522238, 296113452, + -714007528, 480713824, 665592454, 1696360848, 780843358, + -1901994531, 1667711523, 1390737696, 1357434904, -290165630, + 305128121, 1301489180, 630469211, -1385846315, 809333959, + 1098974670, 56900257, 876775101, -1496897817, 1172877939, + 1528916082, 559152364, 749878571, 2071902702, -430851798, + 912342, -88888888, 777777777, -11111111, 0}; + + uint64_t ivals7[NUM] = {7916688577774406903ULL, + 16873442000174444235ULL, + 5261140449171682429ULL, + 6274209348061756377ULL, + 17881284978944229367ULL, + 4701456380424752599ULL, + 6241062870187348613ULL, + 8972524466137433448ULL, + 468629112944127776ULL, + 17523909311582893643ULL, + 17447772191733931166ULL, + 14311152396797789854ULL, + 9265327272409079662ULL, + 9958475911404242556ULL, + 15359829357736068242ULL, + 11416531655738519189ULL, + 16839972510321195914ULL, + 1927049095689256442ULL, + 3356565661065236628ULL, + 1065114285796701562ULL, + 7071763288904613033ULL, + 16473053015775147286ULL, + 10317354477399696817ULL, + 16005969584273256379ULL, + 15391010921709289298ULL, + 17671303749287233862ULL, + 8028596930095411867ULL, + 10265936863337610975ULL, + 17403745948359398749ULL, + 8504886230707230194ULL, + 12855085916215721214ULL, + 5562885793068933146ULL, + 1508385574711135517ULL, + 5953119477818575536ULL, + 9165320150094769334ULL, + 0, + 11, + 324, + 943534, + 930525}; + + int64_t ivals8[NUM] = {2944696543084623337, + 137239101631340692, + 4370169869966467498, + 3842452903439153631, + -795080033670806202, + 3023506421574592237, + -4142692575864168559, + 1716333381567689984, + 1591746912204250089, + -1974664662220599925, + 3144022139297218102, + -371429365537296255, + 4202906659701034264, + 3878513012313576184, + -3425767072006791628, + -2929337291418891626, + 1880013370888913338, + 1498977159463939728, + -2928775660744278650, + 4074214991200977615, + 4291797122649374026, + -763110360214750992, + 2883673064242977727, + 4270151072450399778, + 1408696225027958214, + 1264214335825459628, + -4152065441956669638, + 2684706424226400837, + 569335474272794084, + -2798088842177577224, + 814002749054152728, + 2517003920904582842, + 4089891582575745386, + 705067059635512048, + -2500935118374519236, + 0, + -1, + 234, + 45, + -1232143254}; + + auto work_group_sorter = [](int64_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1i64_p1u32_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1i64_p1u32_u32_p3i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1i64_p1u32_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1i64_p1u32_u32_p3i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter1 = [](int64_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1i64_p1u8_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1i64_p1u8_u32_p3i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1i64_p1u8_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1i64_p1u8_u32_p3i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter2 = [](int64_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1i64_p1i8_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1i64_p1i8_u32_p3i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1i64_p1i8_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1i64_p1i8_u32_p3i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter3 = [](int64_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1i64_p1u16_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1i64_p1u16_u32_p3i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1i64_p1u16_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1i64_p1u16_u32_p3i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter4 = [](int64_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1i64_p1i16_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1i64_p1i16_u32_p3i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1i64_p1i16_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1i64_p1i16_u32_p3i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter5 = [](int64_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1i64_p1u32_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1i64_p1u32_u32_p3i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1i64_p1u32_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1i64_p1u32_u32_p3i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter6 = [](int64_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1i64_p1i32_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1i64_p1i32_u32_p3i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1i64_p1i32_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1i64_p1i32_u32_p3i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter7 = [](int64_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1i64_p1u64_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1i64_p1u64_u32_p3i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1i64_p1u64_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1i64_p1u64_u32_p3i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter8 = [](int64_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1i64_p1i64_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1i64_p1i64_u32_p3i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1i64_p1i64_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1i64_p1i64_u32_p3i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + constexpr static int NUM1 = 36; + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 4 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 3 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 6 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 9 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 18 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 36 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 5 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 8 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 4 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 10 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 20 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 40 pass." << std::endl; + + constexpr static int NUM2 = 24; + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 3 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 4 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 6 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 8 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 12 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 24 pass." << std::endl; + + constexpr static int NUM3 = 20; + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 4 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 5 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 10 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 20 pass." << std::endl; + + constexpr static int NUM4 = 30; + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 3 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 5 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 6 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 10 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 15 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 30 pass." << std::endl; + + constexpr static int NUM5 = 25; + test_work_group_KV_private_sort( + q, ikeys, ivals6, work_group_sorter6); + std::cout << "KV private sort NUM = " << NUM5 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals6, work_group_sorter6); + std::cout << "KV private sort NUM = " << NUM5 + << ", WG = 5 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals6, work_group_sorter6); + std::cout << "KV private sort NUM = " << NUM5 + << ", WG = 25 pass." << std::endl; + + constexpr static int NUM6 = 30; + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 3 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 5 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 6 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 10 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 15 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 30 pass." << std::endl; + + constexpr static int NUM7 = 21; + test_work_group_KV_private_sort( + q, ikeys, ivals8, work_group_sorter8); + std::cout << "KV private sort NUM = " << NUM7 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals8, work_group_sorter8); + std::cout << "KV private sort NUM = " << NUM7 + << ", WG = 3 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals8, work_group_sorter8); + std::cout << "KV private sort NUM = " << NUM7 + << ", WG = 7 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals8, work_group_sorter8); + std::cout << "KV private sort NUM = " << NUM7 + << ", WG = 21 pass." << std::endl; + } +} diff --git a/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_p3_i8.cpp b/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_p3_i8.cpp new file mode 100644 index 0000000000000..e048e82acb3f0 --- /dev/null +++ b/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_p3_i8.cpp @@ -0,0 +1,609 @@ +// RUN: %{build} -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -fsycl-device-lib-jit-link -o %t.out +// RUN: %{run} %t.out + + +// RUN: %{build} -DDES -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DDES -fsycl-device-lib-jit-link -o %t.out +// RUN: %{run} %t.out + + +// RUN: %{build} -DSPREAD -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DSPREAD -fsycl-device-lib-jit-link -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DDES -DSPREAD -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DDES -DSPREAD -fsycl-device-lib-jit-link -o %t.out +// RUN: %{run} %t.out + +// UNSUPPORTED: cuda || hip || cpu + +#include "group_private_KV_sort_p1p1_p3.hpp" + +int main() { + queue q; + + { + constexpr static int NUM = 35; + int8_t ikeys[NUM] = {1, -11, 1, 9, -3, 100, 34, 8, 121, + -77, 125, 23, 36, -2, -111, 91, 88, -2, + 51, -23, -81, 83, 31, 42, 2, 1, -99, + 124, 12, 0, -81, 17, 15, 101, 44}; + uint32_t ivals[NUM] = {99, 32, 1, 2, 67, 9123, 453, + 435, 91111, 777, 165, 145, 2456, 88811, + 761, 96, 765, 10000, 6364, 90, 525, + 882, 1, 2423, 9, 4324, 9123, 0, + 1232, 777, 555, 314159, 905, 9831, 84341}; + uint8_t ivals1[NUM] = {99, 32, 1, 2, 67, 91, 45, 43, 91, 77, 16, 14, + 24, 88, 76, 96, 76, 100, 63, 90, 52, 82, 1, 22, + 9, 225, 127, 0, 12, 128, 3, 102, 200, 111, 123}; + int8_t ivals2[NUM] = {-99, 127, -121, 100, 9, 5, 12, 35, -98, + 77, 112, -91, 11, 12, 3, 71, -66, 121, + 18, 14, 21, -22, 54, 88, -81, 31, 23, + 53, 97, 103, 71, 83, 97, 37, -41}; + uint16_t ivals3[NUM] = {28831, 23870, 54250, 5022, 9571, 60147, 9554, + 18818, 28689, 18229, 40512, 23200, 40454, 24841, + 43251, 63264, 29448, 45917, 882, 30788, 7586, + 57541, 22108, 59535, 31880, 7152, 63919, 58703, + 14686, 29914, 5872, 35868, 51479, 22721, 50927}; + int16_t ivals4[NUM] = { + 2798, -13656, 1592, 3992, -25870, 25172, 7761, -18347, 1617, + 25472, 26763, -5982, 24791, 27189, 22911, 22502, 15801, 25326, + -2196, 9205, -10418, 20464, -16616, -11285, 7249, 22866, 30574, + -1298, 31351, 28252, 21322, -10072, 7874, -26785, 22016}; + + uint32_t ivals5[NUM] = { + 2238578408, 102907035, 2316773768, 617902655, 532045482, 73173328, + 1862406505, 142735533, 3494078873, 610196959, 4210902254, 1863122236, + 1257721692, 30008197, 3199012044, 3503276708, 3504950001, 1240383071, + 2463430884, 904104390, 4044803029, 3164373711, 1586440767, 1999536602, + 3377662770, 927279985, 1740225703, 1133653675, 3975816601, 260339911, + 1115507520, 2279020820, 4289105012, 692964674, 53775301}; + + int32_t ivals6[NUM] = { + 507394811, 1949685322, 1624859474, -940434061, -1440675113, + -2002743224, 369969519, 840772268, 224522238, 296113452, + -714007528, 480713824, 665592454, 1696360848, 780843358, + -1901994531, 1667711523, 1390737696, 1357434904, -290165630, + 305128121, 1301489180, 630469211, -1385846315, 809333959, + 1098974670, 56900257, 876775101, -1496897817, 1172877939, + 1528916082, 559152364, 749878571, 2071902702, -430851798}; + + uint64_t ivals7[NUM] = {7916688577774406903ULL, 16873442000174444235ULL, + 5261140449171682429ULL, 6274209348061756377ULL, + 17881284978944229367ULL, 4701456380424752599ULL, + 6241062870187348613ULL, 8972524466137433448ULL, + 468629112944127776ULL, 17523909311582893643ULL, + 17447772191733931166ULL, 14311152396797789854ULL, + 9265327272409079662ULL, 9958475911404242556ULL, + 15359829357736068242ULL, 11416531655738519189ULL, + 16839972510321195914ULL, 1927049095689256442ULL, + 3356565661065236628ULL, 1065114285796701562ULL, + 7071763288904613033ULL, 16473053015775147286ULL, + 10317354477399696817ULL, 16005969584273256379ULL, + 15391010921709289298ULL, 17671303749287233862ULL, + 8028596930095411867ULL, 10265936863337610975ULL, + 17403745948359398749ULL, 8504886230707230194ULL, + 12855085916215721214ULL, 5562885793068933146ULL, + 1508385574711135517ULL, 5953119477818575536ULL, + 9165320150094769334ULL}; + + int64_t ivals8[NUM] = { + 2944696543084623337, 137239101631340692, 4370169869966467498, + 3842452903439153631, -795080033670806202, 3023506421574592237, + -4142692575864168559, 1716333381567689984, 1591746912204250089, + -1974664662220599925, 3144022139297218102, -371429365537296255, + 4202906659701034264, 3878513012313576184, -3425767072006791628, + -2929337291418891626, 1880013370888913338, 1498977159463939728, + -2928775660744278650, 4074214991200977615, 4291797122649374026, + -763110360214750992, 2883673064242977727, 4270151072450399778, + 1408696225027958214, 1264214335825459628, -4152065441956669638, + 2684706424226400837, 569335474272794084, -2798088842177577224, + 814002749054152728, 2517003920904582842, 4089891582575745386, + 705067059635512048, -2500935118374519236}; + + auto work_group_sorter = [](int8_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1i8_p1u32_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1i8_p1u32_u32_p3i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1i8_p1u32_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1i8_p1u32_u32_p3i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter1 = [](int8_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1i8_p1u8_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1i8_p1u8_u32_p3i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1i8_p1u8_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1i8_p1u8_u32_p3i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter2 = [](int8_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1i8_p1i8_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1i8_p1i8_u32_p3i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1i8_p1i8_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1i8_p1i8_u32_p3i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter3 = [](int8_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1i8_p1u16_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1i8_p1u16_u32_p3i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1i8_p1u16_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1i8_p1u16_u32_p3i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter4 = [](int8_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1i8_p1i16_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1i8_p1i16_u32_p3i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1i8_p1i16_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1i8_p1i16_u32_p3i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter5 = [](int8_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1i8_p1u32_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1i8_p1u32_u32_p3i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1i8_p1u32_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1i8_p1u32_u32_p3i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter6 = [](int8_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1i8_p1i32_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1i8_p1i32_u32_p3i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1i8_p1i32_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1i8_p1i32_u32_p3i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter7 = [](int8_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1i8_p1u64_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1i8_p1u64_u32_p3i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1i8_p1u64_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1i8_p1u64_u32_p3i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter8 = [](int8_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1i8_p1i64_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1i8_p1i64_u32_p3i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1i8_p1i64_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1i8_p1i64_u32_p3i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + constexpr static int NUM1 = 32; + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 4 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 8 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 16 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 32 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 5 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 7 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 35 pass." << std::endl; + + constexpr static int NUM2 = 24; + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 3 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 4 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 6 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 8 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 12 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 24 pass." << std::endl; + + constexpr static int NUM3 = 20; + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 4 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 5 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 10 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 20 pass." << std::endl; + + constexpr static int NUM4 = 30; + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 3 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 5 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 6 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 10 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 15 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 30 pass." << std::endl; + + constexpr static int NUM5 = 25; + test_work_group_KV_private_sort( + q, ikeys, ivals6, work_group_sorter6); + std::cout << "KV private sort NUM = " << NUM5 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals6, work_group_sorter6); + std::cout << "KV private sort NUM = " << NUM5 + << ", WG = 5 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals6, work_group_sorter6); + std::cout << "KV private sort NUM = " << NUM5 + << ", WG = 25 pass." << std::endl; + + constexpr static int NUM6 = 30; + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 3 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 5 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 6 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 10 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 15 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 30 pass." << std::endl; + + constexpr static int NUM7 = 21; + test_work_group_KV_private_sort( + q, ikeys, ivals8, work_group_sorter8); + std::cout << "KV private sort NUM = " << NUM7 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals8, work_group_sorter8); + std::cout << "KV private sort NUM = " << NUM7 + << ", WG = 3 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals8, work_group_sorter8); + std::cout << "KV private sort NUM = " << NUM7 + << ", WG = 7 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals8, work_group_sorter8); + std::cout << "KV private sort NUM = " << NUM7 + << ", WG = 21 pass." << std::endl; + } +} diff --git a/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_p3_u16.cpp b/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_p3_u16.cpp new file mode 100644 index 0000000000000..25cd3791ec13e --- /dev/null +++ b/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_p3_u16.cpp @@ -0,0 +1,609 @@ +// RUN: %{build} -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -fsycl-device-lib-jit-link -o %t.out +// RUN: %{run} %t.out + + +// RUN: %{build} -DDES -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DDES -fsycl-device-lib-jit-link -o %t.out +// RUN: %{run} %t.out + + +// RUN: %{build} -DSPREAD -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DSPREAD -fsycl-device-lib-jit-link -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DDES -DSPREAD -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DDES -DSPREAD -fsycl-device-lib-jit-link -o %t.out +// RUN: %{run} %t.out + +// UNSUPPORTED: cuda || hip || cpu + +#include "group_private_KV_sort_p1p1_p3.hpp" + +int main() { + queue q; + + { + constexpr static int NUM = 35; + uint16_t ikeys[NUM] = {1, 11, 1, 9, 3, 100, 34, 8, 121, + 77, 125, 23, 36, 2, 111, 91, 88, 2, + 51, 213, 181, 183, 31, 142, 216, 1, 199, + 124, 12, 0, 181, 17, 15, 101, 44}; + uint32_t ivals[NUM] = {99, 32, 1, 2, 67, 9123, 453, + 435, 91111, 777, 165, 145, 2456, 88811, + 761, 96, 765, 10000, 6364, 90, 525, + 882, 1, 2423, 9, 4324, 9123, 0, + 1232, 777, 555, 314159, 905, 9831, 84341}; + uint8_t ivals1[NUM] = {99, 32, 1, 2, 67, 91, 45, 43, 91, 77, 16, 14, + 24, 88, 76, 96, 76, 100, 63, 90, 52, 82, 1, 22, + 9, 225, 127, 0, 12, 128, 3, 102, 200, 111, 123}; + int8_t ivals2[NUM] = {-99, 127, -121, 100, 9, 5, 12, 35, -98, + 77, 112, -91, 11, 12, 3, 71, -66, 121, + 18, 14, 21, -22, 54, 88, -81, 31, 23, + 53, 97, 103, 71, 83, 97, 37, -41}; + uint16_t ivals3[NUM] = {28831, 23870, 54250, 5022, 9571, 60147, 9554, + 18818, 28689, 18229, 40512, 23200, 40454, 24841, + 43251, 63264, 29448, 45917, 882, 30788, 7586, + 57541, 22108, 59535, 31880, 7152, 63919, 58703, + 14686, 29914, 5872, 35868, 51479, 22721, 50927}; + int16_t ivals4[NUM] = { + 2798, -13656, 1592, 3992, -25870, 25172, 7761, -18347, 1617, + 25472, 26763, -5982, 24791, 27189, 22911, 22502, 15801, 25326, + -2196, 9205, -10418, 20464, -16616, -11285, 7249, 22866, 30574, + -1298, 31351, 28252, 21322, -10072, 7874, -26785, 22016}; + + uint32_t ivals5[NUM] = { + 2238578408, 102907035, 2316773768, 617902655, 532045482, 73173328, + 1862406505, 142735533, 3494078873, 610196959, 4210902254, 1863122236, + 1257721692, 30008197, 3199012044, 3503276708, 3504950001, 1240383071, + 2463430884, 904104390, 4044803029, 3164373711, 1586440767, 1999536602, + 3377662770, 927279985, 1740225703, 1133653675, 3975816601, 260339911, + 1115507520, 2279020820, 4289105012, 692964674, 53775301}; + + int32_t ivals6[NUM] = { + 507394811, 1949685322, 1624859474, -940434061, -1440675113, + -2002743224, 369969519, 840772268, 224522238, 296113452, + -714007528, 480713824, 665592454, 1696360848, 780843358, + -1901994531, 1667711523, 1390737696, 1357434904, -290165630, + 305128121, 1301489180, 630469211, -1385846315, 809333959, + 1098974670, 56900257, 876775101, -1496897817, 1172877939, + 1528916082, 559152364, 749878571, 2071902702, -430851798}; + + uint64_t ivals7[NUM] = {7916688577774406903ULL, 16873442000174444235ULL, + 5261140449171682429ULL, 6274209348061756377ULL, + 17881284978944229367ULL, 4701456380424752599ULL, + 6241062870187348613ULL, 8972524466137433448ULL, + 468629112944127776ULL, 17523909311582893643ULL, + 17447772191733931166ULL, 14311152396797789854ULL, + 9265327272409079662ULL, 9958475911404242556ULL, + 15359829357736068242ULL, 11416531655738519189ULL, + 16839972510321195914ULL, 1927049095689256442ULL, + 3356565661065236628ULL, 1065114285796701562ULL, + 7071763288904613033ULL, 16473053015775147286ULL, + 10317354477399696817ULL, 16005969584273256379ULL, + 15391010921709289298ULL, 17671303749287233862ULL, + 8028596930095411867ULL, 10265936863337610975ULL, + 17403745948359398749ULL, 8504886230707230194ULL, + 12855085916215721214ULL, 5562885793068933146ULL, + 1508385574711135517ULL, 5953119477818575536ULL, + 9165320150094769334ULL}; + + int64_t ivals8[NUM] = { + 2944696543084623337, 137239101631340692, 4370169869966467498, + 3842452903439153631, -795080033670806202, 3023506421574592237, + -4142692575864168559, 1716333381567689984, 1591746912204250089, + -1974664662220599925, 3144022139297218102, -371429365537296255, + 4202906659701034264, 3878513012313576184, -3425767072006791628, + -2929337291418891626, 1880013370888913338, 1498977159463939728, + -2928775660744278650, 4074214991200977615, 4291797122649374026, + -763110360214750992, 2883673064242977727, 4270151072450399778, + 1408696225027958214, 1264214335825459628, -4152065441956669638, + 2684706424226400837, 569335474272794084, -2798088842177577224, + 814002749054152728, 2517003920904582842, 4089891582575745386, + 705067059635512048, -2500935118374519236}; + + auto work_group_sorter = [](uint16_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1u16_p1u32_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1u16_p1u32_u32_p3i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1u16_p1u32_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1u16_p1u32_u32_p3i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter1 = [](uint16_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1u16_p1u8_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1u16_p1u8_u32_p3i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1u16_p1u8_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1u16_p1u8_u32_p3i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter2 = [](uint16_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1u16_p1i8_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1u16_p1i8_u32_p3i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1u16_p1i8_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1u16_p1i8_u32_p3i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter3 = [](uint16_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1u16_p1u16_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1u16_p1u16_u32_p3i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1u16_p1u16_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1u16_p1u16_u32_p3i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter4 = [](uint16_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1u16_p1i16_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1u16_p1i16_u32_p3i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1u16_p1i16_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1u16_p1i16_u32_p3i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter5 = [](uint16_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1u16_p1u32_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1u16_p1u32_u32_p3i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1u16_p1u32_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1u16_p1u32_u32_p3i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter6 = [](uint16_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1u16_p1i32_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1u16_p1i32_u32_p3i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1u16_p1i32_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1u16_p1i32_u32_p3i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter7 = [](uint16_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1u16_p1u64_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1u16_p1u64_u32_p3i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1u16_p1u64_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1u16_p1u64_u32_p3i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter8 = [](uint16_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1u16_p1i64_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1u16_p1i64_u32_p3i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1u16_p1i64_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1u16_p1i64_u32_p3i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + constexpr static int NUM1 = 32; + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 4 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 8 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 16 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 32 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 5 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 7 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 35 pass." << std::endl; + + constexpr static int NUM2 = 24; + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 3 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 4 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 6 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 8 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 12 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 24 pass." << std::endl; + + constexpr static int NUM3 = 20; + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 4 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 5 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 10 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 20 pass." << std::endl; + + constexpr static int NUM4 = 30; + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 3 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 5 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 6 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 10 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 15 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 30 pass." << std::endl; + + constexpr static int NUM5 = 25; + test_work_group_KV_private_sort( + q, ikeys, ivals6, work_group_sorter6); + std::cout << "KV private sort NUM = " << NUM5 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals6, work_group_sorter6); + std::cout << "KV private sort NUM = " << NUM5 + << ", WG = 5 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals6, work_group_sorter6); + std::cout << "KV private sort NUM = " << NUM5 + << ", WG = 25 pass." << std::endl; + + constexpr static int NUM6 = 30; + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 3 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 5 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 6 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 10 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 15 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 30 pass." << std::endl; + + constexpr static int NUM7 = 21; + test_work_group_KV_private_sort( + q, ikeys, ivals8, work_group_sorter8); + std::cout << "KV private sort NUM = " << NUM7 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals8, work_group_sorter8); + std::cout << "KV private sort NUM = " << NUM7 + << ", WG = 3 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals8, work_group_sorter8); + std::cout << "KV private sort NUM = " << NUM7 + << ", WG = 7 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals8, work_group_sorter8); + std::cout << "KV private sort NUM = " << NUM7 + << ", WG = 21 pass." << std::endl; + } +} diff --git a/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_p3_u32.cpp b/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_p3_u32.cpp new file mode 100644 index 0000000000000..bd3b67e3476ae --- /dev/null +++ b/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_p3_u32.cpp @@ -0,0 +1,695 @@ +// RUN: %{build} -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -fsycl-device-lib-jit-link -o %t.out +// RUN: %{run} %t.out + + +// RUN: %{build} -DDES -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DDES -fsycl-device-lib-jit-link -o %t.out +// RUN: %{run} %t.out + + +// RUN: %{build} -DSPREAD -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DSPREAD -fsycl-device-lib-jit-link -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DDES -DSPREAD -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DDES -DSPREAD -fsycl-device-lib-jit-link -o %t.out +// RUN: %{run} %t.out + +// UNSUPPORTED: cuda || hip || cpu + +#include "group_private_KV_sort_p1p1_p3.hpp" + +int main() { + queue q; + + { + constexpr static int NUM = 40; + uint32_t ikeys[NUM] = { + 1, 11, 1, 9, 3, 100, 34, 8, + 121, 77, 125, 23, 222336, 2, 111, 91, + 881112, 2, 6662451, 213, 199181, 3183, 310910, 11242, + 3216, 1, 199, 7712423, 0, 0, 181, 17, + 15, 101, 44, 103934, 1, 11, 193, 213}; + uint8_t ivals1[NUM] = {99, 32, 1, 2, 67, 91, 45, 43, 91, 77, + 16, 14, 24, 88, 76, 96, 76, 100, 63, 90, + 52, 82, 1, 22, 9, 225, 127, 0, 12, 128, + 3, 102, 200, 111, 123, 15, 45, 66, 123, 91}; + int8_t ivals2[NUM] = {-99, 127, -121, 100, 9, 5, 12, 35, -98, 77, + 112, -91, 11, 12, 3, 71, -66, 121, 18, 14, + 21, -22, 54, 88, -81, 31, 23, 53, 97, 103, + 71, 83, 97, 37, -41, -71, 112, -121, 98, 78}; + uint16_t ivals3[NUM] = { + 28831, 23870, 54250, 5022, 9571, 60147, 9554, 18818, 28689, 18229, + 40512, 23200, 40454, 24841, 43251, 63264, 29448, 45917, 882, 30788, + 7586, 57541, 22108, 59535, 31880, 7152, 63919, 58703, 14686, 29914, + 5872, 35868, 51479, 22721, 50927, 55094, 2341, 12, 23411, 9812}; + int16_t ivals4[NUM] = { + 2798, -13656, 1592, 3992, -25870, 25172, 7761, -18347, + 1617, 25472, 26763, -5982, 24791, 27189, 22911, 22502, + 15801, 25326, -2196, 9205, -10418, 20464, -16616, -11285, + 7249, 22866, 30574, -1298, 31351, 28252, 21322, -10072, + 7874, -26785, 22016, -12421, 0, 9999, 8888, -7777}; + + uint32_t ivals5[NUM] = { + 2238578408, 102907035, 2316773768, 617902655, 532045482, 73173328, + 1862406505, 142735533, 3494078873, 610196959, 4210902254, 1863122236, + 1257721692, 30008197, 3199012044, 3503276708, 3504950001, 1240383071, + 2463430884, 904104390, 4044803029, 3164373711, 1586440767, 1999536602, + 3377662770, 927279985, 1740225703, 1133653675, 3975816601, 260339911, + 1115507520, 2279020820, 4289105012, 692964674, 53775301, 999999999, + 777777777, 44444444, 1000, 9124}; + + int32_t ivals6[NUM] = { + 507394811, 1949685322, 1624859474, -940434061, -1440675113, + -2002743224, 369969519, 840772268, 224522238, 296113452, + -714007528, 480713824, 665592454, 1696360848, 780843358, + -1901994531, 1667711523, 1390737696, 1357434904, -290165630, + 305128121, 1301489180, 630469211, -1385846315, 809333959, + 1098974670, 56900257, 876775101, -1496897817, 1172877939, + 1528916082, 559152364, 749878571, 2071902702, -430851798, + 912342, -88888888, 777777777, -11111111, 0}; + + uint64_t ivals7[NUM] = {7916688577774406903ULL, + 16873442000174444235ULL, + 5261140449171682429ULL, + 6274209348061756377ULL, + 17881284978944229367ULL, + 4701456380424752599ULL, + 6241062870187348613ULL, + 8972524466137433448ULL, + 468629112944127776ULL, + 17523909311582893643ULL, + 17447772191733931166ULL, + 14311152396797789854ULL, + 9265327272409079662ULL, + 9958475911404242556ULL, + 15359829357736068242ULL, + 11416531655738519189ULL, + 16839972510321195914ULL, + 1927049095689256442ULL, + 3356565661065236628ULL, + 1065114285796701562ULL, + 7071763288904613033ULL, + 16473053015775147286ULL, + 10317354477399696817ULL, + 16005969584273256379ULL, + 15391010921709289298ULL, + 17671303749287233862ULL, + 8028596930095411867ULL, + 10265936863337610975ULL, + 17403745948359398749ULL, + 8504886230707230194ULL, + 12855085916215721214ULL, + 5562885793068933146ULL, + 1508385574711135517ULL, + 5953119477818575536ULL, + 9165320150094769334ULL, + 0, + 11, + 324, + 943534, + 930525}; + + int64_t ivals8[NUM] = {2944696543084623337, + 137239101631340692, + 4370169869966467498, + 3842452903439153631, + -795080033670806202, + 3023506421574592237, + -4142692575864168559, + 1716333381567689984, + 1591746912204250089, + -1974664662220599925, + 3144022139297218102, + -371429365537296255, + 4202906659701034264, + 3878513012313576184, + -3425767072006791628, + -2929337291418891626, + 1880013370888913338, + 1498977159463939728, + -2928775660744278650, + 4074214991200977615, + 4291797122649374026, + -763110360214750992, + 2883673064242977727, + 4270151072450399778, + 1408696225027958214, + 1264214335825459628, + -4152065441956669638, + 2684706424226400837, + 569335474272794084, + -2798088842177577224, + 814002749054152728, + 2517003920904582842, + 4089891582575745386, + 705067059635512048, + -2500935118374519236, + 0, + -1, + 234, + 45, + -1232143254}; + + auto work_group_sorter = [](uint32_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1u32_p1u32_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1u32_p1u32_u32_p3i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1u32_p1u32_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1u32_p1u32_u32_p3i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter1 = [](uint32_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1u32_p1u8_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1u32_p1u8_u32_p3i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1u32_p1u8_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1u32_p1u8_u32_p3i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter2 = [](uint32_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1u32_p3i8_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1u32_p3i8_u32_p3i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1u32_p3i8_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1u32_p3i8_u32_p3i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter3 = [](uint32_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1u32_p1u16_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1u32_p1u16_u32_p3i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1u32_p1u16_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1u32_p1u16_u32_p3i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter4 = [](uint32_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1u32_p1i16_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1u32_p1i16_u32_p3i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1u32_p1i16_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1u32_p1i16_u32_p3i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter5 = [](uint32_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1u32_p1u32_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1u32_p1u32_u32_p3i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1u32_p1u32_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1u32_p1u32_u32_p3i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter6 = [](uint32_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1u32_p1i32_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1u32_p1i32_u32_p3i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1u32_p1i32_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1u32_p1i32_u32_p3i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter7 = [](uint32_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1u32_p1u64_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1u32_p1u64_u32_p3i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1u32_p1u64_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1u32_p1u64_u32_p3i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter8 = [](uint32_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1u32_p1i64_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1u32_p1i64_u32_p3i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1u32_p1i64_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1u32_p1i64_u32_p3i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + constexpr static int NUM1 = 36; + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 4 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 3 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 6 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 9 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 18 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 36 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 5 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 8 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 4 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 10 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 20 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 40 pass." << std::endl; + + constexpr static int NUM2 = 24; + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 3 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 4 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 6 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 8 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 12 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 24 pass." << std::endl; + + constexpr static int NUM3 = 20; + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 4 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 5 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 10 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 20 pass." << std::endl; + + constexpr static int NUM4 = 30; + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 3 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 5 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 6 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 10 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 15 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 30 pass." << std::endl; + + constexpr static int NUM5 = 25; + test_work_group_KV_private_sort( + q, ikeys, ivals6, work_group_sorter6); + std::cout << "KV private sort NUM = " << NUM5 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals6, work_group_sorter6); + std::cout << "KV private sort NUM = " << NUM5 + << ", WG = 5 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals6, work_group_sorter6); + std::cout << "KV private sort NUM = " << NUM5 + << ", WG = 25 pass." << std::endl; + + constexpr static int NUM6 = 30; + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 3 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 5 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 6 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 10 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 15 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 30 pass." << std::endl; + + constexpr static int NUM7 = 21; + test_work_group_KV_private_sort( + q, ikeys, ivals8, work_group_sorter8); + std::cout << "KV private sort NUM = " << NUM7 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals8, work_group_sorter8); + std::cout << "KV private sort NUM = " << NUM7 + << ", WG = 3 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals8, work_group_sorter8); + std::cout << "KV private sort NUM = " << NUM7 + << ", WG = 7 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals8, work_group_sorter8); + std::cout << "KV private sort NUM = " << NUM7 + << ", WG = 21 pass." << std::endl; + } +} diff --git a/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_p3_u64.cpp b/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_p3_u64.cpp new file mode 100644 index 0000000000000..746b963414873 --- /dev/null +++ b/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_p3_u64.cpp @@ -0,0 +1,695 @@ +// RUN: %{build} -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -fsycl-device-lib-jit-link -o %t.out +// RUN: %{run} %t.out + + +// RUN: %{build} -DDES -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DDES -fsycl-device-lib-jit-link -o %t.out +// RUN: %{run} %t.out + + +// RUN: %{build} -DSPREAD -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DSPREAD -fsycl-device-lib-jit-link -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DDES -DSPREAD -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DDES -DSPREAD -fsycl-device-lib-jit-link -o %t.out +// RUN: %{run} %t.out + +// UNSUPPORTED: cuda || hip || cpu + +#include "group_private_KV_sort_p1p1_p3.hpp" + +int main() { + queue q; + + { + constexpr static int NUM = 40; + uint64_t ikeys[NUM] = { + 1, 11, 1, 9, 3, 100, 34, 8, + 121, 77, 125, 23, 222336, 2, 111, 91, + 881112, 2, 6662451, 213, 199181, 3183, 310910, 11242, + 3216, 1, 199, 7712423, 0, 0, 181, 17, + 15, 101, 44, 103934, 1, 11, 193, 213}; + uint8_t ivals1[NUM] = {99, 32, 1, 2, 67, 91, 45, 43, 91, 77, + 16, 14, 24, 88, 76, 96, 76, 100, 63, 90, + 52, 82, 1, 22, 9, 225, 127, 0, 12, 128, + 3, 102, 200, 111, 123, 15, 45, 66, 123, 91}; + int8_t ivals2[NUM] = {-99, 127, -121, 100, 9, 5, 12, 35, -98, 77, + 112, -91, 11, 12, 3, 71, -66, 121, 18, 14, + 21, -22, 54, 88, -81, 31, 23, 53, 97, 103, + 71, 83, 97, 37, -41, -71, 112, -121, 98, 78}; + uint16_t ivals3[NUM] = { + 28831, 23870, 54250, 5022, 9571, 60147, 9554, 18818, 28689, 18229, + 40512, 23200, 40454, 24841, 43251, 63264, 29448, 45917, 882, 30788, + 7586, 57541, 22108, 59535, 31880, 7152, 63919, 58703, 14686, 29914, + 5872, 35868, 51479, 22721, 50927, 55094, 2341, 12, 23411, 9812}; + int16_t ivals4[NUM] = { + 2798, -13656, 1592, 3992, -25870, 25172, 7761, -18347, + 1617, 25472, 26763, -5982, 24791, 27189, 22911, 22502, + 15801, 25326, -2196, 9205, -10418, 20464, -16616, -11285, + 7249, 22866, 30574, -1298, 31351, 28252, 21322, -10072, + 7874, -26785, 22016, -12421, 0, 9999, 8888, -7777}; + + uint32_t ivals5[NUM] = { + 2238578408, 102907035, 2316773768, 617902655, 532045482, 73173328, + 1862406505, 142735533, 3494078873, 610196959, 4210902254, 1863122236, + 1257721692, 30008197, 3199012044, 3503276708, 3504950001, 1240383071, + 2463430884, 904104390, 4044803029, 3164373711, 1586440767, 1999536602, + 3377662770, 927279985, 1740225703, 1133653675, 3975816601, 260339911, + 1115507520, 2279020820, 4289105012, 692964674, 53775301, 999999999, + 777777777, 44444444, 1000, 9124}; + + int32_t ivals6[NUM] = { + 507394811, 1949685322, 1624859474, -940434061, -1440675113, + -2002743224, 369969519, 840772268, 224522238, 296113452, + -714007528, 480713824, 665592454, 1696360848, 780843358, + -1901994531, 1667711523, 1390737696, 1357434904, -290165630, + 305128121, 1301489180, 630469211, -1385846315, 809333959, + 1098974670, 56900257, 876775101, -1496897817, 1172877939, + 1528916082, 559152364, 749878571, 2071902702, -430851798, + 912342, -88888888, 777777777, -11111111, 0}; + + uint64_t ivals7[NUM] = {7916688577774406903ULL, + 16873442000174444235ULL, + 5261140449171682429ULL, + 6274209348061756377ULL, + 17881284978944229367ULL, + 4701456380424752599ULL, + 6241062870187348613ULL, + 8972524466137433448ULL, + 468629112944127776ULL, + 17523909311582893643ULL, + 17447772191733931166ULL, + 14311152396797789854ULL, + 9265327272409079662ULL, + 9958475911404242556ULL, + 15359829357736068242ULL, + 11416531655738519189ULL, + 16839972510321195914ULL, + 1927049095689256442ULL, + 3356565661065236628ULL, + 1065114285796701562ULL, + 7071763288904613033ULL, + 16473053015775147286ULL, + 10317354477399696817ULL, + 16005969584273256379ULL, + 15391010921709289298ULL, + 17671303749287233862ULL, + 8028596930095411867ULL, + 10265936863337610975ULL, + 17403745948359398749ULL, + 8504886230707230194ULL, + 12855085916215721214ULL, + 5562885793068933146ULL, + 1508385574711135517ULL, + 5953119477818575536ULL, + 9165320150094769334ULL, + 0, + 11, + 324, + 943534, + 930525}; + + int64_t ivals8[NUM] = {2944696543084623337, + 137239101631340692, + 4370169869966467498, + 3842452903439153631, + -795080033670806202, + 3023506421574592237, + -4142692575864168559, + 1716333381567689984, + 1591746912204250089, + -1974664662220599925, + 3144022139297218102, + -371429365537296255, + 4202906659701034264, + 3878513012313576184, + -3425767072006791628, + -2929337291418891626, + 1880013370888913338, + 1498977159463939728, + -2928775660744278650, + 4074214991200977615, + 4291797122649374026, + -763110360214750992, + 2883673064242977727, + 4270151072450399778, + 1408696225027958214, + 1264214335825459628, + -4152065441956669638, + 2684706424226400837, + 569335474272794084, + -2798088842177577224, + 814002749054152728, + 2517003920904582842, + 4089891582575745386, + 705067059635512048, + -2500935118374519236, + 0, + -1, + 234, + 45, + -1232143254}; + + auto work_group_sorter = [](uint64_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1u64_p1u32_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1u64_p1u32_u32_p3i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1u64_p1u32_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1u64_p1u32_u32_p3i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter1 = [](uint64_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1u64_p1u8_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1u64_p1u8_u32_p3i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1u64_p1u8_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1u64_p1u8_u32_p3i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter2 = [](uint64_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1u64_p1i8_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1u64_p1i8_u32_p3i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1u64_p1i8_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1u64_p1i8_u32_p3i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter3 = [](uint64_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1u64_p1u16_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1u64_p1u16_u32_p3i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1u64_p1u16_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1u64_p1u16_u32_p3i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter4 = [](uint64_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1u64_p1i16_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1u64_p1i16_u32_p3i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1u64_p1i16_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1u64_p1i16_u32_p3i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter5 = [](uint64_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1u64_p1u32_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1u64_p1u32_u32_p3i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1u64_p1u32_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1u64_p1u32_u32_p3i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter6 = [](uint64_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1u64_p1i32_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1u64_p1i32_u32_p3i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1u64_p1i32_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1u64_p1i32_u32_p3i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter7 = [](uint64_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1u64_p1u64_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1u64_p1u64_u32_p3i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1u64_p1u64_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1u64_p1u64_u32_p3i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter8 = [](uint64_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1u64_p1i64_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1u64_p1i64_u32_p3i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1u64_p1i64_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1u64_p1i64_u32_p3i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + constexpr static int NUM1 = 36; + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 4 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 3 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 6 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 9 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 18 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 36 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 5 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 8 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 4 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 10 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 20 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 40 pass." << std::endl; + + constexpr static int NUM2 = 24; + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 3 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 4 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 6 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 8 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 12 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 24 pass." << std::endl; + + constexpr static int NUM3 = 20; + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 4 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 5 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 10 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 20 pass." << std::endl; + + constexpr static int NUM4 = 30; + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 3 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 5 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 6 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 10 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 15 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 30 pass." << std::endl; + + constexpr static int NUM5 = 25; + test_work_group_KV_private_sort( + q, ikeys, ivals6, work_group_sorter6); + std::cout << "KV private sort NUM = " << NUM5 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals6, work_group_sorter6); + std::cout << "KV private sort NUM = " << NUM5 + << ", WG = 5 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals6, work_group_sorter6); + std::cout << "KV private sort NUM = " << NUM5 + << ", WG = 25 pass." << std::endl; + + constexpr static int NUM6 = 30; + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 3 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 5 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 6 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 10 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 15 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 30 pass." << std::endl; + + constexpr static int NUM7 = 21; + test_work_group_KV_private_sort( + q, ikeys, ivals8, work_group_sorter8); + std::cout << "KV private sort NUM = " << NUM7 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals8, work_group_sorter8); + std::cout << "KV private sort NUM = " << NUM7 + << ", WG = 3 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals8, work_group_sorter8); + std::cout << "KV private sort NUM = " << NUM7 + << ", WG = 7 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals8, work_group_sorter8); + std::cout << "KV private sort NUM = " << NUM7 + << ", WG = 21 pass." << std::endl; + } +} diff --git a/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_p3_u8.cpp b/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_p3_u8.cpp new file mode 100644 index 0000000000000..a34bbe4004f62 --- /dev/null +++ b/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_p3_u8.cpp @@ -0,0 +1,609 @@ +// RUN: %{build} -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -fsycl-device-lib-jit-link -o %t.out +// RUN: %{run} %t.out + + +// RUN: %{build} -DDES -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DDES -fsycl-device-lib-jit-link -o %t.out +// RUN: %{run} %t.out + + +// RUN: %{build} -DSPREAD -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DSPREAD -fsycl-device-lib-jit-link -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DDES -DSPREAD -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DDES -DSPREAD -fsycl-device-lib-jit-link -o %t.out +// RUN: %{run} %t.out + +// UNSUPPORTED: cuda || hip || cpu + +#include "group_private_KV_sort_p1p1_p3.hpp" + +int main() { + queue q; + + { + constexpr static int NUM = 35; + uint8_t ikeys[NUM] = {1, 11, 1, 9, 3, 100, 34, 8, 121, + 77, 125, 23, 36, 2, 111, 91, 88, 2, + 51, 213, 181, 183, 31, 142, 216, 1, 199, + 124, 12, 0, 181, 17, 15, 101, 44}; + uint32_t ivals[NUM] = {99, 32, 1, 2, 67, 9123, 453, + 435, 91111, 777, 165, 145, 2456, 88811, + 761, 96, 765, 10000, 6364, 90, 525, + 882, 1, 2423, 9, 4324, 9123, 0, + 1232, 777, 555, 314159, 905, 9831, 84341}; + uint8_t ivals1[NUM] = {99, 32, 1, 2, 67, 91, 45, 43, 91, 77, 16, 14, + 24, 88, 76, 96, 76, 100, 63, 90, 52, 82, 1, 22, + 9, 225, 127, 0, 12, 128, 3, 102, 200, 111, 123}; + int8_t ivals2[NUM] = {-99, 127, -121, 100, 9, 5, 12, 35, -98, + 77, 112, -91, 11, 12, 3, 71, -66, 121, + 18, 14, 21, -22, 54, 88, -81, 31, 23, + 53, 97, 103, 71, 83, 97, 37, -41}; + uint16_t ivals3[NUM] = {28831, 23870, 54250, 5022, 9571, 60147, 9554, + 18818, 28689, 18229, 40512, 23200, 40454, 24841, + 43251, 63264, 29448, 45917, 882, 30788, 7586, + 57541, 22108, 59535, 31880, 7152, 63919, 58703, + 14686, 29914, 5872, 35868, 51479, 22721, 50927}; + int16_t ivals4[NUM] = { + 2798, -13656, 1592, 3992, -25870, 25172, 7761, -18347, 1617, + 25472, 26763, -5982, 24791, 27189, 22911, 22502, 15801, 25326, + -2196, 9205, -10418, 20464, -16616, -11285, 7249, 22866, 30574, + -1298, 31351, 28252, 21322, -10072, 7874, -26785, 22016}; + + uint32_t ivals5[NUM] = { + 2238578408, 102907035, 2316773768, 617902655, 532045482, 73173328, + 1862406505, 142735533, 3494078873, 610196959, 4210902254, 1863122236, + 1257721692, 30008197, 3199012044, 3503276708, 3504950001, 1240383071, + 2463430884, 904104390, 4044803029, 3164373711, 1586440767, 1999536602, + 3377662770, 927279985, 1740225703, 1133653675, 3975816601, 260339911, + 1115507520, 2279020820, 4289105012, 692964674, 53775301}; + + int32_t ivals6[NUM] = { + 507394811, 1949685322, 1624859474, -940434061, -1440675113, + -2002743224, 369969519, 840772268, 224522238, 296113452, + -714007528, 480713824, 665592454, 1696360848, 780843358, + -1901994531, 1667711523, 1390737696, 1357434904, -290165630, + 305128121, 1301489180, 630469211, -1385846315, 809333959, + 1098974670, 56900257, 876775101, -1496897817, 1172877939, + 1528916082, 559152364, 749878571, 2071902702, -430851798}; + + uint64_t ivals7[NUM] = {7916688577774406903ULL, 16873442000174444235ULL, + 5261140449171682429ULL, 6274209348061756377ULL, + 17881284978944229367ULL, 4701456380424752599ULL, + 6241062870187348613ULL, 8972524466137433448ULL, + 468629112944127776ULL, 17523909311582893643ULL, + 17447772191733931166ULL, 14311152396797789854ULL, + 9265327272409079662ULL, 9958475911404242556ULL, + 15359829357736068242ULL, 11416531655738519189ULL, + 16839972510321195914ULL, 1927049095689256442ULL, + 3356565661065236628ULL, 1065114285796701562ULL, + 7071763288904613033ULL, 16473053015775147286ULL, + 10317354477399696817ULL, 16005969584273256379ULL, + 15391010921709289298ULL, 17671303749287233862ULL, + 8028596930095411867ULL, 10265936863337610975ULL, + 17403745948359398749ULL, 8504886230707230194ULL, + 12855085916215721214ULL, 5562885793068933146ULL, + 1508385574711135517ULL, 5953119477818575536ULL, + 9165320150094769334ULL}; + + int64_t ivals8[NUM] = { + 2944696543084623337, 137239101631340692, 4370169869966467498, + 3842452903439153631, -795080033670806202, 3023506421574592237, + -4142692575864168559, 1716333381567689984, 1591746912204250089, + -1974664662220599925, 3144022139297218102, -371429365537296255, + 4202906659701034264, 3878513012313576184, -3425767072006791628, + -2929337291418891626, 1880013370888913338, 1498977159463939728, + -2928775660744278650, 4074214991200977615, 4291797122649374026, + -763110360214750992, 2883673064242977727, 4270151072450399778, + 1408696225027958214, 1264214335825459628, -4152065441956669638, + 2684706424226400837, 569335474272794084, -2798088842177577224, + 814002749054152728, 2517003920904582842, 4089891582575745386, + 705067059635512048, -2500935118374519236}; + + auto work_group_sorter = [](uint8_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1u8_p1u32_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1u8_p1u32_u32_p3i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1u8_p1u32_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1u8_p1u32_u32_p3i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter1 = [](uint8_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1u8_p1u8_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1u8_p1u8_u32_p3i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1u8_p1u8_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1u8_p1u8_u32_p3i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter2 = [](uint8_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1u8_p1i8_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1u8_p1i8_u32_p3i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1u8_p1i8_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1u8_p1i8_u32_p3i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter3 = [](uint8_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1u8_p1u16_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1u8_p1u16_u32_p3i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1u8_p1u16_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1u8_p1u16_u32_p3i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter4 = [](uint8_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1u8_p1i16_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1u8_p1i16_u32_p3i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1u8_p1i16_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1u8_p1i16_u32_p3i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter5 = [](uint8_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#if SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1u8_p1u32_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1u8_p1u32_u32_p3i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1u8_p1u32_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1u8_p1u32_u32_p3i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter6 = [](uint8_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1u8_p1i32_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1u8_p1i32_u32_p3i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1u8_p1i32_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1u8_p1i32_u32_p3i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter7 = [](uint8_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1u8_p1u64_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1u8_p1u64_u32_p3i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1u8_p1u64_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1u8_p1u64_u32_p3i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter8 = [](uint8_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1u8_p1i64_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1u8_p1i64_u32_p3i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1u8_p1i64_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1u8_p1i64_u32_p3i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + constexpr static int NUM1 = 32; + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 4 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 8 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 16 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 32 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 5 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 7 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 35 pass." << std::endl; + + constexpr static int NUM2 = 24; + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 3 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 4 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 6 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 8 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 12 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 24 pass." << std::endl; + + constexpr static int NUM3 = 20; + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 4 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 5 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 10 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 20 pass." << std::endl; + + constexpr static int NUM4 = 30; + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 3 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 5 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 6 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 10 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 15 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 30 pass." << std::endl; + + constexpr static int NUM5 = 25; + test_work_group_KV_private_sort( + q, ikeys, ivals6, work_group_sorter6); + std::cout << "KV private sort NUM = " << NUM5 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals6, work_group_sorter6); + std::cout << "KV private sort NUM = " << NUM5 + << ", WG = 5 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals6, work_group_sorter6); + std::cout << "KV private sort NUM = " << NUM5 + << ", WG = 25 pass." << std::endl; + + constexpr static int NUM6 = 30; + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 3 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 5 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 6 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 10 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 15 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 30 pass." << std::endl; + + constexpr static int NUM7 = 21; + test_work_group_KV_private_sort( + q, ikeys, ivals8, work_group_sorter8); + std::cout << "KV private sort NUM = " << NUM7 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals8, work_group_sorter8); + std::cout << "KV private sort NUM = " << NUM7 + << ", WG = 3 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals8, work_group_sorter8); + std::cout << "KV private sort NUM = " << NUM7 + << ", WG = 7 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals8, work_group_sorter8); + std::cout << "KV private sort NUM = " << NUM7 + << ", WG = 21 pass." << std::endl; + } +} From b079feb6a98b0a9903b3ff5c7a22663bb14f0d1b Mon Sep 17 00:00:00 2001 From: jinge90 Date: Mon, 14 Oct 2024 11:28:21 +0800 Subject: [PATCH 62/71] add subgroup key-value sort for 1-element for u8/u16 key Signed-off-by: jinge90 --- libdevice/fallback-gsort.cpp | 247 +++++++++++++++++++++++++++++++++++ 1 file changed, 247 insertions(+) diff --git a/libdevice/fallback-gsort.cpp b/libdevice/fallback-gsort.cpp index 3eb02dd02bf62..fd0a5eb9018d4 100644 --- a/libdevice/fallback-gsort.cpp +++ b/libdevice/fallback-gsort.cpp @@ -6985,4 +6985,251 @@ void WG_PS_SD(p1i64_p1f32_u32_p3i8)(int64_t *keys, float *vals, uint32_t n, n, scratch, std::greater_equal{}); } + +// 1-element version of subgroup private sort. +DEVICE_EXTERN_C_INLINE +void SG_PS_A(p1u8_p1u8_u32_p1i8)(uint8_t *key, uint8_t *val, uint8_t *scratch) { + private_merge_sort_key_value_close(key, val, 1, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void SG_PS_D(p1u8_p1u8_u32_p1i8)(uint8_t *key, uint8_t *val, uint8_t *scratch) { + private_merge_sort_key_value_close(key, val, 1, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void SG_PS_A(p1u8_p1i8_u32_p1i8)(uint8_t *key, int8_t *val, uint8_t *scratch) { + private_merge_sort_key_value_close(key, reinterpret_cast(val), 1, + scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void SG_PS_D(p1u8_p1i8_u32_p1i8)(uint8_t *key, int8_t *val, uint8_t *scratch) { + private_merge_sort_key_value_close(key, reinterpret_cast(val), 1, + scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void SG_PS_A(p1u8_p1u16_u32_p1i8)(uint8_t *key, uint16_t *val, + uint8_t *scratch) { + private_merge_sort_key_value_close(key, val, 1, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void SG_PS_D(p1u8_p1u16_u32_p1i8)(uint8_t *key, uint16_t *val, + uint8_t *scratch) { + private_merge_sort_key_value_close(key, val, 1, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void SG_PS_A(p1u8_p1i16_u32_p1i8)(uint8_t *key, int16_t *val, + uint8_t *scratch) { + private_merge_sort_key_value_close(key, reinterpret_cast(val), 1, + scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void SG_PS_D(p1u8_p1i16_u32_p1i8)(uint8_t *key, int16_t *val, + uint8_t *scratch) { + private_merge_sort_key_value_close(key, reinterpret_cast(val), 1, + scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void SG_PS_A(p1u8_p1u32_u32_p1i8)(uint8_t *key, uint32_t *val, + uint8_t *scratch) { + private_merge_sort_key_value_close(key, val, 1, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void SG_PS_D(p1u8_p1u32_u32_p1i8)(uint8_t *key, uint32_t *val, + uint8_t *scratch) { + private_merge_sort_key_value_close(key, val, 1, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void SG_PS_A(p1u8_p1i32_u32_p1i8)(uint8_t *key, int32_t *val, + uint8_t *scratch) { + private_merge_sort_key_value_close(key, reinterpret_cast(val), 1, + scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void SG_PS_D(p1u8_p1i32_u32_p1i8)(uint8_t *key, int32_t *val, + uint8_t *scratch) { + private_merge_sort_key_value_close(key, reinterpret_cast(val), 1, + scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void SG_PS_A(p1u8_p1u64_u32_p1i8)(uint8_t *key, uint64_t *val, + uint8_t *scratch) { + private_merge_sort_key_value_close(key, val, 1, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void SG_PS_D(p1u8_p1u64_u32_p1i8)(uint8_t *key, uint64_t *val, + uint8_t *scratch) { + private_merge_sort_key_value_close(key, val, 1, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void SG_PS_A(p1u8_p1i64_u32_p1i8)(uint8_t *key, int64_t *val, + uint8_t *scratch) { + private_merge_sort_key_value_close(key, reinterpret_cast(val), 1, + scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void SG_PS_D(p1u8_p1i64_u32_p1i8)(uint8_t *key, int64_t *val, + uint8_t *scratch) { + private_merge_sort_key_value_close(key, reinterpret_cast(val), 1, + scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void SG_PS_A(p1u8_p1f32_u32_p1i8)(uint8_t *key, float *val, uint8_t *scratch) { + private_merge_sort_key_value_close(key, reinterpret_cast(val), 1, + scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void SG_PS_D(p1u8_p1f32_u32_p1i8)(uint8_t *key, float *val, uint8_t *scratch) { + private_merge_sort_key_value_close(key, reinterpret_cast(val), 1, + scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void SG_PS_A(p1u16_p1u8_u32_p1i8)(uint16_t *key, uint8_t *val, + uint8_t *scratch) { + private_merge_sort_key_value_close(key, val, 1, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void SG_PS_D(p1u16_p1u8_u32_p1i8)(uint16_t *key, uint8_t *val, + uint8_t *scratch) { + private_merge_sort_key_value_close(key, val, 1, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void SG_PS_A(p1u16_p1i8_u32_p1i8)(uint16_t *key, int8_t *val, + uint8_t *scratch) { + private_merge_sort_key_value_close(key, reinterpret_cast(val), 1, + scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void SG_PS_D(p1u16_p1i8_u32_p1i8)(uint16_t *key, int8_t *val, + uint8_t *scratch) { + private_merge_sort_key_value_close(key, reinterpret_cast(val), 1, + scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void SG_PS_A(p1u16_p1u16_u32_p1i8)(uint16_t *key, uint16_t *val, + uint8_t *scratch) { + private_merge_sort_key_value_close(key, val, 1, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void SG_PS_D(p1u16_p1u16_u32_p1i8)(uint16_t *key, uint16_t *val, + uint8_t *scratch) { + private_merge_sort_key_value_close(key, val, 1, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void SG_PS_A(p1u16_p1i16_u32_p1i8)(uint16_t *key, int16_t *val, + uint8_t *scratch) { + private_merge_sort_key_value_close(key, reinterpret_cast(val), 1, + scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void SG_PS_D(p1u16_p1i16_u32_p1i8)(uint16_t *key, int16_t *val, + uint8_t *scratch) { + private_merge_sort_key_value_close(key, reinterpret_cast(val), 1, + scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void SG_PS_A(p1u16_p1u32_u32_p1i8)(uint16_t *key, uint32_t *val, + uint8_t *scratch) { + private_merge_sort_key_value_close(key, val, 1, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void SG_PS_D(p1u16_p1u32_u32_p1i8)(uint16_t *key, uint32_t *val, + uint8_t *scratch) { + private_merge_sort_key_value_close(key, val, 1, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void SG_PS_A(p1u16_p1i32_u32_p1i8)(uint16_t *key, int32_t *val, + uint8_t *scratch) { + private_merge_sort_key_value_close(key, reinterpret_cast(val), 1, + scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void SG_PS_D(p1u16_p1i32_u32_p1i8)(uint16_t *key, int32_t *val, + uint8_t *scratch) { + private_merge_sort_key_value_close(key, reinterpret_cast(val), 1, + scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void SG_PS_A(p1u16_p1u64_u32_p1i8)(uint16_t *key, uint64_t *val, + uint8_t *scratch) { + private_merge_sort_key_value_close(key, val, 1, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void SG_PS_D(p1u16_p1u64_u32_p1i8)(uint16_t *key, uint64_t *val, + uint8_t *scratch) { + private_merge_sort_key_value_close(key, val, 1, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void SG_PS_A(p1u16_p1i64_u32_p1i8)(uint16_t *key, int64_t *val, + uint8_t *scratch) { + private_merge_sort_key_value_close(key, reinterpret_cast(val), 1, + scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void SG_PS_D(p1u16_p1i64_u32_p1i8)(uint16_t *key, int64_t *val, + uint8_t *scratch) { + private_merge_sort_key_value_close(key, reinterpret_cast(val), 1, + scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void SG_PS_A(p1u16_p1f32_u32_p1i8)(uint16_t *key, float *val, + uint8_t *scratch) { + private_merge_sort_key_value_close(key, reinterpret_cast(val), 1, + scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void SG_PS_D(p1u16_p1f32_u32_p1i8)(uint16_t *key, float *val, + uint8_t *scratch) { + private_merge_sort_key_value_close(key, reinterpret_cast(val), 1, + scratch, std::greater_equal{}); +} #endif // __SPIR__ || __SPIRV__ From 5436084a096ea4aed59c0a515d78e0175c7233a0 Mon Sep 17 00:00:00 2001 From: jinge90 Date: Mon, 14 Oct 2024 15:13:18 +0800 Subject: [PATCH 63/71] Fix incorrect type in sg_private_sort_p1u16_p1i8 Signed-off-by: jinge90 --- libdevice/fallback-gsort.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/libdevice/fallback-gsort.cpp b/libdevice/fallback-gsort.cpp index fd0a5eb9018d4..14d20175c9bc3 100644 --- a/libdevice/fallback-gsort.cpp +++ b/libdevice/fallback-gsort.cpp @@ -7001,13 +7001,13 @@ void SG_PS_D(p1u8_p1u8_u32_p1i8)(uint8_t *key, uint8_t *val, uint8_t *scratch) { DEVICE_EXTERN_C_INLINE void SG_PS_A(p1u8_p1i8_u32_p1i8)(uint8_t *key, int8_t *val, uint8_t *scratch) { - private_merge_sort_key_value_close(key, reinterpret_cast(val), 1, + private_merge_sort_key_value_close(key, reinterpret_cast(val), 1, scratch, std::less_equal{}); } DEVICE_EXTERN_C_INLINE void SG_PS_D(p1u8_p1i8_u32_p1i8)(uint8_t *key, int8_t *val, uint8_t *scratch) { - private_merge_sort_key_value_close(key, reinterpret_cast(val), 1, + private_merge_sort_key_value_close(key, reinterpret_cast(val), 1, scratch, std::greater_equal{}); } @@ -7124,14 +7124,14 @@ void SG_PS_D(p1u16_p1u8_u32_p1i8)(uint16_t *key, uint8_t *val, DEVICE_EXTERN_C_INLINE void SG_PS_A(p1u16_p1i8_u32_p1i8)(uint16_t *key, int8_t *val, uint8_t *scratch) { - private_merge_sort_key_value_close(key, reinterpret_cast(val), 1, + private_merge_sort_key_value_close(key, reinterpret_cast(val), 1, scratch, std::less_equal{}); } DEVICE_EXTERN_C_INLINE void SG_PS_D(p1u16_p1i8_u32_p1i8)(uint16_t *key, int8_t *val, uint8_t *scratch) { - private_merge_sort_key_value_close(key, reinterpret_cast(val), 1, + private_merge_sort_key_value_close(key, reinterpret_cast(val), 1, scratch, std::greater_equal{}); } From 961a6d80790e5a64bd850b6efad52b4db391068d Mon Sep 17 00:00:00 2001 From: jinge90 Date: Mon, 14 Oct 2024 20:24:43 +0800 Subject: [PATCH 64/71] fix e2e test header file issue Signed-off-by: jinge90 --- sycl/test-e2e/DeviceLib/group_sort/group_joint_sort_p1p1.hpp | 2 +- sycl/test-e2e/DeviceLib/group_sort/group_joint_sort_p1p3.hpp | 2 +- sycl/test-e2e/DeviceLib/group_sort/group_joint_sort_p3p1.hpp | 2 +- sycl/test-e2e/DeviceLib/group_sort/group_joint_sort_p3p3.hpp | 2 +- .../DeviceLib/group_sort/group_private_KV_sort_p1p1_p1.hpp | 2 +- .../DeviceLib/group_sort/group_private_KV_sort_p1p1_p3.hpp | 2 +- sycl/test-e2e/DeviceLib/group_sort/group_private_sort_p1p1.hpp | 2 +- sycl/test-e2e/DeviceLib/group_sort/group_sort.hpp | 2 +- .../DeviceLib/group_sort/workgroup_joint_KV_sort_p1p1_p1.cpp | 2 +- .../test-e2e/DeviceLib/group_sort/workgroup_joint_sort_p1p1.cpp | 2 +- .../test-e2e/DeviceLib/group_sort/workgroup_joint_sort_p1p3.cpp | 2 +- .../test-e2e/DeviceLib/group_sort/workgroup_joint_sort_p3p1.cpp | 2 +- .../test-e2e/DeviceLib/group_sort/workgroup_joint_sort_p3p3.cpp | 2 +- .../DeviceLib/group_sort/workgroup_private_sort_p1p1.cpp | 2 +- 14 files changed, 14 insertions(+), 14 deletions(-) diff --git a/sycl/test-e2e/DeviceLib/group_sort/group_joint_sort_p1p1.hpp b/sycl/test-e2e/DeviceLib/group_sort/group_joint_sort_p1p1.hpp index d54bdae571834..2b84ce68da2b3 100644 --- a/sycl/test-e2e/DeviceLib/group_sort/group_joint_sort_p1p1.hpp +++ b/sycl/test-e2e/DeviceLib/group_sort/group_joint_sort_p1p1.hpp @@ -1,5 +1,5 @@ #pragma once -#include +#include #ifdef __SYCL_DEVICE_ONLY__ SYCL_EXTERNAL extern "C" void diff --git a/sycl/test-e2e/DeviceLib/group_sort/group_joint_sort_p1p3.hpp b/sycl/test-e2e/DeviceLib/group_sort/group_joint_sort_p1p3.hpp index 1177c82254c01..a43f191c2b8b4 100644 --- a/sycl/test-e2e/DeviceLib/group_sort/group_joint_sort_p1p3.hpp +++ b/sycl/test-e2e/DeviceLib/group_sort/group_joint_sort_p1p3.hpp @@ -1,5 +1,5 @@ #pragma once -#include +#include #ifdef __SYCL_DEVICE_ONLY__ SYCL_EXTERNAL extern "C" void diff --git a/sycl/test-e2e/DeviceLib/group_sort/group_joint_sort_p3p1.hpp b/sycl/test-e2e/DeviceLib/group_sort/group_joint_sort_p3p1.hpp index 361e5393ce3fe..180ddeecffcfb 100644 --- a/sycl/test-e2e/DeviceLib/group_sort/group_joint_sort_p3p1.hpp +++ b/sycl/test-e2e/DeviceLib/group_sort/group_joint_sort_p3p1.hpp @@ -1,5 +1,5 @@ #pragma once -#include +#include #ifdef __SYCL_DEVICE_ONLY__ SYCL_EXTERNAL extern "C" void diff --git a/sycl/test-e2e/DeviceLib/group_sort/group_joint_sort_p3p3.hpp b/sycl/test-e2e/DeviceLib/group_sort/group_joint_sort_p3p3.hpp index 4b8ea3b737674..ff80e75a8fa49 100644 --- a/sycl/test-e2e/DeviceLib/group_sort/group_joint_sort_p3p3.hpp +++ b/sycl/test-e2e/DeviceLib/group_sort/group_joint_sort_p3p3.hpp @@ -1,5 +1,5 @@ #pragma once -#include +#include #ifdef __SYCL_DEVICE_ONLY__ SYCL_EXTERNAL extern "C" void diff --git a/sycl/test-e2e/DeviceLib/group_sort/group_private_KV_sort_p1p1_p1.hpp b/sycl/test-e2e/DeviceLib/group_sort/group_private_KV_sort_p1p1_p1.hpp index 9a99a53e569d8..aac6ef44f3f1f 100644 --- a/sycl/test-e2e/DeviceLib/group_sort/group_private_KV_sort_p1p1_p1.hpp +++ b/sycl/test-e2e/DeviceLib/group_sort/group_private_KV_sort_p1p1_p1.hpp @@ -4,7 +4,7 @@ #include #include #include -#include +#include #include #include __DPCPP_SYCL_EXTERNAL extern "C" void diff --git a/sycl/test-e2e/DeviceLib/group_sort/group_private_KV_sort_p1p1_p3.hpp b/sycl/test-e2e/DeviceLib/group_sort/group_private_KV_sort_p1p1_p3.hpp index 14f2dd3eea821..3b47a4c59aefb 100644 --- a/sycl/test-e2e/DeviceLib/group_sort/group_private_KV_sort_p1p1_p3.hpp +++ b/sycl/test-e2e/DeviceLib/group_sort/group_private_KV_sort_p1p1_p3.hpp @@ -4,7 +4,7 @@ #include #include #include -#include +#include #include #include __DPCPP_SYCL_EXTERNAL extern "C" void diff --git a/sycl/test-e2e/DeviceLib/group_sort/group_private_sort_p1p1.hpp b/sycl/test-e2e/DeviceLib/group_sort/group_private_sort_p1p1.hpp index 35b27a9e1d9eb..33618b3e1682a 100644 --- a/sycl/test-e2e/DeviceLib/group_sort/group_private_sort_p1p1.hpp +++ b/sycl/test-e2e/DeviceLib/group_sort/group_private_sort_p1p1.hpp @@ -1,5 +1,5 @@ #pragma once -#include +#include #ifdef __SYCL_DEVICE_ONLY__ SYCL_EXTERNAL extern "C" void __devicelib_default_work_group_private_sort_close_ascending_p1i8_u32_p1i8( diff --git a/sycl/test-e2e/DeviceLib/group_sort/group_sort.hpp b/sycl/test-e2e/DeviceLib/group_sort/group_sort.hpp index e07ab7ac523b0..b1fbf85551c97 100644 --- a/sycl/test-e2e/DeviceLib/group_sort/group_sort.hpp +++ b/sycl/test-e2e/DeviceLib/group_sort/group_sort.hpp @@ -1,5 +1,5 @@ #pragma once -#include +#include #if defined(__SYCL_DEVICE_ONLY__) && (defined(__SPIR__) || defined(__SPIRV__)) #define __DEVICE_CODE 1 diff --git a/sycl/test-e2e/DeviceLib/group_sort/workgroup_joint_KV_sort_p1p1_p1.cpp b/sycl/test-e2e/DeviceLib/group_sort/workgroup_joint_KV_sort_p1p1_p1.cpp index 8063efa465436..c2774191593ea 100644 --- a/sycl/test-e2e/DeviceLib/group_sort/workgroup_joint_KV_sort_p1p1_p1.cpp +++ b/sycl/test-e2e/DeviceLib/group_sort/workgroup_joint_KV_sort_p1p1_p1.cpp @@ -12,7 +12,7 @@ #include #include #include -#include +#include #include #include diff --git a/sycl/test-e2e/DeviceLib/group_sort/workgroup_joint_sort_p1p1.cpp b/sycl/test-e2e/DeviceLib/group_sort/workgroup_joint_sort_p1p1.cpp index ff2538d1bfb57..936a00603cf58 100644 --- a/sycl/test-e2e/DeviceLib/group_sort/workgroup_joint_sort_p1p1.cpp +++ b/sycl/test-e2e/DeviceLib/group_sort/workgroup_joint_sort_p1p1.cpp @@ -17,7 +17,7 @@ #include #include #include -#include +#include using namespace sycl; template diff --git a/sycl/test-e2e/DeviceLib/group_sort/workgroup_joint_sort_p1p3.cpp b/sycl/test-e2e/DeviceLib/group_sort/workgroup_joint_sort_p1p3.cpp index 603ce777e470f..f3fa3bed5f572 100644 --- a/sycl/test-e2e/DeviceLib/group_sort/workgroup_joint_sort_p1p3.cpp +++ b/sycl/test-e2e/DeviceLib/group_sort/workgroup_joint_sort_p1p3.cpp @@ -17,7 +17,7 @@ #include #include #include -#include +#include using namespace sycl; // For __devicelib_default_work_group_xxx_p1*_u32_p3u8, the scratch memory is diff --git a/sycl/test-e2e/DeviceLib/group_sort/workgroup_joint_sort_p3p1.cpp b/sycl/test-e2e/DeviceLib/group_sort/workgroup_joint_sort_p3p1.cpp index 5204aae9c09fb..d8bf629ed7fc1 100644 --- a/sycl/test-e2e/DeviceLib/group_sort/workgroup_joint_sort_p3p1.cpp +++ b/sycl/test-e2e/DeviceLib/group_sort/workgroup_joint_sort_p3p1.cpp @@ -17,7 +17,7 @@ #include #include #include -#include +#include using namespace sycl; template diff --git a/sycl/test-e2e/DeviceLib/group_sort/workgroup_joint_sort_p3p3.cpp b/sycl/test-e2e/DeviceLib/group_sort/workgroup_joint_sort_p3p3.cpp index e83fa5d0d4361..f7462362eba13 100644 --- a/sycl/test-e2e/DeviceLib/group_sort/workgroup_joint_sort_p3p3.cpp +++ b/sycl/test-e2e/DeviceLib/group_sort/workgroup_joint_sort_p3p3.cpp @@ -17,7 +17,7 @@ #include #include #include -#include +#include using namespace sycl; template diff --git a/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_sort_p1p1.cpp b/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_sort_p1p1.cpp index 5fd2444d0c9ae..e1c7eb6b8dc83 100644 --- a/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_sort_p1p1.cpp +++ b/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_sort_p1p1.cpp @@ -26,7 +26,7 @@ #include #include #include -#include +#include using namespace sycl; template void test_work_group_private_sort(sycl::queue &q, Ty input[NUM], From 3ff204a1196b26ebf5ad7e56b68af509db88db9f Mon Sep 17 00:00:00 2001 From: jinge90 Date: Fri, 18 Oct 2024 15:22:59 +0800 Subject: [PATCH 65/71] remove incorrect subgroup private func names Signed-off-by: jinge90 --- libdevice/fallback-gsort.cpp | 246 ------------------ .../group_private_KV_sort_p1p1_p1.hpp | 1 - .../group_private_KV_sort_p1p1_p3.hpp | 1 - 3 files changed, 248 deletions(-) diff --git a/libdevice/fallback-gsort.cpp b/libdevice/fallback-gsort.cpp index 14d20175c9bc3..c5f9a78097ce3 100644 --- a/libdevice/fallback-gsort.cpp +++ b/libdevice/fallback-gsort.cpp @@ -6986,250 +6986,4 @@ void WG_PS_SD(p1i64_p1f32_u32_p3i8)(int64_t *keys, float *vals, uint32_t n, std::greater_equal{}); } -// 1-element version of subgroup private sort. -DEVICE_EXTERN_C_INLINE -void SG_PS_A(p1u8_p1u8_u32_p1i8)(uint8_t *key, uint8_t *val, uint8_t *scratch) { - private_merge_sort_key_value_close(key, val, 1, scratch, - std::less_equal{}); -} - -DEVICE_EXTERN_C_INLINE -void SG_PS_D(p1u8_p1u8_u32_p1i8)(uint8_t *key, uint8_t *val, uint8_t *scratch) { - private_merge_sort_key_value_close(key, val, 1, scratch, - std::greater_equal{}); -} - -DEVICE_EXTERN_C_INLINE -void SG_PS_A(p1u8_p1i8_u32_p1i8)(uint8_t *key, int8_t *val, uint8_t *scratch) { - private_merge_sort_key_value_close(key, reinterpret_cast(val), 1, - scratch, std::less_equal{}); -} - -DEVICE_EXTERN_C_INLINE -void SG_PS_D(p1u8_p1i8_u32_p1i8)(uint8_t *key, int8_t *val, uint8_t *scratch) { - private_merge_sort_key_value_close(key, reinterpret_cast(val), 1, - scratch, std::greater_equal{}); -} - -DEVICE_EXTERN_C_INLINE -void SG_PS_A(p1u8_p1u16_u32_p1i8)(uint8_t *key, uint16_t *val, - uint8_t *scratch) { - private_merge_sort_key_value_close(key, val, 1, scratch, - std::less_equal{}); -} - -DEVICE_EXTERN_C_INLINE -void SG_PS_D(p1u8_p1u16_u32_p1i8)(uint8_t *key, uint16_t *val, - uint8_t *scratch) { - private_merge_sort_key_value_close(key, val, 1, scratch, - std::greater_equal{}); -} - -DEVICE_EXTERN_C_INLINE -void SG_PS_A(p1u8_p1i16_u32_p1i8)(uint8_t *key, int16_t *val, - uint8_t *scratch) { - private_merge_sort_key_value_close(key, reinterpret_cast(val), 1, - scratch, std::less_equal{}); -} - -DEVICE_EXTERN_C_INLINE -void SG_PS_D(p1u8_p1i16_u32_p1i8)(uint8_t *key, int16_t *val, - uint8_t *scratch) { - private_merge_sort_key_value_close(key, reinterpret_cast(val), 1, - scratch, std::greater_equal{}); -} - -DEVICE_EXTERN_C_INLINE -void SG_PS_A(p1u8_p1u32_u32_p1i8)(uint8_t *key, uint32_t *val, - uint8_t *scratch) { - private_merge_sort_key_value_close(key, val, 1, scratch, - std::less_equal{}); -} - -DEVICE_EXTERN_C_INLINE -void SG_PS_D(p1u8_p1u32_u32_p1i8)(uint8_t *key, uint32_t *val, - uint8_t *scratch) { - private_merge_sort_key_value_close(key, val, 1, scratch, - std::greater_equal{}); -} - -DEVICE_EXTERN_C_INLINE -void SG_PS_A(p1u8_p1i32_u32_p1i8)(uint8_t *key, int32_t *val, - uint8_t *scratch) { - private_merge_sort_key_value_close(key, reinterpret_cast(val), 1, - scratch, std::less_equal{}); -} - -DEVICE_EXTERN_C_INLINE -void SG_PS_D(p1u8_p1i32_u32_p1i8)(uint8_t *key, int32_t *val, - uint8_t *scratch) { - private_merge_sort_key_value_close(key, reinterpret_cast(val), 1, - scratch, std::greater_equal{}); -} - -DEVICE_EXTERN_C_INLINE -void SG_PS_A(p1u8_p1u64_u32_p1i8)(uint8_t *key, uint64_t *val, - uint8_t *scratch) { - private_merge_sort_key_value_close(key, val, 1, scratch, - std::less_equal{}); -} - -DEVICE_EXTERN_C_INLINE -void SG_PS_D(p1u8_p1u64_u32_p1i8)(uint8_t *key, uint64_t *val, - uint8_t *scratch) { - private_merge_sort_key_value_close(key, val, 1, scratch, - std::greater_equal{}); -} - -DEVICE_EXTERN_C_INLINE -void SG_PS_A(p1u8_p1i64_u32_p1i8)(uint8_t *key, int64_t *val, - uint8_t *scratch) { - private_merge_sort_key_value_close(key, reinterpret_cast(val), 1, - scratch, std::less_equal{}); -} - -DEVICE_EXTERN_C_INLINE -void SG_PS_D(p1u8_p1i64_u32_p1i8)(uint8_t *key, int64_t *val, - uint8_t *scratch) { - private_merge_sort_key_value_close(key, reinterpret_cast(val), 1, - scratch, std::greater_equal{}); -} - -DEVICE_EXTERN_C_INLINE -void SG_PS_A(p1u8_p1f32_u32_p1i8)(uint8_t *key, float *val, uint8_t *scratch) { - private_merge_sort_key_value_close(key, reinterpret_cast(val), 1, - scratch, std::less_equal{}); -} - -DEVICE_EXTERN_C_INLINE -void SG_PS_D(p1u8_p1f32_u32_p1i8)(uint8_t *key, float *val, uint8_t *scratch) { - private_merge_sort_key_value_close(key, reinterpret_cast(val), 1, - scratch, std::greater_equal{}); -} - -DEVICE_EXTERN_C_INLINE -void SG_PS_A(p1u16_p1u8_u32_p1i8)(uint16_t *key, uint8_t *val, - uint8_t *scratch) { - private_merge_sort_key_value_close(key, val, 1, scratch, - std::less_equal{}); -} - -DEVICE_EXTERN_C_INLINE -void SG_PS_D(p1u16_p1u8_u32_p1i8)(uint16_t *key, uint8_t *val, - uint8_t *scratch) { - private_merge_sort_key_value_close(key, val, 1, scratch, - std::greater_equal{}); -} - -DEVICE_EXTERN_C_INLINE -void SG_PS_A(p1u16_p1i8_u32_p1i8)(uint16_t *key, int8_t *val, - uint8_t *scratch) { - private_merge_sort_key_value_close(key, reinterpret_cast(val), 1, - scratch, std::less_equal{}); -} - -DEVICE_EXTERN_C_INLINE -void SG_PS_D(p1u16_p1i8_u32_p1i8)(uint16_t *key, int8_t *val, - uint8_t *scratch) { - private_merge_sort_key_value_close(key, reinterpret_cast(val), 1, - scratch, std::greater_equal{}); -} - -DEVICE_EXTERN_C_INLINE -void SG_PS_A(p1u16_p1u16_u32_p1i8)(uint16_t *key, uint16_t *val, - uint8_t *scratch) { - private_merge_sort_key_value_close(key, val, 1, scratch, - std::less_equal{}); -} - -DEVICE_EXTERN_C_INLINE -void SG_PS_D(p1u16_p1u16_u32_p1i8)(uint16_t *key, uint16_t *val, - uint8_t *scratch) { - private_merge_sort_key_value_close(key, val, 1, scratch, - std::greater_equal{}); -} - -DEVICE_EXTERN_C_INLINE -void SG_PS_A(p1u16_p1i16_u32_p1i8)(uint16_t *key, int16_t *val, - uint8_t *scratch) { - private_merge_sort_key_value_close(key, reinterpret_cast(val), 1, - scratch, std::less_equal{}); -} - -DEVICE_EXTERN_C_INLINE -void SG_PS_D(p1u16_p1i16_u32_p1i8)(uint16_t *key, int16_t *val, - uint8_t *scratch) { - private_merge_sort_key_value_close(key, reinterpret_cast(val), 1, - scratch, std::greater_equal{}); -} - -DEVICE_EXTERN_C_INLINE -void SG_PS_A(p1u16_p1u32_u32_p1i8)(uint16_t *key, uint32_t *val, - uint8_t *scratch) { - private_merge_sort_key_value_close(key, val, 1, scratch, - std::less_equal{}); -} - -DEVICE_EXTERN_C_INLINE -void SG_PS_D(p1u16_p1u32_u32_p1i8)(uint16_t *key, uint32_t *val, - uint8_t *scratch) { - private_merge_sort_key_value_close(key, val, 1, scratch, - std::greater_equal{}); -} - -DEVICE_EXTERN_C_INLINE -void SG_PS_A(p1u16_p1i32_u32_p1i8)(uint16_t *key, int32_t *val, - uint8_t *scratch) { - private_merge_sort_key_value_close(key, reinterpret_cast(val), 1, - scratch, std::less_equal{}); -} - -DEVICE_EXTERN_C_INLINE -void SG_PS_D(p1u16_p1i32_u32_p1i8)(uint16_t *key, int32_t *val, - uint8_t *scratch) { - private_merge_sort_key_value_close(key, reinterpret_cast(val), 1, - scratch, std::greater_equal{}); -} - -DEVICE_EXTERN_C_INLINE -void SG_PS_A(p1u16_p1u64_u32_p1i8)(uint16_t *key, uint64_t *val, - uint8_t *scratch) { - private_merge_sort_key_value_close(key, val, 1, scratch, - std::less_equal{}); -} - -DEVICE_EXTERN_C_INLINE -void SG_PS_D(p1u16_p1u64_u32_p1i8)(uint16_t *key, uint64_t *val, - uint8_t *scratch) { - private_merge_sort_key_value_close(key, val, 1, scratch, - std::greater_equal{}); -} - -DEVICE_EXTERN_C_INLINE -void SG_PS_A(p1u16_p1i64_u32_p1i8)(uint16_t *key, int64_t *val, - uint8_t *scratch) { - private_merge_sort_key_value_close(key, reinterpret_cast(val), 1, - scratch, std::less_equal{}); -} - -DEVICE_EXTERN_C_INLINE -void SG_PS_D(p1u16_p1i64_u32_p1i8)(uint16_t *key, int64_t *val, - uint8_t *scratch) { - private_merge_sort_key_value_close(key, reinterpret_cast(val), 1, - scratch, std::greater_equal{}); -} - -DEVICE_EXTERN_C_INLINE -void SG_PS_A(p1u16_p1f32_u32_p1i8)(uint16_t *key, float *val, - uint8_t *scratch) { - private_merge_sort_key_value_close(key, reinterpret_cast(val), 1, - scratch, std::less_equal{}); -} - -DEVICE_EXTERN_C_INLINE -void SG_PS_D(p1u16_p1f32_u32_p1i8)(uint16_t *key, float *val, - uint8_t *scratch) { - private_merge_sort_key_value_close(key, reinterpret_cast(val), 1, - scratch, std::greater_equal{}); -} #endif // __SPIR__ || __SPIRV__ diff --git a/sycl/test-e2e/DeviceLib/group_sort/group_private_KV_sort_p1p1_p1.hpp b/sycl/test-e2e/DeviceLib/group_sort/group_private_KV_sort_p1p1_p1.hpp index aac6ef44f3f1f..d7086ec343f97 100644 --- a/sycl/test-e2e/DeviceLib/group_sort/group_private_KV_sort_p1p1_p1.hpp +++ b/sycl/test-e2e/DeviceLib/group_sort/group_private_KV_sort_p1p1_p1.hpp @@ -1132,7 +1132,6 @@ void test_work_group_KV_private_sort(sycl::queue &q, const KeyT keys[NUM], accessor ivals_acc{ivals_buf, h}; accessor okeys_acc{okeys_buf, h}; accessor ovals_acc{ovals_buf, h}; - sycl::stream os(1024, 128, h); h.parallel_for(num_items, [=](nd_item<1> i) { KeyT pkeys[num_per_work_item]; ValT pvals[num_per_work_item]; diff --git a/sycl/test-e2e/DeviceLib/group_sort/group_private_KV_sort_p1p1_p3.hpp b/sycl/test-e2e/DeviceLib/group_sort/group_private_KV_sort_p1p1_p3.hpp index 3b47a4c59aefb..fbed182a3926e 100644 --- a/sycl/test-e2e/DeviceLib/group_sort/group_private_KV_sort_p1p1_p3.hpp +++ b/sycl/test-e2e/DeviceLib/group_sort/group_private_KV_sort_p1p1_p3.hpp @@ -1134,7 +1134,6 @@ void test_work_group_KV_private_sort(sycl::queue &q, const KeyT keys[NUM], accessor okeys_acc{okeys_buf, h}; accessor ovals_acc{ovals_buf, h}; local_accessor scratch_acc(scratch_size >> 3, h); - sycl::stream os(1024, 128, h); h.parallel_for(num_items, [=](nd_item<1> i) { KeyT pkeys[num_per_work_item]; ValT pvals[num_per_work_item]; From a4e857292630a5c135d067b0fa3881c5b5b87837 Mon Sep 17 00:00:00 2001 From: jinge90 Date: Mon, 21 Oct 2024 11:25:07 +0800 Subject: [PATCH 66/71] correct function name for subgroup private sorting Signed-off-by: jinge90 --- libdevice/fallback-gsort.cpp | 141 ++++++++++++++++++++++++++++++----- 1 file changed, 121 insertions(+), 20 deletions(-) diff --git a/libdevice/fallback-gsort.cpp b/libdevice/fallback-gsort.cpp index c5f9a78097ce3..f32a92fab0f66 100644 --- a/libdevice/fallback-gsort.cpp +++ b/libdevice/fallback-gsort.cpp @@ -850,103 +850,204 @@ void WG_PS_SD(p1f16_u32_p3i8)(_Float16 *first, uint32_t n, uint8_t *scratch) { //============= default sub group private sort for signed integer ============= DEVICE_EXTERN_C_INLINE -int8_t SG_PS_A(i8)(int8_t value, uint8_t *scratch) { +int8_t SG_PS_A(i8_p1i8)(int8_t value, uint8_t *scratch) { + return sub_group_merge_sort(value, scratch, std::less{}); +} + +int8_t SG_PS_A(i8_p3i8)(int8_t value, uint8_t *scratch) { return sub_group_merge_sort(value, scratch, std::less{}); } DEVICE_EXTERN_C_INLINE -int16_t SG_PS_A(i16)(int16_t value, uint8_t *scratch) { +int16_t SG_PS_A(i16_p1i8)(int16_t value, uint8_t *scratch) { return sub_group_merge_sort(value, scratch, std::less{}); } DEVICE_EXTERN_C_INLINE -int32_t SG_PS_A(i32)(int32_t value, uint8_t *scratch) { +int16_t SG_PS_A(i16_p3i8)(int16_t value, uint8_t *scratch) { + return sub_group_merge_sort(value, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +int32_t SG_PS_A(i32_p1i8)(int32_t value, uint8_t *scratch) { + return sub_group_merge_sort(value, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +int32_t SG_PS_A(i32_p3i8)(int32_t value, uint8_t *scratch) { return sub_group_merge_sort(value, scratch, std::less{}); } DEVICE_EXTERN_C_INLINE -int64_t SG_PS_A(i64)(int64_t value, uint8_t *scratch) { +int64_t SG_PS_A(i64_p1i8)(int64_t value, uint8_t *scratch) { return sub_group_merge_sort(value, scratch, std::less{}); } DEVICE_EXTERN_C_INLINE -uint8_t SG_PS_A(u8)(uint8_t value, uint8_t *scratch) { +int64_t SG_PS_A(i64_p3i8)(int64_t value, uint8_t *scratch) { + return sub_group_merge_sort(value, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +uint8_t SG_PS_A(u8_p1i8)(uint8_t value, uint8_t *scratch) { + return sub_group_merge_sort(value, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +uint8_t SG_PS_A(u8_p3i8)(uint8_t value, uint8_t *scratch) { return sub_group_merge_sort(value, scratch, std::less{}); } DEVICE_EXTERN_C_INLINE -uint16_t SG_PS_A(u16)(uint16_t value, uint8_t *scratch) { +uint16_t SG_PS_A(u16_p1i8)(uint16_t value, uint8_t *scratch) { return sub_group_merge_sort(value, scratch, std::less{}); } DEVICE_EXTERN_C_INLINE -uint32_t SG_PS_A(u32)(uint32_t value, uint8_t *scratch) { +uint16_t SG_PS_A(u16_p3i8)(uint16_t value, uint8_t *scratch) { + return sub_group_merge_sort(value, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +uint32_t SG_PS_A(u32_p1i8)(uint32_t value, uint8_t *scratch) { + return sub_group_merge_sort(value, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +uint32_t SG_PS_A(u32_p3i8)(uint32_t value, uint8_t *scratch) { return sub_group_merge_sort(value, scratch, std::less{}); } DEVICE_EXTERN_C_INLINE -uint64_t SG_PS_A(u64)(uint64_t value, uint8_t *scratch) { +uint64_t SG_PS_A(u64_p1i8)(uint64_t value, uint8_t *scratch) { + return sub_group_merge_sort(value, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +uint64_t SG_PS_A(u64_p3i8)(uint64_t value, uint8_t *scratch) { return sub_group_merge_sort(value, scratch, std::less{}); } DEVICE_EXTERN_C_INLINE -float SG_PS_A(f32)(float value, uint8_t *scratch) { +float SG_PS_A(f32_p1i8)(float value, uint8_t *scratch) { return sub_group_merge_sort(value, scratch, std::less{}); } DEVICE_EXTERN_C_INLINE -_Float16 SG_PS_A(f16)(_Float16 value, uint8_t *scratch) { +float SG_PS_A(f32_p3i8)(float value, uint8_t *scratch) { + return sub_group_merge_sort(value, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +_Float16 SG_PS_A(f16_p1i8)(_Float16 value, uint8_t *scratch) { + return sub_group_merge_sort(value, scratch, + [](_Float16 a, _Float16 b) { return (a < b); }); +} + +DEVICE_EXTERN_C_INLINE +_Float16 SG_PS_A(f16_p3i8)(_Float16 value, uint8_t *scratch) { return sub_group_merge_sort(value, scratch, [](_Float16 a, _Float16 b) { return (a < b); }); } DEVICE_EXTERN_C_INLINE -int8_t SG_PS_D(i8)(int8_t value, uint8_t *scratch) { +int8_t SG_PS_D(i8_p1i8)(int8_t value, uint8_t *scratch) { return sub_group_merge_sort(value, scratch, std::greater{}); } DEVICE_EXTERN_C_INLINE -int16_t SG_PS_D(i16)(int16_t value, uint8_t *scratch) { +int8_t SG_PS_D(i8_p3i8)(int8_t value, uint8_t *scratch) { + return sub_group_merge_sort(value, scratch, std::greater{}); +} + +DEVICE_EXTERN_C_INLINE +int16_t SG_PS_D(i16_p1i8)(int16_t value, uint8_t *scratch) { + return sub_group_merge_sort(value, scratch, std::greater{}); +} + +DEVICE_EXTERN_C_INLINE +int16_t SG_PS_D(i16_p3i8)(int16_t value, uint8_t *scratch) { return sub_group_merge_sort(value, scratch, std::greater{}); } DEVICE_EXTERN_C_INLINE -int32_t SG_PS_D(i32)(int32_t value, uint8_t *scratch) { +int32_t SG_PS_D(i32_p1i8)(int32_t value, uint8_t *scratch) { + return sub_group_merge_sort(value, scratch, std::greater{}); +} + +DEVICE_EXTERN_C_INLINE +int32_t SG_PS_D(i32_p3i8)(int32_t value, uint8_t *scratch) { return sub_group_merge_sort(value, scratch, std::greater{}); } DEVICE_EXTERN_C_INLINE -int64_t SG_PS_D(i64)(int64_t value, uint8_t *scratch) { +int64_t SG_PS_D(i64_p1i8)(int64_t value, uint8_t *scratch) { + return sub_group_merge_sort(value, scratch, std::greater{}); +} + +DEVICE_EXTERN_C_INLINE +int64_t SG_PS_D(i64_p3i8)(int64_t value, uint8_t *scratch) { return sub_group_merge_sort(value, scratch, std::greater{}); } DEVICE_EXTERN_C_INLINE -uint8_t SG_PS_D(u8)(uint8_t value, uint8_t *scratch) { +uint8_t SG_PS_D(u8_p1i8)(uint8_t value, uint8_t *scratch) { return sub_group_merge_sort(value, scratch, std::greater{}); } DEVICE_EXTERN_C_INLINE -uint16_t SG_PS_D(u16)(uint16_t value, uint8_t *scratch) { +uint8_t SG_PS_D(u8_p3i8)(uint8_t value, uint8_t *scratch) { + return sub_group_merge_sort(value, scratch, std::greater{}); +} + +DEVICE_EXTERN_C_INLINE +uint16_t SG_PS_D(u16_p1i8)(uint16_t value, uint8_t *scratch) { + return sub_group_merge_sort(value, scratch, std::greater{}); +} + +DEVICE_EXTERN_C_INLINE +uint16_t SG_PS_D(u16_p3i8)(uint16_t value, uint8_t *scratch) { return sub_group_merge_sort(value, scratch, std::greater{}); } DEVICE_EXTERN_C_INLINE -uint32_t SG_PS_D(u32)(uint32_t value, uint8_t *scratch) { +uint32_t SG_PS_D(u32_p1i8)(uint32_t value, uint8_t *scratch) { + return sub_group_merge_sort(value, scratch, std::greater{}); +} + +DEVICE_EXTERN_C_INLINE +uint32_t SG_PS_D(u32_p3i8)(uint32_t value, uint8_t *scratch) { return sub_group_merge_sort(value, scratch, std::greater{}); } DEVICE_EXTERN_C_INLINE -uint64_t SG_PS_D(u64)(uint64_t value, uint8_t *scratch) { +uint64_t SG_PS_D(u64_p1i8)(uint64_t value, uint8_t *scratch) { + return sub_group_merge_sort(value, scratch, std::greater{}); +} + +DEVICE_EXTERN_C_INLINE +uint64_t SG_PS_D(u64_p3i8)(uint64_t value, uint8_t *scratch) { return sub_group_merge_sort(value, scratch, std::greater{}); } DEVICE_EXTERN_C_INLINE -float SG_PS_D(f32)(float value, uint8_t *scratch) { +float SG_PS_D(f32_p1i8)(float value, uint8_t *scratch) { + return sub_group_merge_sort(value, scratch, std::greater{}); +} + +DEVICE_EXTERN_C_INLINE +float SG_PS_D(f32_p3i8)(float value, uint8_t *scratch) { return sub_group_merge_sort(value, scratch, std::greater{}); } DEVICE_EXTERN_C_INLINE -_Float16 SG_PS_D(f16)(_Float16 value, uint8_t *scratch) { +_Float16 SG_PS_D(f16_p1i8)(_Float16 value, uint8_t *scratch) { + return sub_group_merge_sort(value, scratch, + [](_Float16 a, _Float16 b) { return (a > b); }); +} + +DEVICE_EXTERN_C_INLINE +_Float16 SG_PS_D(f16_p3i8)(_Float16 value, uint8_t *scratch) { return sub_group_merge_sort(value, scratch, [](_Float16 a, _Float16 b) { return (a > b); }); } From 877af070e3a86e624886d1fca5fc083c7705fad0 Mon Sep 17 00:00:00 2001 From: jinge90 Date: Mon, 11 Nov 2024 14:53:43 +0800 Subject: [PATCH 67/71] fix clang format issue Signed-off-by: jinge90 --- libdevice/sort_helper.hpp | 4 +--- .../group_sort/group_private_KV_sort_p1p1_p1.hpp | 6 ++++-- .../group_sort/workgroup_private_KV_sort_i16.cpp | 10 ++++------ .../group_sort/workgroup_private_KV_sort_i32.cpp | 12 +++++------- .../group_sort/workgroup_private_KV_sort_i64.cpp | 12 +++++------- .../group_sort/workgroup_private_KV_sort_i8.cpp | 10 ++++------ .../group_sort/workgroup_private_KV_sort_p3_i32.cpp | 12 +++++------- .../group_sort/workgroup_private_KV_sort_p3_i64.cpp | 12 +++++------- .../group_sort/workgroup_private_KV_sort_p3_i8.cpp | 10 ++++------ .../group_sort/workgroup_private_KV_sort_p3_u64.cpp | 2 -- .../group_sort/workgroup_private_KV_sort_p3_u8.cpp | 2 -- .../group_sort/workgroup_private_KV_sort_u16.cpp | 8 +++----- .../group_sort/workgroup_private_KV_sort_u32.cpp | 2 -- .../group_sort/workgroup_private_KV_sort_u64.cpp | 2 -- .../group_sort/workgroup_private_KV_sort_u8.cpp | 2 -- 15 files changed, 40 insertions(+), 66 deletions(-) diff --git a/libdevice/sort_helper.hpp b/libdevice/sort_helper.hpp index f87adc09bf121..f54a4965d3086 100644 --- a/libdevice/sort_helper.hpp +++ b/libdevice/sort_helper.hpp @@ -339,7 +339,6 @@ void private_merge_sort_key_value_close(KeyT *keys, ValT *vals, size_t n, private_merge_sort_key_value_helper(keys, vals, n, scratch, comp, &keys_back, &vals_back); - const size_t local_idx = __get_wg_local_linear_id(); for (size_t i = 0; i < n; ++i) { keys[i] = keys_back[local_idx * n + i]; @@ -349,14 +348,13 @@ void private_merge_sort_key_value_close(KeyT *keys, ValT *vals, size_t n, template void private_merge_sort_key_value_spread(KeyT *keys, ValT *vals, size_t n, - uint8_t *scratch, Compare comp) { + uint8_t *scratch, Compare comp) { KeyT *keys_back = nullptr; ValT *vals_back = nullptr; private_merge_sort_key_value_helper(keys, vals, n, scratch, comp, &keys_back, &vals_back); - const size_t local_idx = __get_wg_local_linear_id(); const size_t wg_size = __get_wg_local_range(); diff --git a/sycl/test-e2e/DeviceLib/group_sort/group_private_KV_sort_p1p1_p1.hpp b/sycl/test-e2e/DeviceLib/group_sort/group_private_KV_sort_p1p1_p1.hpp index d7086ec343f97..cf40a5d23578a 100644 --- a/sycl/test-e2e/DeviceLib/group_sort/group_private_KV_sort_p1p1_p1.hpp +++ b/sycl/test-e2e/DeviceLib/group_sort/group_private_KV_sort_p1p1_p1.hpp @@ -1166,8 +1166,10 @@ void test_work_group_KV_private_sort(sycl::queue &q, const KeyT keys[NUM], for (size_t idx = 0; idx < NUM; ++idx) { size_t idx1 = idx / WG_SZ; size_t idx2 = idx % WG_SZ; - if ((output_keys[idx2 * num_per_work_item + idx1] != std::get<0>(sorted_vec[idx])) || - (output_vals[idx2 * num_per_work_item + idx1] != std::get<1>(sorted_vec[idx]))) { + if ((output_keys[idx2 * num_per_work_item + idx1] != + std::get<0>(sorted_vec[idx])) || + (output_vals[idx2 * num_per_work_item + idx1] != + std::get<1>(sorted_vec[idx]))) { std::cout << "idx: " << idx << std::endl; fails = true; break; diff --git a/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_i16.cpp b/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_i16.cpp index ba6f1af2b4261..6c9a65b4e6bc7 100644 --- a/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_i16.cpp +++ b/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_i16.cpp @@ -4,14 +4,12 @@ // RUN: %{build} -fsycl-device-lib-jit-link -o %t.out // RUN: %{run} %t.out - // RUN: %{build} -DDES -o %t.out // RUN: %{run} %t.out // RUN: %{build} -DDES -fsycl-device-lib-jit-link -o %t.out // RUN: %{run} %t.out - // RUN: %{build} -DSPREAD -o %t.out // RUN: %{run} %t.out @@ -32,10 +30,10 @@ int main() { queue q; { constexpr static int NUM = 35; - int16_t ikeys[NUM] = {1, -11, 1, 9, -3, 100, 34, 8, 121, - -77, 125, 23, 36, -2, -111, 91, 88, -2, - 51, -23, -81, 83, 31, 42, 2, 1, -99, - 124, 12, 0, -81, 17, 15, 101, 44}; + int16_t ikeys[NUM] = {1, -11, 1, 9, -3, 100, 34, 8, 121, + -77, 125, 23, 36, -2, -111, 91, 88, -2, + 51, -23, -81, 83, 31, 42, 2, 1, -99, + 124, 12, 0, -81, 17, 15, 101, 44}; uint32_t ivals[NUM] = {99, 32, 1, 2, 67, 9123, 453, 435, 91111, 777, 165, 145, 2456, 88811, 761, 96, 765, 10000, 6364, 90, 525, diff --git a/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_i32.cpp b/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_i32.cpp index af38c56d3be3c..7294e6b9345b0 100644 --- a/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_i32.cpp +++ b/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_i32.cpp @@ -4,14 +4,12 @@ // RUN: %{build} -fsycl-device-lib-jit-link -o %t.out // RUN: %{run} %t.out - // RUN: %{build} -DDES -o %t.out // RUN: %{run} %t.out // RUN: %{build} -DDES -fsycl-device-lib-jit-link -o %t.out // RUN: %{run} %t.out - // RUN: %{build} -DSPREAD -o %t.out // RUN: %{run} %t.out @@ -34,11 +32,11 @@ int main() { { constexpr static int NUM = 40; int32_t ikeys[NUM] = { - 1, 11, 1, 9, 3, 100, 34, 8, - 121, 77, 125, 23, 222336, 2, 111, 91, - 881112, -2, -62451, 213, 199181, 3183, 310910, 11242, - 3216, 1, 199, -7712423, 0, 0, -181, 17, - 15, -101, 44, 103934, 1, 11, 193, 213}; + 1, 11, 1, 9, 3, 100, 34, 8, + 121, 77, 125, 23, 222336, 2, 111, 91, + 881112, -2, -62451, 213, 199181, 3183, 310910, 11242, + 3216, 1, 199, -7712423, 0, 0, -181, 17, + 15, -101, 44, 103934, 1, 11, 193, 213}; uint8_t ivals1[NUM] = {99, 32, 1, 2, 67, 91, 45, 43, 91, 77, 16, 14, 24, 88, 76, 96, 76, 100, 63, 90, 52, 82, 1, 22, 9, 225, 127, 0, 12, 128, diff --git a/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_i64.cpp b/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_i64.cpp index 7d1944150250d..de39f4527c0e5 100644 --- a/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_i64.cpp +++ b/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_i64.cpp @@ -4,14 +4,12 @@ // RUN: %{build} -fsycl-device-lib-jit-link -o %t.out // RUN: %{run} %t.out - // RUN: %{build} -DDES -o %t.out // RUN: %{run} %t.out // RUN: %{build} -DDES -fsycl-device-lib-jit-link -o %t.out // RUN: %{run} %t.out - // RUN: %{build} -DSPREAD -o %t.out // RUN: %{run} %t.out @@ -34,11 +32,11 @@ int main() { { constexpr static int NUM = 40; int64_t ikeys[NUM] = { - 1, 11, -1, 9, 3, 100, -34, 8, - 121, 77, 125, 23, 222336, 2, 111, 91, - -881112, 2, 6662451, 213, 199181, -3183, 310910, 11242, - 3216, 1, -199, 7712423, 0, 0, 181, 17, - 15, -101, 44, 103934, 1, -11, 193, 213}; + 1, 11, -1, 9, 3, 100, -34, 8, + 121, 77, 125, 23, 222336, 2, 111, 91, + -881112, 2, 6662451, 213, 199181, -3183, 310910, 11242, + 3216, 1, -199, 7712423, 0, 0, 181, 17, + 15, -101, 44, 103934, 1, -11, 193, 213}; uint8_t ivals1[NUM] = {99, 32, 1, 2, 67, 91, 45, 43, 91, 77, 16, 14, 24, 88, 76, 96, 76, 100, 63, 90, 52, 82, 1, 22, 9, 225, 127, 0, 12, 128, diff --git a/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_i8.cpp b/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_i8.cpp index ddd91f3902d36..96712fbb80317 100644 --- a/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_i8.cpp +++ b/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_i8.cpp @@ -4,14 +4,12 @@ // RUN: %{build} -fsycl-device-lib-jit-link -o %t.out // RUN: %{run} %t.out - // RUN: %{build} -DDES -o %t.out // RUN: %{run} %t.out // RUN: %{build} -DDES -fsycl-device-lib-jit-link -o %t.out // RUN: %{run} %t.out - // RUN: %{build} -DSPREAD -o %t.out // RUN: %{run} %t.out @@ -33,10 +31,10 @@ int main() { { constexpr static int NUM = 35; - int8_t ikeys[NUM] = {1, -11, 1, 9, -3, 100, 34, 8, 121, - -77, 125, 23, 36, -2, -111, 91, 88, -2, - 51, -23, -81, 83, 31, 42, 2, 1, -99, - 124, 12, 0, -81, 17, 15, 101, 44}; + int8_t ikeys[NUM] = {1, -11, 1, 9, -3, 100, 34, 8, 121, + -77, 125, 23, 36, -2, -111, 91, 88, -2, + 51, -23, -81, 83, 31, 42, 2, 1, -99, + 124, 12, 0, -81, 17, 15, 101, 44}; uint32_t ivals[NUM] = {99, 32, 1, 2, 67, 9123, 453, 435, 91111, 777, 165, 145, 2456, 88811, 761, 96, 765, 10000, 6364, 90, 525, diff --git a/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_p3_i32.cpp b/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_p3_i32.cpp index 36eb3a46f00cc..dd460a3e5075e 100644 --- a/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_p3_i32.cpp +++ b/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_p3_i32.cpp @@ -4,14 +4,12 @@ // RUN: %{build} -fsycl-device-lib-jit-link -o %t.out // RUN: %{run} %t.out - // RUN: %{build} -DDES -o %t.out // RUN: %{run} %t.out // RUN: %{build} -DDES -fsycl-device-lib-jit-link -o %t.out // RUN: %{run} %t.out - // RUN: %{build} -DSPREAD -o %t.out // RUN: %{run} %t.out @@ -34,11 +32,11 @@ int main() { { constexpr static int NUM = 40; int32_t ikeys[NUM] = { - 1, 11, 1, 9, 3, 100, 34, 8, - 121, 77, 125, 23, 222336, 2, 111, 91, - 881112, -2, -62451, 213, 199181, 3183, 310910, 11242, - 3216, 1, 199, -7712423, 0, 0, -181, 17, - 15, -101, 44, 103934, 1, 11, 193, 213}; + 1, 11, 1, 9, 3, 100, 34, 8, + 121, 77, 125, 23, 222336, 2, 111, 91, + 881112, -2, -62451, 213, 199181, 3183, 310910, 11242, + 3216, 1, 199, -7712423, 0, 0, -181, 17, + 15, -101, 44, 103934, 1, 11, 193, 213}; uint8_t ivals1[NUM] = {99, 32, 1, 2, 67, 91, 45, 43, 91, 77, 16, 14, 24, 88, 76, 96, 76, 100, 63, 90, 52, 82, 1, 22, 9, 225, 127, 0, 12, 128, diff --git a/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_p3_i64.cpp b/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_p3_i64.cpp index a0f7785170a30..a21e93caff47a 100644 --- a/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_p3_i64.cpp +++ b/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_p3_i64.cpp @@ -4,14 +4,12 @@ // RUN: %{build} -fsycl-device-lib-jit-link -o %t.out // RUN: %{run} %t.out - // RUN: %{build} -DDES -o %t.out // RUN: %{run} %t.out // RUN: %{build} -DDES -fsycl-device-lib-jit-link -o %t.out // RUN: %{run} %t.out - // RUN: %{build} -DSPREAD -o %t.out // RUN: %{run} %t.out @@ -34,11 +32,11 @@ int main() { { constexpr static int NUM = 40; int64_t ikeys[NUM] = { - 1, 11, -1, 9, 3, 100, -34, 8, - 121, 77, 125, 23, 222336, 2, 111, 91, - -881112, 2, 6662451, 213, 199181, -3183, 310910, 11242, - 3216, 1, -199, 7712423, 0, 0, 181, 17, - 15, -101, 44, 103934, 1, -11, 193, 213}; + 1, 11, -1, 9, 3, 100, -34, 8, + 121, 77, 125, 23, 222336, 2, 111, 91, + -881112, 2, 6662451, 213, 199181, -3183, 310910, 11242, + 3216, 1, -199, 7712423, 0, 0, 181, 17, + 15, -101, 44, 103934, 1, -11, 193, 213}; uint8_t ivals1[NUM] = {99, 32, 1, 2, 67, 91, 45, 43, 91, 77, 16, 14, 24, 88, 76, 96, 76, 100, 63, 90, 52, 82, 1, 22, 9, 225, 127, 0, 12, 128, diff --git a/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_p3_i8.cpp b/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_p3_i8.cpp index e048e82acb3f0..41f00cf73c99f 100644 --- a/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_p3_i8.cpp +++ b/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_p3_i8.cpp @@ -4,14 +4,12 @@ // RUN: %{build} -fsycl-device-lib-jit-link -o %t.out // RUN: %{run} %t.out - // RUN: %{build} -DDES -o %t.out // RUN: %{run} %t.out // RUN: %{build} -DDES -fsycl-device-lib-jit-link -o %t.out // RUN: %{run} %t.out - // RUN: %{build} -DSPREAD -o %t.out // RUN: %{run} %t.out @@ -33,10 +31,10 @@ int main() { { constexpr static int NUM = 35; - int8_t ikeys[NUM] = {1, -11, 1, 9, -3, 100, 34, 8, 121, - -77, 125, 23, 36, -2, -111, 91, 88, -2, - 51, -23, -81, 83, 31, 42, 2, 1, -99, - 124, 12, 0, -81, 17, 15, 101, 44}; + int8_t ikeys[NUM] = {1, -11, 1, 9, -3, 100, 34, 8, 121, + -77, 125, 23, 36, -2, -111, 91, 88, -2, + 51, -23, -81, 83, 31, 42, 2, 1, -99, + 124, 12, 0, -81, 17, 15, 101, 44}; uint32_t ivals[NUM] = {99, 32, 1, 2, 67, 9123, 453, 435, 91111, 777, 165, 145, 2456, 88811, 761, 96, 765, 10000, 6364, 90, 525, diff --git a/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_p3_u64.cpp b/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_p3_u64.cpp index 746b963414873..f99b2f44215fd 100644 --- a/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_p3_u64.cpp +++ b/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_p3_u64.cpp @@ -4,14 +4,12 @@ // RUN: %{build} -fsycl-device-lib-jit-link -o %t.out // RUN: %{run} %t.out - // RUN: %{build} -DDES -o %t.out // RUN: %{run} %t.out // RUN: %{build} -DDES -fsycl-device-lib-jit-link -o %t.out // RUN: %{run} %t.out - // RUN: %{build} -DSPREAD -o %t.out // RUN: %{run} %t.out diff --git a/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_p3_u8.cpp b/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_p3_u8.cpp index a34bbe4004f62..b42ea4acc0736 100644 --- a/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_p3_u8.cpp +++ b/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_p3_u8.cpp @@ -4,14 +4,12 @@ // RUN: %{build} -fsycl-device-lib-jit-link -o %t.out // RUN: %{run} %t.out - // RUN: %{build} -DDES -o %t.out // RUN: %{run} %t.out // RUN: %{build} -DDES -fsycl-device-lib-jit-link -o %t.out // RUN: %{run} %t.out - // RUN: %{build} -DSPREAD -o %t.out // RUN: %{run} %t.out diff --git a/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_u16.cpp b/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_u16.cpp index 54355af4f825d..6bd24951d2fc7 100644 --- a/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_u16.cpp +++ b/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_u16.cpp @@ -4,14 +4,12 @@ // RUN: %{build} -fsycl-device-lib-jit-link -o %t.out // RUN: %{run} %t.out - // RUN: %{build} -DDES -o %t.out // RUN: %{run} %t.out // RUN: %{build} -DDES -fsycl-device-lib-jit-link -o %t.out // RUN: %{run} %t.out - // RUN: %{build} -DSPREAD -o %t.out // RUN: %{run} %t.out @@ -34,9 +32,9 @@ int main() { { constexpr static int NUM = 35; uint16_t ikeys[NUM] = {1, 11, 1, 9, 3, 100, 34, 8, 121, - 77, 125, 23, 36, 2, 111, 91, 88, 2, - 51, 213, 181, 183, 31, 142, 216, 1, 199, - 124, 12, 0, 181, 17, 15, 101, 44}; + 77, 125, 23, 36, 2, 111, 91, 88, 2, + 51, 213, 181, 183, 31, 142, 216, 1, 199, + 124, 12, 0, 181, 17, 15, 101, 44}; uint32_t ivals[NUM] = {99, 32, 1, 2, 67, 9123, 453, 435, 91111, 777, 165, 145, 2456, 88811, 761, 96, 765, 10000, 6364, 90, 525, diff --git a/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_u32.cpp b/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_u32.cpp index be598938aa196..b515a5a32ca91 100644 --- a/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_u32.cpp +++ b/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_u32.cpp @@ -4,14 +4,12 @@ // RUN: %{build} -fsycl-device-lib-jit-link -o %t.out // RUN: %{run} %t.out - // RUN: %{build} -DDES -o %t.out // RUN: %{run} %t.out // RUN: %{build} -DDES -fsycl-device-lib-jit-link -o %t.out // RUN: %{run} %t.out - // RUN: %{build} -DSPREAD -o %t.out // RUN: %{run} %t.out diff --git a/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_u64.cpp b/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_u64.cpp index 8a60e0831c0b9..39ca18929bd01 100644 --- a/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_u64.cpp +++ b/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_u64.cpp @@ -4,14 +4,12 @@ // RUN: %{build} -fsycl-device-lib-jit-link -o %t.out // RUN: %{run} %t.out - // RUN: %{build} -DDES -o %t.out // RUN: %{run} %t.out // RUN: %{build} -DDES -fsycl-device-lib-jit-link -o %t.out // RUN: %{run} %t.out - // RUN: %{build} -DSPREAD -o %t.out // RUN: %{run} %t.out diff --git a/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_u8.cpp b/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_u8.cpp index 90080dbbfc161..62e752934dfbd 100644 --- a/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_u8.cpp +++ b/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_u8.cpp @@ -4,14 +4,12 @@ // RUN: %{build} -fsycl-device-lib-jit-link -o %t.out // RUN: %{run} %t.out - // RUN: %{build} -DDES -o %t.out // RUN: %{run} %t.out // RUN: %{build} -DDES -fsycl-device-lib-jit-link -o %t.out // RUN: %{run} %t.out - // RUN: %{build} -DSPREAD -o %t.out // RUN: %{run} %t.out From 7d87a893bd96eb6b41600f81d12adbffff3d3d96 Mon Sep 17 00:00:00 2001 From: jinge90 Date: Mon, 11 Nov 2024 15:19:00 +0800 Subject: [PATCH 68/71] fix clang format Signed-off-by: jinge90 --- .../group_sort/workgroup_private_KV_sort_p3_i16.cpp | 10 ++++------ .../group_sort/workgroup_private_KV_sort_p3_u16.cpp | 8 +++----- .../group_sort/workgroup_private_KV_sort_p3_u32.cpp | 2 -- 3 files changed, 7 insertions(+), 13 deletions(-) diff --git a/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_p3_i16.cpp b/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_p3_i16.cpp index 9516999e0d9b1..19f8335c2c731 100644 --- a/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_p3_i16.cpp +++ b/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_p3_i16.cpp @@ -4,14 +4,12 @@ // RUN: %{build} -fsycl-device-lib-jit-link -o %t.out // RUN: %{run} %t.out - // RUN: %{build} -DDES -o %t.out // RUN: %{run} %t.out // RUN: %{build} -DDES -fsycl-device-lib-jit-link -o %t.out // RUN: %{run} %t.out - // RUN: %{build} -DSPREAD -o %t.out // RUN: %{run} %t.out @@ -32,10 +30,10 @@ int main() { queue q; { constexpr static int NUM = 35; - int16_t ikeys[NUM] = {1, -11, 1, 9, -3, 100, 34, 8, 121, - -77, 125, 23, 36, -2, -111, 91, 88, -2, - 51, -23, -81, 83, 31, 42, 2, 1, -99, - 124, 12, 0, -81, 17, 15, 101, 44}; + int16_t ikeys[NUM] = {1, -11, 1, 9, -3, 100, 34, 8, 121, + -77, 125, 23, 36, -2, -111, 91, 88, -2, + 51, -23, -81, 83, 31, 42, 2, 1, -99, + 124, 12, 0, -81, 17, 15, 101, 44}; uint32_t ivals[NUM] = {99, 32, 1, 2, 67, 9123, 453, 435, 91111, 777, 165, 145, 2456, 88811, 761, 96, 765, 10000, 6364, 90, 525, diff --git a/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_p3_u16.cpp b/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_p3_u16.cpp index 25cd3791ec13e..5927c08582621 100644 --- a/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_p3_u16.cpp +++ b/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_p3_u16.cpp @@ -4,14 +4,12 @@ // RUN: %{build} -fsycl-device-lib-jit-link -o %t.out // RUN: %{run} %t.out - // RUN: %{build} -DDES -o %t.out // RUN: %{run} %t.out // RUN: %{build} -DDES -fsycl-device-lib-jit-link -o %t.out // RUN: %{run} %t.out - // RUN: %{build} -DSPREAD -o %t.out // RUN: %{run} %t.out @@ -34,9 +32,9 @@ int main() { { constexpr static int NUM = 35; uint16_t ikeys[NUM] = {1, 11, 1, 9, 3, 100, 34, 8, 121, - 77, 125, 23, 36, 2, 111, 91, 88, 2, - 51, 213, 181, 183, 31, 142, 216, 1, 199, - 124, 12, 0, 181, 17, 15, 101, 44}; + 77, 125, 23, 36, 2, 111, 91, 88, 2, + 51, 213, 181, 183, 31, 142, 216, 1, 199, + 124, 12, 0, 181, 17, 15, 101, 44}; uint32_t ivals[NUM] = {99, 32, 1, 2, 67, 9123, 453, 435, 91111, 777, 165, 145, 2456, 88811, 761, 96, 765, 10000, 6364, 90, 525, diff --git a/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_p3_u32.cpp b/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_p3_u32.cpp index bd3b67e3476ae..4787de440b25e 100644 --- a/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_p3_u32.cpp +++ b/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_p3_u32.cpp @@ -4,14 +4,12 @@ // RUN: %{build} -fsycl-device-lib-jit-link -o %t.out // RUN: %{run} %t.out - // RUN: %{build} -DDES -o %t.out // RUN: %{run} %t.out // RUN: %{build} -DDES -fsycl-device-lib-jit-link -o %t.out // RUN: %{run} %t.out - // RUN: %{build} -DSPREAD -o %t.out // RUN: %{run} %t.out From 78235d0b21a7a5a4dffec6a08faf77a6139d8af0 Mon Sep 17 00:00:00 2001 From: jinge90 Date: Mon, 11 Nov 2024 15:41:22 +0800 Subject: [PATCH 69/71] Disable AMD target for fallback gsort Signed-off-by: jinge90 --- libdevice/cmake/modules/SYCLLibdevice.cmake | 26 ++++++++++++--------- 1 file changed, 15 insertions(+), 11 deletions(-) diff --git a/libdevice/cmake/modules/SYCLLibdevice.cmake b/libdevice/cmake/modules/SYCLLibdevice.cmake index 37671b620a368..59ef8fa419f02 100644 --- a/libdevice/cmake/modules/SYCLLibdevice.cmake +++ b/libdevice/cmake/modules/SYCLLibdevice.cmake @@ -185,7 +185,7 @@ function(add_devicelibs filename) cmake_parse_arguments(ARG "" "" - "SRC;EXTRA_OPTS;DEPENDENCIES" + "SRC;EXTRA_OPTS;DEPENDENCIES;SKIP_ARCHS" ${ARGN}) foreach(filetype IN LISTS filetypes) @@ -197,15 +197,18 @@ function(add_devicelibs filename) endforeach() foreach(arch IN LISTS devicelib_arch) - compile_lib(${filename}-${arch} - FILETYPE bc - SRC ${ARG_SRC} - DEPENDENCIES ${ARG_DEPENDENCIES} - EXTRA_OPTS ${ARG_EXTRA_OPTS} ${bc_device_compile_opts} - ${compile_opts_${arch}}) - - append_to_property(${bc_binary_dir}/${filename}-${arch}.bc - PROPERTY_NAME BC_DEVICE_LIBS_${arch}) + list(FIND ${ARG_SKIP_ARCHS} ${arch} skip_idx) + if (skip_idx EQUAL -1) + compile_lib(${filename}-${arch} + FILETYPE bc + SRC ${ARG_SRC} + DEPENDENCIES ${ARG_DEPENDENCIES} + EXTRA_OPTS ${ARG_EXTRA_OPTS} ${bc_device_compile_opts} + ${compile_opts_${arch}}) + + append_to_property(${bc_binary_dir}/${filename}-${arch}.bc + PROPERTY_NAME BC_DEVICE_LIBS_${arch}) + endif() endforeach() endfunction() @@ -321,7 +324,8 @@ add_devicelibs(libsycl-native-bfloat16 add_devicelibs(libsycl-fallback-gsort SRC fallback-gsort.cpp DEPENDENCIES ${gsort_obj_deps} - EXTRA_OPTS -fno-sycl-instrument-device-code) + EXTRA_OPTS -fno-sycl-instrument-device-code + SKIP_ARCHS amdgcn-amd-amdhsa) # Create dependency and source lists for Intel math function libraries. file(MAKE_DIRECTORY ${obj_binary_dir}/libdevice) From 3bdc448127fc14283da5059a14838cec307382b0 Mon Sep 17 00:00:00 2001 From: jinge90 Date: Mon, 11 Nov 2024 15:54:12 +0800 Subject: [PATCH 70/71] fix cmake Signed-off-by: jinge90 --- libdevice/cmake/modules/SYCLLibdevice.cmake | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/libdevice/cmake/modules/SYCLLibdevice.cmake b/libdevice/cmake/modules/SYCLLibdevice.cmake index 59ef8fa419f02..c9abe26bda4da 100644 --- a/libdevice/cmake/modules/SYCLLibdevice.cmake +++ b/libdevice/cmake/modules/SYCLLibdevice.cmake @@ -197,8 +197,8 @@ function(add_devicelibs filename) endforeach() foreach(arch IN LISTS devicelib_arch) - list(FIND ${ARG_SKIP_ARCHS} ${arch} skip_idx) - if (skip_idx EQUAL -1) + list(FIND ${ARG_SKIP_ARCHS} "${arch}" skip_idx) + if (${skip_idx} EQUAL -1) compile_lib(${filename}-${arch} FILETYPE bc SRC ${ARG_SRC} From 4e486b41e196e3f7421c376af638f68ffad214f4 Mon Sep 17 00:00:00 2001 From: jinge90 Date: Mon, 11 Nov 2024 16:45:19 +0800 Subject: [PATCH 71/71] Fix cmake issue Signed-off-by: jinge90 --- libdevice/cmake/modules/SYCLLibdevice.cmake | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/libdevice/cmake/modules/SYCLLibdevice.cmake b/libdevice/cmake/modules/SYCLLibdevice.cmake index c9abe26bda4da..48b2bd551c5a9 100644 --- a/libdevice/cmake/modules/SYCLLibdevice.cmake +++ b/libdevice/cmake/modules/SYCLLibdevice.cmake @@ -197,8 +197,12 @@ function(add_devicelibs filename) endforeach() foreach(arch IN LISTS devicelib_arch) - list(FIND ${ARG_SKIP_ARCHS} "${arch}" skip_idx) - if (${skip_idx} EQUAL -1) + set(skip_idx -1) + set(skip_arch_list ${ARG_SKIP_ARCHS}) + if (skip_arch_list) + list(FIND skip_arch_list ${arch} skip_idx) + endif() + if (skip_idx EQUAL -1) compile_lib(${filename}-${arch} FILETYPE bc SRC ${ARG_SRC}