diff --git a/clang/include/clang/Driver/Options.td b/clang/include/clang/Driver/Options.td index 4c946b9a98fcf..96580cd42a1b2 100644 --- a/clang/include/clang/Driver/Options.td +++ b/clang/include/clang/Driver/Options.td @@ -6900,14 +6900,14 @@ def fno_sycl_dead_args_optimization : Flag<["-"], "fno-sycl-dead-args-optimizati HelpText<"Disables elimination of DPC++ dead kernel arguments">; def fsycl_device_lib_EQ : CommaJoined<["-"], "fsycl-device-lib=">, Flags<[NoXarchOption]>, - Values<"libc,libm-fp32,libm-fp64,libimf-fp32,libimf-fp64,libimf-bf16,all">, + Values<"libc,libm-fp32,libm-fp64,libimf-fp32,libimf-fp64,libimf-bf16,libgsort-fp32,all">, HelpText<"Control inclusion of device libraries into device binary linkage. " "Valid arguments are libc, libm-fp32, libm-fp64, libimf-fp32, " - "libimf-fp64, libimf-bf16, all">; + "libimf-fp64, libimf-bf16, libgsort-fp32, all">; def fno_sycl_device_lib_EQ : CommaJoined<["-"], "fno-sycl-device-lib=">, - Flags<[NoXarchOption]>, Values<"libc, libm-fp32, libm-fp64, all">, + Flags<[NoXarchOption]>, Values<"libc, libm-fp32, libm-fp64, libgsort-fp32, all">, HelpText<"Control exclusion of device libraries from device binary linkage. " - "Valid arguments are libc, libm-fp32, libm-fp64, all">; + "Valid arguments are libc, libm-fp32, libm-fp64, libgsort-fp32, all">; def fsycl_device_lib_jit_link : Flag<["-"], "fsycl-device-lib-jit-link">, HelpText<"Enables sycl device library jit link (experimental)">; def fno_sycl_device_lib_jit_link : Flag<["-"], "fno-sycl-device-lib-jit-link">, diff --git a/clang/lib/Driver/ToolChains/SYCL.cpp b/clang/lib/Driver/ToolChains/SYCL.cpp index ed254e01bd41e..47b9da47b45fd 100644 --- a/clang/lib/Driver/ToolChains/SYCL.cpp +++ b/clang/lib/Driver/ToolChains/SYCL.cpp @@ -452,9 +452,9 @@ SYCL::getDeviceLibraries(const Compilation &C, const llvm::Triple &TargetTriple, // Currently, all SYCL device libraries will be linked by default. llvm::StringMap DeviceLibLinkInfo = { - {"libc", true}, {"libm-fp32", true}, {"libm-fp64", true}, - {"libimf-fp32", true}, {"libimf-fp64", true}, {"libimf-bf16", true}, - {"libm-bfloat16", true}, {"internal", true}}; + {"libc", true}, {"libm-fp32", true}, {"libm-fp64", true}, + {"libimf-fp32", true}, {"libimf-fp64", true}, {"libimf-bf16", true}, + {"libm-bfloat16", true}, {"libgsort-fp32", true}, {"internal", true}}; // If -fno-sycl-device-lib is specified, its values will be used to exclude // linkage of libraries specified by DeviceLibLinkInfo. Linkage of "internal" @@ -536,6 +536,7 @@ SYCL::getDeviceLibraries(const Compilation &C, const llvm::Triple &TargetTriple, {"libsycl-fallback-complex-fp64", "libm-fp64"}, {"libsycl-fallback-cmath", "libm-fp32"}, {"libsycl-fallback-cmath-fp64", "libm-fp64"}, + {"libsycl-fallback-gsort", "libgsort-fp32"}, {"libsycl-fallback-imf", "libimf-fp32"}, {"libsycl-fallback-imf-fp64", "libimf-fp64"}, {"libsycl-fallback-imf-bf16", "libimf-bf16"}}; @@ -851,6 +852,7 @@ static llvm::SmallVector SYCLDeviceLibList{ "fallback-cmath-fp64", "fallback-complex", "fallback-complex-fp64", + "fallback-gsort", "fallback-imf", "fallback-imf-fp64", "fallback-imf-bf16", diff --git a/clang/test/Driver/Inputs/SYCL/lib/libsycl-fallback-gsort.bc b/clang/test/Driver/Inputs/SYCL/lib/libsycl-fallback-gsort.bc new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/clang/test/Driver/Inputs/SYCL/lib/libsycl-fallback-gsort.o b/clang/test/Driver/Inputs/SYCL/lib/libsycl-fallback-gsort.o new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/clang/test/Driver/sycl-device-lib-win.cpp b/clang/test/Driver/sycl-device-lib-win.cpp index 3f7267e017efa..60ed1d9abab7f 100644 --- a/clang/test/Driver/sycl-device-lib-win.cpp +++ b/clang/test/Driver/sycl-device-lib-win.cpp @@ -33,6 +33,7 @@ // SYCL_DEVICE_LIB_LINK_DEFAULT-SAME: "{{.*}}libsycl-fallback-complex-fp64.bc" // SYCL_DEVICE_LIB_LINK_DEFAULT-SAME: "{{.*}}libsycl-fallback-cmath.bc" // SYCL_DEVICE_LIB_LINK_DEFAULT-SAME: "{{.*}}libsycl-fallback-cmath-fp64.bc" +// SYCL_DEVICE_LIB_LINK_DEFAULT-SAME: "{{.*}}libsycl-fallback-gsort.bc" // SYCL_DEVICE_LIB_LINK_DEFAULT-SAME: "{{.*}}libsycl-fallback-imf.bc" // SYCL_DEVICE_LIB_LINK_DEFAULT-SAME: "{{.*}}libsycl-fallback-imf-fp64.bc" // SYCL_DEVICE_LIB_LINK_DEFAULT-SAME: "{{.*}}libsycl-fallback-imf-bf16.bc" @@ -70,6 +71,7 @@ // SYCL_DEVICE_LIB_LINK_WITH_FP64-SAME: "{{.*}}libsycl-fallback-complex-fp64.bc" // SYCL_DEVICE_LIB_LINK_WITH_FP64-SAME: "{{.*}}libsycl-fallback-cmath.bc" // SYCL_DEVICE_LIB_LINK_WITH_FP64-SAME: "{{.*}}libsycl-fallback-cmath-fp64.bc" +// SYCL_DEVICE_LIB_LINK_WITH_FP64-SAME: "{{.*}}libsycl-fallback-gsort.bc" // SYCL_DEVICE_LIB_LINK_WITH_FP64-SAME: "{{.*}}libsycl-fallback-imf.bc" // SYCL_DEVICE_LIB_LINK_WITH_FP64-SAME: "{{.*}}libsycl-fallback-imf-fp64.bc" // SYCL_DEVICE_LIB_LINK_WITH_FP64-SAME: "{{.*}}libsycl-fallback-imf-bf16.bc" @@ -93,6 +95,7 @@ // SYCL_DEVICE_LIB_LINK_NO_LIBC-SAME: "{{.*}}libsycl-fallback-complex-fp64.bc" // SYCL_DEVICE_LIB_LINK_NO_LIBC-SAME: "{{.*}}libsycl-fallback-cmath.bc" // SYCL_DEVICE_LIB_LINK_NO_LIBC-SAME: "{{.*}}libsycl-fallback-cmath-fp64.bc" +// SYCL_DEVICE_LIB_LINK_NO_LIBC-SAME: "{{.*}}libsycl-fallback-gsort.bc" // SYCL_DEVICE_LIB_LINK_NO_LIBC-SAME: "{{.*}}libsycl-fallback-imf.bc" // SYCL_DEVICE_LIB_LINK_NO_LIBC-SAME: "{{.*}}libsycl-fallback-imf-fp64.bc" // SYCL_DEVICE_LIB_LINK_NO_LIBC-SAME: "{{.*}}libsycl-fallback-imf-bf16.bc" @@ -115,6 +118,7 @@ // SYCL_DEVICE_LIB_LINK_NO_LIBM-NOT: "{{.*}}libsycl-fallback-complex-fp64.bc" // SYCL_DEVICE_LIB_LINK_NO_LIBM-NOT: "{{.*}}libsycl-fallback-cmath.bc" // SYCL_DEVICE_LIB_LINK_NO_LIBM-NOT: "{{.*}}libsycl-fallback-cmath-fp64.bc" +// SYCL_DEVICE_LIB_LINK_NO_LIBM-SAME: "{{.*}}libsycl-fallback-gsort.bc" // SYCL_DEVICE_LIB_LINK_NO_LIBM-SAME: "{{.*}}libsycl-fallback-imf.bc" // SYCL_DEVICE_LIB_LINK_NO_LIBM-SAME: "{{.*}}libsycl-fallback-imf-fp64.bc" // SYCL_DEVICE_LIB_LINK_NO_LIBM-SAME: "{{.*}}libsycl-fallback-imf-bf16.bc" @@ -172,6 +176,7 @@ // SYCL_LLVM_LINK_DEVICE_LIB-SAME: "{{.*}}libsycl-fallback-complex-fp64.bc" // SYCL_LLVM_LINK_DEVICE_LIB-SAME: "{{.*}}libsycl-fallback-cmath.bc" // SYCL_LLVM_LINK_DEVICE_LIB-SAME: "{{.*}}libsycl-fallback-cmath-fp64.bc" +// SYCL_LLVM_LINK_DEVICE_LIB-SAME: "{{.*}}libsycl-fallback-gsort.bc" // SYCL_LLVM_LINK_DEVICE_LIB-SAME: "{{.*}}libsycl-fallback-imf.bc" // SYCL_LLVM_LINK_DEVICE_LIB-SAME: "{{.*}}libsycl-fallback-imf-fp64.bc" // SYCL_LLVM_LINK_DEVICE_LIB-SAME: "{{.*}}libsycl-fallback-imf-bf16.bc" diff --git a/clang/test/Driver/sycl-device-lib.cpp b/clang/test/Driver/sycl-device-lib.cpp index 9e07edf2287fa..f99b649307787 100644 --- a/clang/test/Driver/sycl-device-lib.cpp +++ b/clang/test/Driver/sycl-device-lib.cpp @@ -33,6 +33,7 @@ // SYCL_DEVICE_LIB_LINK_DEFAULT-SAME: {{.*}}libsycl-fallback-complex-fp64.new.o // SYCL_DEVICE_LIB_LINK_DEFAULT-SAME: {{.*}}libsycl-fallback-cmath.new.o // SYCL_DEVICE_LIB_LINK_DEFAULT-SAME: {{.*}}libsycl-fallback-cmath-fp64.new.o +// SYCL_DEVICE_LIB_LINK_DEFAULT-SAME: {{.*}}libsycl-fallback-gsort.new.o // SYCL_DEVICE_LIB_LINK_DEFAULT-SAME: {{.*}}libsycl-fallback-imf.new.o // SYCL_DEVICE_LIB_LINK_DEFAULT-SAME: {{.*}}libsycl-fallback-imf-fp64.new.o // SYCL_DEVICE_LIB_LINK_DEFAULT-SAME: {{.*}}libsycl-fallback-imf-bf16.new.o @@ -70,11 +71,12 @@ // SYCL_DEVICE_LIB_LINK_WITH_FP64-SAME: {{.*}}libsycl-fallback-complex-fp64.new.o // SYCL_DEVICE_LIB_LINK_WITH_FP64-SAME: {{.*}}libsycl-fallback-cmath.new.o // SYCL_DEVICE_LIB_LINK_WITH_FP64-SAME: {{.*}}libsycl-fallback-cmath-fp64.new.o +// SYCL_DEVICE_LIB_LINK_WITH_FP64-SAME: {{.*}}libsycl-fallback-gsort.new.o // SYCL_DEVICE_LIB_LINK_WITH_FP64-SAME: {{.*}}libsycl-fallback-imf.new.o // SYCL_DEVICE_LIB_LINK_WITH_FP64-SAME: {{.*}}libsycl-fallback-imf-fp64.new.o // SYCL_DEVICE_LIB_LINK_WITH_FP64-SAME: {{.*}}libsycl-fallback-imf-bf16.new.o -/// ########################################################################### +/// ########################################################################### /// test behavior of -fno-sycl-device-lib=libc // RUN: %clangxx -fsycl --offload-new-driver %s -fno-sycl-device-lib=libc --sysroot=%S/Inputs/SYCL -### 2>&1 \ // RUN: | FileCheck %s -check-prefix=SYCL_DEVICE_LIB_LINK_NO_LIBC @@ -93,11 +95,12 @@ // SYCL_DEVICE_LIB_LINK_NO_LIBC-SAME: {{.*}}libsycl-fallback-complex-fp64.new.o // SYCL_DEVICE_LIB_LINK_NO_LIBC-SAME: {{.*}}libsycl-fallback-cmath.new.o // SYCL_DEVICE_LIB_LINK_NO_LIBC-SAME: {{.*}}libsycl-fallback-cmath-fp64.new.o +// SYCL_DEVICE_LIB_LINK_NO_LIBC-SAME: {{.*}}libsycl-fallback-gsort.new.o // SYCL_DEVICE_LIB_LINK_NO_LIBC-SAME: {{.*}}libsycl-fallback-imf.new.o // SYCL_DEVICE_LIB_LINK_NO_LIBC-SAME: {{.*}}libsycl-fallback-imf-fp64.new.o // SYCL_DEVICE_LIB_LINK_NO_LIBC-SAME: {{.*}}libsycl-fallback-imf-bf16.new.o -/// ########################################################################### +/// ########################################################################### /// test behavior of -fno-sycl-device-lib=libm-fp32,libm-fp64 // RUN: %clangxx -fsycl --offload-new-driver %s -fno-sycl-device-lib=libm-fp32,libm-fp64 --sysroot=%S/Inputs/SYCL -### 2>&1 \ // RUN: | FileCheck %s -check-prefix=SYCL_DEVICE_LIB_LINK_NO_LIBM @@ -116,11 +119,12 @@ // SYCL_DEVICE_LIB_LINK_NO_LIBM-NOT: {{.*}}libsycl-fallback-complex-fp64.new.o // SYCL_DEVICE_LIB_LINK_NO_LIBM-NOT: {{.*}}libsycl-fallback-cmath.new.o // SYCL_DEVICE_LIB_LINK_NO_LIBM-NOT: {{.*}}libsycl-fallback-cmath-fp64.new.o +// SYCL_DEVICE_LIB_LINK_NO_LIBM-SAME: {{.*}}libsycl-fallback-gsort.new.o // SYCL_DEVICE_LIB_LINK_NO_LIBM-SAME: {{.*}}libsycl-fallback-imf.new.o // SYCL_DEVICE_LIB_LINK_NO_LIBM-SAME: {{.*}}libsycl-fallback-imf-fp64.new.o // SYCL_DEVICE_LIB_LINK_NO_LIBM-SAME: {{.*}}libsycl-fallback-imf-bf16.new.o -/// ########################################################################### +/// ########################################################################### /// test behavior of disabling all device libraries // RUN: %clangxx -fsycl --offload-new-driver %s -fno-sycl-device-lib=libc,libm-fp32 --sysroot=%S/Inputs/SYCL -### 2>&1 \ // RUN: | FileCheck %s -check-prefix=SYCL_DEVICE_LIB_LINK_NO_DEVICE_LIB @@ -152,6 +156,7 @@ // SYCL_DEVICE_LIB_INVALID_VALUE: error: unsupported argument '[[Val]]' to option '-fsycl-device-lib=' // SYCL_NO_DEVICE_LIB_INVALID_VALUE: error: unsupported argument '[[Val]]' to option '-fno-sycl-device-lib=' + /// ########################################################################### /// test behavior of libsycl-asan.o linking when -fsanitize=address is available // RUN: %clangxx -fsycl --offload-new-driver %s --sysroot=%S/Inputs/SYCL -fsanitize=address -### 2>&1 \ @@ -168,8 +173,8 @@ // RUN: | FileCheck %s -check-prefix=SYCL_DEVICE_ASAN_MACRO // SYCL_DEVICE_LIB_ASAN: clang-linker-wrapper{{.*}} "-sycl-device-libraries // SYCL_DEVICE_LIB_ASAN: {{.*}}libsycl-crt.new.o -// SYCL_DEVICE_LIB_ASAN-SAME: {{.*}}libsycl-complex. -// SYCL_DEVICE_LIB_ASAN-SAME: {{.*}}libsycl-complex-fp64. +// SYCL_DEVICE_LIB_ASAN-SAME: {{.*}}libsycl-complex.new.o +// SYCL_DEVICE_LIB_ASAN-SAME: {{.*}}libsycl-complex-fp64.new.o // SYCL_DEVICE_LIB_ASAN-SAME: {{.*}}libsycl-cmath.new.o // SYCL_DEVICE_LIB_ASAN-SAME: {{.*}}libsycl-cmath-fp64.new.o // SYCL_DEVICE_LIB_ASAN-SAME: {{.*}}libsycl-imf.new.o @@ -181,6 +186,7 @@ // SYCL_DEVICE_LIB_ASAN-SAME: {{.*}}libsycl-fallback-complex-fp64.new.o // SYCL_DEVICE_LIB_ASAN-SAME: {{.*}}libsycl-fallback-cmath.new.o // SYCL_DEVICE_LIB_ASAN-SAME: {{.*}}libsycl-fallback-cmath-fp64.new.o +// SYCL_DEVICE_LIB_ASAN-SAME: {{.*}}libsycl-fallback-gsort.new.o // SYCL_DEVICE_LIB_ASAN-SAME: {{.*}}libsycl-fallback-imf.new.o // SYCL_DEVICE_LIB_ASAN-SAME: {{.*}}libsycl-fallback-imf-fp64.new.o // SYCL_DEVICE_LIB_ASAN-SAME: {{.*}}libsycl-fallback-imf-bf16.new.o diff --git a/clang/test/Driver/sycl-offload-new-driver.c b/clang/test/Driver/sycl-offload-new-driver.c index dd656192b80f3..ca241b5bfca03 100644 --- a/clang/test/Driver/sycl-offload-new-driver.c +++ b/clang/test/Driver/sycl-offload-new-driver.c @@ -34,7 +34,7 @@ // RUN: %clangxx --target=x86_64-unknown-linux-gnu -fsycl --offload-new-driver \ // RUN: --sysroot=%S/Inputs/SYCL -### %s 2>&1 \ // RUN: | FileCheck -check-prefix WRAPPER_OPTIONS %s -// WRAPPER_OPTIONS: clang-linker-wrapper{{.*}} "-sycl-device-libraries=libsycl-crt.new.o,libsycl-complex.new.o,libsycl-complex-fp64.new.o,libsycl-cmath.new.o,libsycl-cmath-fp64.new.o,libsycl-imf.new.o,libsycl-imf-fp64.new.o,libsycl-imf-bf16.new.o,libsycl-fallback-cassert.new.o,libsycl-fallback-cstring.new.o,libsycl-fallback-complex.new.o,libsycl-fallback-complex-fp64.new.o,libsycl-fallback-cmath.new.o,libsycl-fallback-cmath-fp64.new.o,libsycl-fallback-imf.new.o,libsycl-fallback-imf-fp64.new.o,libsycl-fallback-imf-bf16.new.o,libsycl-itt-user-wrappers.new.o,libsycl-itt-compiler-wrappers.new.o,libsycl-itt-stubs.new.o" +// WRAPPER_OPTIONS: clang-linker-wrapper{{.*}} "-sycl-device-libraries=libsycl-crt.new.o,libsycl-complex.new.o,libsycl-complex-fp64.new.o,libsycl-cmath.new.o,libsycl-cmath-fp64.new.o,libsycl-imf.new.o,libsycl-imf-fp64.new.o,libsycl-imf-bf16.new.o,libsycl-fallback-cassert.new.o,libsycl-fallback-cstring.new.o,libsycl-fallback-complex.new.o,libsycl-fallback-complex-fp64.new.o,libsycl-fallback-cmath.new.o,libsycl-fallback-cmath-fp64.new.o,libsycl-fallback-gsort.new.o,libsycl-fallback-imf.new.o,libsycl-fallback-imf-fp64.new.o,libsycl-fallback-imf-bf16.new.o,libsycl-itt-user-wrappers.new.o,libsycl-itt-compiler-wrappers.new.o,libsycl-itt-stubs.new.o" // WRAPPER_OPTIONS-SAME: "-sycl-device-library-location={{.*}}/lib" /// Verify phases used to generate SPIR-V instead of LLVM-IR diff --git a/libdevice/atomic.hpp b/libdevice/atomic.hpp index ca35fa8767cd0..ab95e139ca47c 100644 --- a/libdevice/atomic.hpp +++ b/libdevice/atomic.hpp @@ -7,78 +7,9 @@ //===----------------------------------------------------------------------===// #pragma once -#include - -#include "device.h" +#include "spirv_decls.hpp" #if defined(__SPIR__) || defined(__SPIRV__) - -#define SPIR_GLOBAL __attribute__((opencl_global)) - -namespace __spv { -struct Scope { - - enum Flag : uint32_t { - CrossDevice = 0, - Device = 1, - Workgroup = 2, - Subgroup = 3, - Invocation = 4, - }; - - constexpr Scope(Flag flag) : flag_value(flag) {} - - constexpr operator uint32_t() const { return flag_value; } - - Flag flag_value; -}; - -struct MemorySemanticsMask { - - enum Flag : uint32_t { - None = 0x0, - Acquire = 0x2, - Release = 0x4, - AcquireRelease = 0x8, - SequentiallyConsistent = 0x10, - UniformMemory = 0x40, - SubgroupMemory = 0x80, - WorkgroupMemory = 0x100, - CrossWorkgroupMemory = 0x200, - AtomicCounterMemory = 0x400, - ImageMemory = 0x800, - }; - - constexpr MemorySemanticsMask(Flag flag) : flag_value(flag) {} - - constexpr operator uint32_t() const { return flag_value; } - - Flag flag_value; -}; -} // namespace __spv - -extern DEVICE_EXTERNAL int -__spirv_AtomicCompareExchange(int SPIR_GLOBAL *, __spv::Scope::Flag, - __spv::MemorySemanticsMask::Flag, - __spv::MemorySemanticsMask::Flag, int, int); - -extern DEVICE_EXTERNAL int -__spirv_AtomicCompareExchange(int *, __spv::Scope::Flag, - __spv::MemorySemanticsMask::Flag, - __spv::MemorySemanticsMask::Flag, int, int); - -extern DEVICE_EXTERNAL int __spirv_AtomicLoad(const int SPIR_GLOBAL *, - __spv::Scope::Flag, - __spv::MemorySemanticsMask::Flag); - -extern DEVICE_EXTERNAL void -__spirv_AtomicStore(int SPIR_GLOBAL *, __spv::Scope::Flag, - __spv::MemorySemanticsMask::Flag, int); - -extern DEVICE_EXTERNAL void -__spirv_AtomicStore(int *, __spv::Scope::Flag, __spv::MemorySemanticsMask::Flag, - int); - /// Atomically set the value in *Ptr with Desired if and only if it is Expected /// Return the value which already was in *Ptr static inline int atomicCompareAndSet(SPIR_GLOBAL int *Ptr, int Desired, diff --git a/libdevice/cmake/modules/SYCLLibdevice.cmake b/libdevice/cmake/modules/SYCLLibdevice.cmake index 0b2c1780a4756..f9f5ed6d9809e 100644 --- a/libdevice/cmake/modules/SYCLLibdevice.cmake +++ b/libdevice/cmake/modules/SYCLLibdevice.cmake @@ -218,9 +218,8 @@ function(add_devicelibs filename) DEPENDENCIES ${ARG_DEPENDENCIES} EXTRA_OPTS ${ARG_EXTRA_OPTS} ${bc_device_compile_opts} ${compile_opts_${arch}}) - - append_to_property(${bc_binary_dir}/${filename}-${arch}.bc - PROPERTY_NAME BC_DEVICE_LIBS_${arch}) + append_to_property(${bc_binary_dir}/${filename}-${arch}.bc + PROPERTY_NAME BC_DEVICE_LIBS_${arch}) endforeach() endfunction() @@ -233,7 +232,7 @@ set(itt_obj_deps device_itt.h spirv_vars.h device.h sycl-compiler) set(bfloat16_obj_deps sycl-headers sycl-compiler) if (NOT MSVC AND UR_SANITIZER_INCLUDE_DIR) set(asan_obj_deps - device.h atomic.hpp spirv_vars.h + device.h atomic.hpp spirv_vars.h spirv_decls.hpp ${UR_SANITIZER_INCLUDE_DIR}/asan/asan_libdevice.hpp include/asan_rtl.hpp include/spir_global_var.hpp @@ -296,6 +295,7 @@ if (NOT MSVC AND UR_SANITIZER_INCLUDE_DIR) include/spir_global_var.hpp sycl-compiler) endif() +set(gsort_obj_deps device.h spirv_decls.hpp spirv_vars.h group_helper.hpp sort_helper.hpp sycl-compiler) if("native_cpu" IN_LIST SYCL_ENABLE_BACKENDS) if (NOT DEFINED NATIVE_CPU_DIR) @@ -416,6 +416,11 @@ add_devicelibs(libsycl-fallback-bfloat16 add_devicelibs(libsycl-native-bfloat16 SRC bfloat16_wrapper.cpp DEPENDENCIES ${bfloat16_obj_deps}) +add_devicelibs(libsycl-fallback-gsort + SRC fallback-gsort.cpp + DEPENDENCIES ${gsort_obj_deps} + EXTRA_OPTS -fno-sycl-instrument-device-code + SKIP_ARCHS amdgcn-amd-amdhsa) # Create dependency and source lists for Intel math function libraries. file(MAKE_DIRECTORY ${obj_binary_dir}/libdevice) diff --git a/libdevice/fallback-gsort.cpp b/libdevice/fallback-gsort.cpp new file mode 100644 index 0000000000000..f32a92fab0f66 --- /dev/null +++ b/libdevice/fallback-gsort.cpp @@ -0,0 +1,7090 @@ + +//==------ fallback_gsort.cpp - fallback implementation of group sort-------==// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +#include "device.h" +#include "sort_helper.hpp" +#include +#include +#if defined(__SPIR__) || defined(__SPIRV__) + +#define WG_JS_A(EP) __devicelib_default_work_group_joint_sort_ascending_##EP +#define WG_JS_D(EP) __devicelib_default_work_group_joint_sort_descending_##EP +#define WG_PS_CA(EP) \ + __devicelib_default_work_group_private_sort_close_ascending_##EP +#define WG_PS_CD(EP) \ + __devicelib_default_work_group_private_sort_close_descending_##EP +#define WG_PS_SA(EP) \ + __devicelib_default_work_group_private_sort_spread_ascending_##EP +#define WG_PS_SD(EP) \ + __devicelib_default_work_group_private_sort_spread_descending_##EP +#define SG_PS_A(EP) __devicelib_default_sub_group_private_sort_ascending_##EP +#define SG_PS_D(EP) __devicelib_default_sub_group_private_sort_descending_##EP + +//============ default work grop joint sort for signed integer =============== +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1i8_u32_p1i8)(int8_t *first, uint32_t n, uint8_t *scratch) { + merge_sort(first, n, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1i8_u32_p3i8)(int8_t *first, uint32_t n, uint8_t *scratch) { + merge_sort(first, n, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p3i8_u32_p1i8)(int8_t *first, uint32_t n, uint8_t *scratch) { + merge_sort(first, n, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p3i8_u32_p3i8)(int8_t *first, uint32_t n, uint8_t *scratch) { + merge_sort(first, n, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1i8_u32_p1i8)(int8_t *first, uint32_t n, uint8_t *scratch) { + merge_sort(first, n, scratch, std::greater{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1i8_u32_p3i8)(int8_t *first, uint32_t n, uint8_t *scratch) { + merge_sort(first, n, scratch, std::greater{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p3i8_u32_p1i8)(int8_t *first, uint32_t n, uint8_t *scratch) { + merge_sort(first, n, scratch, std::greater{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p3i8_u32_p3i8)(int8_t *first, uint32_t n, uint8_t *scratch) { + merge_sort(first, n, scratch, std::greater{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1i16_u32_p1i8)(int16_t *first, uint32_t n, uint8_t *scratch) { + merge_sort(first, n, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1i16_u32_p3i8)(int16_t *first, uint32_t n, uint8_t *scratch) { + merge_sort(first, n, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p3i16_u32_p1i8)(int16_t *first, uint32_t n, uint8_t *scratch) { + merge_sort(first, n, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p3i16_u32_p3i8)(int16_t *first, uint32_t n, uint8_t *scratch) { + merge_sort(first, n, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1i16_u32_p1i8)(int16_t *first, uint32_t n, uint8_t *scratch) { + merge_sort(first, n, scratch, std::greater{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1i16_u32_p3i8)(int16_t *first, uint32_t n, uint8_t *scratch) { + merge_sort(first, n, scratch, std::greater{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p3i16_u32_p1i8)(int16_t *first, uint32_t n, uint8_t *scratch) { + merge_sort(first, n, scratch, std::greater{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p3i16_u32_p3i8)(int16_t *first, uint32_t n, uint8_t *scratch) { + merge_sort(first, n, scratch, std::greater{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1i32_u32_p1i8)(int32_t *first, uint32_t n, uint8_t *scratch) { + merge_sort(first, n, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1i32_u32_p3i8)(int32_t *first, uint32_t n, uint8_t *scratch) { + merge_sort(first, n, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p3i32_u32_p1i8)(int32_t *first, uint32_t n, uint8_t *scratch) { + merge_sort(first, n, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p3i32_u32_p3i8)(int32_t *first, uint32_t n, uint8_t *scratch) { + merge_sort(first, n, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1i32_u32_p1i8)(int32_t *first, uint32_t n, uint8_t *scratch) { + merge_sort(first, n, scratch, std::greater{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1i32_u32_p3i8)(int32_t *first, uint32_t n, uint8_t *scratch) { + merge_sort(first, n, scratch, std::greater{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p3i32_u32_p1i8)(int32_t *first, uint32_t n, uint8_t *scratch) { + merge_sort(first, n, scratch, std::greater{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p3i32_u32_p3i8)(int32_t *first, uint32_t n, uint8_t *scratch) { + merge_sort(first, n, scratch, std::greater{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1i64_u32_p1i8)(int64_t *first, uint32_t n, uint8_t *scratch) { + merge_sort(first, n, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1i64_u32_p3i8)(int64_t *first, uint32_t n, uint8_t *scratch) { + merge_sort(first, n, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p3i64_u32_p1i8)(int64_t *first, uint32_t n, uint8_t *scratch) { + merge_sort(first, n, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p3i64_u32_p3i8)(int64_t *first, uint32_t n, uint8_t *scratch) { + merge_sort(first, n, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1i64_u32_p1i8)(int64_t *first, uint32_t n, uint8_t *scratch) { + merge_sort(first, n, scratch, std::greater{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1i64_u32_p3i8)(int64_t *first, uint32_t n, uint8_t *scratch) { + merge_sort(first, n, scratch, std::greater{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p3i64_u32_p1i8)(int64_t *first, uint32_t n, uint8_t *scratch) { + merge_sort(first, n, scratch, std::greater{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p3i64_u32_p3i8)(int64_t *first, uint32_t n, uint8_t *scratch) { + merge_sort(first, n, scratch, std::greater{}); +} + +//=========== default work grop joint sort for unsigned integer ============== +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1u8_u32_p1i8)(uint8_t *first, uint32_t n, uint8_t *scratch) { + merge_sort(first, n, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1u8_u32_p3i8)(uint8_t *first, uint32_t n, uint8_t *scratch) { + merge_sort(first, n, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p3u8_u32_p1i8)(uint8_t *first, uint32_t n, uint8_t *scratch) { + merge_sort(first, n, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p3u8_u32_p3i8)(uint8_t *first, uint32_t n, uint8_t *scratch) { + merge_sort(first, n, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1u8_u32_p1i8)(uint8_t *first, uint32_t n, uint8_t *scratch) { + merge_sort(first, n, scratch, std::greater{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1u8_u32_p3i8)(uint8_t *first, uint32_t n, uint8_t *scratch) { + merge_sort(first, n, scratch, std::greater{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p3u8_u32_p1i8)(uint8_t *first, uint32_t n, uint8_t *scratch) { + merge_sort(first, n, scratch, std::greater{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p3u8_u32_p3i8)(uint8_t *first, uint32_t n, uint8_t *scratch) { + merge_sort(first, n, scratch, std::greater{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1u16_u32_p1i8)(uint16_t *first, uint32_t n, uint8_t *scratch) { + merge_sort(first, n, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1u16_u32_p3i8)(uint16_t *first, uint32_t n, uint8_t *scratch) { + merge_sort(first, n, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p3u16_u32_p1i8)(uint16_t *first, uint32_t n, uint8_t *scratch) { + merge_sort(first, n, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p3u16_u32_p3i8)(uint16_t *first, uint32_t n, uint8_t *scratch) { + merge_sort(first, n, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1u16_u32_p1i8)(uint16_t *first, uint32_t n, uint8_t *scratch) { + merge_sort(first, n, scratch, std::greater{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1u16_u32_p3i8)(uint16_t *first, uint32_t n, uint8_t *scratch) { + merge_sort(first, n, scratch, std::greater{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p3u16_u32_p1i8)(uint16_t *first, uint32_t n, uint8_t *scratch) { + merge_sort(first, n, scratch, std::greater{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p3u16_u32_p3i8)(uint16_t *first, uint32_t n, uint8_t *scratch) { + merge_sort(first, n, scratch, std::greater{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1u32_u32_p1i8)(uint32_t *first, uint32_t n, uint8_t *scratch) { + merge_sort(first, n, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1u32_u32_p3i8)(uint32_t *first, uint32_t n, uint8_t *scratch) { + merge_sort(first, n, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p3u32_u32_p1i8)(uint32_t *first, uint32_t n, uint8_t *scratch) { + merge_sort(first, n, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p3u32_u32_p3i8)(uint32_t *first, uint32_t n, uint8_t *scratch) { + merge_sort(first, n, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1u32_u32_p1i8)(uint32_t *first, uint32_t n, uint8_t *scratch) { + merge_sort(first, n, scratch, std::greater{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1u32_u32_p3i8)(uint32_t *first, uint32_t n, uint8_t *scratch) { + merge_sort(first, n, scratch, std::greater{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p3u32_u32_p1i8)(uint32_t *first, uint32_t n, uint8_t *scratch) { + merge_sort(first, n, scratch, std::greater{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p3u32_u32_p3i8)(uint32_t *first, uint32_t n, uint8_t *scratch) { + merge_sort(first, n, scratch, std::greater{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1u64_u32_p1i8)(uint64_t *first, uint32_t n, uint8_t *scratch) { + merge_sort(first, n, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1u64_u32_p3i8)(uint64_t *first, uint32_t n, uint8_t *scratch) { + merge_sort(first, n, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p3u64_u32_p1i8)(uint64_t *first, uint32_t n, uint8_t *scratch) { + merge_sort(first, n, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p3u64_u32_p3i8)(uint64_t *first, uint32_t n, uint8_t *scratch) { + merge_sort(first, n, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1u64_u32_p1i8)(uint64_t *first, uint32_t n, uint8_t *scratch) { + merge_sort(first, n, scratch, std::greater{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1u64_u32_p3i8)(uint64_t *first, uint32_t n, uint8_t *scratch) { + merge_sort(first, n, scratch, std::greater{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p3u64_u32_p1i8)(uint64_t *first, uint32_t n, uint8_t *scratch) { + merge_sort(first, n, scratch, std::greater{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p3u64_u32_p3i8)(uint64_t *first, uint32_t n, uint8_t *scratch) { + merge_sort(first, n, scratch, std::greater{}); +} + +//=============== default work grop joint sort for fp32 ====================== +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1f32_u32_p1i8)(float *first, uint32_t n, uint8_t *scratch) { + merge_sort(first, n, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1f32_u32_p3i8)(float *first, uint32_t n, uint8_t *scratch) { + merge_sort(first, n, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p3f32_u32_p1i8)(float *first, uint32_t n, uint8_t *scratch) { + merge_sort(first, n, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p3f32_u32_p3i8)(float *first, uint32_t n, uint8_t *scratch) { + merge_sort(first, n, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1f32_u32_p1i8)(float *first, uint32_t n, uint8_t *scratch) { + merge_sort(first, n, scratch, std::greater{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1f32_u32_p3i8)(float *first, uint32_t n, uint8_t *scratch) { + merge_sort(first, n, scratch, std::greater{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p3f32_u32_p1i8)(float *first, uint32_t n, uint8_t *scratch) { + merge_sort(first, n, scratch, std::greater{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p3f32_u32_p3i8)(float *first, uint32_t n, uint8_t *scratch) { + merge_sort(first, n, scratch, std::greater{}); +} + +// TODO: split all f16 functions into separate libraries in case some platform +// doesn't support native fp16 +//=============== default work grop joint sort for fp16 ====================== +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1f16_u32_p1i8)(_Float16 *first, uint32_t n, uint8_t *scratch) { + merge_sort(first, n, scratch, [](_Float16 a, _Float16 b) { return (a < b); }); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1f16_u32_p3i8)(_Float16 *first, uint32_t n, uint8_t *scratch) { + merge_sort(first, n, scratch, [](_Float16 a, _Float16 b) { return (a < b); }); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p3f16_u32_p1i8)(_Float16 *first, uint32_t n, uint8_t *scratch) { + merge_sort(first, n, scratch, [](_Float16 a, _Float16 b) { return (a < b); }); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p3f16_u32_p3i8)(_Float16 *first, uint32_t n, uint8_t *scratch) { + merge_sort(first, n, scratch, [](_Float16 a, _Float16 b) { return (a < b); }); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1f16_u32_p1i8)(_Float16 *first, uint32_t n, uint8_t *scratch) { + merge_sort(first, n, scratch, [](_Float16 a, _Float16 b) { return (a > b); }); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1f16_u32_p3i8)(_Float16 *first, uint32_t n, uint8_t *scratch) { + merge_sort(first, n, scratch, [](_Float16 a, _Float16 b) { return (a > b); }); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p3f16_u32_p1i8)(_Float16 *first, uint32_t n, uint8_t *scratch) { + merge_sort(first, n, scratch, [](_Float16 a, _Float16 b) { return (a > b); }); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p3f16_u32_p3i8)(_Float16 *first, uint32_t n, uint8_t *scratch) { + merge_sort(first, n, scratch, [](_Float16 a, _Float16 b) { return (a > b); }); +} + +//============ default work grop private sort for signed integer ============== +// Since 'first' should point to 'private' memory address space, it can only be +// decorated with 'p1'. +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1i8_u32_p1i8)(int8_t *first, uint32_t n, uint8_t *scratch) { + private_merge_sort_close(first, n, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1i8_u32_p3i8)(int8_t *first, uint32_t n, uint8_t *scratch) { + private_merge_sort_close(first, n, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1i8_u32_p1i8)(int8_t *first, uint32_t n, uint8_t *scratch) { + private_merge_sort_close(first, n, scratch, std::greater{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1i8_u32_p3i8)(int8_t *first, uint32_t n, uint8_t *scratch) { + private_merge_sort_close(first, n, scratch, std::greater{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1i8_u32_p1i8)(int8_t *first, uint32_t n, uint8_t *scratch) { + private_merge_sort_spread(first, n, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1i8_u32_p3i8)(int8_t *first, uint32_t n, uint8_t *scratch) { + private_merge_sort_spread(first, n, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1i8_u32_p1i8)(int8_t *first, uint32_t n, uint8_t *scratch) { + private_merge_sort_spread(first, n, scratch, std::greater{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1i8_u32_p3i8)(int8_t *first, uint32_t n, uint8_t *scratch) { + private_merge_sort_spread(first, n, scratch, std::greater{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1i16_u32_p1i8)(int16_t *first, uint32_t n, uint8_t *scratch) { + private_merge_sort_close(first, n, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1i16_u32_p3i8)(int16_t *first, uint32_t n, uint8_t *scratch) { + private_merge_sort_close(first, n, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1i16_u32_p1i8)(int16_t *first, uint32_t n, uint8_t *scratch) { + private_merge_sort_close(first, n, scratch, std::greater{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1i16_u32_p3i8)(int16_t *first, uint32_t n, uint8_t *scratch) { + private_merge_sort_close(first, n, scratch, std::greater{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1i16_u32_p1i8)(int16_t *first, uint32_t n, uint8_t *scratch) { + private_merge_sort_spread(first, n, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1i16_u32_p3i8)(int16_t *first, uint32_t n, uint8_t *scratch) { + private_merge_sort_spread(first, n, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1i16_u32_p1i8)(int16_t *first, uint32_t n, uint8_t *scratch) { + private_merge_sort_spread(first, n, scratch, std::greater{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1i16_u32_p3i8)(int16_t *first, uint32_t n, uint8_t *scratch) { + private_merge_sort_spread(first, n, scratch, std::greater{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1i32_u32_p1i8)(int32_t *first, uint32_t n, uint8_t *scratch) { + private_merge_sort_close(first, n, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1i32_u32_p3i8)(int32_t *first, uint32_t n, uint8_t *scratch) { + private_merge_sort_close(first, n, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1i32_u32_p1i8)(int32_t *first, uint32_t n, uint8_t *scratch) { + private_merge_sort_close(first, n, scratch, std::greater{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1i32_u32_p3i8)(int32_t *first, uint32_t n, uint8_t *scratch) { + private_merge_sort_close(first, n, scratch, std::greater{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1i32_u32_p1i8)(int32_t *first, uint32_t n, uint8_t *scratch) { + private_merge_sort_spread(first, n, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1i32_u32_p3i8)(int32_t *first, uint32_t n, uint8_t *scratch) { + private_merge_sort_spread(first, n, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1i32_u32_p1i8)(int32_t *first, uint32_t n, uint8_t *scratch) { + private_merge_sort_spread(first, n, scratch, std::greater{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1i32_u32_p3i8)(int32_t *first, uint32_t n, uint8_t *scratch) { + private_merge_sort_spread(first, n, scratch, std::greater{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1i64_u32_p1i8)(int64_t *first, uint32_t n, uint8_t *scratch) { + private_merge_sort_close(first, n, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1i64_u32_p3i8)(int64_t *first, uint32_t n, uint8_t *scratch) { + private_merge_sort_close(first, n, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1i64_u32_p1i8)(int64_t *first, uint32_t n, uint8_t *scratch) { + private_merge_sort_close(first, n, scratch, std::greater{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1i64_u32_p3i8)(int64_t *first, uint32_t n, uint8_t *scratch) { + private_merge_sort_close(first, n, scratch, std::greater{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1i64_u32_p1i8)(int64_t *first, uint32_t n, uint8_t *scratch) { + private_merge_sort_spread(first, n, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1i64_u32_p3i8)(int64_t *first, uint32_t n, uint8_t *scratch) { + private_merge_sort_spread(first, n, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1i64_u32_p1i8)(int64_t *first, uint32_t n, uint8_t *scratch) { + private_merge_sort_spread(first, n, scratch, std::greater{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1i64_u32_p3i8)(int64_t *first, uint32_t n, uint8_t *scratch) { + private_merge_sort_spread(first, n, scratch, std::greater{}); +} + +//=========== default work grop private sort for unsigned integer ============= + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1u8_u32_p1i8)(uint8_t *first, uint32_t n, uint8_t *scratch) { + private_merge_sort_close(first, n, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1u8_u32_p3i8)(uint8_t *first, uint32_t n, uint8_t *scratch) { + private_merge_sort_close(first, n, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1u8_u32_p1i8)(uint8_t *first, uint32_t n, uint8_t *scratch) { + private_merge_sort_close(first, n, scratch, std::greater{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1u8_u32_p3i8)(uint8_t *first, uint32_t n, uint8_t *scratch) { + private_merge_sort_close(first, n, scratch, std::greater{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1u8_u32_p1i8)(uint8_t *first, uint32_t n, uint8_t *scratch) { + private_merge_sort_spread(first, n, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1u8_u32_p3i8)(uint8_t *first, uint32_t n, uint8_t *scratch) { + private_merge_sort_spread(first, n, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1u8_u32_p1i8)(uint8_t *first, uint32_t n, uint8_t *scratch) { + private_merge_sort_spread(first, n, scratch, std::greater{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1u8_u32_p3i8)(uint8_t *first, uint32_t n, uint8_t *scratch) { + private_merge_sort_spread(first, n, scratch, std::greater{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1u16_u32_p1i8)(uint16_t *first, uint32_t n, uint8_t *scratch) { + private_merge_sort_close(first, n, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1u16_u32_p3i8)(uint16_t *first, uint32_t n, uint8_t *scratch) { + private_merge_sort_close(first, n, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1u16_u32_p1i8)(uint16_t *first, uint32_t n, uint8_t *scratch) { + private_merge_sort_close(first, n, scratch, std::greater{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1u16_u32_p3i8)(uint16_t *first, uint32_t n, uint8_t *scratch) { + private_merge_sort_close(first, n, scratch, std::greater{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1u16_u32_p1i8)(uint16_t *first, uint32_t n, uint8_t *scratch) { + private_merge_sort_spread(first, n, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1u16_u32_p3i8)(uint16_t *first, uint32_t n, uint8_t *scratch) { + private_merge_sort_spread(first, n, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1u16_u32_p1i8)(uint16_t *first, uint32_t n, uint8_t *scratch) { + private_merge_sort_spread(first, n, scratch, std::greater{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1u16_u32_p3i8)(uint16_t *first, uint32_t n, uint8_t *scratch) { + private_merge_sort_spread(first, n, scratch, std::greater{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1u32_u32_p1i8)(uint32_t *first, uint32_t n, uint8_t *scratch) { + private_merge_sort_close(first, n, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1u32_u32_p3i8)(uint32_t *first, uint32_t n, uint8_t *scratch) { + private_merge_sort_close(first, n, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1u32_u32_p1i8)(uint32_t *first, uint32_t n, uint8_t *scratch) { + private_merge_sort_close(first, n, scratch, std::greater{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1u32_u32_p3i8)(uint32_t *first, uint32_t n, uint8_t *scratch) { + private_merge_sort_close(first, n, scratch, std::greater{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1u32_u32_p1i8)(uint32_t *first, uint32_t n, uint8_t *scratch) { + private_merge_sort_spread(first, n, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1u32_u32_p3i8)(uint32_t *first, uint32_t n, uint8_t *scratch) { + private_merge_sort_spread(first, n, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1u32_u32_p1i8)(uint32_t *first, uint32_t n, uint8_t *scratch) { + private_merge_sort_spread(first, n, scratch, std::greater{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1u32_u32_p3i8)(uint32_t *first, uint32_t n, uint8_t *scratch) { + private_merge_sort_spread(first, n, scratch, std::greater{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1u64_u32_p1i8)(uint64_t *first, uint32_t n, uint8_t *scratch) { + private_merge_sort_close(first, n, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1u64_u32_p3i8)(uint64_t *first, uint32_t n, uint8_t *scratch) { + private_merge_sort_close(first, n, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1u64_u32_p1i8)(uint64_t *first, uint32_t n, uint8_t *scratch) { + private_merge_sort_close(first, n, scratch, std::greater{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1u64_u32_p3i8)(uint64_t *first, uint32_t n, uint8_t *scratch) { + private_merge_sort_close(first, n, scratch, std::greater{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1u64_u32_p1i8)(uint64_t *first, uint32_t n, uint8_t *scratch) { + private_merge_sort_spread(first, n, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1u64_u32_p3i8)(uint64_t *first, uint32_t n, uint8_t *scratch) { + private_merge_sort_spread(first, n, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1u64_u32_p1i8)(uint64_t *first, uint32_t n, uint8_t *scratch) { + private_merge_sort_spread(first, n, scratch, std::greater{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1u64_u32_p3i8)(uint64_t *first, uint32_t n, uint8_t *scratch) { + private_merge_sort_spread(first, n, scratch, std::greater{}); +} + +//================= default work grop private sort for fp32 ==================== + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1f32_u32_p1i8)(float *first, uint32_t n, uint8_t *scratch) { + private_merge_sort_close(first, n, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1f32_u32_p3i8)(float *first, uint32_t n, uint8_t *scratch) { + private_merge_sort_close(first, n, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1f32_u32_p1i8)(float *first, uint32_t n, uint8_t *scratch) { + private_merge_sort_close(first, n, scratch, std::greater{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1f32_u32_p3i8)(float *first, uint32_t n, uint8_t *scratch) { + private_merge_sort_close(first, n, scratch, std::greater{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1f32_u32_p1i8)(float *first, uint32_t n, uint8_t *scratch) { + private_merge_sort_spread(first, n, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1f32_u32_p3i8)(float *first, uint32_t n, uint8_t *scratch) { + private_merge_sort_spread(first, n, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1f32_u32_p1i8)(float *first, uint32_t n, uint8_t *scratch) { + private_merge_sort_spread(first, n, scratch, std::greater{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1f32_u32_p3i8)(float *first, uint32_t n, uint8_t *scratch) { + private_merge_sort_spread(first, n, scratch, std::greater{}); +} + +//================= default work grop private sort for fp16 ==================== + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1f16_u32_p1i8)(_Float16 *first, uint32_t n, uint8_t *scratch) { + private_merge_sort_close(first, n, scratch, + [](_Float16 a, _Float16 b) { return (a < b); }); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1f16_u32_p3i8)(_Float16 *first, uint32_t n, uint8_t *scratch) { + private_merge_sort_close(first, n, scratch, + [](_Float16 a, _Float16 b) { return (a < b); }); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1f16_u32_p1i8)(_Float16 *first, uint32_t n, uint8_t *scratch) { + private_merge_sort_close(first, n, scratch, + [](_Float16 a, _Float16 b) { return (a > b); }); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1f16_u32_p3i8)(_Float16 *first, uint32_t n, uint8_t *scratch) { + private_merge_sort_close(first, n, scratch, + [](_Float16 a, _Float16 b) { return (a > b); }); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1f16_u32_p1i8)(_Float16 *first, uint32_t n, uint8_t *scratch) { + private_merge_sort_spread(first, n, scratch, + [](_Float16 a, _Float16 b) { return (a < b); }); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1f16_u32_p3i8)(_Float16 *first, uint32_t n, uint8_t *scratch) { + private_merge_sort_spread(first, n, scratch, + [](_Float16 a, _Float16 b) { return (a < b); }); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1f16_u32_p1i8)(_Float16 *first, uint32_t n, uint8_t *scratch) { + private_merge_sort_spread(first, n, scratch, + [](_Float16 a, _Float16 b) { return (a > b); }); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1f16_u32_p3i8)(_Float16 *first, uint32_t n, uint8_t *scratch) { + private_merge_sort_spread(first, n, scratch, + [](_Float16 a, _Float16 b) { return (a > b); }); +} + +//============= default sub group private sort for signed integer ============= +DEVICE_EXTERN_C_INLINE +int8_t SG_PS_A(i8_p1i8)(int8_t value, uint8_t *scratch) { + return sub_group_merge_sort(value, scratch, std::less{}); +} + +int8_t SG_PS_A(i8_p3i8)(int8_t value, uint8_t *scratch) { + return sub_group_merge_sort(value, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +int16_t SG_PS_A(i16_p1i8)(int16_t value, uint8_t *scratch) { + return sub_group_merge_sort(value, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +int16_t SG_PS_A(i16_p3i8)(int16_t value, uint8_t *scratch) { + return sub_group_merge_sort(value, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +int32_t SG_PS_A(i32_p1i8)(int32_t value, uint8_t *scratch) { + return sub_group_merge_sort(value, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +int32_t SG_PS_A(i32_p3i8)(int32_t value, uint8_t *scratch) { + return sub_group_merge_sort(value, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +int64_t SG_PS_A(i64_p1i8)(int64_t value, uint8_t *scratch) { + return sub_group_merge_sort(value, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +int64_t SG_PS_A(i64_p3i8)(int64_t value, uint8_t *scratch) { + return sub_group_merge_sort(value, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +uint8_t SG_PS_A(u8_p1i8)(uint8_t value, uint8_t *scratch) { + return sub_group_merge_sort(value, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +uint8_t SG_PS_A(u8_p3i8)(uint8_t value, uint8_t *scratch) { + return sub_group_merge_sort(value, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +uint16_t SG_PS_A(u16_p1i8)(uint16_t value, uint8_t *scratch) { + return sub_group_merge_sort(value, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +uint16_t SG_PS_A(u16_p3i8)(uint16_t value, uint8_t *scratch) { + return sub_group_merge_sort(value, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +uint32_t SG_PS_A(u32_p1i8)(uint32_t value, uint8_t *scratch) { + return sub_group_merge_sort(value, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +uint32_t SG_PS_A(u32_p3i8)(uint32_t value, uint8_t *scratch) { + return sub_group_merge_sort(value, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +uint64_t SG_PS_A(u64_p1i8)(uint64_t value, uint8_t *scratch) { + return sub_group_merge_sort(value, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +uint64_t SG_PS_A(u64_p3i8)(uint64_t value, uint8_t *scratch) { + return sub_group_merge_sort(value, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +float SG_PS_A(f32_p1i8)(float value, uint8_t *scratch) { + return sub_group_merge_sort(value, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +float SG_PS_A(f32_p3i8)(float value, uint8_t *scratch) { + return sub_group_merge_sort(value, scratch, std::less{}); +} + +DEVICE_EXTERN_C_INLINE +_Float16 SG_PS_A(f16_p1i8)(_Float16 value, uint8_t *scratch) { + return sub_group_merge_sort(value, scratch, + [](_Float16 a, _Float16 b) { return (a < b); }); +} + +DEVICE_EXTERN_C_INLINE +_Float16 SG_PS_A(f16_p3i8)(_Float16 value, uint8_t *scratch) { + return sub_group_merge_sort(value, scratch, + [](_Float16 a, _Float16 b) { return (a < b); }); +} + +DEVICE_EXTERN_C_INLINE +int8_t SG_PS_D(i8_p1i8)(int8_t value, uint8_t *scratch) { + return sub_group_merge_sort(value, scratch, std::greater{}); +} + +DEVICE_EXTERN_C_INLINE +int8_t SG_PS_D(i8_p3i8)(int8_t value, uint8_t *scratch) { + return sub_group_merge_sort(value, scratch, std::greater{}); +} + +DEVICE_EXTERN_C_INLINE +int16_t SG_PS_D(i16_p1i8)(int16_t value, uint8_t *scratch) { + return sub_group_merge_sort(value, scratch, std::greater{}); +} + +DEVICE_EXTERN_C_INLINE +int16_t SG_PS_D(i16_p3i8)(int16_t value, uint8_t *scratch) { + return sub_group_merge_sort(value, scratch, std::greater{}); +} + +DEVICE_EXTERN_C_INLINE +int32_t SG_PS_D(i32_p1i8)(int32_t value, uint8_t *scratch) { + return sub_group_merge_sort(value, scratch, std::greater{}); +} + +DEVICE_EXTERN_C_INLINE +int32_t SG_PS_D(i32_p3i8)(int32_t value, uint8_t *scratch) { + return sub_group_merge_sort(value, scratch, std::greater{}); +} + +DEVICE_EXTERN_C_INLINE +int64_t SG_PS_D(i64_p1i8)(int64_t value, uint8_t *scratch) { + return sub_group_merge_sort(value, scratch, std::greater{}); +} + +DEVICE_EXTERN_C_INLINE +int64_t SG_PS_D(i64_p3i8)(int64_t value, uint8_t *scratch) { + return sub_group_merge_sort(value, scratch, std::greater{}); +} + +DEVICE_EXTERN_C_INLINE +uint8_t SG_PS_D(u8_p1i8)(uint8_t value, uint8_t *scratch) { + return sub_group_merge_sort(value, scratch, std::greater{}); +} + +DEVICE_EXTERN_C_INLINE +uint8_t SG_PS_D(u8_p3i8)(uint8_t value, uint8_t *scratch) { + return sub_group_merge_sort(value, scratch, std::greater{}); +} + +DEVICE_EXTERN_C_INLINE +uint16_t SG_PS_D(u16_p1i8)(uint16_t value, uint8_t *scratch) { + return sub_group_merge_sort(value, scratch, std::greater{}); +} + +DEVICE_EXTERN_C_INLINE +uint16_t SG_PS_D(u16_p3i8)(uint16_t value, uint8_t *scratch) { + return sub_group_merge_sort(value, scratch, std::greater{}); +} + +DEVICE_EXTERN_C_INLINE +uint32_t SG_PS_D(u32_p1i8)(uint32_t value, uint8_t *scratch) { + return sub_group_merge_sort(value, scratch, std::greater{}); +} + +DEVICE_EXTERN_C_INLINE +uint32_t SG_PS_D(u32_p3i8)(uint32_t value, uint8_t *scratch) { + return sub_group_merge_sort(value, scratch, std::greater{}); +} + +DEVICE_EXTERN_C_INLINE +uint64_t SG_PS_D(u64_p1i8)(uint64_t value, uint8_t *scratch) { + return sub_group_merge_sort(value, scratch, std::greater{}); +} + +DEVICE_EXTERN_C_INLINE +uint64_t SG_PS_D(u64_p3i8)(uint64_t value, uint8_t *scratch) { + return sub_group_merge_sort(value, scratch, std::greater{}); +} + +DEVICE_EXTERN_C_INLINE +float SG_PS_D(f32_p1i8)(float value, uint8_t *scratch) { + return sub_group_merge_sort(value, scratch, std::greater{}); +} + +DEVICE_EXTERN_C_INLINE +float SG_PS_D(f32_p3i8)(float value, uint8_t *scratch) { + return sub_group_merge_sort(value, scratch, std::greater{}); +} + +DEVICE_EXTERN_C_INLINE +_Float16 SG_PS_D(f16_p1i8)(_Float16 value, uint8_t *scratch) { + return sub_group_merge_sort(value, scratch, + [](_Float16 a, _Float16 b) { return (a > b); }); +} + +DEVICE_EXTERN_C_INLINE +_Float16 SG_PS_D(f16_p3i8)(_Float16 value, uint8_t *scratch) { + return sub_group_merge_sort(value, scratch, + [](_Float16 a, _Float16 b) { return (a > b); }); +} + +//========= default work grop joint sort for (uint32_t, uint32_t) ============== + +// uint8_t as key type +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1u8_p1u8_u32_p1i8)(uint8_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1u8_p1u8_u32_p1i8)(uint8_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::greater_equal{}); +} + +// For int8_t values, the size and alignment are same as uint8_t, we use same +// implementation as uint8_t values to reduce code size. +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1u8_p1i8_u32_p1i8)(uint8_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1u8_p1i8_u32_p1i8)(uint8_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1u8_p1u16_u32_p1i8)(uint8_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1u8_p1u16_u32_p1i8)(uint8_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1u8_p1i16_u32_p1i8)(uint8_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1u8_p1i16_u32_p1i8)(uint8_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1u8_p1u32_u32_p1i8)(uint8_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1u8_p1u32_u32_p1i8)(uint8_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1u8_p1i32_u32_p1i8)(uint8_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1u8_p1i32_u32_p1i8)(uint8_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1u8_p1u64_u32_p1i8)(uint8_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1u8_p1u64_u32_p1i8)(uint8_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1u8_p1i64_u32_p1i8)(uint8_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1u8_p1i64_u32_p1i8)(uint8_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1u8_p1f32_u32_p1i8)(uint8_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1u8_p1f32_u32_p1i8)(uint8_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::greater_equal{}); +} + +// int8_t as key +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1i8_p1u8_u32_p1i8)(int8_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1i8_p1u8_u32_p1i8)(int8_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1i8_p1i8_u32_p1i8)(int8_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1i8_p1i8_u32_p1i8)(int8_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1i8_p1u16_u32_p1i8)(int8_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1i8_p1u16_u32_p1i8)(int8_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1i8_p1i16_u32_p1i8)(int8_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1i8_p1i16_u32_p1i8)(int8_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1i8_p1u32_u32_p1i8)(int8_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1i8_p1u32_u32_p1i8)(int8_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1i8_p1i32_u32_p1i8)(int8_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1i8_p1i32_u32_p1i8)(int8_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1i8_p1u64_u32_p1i8)(int8_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1i8_p1u64_u32_p1i8)(int8_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1i8_p1i64_u32_p1i8)(int8_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1i8_p1i64_u32_p1i8)(int8_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1i8_p1f32_u32_p1i8)(int8_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1i8_p1f32_u32_p1i8)(int8_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::greater_equal{}); +} + +// uint16_t as key type +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1u16_p1u8_u32_p1i8)(uint16_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1u16_p1u8_u32_p1i8)(uint16_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1u16_p1i8_u32_p1i8)(uint16_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1u16_p1i8_u32_p1i8)(uint16_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1u16_p1u16_u32_p1i8)(uint16_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1u16_p1u16_u32_p1i8)(uint16_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1u16_p1i16_u32_p1i8)(uint16_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1u16_p1i16_u32_p1i8)(uint16_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1u16_p1u32_u32_p1i8)(uint16_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1u16_p1u32_u32_p1i8)(uint16_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1u16_p1i32_u32_p1i8)(uint16_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1u16_p1i32_u32_p1i8)(uint16_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1u16_p1u64_u32_p1i8)(uint16_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1u16_p1u64_u32_p1i8)(uint16_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1u16_p1i64_u32_p1i8)(uint16_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1u16_p1i64_u32_p1i8)(uint16_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1u16_p1f32_u32_p1i8)(uint16_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1u16_p1f32_u32_p1i8)(uint16_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::greater_equal{}); +} + +// int16_t as key type +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1i16_p1u8_u32_p1i8)(int16_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1i16_p1u8_u32_p1i8)(int16_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1i16_p1i8_u32_p1i8)(int16_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1i16_p1i8_u32_p1i8)(int16_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1i16_p1u16_u32_p1i8)(int16_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1i16_p1u16_u32_p1i8)(int16_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1i16_p1i16_u32_p1i8)(int16_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1i16_p1i16_u32_p1i8)(int16_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1i16_p1u32_u32_p1i8)(int16_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1i16_p1u32_u32_p1i8)(int16_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1i16_p1i32_u32_p1i8)(int16_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1i16_p1i32_u32_p1i8)(int16_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1i16_p1u64_u32_p1i8)(int16_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1i16_p1u64_u32_p1i8)(int16_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1i16_p1i64_u32_p1i8)(int16_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1i16_p1i64_u32_p1i8)(int16_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1i16_p1f32_u32_p1i8)(int16_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1i16_p1f32_u32_p1i8)(int16_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::greater_equal{}); +} + +// uint32_t as key type +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1u32_p1u8_u32_p1i8)(uint32_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1u32_p1u8_u32_p1i8)(uint32_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1u32_p1i8_u32_p1i8)(uint32_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1u32_p1i8_u32_p1i8)(uint32_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1u32_p1u16_u32_p1i8)(uint32_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1u32_p1u16_u32_p1i8)(uint32_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1u32_p1i16_u32_p1i8)(uint32_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1u32_p1i16_u32_p1i8)(uint32_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1u32_p1u32_u32_p1i8)(uint32_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1u32_p1u32_u32_p1i8)(uint32_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1u32_p1i32_u32_p1i8)(uint32_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1u32_p1i32_u32_p1i8)(uint32_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1u32_p1u64_u32_p1i8)(uint32_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1u32_p1u64_u32_p1i8)(uint32_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1u32_p1i64_u32_p1i8)(uint32_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1u32_p1i64_u32_p1i8)(uint32_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1u32_p1f32_u32_p1i8)(uint32_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1u32_p1f32_u32_p1i8)(uint32_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::greater_equal{}); +} + +// int32_t as key type +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1i32_p1u8_u32_p1i8)(int32_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1i32_p1u8_u32_p1i8)(int32_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1i32_p1i8_u32_p1i8)(int32_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1i32_p1i8_u32_p1i8)(int32_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1i32_p1u16_u32_p1i8)(int32_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1i32_p1u16_u32_p1i8)(int32_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1i32_p1i16_u32_p1i8)(int32_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1i32_p1i16_u32_p1i8)(int32_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1i32_p1u32_u32_p1i8)(int32_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1i32_p1u32_u32_p1i8)(int32_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1i32_p1i32_u32_p1i8)(int32_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1i32_p1i32_u32_p1i8)(int32_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1i32_p1u64_u32_p1i8)(int32_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1i32_p1u64_u32_p1i8)(int32_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1i32_p1i64_u32_p1i8)(int32_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1i32_p1i64_u32_p1i8)(int32_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1i32_p1f32_u32_p1i8)(int32_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1i32_p1f32_u32_p1i8)(int32_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::greater_equal{}); +} + +// uint64_t as key type +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1u64_p1u8_u32_p1i8)(uint64_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1u64_p1u8_u32_p1i8)(uint64_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1u64_p1i8_u32_p1i8)(uint64_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1u64_p1i8_u32_p1i8)(uint64_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1u64_p1u16_u32_p1i8)(uint64_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1u64_p1u16_u32_p1i8)(uint64_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1u64_p1i16_u32_p1i8)(uint64_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1u64_p1i16_u32_p1i8)(uint64_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1u64_p1u32_u32_p1i8)(uint64_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1u64_p1u32_u32_p1i8)(uint64_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1u64_p1i32_u32_p1i8)(uint64_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1u64_p1i32_u32_p1i8)(uint64_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1u64_p1u64_u32_p1i8)(uint64_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1u64_p1u64_u32_p1i8)(uint64_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1u64_p1i64_u32_p1i8)(uint64_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1u64_p1i64_u32_p1i8)(uint64_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1u64_p1f32_u32_p1i8)(uint64_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1u64_p1f32_u32_p1i8)(uint64_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::greater_equal{}); +} + +// int64_t as key type +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1i64_p1u8_u32_p1i8)(int64_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1i64_p1u8_u32_p1i8)(int64_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1i64_p1i8_u32_p1i8)(int64_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1i64_p1i8_u32_p1i8)(int64_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1i64_p1u16_u32_p1i8)(int64_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1i64_p1u16_u32_p1i8)(int64_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1i64_p1i16_u32_p1i8)(int64_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1i64_p1i16_u32_p1i8)(int64_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1i64_p1u32_u32_p1i8)(int64_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1i64_p1u32_u32_p1i8)(int64_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1i64_p1i32_u32_p1i8)(int64_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1i64_p1i32_u32_p1i8)(int64_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1i64_p1u64_u32_p1i8)(int64_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1i64_p1u64_u32_p1i8)(int64_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1i64_p1i64_u32_p1i8)(int64_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1i64_p1i64_u32_p1i8)(int64_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1i64_p1f32_u32_p1i8)(int64_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1i64_p1f32_u32_p1i8)(int64_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::greater_equal{}); +} + +// Work group private sorting algorithms. +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1u8_p1u8_u32_p1i8)(uint8_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1u8_p1u8_u32_p1i8)(uint8_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1u8_p1i8_u32_p1i8)(uint8_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), n, + scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1u8_p1i8_u32_p1i8)(uint8_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), n, + scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1u8_p1u16_u32_p1i8)(uint8_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1u8_p1u16_u32_p1i8)(uint8_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1u8_p1i16_u32_p1i8)(uint8_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1u8_p1i16_u32_p1i8)(uint8_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1u8_p1u32_u32_p1i8)(uint8_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1u8_p1u32_u32_p1i8)(uint8_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1u8_p1i32_u32_p1i8)(uint8_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1u8_p1i32_u32_p1i8)(uint8_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1u8_p1u64_u32_p1i8)(uint8_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1u8_p1u64_u32_p1i8)(uint8_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1u8_p1i64_u32_p1i8)(uint8_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1u8_p1i64_u32_p1i8)(uint8_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1u8_p1f32_u32_p1i8)(uint8_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1u8_p1f32_u32_p1i8)(uint8_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1i8_p1u8_u32_p1i8)(int8_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1i8_p1u8_u32_p1i8)(int8_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1i8_p1i8_u32_p1i8)(int8_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), n, + scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1i8_p1i8_u32_p1i8)(int8_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), n, + scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1i8_p1u16_u32_p1i8)(int8_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1i8_p1u16_u32_p1i8)(int8_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1i8_p1i16_u32_p1i8)(int8_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1i8_p1i16_u32_p1i8)(int8_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1i8_p1u32_u32_p1i8)(int8_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1i8_p1u32_u32_p1i8)(int8_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1i8_p1i32_u32_p1i8)(int8_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1i8_p1i32_u32_p1i8)(int8_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1i8_p1u64_u32_p1i8)(int8_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1i8_p1u64_u32_p1i8)(int8_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1i8_p1i64_u32_p1i8)(int8_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1i8_p1i64_u32_p1i8)(int8_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1i8_p1f32_u32_p1i8)(int8_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1i8_p1f32_u32_p1i8)(int8_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1u16_p1u8_u32_p1i8)(uint16_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1u16_p1u8_u32_p1i8)(uint16_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1u16_p1i8_u32_p1i8)(uint16_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), n, + scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1u16_p1i8_u32_p1i8)(uint16_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), n, + scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1u16_p1u16_u32_p1i8)(uint16_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1u16_p1u16_u32_p1i8)(uint16_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1u16_p1i16_u32_p1i8)(uint16_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1u16_p1i16_u32_p1i8)(uint16_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1u16_p1u32_u32_p1i8)(uint16_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1u16_p1u32_u32_p1i8)(uint16_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1u16_p1i32_u32_p1i8)(uint16_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1u16_p1i32_u32_p1i8)(uint16_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1u16_p1u64_u32_p1i8)(uint16_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1u16_p1u64_u32_p1i8)(uint16_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1u16_p1i64_u32_p1i8)(uint16_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1u16_p1i64_u32_p1i8)(uint16_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1u16_p1f32_u32_p1i8)(uint16_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1u16_p1f32_u32_p1i8)(uint16_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1i16_p1u8_u32_p1i8)(int16_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1i16_p1u8_u32_p1i8)(int16_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1i16_p1i8_u32_p1i8)(int16_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), n, + scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1i16_p1i8_u32_p1i8)(int16_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), n, + scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1i16_p1u16_u32_p1i8)(int16_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1i16_p1u16_u32_p1i8)(int16_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1i16_p1i16_u32_p1i8)(int16_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1i16_p1i16_u32_p1i8)(int16_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1i16_p1u32_u32_p1i8)(int16_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1i16_p1u32_u32_p1i8)(int16_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1i16_p1i32_u32_p1i8)(int16_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1i16_p1i32_u32_p1i8)(int16_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1i16_p1u64_u32_p1i8)(int16_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1i16_p1u64_u32_p1i8)(int16_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1i16_p1i64_u32_p1i8)(int16_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1i16_p1i64_u32_p1i8)(int16_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1i16_p1f32_u32_p1i8)(int16_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1i16_p1f32_u32_p1i8)(int16_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1u32_p1u8_u32_p1i8)(uint32_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1u32_p1u8_u32_p1i8)(uint32_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1u32_p1i8_u32_p1i8)(uint32_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), n, + scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1u32_p1i8_u32_p1i8)(uint32_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), n, + scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1u32_p1u16_u32_p1i8)(uint32_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1u32_p1u16_u32_p1i8)(uint32_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1u32_p1i16_u32_p1i8)(uint32_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1u32_p1i16_u32_p1i8)(uint32_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1u32_p1u32_u32_p1i8)(uint32_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1u32_p1u32_u32_p1i8)(uint32_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1u32_p1i32_u32_p1i8)(uint32_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1u32_p1i32_u32_p1i8)(uint32_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1u32_p1u64_u32_p1i8)(uint32_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1u32_p1u64_u32_p1i8)(uint32_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1u32_p1i64_u32_p1i8)(uint32_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1u32_p1i64_u32_p1i8)(uint32_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1u32_p1f32_u32_p1i8)(uint32_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1u32_p1f32_u32_p1i8)(uint32_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1i32_p1u8_u32_p1i8)(int32_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1i32_p1u8_u32_p1i8)(int32_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1i32_p1i8_u32_p1i8)(int32_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), n, + scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1i32_p1i8_u32_p1i8)(int32_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), n, + scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1i32_p1u16_u32_p1i8)(int32_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1i32_p1u16_u32_p1i8)(int32_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1i32_p1i16_u32_p1i8)(int32_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1i32_p1i16_u32_p1i8)(int32_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1i32_p1u32_u32_p1i8)(int32_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1i32_p1u32_u32_p1i8)(int32_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1i32_p1i32_u32_p1i8)(int32_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1i32_p1i32_u32_p1i8)(int32_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1i32_p1u64_u32_p1i8)(int32_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1i32_p1u64_u32_p1i8)(int32_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1i32_p1i64_u32_p1i8)(int32_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1i32_p1i64_u32_p1i8)(int32_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1i32_p1f32_u32_p1i8)(int32_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1i32_p1f32_u32_p1i8)(int32_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1u64_p1u8_u32_p1i8)(uint64_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1u64_p1u8_u32_p1i8)(uint64_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1u64_p1i8_u32_p1i8)(uint64_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), n, + scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1u64_p1i8_u32_p1i8)(uint64_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), n, + scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1u64_p1u16_u32_p1i8)(uint64_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1u64_p1u16_u32_p1i8)(uint64_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1u64_p1i16_u32_p1i8)(uint64_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1u64_p1i16_u32_p1i8)(uint64_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1u64_p1u32_u32_p1i8)(uint64_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1u64_p1u32_u32_p1i8)(uint64_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1u64_p1i32_u32_p1i8)(uint64_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1u64_p1i32_u32_p1i8)(uint64_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1u64_p1u64_u32_p1i8)(uint64_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1u64_p1u64_u32_p1i8)(uint64_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1u64_p1i64_u32_p1i8)(uint64_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1u64_p1i64_u32_p1i8)(uint64_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1u64_p1f32_u32_p1i8)(uint64_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1u64_p1f32_u32_p1i8)(uint64_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1i64_p1u8_u32_p1i8)(int64_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1i64_p1u8_u32_p1i8)(int64_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1i64_p1i8_u32_p1i8)(int64_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), n, + scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1i64_p1i8_u32_p1i8)(int64_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), n, + scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1i64_p1u16_u32_p1i8)(int64_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1i64_p1u16_u32_p1i8)(int64_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1i64_p1i16_u32_p1i8)(int64_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1i64_p1i16_u32_p1i8)(int64_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1i64_p1u32_u32_p1i8)(int64_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1i64_p1u32_u32_p1i8)(int64_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1i64_p1i32_u32_p1i8)(int64_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1i64_p1i32_u32_p1i8)(int64_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1i64_p1u64_u32_p1i8)(int64_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1i64_p1u64_u32_p1i8)(int64_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1i64_p1i64_u32_p1i8)(int64_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1i64_p1i64_u32_p1i8)(int64_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1i64_p1f32_u32_p1i8)(int64_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1i64_p1f32_u32_p1i8)(int64_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1u8_p1u8_u32_p1i8)(uint8_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1u8_p1u8_u32_p1i8)(uint8_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1u8_p1i8_u32_p1i8)(uint8_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1u8_p1i8_u32_p1i8)(uint8_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1u8_p1u16_u32_p1i8)(uint8_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1u8_p1u16_u32_p1i8)(uint8_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1u8_p1i16_u32_p1i8)(uint8_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1u8_p1i16_u32_p1i8)(uint8_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1u8_p1u32_u32_p1i8)(uint8_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1u8_p1u32_u32_p1i8)(uint8_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1u8_p1i32_u32_p1i8)(uint8_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1u8_p1i32_u32_p1i8)(uint8_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1u8_p1u64_u32_p1i8)(uint8_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1u8_p1u64_u32_p1i8)(uint8_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1u8_p1i64_u32_p1i8)(uint8_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1u8_p1i64_u32_p1i8)(uint8_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1u8_p1f32_u32_p1i8)(uint8_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1u8_p1f32_u32_p1i8)(uint8_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1i8_p1u8_u32_p1i8)(int8_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1i8_p1u8_u32_p1i8)(int8_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1i8_p1i8_u32_p1i8)(int8_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1i8_p1i8_u32_p1i8)(int8_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1i8_p1u16_u32_p1i8)(int8_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1i8_p1u16_u32_p1i8)(int8_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1i8_p1i16_u32_p1i8)(int8_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1i8_p1i16_u32_p1i8)(int8_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1i8_p1u32_u32_p1i8)(int8_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1i8_p1u32_u32_p1i8)(int8_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1i8_p1i32_u32_p1i8)(int8_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1i8_p1i32_u32_p1i8)(int8_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1i8_p1u64_u32_p1i8)(int8_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1i8_p1u64_u32_p1i8)(int8_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1i8_p1i64_u32_p1i8)(int8_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1i8_p1i64_u32_p1i8)(int8_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1i8_p1f32_u32_p1i8)(int8_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1i8_p1f32_u32_p1i8)(int8_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1u16_p1u8_u32_p1i8)(uint16_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1u16_p1u8_u32_p1i8)(uint16_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1u16_p1i8_u32_p1i8)(uint16_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1u16_p1i8_u32_p1i8)(uint16_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1u16_p1u16_u32_p1i8)(uint16_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1u16_p1u16_u32_p1i8)(uint16_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1u16_p1i16_u32_p1i8)(uint16_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1u16_p1i16_u32_p1i8)(uint16_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1u16_p1u32_u32_p1i8)(uint16_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1u16_p1u32_u32_p1i8)(uint16_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1u16_p1i32_u32_p1i8)(uint16_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1u16_p1i32_u32_p1i8)(uint16_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1u16_p1u64_u32_p1i8)(uint16_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1u16_p1u64_u32_p1i8)(uint16_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1u16_p1i64_u32_p1i8)(uint16_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1u16_p1i64_u32_p1i8)(uint16_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1u16_p1f32_u32_p1i8)(uint16_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1u16_p1f32_u32_p1i8)(uint16_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1i16_p1u8_u32_p1i8)(int16_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1i16_p1u8_u32_p1i8)(int16_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1i16_p1i8_u32_p1i8)(int16_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1i16_p1i8_u32_p1i8)(int16_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1i16_p1u16_u32_p1i8)(int16_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1i16_p1u16_u32_p1i8)(int16_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1i16_p1i16_u32_p1i8)(int16_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1i16_p1i16_u32_p1i8)(int16_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1i16_p1u32_u32_p1i8)(int16_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1i16_p1u32_u32_p1i8)(int16_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1i16_p1i32_u32_p1i8)(int16_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1i16_p1i32_u32_p1i8)(int16_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1i16_p1u64_u32_p1i8)(int16_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1i16_p1u64_u32_p1i8)(int16_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1i16_p1i64_u32_p1i8)(int16_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1i16_p1i64_u32_p1i8)(int16_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1i16_p1f32_u32_p1i8)(int16_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1i16_p1f32_u32_p1i8)(int16_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1u32_p1u8_u32_p1i8)(uint32_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1u32_p1u8_u32_p1i8)(uint32_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1u32_p1i8_u32_p1i8)(uint32_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1u32_p1i8_u32_p1i8)(uint32_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1u32_p1u16_u32_p1i8)(uint32_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1u32_p1u16_u32_p1i8)(uint32_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1u32_p1i16_u32_p1i8)(uint32_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1u32_p1i16_u32_p1i8)(uint32_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1u32_p1u32_u32_p1i8)(uint32_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1u32_p1u32_u32_p1i8)(uint32_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1u32_p1i32_u32_p1i8)(uint32_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1u32_p1i32_u32_p1i8)(uint32_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1u32_p1u64_u32_p1i8)(uint32_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1u32_p1u64_u32_p1i8)(uint32_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1u32_p1i64_u32_p1i8)(uint32_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1u32_p1i64_u32_p1i8)(uint32_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1u32_p1f32_u32_p1i8)(uint32_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1u32_p1f32_u32_p1i8)(uint32_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1i32_p1u8_u32_p1i8)(int32_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1i32_p1u8_u32_p1i8)(int32_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1i32_p1i8_u32_p1i8)(int32_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1i32_p1i8_u32_p1i8)(int32_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1i32_p1u16_u32_p1i8)(int32_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1i32_p1u16_u32_p1i8)(int32_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1i32_p1i16_u32_p1i8)(int32_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1i32_p1i16_u32_p1i8)(int32_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1i32_p1u32_u32_p1i8)(int32_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1i32_p1u32_u32_p1i8)(int32_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1i32_p1i32_u32_p1i8)(int32_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1i32_p1i32_u32_p1i8)(int32_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1i32_p1u64_u32_p1i8)(int32_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1i32_p1u64_u32_p1i8)(int32_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1i32_p1i64_u32_p1i8)(int32_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1i32_p1i64_u32_p1i8)(int32_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1i32_p1f32_u32_p1i8)(int32_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1i32_p1f32_u32_p1i8)(int32_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1u64_p1u8_u32_p1i8)(uint64_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1u64_p1u8_u32_p1i8)(uint64_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1u64_p1i8_u32_p1i8)(uint64_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1u64_p1i8_u32_p1i8)(uint64_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1u64_p1u16_u32_p1i8)(uint64_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1u64_p1u16_u32_p1i8)(uint64_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1u64_p1i16_u32_p1i8)(uint64_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1u64_p1i16_u32_p1i8)(uint64_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1u64_p1u32_u32_p1i8)(uint64_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1u64_p1u32_u32_p1i8)(uint64_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1u64_p1i32_u32_p1i8)(uint64_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1u64_p1i32_u32_p1i8)(uint64_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1u64_p1u64_u32_p1i8)(uint64_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1u64_p1u64_u32_p1i8)(uint64_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1u64_p1i64_u32_p1i8)(uint64_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1u64_p1i64_u32_p1i8)(uint64_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1u64_p1f32_u32_p1i8)(uint64_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1u64_p1f32_u32_p1i8)(uint64_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1i64_p1u8_u32_p1i8)(int64_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1i64_p1u8_u32_p1i8)(int64_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1i64_p1i8_u32_p1i8)(int64_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1i64_p1i8_u32_p1i8)(int64_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1i64_p1u16_u32_p1i8)(int64_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1i64_p1u16_u32_p1i8)(int64_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1i64_p1i16_u32_p1i8)(int64_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1i64_p1i16_u32_p1i8)(int64_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1i64_p1u32_u32_p1i8)(int64_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1i64_p1u32_u32_p1i8)(int64_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1i64_p1i32_u32_p1i8)(int64_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1i64_p1i32_u32_p1i8)(int64_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1i64_p1u64_u32_p1i8)(int64_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1i64_p1u64_u32_p1i8)(int64_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1i64_p1i64_u32_p1i8)(int64_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1i64_p1i64_u32_p1i8)(int64_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1i64_p1f32_u32_p1i8)(int64_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1i64_p1f32_u32_p1i8)(int64_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1u8_p1u8_u32_p3i8)(uint8_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1u8_p1u8_u32_p3i8)(uint8_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1u8_p1i8_u32_p3i8)(uint8_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1u8_p1i8_u32_p3i8)(uint8_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1u8_p1u16_u32_p3i8)(uint8_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1u8_p1u16_u32_p3i8)(uint8_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1u8_p1i16_u32_p3i8)(uint8_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1u8_p1i16_u32_p3i8)(uint8_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1u8_p1u32_u32_p3i8)(uint8_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1u8_p1u32_u32_p3i8)(uint8_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1u8_p1i32_u32_p3i8)(uint8_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1u8_p1i32_u32_p3i8)(uint8_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1u8_p1u64_u32_p3i8)(uint8_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1u8_p1u64_u32_p3i8)(uint8_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1u8_p1i64_u32_p3i8)(uint8_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1u8_p1i64_u32_p3i8)(uint8_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1u8_p1f32_u32_p3i8)(uint8_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1u8_p1f32_u32_p3i8)(uint8_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::greater_equal{}); +} + +// int8_t as key +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1i8_p1u8_u32_p3i8)(int8_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1i8_p1u8_u32_p3i8)(int8_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1i8_p1i8_u32_p3i8)(int8_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1i8_p1i8_u32_p3i8)(int8_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1i8_p1u16_u32_p3i8)(int8_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1i8_p1u16_u32_p3i8)(int8_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1i8_p1i16_u32_p3i8)(int8_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1i8_p1i16_u32_p3i8)(int8_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1i8_p1u32_u32_p3i8)(int8_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1i8_p1u32_u32_p3i8)(int8_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1i8_p1i32_u32_p3i8)(int8_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1i8_p1i32_u32_p3i8)(int8_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1i8_p1u64_u32_p3i8)(int8_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1i8_p1u64_u32_p3i8)(int8_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1i8_p1i64_u32_p3i8)(int8_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1i8_p1i64_u32_p3i8)(int8_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1i8_p1f32_u32_p3i8)(int8_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1i8_p1f32_u32_p3i8)(int8_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::greater_equal{}); +} + +// uint16_t as key type +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1u16_p1u8_u32_p3i8)(uint16_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1u16_p1u8_u32_p3i8)(uint16_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1u16_p1i8_u32_p3i8)(uint16_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1u16_p1i8_u32_p3i8)(uint16_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1u16_p1u16_u32_p3i8)(uint16_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1u16_p1u16_u32_p3i8)(uint16_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1u16_p1i16_u32_p3i8)(uint16_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1u16_p1i16_u32_p3i8)(uint16_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1u16_p1u32_u32_p3i8)(uint16_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1u16_p1u32_u32_p3i8)(uint16_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1u16_p1i32_u32_p3i8)(uint16_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1u16_p1i32_u32_p3i8)(uint16_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1u16_p1u64_u32_p3i8)(uint16_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1u16_p1u64_u32_p3i8)(uint16_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1u16_p1i64_u32_p3i8)(uint16_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1u16_p1i64_u32_p3i8)(uint16_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1u16_p1f32_u32_p3i8)(uint16_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1u16_p1f32_u32_p3i8)(uint16_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::greater_equal{}); +} + +// int16_t as key type +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1i16_p1u8_u32_p3i8)(int16_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1i16_p1u8_u32_p3i8)(int16_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1i16_p1i8_u32_p3i8)(int16_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1i16_p1i8_u32_p3i8)(int16_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1i16_p1u16_u32_p3i8)(int16_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1i16_p1u16_u32_p3i8)(int16_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1i16_p1i16_u32_p3i8)(int16_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1i16_p1i16_u32_p3i8)(int16_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1i16_p1u32_u32_p3i8)(int16_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1i16_p1u32_u32_p3i8)(int16_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1i16_p1i32_u32_p3i8)(int16_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1i16_p1i32_u32_p3i8)(int16_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1i16_p1u64_u32_p3i8)(int16_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1i16_p1u64_u32_p3i8)(int16_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1i16_p1i64_u32_p3i8)(int16_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1i16_p1i64_u32_p3i8)(int16_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1i16_p1f32_u32_p3i8)(int16_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1i16_p1f32_u32_p3i8)(int16_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::greater_equal{}); +} + +// uint32_t as key type +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1u32_p1u8_u32_p3i8)(uint32_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1u32_p1u8_u32_p3i8)(uint32_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1u32_p3i8_u32_p3i8)(uint32_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1u32_p3i8_u32_p3i8)(uint32_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1u32_p1u16_u32_p3i8)(uint32_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1u32_p1u16_u32_p3i8)(uint32_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1u32_p1i16_u32_p3i8)(uint32_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1u32_p1i16_u32_p3i8)(uint32_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1u32_p1u32_u32_p3i8)(uint32_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1u32_p1u32_u32_p3i8)(uint32_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1u32_p1i32_u32_p3i8)(uint32_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1u32_p1i32_u32_p3i8)(uint32_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1u32_p1u64_u32_p3i8)(uint32_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1u32_p1u64_u32_p3i8)(uint32_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1u32_p1i64_u32_p3i8)(uint32_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1u32_p1i64_u32_p3i8)(uint32_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1u32_p1f32_u32_p3i8)(uint32_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1u32_p1f32_u32_p3i8)(uint32_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::greater_equal{}); +} + +// int32_t as key type +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1i32_p1u8_u32_p3i8)(int32_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1i32_p1u8_u32_p3i8)(int32_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1i32_p1i8_u32_p3i8)(int32_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1i32_p1i8_u32_p3i8)(int32_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1i32_p1u16_u32_p3i8)(int32_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1i32_p1u16_u32_p3i8)(int32_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1i32_p1i16_u32_p3i8)(int32_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1i32_p1i16_u32_p3i8)(int32_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1i32_p1u32_u32_p3i8)(int32_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1i32_p1u32_u32_p3i8)(int32_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1i32_p1i32_u32_p3i8)(int32_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1i32_p1i32_u32_p3i8)(int32_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1i32_p1u64_u32_p3i8)(int32_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1i32_p1u64_u32_p3i8)(int32_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1i32_p1i64_u32_p3i8)(int32_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1i32_p1i64_u32_p3i8)(int32_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1i32_p1f32_u32_p3i8)(int32_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1i32_p1f32_u32_p3i8)(int32_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::greater_equal{}); +} + +// uint64_t as key type +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1u64_p1u8_u32_p3i8)(uint64_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1u64_p1u8_u32_p3i8)(uint64_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1u64_p1i8_u32_p3i8)(uint64_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1u64_p1i8_u32_p3i8)(uint64_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1u64_p1u16_u32_p3i8)(uint64_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1u64_p1u16_u32_p3i8)(uint64_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1u64_p1i16_u32_p3i8)(uint64_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1u64_p1i16_u32_p3i8)(uint64_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1u64_p1u32_u32_p3i8)(uint64_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1u64_p1u32_u32_p3i8)(uint64_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1u64_p1i32_u32_p3i8)(uint64_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1u64_p1i32_u32_p3i8)(uint64_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1u64_p1u64_u32_p3i8)(uint64_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1u64_p1u64_u32_p3i8)(uint64_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1u64_p1i64_u32_p3i8)(uint64_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1u64_p1i64_u32_p3i8)(uint64_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1u64_p1f32_u32_p3i8)(uint64_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1u64_p1f32_u32_p3i8)(uint64_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::greater_equal{}); +} + +// int64_t as key type +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1i64_p1u8_u32_p3i8)(int64_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1i64_p1u8_u32_p3i8)(int64_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1i64_p1i8_u32_p3i8)(int64_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1i64_p1i8_u32_p3i8)(int64_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1i64_p1u16_u32_p3i8)(int64_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1i64_p1u16_u32_p3i8)(int64_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1i64_p1i16_u32_p3i8)(int64_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1i64_p1i16_u32_p3i8)(int64_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1i64_p1u32_u32_p3i8)(int64_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1i64_p1u32_u32_p3i8)(int64_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1i64_p1i32_u32_p3i8)(int64_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1i64_p1i32_u32_p3i8)(int64_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1i64_p1u64_u32_p3i8)(int64_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1i64_p1u64_u32_p3i8)(int64_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, vals, n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1i64_p1i64_u32_p3i8)(int64_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1i64_p1i64_u32_p3i8)(int64_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_A(p1i64_p1f32_u32_p3i8)(int64_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_JS_D(p1i64_p1f32_u32_p3i8)(int64_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + merge_sort_key_value(keys, reinterpret_cast(vals), n, scratch, + std::greater_equal{}); +} + +// Work group private sorting algorithms. +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1u8_p1u8_u32_p3i8)(uint8_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1u8_p1u8_u32_p3i8)(uint8_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1u8_p1i8_u32_p3i8)(uint8_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), n, + scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1u8_p1i8_u32_p3i8)(uint8_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), n, + scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1u8_p1u16_u32_p3i8)(uint8_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1u8_p1u16_u32_p3i8)(uint8_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1u8_p1i16_u32_p3i8)(uint8_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1u8_p1i16_u32_p3i8)(uint8_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1u8_p1u32_u32_p3i8)(uint8_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1u8_p1u32_u32_p3i8)(uint8_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1u8_p1i32_u32_p3i8)(uint8_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1u8_p1i32_u32_p3i8)(uint8_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1u8_p1u64_u32_p3i8)(uint8_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1u8_p1u64_u32_p3i8)(uint8_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1u8_p1i64_u32_p3i8)(uint8_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1u8_p1i64_u32_p3i8)(uint8_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1u8_p1f32_u32_p3i8)(uint8_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1u8_p1f32_u32_p3i8)(uint8_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1i8_p1u8_u32_p3i8)(int8_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1i8_p1u8_u32_p3i8)(int8_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1i8_p1i8_u32_p3i8)(int8_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), n, + scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1i8_p1i8_u32_p3i8)(int8_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), n, + scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1i8_p1u16_u32_p3i8)(int8_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1i8_p1u16_u32_p3i8)(int8_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1i8_p1i16_u32_p3i8)(int8_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1i8_p1i16_u32_p3i8)(int8_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1i8_p1u32_u32_p3i8)(int8_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1i8_p1u32_u32_p3i8)(int8_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1i8_p1i32_u32_p3i8)(int8_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1i8_p1i32_u32_p3i8)(int8_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1i8_p1u64_u32_p3i8)(int8_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1i8_p1u64_u32_p3i8)(int8_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1i8_p1i64_u32_p3i8)(int8_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1i8_p1i64_u32_p3i8)(int8_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1i8_p1f32_u32_p3i8)(int8_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1i8_p1f32_u32_p3i8)(int8_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1u16_p1u8_u32_p3i8)(uint16_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1u16_p1u8_u32_p3i8)(uint16_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1u16_p1i8_u32_p3i8)(uint16_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), n, + scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1u16_p1i8_u32_p3i8)(uint16_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), n, + scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1u16_p1u16_u32_p3i8)(uint16_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1u16_p1u16_u32_p3i8)(uint16_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1u16_p1i16_u32_p3i8)(uint16_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1u16_p1i16_u32_p3i8)(uint16_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1u16_p1u32_u32_p3i8)(uint16_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1u16_p1u32_u32_p3i8)(uint16_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1u16_p1i32_u32_p3i8)(uint16_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1u16_p1i32_u32_p3i8)(uint16_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1u16_p1u64_u32_p3i8)(uint16_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1u16_p1u64_u32_p3i8)(uint16_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1u16_p1i64_u32_p3i8)(uint16_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1u16_p1i64_u32_p3i8)(uint16_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1u16_p1f32_u32_p3i8)(uint16_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1u16_p1f32_u32_p3i8)(uint16_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1i16_p1u8_u32_p3i8)(int16_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1i16_p1u8_u32_p3i8)(int16_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1i16_p1i8_u32_p3i8)(int16_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), n, + scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1i16_p1i8_u32_p3i8)(int16_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), n, + scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1i16_p1u16_u32_p3i8)(int16_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1i16_p1u16_u32_p3i8)(int16_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1i16_p1i16_u32_p3i8)(int16_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1i16_p1i16_u32_p3i8)(int16_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1i16_p1u32_u32_p3i8)(int16_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1i16_p1u32_u32_p3i8)(int16_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1i16_p1i32_u32_p3i8)(int16_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1i16_p1i32_u32_p3i8)(int16_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1i16_p1u64_u32_p3i8)(int16_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1i16_p1u64_u32_p3i8)(int16_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1i16_p1i64_u32_p3i8)(int16_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1i16_p1i64_u32_p3i8)(int16_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1i16_p1f32_u32_p3i8)(int16_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1i16_p1f32_u32_p3i8)(int16_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1u32_p1u8_u32_p3i8)(uint32_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1u32_p1u8_u32_p3i8)(uint32_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1u32_p3i8_u32_p3i8)(uint32_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), n, + scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1u32_p3i8_u32_p3i8)(uint32_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), n, + scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1u32_p1u16_u32_p3i8)(uint32_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1u32_p1u16_u32_p3i8)(uint32_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1u32_p1i16_u32_p3i8)(uint32_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1u32_p1i16_u32_p3i8)(uint32_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1u32_p1u32_u32_p3i8)(uint32_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1u32_p1u32_u32_p3i8)(uint32_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1u32_p1i32_u32_p3i8)(uint32_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1u32_p1i32_u32_p3i8)(uint32_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1u32_p1u64_u32_p3i8)(uint32_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1u32_p1u64_u32_p3i8)(uint32_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1u32_p1i64_u32_p3i8)(uint32_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1u32_p1i64_u32_p3i8)(uint32_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1u32_p1f32_u32_p3i8)(uint32_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1u32_p1f32_u32_p3i8)(uint32_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1i32_p1u8_u32_p3i8)(int32_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1i32_p1u8_u32_p3i8)(int32_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1i32_p1i8_u32_p3i8)(int32_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), n, + scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1i32_p1i8_u32_p3i8)(int32_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), n, + scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1i32_p1u16_u32_p3i8)(int32_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1i32_p1u16_u32_p3i8)(int32_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1i32_p1i16_u32_p3i8)(int32_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1i32_p1i16_u32_p3i8)(int32_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1i32_p1u32_u32_p3i8)(int32_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1i32_p1u32_u32_p3i8)(int32_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1i32_p1i32_u32_p3i8)(int32_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1i32_p1i32_u32_p3i8)(int32_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1i32_p1u64_u32_p3i8)(int32_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1i32_p1u64_u32_p3i8)(int32_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1i32_p1i64_u32_p3i8)(int32_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1i32_p1i64_u32_p3i8)(int32_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1i32_p1f32_u32_p3i8)(int32_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1i32_p1f32_u32_p3i8)(int32_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1u64_p1u8_u32_p3i8)(uint64_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1u64_p1u8_u32_p3i8)(uint64_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1u64_p1i8_u32_p3i8)(uint64_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), n, + scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1u64_p1i8_u32_p3i8)(uint64_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), n, + scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1u64_p1u16_u32_p3i8)(uint64_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1u64_p1u16_u32_p3i8)(uint64_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1u64_p1i16_u32_p3i8)(uint64_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1u64_p1i16_u32_p3i8)(uint64_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1u64_p1u32_u32_p3i8)(uint64_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1u64_p1u32_u32_p3i8)(uint64_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1u64_p1i32_u32_p3i8)(uint64_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1u64_p1i32_u32_p3i8)(uint64_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1u64_p1u64_u32_p3i8)(uint64_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1u64_p1u64_u32_p3i8)(uint64_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1u64_p1i64_u32_p3i8)(uint64_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1u64_p1i64_u32_p3i8)(uint64_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1u64_p1f32_u32_p3i8)(uint64_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1u64_p1f32_u32_p3i8)(uint64_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1i64_p1u8_u32_p3i8)(int64_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1i64_p1u8_u32_p3i8)(int64_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1i64_p1i8_u32_p3i8)(int64_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), n, + scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1i64_p1i8_u32_p3i8)(int64_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), n, + scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1i64_p1u16_u32_p3i8)(int64_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1i64_p1u16_u32_p3i8)(int64_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1i64_p1i16_u32_p3i8)(int64_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1i64_p1i16_u32_p3i8)(int64_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1i64_p1u32_u32_p3i8)(int64_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1i64_p1u32_u32_p3i8)(int64_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1i64_p1i32_u32_p3i8)(int64_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1i64_p1i32_u32_p3i8)(int64_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1i64_p1u64_u32_p3i8)(int64_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1i64_p1u64_u32_p3i8)(int64_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1i64_p1i64_u32_p3i8)(int64_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1i64_p1i64_u32_p3i8)(int64_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CA(p1i64_p1f32_u32_p3i8)(int64_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_CD(p1i64_p1f32_u32_p3i8)(int64_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_close(keys, reinterpret_cast(vals), + n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1u8_p1u8_u32_p3i8)(uint8_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1u8_p1u8_u32_p3i8)(uint8_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1u8_p1i8_u32_p3i8)(uint8_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1u8_p1i8_u32_p3i8)(uint8_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1u8_p1u16_u32_p3i8)(uint8_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1u8_p1u16_u32_p3i8)(uint8_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1u8_p1i16_u32_p3i8)(uint8_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1u8_p1i16_u32_p3i8)(uint8_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1u8_p1u32_u32_p3i8)(uint8_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1u8_p1u32_u32_p3i8)(uint8_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1u8_p1i32_u32_p3i8)(uint8_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1u8_p1i32_u32_p3i8)(uint8_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1u8_p1u64_u32_p3i8)(uint8_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1u8_p1u64_u32_p3i8)(uint8_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1u8_p1i64_u32_p3i8)(uint8_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1u8_p1i64_u32_p3i8)(uint8_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1u8_p1f32_u32_p3i8)(uint8_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1u8_p1f32_u32_p3i8)(uint8_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1i8_p1u8_u32_p3i8)(int8_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1i8_p1u8_u32_p3i8)(int8_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1i8_p1i8_u32_p3i8)(int8_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1i8_p1i8_u32_p3i8)(int8_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1i8_p1u16_u32_p3i8)(int8_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1i8_p1u16_u32_p3i8)(int8_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1i8_p1i16_u32_p3i8)(int8_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1i8_p1i16_u32_p3i8)(int8_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1i8_p1u32_u32_p3i8)(int8_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1i8_p1u32_u32_p3i8)(int8_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1i8_p1i32_u32_p3i8)(int8_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1i8_p1i32_u32_p3i8)(int8_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1i8_p1u64_u32_p3i8)(int8_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1i8_p1u64_u32_p3i8)(int8_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1i8_p1i64_u32_p3i8)(int8_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1i8_p1i64_u32_p3i8)(int8_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1i8_p1f32_u32_p3i8)(int8_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1i8_p1f32_u32_p3i8)(int8_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1u16_p1u8_u32_p3i8)(uint16_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1u16_p1u8_u32_p3i8)(uint16_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1u16_p1i8_u32_p3i8)(uint16_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1u16_p1i8_u32_p3i8)(uint16_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1u16_p1u16_u32_p3i8)(uint16_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1u16_p1u16_u32_p3i8)(uint16_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1u16_p1i16_u32_p3i8)(uint16_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1u16_p1i16_u32_p3i8)(uint16_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1u16_p1u32_u32_p3i8)(uint16_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1u16_p1u32_u32_p3i8)(uint16_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1u16_p1i32_u32_p3i8)(uint16_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1u16_p1i32_u32_p3i8)(uint16_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1u16_p1u64_u32_p3i8)(uint16_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1u16_p1u64_u32_p3i8)(uint16_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1u16_p1i64_u32_p3i8)(uint16_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1u16_p1i64_u32_p3i8)(uint16_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1u16_p1f32_u32_p3i8)(uint16_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1u16_p1f32_u32_p3i8)(uint16_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1i16_p1u8_u32_p3i8)(int16_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1i16_p1u8_u32_p3i8)(int16_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1i16_p1i8_u32_p3i8)(int16_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1i16_p1i8_u32_p3i8)(int16_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1i16_p1u16_u32_p3i8)(int16_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1i16_p1u16_u32_p3i8)(int16_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1i16_p1i16_u32_p3i8)(int16_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1i16_p1i16_u32_p3i8)(int16_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1i16_p1u32_u32_p3i8)(int16_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1i16_p1u32_u32_p3i8)(int16_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1i16_p1i32_u32_p3i8)(int16_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1i16_p1i32_u32_p3i8)(int16_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1i16_p1u64_u32_p3i8)(int16_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1i16_p1u64_u32_p3i8)(int16_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1i16_p1i64_u32_p3i8)(int16_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1i16_p1i64_u32_p3i8)(int16_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1i16_p1f32_u32_p3i8)(int16_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1i16_p1f32_u32_p3i8)(int16_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1u32_p1u8_u32_p3i8)(uint32_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1u32_p1u8_u32_p3i8)(uint32_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1u32_p3i8_u32_p3i8)(uint32_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1u32_p3i8_u32_p3i8)(uint32_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1u32_p1u16_u32_p3i8)(uint32_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1u32_p1u16_u32_p3i8)(uint32_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1u32_p1i16_u32_p3i8)(uint32_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1u32_p1i16_u32_p3i8)(uint32_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1u32_p1u32_u32_p3i8)(uint32_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1u32_p1u32_u32_p3i8)(uint32_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1u32_p1i32_u32_p3i8)(uint32_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1u32_p1i32_u32_p3i8)(uint32_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1u32_p1u64_u32_p3i8)(uint32_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1u32_p1u64_u32_p3i8)(uint32_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1u32_p1i64_u32_p3i8)(uint32_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1u32_p1i64_u32_p3i8)(uint32_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1u32_p1f32_u32_p3i8)(uint32_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1u32_p1f32_u32_p3i8)(uint32_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1i32_p1u8_u32_p3i8)(int32_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1i32_p1u8_u32_p3i8)(int32_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1i32_p1i8_u32_p3i8)(int32_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1i32_p1i8_u32_p3i8)(int32_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1i32_p1u16_u32_p3i8)(int32_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1i32_p1u16_u32_p3i8)(int32_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1i32_p1i16_u32_p3i8)(int32_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1i32_p1i16_u32_p3i8)(int32_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1i32_p1u32_u32_p3i8)(int32_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1i32_p1u32_u32_p3i8)(int32_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1i32_p1i32_u32_p3i8)(int32_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1i32_p1i32_u32_p3i8)(int32_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1i32_p1u64_u32_p3i8)(int32_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1i32_p1u64_u32_p3i8)(int32_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1i32_p1i64_u32_p3i8)(int32_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1i32_p1i64_u32_p3i8)(int32_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1i32_p1f32_u32_p3i8)(int32_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1i32_p1f32_u32_p3i8)(int32_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1u64_p1u8_u32_p3i8)(uint64_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1u64_p1u8_u32_p3i8)(uint64_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1u64_p1i8_u32_p3i8)(uint64_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1u64_p1i8_u32_p3i8)(uint64_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1u64_p1u16_u32_p3i8)(uint64_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1u64_p1u16_u32_p3i8)(uint64_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1u64_p1i16_u32_p3i8)(uint64_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1u64_p1i16_u32_p3i8)(uint64_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1u64_p1u32_u32_p3i8)(uint64_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1u64_p1u32_u32_p3i8)(uint64_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1u64_p1i32_u32_p3i8)(uint64_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1u64_p1i32_u32_p3i8)(uint64_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1u64_p1u64_u32_p3i8)(uint64_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1u64_p1u64_u32_p3i8)(uint64_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1u64_p1i64_u32_p3i8)(uint64_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1u64_p1i64_u32_p3i8)(uint64_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1u64_p1f32_u32_p3i8)(uint64_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1u64_p1f32_u32_p3i8)(uint64_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1i64_p1u8_u32_p3i8)(int64_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1i64_p1u8_u32_p3i8)(int64_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1i64_p1i8_u32_p3i8)(int64_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1i64_p1i8_u32_p3i8)(int64_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1i64_p1u16_u32_p3i8)(int64_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1i64_p1u16_u32_p3i8)(int64_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1i64_p1i16_u32_p3i8)(int64_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1i64_p1i16_u32_p3i8)(int64_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1i64_p1u32_u32_p3i8)(int64_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1i64_p1u32_u32_p3i8)(int64_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1i64_p1i32_u32_p3i8)(int64_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1i64_p1i32_u32_p3i8)(int64_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1i64_p1u64_u32_p3i8)(int64_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1i64_p1u64_u32_p3i8)(int64_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, vals, n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1i64_p1i64_u32_p3i8)(int64_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1i64_p1i64_u32_p3i8)(int64_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SA(p1i64_p1f32_u32_p3i8)(int64_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, std::less_equal{}); +} + +DEVICE_EXTERN_C_INLINE +void WG_PS_SD(p1i64_p1f32_u32_p3i8)(int64_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { + private_merge_sort_key_value_spread(keys, reinterpret_cast(vals), + n, scratch, + std::greater_equal{}); +} + +#endif // __SPIR__ || __SPIRV__ diff --git a/libdevice/group_helper.hpp b/libdevice/group_helper.hpp new file mode 100644 index 0000000000000..934514e5326c2 --- /dev/null +++ b/libdevice/group_helper.hpp @@ -0,0 +1,37 @@ +//==------- group_helper.hpp - utils related to work-group operations-------==// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//==------------------------------------------------------------------------==// +#pragma once +#include "spirv_decls.hpp" +#include "spirv_vars.h" +#include +#if defined(__SPIR__) || defined(__SPIRV__) + +static inline size_t __get_wg_local_range() { + return __spirv_BuiltInWorkgroupSize.x * __spirv_BuiltInWorkgroupSize.y * + __spirv_BuiltInWorkgroupSize.z; +} + +static inline size_t __get_wg_local_linear_id() { + return (__spirv_BuiltInLocalInvocationId.x * __spirv_BuiltInWorkgroupSize.y * + __spirv_BuiltInWorkgroupSize.z) + + (__spirv_BuiltInLocalInvocationId.y * __spirv_BuiltInWorkgroupSize.z) + + __spirv_BuiltInLocalInvocationId.z; +} + +static inline void group_barrier() { + __spirv_ControlBarrier(__spv::Scope::Workgroup, __spv::Scope::Workgroup, + __spv::MemorySemanticsMask::SequentiallyConsistent | + __spv::MemorySemanticsMask::SubgroupMemory | + __spv::MemorySemanticsMask::WorkgroupMemory | + __spv::MemorySemanticsMask::CrossWorkgroupMemory); +} + +static inline uint64_t group_broadcast(uint64_t x) { + return __spirv_GroupBroadcast(__spv::Scope::Flag::Workgroup, x, 0); +} +#endif // __SPIR__ || __SPIRV__ diff --git a/libdevice/sort_helper.hpp b/libdevice/sort_helper.hpp new file mode 100644 index 0000000000000..f54a4965d3086 --- /dev/null +++ b/libdevice/sort_helper.hpp @@ -0,0 +1,367 @@ +//==------- sort_helper.hpp - helper functions to do group sorting----------==// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//==------------------------------------------------------------------------==// + +#pragma once +#include "group_helper.hpp" +#include + +#if defined(__SPIR__) || defined(__SPIRV__) +template +void bubble_sort(Tp *first, const size_t beg, const size_t end, Compare comp) { + if (beg < end) { + Tp temp; + for (size_t i = beg; i < end; ++i) + for (size_t j = i + 1; j < end; ++j) { + if (!comp(first[i], first[j])) { + temp = first[i]; + first[i] = first[j]; + first[j] = temp; + } + } + } +} + +// widx: work-item id with a work-group +// chunks: number of sorted chunks waiting to be merged +// n: total number of elements waiting to be sorted +// msize: number of elements in a chunk ready to be merged +template +void merge(Tp *din, Tp *dout, size_t widx, size_t msize, size_t chunks, + size_t n, Compare comp) { + if (2 * widx >= chunks) + return; + + size_t beg1 = 2 * widx * msize; + size_t end1 = beg1 + msize; + size_t beg2, end2; + if (end1 >= n) + end1 = beg2 = end2 = n; + else { + beg2 = end1; + end2 = beg2 + msize; + if (end2 >= n) + end2 = n; + } + + size_t output_idx = 2 * widx * msize; + while ((beg1 != end1) && (beg2 != end2)) { + if (comp(din[beg1], din[beg2])) + dout[output_idx++] = din[beg1++]; + else + dout[output_idx++] = din[beg2++]; + } + + while (beg1 != end1) + dout[output_idx++] = din[beg1++]; + while (beg2 != end2) + dout[output_idx++] = din[beg2++]; +} + +template +void merge_sort(Tp *first, uint32_t n, uint8_t *scratch, Compare comp) { + const size_t idx = __get_wg_local_linear_id(); + const size_t wg_size = __get_wg_local_range(); + const size_t chunk_size = (n - 1) / wg_size + 1; + + const size_t bubble_beg = (idx * chunk_size) >= n ? n : idx * chunk_size; + const size_t bubble_end = + ((idx + 1) * chunk_size) > n ? n : (idx + 1) * chunk_size; + bubble_sort(first, bubble_beg, bubble_end, comp); + group_barrier(); + Tp *scratch1 = reinterpret_cast(scratch); + bool data_in_scratch = false; + // We have wg_size chunks here, each chunk has chunk_size elements which + // are sorted. The last chunck's element number may be smaller. + size_t chunks_to_merge = (n - 1) / chunk_size + 1; + size_t merge_size = chunk_size; + while (chunks_to_merge > 1) { + // workitem 0 will merge chunk 0, 1. + // workitem 1 will merge chunk 2, 3. + // workitem idx will merge chunk 2 * idx and 2 * idx + 1 + Tp *data_in = data_in_scratch ? scratch1 : first; + Tp *data_out = data_in_scratch ? first : scratch1; + merge(data_in, data_out, idx, merge_size, chunks_to_merge, n, + comp); + group_barrier(); + chunks_to_merge = (chunks_to_merge - 1) / 2 + 1; + merge_size <<= 1; + data_in_scratch = !data_in_scratch; + } + if (data_in_scratch) { + for (size_t i = idx * chunk_size; i < bubble_end; ++i) + first[i] = scratch1[i]; + group_barrier(); + } +} + +// Each work-item holds some input elements located in private memory and apply +// group sorting to all work-items' input. The sorted data will be copied back +// to each work-item's private memory. +// Assumption about scratch memory size: +// scratch_size >= n * wg_size * sizeof(Tp) * 2 +template +void private_merge_sort_close(Tp *first, uint32_t n, uint8_t *scratch, + Compare comp) { + const size_t idx = __get_wg_local_linear_id(); + const size_t wg_size = __get_wg_local_range(); + Tp *temp_buffer = reinterpret_cast(scratch); + for (size_t i = 0; i < n; ++i) + temp_buffer[idx * n + i] = first[i]; + + group_barrier(); + // do group sorting for whole input data + merge_sort(temp_buffer, n * wg_size, + reinterpret_cast(temp_buffer + n * wg_size), comp); + + for (size_t i = 0; i < n; ++i) + first[i] = temp_buffer[idx * n + i]; +} + +template +void private_merge_sort_spread(Tp *first, uint32_t n, uint8_t *scratch, + Compare comp) { + const size_t idx = __get_wg_local_linear_id(); + const size_t wg_size = __get_wg_local_range(); + Tp *temp_buffer = reinterpret_cast(scratch); + for (size_t i = 0; i < n; ++i) + temp_buffer[idx * n + i] = first[i]; + + group_barrier(); + // do group sorting for whole input data + merge_sort(temp_buffer, n * wg_size, + reinterpret_cast(temp_buffer + n * wg_size), comp); + + for (size_t i = 0; i < n; ++i) + first[i] = temp_buffer[i * wg_size + idx]; +} + +// sub group sort implementation, each work-item holds an element, the total +// number of input elements is work group size. +// Assumption about scratch memory size: +// scratch_size >= wg_size * sizeof(Tp) * 2 +template +Tp sub_group_merge_sort(Tp value, uint8_t *scratch, Compare comp) { + const size_t idx = __get_wg_local_linear_id(); + const size_t wg_size = __get_wg_local_range(); + Tp *temp_buffer = reinterpret_cast(scratch); + temp_buffer[idx] = value; + + group_barrier(); + merge_sort(temp_buffer, wg_size, + reinterpret_cast(temp_buffer + wg_size), comp); + return temp_buffer[idx]; +} + +template +void bubble_sort_key_value_stable(KeyT *keys, ValT *vals, const size_t beg, + const size_t end, Compare comp) { + if (beg < end) { + KeyT temp_key; + ValT temp_val; + size_t swaps; + do { + swaps = 0; + for (size_t i = beg; i < (end - 1); ++i) { + if (!comp(keys[i], keys[i + 1])) { + temp_key = keys[i]; + keys[i] = keys[i + 1]; + keys[i + 1] = temp_key; + temp_val = vals[i]; + vals[i] = vals[i + 1]; + vals[i + 1] = temp_val; + swaps += 1; + } + } + } while (swaps != 0); + } +} + +template +void merge_key_value(KeyT *keys_in, KeyT *keys_out, ValT *vals_in, + ValT *vals_out, size_t widx, size_t msize, size_t chunks, + size_t n, Compare comp) { + if (2 * widx >= chunks) + return; + + size_t beg1 = 2 * widx * msize; + size_t end1 = beg1 + msize; + size_t beg2, end2; + if (end1 >= n) + end1 = beg2 = end2 = n; + else { + beg2 = end1; + end2 = beg2 + msize; + if (end2 >= n) + end2 = n; + } + + size_t output_idx = 2 * widx * msize; + while ((beg1 != end1) && (beg2 != end2)) { + KeyT key_temp; + ValT val_temp; + if (comp(keys_in[beg1], keys_in[beg2])) { + key_temp = keys_in[beg1]; + val_temp = vals_in[beg1]; + ++beg1; + } else { + key_temp = keys_in[beg2]; + val_temp = vals_in[beg2]; + ++beg2; + } + keys_out[output_idx] = key_temp; + vals_out[output_idx] = val_temp; + ++output_idx; + } + + while (beg1 != end1) { + keys_out[output_idx] = keys_in[beg1]; + vals_out[output_idx] = vals_in[beg1]; + ++beg1; + ++output_idx; + } + while (beg2 != end2) { + keys_out[output_idx] = keys_in[beg2]; + vals_out[output_idx] = vals_in[beg2]; + ++beg2; + ++output_idx; + } +} + +// We have following assumption for scratch memory size for key-value +// group sort: size of scratch > n * (sizeof(KeyT) + sizeof(ValT)) + +// max(alignof(KeyT), alignof(ValT)). +template +void merge_sort_key_value(KeyT *keys, ValT *vals, size_t n, uint8_t *scratch, + Compare comp) { + const size_t idx = __get_wg_local_linear_id(); + const size_t wg_size = __get_wg_local_range(); + size_t chunk_size = (n - 1) / wg_size + 1; + size_t bubble_beg, bubble_end; + bubble_beg = (idx * chunk_size) >= n ? n : idx * chunk_size; + bubble_end = ((idx + 1) * chunk_size) > n ? n : (idx + 1) * chunk_size; + bubble_sort_key_value_stable(keys, vals, bubble_beg, bubble_end, comp); + group_barrier(); + bool data_in_scratch = false; + KeyT *scratch_keys = reinterpret_cast(scratch); + ValT *scratch_vals = nullptr; + uint64_t val_offset = reinterpret_cast(scratch + sizeof(KeyT) * n); + uint64_t temp1 = val_offset % alignof(ValT); + if (temp1) + scratch_vals = reinterpret_cast(val_offset + alignof(ValT) - temp1); + else + scratch_vals = reinterpret_cast(val_offset); + + size_t chunks_to_merge = (n - 1) / chunk_size + 1; + size_t merge_size = chunk_size; + while (chunks_to_merge > 1) { + // workitem 0 will merge chunk 0, 1. + // workitem 1 will merge chunk 2, 3. + // workitem idx will merge chunk 2 * idx and 2 * idx + 1 + KeyT *keys_in = data_in_scratch ? scratch_keys : keys; + KeyT *keys_out = data_in_scratch ? keys : scratch_keys; + ValT *vals_in = data_in_scratch ? scratch_vals : vals; + ValT *vals_out = data_in_scratch ? vals : scratch_vals; + merge_key_value(keys_in, keys_out, vals_in, vals_out, + idx, merge_size, chunks_to_merge, n, + comp); + group_barrier(); + chunks_to_merge = (chunks_to_merge - 1) / 2 + 1; + merge_size <<= 1; + data_in_scratch = !data_in_scratch; + } + + if (data_in_scratch) { + for (size_t i = idx * chunk_size; i < bubble_end; ++i) { + keys[i] = scratch_keys[i]; + vals[i] = scratch_vals[i]; + } + group_barrier(); + } +} + +// Each work-item holds fixed-size input keys/values located in private memory +// and apply group sorting to all work-items' input. The sorted data will be +// copied back to each work-item's private memory. +// Assumption about scratch memory size: +// scratch_size >= 2 *(n * wg_size * (sizeof(KeyT) + sizeof(ValT))) + \ +// max(alignof(KeyT), alignof(ValT)) +// The scrach memory alignment is max(alignof(KeyT), alignof(ValT)) +template +static void private_merge_sort_key_value_helper(KeyT *keys, ValT *vals, + size_t n, uint8_t *scratch, + Compare comp, KeyT **keys_back, + ValT **vals_back) { + const size_t local_idx = __get_wg_local_linear_id(); + const size_t wg_size = __get_wg_local_range(); + uint64_t temp_val_beg = 0, temp_key_beg = 0, internal_scratch = 0; + KeyT *keys_ptr = nullptr; + ValT *vals_ptr = nullptr; + uint8_t *scratch_ptr = nullptr; + + if (local_idx == 0) { + uint64_t temp_val_unaligned = + reinterpret_cast(scratch + 2 * wg_size * n * sizeof(KeyT)); + uint64_t temp1 = temp_val_unaligned % alignof(ValT); + temp_val_beg = (temp1 != 0) ? (temp_val_unaligned + alignof(ValT) - temp1) + : temp_val_unaligned; + temp_val_beg += sizeof(ValT) * n * wg_size; + temp_key_beg = reinterpret_cast(scratch); + internal_scratch = temp_key_beg + sizeof(KeyT) * n * wg_size; + } + + keys_ptr = reinterpret_cast(group_broadcast(temp_key_beg)); + vals_ptr = reinterpret_cast(group_broadcast(temp_val_beg)); + scratch_ptr = reinterpret_cast(group_broadcast(internal_scratch)); + *keys_back = keys_ptr; + *vals_back = vals_ptr; + + for (size_t i = 0; i < n; ++i) { + keys_ptr[local_idx * n + i] = keys[i]; + vals_ptr[local_idx * n + i] = vals[i]; + } + + group_barrier(); + + merge_sort_key_value(keys_ptr, vals_ptr, n * wg_size, scratch_ptr, comp); +} + +template +void private_merge_sort_key_value_close(KeyT *keys, ValT *vals, size_t n, + uint8_t *scratch, Compare comp) { + + KeyT *keys_back = nullptr; + ValT *vals_back = nullptr; + private_merge_sort_key_value_helper(keys, vals, n, scratch, comp, &keys_back, + &vals_back); + + const size_t local_idx = __get_wg_local_linear_id(); + for (size_t i = 0; i < n; ++i) { + keys[i] = keys_back[local_idx * n + i]; + vals[i] = vals_back[local_idx * n + i]; + } +} + +template +void private_merge_sort_key_value_spread(KeyT *keys, ValT *vals, size_t n, + uint8_t *scratch, Compare comp) { + + KeyT *keys_back = nullptr; + ValT *vals_back = nullptr; + private_merge_sort_key_value_helper(keys, vals, n, scratch, comp, &keys_back, + &vals_back); + + const size_t local_idx = __get_wg_local_linear_id(); + const size_t wg_size = __get_wg_local_range(); + + for (size_t i = 0; i < n; ++i) { + keys[i] = keys_back[wg_size * i + local_idx]; + vals[i] = vals_back[wg_size * i + local_idx]; + } +} + +#endif // __SPIR__ || __SPIRV__ diff --git a/libdevice/spirv_decls.hpp b/libdevice/spirv_decls.hpp new file mode 100644 index 0000000000000..7e48f4104be0e --- /dev/null +++ b/libdevice/spirv_decls.hpp @@ -0,0 +1,88 @@ +//==-------------- atomic.hpp - support of atomic operations ---------------==// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +#pragma once + +#include + +#include "device.h" + +#if defined(__SPIR__) || defined(__SPIRV__) + +#define SPIR_GLOBAL __attribute__((opencl_global)) + +namespace __spv { +struct Scope { + + enum Flag : uint32_t { + CrossDevice = 0, + Device = 1, + Workgroup = 2, + Subgroup = 3, + Invocation = 4, + }; + + constexpr Scope(Flag flag) : flag_value(flag) {} + + constexpr operator uint32_t() const { return flag_value; } + + Flag flag_value; +}; + +struct MemorySemanticsMask { + + enum Flag : uint32_t { + None = 0x0, + Acquire = 0x2, + Release = 0x4, + AcquireRelease = 0x8, + SequentiallyConsistent = 0x10, + UniformMemory = 0x40, + SubgroupMemory = 0x80, + WorkgroupMemory = 0x100, + CrossWorkgroupMemory = 0x200, + AtomicCounterMemory = 0x400, + ImageMemory = 0x800, + }; + + constexpr MemorySemanticsMask(Flag flag) : flag_value(flag) {} + + constexpr operator uint32_t() const { return flag_value; } + + Flag flag_value; +}; +} // namespace __spv + +extern DEVICE_EXTERNAL void __spirv_ControlBarrier(__spv::Scope::Flag, + __spv::Scope::Flag Memory, + uint32_t Semantics); + +extern DEVICE_EXTERNAL int +__spirv_AtomicCompareExchange(int SPIR_GLOBAL *, __spv::Scope::Flag, + __spv::MemorySemanticsMask::Flag, + __spv::MemorySemanticsMask::Flag, int, int); + +extern DEVICE_EXTERNAL int +__spirv_AtomicCompareExchange(int *, __spv::Scope::Flag, + __spv::MemorySemanticsMask::Flag, + __spv::MemorySemanticsMask::Flag, int, int); + +extern DEVICE_EXTERNAL int __spirv_AtomicLoad(const int SPIR_GLOBAL *, + __spv::Scope::Flag, + __spv::MemorySemanticsMask::Flag); + +extern DEVICE_EXTERNAL void +__spirv_AtomicStore(int SPIR_GLOBAL *, __spv::Scope::Flag, + __spv::MemorySemanticsMask::Flag, int); + +extern DEVICE_EXTERNAL void +__spirv_AtomicStore(int *, __spv::Scope::Flag, __spv::MemorySemanticsMask::Flag, + int); + +extern DEVICE_EXTERNAL uint64_t __spirv_GroupBroadcast(__spv::Scope::Flag, + uint64_t, uint64_t); +#endif // __SPIR__ || __SPIRV__ diff --git a/llvm/include/llvm/SYCLLowerIR/SYCLDeviceLibReqMask.h b/llvm/include/llvm/SYCLLowerIR/SYCLDeviceLibReqMask.h index c9b737e2d053a..d107661d02e5c 100644 --- a/llvm/include/llvm/SYCLLowerIR/SYCLDeviceLibReqMask.h +++ b/llvm/include/llvm/SYCLLowerIR/SYCLDeviceLibReqMask.h @@ -36,6 +36,7 @@ enum class DeviceLibExt : std::uint32_t { cl_intel_devicelib_imf_fp64, cl_intel_devicelib_imf_bf16, cl_intel_devicelib_bfloat16, + cl_intel_devicelib_gsort, }; uint32_t getSYCLDeviceLibReqMask(const Module &M); diff --git a/llvm/lib/SYCLLowerIR/SYCLDeviceLibReqMask.cpp b/llvm/lib/SYCLLowerIR/SYCLDeviceLibReqMask.cpp index 12914d3763521..518a4966731e7 100644 --- a/llvm/lib/SYCLLowerIR/SYCLDeviceLibReqMask.cpp +++ b/llvm/lib/SYCLLowerIR/SYCLDeviceLibReqMask.cpp @@ -15,6 +15,7 @@ //===----------------------------------------------------------------------===// #include "llvm/SYCLLowerIR/SYCLDeviceLibReqMask.h" +#include "llvm/ADT/StringRef.h" #include "llvm/IR/Module.h" #include "llvm/TargetParser/Triple.h" @@ -732,6 +733,15 @@ SYCLDeviceLibFuncMap SDLMap = { DeviceLibExt::cl_intel_devicelib_bfloat16}, }; +// TODO: more robust check for all group sort fallback devicelib functions. +static bool checkGroupSortFallback(const StringRef &FuncName) { + if (FuncName.starts_with("__devicelib_default_work_group_") || + FuncName.starts_with("__devicelib_default_sub_group_")) + return true; + else + return false; +} + // Each fallback device library corresponds to one bit in "require mask" which // is an unsigned int32. getDeviceLibBit checks which fallback device library // is required for FuncName and returns the corresponding bit. The corresponding @@ -746,6 +756,7 @@ SYCLDeviceLibFuncMap SDLMap = { // cl_intel_devicelib_imf_fp64: 0x80 // cl_intel_devicelib_imf_bf16: 0x100 // cl_intel_devicelib_bfloat16: 0x200 +// cl_intel_devicelib_gsort: 0x400 uint32_t getDeviceLibBits(const std::string &FuncName) { auto DeviceLibFuncIter = SDLMap.find(FuncName); return ((DeviceLibFuncIter == SDLMap.end()) @@ -770,6 +781,10 @@ uint32_t llvm::getSYCLDeviceLibReqMask(const Module &M) { if (SF.getName().starts_with(DEVICELIB_FUNC_PREFIX) && SF.isDeclaration()) { assert(SF.getCallingConv() == CallingConv::SPIR_FUNC); uint32_t DeviceLibBits = getDeviceLibBits(SF.getName().str()); + if (!DeviceLibBits) { + if (checkGroupSortFallback(SF.getName())) + DeviceLibBits = 0x400; + } ReqMask |= DeviceLibBits; } } diff --git a/sycl/source/detail/program_manager/program_manager.cpp b/sycl/source/detail/program_manager/program_manager.cpp index 841ef9f562db5..fe2371e4baf1d 100644 --- a/sycl/source/detail/program_manager/program_manager.cpp +++ b/sycl/source/detail/program_manager/program_manager.cpp @@ -1207,6 +1207,8 @@ static const std::map> {nullptr, "libsycl-fallback-complex-fp64.spv"}}, {DeviceLibExt::cl_intel_devicelib_cstring, {nullptr, "libsycl-fallback-cstring.spv"}}, + {DeviceLibExt::cl_intel_devicelib_gsort, + {nullptr, "libsycl-fallback-gsort.spv"}}, {DeviceLibExt::cl_intel_devicelib_imf, {nullptr, "libsycl-fallback-imf.spv"}}, {DeviceLibExt::cl_intel_devicelib_imf_fp64, @@ -1239,6 +1241,7 @@ static const std::map DeviceLibExtensionStrs = { {DeviceLibExt::cl_intel_devicelib_complex_fp64, "cl_intel_devicelib_complex_fp64"}, {DeviceLibExt::cl_intel_devicelib_cstring, "cl_intel_devicelib_cstring"}, + {DeviceLibExt::cl_intel_devicelib_gsort, "cl_intel_devicelib_gsort"}, {DeviceLibExt::cl_intel_devicelib_imf, "cl_intel_devicelib_imf"}, {DeviceLibExt::cl_intel_devicelib_imf_fp64, "cl_intel_devicelib_imf_fp64"}, {DeviceLibExt::cl_intel_devicelib_imf_bf16, "cl_intel_devicelib_imf_bf16"}, @@ -1586,6 +1589,7 @@ getDeviceLibPrograms(const ContextImplPtr Context, {DeviceLibExt::cl_intel_devicelib_complex, false}, {DeviceLibExt::cl_intel_devicelib_complex_fp64, false}, {DeviceLibExt::cl_intel_devicelib_cstring, false}, + {DeviceLibExt::cl_intel_devicelib_gsort, false}, {DeviceLibExt::cl_intel_devicelib_imf, false}, {DeviceLibExt::cl_intel_devicelib_imf_fp64, false}, {DeviceLibExt::cl_intel_devicelib_imf_bf16, false}, diff --git a/sycl/source/detail/program_manager/program_manager.hpp b/sycl/source/detail/program_manager/program_manager.hpp index 14467a1dd26b8..fa71dcb76a9ee 100644 --- a/sycl/source/detail/program_manager/program_manager.hpp +++ b/sycl/source/detail/program_manager/program_manager.hpp @@ -85,6 +85,7 @@ enum class DeviceLibExt : std::uint32_t { cl_intel_devicelib_imf_fp64, cl_intel_devicelib_imf_bf16, cl_intel_devicelib_bfloat16, + cl_intel_devicelib_gsort, }; enum class SanitizerType { diff --git a/sycl/test-e2e/DeviceLib/group_sort/group_joint_KV_sort_p1p1_p1.hpp b/sycl/test-e2e/DeviceLib/group_sort/group_joint_KV_sort_p1p1_p1.hpp new file mode 100644 index 0000000000000..5bc5dc0e9f2b3 --- /dev/null +++ b/sycl/test-e2e/DeviceLib/group_sort/group_joint_KV_sort_p1p1_p1.hpp @@ -0,0 +1,129 @@ +#include "group_sort.hpp" + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p1u8_p1u8_u32_p1i8( + uint8_t *keys, uint8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_descending_p1u8_p1u8_u32_p1i8( + uint8_t *keys, uint8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p1u8_p1i8_u32_p1i8( + uint8_t *keys, int8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_descending_p1u8_p1i8_u32_p1i8( + uint8_t *keys, int8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p1u8_p1u16_u32_p1i8( + uint8_t *keys, uint16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_descending_p1u8_p1u16_u32_p1i8( + uint8_t *keys, uint16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p1u8_p1i16_u32_p1i8( + uint8_t *keys, int16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_descending_p1u8_p1i16_u32_p1i8( + uint8_t *keys, int16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p1u8_p1u32_u32_p1i8( + uint8_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_descending_p1u8_p1u32_u32_p1i8( + uint8_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p1u8_p1i32_u32_p1i8( + uint8_t *keys, int32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_descending_p1u8_p1i32_u32_p1i8( + uint8_t *keys, int32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p1u8_p1u64_u32_p1i8( + uint8_t *keys, uint64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_descending_p1u8_p1u64_u32_p1i8( + uint8_t *keys, uint64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p1u8_p1i64_u32_p1i8( + uint8_t *keys, int64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_descending_p1u8_p1i64_u32_p1i8( + uint8_t *keys, int64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p1u8_p1f32_u32_p1i8( + uint8_t *keys, float *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_descending_p1u8_p1f32_u32_p1i8( + uint8_t *keys, float *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p1i8_p1u8_u32_p1i8( + int8_t *keys, uint8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_descending_p1i8_p1u8_u32_p1i8( + int8_t *keys, uint8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p1i8_p1i8_u32_p1i8( + int8_t *keys, int8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_descending_p1i8_p1i8_u32_p1i8( + int8_t *keys, int8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p1i8_p1u16_u32_p1i8( + int8_t *keys, uint16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_descending_p1i8_p1u16_u32_p1i8( + int8_t *keys, uint16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p1i8_p1i16_u32_p1i8( + int8_t *keys, int16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_descending_p1i8_p1i16_u32_p1i8( + int8_t *keys, int16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p1i8_p1u32_u32_p1i8( + int8_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_descending_p1i8_p1u32_u32_p1i8( + int8_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p1i8_p1u64_u32_p1i8( + int8_t *keys, uint64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_descending_p1i8_p1u64_u32_p1i8( + int8_t *keys, uint64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p1u32_p1u32_u32_p1i8( + uint32_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_descending_p1u32_p1u32_u32_p1i8( + uint32_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch); diff --git a/sycl/test-e2e/DeviceLib/group_sort/group_joint_sort_p1p1.hpp b/sycl/test-e2e/DeviceLib/group_sort/group_joint_sort_p1p1.hpp new file mode 100644 index 0000000000000..2b84ce68da2b3 --- /dev/null +++ b/sycl/test-e2e/DeviceLib/group_sort/group_joint_sort_p1p1.hpp @@ -0,0 +1,186 @@ +#pragma once +#include + +#ifdef __SYCL_DEVICE_ONLY__ +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p1i32_u32_p1i8( + int32_t *first, uint32_t n, uint8_t *scratch); + +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_descending_p1i32_u32_p1i8( + int32_t *first, uint32_t n, uint8_t *scratch); + +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p1i16_u32_p1i8( + int16_t *first, uint32_t n, uint8_t *scratch); + +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_descending_p1i16_u32_p1i8( + int16_t *first, uint32_t n, uint8_t *scratch); + +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p1i64_u32_p1i8( + int64_t *first, uint32_t n, uint8_t *scratch); + +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_descending_p1i64_u32_p1i8( + int64_t *first, uint32_t n, uint8_t *scratch); + +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p1i8_u32_p1i8( + int8_t *first, uint32_t n, uint8_t *scratch); + +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_descending_p1i8_u32_p1i8( + int8_t *first, uint32_t n, uint8_t *scratch); + +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p1f32_u32_p1i8( + float *first, uint32_t n, uint8_t *scratch); + +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_descending_p1f32_u32_p1i8( + float *first, uint32_t n, uint8_t *scratch); + +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p1u8_u32_p1i8( + uint8_t *first, uint32_t n, uint8_t *scratch); + +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_descending_p1u8_u32_p1i8( + uint8_t *first, uint32_t n, uint8_t *scratch); + +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p1u16_u32_p1i8( + uint16_t *first, uint32_t n, uint8_t *scratch); + +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_descending_p1u16_u32_p1i8( + uint16_t *first, uint32_t n, uint8_t *scratch); + +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p1u32_u32_p1i8( + uint32_t *first, uint32_t n, uint8_t *scratch); + +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_descending_p1u32_u32_p1i8( + uint32_t *first, uint32_t n, uint8_t *scratch); + +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p1u64_u32_p1i8( + uint64_t *first, uint32_t n, uint8_t *scratch); + +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_descending_p1u64_u32_p1i8( + uint64_t *first, uint32_t n, uint8_t *scratch); + +#else +extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p1i32_u32_p1i8( + int32_t *first, uint32_t n, uint8_t *scratch) { + return; +} + +extern "C" void +__devicelib_default_work_group_joint_sort_descending_p1i32_u32_p1i8( + int32_t *first, uint32_t n, uint8_t *scratch) { + return; +} + +extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p1i16_u32_p1i8( + int16_t *first, uint32_t n, uint8_t *scratch) { + return; +} + +extern "C" void +__devicelib_default_work_group_joint_sort_descending_p1i16_u32_p1i8( + int16_t *first, uint32_t n, uint8_t *scratch) { + return; +} + +extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p1i64_u32_p1i8( + int64_t *first, uint32_t n, uint8_t *scratch) { + return; +} + +extern "C" void +__devicelib_default_work_group_joint_sort_descending_p1i64_u32_p1i8( + int64_t *first, uint32_t n, uint8_t *scratch) { + return; +} + +extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p1i8_u32_p1i8( + int8_t *first, uint32_t n, uint8_t *scratch) { + return; +} + +extern "C" void +__devicelib_default_work_group_joint_sort_descending_p1i8_u32_p1i8( + int8_t *first, uint32_t n, uint8_t *scratch) { + return; +} + +extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p1f32_u32_p1i8( + float *first, uint32_t n, uint8_t *scratch) { + return; +} + +extern "C" void +__devicelib_default_work_group_joint_sort_descending_p1f32_u32_p1i8( + float *first, uint32_t n, uint8_t *scratch) { + return; +} + +extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p1u8_u32_p1i8( + uint8_t *first, uint32_t n, uint8_t *scratch) { + return; +} + +extern "C" void +__devicelib_default_work_group_joint_sort_descending_p1u8_u32_p1i8( + uint8_t *first, uint32_t n, uint8_t *scratch) { + return; +} + +extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p1u16_u32_p1i8( + uint16_t *first, uint32_t n, uint8_t *scratch) { + return; +} + +extern "C" void +__devicelib_default_work_group_joint_sort_descending_p1u16_u32_p1i8( + uint16_t *first, uint32_t n, uint8_t *scratch) { + return; +} + +extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p1u32_u32_p1i8( + uint32_t *first, uint32_t n, uint8_t *scratch) { + return; +} + +extern "C" void +__devicelib_default_work_group_joint_sort_descending_p1u32_u32_p1i8( + uint32_t *first, uint32_t n, uint8_t *scratch) { + return; +} + +extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p1u64_u32_p1i8( + uint64_t *first, uint32_t n, uint8_t *scratch) { + return; +} + +extern "C" void +__devicelib_default_work_group_joint_sort_descending_p1u64_u32_p1i8( + uint64_t *first, uint32_t n, uint8_t *scratch) { + return; +} + +#endif diff --git a/sycl/test-e2e/DeviceLib/group_sort/group_joint_sort_p1p3.hpp b/sycl/test-e2e/DeviceLib/group_sort/group_joint_sort_p1p3.hpp new file mode 100644 index 0000000000000..a43f191c2b8b4 --- /dev/null +++ b/sycl/test-e2e/DeviceLib/group_sort/group_joint_sort_p1p3.hpp @@ -0,0 +1,186 @@ +#pragma once +#include + +#ifdef __SYCL_DEVICE_ONLY__ +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p1i32_u32_p3i8( + int32_t *first, uint32_t n, uint8_t *scratch); + +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_descending_p1i32_u32_p3i8( + int32_t *first, uint32_t n, uint8_t *scratch); + +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p1i16_u32_p3i8( + int16_t *first, uint32_t n, uint8_t *scratch); + +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_descending_p1i16_u32_p3i8( + int16_t *first, uint32_t n, uint8_t *scratch); + +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p1i64_u32_p3i8( + int64_t *first, uint32_t n, uint8_t *scratch); + +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_descending_p1i64_u32_p3i8( + int64_t *first, uint32_t n, uint8_t *scratch); + +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p1i8_u32_p3i8( + int8_t *first, uint32_t n, uint8_t *scratch); + +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_descending_p1i8_u32_p3i8( + int8_t *first, uint32_t n, uint8_t *scratch); + +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p1f32_u32_p3i8( + float *first, uint32_t n, uint8_t *scratch); + +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_descending_p1f32_u32_p3i8( + float *first, uint32_t n, uint8_t *scratch); + +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p1u8_u32_p3i8( + uint8_t *first, uint32_t n, uint8_t *scratch); + +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_descending_p1u8_u32_p3i8( + uint8_t *first, uint32_t n, uint8_t *scratch); + +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p1u16_u32_p3i8( + uint16_t *first, uint32_t n, uint8_t *scratch); + +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_descending_p1u16_u32_p3i8( + uint16_t *first, uint32_t n, uint8_t *scratch); + +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p1u32_u32_p3i8( + uint32_t *first, uint32_t n, uint8_t *scratch); + +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_descending_p1u32_u32_p3i8( + uint32_t *first, uint32_t n, uint8_t *scratch); + +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p1u64_u32_p3i8( + uint64_t *first, uint32_t n, uint8_t *scratch); + +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_descending_p1u64_u32_p3i8( + uint64_t *first, uint32_t n, uint8_t *scratch); + +#else +extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p1i32_u32_p3i8( + int32_t *first, uint32_t n, uint8_t *scratch) { + return; +} + +extern "C" void +__devicelib_default_work_group_joint_sort_descending_p1i32_u32_p3i8( + int32_t *first, uint32_t n, uint8_t *scratch) { + return; +} + +extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p1i16_u32_p3i8( + int16_t *first, uint32_t n, uint8_t *scratch) { + return; +} + +extern "C" void +__devicelib_default_work_group_joint_sort_descending_p1i16_u32_p3i8( + int16_t *first, uint32_t n, uint8_t *scratch) { + return; +} + +extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p1i64_u32_p3i8( + int64_t *first, uint32_t n, uint8_t *scratch) { + return; +} + +extern "C" void +__devicelib_default_work_group_joint_sort_descending_p1i64_u32_p3i8( + int64_t *first, uint32_t n, uint8_t *scratch) { + return; +} + +extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p1i8_u32_p3i8( + int8_t *first, uint32_t n, uint8_t *scratch) { + return; +} + +extern "C" void +__devicelib_default_work_group_joint_sort_descending_p1i8_u32_p3i8( + int8_t *first, uint32_t n, uint8_t *scratch) { + return; +} + +extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p1f32_u32_p3i8( + float *first, uint32_t n, uint8_t *scratch) { + return; +} + +extern "C" void +__devicelib_default_work_group_joint_sort_descending_p1f32_u32_p3i8( + float *first, uint32_t n, uint8_t *scratch) { + return; +} + +extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p1u8_u32_p3i8( + uint8_t *first, uint32_t n, uint8_t *scratch) { + return; +} + +extern "C" void +__devicelib_default_work_group_joint_sort_descending_p1u8_u32_p3i8( + uint8_t *first, uint32_t n, uint8_t *scratch) { + return; +} + +extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p1u16_u32_p3i8( + uint16_t *first, uint32_t n, uint8_t *scratch) { + return; +} + +extern "C" void +__devicelib_default_work_group_joint_sort_descending_p1u16_u32_p3i8( + uint16_t *first, uint32_t n, uint8_t *scratch) { + return; +} + +extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p1u32_u32_p3i8( + uint32_t *first, uint32_t n, uint8_t *scratch) { + return; +} + +extern "C" void +__devicelib_default_work_group_joint_sort_descending_p1u32_u32_p3i8( + uint32_t *first, uint32_t n, uint8_t *scratch) { + return; +} + +extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p1u64_u32_p3i8( + uint64_t *first, uint32_t n, uint8_t *scratch) { + return; +} + +extern "C" void +__devicelib_default_work_group_joint_sort_descending_p1u64_u32_p3i8( + uint64_t *first, uint32_t n, uint8_t *scratch) { + return; +} + +#endif diff --git a/sycl/test-e2e/DeviceLib/group_sort/group_joint_sort_p3p1.hpp b/sycl/test-e2e/DeviceLib/group_sort/group_joint_sort_p3p1.hpp new file mode 100644 index 0000000000000..180ddeecffcfb --- /dev/null +++ b/sycl/test-e2e/DeviceLib/group_sort/group_joint_sort_p3p1.hpp @@ -0,0 +1,186 @@ +#pragma once +#include + +#ifdef __SYCL_DEVICE_ONLY__ +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p3i32_u32_p1i8( + int32_t *first, uint32_t n, uint8_t *scratch); + +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_descending_p3i32_u32_p1i8( + int32_t *first, uint32_t n, uint8_t *scratch); + +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p3i16_u32_p1i8( + int16_t *first, uint32_t n, uint8_t *scratch); + +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_descending_p3i16_u32_p1i8( + int16_t *first, uint32_t n, uint8_t *scratch); + +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p3i64_u32_p1i8( + int64_t *first, uint32_t n, uint8_t *scratch); + +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_descending_p3i64_u32_p1i8( + int64_t *first, uint32_t n, uint8_t *scratch); + +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p3i8_u32_p1i8( + int8_t *first, uint32_t n, uint8_t *scratch); + +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_descending_p3i8_u32_p1i8( + int8_t *first, uint32_t n, uint8_t *scratch); + +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p3f32_u32_p1i8( + float *first, uint32_t n, uint8_t *scratch); + +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_descending_p3f32_u32_p1i8( + float *first, uint32_t n, uint8_t *scratch); + +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p3u8_u32_p1i8( + uint8_t *first, uint32_t n, uint8_t *scratch); + +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_descending_p3u8_u32_p1i8( + uint8_t *first, uint32_t n, uint8_t *scratch); + +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p3u16_u32_p1i8( + uint16_t *first, uint32_t n, uint8_t *scratch); + +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_descending_p3u16_u32_p1i8( + uint16_t *first, uint32_t n, uint8_t *scratch); + +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p3u32_u32_p1i8( + uint32_t *first, uint32_t n, uint8_t *scratch); + +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_descending_p3u32_u32_p1i8( + uint32_t *first, uint32_t n, uint8_t *scratch); + +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p3u64_u32_p1i8( + uint64_t *first, uint32_t n, uint8_t *scratch); + +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_descending_p3u64_u32_p1i8( + uint64_t *first, uint32_t n, uint8_t *scratch); + +#else +extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p3i32_u32_p1i8( + int32_t *first, uint32_t n, uint8_t *scratch) { + return; +} + +extern "C" void +__devicelib_default_work_group_joint_sort_descending_p3i32_u32_p1i8( + int32_t *first, uint32_t n, uint8_t *scratch) { + return; +} + +extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p3i16_u32_p1i8( + int16_t *first, uint32_t n, uint8_t *scratch) { + return; +} + +extern "C" void +__devicelib_default_work_group_joint_sort_descending_p3i16_u32_p1i8( + int16_t *first, uint32_t n, uint8_t *scratch) { + return; +} + +extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p3i64_u32_p1i8( + int64_t *first, uint32_t n, uint8_t *scratch) { + return; +} + +extern "C" void +__devicelib_default_work_group_joint_sort_descending_p3i64_u32_p1i8( + int64_t *first, uint32_t n, uint8_t *scratch) { + return; +} + +extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p3i8_u32_p1i8( + int8_t *first, uint32_t n, uint8_t *scratch) { + return; +} + +extern "C" void +__devicelib_default_work_group_joint_sort_descending_p3i8_u32_p1i8( + int8_t *first, uint32_t n, uint8_t *scratch) { + return; +} + +extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p3f32_u32_p1i8( + float *first, uint32_t n, uint8_t *scratch) { + return; +} + +extern "C" void +__devicelib_default_work_group_joint_sort_descending_p3f32_u32_p1i8( + float *first, uint32_t n, uint8_t *scratch) { + return; +} + +extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p3u8_u32_p1i8( + uint8_t *first, uint32_t n, uint8_t *scratch) { + return; +} + +extern "C" void +__devicelib_default_work_group_joint_sort_descending_p3u8_u32_p1i8( + uint8_t *first, uint32_t n, uint8_t *scratch) { + return; +} + +extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p3u16_u32_p1i8( + uint16_t *first, uint32_t n, uint8_t *scratch) { + return; +} + +extern "C" void +__devicelib_default_work_group_joint_sort_descending_p3u16_u32_p1i8( + uint16_t *first, uint32_t n, uint8_t *scratch) { + return; +} + +extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p3u32_u32_p1i8( + uint32_t *first, uint32_t n, uint8_t *scratch) { + return; +} + +extern "C" void +__devicelib_default_work_group_joint_sort_descending_p3u32_u32_p1i8( + uint32_t *first, uint32_t n, uint8_t *scratch) { + return; +} + +extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p3u64_u32_p1i8( + uint64_t *first, uint32_t n, uint8_t *scratch) { + return; +} + +extern "C" void +__devicelib_default_work_group_joint_sort_descending_p3u64_u32_p1i8( + uint64_t *first, uint32_t n, uint8_t *scratch) { + return; +} + +#endif diff --git a/sycl/test-e2e/DeviceLib/group_sort/group_joint_sort_p3p3.hpp b/sycl/test-e2e/DeviceLib/group_sort/group_joint_sort_p3p3.hpp new file mode 100644 index 0000000000000..ff80e75a8fa49 --- /dev/null +++ b/sycl/test-e2e/DeviceLib/group_sort/group_joint_sort_p3p3.hpp @@ -0,0 +1,186 @@ +#pragma once +#include + +#ifdef __SYCL_DEVICE_ONLY__ +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p3i32_u32_p3i8( + int32_t *first, uint32_t n, uint8_t *scratch); + +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_descending_p3i32_u32_p3i8( + int32_t *first, uint32_t n, uint8_t *scratch); + +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p3i16_u32_p3i8( + int16_t *first, uint32_t n, uint8_t *scratch); + +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_descending_p3i16_u32_p3i8( + int16_t *first, uint32_t n, uint8_t *scratch); + +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p3i64_u32_p3i8( + int64_t *first, uint32_t n, uint8_t *scratch); + +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_descending_p3i64_u32_p3i8( + int64_t *first, uint32_t n, uint8_t *scratch); + +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p3i8_u32_p3i8( + int8_t *first, uint32_t n, uint8_t *scratch); + +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_descending_p3i8_u32_p3i8( + int8_t *first, uint32_t n, uint8_t *scratch); + +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p3f32_u32_p3i8( + float *first, uint32_t n, uint8_t *scratch); + +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_descending_p3f32_u32_p3i8( + float *first, uint32_t n, uint8_t *scratch); + +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p3u8_u32_p3i8( + uint8_t *first, uint32_t n, uint8_t *scratch); + +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_descending_p3u8_u32_p3i8( + uint8_t *first, uint32_t n, uint8_t *scratch); + +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p3u16_u32_p3i8( + uint16_t *first, uint32_t n, uint8_t *scratch); + +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_descending_p3u16_u32_p3i8( + uint16_t *first, uint32_t n, uint8_t *scratch); + +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p3u32_u32_p3i8( + uint32_t *first, uint32_t n, uint8_t *scratch); + +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_descending_p3u32_u32_p3i8( + uint32_t *first, uint32_t n, uint8_t *scratch); + +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p3u64_u32_p3i8( + uint64_t *first, uint32_t n, uint8_t *scratch); + +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_joint_sort_descending_p3u64_u32_p3i8( + uint64_t *first, uint32_t n, uint8_t *scratch); + +#else +extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p3i32_u32_p3i8( + int32_t *first, uint32_t n, uint8_t *scratch) { + return; +} + +extern "C" void +__devicelib_default_work_group_joint_sort_descending_p3i32_u32_p3i8( + int32_t *first, uint32_t n, uint8_t *scratch) { + return; +} + +extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p3i16_u32_p3i8( + int16_t *first, uint32_t n, uint8_t *scratch) { + return; +} + +extern "C" void +__devicelib_default_work_group_joint_sort_descending_p3i16_u32_p3i8( + int16_t *first, uint32_t n, uint8_t *scratch) { + return; +} + +extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p3i64_u32_p3i8( + int64_t *first, uint32_t n, uint8_t *scratch) { + return; +} + +extern "C" void +__devicelib_default_work_group_joint_sort_descending_p3i64_u32_p3i8( + int64_t *first, uint32_t n, uint8_t *scratch) { + return; +} + +extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p3i8_u32_p3i8( + int8_t *first, uint32_t n, uint8_t *scratch) { + return; +} + +extern "C" void +__devicelib_default_work_group_joint_sort_descending_p3i8_u32_p3i8( + int8_t *first, uint32_t n, uint8_t *scratch) { + return; +} + +extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p3f32_u32_p3i8( + float *first, uint32_t n, uint8_t *scratch) { + return; +} + +extern "C" void +__devicelib_default_work_group_joint_sort_descending_p3f32_u32_p3i8( + float *first, uint32_t n, uint8_t *scratch) { + return; +} + +extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p3u8_u32_p3i8( + uint8_t *first, uint32_t n, uint8_t *scratch) { + return; +} + +extern "C" void +__devicelib_default_work_group_joint_sort_descending_p3u8_u32_p3i8( + uint8_t *first, uint32_t n, uint8_t *scratch) { + return; +} + +extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p3u16_u32_p3i8( + uint16_t *first, uint32_t n, uint8_t *scratch) { + return; +} + +extern "C" void +__devicelib_default_work_group_joint_sort_descending_p3u16_u32_p3i8( + uint16_t *first, uint32_t n, uint8_t *scratch) { + return; +} + +extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p3u32_u32_p3i8( + uint32_t *first, uint32_t n, uint8_t *scratch) { + return; +} + +extern "C" void +__devicelib_default_work_group_joint_sort_descending_p3u32_u32_p3i8( + uint32_t *first, uint32_t n, uint8_t *scratch) { + return; +} + +extern "C" void +__devicelib_default_work_group_joint_sort_ascending_p3u64_u32_p3i8( + uint64_t *first, uint32_t n, uint8_t *scratch) { + return; +} + +extern "C" void +__devicelib_default_work_group_joint_sort_descending_p3u64_u32_p3i8( + uint64_t *first, uint32_t n, uint8_t *scratch) { + return; +} + +#endif diff --git a/sycl/test-e2e/DeviceLib/group_sort/group_private_KV_sort_p1p1_p1.hpp b/sycl/test-e2e/DeviceLib/group_sort/group_private_KV_sort_p1p1_p1.hpp new file mode 100644 index 0000000000000..cf40a5d23578a --- /dev/null +++ b/sycl/test-e2e/DeviceLib/group_sort/group_private_KV_sort_p1p1_p1.hpp @@ -0,0 +1,1189 @@ +#include "group_sort.hpp" +#include +#include +#include +#include +#include +#include +#include +#include +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1u8_p1u8_u32_p1i8( + uint8_t *keys, uint8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1u8_p1u8_u32_p1i8( + uint8_t *keys, uint8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1u8_p1i8_u32_p1i8( + uint8_t *keys, int8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1u8_p1i8_u32_p1i8( + uint8_t *keys, int8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1u8_p1u16_u32_p1i8( + uint8_t *keys, uint16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1u8_p1u16_u32_p1i8( + uint8_t *keys, uint16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1u8_p1i16_u32_p1i8( + uint8_t *keys, int16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1u8_p1i16_u32_p1i8( + uint8_t *keys, int16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1u8_p1u32_u32_p1i8( + uint8_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1u8_p1u32_u32_p1i8( + uint8_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1u8_p1i32_u32_p1i8( + uint8_t *keys, int32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1u8_p1i32_u32_p1i8( + uint8_t *keys, int32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1u8_p1u64_u32_p1i8( + uint8_t *keys, uint64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1u8_p1u64_u32_p1i8( + uint8_t *keys, uint64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1u8_p1i64_u32_p1i8( + uint8_t *keys, int64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1u8_p1i64_u32_p1i8( + uint8_t *keys, int64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1u8_p1f32_u32_p1i8( + uint8_t *keys, float *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1i8_p1u8_u32_p1i8( + int8_t *keys, uint8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1i8_p1u8_u32_p1i8( + int8_t *keys, uint8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1i8_p1i8_u32_p1i8( + int8_t *keys, int8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1i8_p1i8_u32_p1i8( + int8_t *keys, int8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1i8_p1u16_u32_p1i8( + int8_t *keys, uint16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1i8_p1u16_u32_p1i8( + int8_t *keys, uint16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1i8_p1i16_u32_p1i8( + int8_t *keys, int16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1i8_p1i16_u32_p1i8( + int8_t *keys, int16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1i8_p1u32_u32_p1i8( + int8_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1i8_p1u32_u32_p1i8( + int8_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1i8_p1i32_u32_p1i8( + int8_t *keys, int32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1i8_p1i32_u32_p1i8( + int8_t *keys, int32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1i8_p1u64_u32_p1i8( + int8_t *keys, uint64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1i8_p1u64_u32_p1i8( + int8_t *keys, uint64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1i8_p1i64_u32_p1i8( + int8_t *keys, int64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1i8_p1i64_u32_p1i8( + int8_t *keys, int64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1i8_p1f32_u32_p1i8( + int8_t *keys, float *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1u16_p1u8_u32_p1i8( + uint16_t *keys, uint8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1u16_p1u8_u32_p1i8( + uint16_t *keys, uint8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1u16_p1i8_u32_p1i8( + uint16_t *keys, int8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1u16_p1i8_u32_p1i8( + uint16_t *keys, int8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1u16_p1u16_u32_p1i8( + uint16_t *keys, uint16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1u16_p1u16_u32_p1i8( + uint16_t *keys, uint16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1u16_p1i16_u32_p1i8( + uint16_t *keys, int16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1u16_p1i16_u32_p1i8( + uint16_t *keys, int16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1u16_p1u32_u32_p1i8( + uint16_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1u16_p1u32_u32_p1i8( + uint16_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1u16_p1i32_u32_p1i8( + uint16_t *keys, int32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1u16_p1i32_u32_p1i8( + uint16_t *keys, int32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1u16_p1u64_u32_p1i8( + uint16_t *keys, uint64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1u16_p1u64_u32_p1i8( + uint16_t *keys, uint64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1u16_p1i64_u32_p1i8( + uint16_t *keys, int64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1u16_p1i64_u32_p1i8( + uint16_t *keys, int64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1u16_p1f32_u32_p1i8( + uint16_t *keys, float *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1u16_p1f32_u32_p1i8( + uint16_t *keys, float *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1i16_p1u8_u32_p1i8( + int16_t *keys, uint8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1i16_p1u8_u32_p1i8( + int16_t *keys, uint8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1i16_p1i8_u32_p1i8( + int16_t *keys, int8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1i16_p1i8_u32_p1i8( + int16_t *keys, int8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1i16_p1u16_u32_p1i8( + int16_t *keys, uint16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1i16_p1u16_u32_p1i8( + int16_t *keys, uint16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1i16_p1i16_u32_p1i8( + int16_t *keys, int16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1i16_p1i16_u32_p1i8( + int16_t *keys, int16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1i16_p1u32_u32_p1i8( + int16_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1i16_p1u32_u32_p1i8( + int16_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1i16_p1i32_u32_p1i8( + int16_t *keys, int32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1i16_p1i32_u32_p1i8( + int16_t *keys, int32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1i16_p1u64_u32_p1i8( + int16_t *keys, uint64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1i16_p1u64_u32_p1i8( + int16_t *keys, uint64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1i16_p1i64_u32_p1i8( + int16_t *keys, int64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1i16_p1i64_u32_p1i8( + int16_t *keys, int64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1i16_p1f32_u32_p1i8( + int16_t *keys, float *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1i16_p1f32_u32_p1i8( + int16_t *keys, float *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1u32_p1u8_u32_p1i8( + uint32_t *keys, uint8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1u32_p1u8_u32_p1i8( + uint32_t *keys, uint8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1u32_p1i8_u32_p1i8( + uint32_t *keys, int8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1u32_p1i8_u32_p1i8( + uint32_t *keys, int8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1u32_p1u16_u32_p1i8( + uint32_t *keys, uint16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1u32_p1u16_u32_p1i8( + uint32_t *keys, uint16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1u32_p1i16_u32_p1i8( + uint32_t *keys, int16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1u32_p1i16_u32_p1i8( + uint32_t *keys, int16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1u32_p1u32_u32_p1i8( + uint32_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1u32_p1u32_u32_p1i8( + uint32_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1u32_p1i32_u32_p1i8( + uint32_t *keys, int32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1u32_p1i32_u32_p1i8( + uint32_t *keys, int32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1u32_p1u64_u32_p1i8( + uint32_t *keys, uint64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1u32_p1u64_u32_p1i8( + uint32_t *keys, uint64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1u32_p1i64_u32_p1i8( + uint32_t *keys, int64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1u32_p1i64_u32_p1i8( + uint32_t *keys, int64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1i32_p1u8_u32_p1i8( + int32_t *keys, uint8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1i32_p1u8_u32_p1i8( + int32_t *keys, uint8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1i32_p1i8_u32_p1i8( + int32_t *keys, int8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1i32_p1i8_u32_p1i8( + int32_t *keys, int8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1i32_p1u16_u32_p1i8( + int32_t *keys, uint16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1i32_p1u16_u32_p1i8( + int32_t *keys, uint16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1i32_p1i16_u32_p1i8( + int32_t *keys, int16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1i32_p1i16_u32_p1i8( + int32_t *keys, int16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1i32_p1u32_u32_p1i8( + int32_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1i32_p1u32_u32_p1i8( + int32_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1i32_p1i32_u32_p1i8( + int32_t *keys, int32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1i32_p1i32_u32_p1i8( + int32_t *keys, int32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1i32_p1u64_u32_p1i8( + int32_t *keys, uint64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1i32_p1u64_u32_p1i8( + int32_t *keys, uint64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1i32_p1i64_u32_p1i8( + int32_t *keys, int64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1i32_p1i64_u32_p1i8( + int32_t *keys, int64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1u64_p1u8_u32_p1i8( + uint64_t *keys, uint8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1u64_p1u8_u32_p1i8( + uint64_t *keys, uint8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1u64_p1i8_u32_p1i8( + uint64_t *keys, int8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1u64_p1i8_u32_p1i8( + uint64_t *keys, int8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1u64_p1u16_u32_p1i8( + uint64_t *keys, uint16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1u64_p1u16_u32_p1i8( + uint64_t *keys, uint16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1u64_p1i16_u32_p1i8( + uint64_t *keys, int16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1u64_p1i16_u32_p1i8( + uint64_t *keys, int16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1u64_p1u32_u32_p1i8( + uint64_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1u64_p1u32_u32_p1i8( + uint64_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1u64_p1i32_u32_p1i8( + uint64_t *keys, int32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1u64_p1i32_u32_p1i8( + uint64_t *keys, int32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1u64_p1u64_u32_p1i8( + uint64_t *keys, uint64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1u64_p1u64_u32_p1i8( + uint64_t *keys, uint64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1u64_p1i64_u32_p1i8( + uint64_t *keys, int64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1u64_p1i64_u32_p1i8( + uint64_t *keys, int64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1i64_p1u8_u32_p1i8( + int64_t *keys, uint8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1i64_p1u8_u32_p1i8( + int64_t *keys, uint8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1i64_p1i8_u32_p1i8( + int64_t *keys, int8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1i64_p1i8_u32_p1i8( + int64_t *keys, int8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1i64_p1u16_u32_p1i8( + int64_t *keys, uint16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1i64_p1u16_u32_p1i8( + int64_t *keys, uint16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1i64_p1i16_u32_p1i8( + int64_t *keys, int16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1i64_p1i16_u32_p1i8( + int64_t *keys, int16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1i64_p1u32_u32_p1i8( + int64_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1i64_p1u32_u32_p1i8( + int64_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1i64_p1i32_u32_p1i8( + int64_t *keys, int32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1i64_p1i32_u32_p1i8( + int64_t *keys, int32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1i64_p1u64_u32_p1i8( + int64_t *keys, uint64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1i64_p1u64_u32_p1i8( + int64_t *keys, uint64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1i64_p1i64_u32_p1i8( + int64_t *keys, int64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1i64_p1i64_u32_p1i8( + int64_t *keys, int64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1u8_p1u8_u32_p1i8( + uint8_t *keys, uint8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1u8_p1u8_u32_p1i8( + uint8_t *keys, uint8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1u8_p1i8_u32_p1i8( + uint8_t *keys, int8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1u8_p1i8_u32_p1i8( + uint8_t *keys, int8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1u8_p1u16_u32_p1i8( + uint8_t *keys, uint16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1u8_p1u16_u32_p1i8( + uint8_t *keys, uint16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1u8_p1i16_u32_p1i8( + uint8_t *keys, int16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1u8_p1i16_u32_p1i8( + uint8_t *keys, int16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1u8_p1u32_u32_p1i8( + uint8_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1u8_p1u32_u32_p1i8( + uint8_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1u8_p1i32_u32_p1i8( + uint8_t *keys, int32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1u8_p1i32_u32_p1i8( + uint8_t *keys, int32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1u8_p1u64_u32_p1i8( + uint8_t *keys, uint64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1u8_p1u64_u32_p1i8( + uint8_t *keys, uint64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1u8_p1i64_u32_p1i8( + uint8_t *keys, int64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1u8_p1i64_u32_p1i8( + uint8_t *keys, int64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1u8_p1f32_u32_p1i8( + uint8_t *keys, float *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1i8_p1u8_u32_p1i8( + int8_t *keys, uint8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1i8_p1u8_u32_p1i8( + int8_t *keys, uint8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1i8_p1i8_u32_p1i8( + int8_t *keys, int8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1i8_p1i8_u32_p1i8( + int8_t *keys, int8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1i8_p1u16_u32_p1i8( + int8_t *keys, uint16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1i8_p1u16_u32_p1i8( + int8_t *keys, uint16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1i8_p1i16_u32_p1i8( + int8_t *keys, int16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1i8_p1i16_u32_p1i8( + int8_t *keys, int16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1i8_p1u32_u32_p1i8( + int8_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1i8_p1u32_u32_p1i8( + int8_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1i8_p1i32_u32_p1i8( + int8_t *keys, int32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1i8_p1i32_u32_p1i8( + int8_t *keys, int32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1i8_p1u64_u32_p1i8( + int8_t *keys, uint64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1i8_p1u64_u32_p1i8( + int8_t *keys, uint64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1i8_p1i64_u32_p1i8( + int8_t *keys, int64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1i8_p1i64_u32_p1i8( + int8_t *keys, int64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1i8_p1f32_u32_p1i8( + int8_t *keys, float *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1u16_p1u8_u32_p1i8( + uint16_t *keys, uint8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1u16_p1u8_u32_p1i8( + uint16_t *keys, uint8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1u16_p1i8_u32_p1i8( + uint16_t *keys, int8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1u16_p1i8_u32_p1i8( + uint16_t *keys, int8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1u16_p1u16_u32_p1i8( + uint16_t *keys, uint16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1u16_p1u16_u32_p1i8( + uint16_t *keys, uint16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1u16_p1i16_u32_p1i8( + uint16_t *keys, int16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1u16_p1i16_u32_p1i8( + uint16_t *keys, int16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1u16_p1u32_u32_p1i8( + uint16_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1u16_p1u32_u32_p1i8( + uint16_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1u16_p1i32_u32_p1i8( + uint16_t *keys, int32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1u16_p1i32_u32_p1i8( + uint16_t *keys, int32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1u16_p1u64_u32_p1i8( + uint16_t *keys, uint64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1u16_p1u64_u32_p1i8( + uint16_t *keys, uint64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1u16_p1i64_u32_p1i8( + uint16_t *keys, int64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1u16_p1i64_u32_p1i8( + uint16_t *keys, int64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1u16_p1f32_u32_p1i8( + uint16_t *keys, float *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1u16_p1f32_u32_p1i8( + uint16_t *keys, float *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1i16_p1u8_u32_p1i8( + int16_t *keys, uint8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1i16_p1u8_u32_p1i8( + int16_t *keys, uint8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1i16_p1i8_u32_p1i8( + int16_t *keys, int8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1i16_p1i8_u32_p1i8( + int16_t *keys, int8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1i16_p1u16_u32_p1i8( + int16_t *keys, uint16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1i16_p1u16_u32_p1i8( + int16_t *keys, uint16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1i16_p1i16_u32_p1i8( + int16_t *keys, int16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1i16_p1i16_u32_p1i8( + int16_t *keys, int16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1i16_p1u32_u32_p1i8( + int16_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1i16_p1u32_u32_p1i8( + int16_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1i16_p1i32_u32_p1i8( + int16_t *keys, int32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1i16_p1i32_u32_p1i8( + int16_t *keys, int32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1i16_p1u64_u32_p1i8( + int16_t *keys, uint64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1i16_p1u64_u32_p1i8( + int16_t *keys, uint64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1i16_p1i64_u32_p1i8( + int16_t *keys, int64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1i16_p1i64_u32_p1i8( + int16_t *keys, int64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1i16_p1f32_u32_p1i8( + int16_t *keys, float *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1i16_p1f32_u32_p1i8( + int16_t *keys, float *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1u32_p1u8_u32_p1i8( + uint32_t *keys, uint8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1u32_p1u8_u32_p1i8( + uint32_t *keys, uint8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1u32_p1i8_u32_p1i8( + uint32_t *keys, int8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1u32_p1i8_u32_p1i8( + uint32_t *keys, int8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1u32_p1u16_u32_p1i8( + uint32_t *keys, uint16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1u32_p1u16_u32_p1i8( + uint32_t *keys, uint16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1u32_p1i16_u32_p1i8( + uint32_t *keys, int16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1u32_p1i16_u32_p1i8( + uint32_t *keys, int16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1u32_p1u32_u32_p1i8( + uint32_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1u32_p1u32_u32_p1i8( + uint32_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1u32_p1i32_u32_p1i8( + uint32_t *keys, int32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1u32_p1i32_u32_p1i8( + uint32_t *keys, int32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1u32_p1u64_u32_p1i8( + uint32_t *keys, uint64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1u32_p1u64_u32_p1i8( + uint32_t *keys, uint64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1u32_p1i64_u32_p1i8( + uint32_t *keys, int64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1u32_p1i64_u32_p1i8( + uint32_t *keys, int64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1i32_p1u8_u32_p1i8( + int32_t *keys, uint8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1i32_p1u8_u32_p1i8( + int32_t *keys, uint8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1i32_p1i8_u32_p1i8( + int32_t *keys, int8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1i32_p1i8_u32_p1i8( + int32_t *keys, int8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1i32_p1u16_u32_p1i8( + int32_t *keys, uint16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1i32_p1u16_u32_p1i8( + int32_t *keys, uint16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1i32_p1i16_u32_p1i8( + int32_t *keys, int16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1i32_p1i16_u32_p1i8( + int32_t *keys, int16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1i32_p1u32_u32_p1i8( + int32_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1i32_p1u32_u32_p1i8( + int32_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1i32_p1i32_u32_p1i8( + int32_t *keys, int32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1i32_p1i32_u32_p1i8( + int32_t *keys, int32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1i32_p1u64_u32_p1i8( + int32_t *keys, uint64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1i32_p1u64_u32_p1i8( + int32_t *keys, uint64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1i32_p1i64_u32_p1i8( + int32_t *keys, int64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1i32_p1i64_u32_p1i8( + int32_t *keys, int64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1u64_p1u8_u32_p1i8( + uint64_t *keys, uint8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1u64_p1u8_u32_p1i8( + uint64_t *keys, uint8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1u64_p1i8_u32_p1i8( + uint64_t *keys, int8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1u64_p1i8_u32_p1i8( + uint64_t *keys, int8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1u64_p1u16_u32_p1i8( + uint64_t *keys, uint16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1u64_p1u16_u32_p1i8( + uint64_t *keys, uint16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1u64_p1i16_u32_p1i8( + uint64_t *keys, int16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1u64_p1i16_u32_p1i8( + uint64_t *keys, int16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1u64_p1u32_u32_p1i8( + uint64_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1u64_p1u32_u32_p1i8( + uint64_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1u64_p1i32_u32_p1i8( + uint64_t *keys, int32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1u64_p1i32_u32_p1i8( + uint64_t *keys, int32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1u64_p1u64_u32_p1i8( + uint64_t *keys, uint64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1u64_p1u64_u32_p1i8( + uint64_t *keys, uint64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1u64_p1i64_u32_p1i8( + uint64_t *keys, int64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1u64_p1i64_u32_p1i8( + uint64_t *keys, int64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1i64_p1u8_u32_p1i8( + int64_t *keys, uint8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1i64_p1u8_u32_p1i8( + int64_t *keys, uint8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1i64_p1i8_u32_p1i8( + int64_t *keys, int8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1i64_p1i8_u32_p1i8( + int64_t *keys, int8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1i64_p1u16_u32_p1i8( + int64_t *keys, uint16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1i64_p1u16_u32_p1i8( + int64_t *keys, uint16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1i64_p1i16_u32_p1i8( + int64_t *keys, int16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1i64_p1i16_u32_p1i8( + int64_t *keys, int16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1i64_p1u32_u32_p1i8( + int64_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1i64_p1u32_u32_p1i8( + int64_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1i64_p1i32_u32_p1i8( + int64_t *keys, int32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1i64_p1i32_u32_p1i8( + int64_t *keys, int32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1i64_p1u64_u32_p1i8( + int64_t *keys, uint64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1i64_p1u64_u32_p1i8( + int64_t *keys, uint64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1i64_p1i64_u32_p1i8( + int64_t *keys, int64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1i64_p1i64_u32_p1i8( + int64_t *keys, int64_t *vals, uint32_t n, uint8_t *scratch); + +using namespace sycl; + +template +void test_work_group_KV_private_sort(sycl::queue &q, const KeyT keys[NUM], + const ValT vals[NUM], SortHelper gsh) { + static_assert((NUM % WG_SZ == 0), + "Input number must be divisible by work group size!"); + + KeyT input_keys[NUM]; + ValT input_vals[NUM]; + memcpy(&input_keys[0], &keys[0], NUM * sizeof(KeyT)); + memcpy(&input_vals[0], &vals[0], NUM * sizeof(ValT)); + size_t scratch_size = 2 * NUM * (sizeof(KeyT) + sizeof(ValT)) + + std::max(alignof(KeyT), alignof(ValT)); + uint8_t *scratch_ptr = + (uint8_t *)aligned_alloc_device(alignof(KeyT), scratch_size, q); + const static size_t wg_size = WG_SZ; + constexpr size_t num_per_work_item = NUM / WG_SZ; + KeyT output_keys[NUM]; + ValT output_vals[NUM]; + std::vector> sorted_vec; + for (size_t idx = 0; idx < NUM; ++idx) + sorted_vec.push_back(std::make_tuple(input_keys[idx], input_vals[idx])); +#ifdef DES + auto kv_tuple_comp = [](const std::tuple &t1, + const std::tuple &t2) { + return std::get<0>(t1) > std::get<0>(t2); + }; +#else + auto kv_tuple_comp = [](const std::tuple &t1, + const std::tuple &t2) { + return std::get<0>(t1) < std::get<0>(t2); + }; +#endif + std::stable_sort(sorted_vec.begin(), sorted_vec.end(), kv_tuple_comp); + + /*for (size_t idx = 0; idx < NUM; ++idx) { + std::cout << "key: " << (int)std::get<0>(sorted_vec[idx]) << " val: " << + (int)std::get<1>(sorted_vec[idx]) << std::endl; + }*/ + + nd_range<1> num_items((range<1>(wg_size)), (range<1>(wg_size))); + { + buffer ikeys_buf(input_keys, NUM); + buffer ivals_buf(input_vals, NUM); + buffer okeys_buf(output_keys, NUM); + buffer ovals_buf(output_vals, NUM); + q.submit([&](auto &h) { + accessor ikeys_acc{ikeys_buf, h}; + accessor ivals_acc{ivals_buf, h}; + accessor okeys_acc{okeys_buf, h}; + accessor ovals_acc{ovals_buf, h}; + h.parallel_for(num_items, [=](nd_item<1> i) { + KeyT pkeys[num_per_work_item]; + ValT pvals[num_per_work_item]; + // copy from global input to fix-size private array. + for (size_t idx = 0; idx < num_per_work_item; ++idx) { + pkeys[idx] = + ikeys_acc[i.get_local_linear_id() * num_per_work_item + idx]; + pvals[idx] = + ivals_acc[i.get_local_linear_id() * num_per_work_item + idx]; + } + + gsh(pkeys, pvals, num_per_work_item, scratch_ptr); + + for (size_t idx = 0; idx < num_per_work_item; ++idx) { + okeys_acc[i.get_local_linear_id() * num_per_work_item + idx] = + pkeys[idx]; + ovals_acc[i.get_local_linear_id() * num_per_work_item + idx] = + pvals[idx]; + } + }); + }).wait(); + } + + /* for (size_t idx = 0; idx < NUM; ++idx) { + std::cout << "key: " << (int)(input_keys[idx]) << " val: " << + (int)(input_vals[idx]) << std::endl; + }*/ + + sycl::free(scratch_ptr, q); + bool fails = false; +#ifdef SPREAD + for (size_t idx = 0; idx < NUM; ++idx) { + size_t idx1 = idx / WG_SZ; + size_t idx2 = idx % WG_SZ; + if ((output_keys[idx2 * num_per_work_item + idx1] != + std::get<0>(sorted_vec[idx])) || + (output_vals[idx2 * num_per_work_item + idx1] != + std::get<1>(sorted_vec[idx]))) { + std::cout << "idx: " << idx << std::endl; + fails = true; + break; + } + } +#else + for (size_t idx = 0; idx < NUM; ++idx) { + if ((output_keys[idx] != std::get<0>(sorted_vec[idx])) || + (output_vals[idx] != std::get<1>(sorted_vec[idx]))) { + std::cout << "idx: " << idx << std::endl; + fails = true; + break; + } + } +#endif + assert(!fails); +} diff --git a/sycl/test-e2e/DeviceLib/group_sort/group_private_KV_sort_p1p1_p3.hpp b/sycl/test-e2e/DeviceLib/group_sort/group_private_KV_sort_p1p1_p3.hpp new file mode 100644 index 0000000000000..fbed182a3926e --- /dev/null +++ b/sycl/test-e2e/DeviceLib/group_sort/group_private_KV_sort_p1p1_p3.hpp @@ -0,0 +1,1192 @@ +#include "group_sort.hpp" +#include +#include +#include +#include +#include +#include +#include +#include +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1u8_p1u8_u32_p3i8( + uint8_t *keys, uint8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1u8_p1u8_u32_p3i8( + uint8_t *keys, uint8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1u8_p1i8_u32_p3i8( + uint8_t *keys, int8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1u8_p1i8_u32_p3i8( + uint8_t *keys, int8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1u8_p1u16_u32_p3i8( + uint8_t *keys, uint16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1u8_p1u16_u32_p3i8( + uint8_t *keys, uint16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1u8_p1i16_u32_p3i8( + uint8_t *keys, int16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1u8_p1i16_u32_p3i8( + uint8_t *keys, int16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1u8_p1u32_u32_p3i8( + uint8_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1u8_p1u32_u32_p3i8( + uint8_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1u8_p1i32_u32_p3i8( + uint8_t *keys, int32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1u8_p1i32_u32_p3i8( + uint8_t *keys, int32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1u8_p1u64_u32_p3i8( + uint8_t *keys, uint64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1u8_p1u64_u32_p3i8( + uint8_t *keys, uint64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1u8_p1i64_u32_p3i8( + uint8_t *keys, int64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1u8_p1i64_u32_p3i8( + uint8_t *keys, int64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1u8_p1f32_u32_p3i8( + uint8_t *keys, float *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1i8_p1u8_u32_p3i8( + int8_t *keys, uint8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1i8_p1u8_u32_p3i8( + int8_t *keys, uint8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1i8_p1i8_u32_p3i8( + int8_t *keys, int8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1i8_p1i8_u32_p3i8( + int8_t *keys, int8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1i8_p1u16_u32_p3i8( + int8_t *keys, uint16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1i8_p1u16_u32_p3i8( + int8_t *keys, uint16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1i8_p1i16_u32_p3i8( + int8_t *keys, int16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1i8_p1i16_u32_p3i8( + int8_t *keys, int16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1i8_p1u32_u32_p3i8( + int8_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1i8_p1u32_u32_p3i8( + int8_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1i8_p1i32_u32_p3i8( + int8_t *keys, int32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1i8_p1i32_u32_p3i8( + int8_t *keys, int32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1i8_p1u64_u32_p3i8( + int8_t *keys, uint64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1i8_p1u64_u32_p3i8( + int8_t *keys, uint64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1i8_p1i64_u32_p3i8( + int8_t *keys, int64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1i8_p1i64_u32_p3i8( + int8_t *keys, int64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1i8_p1f32_u32_p3i8( + int8_t *keys, float *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1u16_p1u8_u32_p3i8( + uint16_t *keys, uint8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1u16_p1u8_u32_p3i8( + uint16_t *keys, uint8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1u16_p1i8_u32_p3i8( + uint16_t *keys, int8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1u16_p1i8_u32_p3i8( + uint16_t *keys, int8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1u16_p1u16_u32_p3i8( + uint16_t *keys, uint16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1u16_p1u16_u32_p3i8( + uint16_t *keys, uint16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1u16_p1i16_u32_p3i8( + uint16_t *keys, int16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1u16_p1i16_u32_p3i8( + uint16_t *keys, int16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1u16_p1u32_u32_p3i8( + uint16_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1u16_p1u32_u32_p3i8( + uint16_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1u16_p1i32_u32_p3i8( + uint16_t *keys, int32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1u16_p1i32_u32_p3i8( + uint16_t *keys, int32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1u16_p1u64_u32_p3i8( + uint16_t *keys, uint64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1u16_p1u64_u32_p3i8( + uint16_t *keys, uint64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1u16_p1i64_u32_p3i8( + uint16_t *keys, int64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1u16_p1i64_u32_p3i8( + uint16_t *keys, int64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1u16_p1f32_u32_p3i8( + uint16_t *keys, float *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1u16_p1f32_u32_p3i8( + uint16_t *keys, float *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1i16_p1u8_u32_p3i8( + int16_t *keys, uint8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1i16_p1u8_u32_p3i8( + int16_t *keys, uint8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1i16_p1i8_u32_p3i8( + int16_t *keys, int8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1i16_p1i8_u32_p3i8( + int16_t *keys, int8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1i16_p1u16_u32_p3i8( + int16_t *keys, uint16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1i16_p1u16_u32_p3i8( + int16_t *keys, uint16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1i16_p1i16_u32_p3i8( + int16_t *keys, int16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1i16_p1i16_u32_p3i8( + int16_t *keys, int16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1i16_p1u32_u32_p3i8( + int16_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1i16_p1u32_u32_p3i8( + int16_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1i16_p1i32_u32_p3i8( + int16_t *keys, int32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1i16_p1i32_u32_p3i8( + int16_t *keys, int32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1i16_p1u64_u32_p3i8( + int16_t *keys, uint64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1i16_p1u64_u32_p3i8( + int16_t *keys, uint64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1i16_p1i64_u32_p3i8( + int16_t *keys, int64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1i16_p1i64_u32_p3i8( + int16_t *keys, int64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1i16_p1f32_u32_p3i8( + int16_t *keys, float *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1i16_p1f32_u32_p3i8( + int16_t *keys, float *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1u32_p1u8_u32_p3i8( + uint32_t *keys, uint8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1u32_p1u8_u32_p3i8( + uint32_t *keys, uint8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1u32_p3i8_u32_p3i8( + uint32_t *keys, int8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1u32_p3i8_u32_p3i8( + uint32_t *keys, int8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1u32_p1u16_u32_p3i8( + uint32_t *keys, uint16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1u32_p1u16_u32_p3i8( + uint32_t *keys, uint16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1u32_p1i16_u32_p3i8( + uint32_t *keys, int16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1u32_p1i16_u32_p3i8( + uint32_t *keys, int16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1u32_p1u32_u32_p3i8( + uint32_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1u32_p1u32_u32_p3i8( + uint32_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1u32_p1i32_u32_p3i8( + uint32_t *keys, int32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1u32_p1i32_u32_p3i8( + uint32_t *keys, int32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1u32_p1u64_u32_p3i8( + uint32_t *keys, uint64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1u32_p1u64_u32_p3i8( + uint32_t *keys, uint64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1u32_p1i64_u32_p3i8( + uint32_t *keys, int64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1u32_p1i64_u32_p3i8( + uint32_t *keys, int64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1i32_p1u8_u32_p3i8( + int32_t *keys, uint8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1i32_p1u8_u32_p3i8( + int32_t *keys, uint8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1i32_p1i8_u32_p3i8( + int32_t *keys, int8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1i32_p1i8_u32_p3i8( + int32_t *keys, int8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1i32_p1u16_u32_p3i8( + int32_t *keys, uint16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1i32_p1u16_u32_p3i8( + int32_t *keys, uint16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1i32_p1i16_u32_p3i8( + int32_t *keys, int16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1i32_p1i16_u32_p3i8( + int32_t *keys, int16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1i32_p1u32_u32_p3i8( + int32_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1i32_p1u32_u32_p3i8( + int32_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1i32_p1i32_u32_p3i8( + int32_t *keys, int32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1i32_p1i32_u32_p3i8( + int32_t *keys, int32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1i32_p1u64_u32_p3i8( + int32_t *keys, uint64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1i32_p1u64_u32_p3i8( + int32_t *keys, uint64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1i32_p1i64_u32_p3i8( + int32_t *keys, int64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1i32_p1i64_u32_p3i8( + int32_t *keys, int64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1u64_p1u8_u32_p3i8( + uint64_t *keys, uint8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1u64_p1u8_u32_p3i8( + uint64_t *keys, uint8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1u64_p1i8_u32_p3i8( + uint64_t *keys, int8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1u64_p1i8_u32_p3i8( + uint64_t *keys, int8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1u64_p1u16_u32_p3i8( + uint64_t *keys, uint16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1u64_p1u16_u32_p3i8( + uint64_t *keys, uint16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1u64_p1i16_u32_p3i8( + uint64_t *keys, int16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1u64_p1i16_u32_p3i8( + uint64_t *keys, int16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1u64_p1u32_u32_p3i8( + uint64_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1u64_p1u32_u32_p3i8( + uint64_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1u64_p1i32_u32_p3i8( + uint64_t *keys, int32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1u64_p1i32_u32_p3i8( + uint64_t *keys, int32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1u64_p1u64_u32_p3i8( + uint64_t *keys, uint64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1u64_p1u64_u32_p3i8( + uint64_t *keys, uint64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1u64_p1i64_u32_p3i8( + uint64_t *keys, int64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1u64_p1i64_u32_p3i8( + uint64_t *keys, int64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1i64_p1u8_u32_p3i8( + int64_t *keys, uint8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1i64_p1u8_u32_p3i8( + int64_t *keys, uint8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1i64_p1i8_u32_p3i8( + int64_t *keys, int8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1i64_p1i8_u32_p3i8( + int64_t *keys, int8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1i64_p1u16_u32_p3i8( + int64_t *keys, uint16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1i64_p1u16_u32_p3i8( + int64_t *keys, uint16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1i64_p1i16_u32_p3i8( + int64_t *keys, int16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1i64_p1i16_u32_p3i8( + int64_t *keys, int16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1i64_p1u32_u32_p3i8( + int64_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1i64_p1u32_u32_p3i8( + int64_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1i64_p1i32_u32_p3i8( + int64_t *keys, int32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1i64_p1i32_u32_p3i8( + int64_t *keys, int32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1i64_p1u64_u32_p3i8( + int64_t *keys, uint64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1i64_p1u64_u32_p3i8( + int64_t *keys, uint64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1i64_p1i64_u32_p3i8( + int64_t *keys, int64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1i64_p1i64_u32_p3i8( + int64_t *keys, int64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1u8_p1u8_u32_p3i8( + uint8_t *keys, uint8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1u8_p1u8_u32_p3i8( + uint8_t *keys, uint8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1u8_p1i8_u32_p3i8( + uint8_t *keys, int8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1u8_p1i8_u32_p3i8( + uint8_t *keys, int8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1u8_p1u16_u32_p3i8( + uint8_t *keys, uint16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1u8_p1u16_u32_p3i8( + uint8_t *keys, uint16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1u8_p1i16_u32_p3i8( + uint8_t *keys, int16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1u8_p1i16_u32_p3i8( + uint8_t *keys, int16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1u8_p1u32_u32_p3i8( + uint8_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1u8_p1u32_u32_p3i8( + uint8_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1u8_p1i32_u32_p3i8( + uint8_t *keys, int32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1u8_p1i32_u32_p3i8( + uint8_t *keys, int32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1u8_p1u64_u32_p3i8( + uint8_t *keys, uint64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1u8_p1u64_u32_p3i8( + uint8_t *keys, uint64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1u8_p1i64_u32_p3i8( + uint8_t *keys, int64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1u8_p1i64_u32_p3i8( + uint8_t *keys, int64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1u8_p1f32_u32_p3i8( + uint8_t *keys, float *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1i8_p1u8_u32_p3i8( + int8_t *keys, uint8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1i8_p1u8_u32_p3i8( + int8_t *keys, uint8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1i8_p1i8_u32_p3i8( + int8_t *keys, int8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1i8_p1i8_u32_p3i8( + int8_t *keys, int8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1i8_p1u16_u32_p3i8( + int8_t *keys, uint16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1i8_p1u16_u32_p3i8( + int8_t *keys, uint16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1i8_p1i16_u32_p3i8( + int8_t *keys, int16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1i8_p1i16_u32_p3i8( + int8_t *keys, int16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1i8_p1u32_u32_p3i8( + int8_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1i8_p1u32_u32_p3i8( + int8_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1i8_p1i32_u32_p3i8( + int8_t *keys, int32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1i8_p1i32_u32_p3i8( + int8_t *keys, int32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1i8_p1u64_u32_p3i8( + int8_t *keys, uint64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1i8_p1u64_u32_p3i8( + int8_t *keys, uint64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1i8_p1i64_u32_p3i8( + int8_t *keys, int64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1i8_p1i64_u32_p3i8( + int8_t *keys, int64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1i8_p1f32_u32_p3i8( + int8_t *keys, float *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1u16_p1u8_u32_p3i8( + uint16_t *keys, uint8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1u16_p1u8_u32_p3i8( + uint16_t *keys, uint8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1u16_p1i8_u32_p3i8( + uint16_t *keys, int8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1u16_p1i8_u32_p3i8( + uint16_t *keys, int8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1u16_p1u16_u32_p3i8( + uint16_t *keys, uint16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1u16_p1u16_u32_p3i8( + uint16_t *keys, uint16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1u16_p1i16_u32_p3i8( + uint16_t *keys, int16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1u16_p1i16_u32_p3i8( + uint16_t *keys, int16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1u16_p1u32_u32_p3i8( + uint16_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1u16_p1u32_u32_p3i8( + uint16_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1u16_p1i32_u32_p3i8( + uint16_t *keys, int32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1u16_p1i32_u32_p3i8( + uint16_t *keys, int32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1u16_p1u64_u32_p3i8( + uint16_t *keys, uint64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1u16_p1u64_u32_p3i8( + uint16_t *keys, uint64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1u16_p1i64_u32_p3i8( + uint16_t *keys, int64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1u16_p1i64_u32_p3i8( + uint16_t *keys, int64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1u16_p1f32_u32_p3i8( + uint16_t *keys, float *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1u16_p1f32_u32_p3i8( + uint16_t *keys, float *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1i16_p1u8_u32_p3i8( + int16_t *keys, uint8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1i16_p1u8_u32_p3i8( + int16_t *keys, uint8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1i16_p1i8_u32_p3i8( + int16_t *keys, int8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1i16_p1i8_u32_p3i8( + int16_t *keys, int8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1i16_p1u16_u32_p3i8( + int16_t *keys, uint16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1i16_p1u16_u32_p3i8( + int16_t *keys, uint16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1i16_p1i16_u32_p3i8( + int16_t *keys, int16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1i16_p1i16_u32_p3i8( + int16_t *keys, int16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1i16_p1u32_u32_p3i8( + int16_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1i16_p1u32_u32_p3i8( + int16_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1i16_p1i32_u32_p3i8( + int16_t *keys, int32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1i16_p1i32_u32_p3i8( + int16_t *keys, int32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1i16_p1u64_u32_p3i8( + int16_t *keys, uint64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1i16_p1u64_u32_p3i8( + int16_t *keys, uint64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1i16_p1i64_u32_p3i8( + int16_t *keys, int64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1i16_p1i64_u32_p3i8( + int16_t *keys, int64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1i16_p1f32_u32_p3i8( + int16_t *keys, float *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1i16_p1f32_u32_p3i8( + int16_t *keys, float *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1u32_p1u8_u32_p3i8( + uint32_t *keys, uint8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1u32_p1u8_u32_p3i8( + uint32_t *keys, uint8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1u32_p3i8_u32_p3i8( + uint32_t *keys, int8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1u32_p3i8_u32_p3i8( + uint32_t *keys, int8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1u32_p1u16_u32_p3i8( + uint32_t *keys, uint16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1u32_p1u16_u32_p3i8( + uint32_t *keys, uint16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1u32_p1i16_u32_p3i8( + uint32_t *keys, int16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1u32_p1i16_u32_p3i8( + uint32_t *keys, int16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1u32_p1u32_u32_p3i8( + uint32_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1u32_p1u32_u32_p3i8( + uint32_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1u32_p1i32_u32_p3i8( + uint32_t *keys, int32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1u32_p1i32_u32_p3i8( + uint32_t *keys, int32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1u32_p1u64_u32_p3i8( + uint32_t *keys, uint64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1u32_p1u64_u32_p3i8( + uint32_t *keys, uint64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1u32_p1i64_u32_p3i8( + uint32_t *keys, int64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1u32_p1i64_u32_p3i8( + uint32_t *keys, int64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1i32_p1u8_u32_p3i8( + int32_t *keys, uint8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1i32_p1u8_u32_p3i8( + int32_t *keys, uint8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1i32_p1i8_u32_p3i8( + int32_t *keys, int8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1i32_p1i8_u32_p3i8( + int32_t *keys, int8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1i32_p1u16_u32_p3i8( + int32_t *keys, uint16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1i32_p1u16_u32_p3i8( + int32_t *keys, uint16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1i32_p1i16_u32_p3i8( + int32_t *keys, int16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1i32_p1i16_u32_p3i8( + int32_t *keys, int16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1i32_p1u32_u32_p3i8( + int32_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1i32_p1u32_u32_p3i8( + int32_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1i32_p1i32_u32_p3i8( + int32_t *keys, int32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1i32_p1i32_u32_p3i8( + int32_t *keys, int32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1i32_p1u64_u32_p3i8( + int32_t *keys, uint64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1i32_p1u64_u32_p3i8( + int32_t *keys, uint64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1i32_p1i64_u32_p3i8( + int32_t *keys, int64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1i32_p1i64_u32_p3i8( + int32_t *keys, int64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1u64_p1u8_u32_p3i8( + uint64_t *keys, uint8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1u64_p1u8_u32_p3i8( + uint64_t *keys, uint8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1u64_p1i8_u32_p3i8( + uint64_t *keys, int8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1u64_p1i8_u32_p3i8( + uint64_t *keys, int8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1u64_p1u16_u32_p3i8( + uint64_t *keys, uint16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1u64_p1u16_u32_p3i8( + uint64_t *keys, uint16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1u64_p1i16_u32_p3i8( + uint64_t *keys, int16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1u64_p1i16_u32_p3i8( + uint64_t *keys, int16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1u64_p1u32_u32_p3i8( + uint64_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1u64_p1u32_u32_p3i8( + uint64_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1u64_p1i32_u32_p3i8( + uint64_t *keys, int32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1u64_p1i32_u32_p3i8( + uint64_t *keys, int32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1u64_p1u64_u32_p3i8( + uint64_t *keys, uint64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1u64_p1u64_u32_p3i8( + uint64_t *keys, uint64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1u64_p1i64_u32_p3i8( + uint64_t *keys, int64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1u64_p1i64_u32_p3i8( + uint64_t *keys, int64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1i64_p1u8_u32_p3i8( + int64_t *keys, uint8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1i64_p1u8_u32_p3i8( + int64_t *keys, uint8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1i64_p1i8_u32_p3i8( + int64_t *keys, int8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1i64_p1i8_u32_p3i8( + int64_t *keys, int8_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1i64_p1u16_u32_p3i8( + int64_t *keys, uint16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1i64_p1u16_u32_p3i8( + int64_t *keys, uint16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1i64_p1i16_u32_p3i8( + int64_t *keys, int16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1i64_p1i16_u32_p3i8( + int64_t *keys, int16_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1i64_p1u32_u32_p3i8( + int64_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1i64_p1u32_u32_p3i8( + int64_t *keys, uint32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1i64_p1i32_u32_p3i8( + int64_t *keys, int32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1i64_p1i32_u32_p3i8( + int64_t *keys, int32_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1i64_p1u64_u32_p3i8( + int64_t *keys, uint64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1i64_p1u64_u32_p3i8( + int64_t *keys, uint64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1i64_p1i64_u32_p3i8( + int64_t *keys, int64_t *vals, uint32_t n, uint8_t *scratch); + +__DPCPP_SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1i64_p1i64_u32_p3i8( + int64_t *keys, int64_t *vals, uint32_t n, uint8_t *scratch); + +using namespace sycl; + +template +void test_work_group_KV_private_sort(sycl::queue &q, const KeyT keys[NUM], + const ValT vals[NUM], SortHelper gsh) { + static_assert((NUM % WG_SZ == 0), + "Input number must be divisible by work group size!"); + + KeyT input_keys[NUM]; + ValT input_vals[NUM]; + memcpy(&input_keys[0], &keys[0], NUM * sizeof(KeyT)); + memcpy(&input_vals[0], &vals[0], NUM * sizeof(ValT)); + // Make sure sratch memory is always 8-byte aligned. + size_t scratch_size = 2 * NUM * (sizeof(KeyT) + sizeof(ValT)) + + std::max(alignof(KeyT), alignof(ValT)); + scratch_size = (((scratch_size - 1) >> 3) + 1) << 3; + + const static size_t wg_size = WG_SZ; + constexpr size_t num_per_work_item = NUM / WG_SZ; + KeyT output_keys[NUM]; + ValT output_vals[NUM]; + std::vector> sorted_vec; + for (size_t idx = 0; idx < NUM; ++idx) + sorted_vec.push_back(std::make_tuple(input_keys[idx], input_vals[idx])); +#ifdef DES + auto kv_tuple_comp = [](const std::tuple &t1, + const std::tuple &t2) { + return std::get<0>(t1) > std::get<0>(t2); + }; +#else + auto kv_tuple_comp = [](const std::tuple &t1, + const std::tuple &t2) { + return std::get<0>(t1) < std::get<0>(t2); + }; +#endif + std::stable_sort(sorted_vec.begin(), sorted_vec.end(), kv_tuple_comp); + + /*for (size_t idx = 0; idx < NUM; ++idx) { + std::cout << "key: " << (int)std::get<0>(sorted_vec[idx]) << " val: " << + (int)std::get<1>(sorted_vec[idx]) << std::endl; + }*/ + + nd_range<1> num_items((range<1>(wg_size)), (range<1>(wg_size))); + { + buffer ikeys_buf(input_keys, NUM); + buffer ivals_buf(input_vals, NUM); + buffer okeys_buf(output_keys, NUM); + buffer ovals_buf(output_vals, NUM); + q.submit([&](auto &h) { + accessor ikeys_acc{ikeys_buf, h}; + accessor ivals_acc{ivals_buf, h}; + accessor okeys_acc{okeys_buf, h}; + accessor ovals_acc{ovals_buf, h}; + local_accessor scratch_acc(scratch_size >> 3, h); + h.parallel_for(num_items, [=](nd_item<1> i) { + KeyT pkeys[num_per_work_item]; + ValT pvals[num_per_work_item]; + // copy from global input to fix-size private array. + for (size_t idx = 0; idx < num_per_work_item; ++idx) { + pkeys[idx] = + ikeys_acc[i.get_local_linear_id() * num_per_work_item + idx]; + pvals[idx] = + ivals_acc[i.get_local_linear_id() * num_per_work_item + idx]; + } + + uint8_t *scratch_ptr = reinterpret_cast( + scratch_acc.template get_multi_ptr().get()); + gsh(pkeys, pvals, num_per_work_item, scratch_ptr); + + for (size_t idx = 0; idx < num_per_work_item; ++idx) { + okeys_acc[i.get_local_linear_id() * num_per_work_item + idx] = + pkeys[idx]; + ovals_acc[i.get_local_linear_id() * num_per_work_item + idx] = + pvals[idx]; + } + }); + }).wait(); + } + + /* for (size_t idx = 0; idx < NUM; ++idx) { + std::cout << "key: " << (int)(input_keys[idx]) << " val: " << + (int)(input_vals[idx]) << std::endl; + }*/ + + bool fails = false; +#ifdef SPREAD + for (size_t idx = 0; idx < NUM; ++idx) { + size_t idx1 = idx / WG_SZ; + size_t idx2 = idx % WG_SZ; + if ((output_keys[idx2 * num_per_work_item + idx1] != + std::get<0>(sorted_vec[idx])) || + (output_vals[idx2 * num_per_work_item + idx1] != + std::get<1>(sorted_vec[idx]))) { + std::cout << "idx: " << idx << std::endl; + fails = true; + break; + } + } +#else + for (size_t idx = 0; idx < NUM; ++idx) { + if ((output_keys[idx] != std::get<0>(sorted_vec[idx])) || + (output_vals[idx] != std::get<1>(sorted_vec[idx]))) { + std::cout << "idx: " << idx << std::endl; + fails = true; + break; + } + } +#endif + assert(!fails); +} diff --git a/sycl/test-e2e/DeviceLib/group_sort/group_private_sort_p1p1.hpp b/sycl/test-e2e/DeviceLib/group_sort/group_private_sort_p1p1.hpp new file mode 100644 index 0000000000000..33618b3e1682a --- /dev/null +++ b/sycl/test-e2e/DeviceLib/group_sort/group_private_sort_p1p1.hpp @@ -0,0 +1,124 @@ +#pragma once +#include +#ifdef __SYCL_DEVICE_ONLY__ +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1i8_u32_p1i8( + int8_t *first, uint32_t n, uint8_t *scratch); + +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1i8_u32_p1i8( + int8_t *first, uint32_t n, uint8_t *scratch); + +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1i8_u32_p1i8( + int8_t *first, uint32_t n, uint8_t *scratch); + +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1i8_u32_p1i8( + int8_t *first, uint32_t n, uint8_t *scratch); + +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1i16_u32_p1i8( + int16_t *first, uint32_t n, uint8_t *scratch); + +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1i16_u32_p1i8( + int16_t *first, uint32_t n, uint8_t *scratch); + +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1i16_u32_p1i8( + int16_t *first, uint32_t n, uint8_t *scratch); + +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1i16_u32_p1i8( + int16_t *first, uint32_t n, uint8_t *scratch); + +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1i32_u32_p1i8( + int32_t *first, uint32_t n, uint8_t *scratch); + +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1i32_u32_p1i8( + int32_t *first, uint32_t n, uint8_t *scratch); + +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1i32_u32_p1i8( + int32_t *first, uint32_t n, uint8_t *scratch); + +SYCL_EXTERNAL extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1i32_u32_p1i8( + int32_t *first, uint32_t n, uint8_t *scratch); +#else +extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1i8_u32_p1i8( + int8_t *first, uint32_t n, uint8_t *scratch) { + return; +} + +extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1i8_u32_p1i8( + int8_t *first, uint32_t n, uint8_t *scratch) { + return; +} + +extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1i8_u32_p1i8( + int8_t *first, uint32_t n, uint8_t *scratch) { + return; +} + +extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1i8_u32_p1i8( + int8_t *first, uint32_t n, uint8_t *scratch) { + return; +} + +extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1i16_u32_p1i8( + int16_t *first, uint32_t n, uint8_t *scratch) { + return; +} + +extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1i16_u32_p1i8( + int16_t *first, uint32_t n, uint8_t *scratch) { + return; +} + +extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1i16_u32_p1i8( + int16_t *first, uint32_t n, uint8_t *scratch) { + return; +} + +extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1i16_u32_p1i8( + int16_t *first, uint32_t n, uint8_t *scratch) { + return; +} + +extern "C" void +__devicelib_default_work_group_private_sort_close_ascending_p1i32_u32_p1i8( + int32_t *first, uint32_t n, uint8_t *scratch) { + return; +} + +extern "C" void +__devicelib_default_work_group_private_sort_close_descending_p1i32_u32_p1i8( + int32_t *first, uint32_t n, uint8_t *scratch) { + return; +} + +extern "C" void +__devicelib_default_work_group_private_sort_spread_ascending_p1i32_u32_p1i8( + int32_t *first, uint32_t n, uint8_t *scratch) { + return; +} + +extern "C" void +__devicelib_default_work_group_private_sort_spread_descending_p1i32_u32_p1i8( + int32_t *first, uint32_t n, uint8_t *scratch) { + return; +} + +#endif diff --git a/sycl/test-e2e/DeviceLib/group_sort/group_sort.hpp b/sycl/test-e2e/DeviceLib/group_sort/group_sort.hpp new file mode 100644 index 0000000000000..b1fbf85551c97 --- /dev/null +++ b/sycl/test-e2e/DeviceLib/group_sort/group_sort.hpp @@ -0,0 +1,8 @@ +#pragma once +#include + +#if defined(__SYCL_DEVICE_ONLY__) && (defined(__SPIR__) || defined(__SPIRV__)) +#define __DEVICE_CODE 1 +#else +#define __DEVICE_CODE 0 +#endif diff --git a/sycl/test-e2e/DeviceLib/group_sort/workgroup_joint_KV_sort_p1p1_p1.cpp b/sycl/test-e2e/DeviceLib/group_sort/workgroup_joint_KV_sort_p1p1_p1.cpp new file mode 100644 index 0000000000000..c2774191593ea --- /dev/null +++ b/sycl/test-e2e/DeviceLib/group_sort/workgroup_joint_KV_sort_p1p1_p1.cpp @@ -0,0 +1,1025 @@ +// RUN: %{build} -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -fsycl-device-lib-jit-link -o %t.out +// RUN: %{run} %t.out + +// UNSUPPORTED: cuda || hip || cpu + +#include "group_joint_KV_sort_p1p1_p1.hpp" +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace sycl; + +template +void test_work_group_KV_joint_sort(sycl::queue &q, KeyT keys[NUM], + ValT vals[NUM], SortHelper gsh) { + + KeyT input_keys[NUM]; + ValT input_vals[NUM]; + memcpy((void *)input_keys, (void *)keys, NUM * sizeof(KeyT)); + memcpy((void *)input_vals, (void *)vals, NUM * sizeof(ValT)); + size_t scratch_size = NUM * (sizeof(KeyT) + sizeof(ValT)) + + std::max(alignof(KeyT), alignof(ValT)); + uint8_t *scratch_ptr = + (uint8_t *)aligned_alloc_device(alignof(KeyT), scratch_size, q); + const static size_t wg_size = WG_SZ; + std::vector> sorted_vec; + for (size_t idx = 0; idx < NUM; ++idx) + sorted_vec.push_back(std::make_tuple(input_keys[idx], input_vals[idx])); +#ifdef DES + auto kv_tuple_comp = [](const std::tuple &t1, + const std::tuple &t2) { + return std::get<0>(t1) > std::get<0>(t2); + }; +#else + auto kv_tuple_comp = [](const std::tuple &t1, + const std::tuple &t2) { + return std::get<0>(t1) < std::get<0>(t2); + }; +#endif + std::stable_sort(sorted_vec.begin(), sorted_vec.end(), kv_tuple_comp); + + /* for (size_t idx = 0; idx < NUM; ++idx) { + std::cout << "key: " << (int)std::get<0>(sorted_vec[idx]) << " val: " << + (int)std::get<1>(sorted_vec[idx]) << std::endl; + } */ + + nd_range<1> num_items((range<1>(wg_size)), (range<1>(wg_size))); + { + buffer ikeys_buf(input_keys, NUM); + buffer ivals_buf(input_vals, NUM); + q.submit([&](auto &h) { + accessor ikeys_acc{ikeys_buf, h}; + accessor ivals_acc{ivals_buf, h}; + h.parallel_for(num_items, [=](nd_item<1> i) { + gsh(ikeys_acc.template get_multi_ptr().get(), + ivals_acc.template get_multi_ptr().get(), + NUM, scratch_ptr); + }); + }).wait(); + } + + /* for (size_t idx = 0; idx < NUM; ++idx) { + std::cout << "key: " << (int)(input_keys[idx]) << " val: " << + (int)(input_vals[idx]) << std::endl; + }*/ + + sycl::free(scratch_ptr, q); + bool fails = false; + for (size_t idx = 0; idx < NUM; ++idx) { + if ((input_keys[idx] != std::get<0>(sorted_vec[idx])) || + (input_vals[idx] != std::get<1>(sorted_vec[idx]))) { + fails = true; + std::cout << idx << std::endl; + break; + } + } + assert(!fails); +} + +int main() { + queue q; + + { + constexpr static int NUM = 23; + uint8_t ikeys[NUM] = {1, 11, 1, 9, 3, 100, 34, 8, 121, 77, 125, 23, + 36, 2, 111, 91, 88, 2, 51, 91, 81, 122, 22}; + int8_t ikeys1[NUM] = {0, 1, -2, 1, 122, -123, 99, -91, 9, 12, 13, 46, + 13, 13, 9, 5, 77, 81, -100, 35, -64, 22, 23}; + uint8_t ivals[NUM] = {99, 32, 1, 2, 67, 123, 253, 35, + 111, 77, 165, 145, 254, 11, 161, 96, + 165, 100, 64, 90, 255, 147, 135}; + int8_t ivals2[NUM] = {-1, 23, 0, 123, 99, 44, 8, 11, -67, -54, -113, 7, + 1, 81, -81, 21, 25, -38, 66, 99, -121, 34, 45}; + + uint16_t ivals3[NUM] = {36882, 47565, 20664, 59517, 55773, 5858, + 30720, 64786, 42129, 13618, 62202, 16225, + 54751, 38268, 25563, 44332, 45475, 12550, + 5478, 3301, 3779, 25518, 6659}; + + int16_t ivals4[NUM] = {3882, 7565, -20664, 9517, -5773, 5858, + -30720, 86, 429, 13618, 2202, -16225, + 751, -368, 25563, -4332, -5475, 12550, + 5478, 3301, 3779, 25518, 6659}; + + uint32_t ivals5[NUM] = { + 2, 771, 76, 450, 76421894, 273377, + 85040, 831870667, 402825730, 2774821, 10786, 47164, + 1951118976, 75033606, 35755, 93312, 21, 3257266819, + 1065029990, 139884, 11355, 1464548796, 403290}; + + int32_t ivals6[NUM] = { + 2, 771, -76, 450, 76421894, 273377, + 85040, 831870667, -402825730, -2774821, 10786, 47164, + 1951118976, -75033606, 35755, 93312, 21, 57266819, + -1065029990, 139884, -11355, 1464548796, -403290}; + + uint64_t ivals7[NUM] = { + 2, 771, 76, 1112450, 76421894, + 898273377, 66585040, 11831870667, 402825730, 2774821, + 10786, 47164, 1951118976, 75033606, 99935755, + 9331211, 21, 3257266819, 10650299901112, 1224139884, + 9837411355, 1464548796, 403290}; + + int64_t ivals8[NUM] = { + 2, 771, 76, 1112450, 76421894, + 898273377, 66585040, -11831870667, -402825730, 2774821, + 10786, 47164, 1951118976, 75033606, 10, + 0, 21, -3257266819, -10650299901112, 76, + -9837411355, 1464548796, 403290}; + + float ivals9[NUM] = { + 1.628561f, 2.998057f, 0.082604f, 0.0f, 12.330798f, + -1.350443f, 0.437885f, 0.017387f, 0.474454f, -0.718838f, + 98.150388f, 0.732236f, 0.519963f, -0.332644f, 0.648420f, + 0.578913f, -0.853190f, -910.141650f, 110.037210f, 0.434222f, + -0.343777f, 0.346011f, 0.767590f, + }; + + auto work_group_sorter = [](uint8_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES + __devicelib_default_work_group_joint_sort_descending_p1u8_p1u8_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_joint_sort_ascending_p1u8_p1u8_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif + }; + + auto work_group_sorter1 = [](uint8_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES + __devicelib_default_work_group_joint_sort_descending_p1u8_p1i8_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_joint_sort_ascending_p1u8_p1i8_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif + }; + + auto work_group_sorter2 = [](uint8_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES + __devicelib_default_work_group_joint_sort_descending_p1u8_p1u16_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_joint_sort_ascending_p1u8_p1u16_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif + }; + + auto work_group_sorter3 = [](uint8_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES + __devicelib_default_work_group_joint_sort_descending_p1u8_p1i16_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_joint_sort_ascending_p1u8_p1i16_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif + }; + + auto work_group_sorter4 = [](uint8_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES + __devicelib_default_work_group_joint_sort_descending_p1u8_p1u32_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_joint_sort_ascending_p1u8_p1u32_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif + }; + + auto work_group_sorter5 = [](uint8_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES + __devicelib_default_work_group_joint_sort_descending_p1u8_p1i32_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_joint_sort_ascending_p1u8_p1i32_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif + }; + + auto work_group_sorter6 = [](uint8_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES + __devicelib_default_work_group_joint_sort_descending_p1u8_p1u64_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_joint_sort_ascending_p1u8_p1u64_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif + }; + + auto work_group_sorter7 = [](uint8_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES + __devicelib_default_work_group_joint_sort_descending_p1u8_p1i64_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_joint_sort_ascending_p1u8_p1i64_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif + }; + + auto work_group_sorter8 = [](uint8_t *keys, float *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES + __devicelib_default_work_group_joint_sort_descending_p1u8_p1f32_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_joint_sort_ascending_p1u8_p1f32_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif + }; + + auto work_group_sorter9 = [](int8_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES + __devicelib_default_work_group_joint_sort_descending_p1i8_p1u8_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_joint_sort_ascending_p1i8_p1u8_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif + }; + + auto work_group_sorter10 = [](int8_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES + __devicelib_default_work_group_joint_sort_descending_p1i8_p1u16_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_joint_sort_ascending_p1i8_p1u16_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif + }; + + auto work_group_sorter11 = [](int8_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES + __devicelib_default_work_group_joint_sort_descending_p1i8_p1u32_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_joint_sort_ascending_p1i8_p1u32_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif + }; + + auto work_group_sorter12 = [](int8_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES + __devicelib_default_work_group_joint_sort_descending_p1i8_p1u64_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_joint_sort_ascending_p1i8_p1u64_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif + }; + + auto work_group_sorter13 = [](int8_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES + __devicelib_default_work_group_joint_sort_descending_p1i8_p1i8_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_joint_sort_ascending_p1i8_p1i8_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif + }; + + auto work_group_sorter14 = [](int8_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES + __devicelib_default_work_group_joint_sort_descending_p1i8_p1i16_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_joint_sort_ascending_p1i8_p1i16_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif + }; + + test_work_group_KV_joint_sort( + q, ikeys1, ivals, work_group_sorter9); + std::cout + << "KV joint sort (Key: int8_t, Val: uint8_t) pass." + << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys1, ivals, work_group_sorter9); + std::cout + << "KV joint sort (Key: int8_t, Val: uint8_t) pass." + << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys1, ivals, work_group_sorter9); + std::cout + << "KV joint sort (Key: int8_t, Val: uint8_t) pass." + << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys1, ivals, work_group_sorter9); + std::cout + << "KV joint sort (Key: int8_t, Val: uint8_t) pass." + << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys1, ivals, work_group_sorter9); + std::cout + << "KV joint sort (Key: int8_t, Val: uint8_t) pass." + << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys1, ivals, work_group_sorter9); + std::cout + << "KV joint sort (Key: int8_t, Val: uint8_t) pass." + << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys1, ivals2, work_group_sorter13); + std::cout + << "KV joint sort (Key: int8_t, Val: int8_t) pass." + << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys1, ivals2, work_group_sorter13); + std::cout + << "KV joint sort (Key: int8_t, Val: int8_t) pass." + << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys1, ivals2, work_group_sorter13); + std::cout + << "KV joint sort (Key: int8_t, Val: int8_t) pass." + << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys1, ivals2, work_group_sorter13); + std::cout + << "KV joint sort (Key: int8_t, Val: int8_t) pass." + << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys1, ivals2, work_group_sorter13); + std::cout + << "KV joint sort (Key: int8_t, Val: int8_t) pass." + << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys1, ivals2, work_group_sorter13); + std::cout + << "KV joint sort (Key: int8_t, Val: int8_t) pass." + << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys1, ivals3, work_group_sorter10); + std::cout + << "KV joint sort (Key: int8_t, Val: uint16_t) pass." + << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys1, ivals3, work_group_sorter10); + std::cout + << "KV joint sort (Key: int8_t, Val: uint16_t) pass." + << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys1, ivals3, work_group_sorter10); + std::cout + << "KV joint sort (Key: int8_t, Val: uint16_t) pass." + << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys1, ivals3, work_group_sorter10); + std::cout + << "KV joint sort (Key: int8_t, Val: uint16_t) pass." + << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys1, ivals3, work_group_sorter10); + std::cout << "KV joint sort (Key: int8_t, Val: uint16_t) pass." + << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys1, ivals3, work_group_sorter10); + std::cout << "KV joint sort (Key: int8_t, Val: uint16_t) pass." + << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys1, ivals4, work_group_sorter14); + std::cout + << "KV joint sort (Key: int8_t, Val: int16_t) pass." + << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys1, ivals4, work_group_sorter14); + std::cout + << "KV joint sort (Key: int8_t, Val: int16_t) pass." + << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys1, ivals4, work_group_sorter14); + std::cout + << "KV joint sort (Key: int8_t, Val: int16_t) pass." + << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys1, ivals4, work_group_sorter14); + std::cout + << "KV joint sort (Key: int8_t, Val: int16_t) pass." + << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys1, ivals4, work_group_sorter14); + std::cout + << "KV joint sort (Key: int8_t, Val: int16_t) pass." + << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys1, ivals4, work_group_sorter14); + std::cout + << "KV joint sort (Key: int8_t, Val: int16_t) pass." + << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys1, ivals5, work_group_sorter11); + std::cout + << "KV joint sort (Key: int8_t, Val: uint32_t) pass." + << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys1, ivals5, work_group_sorter11); + std::cout + << "KV joint sort (Key: int8_t, Val: uint32_t) pass." + << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys1, ivals5, work_group_sorter11); + std::cout + << "KV joint sort (Key: int8_t, Val: uint32_t) pass." + << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys1, ivals5, work_group_sorter11); + std::cout + << "KV joint sort (Key: int8_t, Val: uint32_t) pass." + << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys1, ivals5, work_group_sorter11); + std::cout << "KV joint sort (Key: int8_t, Val: uint32_t) pass." + << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys1, ivals5, work_group_sorter11); + std::cout << "KV joint sort (Key: int8_t, Val: uint32_t) pass." + << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys1, ivals7, work_group_sorter12); + std::cout + << "KV joint sort (Key: int8_t, Val: uint64_t) pass." + << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys1, ivals7, work_group_sorter12); + std::cout + << "KV joint sort (Key: int8_t, Val: uint64_t) pass." + << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys1, ivals7, work_group_sorter12); + std::cout + << "KV joint sort (Key: int8_t, Val: uint64_t) pass." + << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys1, ivals7, work_group_sorter12); + std::cout + << "KV joint sort (Key: int8_t, Val: uint64_t) pass." + << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys1, ivals7, work_group_sorter12); + std::cout << "KV joint sort (Key: int8_t, Val: uint64_t) pass." + << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys1, ivals7, work_group_sorter12); + std::cout << "KV joint sort (Key: int8_t, Val: uint64_t) pass." + << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys, ivals, work_group_sorter); + std::cout + << "KV joint sort (Key: uint8_t, Val: uint8_t) pass." + << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys, ivals, work_group_sorter); + std::cout + << "KV joint sort (Key: uint8_t, Val: uint8_t) pass." + << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys, ivals, work_group_sorter); + std::cout + << "KV joint sort (Key: uint8_t, Val: uint8_t) pass." + << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys, ivals, work_group_sorter); + std::cout + << "KV joint sort (Key: uint8_t, Val: uint8_t) pass." + << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys, ivals, work_group_sorter); + std::cout << "KV joint sort (Key: uint8_t, Val: uint8_t) pass." + << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys, ivals, work_group_sorter); + std::cout << "KV joint sort (Key: uint8_t, Val: uint8_t) pass." + << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys, ivals2, work_group_sorter1); + std::cout + << "KV joint sort (Key: uint8_t, Val: int8_t) pass." + << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys, ivals2, work_group_sorter1); + std::cout + << "KV joint sort (Key: uint8_t, Val: int8_t) pass." + << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys, ivals2, work_group_sorter1); + std::cout + << "KV joint sort (Key: uint8_t, Val: int8_t) pass." + << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys, ivals2, work_group_sorter1); + std::cout + << "KV joint sort (Key: uint8_t, Val: int8_t) pass." + << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys, ivals2, work_group_sorter1); + std::cout + << "KV joint sort (Key: uint8_t, Val: int8_t) pass." + << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys, ivals2, work_group_sorter1); + std::cout + << "KV joint sort (Key: uint8_t, Val: int8_t) pass." + << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys, ivals3, work_group_sorter2); + std::cout << "KV joint sort (Key: uint8_t, Val: uint16_t) pass." << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys, ivals3, work_group_sorter2); + std::cout << "KV joint sort (Key: uint8_t, Val: uint16_t) pass." << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys, ivals3, work_group_sorter2); + std::cout << "KV joint sort (Key: uint8_t, Val: uint16_t) pass." << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys, ivals3, work_group_sorter2); + std::cout << "KV joint sort (Key: uint8_t, Val: uint16_t) pass." << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys, ivals3, work_group_sorter2); + std::cout << "KV joint sort (Key: uint8_t, Val: uint16_t) pass." << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys, ivals3, work_group_sorter2); + std::cout << "KV joint sort (Key: uint8_t, Val: uint16_t) pass." << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys, ivals4, work_group_sorter3); + std::cout << "KV joint sort (Key: uint8_t, Val: int16_t) pass." << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys, ivals4, work_group_sorter3); + std::cout << "KV joint sort (Key: uint8_t, Val: int16_t) pass." << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys, ivals4, work_group_sorter3); + std::cout << "KV joint sort (Key: uint8_t, Val: int16_t) pass." << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys, ivals4, work_group_sorter3); + std::cout << "KV joint sort (Key: uint8_t, Val: int16_t) pass." << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys, ivals4, work_group_sorter3); + std::cout << "KV joint sort (Key: uint8_t, Val: int16_t) pass." << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys, ivals4, work_group_sorter3); + std::cout << "KV joint sort (Key: uint8_t, Val: int16_t) pass." << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys, ivals5, work_group_sorter4); + std::cout << "KV joint sort (Key: uint8_t, Val: uint32_t) pass." << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys, ivals5, work_group_sorter4); + std::cout << "KV joint sort (Key: uint8_t, Val: uint32_t) pass." << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys, ivals5, work_group_sorter4); + std::cout << "KV joint sort (Key: uint8_t, Val: uint32_t) pass." << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys, ivals5, work_group_sorter4); + std::cout << "KV joint sort (Key: uint8_t, Val: uint32_t) pass." << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys, ivals5, work_group_sorter4); + std::cout << "KV joint sort (Key: uint8_t, Val: uint32_t) pass." << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys, ivals5, work_group_sorter4); + std::cout << "KV joint sort (Key: uint8_t, Val: uint32_t) pass." << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys, ivals6, work_group_sorter5); + std::cout << "KV joint sort (Key: uint8_t, Val: int32_t) pass." << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys, ivals6, work_group_sorter5); + std::cout << "KV joint sort (Key: uint8_t, Val: int32_t) pass." << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys, ivals6, work_group_sorter5); + std::cout << "KV joint sort (Key: uint8_t, Val: int32_t) pass." << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys, ivals6, work_group_sorter5); + std::cout << "KV joint sort (Key: uint8_t, Val: int32_t) pass." << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys, ivals6, work_group_sorter5); + std::cout << "KV joint sort (Key: uint8_t, Val: int32_t) pass." << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys, ivals6, work_group_sorter5); + std::cout << "KV joint sort (Key: uint8_t, Val: int32_t) pass." << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys, ivals7, work_group_sorter6); + std::cout << "KV joint sort (Key: uint8_t, Val: uint64_t) pass." << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys, ivals7, work_group_sorter6); + std::cout << "KV joint sort (Key: uint8_t, Val: uint64_t) pass." << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys, ivals7, work_group_sorter6); + std::cout << "KV joint sort (Key: uint8_t, Val: uint64_t) pass." << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys, ivals7, work_group_sorter6); + std::cout << "KV joint sort (Key: uint8_t, Val: uint64_t) pass." << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys, ivals7, work_group_sorter6); + std::cout << "KV joint sort (Key: uint8_t, Val: uint64_t) pass." << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys, ivals7, work_group_sorter6); + std::cout << "KV joint sort (Key: uint8_t, Val: uint64_t) pass." << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys, ivals8, work_group_sorter7); + std::cout << "KV joint sort (Key: uint8_t, Val: int64_t) pass." << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys, ivals8, work_group_sorter7); + std::cout << "KV joint sort (Key: uint8_t, Val: int64_t) pass." << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys, ivals8, work_group_sorter7); + std::cout << "KV joint sort (Key: uint8_t, Val: int64_t) pass." << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys, ivals8, work_group_sorter7); + std::cout << "KV joint sort (Key: uint8_t, Val: int64_t) pass." << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys, ivals8, work_group_sorter7); + std::cout << "KV joint sort (Key: uint8_t, Val: int64_t) pass." << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys, ivals8, work_group_sorter7); + std::cout << "KV joint sort (Key: uint8_t, Val: int64_t) pass." << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys, ivals9, work_group_sorter8); + std::cout << "KV joint sort (Key: uint8_t, Val: float) pass." << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys, ivals9, work_group_sorter8); + std::cout << "KV joint sort (Key: uint8_t, Val: float) pass." << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys, ivals9, work_group_sorter8); + std::cout << "KV joint sort (Key: uint8_t, Val: float) pass." << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys, ivals9, work_group_sorter8); + std::cout << "KV joint sort (Key: uint8_t, Val: float) pass." << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys, ivals9, work_group_sorter8); + std::cout << "KV joint sort (Key: uint8_t, Val: float) pass." << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys, ivals9, work_group_sorter8); + std::cout << "KV joint sort (Key: uint8_t, Val: float) pass." << std::endl; + } + + { + constexpr static int NUM = 21; + uint32_t ikeys[NUM] = {1, 11, 1, 9, 3, 100, 34, 8, 121, 77, 125, + 23, 36, 2, 111, 91, 88, 2, 51, 95431, 881}; + uint32_t ivals[NUM] = {99, 32, 1, 2, 67, 9123, 453, + 435, 91111, 777, 165, 145, 2456, 88811, + 761, 96, 765, 10000, 6364, 90, 525}; + auto work_group_sorter = [](uint32_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES + __devicelib_default_work_group_joint_sort_descending_p1u32_p1u32_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_joint_sort_ascending_p1u32_p1u32_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif + }; + test_work_group_KV_joint_sort( + q, ikeys, ivals, work_group_sorter); + std::cout << "KV joint sort (Key: uint32_t, Val: uint32_t) pass." + << std::endl; + } + + { + constexpr static int NUM = 32; + uint32_t ikeys[NUM] = {1, 11, 1, 9, 3, 100, 34, 8, + 121, 77, 125, 23, 36, 2, 111, 91, + 88, 2, 51, 95431, 881, 99183, 31, 142, + 416, 701, 699, 1024, 8912, 0, 7981, 17}; + uint32_t ivals[NUM] = {99, 32, 1, 2, 67, 9123, 453, 435, + 91111, 777, 165, 145, 2456, 88811, 761, 96, + 765, 10000, 6364, 90, 525, 882, 1, 2423, + 9, 4324, 9123, 0, 1232, 777, 555, 314159}; + auto work_group_sorter = [](uint32_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { + +#if __DEVICE_CODE +#ifdef DES + __devicelib_default_work_group_joint_sort_descending_p1u32_p1u32_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_joint_sort_ascending_p1u32_p1u32_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif + }; + + test_work_group_KV_joint_sort( + q, ikeys, ivals, work_group_sorter); + std::cout << "KV joint sort (Key: uint32_t, Val: uint32_t) pass." + << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys, ivals, work_group_sorter); + std::cout << "KV joint sort (Key: uint32_t, Val: uint32_t) pass." + << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys, ivals, work_group_sorter); + std::cout << "KV joint sort (Key: uint32_t, Val: uint32_t) pass." + << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys, ivals, work_group_sorter); + std::cout << "KV joint sort (Key: uint32_t, Val: uint32_t) pass." + << std::endl; + + test_work_group_KV_joint_sort( + q, ikeys, ivals, work_group_sorter); + std::cout << "KV joint sort (Key: uint32_t, Val: uint32_t) pass." + << std::endl; + } + + return 0; +} diff --git a/sycl/test-e2e/DeviceLib/group_sort/workgroup_joint_sort_p1p1.cpp b/sycl/test-e2e/DeviceLib/group_sort/workgroup_joint_sort_p1p1.cpp new file mode 100644 index 0000000000000..936a00603cf58 --- /dev/null +++ b/sycl/test-e2e/DeviceLib/group_sort/workgroup_joint_sort_p1p1.cpp @@ -0,0 +1,256 @@ +// RUN: %{build} -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -fsycl-device-lib-jit-link -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DDES -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DDES -fsycl-device-lib-jit-link -o %t.out +// RUN: %{run} %t.out +// +// UNSUPPORTED: cuda || hip + +#include "group_joint_sort_p1p1.hpp" +#include +#include +#include +#include +#include +using namespace sycl; + +template +void test_work_group_joint_sort(sycl::queue &q, Ty input[NUM], SortHelper gsh) { + Ty scratch[NUM] = { + 0, + }; + Ty result[NUM]; + memcpy(result, input, sizeof(Ty) * NUM); + std::sort(&result[0], &result[NUM]); + const static size_t wg_size = WG_SZ; + nd_range<1> num_items((range<1>(wg_size)), (range<1>(wg_size))); + { + buffer ibuf(input, NUM); + buffer obuf(scratch, NUM); + q.submit([&](auto &h) { + accessor in_acc{ibuf, h}; + accessor out_acc{obuf, h}; + h.parallel_for(num_items, [=](nd_item<1> i) { + Ty *optr = + out_acc.template get_multi_ptr().get(); + uint8_t *by = reinterpret_cast(optr); + gsh(in_acc.template get_multi_ptr().get(), NUM, + by); + }); + }).wait(); + } + + bool fails = false; + for (size_t idx = 0; idx < NUM; ++idx) { +#ifdef DES + if (input[idx] != result[NUM - 1 - idx]) { +#else + if (input[idx] != result[idx]) { +#endif + fails = true; + break; + } + } + assert(!fails); +} + +int main() { + queue q; + + { + constexpr static int NUM = 19; + int8_t a[NUM] = {-1, 11, 10, 9, 3, 100, 34, 8, 10, 77, + -93, 23, 36, 2, 111, 91, 88, -2, -25}; + auto work_group_sorter = [](int8_t *first, uint32_t n, uint8_t *scratch) { +#ifdef DES + __devicelib_default_work_group_joint_sort_descending_p1i8_u32_p1i8( + first, n, scratch); +#else + __devicelib_default_work_group_joint_sort_ascending_p1i8_u32_p1i8( + first, n, scratch); +#endif + }; + test_work_group_joint_sort( + q, a, work_group_sorter); + std::cout << "work group joint sort p1i8_u32_p1i8 passes" << std::endl; + } + + { + constexpr static int NUM = 19; + int32_t a[NUM] = {-1, 11, 1, 9, 3, 100, 34, 8, 1000, 77, + 293, 23, 36, 2, 111, 91, 88, -2, 525}; + auto work_group_sorter = [](int32_t *first, uint32_t n, uint8_t *scratch) { +#ifdef DES + __devicelib_default_work_group_joint_sort_descending_p1i32_u32_p1i8( + first, n, scratch); +#else + __devicelib_default_work_group_joint_sort_ascending_p1i32_u32_p1i8( + first, n, scratch); +#endif + }; + test_work_group_joint_sort( + q, a, work_group_sorter); + std::cout << "work group joint sort p1i32_u32_p1i8 passes" << std::endl; + } + + { + constexpr static int NUM = 21; + int16_t a[NUM] = {-1, 11, 1, 9, 3, 100, 34, 8, 1000, 77, 293, + 23, 36, 2, 111, 91, 88, -2, 525, -12, 525}; + auto work_group_sorter = [](int16_t *first, uint32_t n, uint8_t *scratch) { +#ifdef DES + __devicelib_default_work_group_joint_sort_descending_p1i16_u32_p1i8( + first, n, scratch); +#else + __devicelib_default_work_group_joint_sort_ascending_p1i16_u32_p1i8( + first, n, scratch); +#endif + }; + test_work_group_joint_sort( + q, a, work_group_sorter); + std::cout << "work group joint sort p1i16_u32_p1i8 passes" << std::endl; + } + + { + constexpr static int NUM = 23; + int64_t a[NUM] = {-1, 11, 1, 9, 3, 100, 34, 8, + 1000, 77, 293, 23, 36, 2, 111, 91, + 88, -2, 525, -12, 525, -99999999, 19928348493}; + auto work_group_sorter = [](int64_t *first, uint32_t n, uint8_t *scratch) { +#ifdef DES + __devicelib_default_work_group_joint_sort_descending_p1i64_u32_p1i8( + first, n, scratch); +#else + __devicelib_default_work_group_joint_sort_ascending_p1i64_u32_p1i8( + first, n, scratch); +#endif + }; + test_work_group_joint_sort( + q, a, work_group_sorter); + std::cout << "work group joint sort p1i64_u32_p1i8 passes" << std::endl; + } + + { + constexpr static int NUM = 23; + float a[NUM] = {-1.25f, 11.4643f, 1.45f, -9.98f, 13.665f, + 100.0f, 34.625f, 8.125f, 1000.12f, 77.91f, + 293.33f, 23.4f, -36.6f, 2.5f, 111.11f, + 91.889f, 88.88f, -2.98f, 525.25f, -12.11f, + 525.0f, -9999999.9f, 19928348493.123f}; + auto work_group_sorter = [](float *first, uint32_t n, uint8_t *scratch) { +#ifdef DES + __devicelib_default_work_group_joint_sort_descending_p1f32_u32_p1i8( + first, n, scratch); +#else + __devicelib_default_work_group_joint_sort_ascending_p1f32_u32_p1i8( + first, n, scratch); +#endif + }; + test_work_group_joint_sort( + q, a, work_group_sorter); + std::cout << "work group joint sort p1f32_u32_p1i8 passes" << std::endl; + } + + { + constexpr static int NUM = 23; + uint8_t a[NUM] = {234, 11, 1, 9, 3, 100, 34, 8, 121, 77, 125, 23, + 36, 2, 111, 91, 201, 211, 77, 8, 88, 19, 0}; + auto work_group_sorter = [](uint8_t *first, uint32_t n, uint8_t *scratch) { +#ifdef DES + __devicelib_default_work_group_joint_sort_descending_p1u8_u32_p1i8( + first, n, scratch); +#else + __devicelib_default_work_group_joint_sort_ascending_p1u8_u32_p1i8( + first, n, scratch); +#endif + }; + test_work_group_joint_sort( + q, a, work_group_sorter); + std::cout << "work group joint sort p1u8_u32_p1i8 passes" << std::endl; + } + + { + constexpr static int NUM = 23; + uint16_t a[NUM] = {11234, 11, 1, 119, 3, 100, 341, 8, + 121, 77, 125, 23, 3226, 2, 111, 911, + 201, 211, 77, 8, 11188, 19, 0}; + auto work_group_sorter = [](uint16_t *first, uint32_t n, uint8_t *scratch) { +#ifdef DES + __devicelib_default_work_group_joint_sort_descending_p1u16_u32_p1i8( + first, n, scratch); +#else + __devicelib_default_work_group_joint_sort_ascending_p1u16_u32_p1i8( + first, n, scratch); +#endif + }; + test_work_group_joint_sort( + q, a, work_group_sorter); + std::cout << "work group joint sort p1u16_u32_p1i8 passes" << std::endl; + } + + { + constexpr static int NUM = 23; + uint32_t a[NUM] = {11234, 11, 1, 1193332332, 231, 100, 341, 8, + 121, 77, 125, 32, 3226, 2, 111, 911, + 9912201, 211, 711117, 8, 11188, 19, 0}; + auto work_group_sorter = [](uint32_t *first, uint32_t n, uint8_t *scratch) { +#ifdef DES + __devicelib_default_work_group_joint_sort_descending_p1u32_u32_p1i8( + first, n, scratch); +#else + __devicelib_default_work_group_joint_sort_ascending_p1u32_u32_p1i8( + first, n, scratch); +#endif + }; + test_work_group_joint_sort( + q, a, work_group_sorter); + std::cout << "work group joint sort p1u32_u32_p1i8 passes" << std::endl; + } + + { + constexpr static int NUM = 23; + uint64_t a[NUM] = {0x112A111111FFEEFF, + 0xAACC11, + 0x1, + 0x1193332332, + 0x231, + 0xAA, + 0xFCCCA341, + 0x8, + 0x121, + 0x987777777, + 0x81, + 0x20, + 0x3226, + 0x2, + 0x8FFFFFFFFF111, + 0x911, + 0xAAAA9912201, + 0x211, + 0x711117, + 0x8, + 0xABABABABCC, + 0x13, + 0}; + auto work_group_sorter = [](uint64_t *first, uint32_t n, uint8_t *scratch) { +#ifdef DES + __devicelib_default_work_group_joint_sort_descending_p1u64_u32_p1i8( + first, n, scratch); +#else + __devicelib_default_work_group_joint_sort_ascending_p1u64_u32_p1i8( + first, n, scratch); +#endif + }; + test_work_group_joint_sort( + q, a, work_group_sorter); + std::cout << "work group joint sort p1u64_u32_p1i8 passes" << std::endl; + } + + return 0; +} diff --git a/sycl/test-e2e/DeviceLib/group_sort/workgroup_joint_sort_p1p3.cpp b/sycl/test-e2e/DeviceLib/group_sort/workgroup_joint_sort_p1p3.cpp new file mode 100644 index 0000000000000..f3fa3bed5f572 --- /dev/null +++ b/sycl/test-e2e/DeviceLib/group_sort/workgroup_joint_sort_p1p3.cpp @@ -0,0 +1,253 @@ +// RUN: %{build} -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -fsycl-device-lib-jit-link -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DDES -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DDES -fsycl-device-lib-jit-link -o %t.out +// RUN: %{run} %t.out +// +// UNSUPPORTED: cuda || hip + +#include "group_joint_sort_p1p3.hpp" +#include +#include +#include +#include +#include +using namespace sycl; + +// For __devicelib_default_work_group_xxx_p1*_u32_p3u8, the scratch memory is +// in shared local memory. +template +void test_work_group_joint_sort(sycl::queue &q, Ty input[NUM], SortHelper gsh) { + Ty result[NUM]; + memcpy(result, input, sizeof(Ty) * NUM); + std::sort(&result[0], &result[NUM]); + const static size_t wg_size = WG_SZ; + nd_range<1> num_items((range<1>(wg_size)), (range<1>(wg_size))); + { + buffer ibuf(input, NUM); + q.submit([&](auto &h) { + accessor in_acc{ibuf, h}; + local_accessor slm(NUM, h); + h.parallel_for(num_items, [=](nd_item<1> i) { + Ty *optr = slm.template get_multi_ptr().get(); + uint8_t *by = reinterpret_cast(optr); + gsh(in_acc.template get_multi_ptr().get(), NUM, + by); + }); + }).wait(); + } + + bool fails = false; + for (size_t idx = 0; idx < NUM; ++idx) { +#ifdef DES + if (input[idx] != result[NUM - 1 - idx]) { +#else + if (input[idx] != result[idx]) { +#endif + fails = true; + break; + } + } + assert(!fails); +} + +int main() { + queue q; + + { + constexpr static int NUM = 19; + int8_t a[NUM] = {-1, 11, 10, 9, 3, 100, 34, 8, 10, 77, + -93, 23, 36, 2, 111, 91, 88, -2, -25}; + auto work_group_sorter = [](int8_t *first, uint32_t n, uint8_t *scratch) { +#ifdef DES + __devicelib_default_work_group_joint_sort_descending_p1i8_u32_p3i8( + first, n, scratch); +#else + __devicelib_default_work_group_joint_sort_ascending_p1i8_u32_p3i8( + first, n, scratch); +#endif + }; + test_work_group_joint_sort( + q, a, work_group_sorter); + std::cout << "work group joint sort p1i8_u32_p3i8 passes" << std::endl; + } + + { + constexpr static int NUM = 19; + int32_t a[NUM] = {-1, 11, 1, 9, 3, 100, 34, 8, 1000, 77, + 293, 23, 36, 2, 111, 91, 88, -2, 525}; + auto work_group_sorter = [](int32_t *first, uint32_t n, uint8_t *scratch) { +#ifdef DES + __devicelib_default_work_group_joint_sort_descending_p1i32_u32_p3i8( + first, n, scratch); +#else + __devicelib_default_work_group_joint_sort_ascending_p1i32_u32_p3i8( + first, n, scratch); +#endif + }; + test_work_group_joint_sort( + q, a, work_group_sorter); + std::cout << "work group joint sort p1i32_u32_p3i8 passes" << std::endl; + } + + { + constexpr static int NUM = 21; + int16_t a[NUM] = {-1, 11, 1, 9, 3, 100, 34, 8, 1000, 77, 293, + 23, 36, 2, 111, 91, 88, -2, 525, -12, 525}; + auto work_group_sorter = [](int16_t *first, uint32_t n, uint8_t *scratch) { +#ifdef DES + __devicelib_default_work_group_joint_sort_descending_p1i16_u32_p3i8( + first, n, scratch); +#else + __devicelib_default_work_group_joint_sort_ascending_p1i16_u32_p3i8( + first, n, scratch); +#endif + }; + test_work_group_joint_sort( + q, a, work_group_sorter); + std::cout << "work group joint sort p1i16_u32_p3i8 passes" << std::endl; + } + + { + constexpr static int NUM = 23; + int64_t a[NUM] = {-1, 11, 1, 9, 3, 100, 34, 8, + 1000, 77, 293, 23, 36, 2, 111, 91, + 88, -2, 525, -12, 525, -99999999, 19928348493}; + auto work_group_sorter = [](int64_t *first, uint32_t n, uint8_t *scratch) { +#ifdef DES + __devicelib_default_work_group_joint_sort_descending_p1i64_u32_p3i8( + first, n, scratch); +#else + __devicelib_default_work_group_joint_sort_ascending_p1i64_u32_p3i8( + first, n, scratch); +#endif + }; + test_work_group_joint_sort( + q, a, work_group_sorter); + std::cout << "work group joint sort p1i64_u32_p3i8 passes" << std::endl; + } + + { + constexpr static int NUM = 23; + float a[NUM] = {-1.25f, 11.4643f, 1.45f, -9.98f, 13.665f, + 100.0f, 34.625f, 8.125f, 1000.12f, 77.91f, + 293.33f, 23.4f, -36.6f, 2.5f, 111.11f, + 91.889f, 88.88f, -2.98f, 525.25f, -12.11f, + 525.0f, -9999999.9f, 19928348493.123f}; + auto work_group_sorter = [](float *first, uint32_t n, uint8_t *scratch) { +#ifdef DES + __devicelib_default_work_group_joint_sort_descending_p1f32_u32_p3i8( + first, n, scratch); +#else + __devicelib_default_work_group_joint_sort_ascending_p1f32_u32_p3i8( + first, n, scratch); +#endif + }; + test_work_group_joint_sort( + q, a, work_group_sorter); + std::cout << "work group joint sort p1f32_u32_p3i8 passes" << std::endl; + } + + { + constexpr static int NUM = 23; + uint8_t a[NUM] = {234, 11, 1, 9, 3, 100, 34, 8, 121, 77, 125, 23, + 36, 2, 111, 91, 201, 211, 77, 8, 88, 19, 0}; + auto work_group_sorter = [](uint8_t *first, uint32_t n, uint8_t *scratch) { +#ifdef DES + __devicelib_default_work_group_joint_sort_descending_p1u8_u32_p3i8( + first, n, scratch); +#else + __devicelib_default_work_group_joint_sort_ascending_p1u8_u32_p3i8( + first, n, scratch); +#endif + }; + test_work_group_joint_sort( + q, a, work_group_sorter); + std::cout << "work group joint sort p1u8_u32_p3i8 passes" << std::endl; + } + + { + constexpr static int NUM = 23; + uint16_t a[NUM] = {11234, 11, 1, 119, 3, 100, 341, 8, + 121, 77, 125, 23, 3226, 2, 111, 911, + 201, 211, 77, 8, 11188, 19, 0}; + auto work_group_sorter = [](uint16_t *first, uint32_t n, uint8_t *scratch) { +#ifdef DES + __devicelib_default_work_group_joint_sort_descending_p1u16_u32_p3i8( + first, n, scratch); +#else + __devicelib_default_work_group_joint_sort_ascending_p1u16_u32_p3i8( + first, n, scratch); +#endif + }; + test_work_group_joint_sort( + q, a, work_group_sorter); + std::cout << "work group joint sort p1u16_u32_p3i8 passes" << std::endl; + } + + { + constexpr static int NUM = 23; + uint32_t a[NUM] = {11234, 11, 1, 1193332332, 231, 100, 341, 8, + 121, 77, 125, 32, 3226, 2, 111, 911, + 9912201, 211, 711117, 8, 11188, 19, 0}; + auto work_group_sorter = [](uint32_t *first, uint32_t n, uint8_t *scratch) { +#ifdef DES + __devicelib_default_work_group_joint_sort_descending_p1u32_u32_p3i8( + first, n, scratch); +#else + __devicelib_default_work_group_joint_sort_ascending_p1u32_u32_p3i8( + first, n, scratch); +#endif + }; + test_work_group_joint_sort( + q, a, work_group_sorter); + std::cout << "work group joint sort p1u32_u32_p3i8 passes" << std::endl; + } + + { + constexpr static int NUM = 23; + uint64_t a[NUM] = {0x112A111111FFEEFF, + 0xAACC11, + 0x1, + 0x1193332332, + 0x231, + 0xAA, + 0xFCCCA341, + 0x8, + 0x121, + 0x987777777, + 0x81, + 0x20, + 0x3226, + 0x2, + 0x8FFFFFFFFF111, + 0x911, + 0xAAAA9912201, + 0x211, + 0x711117, + 0x8, + 0xABABABABCC, + 0x13, + 0}; + auto work_group_sorter = [](uint64_t *first, uint32_t n, uint8_t *scratch) { +#ifdef DES + __devicelib_default_work_group_joint_sort_descending_p1u64_u32_p3i8( + first, n, scratch); +#else + __devicelib_default_work_group_joint_sort_ascending_p1u64_u32_p3i8( + first, n, scratch); +#endif + }; + test_work_group_joint_sort( + q, a, work_group_sorter); + std::cout << "work group joint sort p1u64_u32_p3i8 passes" << std::endl; + } + + return 0; +} diff --git a/sycl/test-e2e/DeviceLib/group_sort/workgroup_joint_sort_p3p1.cpp b/sycl/test-e2e/DeviceLib/group_sort/workgroup_joint_sort_p3p1.cpp new file mode 100644 index 0000000000000..d8bf629ed7fc1 --- /dev/null +++ b/sycl/test-e2e/DeviceLib/group_sort/workgroup_joint_sort_p3p1.cpp @@ -0,0 +1,293 @@ +// RUN: %{build} -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -fsycl-device-lib-jit-link -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DDES -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DDES -fsycl-device-lib-jit-link -o %t.out +// RUN: %{run} %t.out +// +// UNSUPPORTED: cuda || hip + +#include "group_joint_sort_p3p1.hpp" +#include +#include +#include +#include +#include +using namespace sycl; + +template +void test_work_group_joint_sort(sycl::queue &q, Ty input[NUM], SortHelper gsh) { + Ty scratch[NUM] = { + 0, + }; + Ty result[NUM]; + memcpy(result, input, sizeof(Ty) * NUM); + std::sort(&result[0], &result[NUM]); + const static size_t wg_size = WG_SZ; + constexpr size_t copy_sz_per_item = NUM / WG_SZ; + nd_range<1> num_items((range<1>(wg_size)), (range<1>(wg_size))); + { + buffer ibuf(input, NUM); + buffer obuf(scratch, NUM); + q.submit([&](auto &h) { + accessor in_acc{ibuf, h}; + local_accessor slm(NUM, h); + accessor out_acc{obuf, h}; + h.parallel_for(num_items, [=](nd_item<1> i) { + // Copy input to local shared memory for sort. + if constexpr (copy_sz_per_item == 0) { + if (i.get_global_id() < NUM) + slm[i.get_global_id()] = in_acc[i.get_global_id()]; + } else { + for (size_t idx = 0; idx < copy_sz_per_item; ++idx) { + slm[copy_sz_per_item * i.get_global_id() + idx] = + in_acc[copy_sz_per_item * i.get_global_id() + idx]; + } + + if (i.get_global_id() < (NUM - copy_sz_per_item * WG_SZ)) { + slm[copy_sz_per_item * WG_SZ + i.get_global_id()] = + in_acc[copy_sz_per_item * WG_SZ + i.get_global_id()]; + } + } + + group_barrier(i.get_group()); + Ty *optr = + out_acc.template get_multi_ptr().get(); + uint8_t *by = reinterpret_cast(optr); + gsh(slm.template get_multi_ptr().get(), NUM, + by); + + // Copy sorted data in shared memory to in_acc + if constexpr (copy_sz_per_item == 0) { + if (i.get_global_id() < NUM) + in_acc[i.get_global_id()] = slm[i.get_global_id()]; + } else { + for (size_t idx = 0; idx < copy_sz_per_item; ++idx) { + in_acc[copy_sz_per_item * i.get_global_id() + idx] = + slm[copy_sz_per_item * i.get_global_id() + idx]; + } + + if (i.get_global_id() < (NUM - copy_sz_per_item * WG_SZ)) { + in_acc[copy_sz_per_item * WG_SZ + i.get_global_id()] = + slm[copy_sz_per_item * WG_SZ + i.get_global_id()]; + } + } + + group_barrier(i.get_group()); + }); + }).wait(); + } + + bool fails = false; + for (size_t idx = 0; idx < NUM; ++idx) { +#ifdef DES + if (input[idx] != result[NUM - 1 - idx]) { +#else + if (input[idx] != result[idx]) { +#endif + fails = true; + break; + } + } + assert(!fails); +} + +int main() { + queue q; + + { + constexpr static int NUM = 19; + int8_t a[NUM] = {-1, 11, 10, 9, 3, 100, 34, 8, 10, 77, + -93, 23, 36, 2, 111, 91, 88, -2, -25}; + auto work_group_sorter = [](int8_t *first, uint32_t n, uint8_t *scratch) { +#ifdef DES + __devicelib_default_work_group_joint_sort_descending_p3i8_u32_p1i8( + first, n, scratch); +#else + __devicelib_default_work_group_joint_sort_ascending_p3i8_u32_p1i8( + first, n, scratch); +#endif + }; + test_work_group_joint_sort( + q, a, work_group_sorter); + std::cout << "work group joint sort p3i8_u32_p1i8 passes" << std::endl; + } + + { + constexpr static int NUM = 19; + int32_t a[NUM] = {-1, 11, 1, 9, 3, 100, 34, 8, 1000, 77, + 293, 23, 36, 2, 111, 91, 88, -2, 525}; + auto work_group_sorter = [](int32_t *first, uint32_t n, uint8_t *scratch) { +#ifdef DES + __devicelib_default_work_group_joint_sort_descending_p3i32_u32_p1i8( + first, n, scratch); +#else + __devicelib_default_work_group_joint_sort_ascending_p3i32_u32_p1i8( + first, n, scratch); +#endif + }; + test_work_group_joint_sort( + q, a, work_group_sorter); + std::cout << "work group joint sort p3i32_u32_p1i8 passes" << std::endl; + } + + { + constexpr static int NUM = 21; + int16_t a[NUM] = {-1, 11, 1, 9, 3, 100, 34, 8, 1000, 77, 293, + 23, 36, 2, 111, 91, 88, -2, 525, -12, 525}; + auto work_group_sorter = [](int16_t *first, uint32_t n, uint8_t *scratch) { +#ifdef DES + __devicelib_default_work_group_joint_sort_descending_p3i16_u32_p1i8( + first, n, scratch); +#else + __devicelib_default_work_group_joint_sort_ascending_p3i16_u32_p1i8( + first, n, scratch); +#endif + }; + test_work_group_joint_sort( + q, a, work_group_sorter); + std::cout << "work group joint sort p3i16_u32_p1i8 passes" << std::endl; + } + + { + constexpr static int NUM = 23; + int64_t a[NUM] = {-1, 11, 1, 9, 3, 100, 34, 8, + 1000, 77, 293, 23, 36, 2, 111, 91, + 88, -2, 525, -12, 525, -99999999, 19928348493}; + auto work_group_sorter = [](int64_t *first, uint32_t n, uint8_t *scratch) { +#ifdef DES + __devicelib_default_work_group_joint_sort_descending_p3i64_u32_p1i8( + first, n, scratch); +#else + __devicelib_default_work_group_joint_sort_ascending_p3i64_u32_p1i8( + first, n, scratch); +#endif + }; + test_work_group_joint_sort( + q, a, work_group_sorter); + std::cout << "work group joint sort p3i64_u32_p1i8 passes" << std::endl; + } + + { + constexpr static int NUM = 23; + float a[NUM] = {-1.25f, 11.4643f, 1.45f, -9.98f, 13.665f, + 100.0f, 34.625f, 8.125f, 1000.12f, 77.91f, + 293.33f, 23.4f, -36.6f, 2.5f, 111.11f, + 91.889f, 88.88f, -2.98f, 525.25f, -12.11f, + 525.0f, -9999999.9f, 19928348493.123f}; + auto work_group_sorter = [](float *first, uint32_t n, uint8_t *scratch) { +#ifdef DES + __devicelib_default_work_group_joint_sort_descending_p3f32_u32_p1i8( + first, n, scratch); +#else + __devicelib_default_work_group_joint_sort_ascending_p3f32_u32_p1i8( + first, n, scratch); +#endif + }; + test_work_group_joint_sort( + q, a, work_group_sorter); + std::cout << "work group joint sort p3f32_u32_p1i8 passes" << std::endl; + } + + { + constexpr static int NUM = 23; + uint8_t a[NUM] = {234, 11, 1, 9, 3, 100, 34, 8, 121, 77, 125, 23, + 36, 2, 111, 91, 201, 211, 77, 8, 88, 19, 0}; + auto work_group_sorter = [](uint8_t *first, uint32_t n, uint8_t *scratch) { +#ifdef DES + __devicelib_default_work_group_joint_sort_descending_p3u8_u32_p1i8( + first, n, scratch); +#else + __devicelib_default_work_group_joint_sort_ascending_p3u8_u32_p1i8( + first, n, scratch); +#endif + }; + test_work_group_joint_sort( + q, a, work_group_sorter); + std::cout << "work group joint sort p3u8_u32_p1i8 passes" << std::endl; + } + + { + constexpr static int NUM = 23; + uint16_t a[NUM] = {11234, 11, 1, 119, 3, 100, 341, 8, + 121, 77, 125, 23, 3226, 2, 111, 911, + 201, 211, 77, 8, 11188, 19, 0}; + auto work_group_sorter = [](uint16_t *first, uint32_t n, uint8_t *scratch) { +#ifdef DES + __devicelib_default_work_group_joint_sort_descending_p3u16_u32_p1i8( + first, n, scratch); +#else + __devicelib_default_work_group_joint_sort_ascending_p3u16_u32_p1i8( + first, n, scratch); +#endif + }; + test_work_group_joint_sort( + q, a, work_group_sorter); + std::cout << "work group joint sort p3u16_u32_p1i8 passes" << std::endl; + } + + { + constexpr static int NUM = 23; + uint32_t a[NUM] = {11234, 11, 1, 1193332332, 231, 100, 341, 8, + 121, 77, 125, 32, 3226, 2, 111, 911, + 9912201, 211, 711117, 8, 11188, 19, 0}; + auto work_group_sorter = [](uint32_t *first, uint32_t n, uint8_t *scratch) { +#ifdef DES + __devicelib_default_work_group_joint_sort_descending_p3u32_u32_p1i8( + first, n, scratch); +#else + __devicelib_default_work_group_joint_sort_ascending_p3u32_u32_p1i8( + first, n, scratch); +#endif + }; + test_work_group_joint_sort( + q, a, work_group_sorter); + std::cout << "work group joint sort p3u32_u32_p1i8 passes" << std::endl; + } + + { + constexpr static int NUM = 23; + uint64_t a[NUM] = {0x112A111111FFEEFF, + 0xAACC11, + 0x1, + 0x1193332332, + 0x231, + 0xAA, + 0xFCCCA341, + 0x8, + 0x121, + 0x987777777, + 0x81, + 0x20, + 0x3226, + 0x2, + 0x8FFFFFFFFF111, + 0x911, + 0xAAAA9912201, + 0x211, + 0x711117, + 0x8, + 0xABABABABCC, + 0x13, + 0}; + auto work_group_sorter = [](uint64_t *first, uint32_t n, uint8_t *scratch) { +#ifdef DES + __devicelib_default_work_group_joint_sort_descending_p3u64_u32_p1i8( + first, n, scratch); +#else + __devicelib_default_work_group_joint_sort_ascending_p3u64_u32_p1i8( + first, n, scratch); +#endif + }; + test_work_group_joint_sort( + q, a, work_group_sorter); + std::cout << "work group joint sort p3u64_u32_p1i8 passes" << std::endl; + } + + return 0; +} diff --git a/sycl/test-e2e/DeviceLib/group_sort/workgroup_joint_sort_p3p3.cpp b/sycl/test-e2e/DeviceLib/group_sort/workgroup_joint_sort_p3p3.cpp new file mode 100644 index 0000000000000..f7462362eba13 --- /dev/null +++ b/sycl/test-e2e/DeviceLib/group_sort/workgroup_joint_sort_p3p3.cpp @@ -0,0 +1,289 @@ +// RUN: %{build} -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -fsycl-device-lib-jit-link -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DDES -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DDES -fsycl-device-lib-jit-link -o %t.out +// RUN: %{run} %t.out +// +// UNSUPPORTED: cuda || hip + +#include "group_joint_sort_p3p3.hpp" +#include +#include +#include +#include +#include +using namespace sycl; + +template +void test_work_group_joint_sort(sycl::queue &q, Ty input[NUM], SortHelper gsh) { + Ty result[NUM]; + memcpy(result, input, sizeof(Ty) * NUM); + std::sort(&result[0], &result[NUM]); + const static size_t wg_size = WG_SZ; + constexpr size_t copy_sz_per_item = NUM / WG_SZ; + nd_range<1> num_items((range<1>(wg_size)), (range<1>(wg_size))); + { + buffer ibuf(input, NUM); + q.submit([&](auto &h) { + accessor in_acc{ibuf, h}; + local_accessor slm(NUM, h); + local_accessor scratch(NUM, h); + h.parallel_for(num_items, [=](nd_item<1> i) { + // Copy input to local shared memory for sort. + if constexpr (copy_sz_per_item == 0) { + if (i.get_global_id() < NUM) + slm[i.get_global_id()] = in_acc[i.get_global_id()]; + } else { + for (size_t idx = 0; idx < copy_sz_per_item; ++idx) { + slm[copy_sz_per_item * i.get_global_id() + idx] = + in_acc[copy_sz_per_item * i.get_global_id() + idx]; + } + + if (i.get_global_id() < (NUM - copy_sz_per_item * WG_SZ)) { + slm[copy_sz_per_item * WG_SZ + i.get_global_id()] = + in_acc[copy_sz_per_item * WG_SZ + i.get_global_id()]; + } + } + + group_barrier(i.get_group()); + Ty *optr = + scratch.template get_multi_ptr().get(); + uint8_t *by = reinterpret_cast(optr); + gsh(slm.template get_multi_ptr().get(), NUM, + by); + + // Copy sorted data in shared memory to in_acc + if constexpr (copy_sz_per_item == 0) { + if (i.get_global_id() < NUM) + in_acc[i.get_global_id()] = slm[i.get_global_id()]; + } else { + for (size_t idx = 0; idx < copy_sz_per_item; ++idx) { + in_acc[copy_sz_per_item * i.get_global_id() + idx] = + slm[copy_sz_per_item * i.get_global_id() + idx]; + } + + if (i.get_global_id() < (NUM - copy_sz_per_item * WG_SZ)) { + in_acc[copy_sz_per_item * WG_SZ + i.get_global_id()] = + slm[copy_sz_per_item * WG_SZ + i.get_global_id()]; + } + } + + group_barrier(i.get_group()); + }); + }).wait(); + } + + bool fails = false; + for (size_t idx = 0; idx < NUM; ++idx) { +#ifdef DES + if (input[idx] != result[NUM - 1 - idx]) { +#else + if (input[idx] != result[idx]) { +#endif + fails = true; + break; + } + } + assert(!fails); +} + +int main() { + queue q; + + { + constexpr static int NUM = 19; + int8_t a[NUM] = {-1, 11, 10, 9, 3, 100, 34, 8, 10, 77, + -93, 23, 36, 2, 111, 91, 88, -2, -25}; + auto work_group_sorter = [](int8_t *first, uint32_t n, uint8_t *scratch) { +#ifdef DES + __devicelib_default_work_group_joint_sort_descending_p3i8_u32_p3i8( + first, n, scratch); +#else + __devicelib_default_work_group_joint_sort_ascending_p3i8_u32_p3i8( + first, n, scratch); +#endif + }; + test_work_group_joint_sort( + q, a, work_group_sorter); + std::cout << "work group joint sort p3i8_u32_p3i8 passes" << std::endl; + } + + { + constexpr static int NUM = 19; + int32_t a[NUM] = {-1, 11, 1, 9, 3, 100, 34, 8, 1000, 77, + 293, 23, 36, 2, 111, 91, 88, -2, 525}; + auto work_group_sorter = [](int32_t *first, uint32_t n, uint8_t *scratch) { +#ifdef DES + __devicelib_default_work_group_joint_sort_descending_p3i32_u32_p3i8( + first, n, scratch); +#else + __devicelib_default_work_group_joint_sort_ascending_p3i32_u32_p3i8( + first, n, scratch); +#endif + }; + test_work_group_joint_sort( + q, a, work_group_sorter); + std::cout << "work group joint sort p3i32_u32_p3i8 passes" << std::endl; + } + + { + constexpr static int NUM = 21; + int16_t a[NUM] = {-1, 11, 1, 9, 3, 100, 34, 8, 1000, 77, 293, + 23, 36, 2, 111, 91, 88, -2, 525, -12, 525}; + auto work_group_sorter = [](int16_t *first, uint32_t n, uint8_t *scratch) { +#ifdef DES + __devicelib_default_work_group_joint_sort_descending_p3i16_u32_p3i8( + first, n, scratch); +#else + __devicelib_default_work_group_joint_sort_ascending_p3i16_u32_p3i8( + first, n, scratch); +#endif + }; + test_work_group_joint_sort( + q, a, work_group_sorter); + std::cout << "work group joint sort p3i16_u32_p3i8 passes" << std::endl; + } + + { + constexpr static int NUM = 23; + int64_t a[NUM] = {-1, 11, 1, 9, 3, 100, 34, 8, + 1000, 77, 293, 23, 36, 2, 111, 91, + 88, -2, 525, -12, 525, -99999999, 19928348493}; + auto work_group_sorter = [](int64_t *first, uint32_t n, uint8_t *scratch) { +#ifdef DES + __devicelib_default_work_group_joint_sort_descending_p3i64_u32_p3i8( + first, n, scratch); +#else + __devicelib_default_work_group_joint_sort_ascending_p3i64_u32_p3i8( + first, n, scratch); +#endif + }; + test_work_group_joint_sort( + q, a, work_group_sorter); + std::cout << "work group joint sort p3i64_u32_p3i8 passes" << std::endl; + } + + { + constexpr static int NUM = 23; + float a[NUM] = {-1.25f, 11.4643f, 1.45f, -9.98f, 13.665f, + 100.0f, 34.625f, 8.125f, 1000.12f, 77.91f, + 293.33f, 23.4f, -36.6f, 2.5f, 111.11f, + 91.889f, 88.88f, -2.98f, 525.25f, -12.11f, + 525.0f, -9999999.9f, 19928348493.123f}; + auto work_group_sorter = [](float *first, uint32_t n, uint8_t *scratch) { +#ifdef DES + __devicelib_default_work_group_joint_sort_descending_p3f32_u32_p3i8( + first, n, scratch); +#else + __devicelib_default_work_group_joint_sort_ascending_p3f32_u32_p3i8( + first, n, scratch); +#endif + }; + test_work_group_joint_sort( + q, a, work_group_sorter); + std::cout << "work group joint sort p3f32_u32_p3i8 passes" << std::endl; + } + + { + constexpr static int NUM = 23; + uint8_t a[NUM] = {234, 11, 1, 9, 3, 100, 34, 8, 121, 77, 125, 23, + 36, 2, 111, 91, 201, 211, 77, 8, 88, 19, 0}; + auto work_group_sorter = [](uint8_t *first, uint32_t n, uint8_t *scratch) { +#ifdef DES + __devicelib_default_work_group_joint_sort_descending_p3u8_u32_p3i8( + first, n, scratch); +#else + __devicelib_default_work_group_joint_sort_ascending_p3u8_u32_p3i8( + first, n, scratch); +#endif + }; + test_work_group_joint_sort( + q, a, work_group_sorter); + std::cout << "work group joint sort p3u8_u32_p3i8 passes" << std::endl; + } + + { + constexpr static int NUM = 23; + uint16_t a[NUM] = {11234, 11, 1, 119, 3, 100, 341, 8, + 121, 77, 125, 23, 3226, 2, 111, 911, + 201, 211, 77, 8, 11188, 19, 0}; + auto work_group_sorter = [](uint16_t *first, uint32_t n, uint8_t *scratch) { +#ifdef DES + __devicelib_default_work_group_joint_sort_descending_p3u16_u32_p3i8( + first, n, scratch); +#else + __devicelib_default_work_group_joint_sort_ascending_p3u16_u32_p3i8( + first, n, scratch); +#endif + }; + test_work_group_joint_sort( + q, a, work_group_sorter); + std::cout << "work group joint sort p3u16_u32_p3i8 passes" << std::endl; + } + + { + constexpr static int NUM = 23; + uint32_t a[NUM] = {11234, 11, 1, 1193332332, 231, 100, 341, 8, + 121, 77, 125, 32, 3226, 2, 111, 911, + 9912201, 211, 711117, 8, 11188, 19, 0}; + auto work_group_sorter = [](uint32_t *first, uint32_t n, uint8_t *scratch) { +#ifdef DES + __devicelib_default_work_group_joint_sort_descending_p3u32_u32_p3i8( + first, n, scratch); +#else + __devicelib_default_work_group_joint_sort_ascending_p3u32_u32_p3i8( + first, n, scratch); +#endif + }; + test_work_group_joint_sort( + q, a, work_group_sorter); + std::cout << "work group joint sort p3u32_u32_p3i8 passes" << std::endl; + } + + { + constexpr static int NUM = 23; + uint64_t a[NUM] = {0x112A111111FFEEFF, + 0xAACC11, + 0x1, + 0x1193332332, + 0x231, + 0xAA, + 0xFCCCA341, + 0x8, + 0x121, + 0x987777777, + 0x81, + 0x20, + 0x3226, + 0x2, + 0x8FFFFFFFFF111, + 0x911, + 0xAAAA9912201, + 0x211, + 0x711117, + 0x8, + 0xABABABABCC, + 0x13, + 0}; + auto work_group_sorter = [](uint64_t *first, uint32_t n, uint8_t *scratch) { +#ifdef DES + __devicelib_default_work_group_joint_sort_descending_p3u64_u32_p3i8( + first, n, scratch); +#else + __devicelib_default_work_group_joint_sort_ascending_p3u64_u32_p3i8( + first, n, scratch); +#endif + }; + test_work_group_joint_sort( + q, a, work_group_sorter); + std::cout << "work group joint sort p3u64_u32_p3i8 passes" << std::endl; + } + + return 0; +} diff --git a/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_i16.cpp b/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_i16.cpp new file mode 100644 index 0000000000000..6c9a65b4e6bc7 --- /dev/null +++ b/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_i16.cpp @@ -0,0 +1,606 @@ +// RUN: %{build} -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -fsycl-device-lib-jit-link -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DDES -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DDES -fsycl-device-lib-jit-link -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DSPREAD -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DSPREAD -fsycl-device-lib-jit-link -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DDES -DSPREAD -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DDES -DSPREAD -fsycl-device-lib-jit-link -o %t.out +// RUN: %{run} %t.out + +// UNSUPPORTED: cuda || hip || cpu + +#include "group_private_KV_sort_p1p1_p1.hpp" + +int main() { + queue q; + { + constexpr static int NUM = 35; + int16_t ikeys[NUM] = {1, -11, 1, 9, -3, 100, 34, 8, 121, + -77, 125, 23, 36, -2, -111, 91, 88, -2, + 51, -23, -81, 83, 31, 42, 2, 1, -99, + 124, 12, 0, -81, 17, 15, 101, 44}; + uint32_t ivals[NUM] = {99, 32, 1, 2, 67, 9123, 453, + 435, 91111, 777, 165, 145, 2456, 88811, + 761, 96, 765, 10000, 6364, 90, 525, + 882, 1, 2423, 9, 4324, 9123, 0, + 1232, 777, 555, 314159, 905, 9831, 84341}; + uint8_t ivals1[NUM] = {99, 32, 1, 2, 67, 91, 45, 43, 91, 77, 16, 14, + 24, 88, 76, 96, 76, 100, 63, 90, 52, 82, 1, 22, + 9, 225, 127, 0, 12, 128, 3, 102, 200, 111, 123}; + int8_t ivals2[NUM] = {-99, 127, -121, 100, 9, 5, 12, 35, -98, + 77, 112, -91, 11, 12, 3, 71, -66, 121, + 18, 14, 21, -22, 54, 88, -81, 31, 23, + 53, 97, 103, 71, 83, 97, 37, -41}; + uint16_t ivals3[NUM] = {28831, 23870, 54250, 5022, 9571, 60147, 9554, + 18818, 28689, 18229, 40512, 23200, 40454, 24841, + 43251, 63264, 29448, 45917, 882, 30788, 7586, + 57541, 22108, 59535, 31880, 7152, 63919, 58703, + 14686, 29914, 5872, 35868, 51479, 22721, 50927}; + int16_t ivals4[NUM] = { + 2798, -13656, 1592, 3992, -25870, 25172, 7761, -18347, 1617, + 25472, 26763, -5982, 24791, 27189, 22911, 22502, 15801, 25326, + -2196, 9205, -10418, 20464, -16616, -11285, 7249, 22866, 30574, + -1298, 31351, 28252, 21322, -10072, 7874, -26785, 22016}; + + uint32_t ivals5[NUM] = { + 2238578408, 102907035, 2316773768, 617902655, 532045482, 73173328, + 1862406505, 142735533, 3494078873, 610196959, 4210902254, 1863122236, + 1257721692, 30008197, 3199012044, 3503276708, 3504950001, 1240383071, + 2463430884, 904104390, 4044803029, 3164373711, 1586440767, 1999536602, + 3377662770, 927279985, 1740225703, 1133653675, 3975816601, 260339911, + 1115507520, 2279020820, 4289105012, 692964674, 53775301}; + + int32_t ivals6[NUM] = { + 507394811, 1949685322, 1624859474, -940434061, -1440675113, + -2002743224, 369969519, 840772268, 224522238, 296113452, + -714007528, 480713824, 665592454, 1696360848, 780843358, + -1901994531, 1667711523, 1390737696, 1357434904, -290165630, + 305128121, 1301489180, 630469211, -1385846315, 809333959, + 1098974670, 56900257, 876775101, -1496897817, 1172877939, + 1528916082, 559152364, 749878571, 2071902702, -430851798}; + + uint64_t ivals7[NUM] = {7916688577774406903ULL, 16873442000174444235ULL, + 5261140449171682429ULL, 6274209348061756377ULL, + 17881284978944229367ULL, 4701456380424752599ULL, + 6241062870187348613ULL, 8972524466137433448ULL, + 468629112944127776ULL, 17523909311582893643ULL, + 17447772191733931166ULL, 14311152396797789854ULL, + 9265327272409079662ULL, 9958475911404242556ULL, + 15359829357736068242ULL, 11416531655738519189ULL, + 16839972510321195914ULL, 1927049095689256442ULL, + 3356565661065236628ULL, 1065114285796701562ULL, + 7071763288904613033ULL, 16473053015775147286ULL, + 10317354477399696817ULL, 16005969584273256379ULL, + 15391010921709289298ULL, 17671303749287233862ULL, + 8028596930095411867ULL, 10265936863337610975ULL, + 17403745948359398749ULL, 8504886230707230194ULL, + 12855085916215721214ULL, 5562885793068933146ULL, + 1508385574711135517ULL, 5953119477818575536ULL, + 9165320150094769334ULL}; + + int64_t ivals8[NUM] = { + 2944696543084623337, 137239101631340692, 4370169869966467498, + 3842452903439153631, -795080033670806202, 3023506421574592237, + -4142692575864168559, 1716333381567689984, 1591746912204250089, + -1974664662220599925, 3144022139297218102, -371429365537296255, + 4202906659701034264, 3878513012313576184, -3425767072006791628, + -2929337291418891626, 1880013370888913338, 1498977159463939728, + -2928775660744278650, 4074214991200977615, 4291797122649374026, + -763110360214750992, 2883673064242977727, 4270151072450399778, + 1408696225027958214, 1264214335825459628, -4152065441956669638, + 2684706424226400837, 569335474272794084, -2798088842177577224, + 814002749054152728, 2517003920904582842, 4089891582575745386, + 705067059635512048, -2500935118374519236}; + + auto work_group_sorter = [](int16_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1i16_p1u32_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1i16_p1u32_u32_p1i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1i16_p1u32_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1i16_p1u32_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter1 = [](int16_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1i16_p1u8_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1i16_p1u8_u32_p1i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1i16_p1u8_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1i16_p1u8_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter2 = [](int16_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1i16_p1i8_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1i16_p1i8_u32_p1i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1i16_p1i8_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1i16_p1i8_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter3 = [](int16_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1i16_p1u16_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1i16_p1u16_u32_p1i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1i16_p1u16_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1i16_p1u16_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter4 = [](int16_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1i16_p1i16_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1i16_p1i16_u32_p1i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1i16_p1i16_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1i16_p1i16_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter5 = [](int16_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1i16_p1u32_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1i16_p1u32_u32_p1i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1i16_p1u32_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1i16_p1u32_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter6 = [](int16_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1i16_p1i32_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1i16_p1i32_u32_p1i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1i16_p1i32_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1i16_p1i32_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter7 = [](int16_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1i16_p1u64_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1i16_p1u64_u32_p1i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1i16_p1u64_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1i16_p1u64_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter8 = [](int16_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1i16_p1i64_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1i16_p1i64_u32_p1i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1i16_p1i64_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1i16_p1i64_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + constexpr static int NUM1 = 32; + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 4 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 8 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 16 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 32 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 5 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 7 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 35 pass." << std::endl; + + constexpr static int NUM2 = 24; + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 3 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 4 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 6 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 8 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 12 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 24 pass." << std::endl; + + constexpr static int NUM3 = 20; + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 4 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 5 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 10 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 20 pass." << std::endl; + + constexpr static int NUM4 = 30; + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 3 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 5 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 6 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 10 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 15 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 30 pass." << std::endl; + + constexpr static int NUM5 = 25; + test_work_group_KV_private_sort( + q, ikeys, ivals6, work_group_sorter6); + std::cout << "KV private sort NUM = " << NUM5 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals6, work_group_sorter6); + std::cout << "KV private sort NUM = " << NUM5 + << ", WG = 5 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals6, work_group_sorter6); + std::cout << "KV private sort NUM = " << NUM5 + << ", WG = 25 pass." << std::endl; + + constexpr static int NUM6 = 30; + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 3 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 5 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 6 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 10 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 15 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 30 pass." << std::endl; + + constexpr static int NUM7 = 21; + test_work_group_KV_private_sort( + q, ikeys, ivals8, work_group_sorter8); + std::cout << "KV private sort NUM = " << NUM7 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals8, work_group_sorter8); + std::cout << "KV private sort NUM = " << NUM7 + << ", WG = 3 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals8, work_group_sorter8); + std::cout << "KV private sort NUM = " << NUM7 + << ", WG = 7 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals8, work_group_sorter8); + std::cout << "KV private sort NUM = " << NUM7 + << ", WG = 21 pass." << std::endl; + } +} diff --git a/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_i32.cpp b/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_i32.cpp new file mode 100644 index 0000000000000..7294e6b9345b0 --- /dev/null +++ b/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_i32.cpp @@ -0,0 +1,693 @@ +// RUN: %{build} -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -fsycl-device-lib-jit-link -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DDES -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DDES -fsycl-device-lib-jit-link -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DSPREAD -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DSPREAD -fsycl-device-lib-jit-link -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DDES -DSPREAD -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DDES -DSPREAD -fsycl-device-lib-jit-link -o %t.out +// RUN: %{run} %t.out + +// UNSUPPORTED: cuda || hip || cpu + +#include "group_private_KV_sort_p1p1_p1.hpp" + +int main() { + queue q; + + { + constexpr static int NUM = 40; + int32_t ikeys[NUM] = { + 1, 11, 1, 9, 3, 100, 34, 8, + 121, 77, 125, 23, 222336, 2, 111, 91, + 881112, -2, -62451, 213, 199181, 3183, 310910, 11242, + 3216, 1, 199, -7712423, 0, 0, -181, 17, + 15, -101, 44, 103934, 1, 11, 193, 213}; + uint8_t ivals1[NUM] = {99, 32, 1, 2, 67, 91, 45, 43, 91, 77, + 16, 14, 24, 88, 76, 96, 76, 100, 63, 90, + 52, 82, 1, 22, 9, 225, 127, 0, 12, 128, + 3, 102, 200, 111, 123, 15, 45, 66, 123, 91}; + int8_t ivals2[NUM] = {-99, 127, -121, 100, 9, 5, 12, 35, -98, 77, + 112, -91, 11, 12, 3, 71, -66, 121, 18, 14, + 21, -22, 54, 88, -81, 31, 23, 53, 97, 103, + 71, 83, 97, 37, -41, -71, 112, -121, 98, 78}; + uint16_t ivals3[NUM] = { + 28831, 23870, 54250, 5022, 9571, 60147, 9554, 18818, 28689, 18229, + 40512, 23200, 40454, 24841, 43251, 63264, 29448, 45917, 882, 30788, + 7586, 57541, 22108, 59535, 31880, 7152, 63919, 58703, 14686, 29914, + 5872, 35868, 51479, 22721, 50927, 55094, 2341, 12, 23411, 9812}; + int16_t ivals4[NUM] = { + 2798, -13656, 1592, 3992, -25870, 25172, 7761, -18347, + 1617, 25472, 26763, -5982, 24791, 27189, 22911, 22502, + 15801, 25326, -2196, 9205, -10418, 20464, -16616, -11285, + 7249, 22866, 30574, -1298, 31351, 28252, 21322, -10072, + 7874, -26785, 22016, -12421, 0, 9999, 8888, -7777}; + + uint32_t ivals5[NUM] = { + 2238578408, 102907035, 2316773768, 617902655, 532045482, 73173328, + 1862406505, 142735533, 3494078873, 610196959, 4210902254, 1863122236, + 1257721692, 30008197, 3199012044, 3503276708, 3504950001, 1240383071, + 2463430884, 904104390, 4044803029, 3164373711, 1586440767, 1999536602, + 3377662770, 927279985, 1740225703, 1133653675, 3975816601, 260339911, + 1115507520, 2279020820, 4289105012, 692964674, 53775301, 999999999, + 777777777, 44444444, 1000, 9124}; + + int32_t ivals6[NUM] = { + 507394811, 1949685322, 1624859474, -940434061, -1440675113, + -2002743224, 369969519, 840772268, 224522238, 296113452, + -714007528, 480713824, 665592454, 1696360848, 780843358, + -1901994531, 1667711523, 1390737696, 1357434904, -290165630, + 305128121, 1301489180, 630469211, -1385846315, 809333959, + 1098974670, 56900257, 876775101, -1496897817, 1172877939, + 1528916082, 559152364, 749878571, 2071902702, -430851798, + 912342, -88888888, 777777777, -11111111, 0}; + + uint64_t ivals7[NUM] = {7916688577774406903ULL, + 16873442000174444235ULL, + 5261140449171682429ULL, + 6274209348061756377ULL, + 17881284978944229367ULL, + 4701456380424752599ULL, + 6241062870187348613ULL, + 8972524466137433448ULL, + 468629112944127776ULL, + 17523909311582893643ULL, + 17447772191733931166ULL, + 14311152396797789854ULL, + 9265327272409079662ULL, + 9958475911404242556ULL, + 15359829357736068242ULL, + 11416531655738519189ULL, + 16839972510321195914ULL, + 1927049095689256442ULL, + 3356565661065236628ULL, + 1065114285796701562ULL, + 7071763288904613033ULL, + 16473053015775147286ULL, + 10317354477399696817ULL, + 16005969584273256379ULL, + 15391010921709289298ULL, + 17671303749287233862ULL, + 8028596930095411867ULL, + 10265936863337610975ULL, + 17403745948359398749ULL, + 8504886230707230194ULL, + 12855085916215721214ULL, + 5562885793068933146ULL, + 1508385574711135517ULL, + 5953119477818575536ULL, + 9165320150094769334ULL, + 0, + 11, + 324, + 943534, + 930525}; + + int64_t ivals8[NUM] = {2944696543084623337, + 137239101631340692, + 4370169869966467498, + 3842452903439153631, + -795080033670806202, + 3023506421574592237, + -4142692575864168559, + 1716333381567689984, + 1591746912204250089, + -1974664662220599925, + 3144022139297218102, + -371429365537296255, + 4202906659701034264, + 3878513012313576184, + -3425767072006791628, + -2929337291418891626, + 1880013370888913338, + 1498977159463939728, + -2928775660744278650, + 4074214991200977615, + 4291797122649374026, + -763110360214750992, + 2883673064242977727, + 4270151072450399778, + 1408696225027958214, + 1264214335825459628, + -4152065441956669638, + 2684706424226400837, + 569335474272794084, + -2798088842177577224, + 814002749054152728, + 2517003920904582842, + 4089891582575745386, + 705067059635512048, + -2500935118374519236, + 0, + -1, + 234, + 45, + -1232143254}; + + auto work_group_sorter = [](int32_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1i32_p1u32_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1i32_p1u32_u32_p1i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1i32_p1u32_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1i32_p1u32_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter1 = [](int32_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1i32_p1u8_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1i32_p1u8_u32_p1i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1i32_p1u8_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1i32_p1u8_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter2 = [](int32_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1i32_p1i8_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1i32_p1i8_u32_p1i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1i32_p1i8_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1i32_p1i8_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter3 = [](int32_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1i32_p1u16_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1i32_p1u16_u32_p1i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1i32_p1u16_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1i32_p1u16_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter4 = [](int32_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1i32_p1i16_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1i32_p1i16_u32_p1i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1i32_p1i16_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1i32_p1i16_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter5 = [](int32_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1i32_p1u32_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1i32_p1u32_u32_p1i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1i32_p1u32_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1i32_p1u32_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter6 = [](int32_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1i32_p1i32_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1i32_p1i32_u32_p1i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1i32_p1i32_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1i32_p1i32_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter7 = [](int32_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1i32_p1u64_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1i32_p1u64_u32_p1i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1i32_p1u64_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1i32_p1u64_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter8 = [](int32_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1i32_p1i64_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1i32_p1i64_u32_p1i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1i32_p1i64_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1i32_p1i64_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + constexpr static int NUM1 = 36; + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 4 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 3 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 6 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 9 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 18 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 36 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 5 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 8 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 4 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 10 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 20 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 40 pass." << std::endl; + + constexpr static int NUM2 = 24; + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 3 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 4 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 6 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 8 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 12 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 24 pass." << std::endl; + + constexpr static int NUM3 = 20; + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 4 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 5 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 10 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 20 pass." << std::endl; + + constexpr static int NUM4 = 30; + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 3 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 5 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 6 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 10 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 15 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 30 pass." << std::endl; + + constexpr static int NUM5 = 25; + test_work_group_KV_private_sort( + q, ikeys, ivals6, work_group_sorter6); + std::cout << "KV private sort NUM = " << NUM5 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals6, work_group_sorter6); + std::cout << "KV private sort NUM = " << NUM5 + << ", WG = 5 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals6, work_group_sorter6); + std::cout << "KV private sort NUM = " << NUM5 + << ", WG = 25 pass." << std::endl; + + constexpr static int NUM6 = 30; + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 3 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 5 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 6 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 10 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 15 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 30 pass." << std::endl; + + constexpr static int NUM7 = 21; + test_work_group_KV_private_sort( + q, ikeys, ivals8, work_group_sorter8); + std::cout << "KV private sort NUM = " << NUM7 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals8, work_group_sorter8); + std::cout << "KV private sort NUM = " << NUM7 + << ", WG = 3 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals8, work_group_sorter8); + std::cout << "KV private sort NUM = " << NUM7 + << ", WG = 7 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals8, work_group_sorter8); + std::cout << "KV private sort NUM = " << NUM7 + << ", WG = 21 pass." << std::endl; + } +} diff --git a/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_i64.cpp b/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_i64.cpp new file mode 100644 index 0000000000000..de39f4527c0e5 --- /dev/null +++ b/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_i64.cpp @@ -0,0 +1,693 @@ +// RUN: %{build} -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -fsycl-device-lib-jit-link -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DDES -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DDES -fsycl-device-lib-jit-link -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DSPREAD -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DSPREAD -fsycl-device-lib-jit-link -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DDES -DSPREAD -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DDES -DSPREAD -fsycl-device-lib-jit-link -o %t.out +// RUN: %{run} %t.out + +// UNSUPPORTED: cuda || hip || cpu + +#include "group_private_KV_sort_p1p1_p1.hpp" + +int main() { + queue q; + + { + constexpr static int NUM = 40; + int64_t ikeys[NUM] = { + 1, 11, -1, 9, 3, 100, -34, 8, + 121, 77, 125, 23, 222336, 2, 111, 91, + -881112, 2, 6662451, 213, 199181, -3183, 310910, 11242, + 3216, 1, -199, 7712423, 0, 0, 181, 17, + 15, -101, 44, 103934, 1, -11, 193, 213}; + uint8_t ivals1[NUM] = {99, 32, 1, 2, 67, 91, 45, 43, 91, 77, + 16, 14, 24, 88, 76, 96, 76, 100, 63, 90, + 52, 82, 1, 22, 9, 225, 127, 0, 12, 128, + 3, 102, 200, 111, 123, 15, 45, 66, 123, 91}; + int8_t ivals2[NUM] = {-99, 127, -121, 100, 9, 5, 12, 35, -98, 77, + 112, -91, 11, 12, 3, 71, -66, 121, 18, 14, + 21, -22, 54, 88, -81, 31, 23, 53, 97, 103, + 71, 83, 97, 37, -41, -71, 112, -121, 98, 78}; + uint16_t ivals3[NUM] = { + 28831, 23870, 54250, 5022, 9571, 60147, 9554, 18818, 28689, 18229, + 40512, 23200, 40454, 24841, 43251, 63264, 29448, 45917, 882, 30788, + 7586, 57541, 22108, 59535, 31880, 7152, 63919, 58703, 14686, 29914, + 5872, 35868, 51479, 22721, 50927, 55094, 2341, 12, 23411, 9812}; + int16_t ivals4[NUM] = { + 2798, -13656, 1592, 3992, -25870, 25172, 7761, -18347, + 1617, 25472, 26763, -5982, 24791, 27189, 22911, 22502, + 15801, 25326, -2196, 9205, -10418, 20464, -16616, -11285, + 7249, 22866, 30574, -1298, 31351, 28252, 21322, -10072, + 7874, -26785, 22016, -12421, 0, 9999, 8888, -7777}; + + uint32_t ivals5[NUM] = { + 2238578408, 102907035, 2316773768, 617902655, 532045482, 73173328, + 1862406505, 142735533, 3494078873, 610196959, 4210902254, 1863122236, + 1257721692, 30008197, 3199012044, 3503276708, 3504950001, 1240383071, + 2463430884, 904104390, 4044803029, 3164373711, 1586440767, 1999536602, + 3377662770, 927279985, 1740225703, 1133653675, 3975816601, 260339911, + 1115507520, 2279020820, 4289105012, 692964674, 53775301, 999999999, + 777777777, 44444444, 1000, 9124}; + + int32_t ivals6[NUM] = { + 507394811, 1949685322, 1624859474, -940434061, -1440675113, + -2002743224, 369969519, 840772268, 224522238, 296113452, + -714007528, 480713824, 665592454, 1696360848, 780843358, + -1901994531, 1667711523, 1390737696, 1357434904, -290165630, + 305128121, 1301489180, 630469211, -1385846315, 809333959, + 1098974670, 56900257, 876775101, -1496897817, 1172877939, + 1528916082, 559152364, 749878571, 2071902702, -430851798, + 912342, -88888888, 777777777, -11111111, 0}; + + uint64_t ivals7[NUM] = {7916688577774406903ULL, + 16873442000174444235ULL, + 5261140449171682429ULL, + 6274209348061756377ULL, + 17881284978944229367ULL, + 4701456380424752599ULL, + 6241062870187348613ULL, + 8972524466137433448ULL, + 468629112944127776ULL, + 17523909311582893643ULL, + 17447772191733931166ULL, + 14311152396797789854ULL, + 9265327272409079662ULL, + 9958475911404242556ULL, + 15359829357736068242ULL, + 11416531655738519189ULL, + 16839972510321195914ULL, + 1927049095689256442ULL, + 3356565661065236628ULL, + 1065114285796701562ULL, + 7071763288904613033ULL, + 16473053015775147286ULL, + 10317354477399696817ULL, + 16005969584273256379ULL, + 15391010921709289298ULL, + 17671303749287233862ULL, + 8028596930095411867ULL, + 10265936863337610975ULL, + 17403745948359398749ULL, + 8504886230707230194ULL, + 12855085916215721214ULL, + 5562885793068933146ULL, + 1508385574711135517ULL, + 5953119477818575536ULL, + 9165320150094769334ULL, + 0, + 11, + 324, + 943534, + 930525}; + + int64_t ivals8[NUM] = {2944696543084623337, + 137239101631340692, + 4370169869966467498, + 3842452903439153631, + -795080033670806202, + 3023506421574592237, + -4142692575864168559, + 1716333381567689984, + 1591746912204250089, + -1974664662220599925, + 3144022139297218102, + -371429365537296255, + 4202906659701034264, + 3878513012313576184, + -3425767072006791628, + -2929337291418891626, + 1880013370888913338, + 1498977159463939728, + -2928775660744278650, + 4074214991200977615, + 4291797122649374026, + -763110360214750992, + 2883673064242977727, + 4270151072450399778, + 1408696225027958214, + 1264214335825459628, + -4152065441956669638, + 2684706424226400837, + 569335474272794084, + -2798088842177577224, + 814002749054152728, + 2517003920904582842, + 4089891582575745386, + 705067059635512048, + -2500935118374519236, + 0, + -1, + 234, + 45, + -1232143254}; + + auto work_group_sorter = [](int64_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1i64_p1u32_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1i64_p1u32_u32_p1i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1i64_p1u32_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1i64_p1u32_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter1 = [](int64_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1i64_p1u8_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1i64_p1u8_u32_p1i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1i64_p1u8_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1i64_p1u8_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter2 = [](int64_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1i64_p1i8_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1i64_p1i8_u32_p1i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1i64_p1i8_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1i64_p1i8_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter3 = [](int64_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1i64_p1u16_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1i64_p1u16_u32_p1i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1i64_p1u16_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1i64_p1u16_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter4 = [](int64_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1i64_p1i16_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1i64_p1i16_u32_p1i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1i64_p1i16_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1i64_p1i16_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter5 = [](int64_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1i64_p1u32_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1i64_p1u32_u32_p1i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1i64_p1u32_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1i64_p1u32_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter6 = [](int64_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1i64_p1i32_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1i64_p1i32_u32_p1i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1i64_p1i32_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1i64_p1i32_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter7 = [](int64_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1i64_p1u64_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1i64_p1u64_u32_p1i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1i64_p1u64_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1i64_p1u64_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter8 = [](int64_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1i64_p1i64_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1i64_p1i64_u32_p1i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1i64_p1i64_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1i64_p1i64_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + constexpr static int NUM1 = 36; + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 4 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 3 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 6 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 9 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 18 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 36 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 5 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 8 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 4 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 10 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 20 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 40 pass." << std::endl; + + constexpr static int NUM2 = 24; + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 3 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 4 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 6 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 8 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 12 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 24 pass." << std::endl; + + constexpr static int NUM3 = 20; + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 4 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 5 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 10 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 20 pass." << std::endl; + + constexpr static int NUM4 = 30; + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 3 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 5 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 6 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 10 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 15 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 30 pass." << std::endl; + + constexpr static int NUM5 = 25; + test_work_group_KV_private_sort( + q, ikeys, ivals6, work_group_sorter6); + std::cout << "KV private sort NUM = " << NUM5 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals6, work_group_sorter6); + std::cout << "KV private sort NUM = " << NUM5 + << ", WG = 5 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals6, work_group_sorter6); + std::cout << "KV private sort NUM = " << NUM5 + << ", WG = 25 pass." << std::endl; + + constexpr static int NUM6 = 30; + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 3 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 5 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 6 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 10 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 15 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 30 pass." << std::endl; + + constexpr static int NUM7 = 21; + test_work_group_KV_private_sort( + q, ikeys, ivals8, work_group_sorter8); + std::cout << "KV private sort NUM = " << NUM7 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals8, work_group_sorter8); + std::cout << "KV private sort NUM = " << NUM7 + << ", WG = 3 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals8, work_group_sorter8); + std::cout << "KV private sort NUM = " << NUM7 + << ", WG = 7 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals8, work_group_sorter8); + std::cout << "KV private sort NUM = " << NUM7 + << ", WG = 21 pass." << std::endl; + } +} diff --git a/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_i8.cpp b/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_i8.cpp new file mode 100644 index 0000000000000..96712fbb80317 --- /dev/null +++ b/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_i8.cpp @@ -0,0 +1,607 @@ +// RUN: %{build} -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -fsycl-device-lib-jit-link -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DDES -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DDES -fsycl-device-lib-jit-link -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DSPREAD -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DSPREAD -fsycl-device-lib-jit-link -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DDES -DSPREAD -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DDES -DSPREAD -fsycl-device-lib-jit-link -o %t.out +// RUN: %{run} %t.out + +// UNSUPPORTED: cuda || hip || cpu + +#include "group_private_KV_sort_p1p1_p1.hpp" + +int main() { + queue q; + + { + constexpr static int NUM = 35; + int8_t ikeys[NUM] = {1, -11, 1, 9, -3, 100, 34, 8, 121, + -77, 125, 23, 36, -2, -111, 91, 88, -2, + 51, -23, -81, 83, 31, 42, 2, 1, -99, + 124, 12, 0, -81, 17, 15, 101, 44}; + uint32_t ivals[NUM] = {99, 32, 1, 2, 67, 9123, 453, + 435, 91111, 777, 165, 145, 2456, 88811, + 761, 96, 765, 10000, 6364, 90, 525, + 882, 1, 2423, 9, 4324, 9123, 0, + 1232, 777, 555, 314159, 905, 9831, 84341}; + uint8_t ivals1[NUM] = {99, 32, 1, 2, 67, 91, 45, 43, 91, 77, 16, 14, + 24, 88, 76, 96, 76, 100, 63, 90, 52, 82, 1, 22, + 9, 225, 127, 0, 12, 128, 3, 102, 200, 111, 123}; + int8_t ivals2[NUM] = {-99, 127, -121, 100, 9, 5, 12, 35, -98, + 77, 112, -91, 11, 12, 3, 71, -66, 121, + 18, 14, 21, -22, 54, 88, -81, 31, 23, + 53, 97, 103, 71, 83, 97, 37, -41}; + uint16_t ivals3[NUM] = {28831, 23870, 54250, 5022, 9571, 60147, 9554, + 18818, 28689, 18229, 40512, 23200, 40454, 24841, + 43251, 63264, 29448, 45917, 882, 30788, 7586, + 57541, 22108, 59535, 31880, 7152, 63919, 58703, + 14686, 29914, 5872, 35868, 51479, 22721, 50927}; + int16_t ivals4[NUM] = { + 2798, -13656, 1592, 3992, -25870, 25172, 7761, -18347, 1617, + 25472, 26763, -5982, 24791, 27189, 22911, 22502, 15801, 25326, + -2196, 9205, -10418, 20464, -16616, -11285, 7249, 22866, 30574, + -1298, 31351, 28252, 21322, -10072, 7874, -26785, 22016}; + + uint32_t ivals5[NUM] = { + 2238578408, 102907035, 2316773768, 617902655, 532045482, 73173328, + 1862406505, 142735533, 3494078873, 610196959, 4210902254, 1863122236, + 1257721692, 30008197, 3199012044, 3503276708, 3504950001, 1240383071, + 2463430884, 904104390, 4044803029, 3164373711, 1586440767, 1999536602, + 3377662770, 927279985, 1740225703, 1133653675, 3975816601, 260339911, + 1115507520, 2279020820, 4289105012, 692964674, 53775301}; + + int32_t ivals6[NUM] = { + 507394811, 1949685322, 1624859474, -940434061, -1440675113, + -2002743224, 369969519, 840772268, 224522238, 296113452, + -714007528, 480713824, 665592454, 1696360848, 780843358, + -1901994531, 1667711523, 1390737696, 1357434904, -290165630, + 305128121, 1301489180, 630469211, -1385846315, 809333959, + 1098974670, 56900257, 876775101, -1496897817, 1172877939, + 1528916082, 559152364, 749878571, 2071902702, -430851798}; + + uint64_t ivals7[NUM] = {7916688577774406903ULL, 16873442000174444235ULL, + 5261140449171682429ULL, 6274209348061756377ULL, + 17881284978944229367ULL, 4701456380424752599ULL, + 6241062870187348613ULL, 8972524466137433448ULL, + 468629112944127776ULL, 17523909311582893643ULL, + 17447772191733931166ULL, 14311152396797789854ULL, + 9265327272409079662ULL, 9958475911404242556ULL, + 15359829357736068242ULL, 11416531655738519189ULL, + 16839972510321195914ULL, 1927049095689256442ULL, + 3356565661065236628ULL, 1065114285796701562ULL, + 7071763288904613033ULL, 16473053015775147286ULL, + 10317354477399696817ULL, 16005969584273256379ULL, + 15391010921709289298ULL, 17671303749287233862ULL, + 8028596930095411867ULL, 10265936863337610975ULL, + 17403745948359398749ULL, 8504886230707230194ULL, + 12855085916215721214ULL, 5562885793068933146ULL, + 1508385574711135517ULL, 5953119477818575536ULL, + 9165320150094769334ULL}; + + int64_t ivals8[NUM] = { + 2944696543084623337, 137239101631340692, 4370169869966467498, + 3842452903439153631, -795080033670806202, 3023506421574592237, + -4142692575864168559, 1716333381567689984, 1591746912204250089, + -1974664662220599925, 3144022139297218102, -371429365537296255, + 4202906659701034264, 3878513012313576184, -3425767072006791628, + -2929337291418891626, 1880013370888913338, 1498977159463939728, + -2928775660744278650, 4074214991200977615, 4291797122649374026, + -763110360214750992, 2883673064242977727, 4270151072450399778, + 1408696225027958214, 1264214335825459628, -4152065441956669638, + 2684706424226400837, 569335474272794084, -2798088842177577224, + 814002749054152728, 2517003920904582842, 4089891582575745386, + 705067059635512048, -2500935118374519236}; + + auto work_group_sorter = [](int8_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1i8_p1u32_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1i8_p1u32_u32_p1i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1i8_p1u32_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1i8_p1u32_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter1 = [](int8_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1i8_p1u8_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1i8_p1u8_u32_p1i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1i8_p1u8_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1i8_p1u8_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter2 = [](int8_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1i8_p1i8_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1i8_p1i8_u32_p1i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1i8_p1i8_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1i8_p1i8_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter3 = [](int8_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1i8_p1u16_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1i8_p1u16_u32_p1i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1i8_p1u16_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1i8_p1u16_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter4 = [](int8_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1i8_p1i16_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1i8_p1i16_u32_p1i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1i8_p1i16_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1i8_p1i16_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter5 = [](int8_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1i8_p1u32_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1i8_p1u32_u32_p1i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1i8_p1u32_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1i8_p1u32_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter6 = [](int8_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1i8_p1i32_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1i8_p1i32_u32_p1i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1i8_p1i32_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1i8_p1i32_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter7 = [](int8_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1i8_p1u64_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1i8_p1u64_u32_p1i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1i8_p1u64_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1i8_p1u64_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter8 = [](int8_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1i8_p1i64_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1i8_p1i64_u32_p1i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1i8_p1i64_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1i8_p1i64_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + constexpr static int NUM1 = 32; + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 4 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 8 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 16 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 32 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 5 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 7 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 35 pass." << std::endl; + + constexpr static int NUM2 = 24; + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 3 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 4 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 6 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 8 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 12 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 24 pass." << std::endl; + + constexpr static int NUM3 = 20; + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 4 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 5 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 10 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 20 pass." << std::endl; + + constexpr static int NUM4 = 30; + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 3 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 5 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 6 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 10 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 15 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 30 pass." << std::endl; + + constexpr static int NUM5 = 25; + test_work_group_KV_private_sort( + q, ikeys, ivals6, work_group_sorter6); + std::cout << "KV private sort NUM = " << NUM5 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals6, work_group_sorter6); + std::cout << "KV private sort NUM = " << NUM5 + << ", WG = 5 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals6, work_group_sorter6); + std::cout << "KV private sort NUM = " << NUM5 + << ", WG = 25 pass." << std::endl; + + constexpr static int NUM6 = 30; + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 3 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 5 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 6 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 10 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 15 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 30 pass." << std::endl; + + constexpr static int NUM7 = 21; + test_work_group_KV_private_sort( + q, ikeys, ivals8, work_group_sorter8); + std::cout << "KV private sort NUM = " << NUM7 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals8, work_group_sorter8); + std::cout << "KV private sort NUM = " << NUM7 + << ", WG = 3 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals8, work_group_sorter8); + std::cout << "KV private sort NUM = " << NUM7 + << ", WG = 7 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals8, work_group_sorter8); + std::cout << "KV private sort NUM = " << NUM7 + << ", WG = 21 pass." << std::endl; + } +} diff --git a/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_p3_i16.cpp b/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_p3_i16.cpp new file mode 100644 index 0000000000000..19f8335c2c731 --- /dev/null +++ b/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_p3_i16.cpp @@ -0,0 +1,606 @@ +// RUN: %{build} -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -fsycl-device-lib-jit-link -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DDES -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DDES -fsycl-device-lib-jit-link -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DSPREAD -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DSPREAD -fsycl-device-lib-jit-link -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DDES -DSPREAD -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DDES -DSPREAD -fsycl-device-lib-jit-link -o %t.out +// RUN: %{run} %t.out + +// UNSUPPORTED: cuda || hip || cpu + +#include "group_private_KV_sort_p1p1_p3.hpp" + +int main() { + queue q; + { + constexpr static int NUM = 35; + int16_t ikeys[NUM] = {1, -11, 1, 9, -3, 100, 34, 8, 121, + -77, 125, 23, 36, -2, -111, 91, 88, -2, + 51, -23, -81, 83, 31, 42, 2, 1, -99, + 124, 12, 0, -81, 17, 15, 101, 44}; + uint32_t ivals[NUM] = {99, 32, 1, 2, 67, 9123, 453, + 435, 91111, 777, 165, 145, 2456, 88811, + 761, 96, 765, 10000, 6364, 90, 525, + 882, 1, 2423, 9, 4324, 9123, 0, + 1232, 777, 555, 314159, 905, 9831, 84341}; + uint8_t ivals1[NUM] = {99, 32, 1, 2, 67, 91, 45, 43, 91, 77, 16, 14, + 24, 88, 76, 96, 76, 100, 63, 90, 52, 82, 1, 22, + 9, 225, 127, 0, 12, 128, 3, 102, 200, 111, 123}; + int8_t ivals2[NUM] = {-99, 127, -121, 100, 9, 5, 12, 35, -98, + 77, 112, -91, 11, 12, 3, 71, -66, 121, + 18, 14, 21, -22, 54, 88, -81, 31, 23, + 53, 97, 103, 71, 83, 97, 37, -41}; + uint16_t ivals3[NUM] = {28831, 23870, 54250, 5022, 9571, 60147, 9554, + 18818, 28689, 18229, 40512, 23200, 40454, 24841, + 43251, 63264, 29448, 45917, 882, 30788, 7586, + 57541, 22108, 59535, 31880, 7152, 63919, 58703, + 14686, 29914, 5872, 35868, 51479, 22721, 50927}; + int16_t ivals4[NUM] = { + 2798, -13656, 1592, 3992, -25870, 25172, 7761, -18347, 1617, + 25472, 26763, -5982, 24791, 27189, 22911, 22502, 15801, 25326, + -2196, 9205, -10418, 20464, -16616, -11285, 7249, 22866, 30574, + -1298, 31351, 28252, 21322, -10072, 7874, -26785, 22016}; + + uint32_t ivals5[NUM] = { + 2238578408, 102907035, 2316773768, 617902655, 532045482, 73173328, + 1862406505, 142735533, 3494078873, 610196959, 4210902254, 1863122236, + 1257721692, 30008197, 3199012044, 3503276708, 3504950001, 1240383071, + 2463430884, 904104390, 4044803029, 3164373711, 1586440767, 1999536602, + 3377662770, 927279985, 1740225703, 1133653675, 3975816601, 260339911, + 1115507520, 2279020820, 4289105012, 692964674, 53775301}; + + int32_t ivals6[NUM] = { + 507394811, 1949685322, 1624859474, -940434061, -1440675113, + -2002743224, 369969519, 840772268, 224522238, 296113452, + -714007528, 480713824, 665592454, 1696360848, 780843358, + -1901994531, 1667711523, 1390737696, 1357434904, -290165630, + 305128121, 1301489180, 630469211, -1385846315, 809333959, + 1098974670, 56900257, 876775101, -1496897817, 1172877939, + 1528916082, 559152364, 749878571, 2071902702, -430851798}; + + uint64_t ivals7[NUM] = {7916688577774406903ULL, 16873442000174444235ULL, + 5261140449171682429ULL, 6274209348061756377ULL, + 17881284978944229367ULL, 4701456380424752599ULL, + 6241062870187348613ULL, 8972524466137433448ULL, + 468629112944127776ULL, 17523909311582893643ULL, + 17447772191733931166ULL, 14311152396797789854ULL, + 9265327272409079662ULL, 9958475911404242556ULL, + 15359829357736068242ULL, 11416531655738519189ULL, + 16839972510321195914ULL, 1927049095689256442ULL, + 3356565661065236628ULL, 1065114285796701562ULL, + 7071763288904613033ULL, 16473053015775147286ULL, + 10317354477399696817ULL, 16005969584273256379ULL, + 15391010921709289298ULL, 17671303749287233862ULL, + 8028596930095411867ULL, 10265936863337610975ULL, + 17403745948359398749ULL, 8504886230707230194ULL, + 12855085916215721214ULL, 5562885793068933146ULL, + 1508385574711135517ULL, 5953119477818575536ULL, + 9165320150094769334ULL}; + + int64_t ivals8[NUM] = { + 2944696543084623337, 137239101631340692, 4370169869966467498, + 3842452903439153631, -795080033670806202, 3023506421574592237, + -4142692575864168559, 1716333381567689984, 1591746912204250089, + -1974664662220599925, 3144022139297218102, -371429365537296255, + 4202906659701034264, 3878513012313576184, -3425767072006791628, + -2929337291418891626, 1880013370888913338, 1498977159463939728, + -2928775660744278650, 4074214991200977615, 4291797122649374026, + -763110360214750992, 2883673064242977727, 4270151072450399778, + 1408696225027958214, 1264214335825459628, -4152065441956669638, + 2684706424226400837, 569335474272794084, -2798088842177577224, + 814002749054152728, 2517003920904582842, 4089891582575745386, + 705067059635512048, -2500935118374519236}; + + auto work_group_sorter = [](int16_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1i16_p1u32_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1i16_p1u32_u32_p3i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1i16_p1u32_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1i16_p1u32_u32_p3i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter1 = [](int16_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1i16_p1u8_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1i16_p1u8_u32_p3i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1i16_p1u8_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1i16_p1u8_u32_p3i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter2 = [](int16_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1i16_p1i8_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1i16_p1i8_u32_p3i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1i16_p1i8_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1i16_p1i8_u32_p3i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter3 = [](int16_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1i16_p1u16_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1i16_p1u16_u32_p3i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1i16_p1u16_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1i16_p1u16_u32_p3i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter4 = [](int16_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1i16_p1i16_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1i16_p1i16_u32_p3i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1i16_p1i16_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1i16_p1i16_u32_p3i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter5 = [](int16_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1i16_p1u32_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1i16_p1u32_u32_p3i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1i16_p1u32_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1i16_p1u32_u32_p3i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter6 = [](int16_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1i16_p1i32_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1i16_p1i32_u32_p3i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1i16_p1i32_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1i16_p1i32_u32_p3i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter7 = [](int16_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1i16_p1u64_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1i16_p1u64_u32_p3i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1i16_p1u64_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1i16_p1u64_u32_p3i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter8 = [](int16_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1i16_p1i64_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1i16_p1i64_u32_p3i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1i16_p1i64_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1i16_p1i64_u32_p3i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + constexpr static int NUM1 = 32; + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 4 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 8 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 16 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 32 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 5 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 7 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 35 pass." << std::endl; + + constexpr static int NUM2 = 24; + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 3 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 4 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 6 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 8 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 12 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 24 pass." << std::endl; + + constexpr static int NUM3 = 20; + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 4 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 5 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 10 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 20 pass." << std::endl; + + constexpr static int NUM4 = 30; + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 3 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 5 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 6 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 10 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 15 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 30 pass." << std::endl; + + constexpr static int NUM5 = 25; + test_work_group_KV_private_sort( + q, ikeys, ivals6, work_group_sorter6); + std::cout << "KV private sort NUM = " << NUM5 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals6, work_group_sorter6); + std::cout << "KV private sort NUM = " << NUM5 + << ", WG = 5 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals6, work_group_sorter6); + std::cout << "KV private sort NUM = " << NUM5 + << ", WG = 25 pass." << std::endl; + + constexpr static int NUM6 = 30; + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 3 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 5 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 6 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 10 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 15 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 30 pass." << std::endl; + + constexpr static int NUM7 = 21; + test_work_group_KV_private_sort( + q, ikeys, ivals8, work_group_sorter8); + std::cout << "KV private sort NUM = " << NUM7 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals8, work_group_sorter8); + std::cout << "KV private sort NUM = " << NUM7 + << ", WG = 3 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals8, work_group_sorter8); + std::cout << "KV private sort NUM = " << NUM7 + << ", WG = 7 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals8, work_group_sorter8); + std::cout << "KV private sort NUM = " << NUM7 + << ", WG = 21 pass." << std::endl; + } +} diff --git a/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_p3_i32.cpp b/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_p3_i32.cpp new file mode 100644 index 0000000000000..dd460a3e5075e --- /dev/null +++ b/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_p3_i32.cpp @@ -0,0 +1,693 @@ +// RUN: %{build} -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -fsycl-device-lib-jit-link -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DDES -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DDES -fsycl-device-lib-jit-link -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DSPREAD -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DSPREAD -fsycl-device-lib-jit-link -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DDES -DSPREAD -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DDES -DSPREAD -fsycl-device-lib-jit-link -o %t.out +// RUN: %{run} %t.out + +// UNSUPPORTED: cuda || hip || cpu + +#include "group_private_KV_sort_p1p1_p3.hpp" + +int main() { + queue q; + + { + constexpr static int NUM = 40; + int32_t ikeys[NUM] = { + 1, 11, 1, 9, 3, 100, 34, 8, + 121, 77, 125, 23, 222336, 2, 111, 91, + 881112, -2, -62451, 213, 199181, 3183, 310910, 11242, + 3216, 1, 199, -7712423, 0, 0, -181, 17, + 15, -101, 44, 103934, 1, 11, 193, 213}; + uint8_t ivals1[NUM] = {99, 32, 1, 2, 67, 91, 45, 43, 91, 77, + 16, 14, 24, 88, 76, 96, 76, 100, 63, 90, + 52, 82, 1, 22, 9, 225, 127, 0, 12, 128, + 3, 102, 200, 111, 123, 15, 45, 66, 123, 91}; + int8_t ivals2[NUM] = {-99, 127, -121, 100, 9, 5, 12, 35, -98, 77, + 112, -91, 11, 12, 3, 71, -66, 121, 18, 14, + 21, -22, 54, 88, -81, 31, 23, 53, 97, 103, + 71, 83, 97, 37, -41, -71, 112, -121, 98, 78}; + uint16_t ivals3[NUM] = { + 28831, 23870, 54250, 5022, 9571, 60147, 9554, 18818, 28689, 18229, + 40512, 23200, 40454, 24841, 43251, 63264, 29448, 45917, 882, 30788, + 7586, 57541, 22108, 59535, 31880, 7152, 63919, 58703, 14686, 29914, + 5872, 35868, 51479, 22721, 50927, 55094, 2341, 12, 23411, 9812}; + int16_t ivals4[NUM] = { + 2798, -13656, 1592, 3992, -25870, 25172, 7761, -18347, + 1617, 25472, 26763, -5982, 24791, 27189, 22911, 22502, + 15801, 25326, -2196, 9205, -10418, 20464, -16616, -11285, + 7249, 22866, 30574, -1298, 31351, 28252, 21322, -10072, + 7874, -26785, 22016, -12421, 0, 9999, 8888, -7777}; + + uint32_t ivals5[NUM] = { + 2238578408, 102907035, 2316773768, 617902655, 532045482, 73173328, + 1862406505, 142735533, 3494078873, 610196959, 4210902254, 1863122236, + 1257721692, 30008197, 3199012044, 3503276708, 3504950001, 1240383071, + 2463430884, 904104390, 4044803029, 3164373711, 1586440767, 1999536602, + 3377662770, 927279985, 1740225703, 1133653675, 3975816601, 260339911, + 1115507520, 2279020820, 4289105012, 692964674, 53775301, 999999999, + 777777777, 44444444, 1000, 9124}; + + int32_t ivals6[NUM] = { + 507394811, 1949685322, 1624859474, -940434061, -1440675113, + -2002743224, 369969519, 840772268, 224522238, 296113452, + -714007528, 480713824, 665592454, 1696360848, 780843358, + -1901994531, 1667711523, 1390737696, 1357434904, -290165630, + 305128121, 1301489180, 630469211, -1385846315, 809333959, + 1098974670, 56900257, 876775101, -1496897817, 1172877939, + 1528916082, 559152364, 749878571, 2071902702, -430851798, + 912342, -88888888, 777777777, -11111111, 0}; + + uint64_t ivals7[NUM] = {7916688577774406903ULL, + 16873442000174444235ULL, + 5261140449171682429ULL, + 6274209348061756377ULL, + 17881284978944229367ULL, + 4701456380424752599ULL, + 6241062870187348613ULL, + 8972524466137433448ULL, + 468629112944127776ULL, + 17523909311582893643ULL, + 17447772191733931166ULL, + 14311152396797789854ULL, + 9265327272409079662ULL, + 9958475911404242556ULL, + 15359829357736068242ULL, + 11416531655738519189ULL, + 16839972510321195914ULL, + 1927049095689256442ULL, + 3356565661065236628ULL, + 1065114285796701562ULL, + 7071763288904613033ULL, + 16473053015775147286ULL, + 10317354477399696817ULL, + 16005969584273256379ULL, + 15391010921709289298ULL, + 17671303749287233862ULL, + 8028596930095411867ULL, + 10265936863337610975ULL, + 17403745948359398749ULL, + 8504886230707230194ULL, + 12855085916215721214ULL, + 5562885793068933146ULL, + 1508385574711135517ULL, + 5953119477818575536ULL, + 9165320150094769334ULL, + 0, + 11, + 324, + 943534, + 930525}; + + int64_t ivals8[NUM] = {2944696543084623337, + 137239101631340692, + 4370169869966467498, + 3842452903439153631, + -795080033670806202, + 3023506421574592237, + -4142692575864168559, + 1716333381567689984, + 1591746912204250089, + -1974664662220599925, + 3144022139297218102, + -371429365537296255, + 4202906659701034264, + 3878513012313576184, + -3425767072006791628, + -2929337291418891626, + 1880013370888913338, + 1498977159463939728, + -2928775660744278650, + 4074214991200977615, + 4291797122649374026, + -763110360214750992, + 2883673064242977727, + 4270151072450399778, + 1408696225027958214, + 1264214335825459628, + -4152065441956669638, + 2684706424226400837, + 569335474272794084, + -2798088842177577224, + 814002749054152728, + 2517003920904582842, + 4089891582575745386, + 705067059635512048, + -2500935118374519236, + 0, + -1, + 234, + 45, + -1232143254}; + + auto work_group_sorter = [](int32_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1i32_p1u32_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1i32_p1u32_u32_p3i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1i32_p1u32_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1i32_p1u32_u32_p3i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter1 = [](int32_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1i32_p1u8_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1i32_p1u8_u32_p3i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1i32_p1u8_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1i32_p1u8_u32_p3i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter2 = [](int32_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1i32_p1i8_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1i32_p1i8_u32_p3i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1i32_p1i8_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1i32_p1i8_u32_p3i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter3 = [](int32_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1i32_p1u16_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1i32_p1u16_u32_p3i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1i32_p1u16_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1i32_p1u16_u32_p3i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter4 = [](int32_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1i32_p1i16_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1i32_p1i16_u32_p3i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1i32_p1i16_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1i32_p1i16_u32_p3i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter5 = [](int32_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1i32_p1u32_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1i32_p1u32_u32_p3i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1i32_p1u32_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1i32_p1u32_u32_p3i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter6 = [](int32_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1i32_p1i32_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1i32_p1i32_u32_p3i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1i32_p1i32_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1i32_p1i32_u32_p3i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter7 = [](int32_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1i32_p1u64_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1i32_p1u64_u32_p3i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1i32_p1u64_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1i32_p1u64_u32_p3i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter8 = [](int32_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1i32_p1i64_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1i32_p1i64_u32_p3i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1i32_p1i64_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1i32_p1i64_u32_p3i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + constexpr static int NUM1 = 36; + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 4 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 3 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 6 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 9 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 18 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 36 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 5 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 8 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 4 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 10 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 20 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 40 pass." << std::endl; + + constexpr static int NUM2 = 24; + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 3 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 4 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 6 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 8 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 12 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 24 pass." << std::endl; + + constexpr static int NUM3 = 20; + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 4 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 5 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 10 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 20 pass." << std::endl; + + constexpr static int NUM4 = 30; + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 3 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 5 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 6 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 10 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 15 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 30 pass." << std::endl; + + constexpr static int NUM5 = 25; + test_work_group_KV_private_sort( + q, ikeys, ivals6, work_group_sorter6); + std::cout << "KV private sort NUM = " << NUM5 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals6, work_group_sorter6); + std::cout << "KV private sort NUM = " << NUM5 + << ", WG = 5 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals6, work_group_sorter6); + std::cout << "KV private sort NUM = " << NUM5 + << ", WG = 25 pass." << std::endl; + + constexpr static int NUM6 = 30; + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 3 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 5 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 6 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 10 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 15 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 30 pass." << std::endl; + + constexpr static int NUM7 = 21; + test_work_group_KV_private_sort( + q, ikeys, ivals8, work_group_sorter8); + std::cout << "KV private sort NUM = " << NUM7 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals8, work_group_sorter8); + std::cout << "KV private sort NUM = " << NUM7 + << ", WG = 3 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals8, work_group_sorter8); + std::cout << "KV private sort NUM = " << NUM7 + << ", WG = 7 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals8, work_group_sorter8); + std::cout << "KV private sort NUM = " << NUM7 + << ", WG = 21 pass." << std::endl; + } +} diff --git a/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_p3_i64.cpp b/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_p3_i64.cpp new file mode 100644 index 0000000000000..a21e93caff47a --- /dev/null +++ b/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_p3_i64.cpp @@ -0,0 +1,693 @@ +// RUN: %{build} -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -fsycl-device-lib-jit-link -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DDES -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DDES -fsycl-device-lib-jit-link -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DSPREAD -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DSPREAD -fsycl-device-lib-jit-link -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DDES -DSPREAD -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DDES -DSPREAD -fsycl-device-lib-jit-link -o %t.out +// RUN: %{run} %t.out + +// UNSUPPORTED: cuda || hip || cpu + +#include "group_private_KV_sort_p1p1_p3.hpp" + +int main() { + queue q; + + { + constexpr static int NUM = 40; + int64_t ikeys[NUM] = { + 1, 11, -1, 9, 3, 100, -34, 8, + 121, 77, 125, 23, 222336, 2, 111, 91, + -881112, 2, 6662451, 213, 199181, -3183, 310910, 11242, + 3216, 1, -199, 7712423, 0, 0, 181, 17, + 15, -101, 44, 103934, 1, -11, 193, 213}; + uint8_t ivals1[NUM] = {99, 32, 1, 2, 67, 91, 45, 43, 91, 77, + 16, 14, 24, 88, 76, 96, 76, 100, 63, 90, + 52, 82, 1, 22, 9, 225, 127, 0, 12, 128, + 3, 102, 200, 111, 123, 15, 45, 66, 123, 91}; + int8_t ivals2[NUM] = {-99, 127, -121, 100, 9, 5, 12, 35, -98, 77, + 112, -91, 11, 12, 3, 71, -66, 121, 18, 14, + 21, -22, 54, 88, -81, 31, 23, 53, 97, 103, + 71, 83, 97, 37, -41, -71, 112, -121, 98, 78}; + uint16_t ivals3[NUM] = { + 28831, 23870, 54250, 5022, 9571, 60147, 9554, 18818, 28689, 18229, + 40512, 23200, 40454, 24841, 43251, 63264, 29448, 45917, 882, 30788, + 7586, 57541, 22108, 59535, 31880, 7152, 63919, 58703, 14686, 29914, + 5872, 35868, 51479, 22721, 50927, 55094, 2341, 12, 23411, 9812}; + int16_t ivals4[NUM] = { + 2798, -13656, 1592, 3992, -25870, 25172, 7761, -18347, + 1617, 25472, 26763, -5982, 24791, 27189, 22911, 22502, + 15801, 25326, -2196, 9205, -10418, 20464, -16616, -11285, + 7249, 22866, 30574, -1298, 31351, 28252, 21322, -10072, + 7874, -26785, 22016, -12421, 0, 9999, 8888, -7777}; + + uint32_t ivals5[NUM] = { + 2238578408, 102907035, 2316773768, 617902655, 532045482, 73173328, + 1862406505, 142735533, 3494078873, 610196959, 4210902254, 1863122236, + 1257721692, 30008197, 3199012044, 3503276708, 3504950001, 1240383071, + 2463430884, 904104390, 4044803029, 3164373711, 1586440767, 1999536602, + 3377662770, 927279985, 1740225703, 1133653675, 3975816601, 260339911, + 1115507520, 2279020820, 4289105012, 692964674, 53775301, 999999999, + 777777777, 44444444, 1000, 9124}; + + int32_t ivals6[NUM] = { + 507394811, 1949685322, 1624859474, -940434061, -1440675113, + -2002743224, 369969519, 840772268, 224522238, 296113452, + -714007528, 480713824, 665592454, 1696360848, 780843358, + -1901994531, 1667711523, 1390737696, 1357434904, -290165630, + 305128121, 1301489180, 630469211, -1385846315, 809333959, + 1098974670, 56900257, 876775101, -1496897817, 1172877939, + 1528916082, 559152364, 749878571, 2071902702, -430851798, + 912342, -88888888, 777777777, -11111111, 0}; + + uint64_t ivals7[NUM] = {7916688577774406903ULL, + 16873442000174444235ULL, + 5261140449171682429ULL, + 6274209348061756377ULL, + 17881284978944229367ULL, + 4701456380424752599ULL, + 6241062870187348613ULL, + 8972524466137433448ULL, + 468629112944127776ULL, + 17523909311582893643ULL, + 17447772191733931166ULL, + 14311152396797789854ULL, + 9265327272409079662ULL, + 9958475911404242556ULL, + 15359829357736068242ULL, + 11416531655738519189ULL, + 16839972510321195914ULL, + 1927049095689256442ULL, + 3356565661065236628ULL, + 1065114285796701562ULL, + 7071763288904613033ULL, + 16473053015775147286ULL, + 10317354477399696817ULL, + 16005969584273256379ULL, + 15391010921709289298ULL, + 17671303749287233862ULL, + 8028596930095411867ULL, + 10265936863337610975ULL, + 17403745948359398749ULL, + 8504886230707230194ULL, + 12855085916215721214ULL, + 5562885793068933146ULL, + 1508385574711135517ULL, + 5953119477818575536ULL, + 9165320150094769334ULL, + 0, + 11, + 324, + 943534, + 930525}; + + int64_t ivals8[NUM] = {2944696543084623337, + 137239101631340692, + 4370169869966467498, + 3842452903439153631, + -795080033670806202, + 3023506421574592237, + -4142692575864168559, + 1716333381567689984, + 1591746912204250089, + -1974664662220599925, + 3144022139297218102, + -371429365537296255, + 4202906659701034264, + 3878513012313576184, + -3425767072006791628, + -2929337291418891626, + 1880013370888913338, + 1498977159463939728, + -2928775660744278650, + 4074214991200977615, + 4291797122649374026, + -763110360214750992, + 2883673064242977727, + 4270151072450399778, + 1408696225027958214, + 1264214335825459628, + -4152065441956669638, + 2684706424226400837, + 569335474272794084, + -2798088842177577224, + 814002749054152728, + 2517003920904582842, + 4089891582575745386, + 705067059635512048, + -2500935118374519236, + 0, + -1, + 234, + 45, + -1232143254}; + + auto work_group_sorter = [](int64_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1i64_p1u32_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1i64_p1u32_u32_p3i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1i64_p1u32_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1i64_p1u32_u32_p3i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter1 = [](int64_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1i64_p1u8_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1i64_p1u8_u32_p3i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1i64_p1u8_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1i64_p1u8_u32_p3i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter2 = [](int64_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1i64_p1i8_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1i64_p1i8_u32_p3i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1i64_p1i8_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1i64_p1i8_u32_p3i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter3 = [](int64_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1i64_p1u16_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1i64_p1u16_u32_p3i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1i64_p1u16_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1i64_p1u16_u32_p3i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter4 = [](int64_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1i64_p1i16_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1i64_p1i16_u32_p3i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1i64_p1i16_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1i64_p1i16_u32_p3i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter5 = [](int64_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1i64_p1u32_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1i64_p1u32_u32_p3i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1i64_p1u32_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1i64_p1u32_u32_p3i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter6 = [](int64_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1i64_p1i32_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1i64_p1i32_u32_p3i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1i64_p1i32_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1i64_p1i32_u32_p3i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter7 = [](int64_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1i64_p1u64_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1i64_p1u64_u32_p3i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1i64_p1u64_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1i64_p1u64_u32_p3i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter8 = [](int64_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1i64_p1i64_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1i64_p1i64_u32_p3i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1i64_p1i64_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1i64_p1i64_u32_p3i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + constexpr static int NUM1 = 36; + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 4 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 3 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 6 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 9 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 18 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 36 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 5 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 8 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 4 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 10 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 20 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 40 pass." << std::endl; + + constexpr static int NUM2 = 24; + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 3 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 4 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 6 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 8 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 12 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 24 pass." << std::endl; + + constexpr static int NUM3 = 20; + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 4 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 5 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 10 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 20 pass." << std::endl; + + constexpr static int NUM4 = 30; + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 3 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 5 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 6 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 10 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 15 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 30 pass." << std::endl; + + constexpr static int NUM5 = 25; + test_work_group_KV_private_sort( + q, ikeys, ivals6, work_group_sorter6); + std::cout << "KV private sort NUM = " << NUM5 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals6, work_group_sorter6); + std::cout << "KV private sort NUM = " << NUM5 + << ", WG = 5 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals6, work_group_sorter6); + std::cout << "KV private sort NUM = " << NUM5 + << ", WG = 25 pass." << std::endl; + + constexpr static int NUM6 = 30; + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 3 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 5 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 6 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 10 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 15 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 30 pass." << std::endl; + + constexpr static int NUM7 = 21; + test_work_group_KV_private_sort( + q, ikeys, ivals8, work_group_sorter8); + std::cout << "KV private sort NUM = " << NUM7 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals8, work_group_sorter8); + std::cout << "KV private sort NUM = " << NUM7 + << ", WG = 3 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals8, work_group_sorter8); + std::cout << "KV private sort NUM = " << NUM7 + << ", WG = 7 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals8, work_group_sorter8); + std::cout << "KV private sort NUM = " << NUM7 + << ", WG = 21 pass." << std::endl; + } +} diff --git a/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_p3_i8.cpp b/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_p3_i8.cpp new file mode 100644 index 0000000000000..41f00cf73c99f --- /dev/null +++ b/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_p3_i8.cpp @@ -0,0 +1,607 @@ +// RUN: %{build} -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -fsycl-device-lib-jit-link -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DDES -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DDES -fsycl-device-lib-jit-link -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DSPREAD -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DSPREAD -fsycl-device-lib-jit-link -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DDES -DSPREAD -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DDES -DSPREAD -fsycl-device-lib-jit-link -o %t.out +// RUN: %{run} %t.out + +// UNSUPPORTED: cuda || hip || cpu + +#include "group_private_KV_sort_p1p1_p3.hpp" + +int main() { + queue q; + + { + constexpr static int NUM = 35; + int8_t ikeys[NUM] = {1, -11, 1, 9, -3, 100, 34, 8, 121, + -77, 125, 23, 36, -2, -111, 91, 88, -2, + 51, -23, -81, 83, 31, 42, 2, 1, -99, + 124, 12, 0, -81, 17, 15, 101, 44}; + uint32_t ivals[NUM] = {99, 32, 1, 2, 67, 9123, 453, + 435, 91111, 777, 165, 145, 2456, 88811, + 761, 96, 765, 10000, 6364, 90, 525, + 882, 1, 2423, 9, 4324, 9123, 0, + 1232, 777, 555, 314159, 905, 9831, 84341}; + uint8_t ivals1[NUM] = {99, 32, 1, 2, 67, 91, 45, 43, 91, 77, 16, 14, + 24, 88, 76, 96, 76, 100, 63, 90, 52, 82, 1, 22, + 9, 225, 127, 0, 12, 128, 3, 102, 200, 111, 123}; + int8_t ivals2[NUM] = {-99, 127, -121, 100, 9, 5, 12, 35, -98, + 77, 112, -91, 11, 12, 3, 71, -66, 121, + 18, 14, 21, -22, 54, 88, -81, 31, 23, + 53, 97, 103, 71, 83, 97, 37, -41}; + uint16_t ivals3[NUM] = {28831, 23870, 54250, 5022, 9571, 60147, 9554, + 18818, 28689, 18229, 40512, 23200, 40454, 24841, + 43251, 63264, 29448, 45917, 882, 30788, 7586, + 57541, 22108, 59535, 31880, 7152, 63919, 58703, + 14686, 29914, 5872, 35868, 51479, 22721, 50927}; + int16_t ivals4[NUM] = { + 2798, -13656, 1592, 3992, -25870, 25172, 7761, -18347, 1617, + 25472, 26763, -5982, 24791, 27189, 22911, 22502, 15801, 25326, + -2196, 9205, -10418, 20464, -16616, -11285, 7249, 22866, 30574, + -1298, 31351, 28252, 21322, -10072, 7874, -26785, 22016}; + + uint32_t ivals5[NUM] = { + 2238578408, 102907035, 2316773768, 617902655, 532045482, 73173328, + 1862406505, 142735533, 3494078873, 610196959, 4210902254, 1863122236, + 1257721692, 30008197, 3199012044, 3503276708, 3504950001, 1240383071, + 2463430884, 904104390, 4044803029, 3164373711, 1586440767, 1999536602, + 3377662770, 927279985, 1740225703, 1133653675, 3975816601, 260339911, + 1115507520, 2279020820, 4289105012, 692964674, 53775301}; + + int32_t ivals6[NUM] = { + 507394811, 1949685322, 1624859474, -940434061, -1440675113, + -2002743224, 369969519, 840772268, 224522238, 296113452, + -714007528, 480713824, 665592454, 1696360848, 780843358, + -1901994531, 1667711523, 1390737696, 1357434904, -290165630, + 305128121, 1301489180, 630469211, -1385846315, 809333959, + 1098974670, 56900257, 876775101, -1496897817, 1172877939, + 1528916082, 559152364, 749878571, 2071902702, -430851798}; + + uint64_t ivals7[NUM] = {7916688577774406903ULL, 16873442000174444235ULL, + 5261140449171682429ULL, 6274209348061756377ULL, + 17881284978944229367ULL, 4701456380424752599ULL, + 6241062870187348613ULL, 8972524466137433448ULL, + 468629112944127776ULL, 17523909311582893643ULL, + 17447772191733931166ULL, 14311152396797789854ULL, + 9265327272409079662ULL, 9958475911404242556ULL, + 15359829357736068242ULL, 11416531655738519189ULL, + 16839972510321195914ULL, 1927049095689256442ULL, + 3356565661065236628ULL, 1065114285796701562ULL, + 7071763288904613033ULL, 16473053015775147286ULL, + 10317354477399696817ULL, 16005969584273256379ULL, + 15391010921709289298ULL, 17671303749287233862ULL, + 8028596930095411867ULL, 10265936863337610975ULL, + 17403745948359398749ULL, 8504886230707230194ULL, + 12855085916215721214ULL, 5562885793068933146ULL, + 1508385574711135517ULL, 5953119477818575536ULL, + 9165320150094769334ULL}; + + int64_t ivals8[NUM] = { + 2944696543084623337, 137239101631340692, 4370169869966467498, + 3842452903439153631, -795080033670806202, 3023506421574592237, + -4142692575864168559, 1716333381567689984, 1591746912204250089, + -1974664662220599925, 3144022139297218102, -371429365537296255, + 4202906659701034264, 3878513012313576184, -3425767072006791628, + -2929337291418891626, 1880013370888913338, 1498977159463939728, + -2928775660744278650, 4074214991200977615, 4291797122649374026, + -763110360214750992, 2883673064242977727, 4270151072450399778, + 1408696225027958214, 1264214335825459628, -4152065441956669638, + 2684706424226400837, 569335474272794084, -2798088842177577224, + 814002749054152728, 2517003920904582842, 4089891582575745386, + 705067059635512048, -2500935118374519236}; + + auto work_group_sorter = [](int8_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1i8_p1u32_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1i8_p1u32_u32_p3i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1i8_p1u32_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1i8_p1u32_u32_p3i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter1 = [](int8_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1i8_p1u8_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1i8_p1u8_u32_p3i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1i8_p1u8_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1i8_p1u8_u32_p3i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter2 = [](int8_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1i8_p1i8_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1i8_p1i8_u32_p3i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1i8_p1i8_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1i8_p1i8_u32_p3i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter3 = [](int8_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1i8_p1u16_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1i8_p1u16_u32_p3i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1i8_p1u16_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1i8_p1u16_u32_p3i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter4 = [](int8_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1i8_p1i16_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1i8_p1i16_u32_p3i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1i8_p1i16_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1i8_p1i16_u32_p3i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter5 = [](int8_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1i8_p1u32_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1i8_p1u32_u32_p3i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1i8_p1u32_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1i8_p1u32_u32_p3i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter6 = [](int8_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1i8_p1i32_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1i8_p1i32_u32_p3i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1i8_p1i32_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1i8_p1i32_u32_p3i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter7 = [](int8_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1i8_p1u64_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1i8_p1u64_u32_p3i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1i8_p1u64_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1i8_p1u64_u32_p3i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter8 = [](int8_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1i8_p1i64_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1i8_p1i64_u32_p3i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1i8_p1i64_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1i8_p1i64_u32_p3i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + constexpr static int NUM1 = 32; + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 4 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 8 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 16 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 32 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 5 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 7 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 35 pass." << std::endl; + + constexpr static int NUM2 = 24; + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 3 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 4 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 6 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 8 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 12 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 24 pass." << std::endl; + + constexpr static int NUM3 = 20; + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 4 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 5 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 10 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 20 pass." << std::endl; + + constexpr static int NUM4 = 30; + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 3 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 5 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 6 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 10 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 15 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 30 pass." << std::endl; + + constexpr static int NUM5 = 25; + test_work_group_KV_private_sort( + q, ikeys, ivals6, work_group_sorter6); + std::cout << "KV private sort NUM = " << NUM5 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals6, work_group_sorter6); + std::cout << "KV private sort NUM = " << NUM5 + << ", WG = 5 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals6, work_group_sorter6); + std::cout << "KV private sort NUM = " << NUM5 + << ", WG = 25 pass." << std::endl; + + constexpr static int NUM6 = 30; + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 3 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 5 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 6 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 10 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 15 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 30 pass." << std::endl; + + constexpr static int NUM7 = 21; + test_work_group_KV_private_sort( + q, ikeys, ivals8, work_group_sorter8); + std::cout << "KV private sort NUM = " << NUM7 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals8, work_group_sorter8); + std::cout << "KV private sort NUM = " << NUM7 + << ", WG = 3 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals8, work_group_sorter8); + std::cout << "KV private sort NUM = " << NUM7 + << ", WG = 7 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals8, work_group_sorter8); + std::cout << "KV private sort NUM = " << NUM7 + << ", WG = 21 pass." << std::endl; + } +} diff --git a/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_p3_u16.cpp b/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_p3_u16.cpp new file mode 100644 index 0000000000000..5927c08582621 --- /dev/null +++ b/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_p3_u16.cpp @@ -0,0 +1,607 @@ +// RUN: %{build} -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -fsycl-device-lib-jit-link -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DDES -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DDES -fsycl-device-lib-jit-link -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DSPREAD -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DSPREAD -fsycl-device-lib-jit-link -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DDES -DSPREAD -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DDES -DSPREAD -fsycl-device-lib-jit-link -o %t.out +// RUN: %{run} %t.out + +// UNSUPPORTED: cuda || hip || cpu + +#include "group_private_KV_sort_p1p1_p3.hpp" + +int main() { + queue q; + + { + constexpr static int NUM = 35; + uint16_t ikeys[NUM] = {1, 11, 1, 9, 3, 100, 34, 8, 121, + 77, 125, 23, 36, 2, 111, 91, 88, 2, + 51, 213, 181, 183, 31, 142, 216, 1, 199, + 124, 12, 0, 181, 17, 15, 101, 44}; + uint32_t ivals[NUM] = {99, 32, 1, 2, 67, 9123, 453, + 435, 91111, 777, 165, 145, 2456, 88811, + 761, 96, 765, 10000, 6364, 90, 525, + 882, 1, 2423, 9, 4324, 9123, 0, + 1232, 777, 555, 314159, 905, 9831, 84341}; + uint8_t ivals1[NUM] = {99, 32, 1, 2, 67, 91, 45, 43, 91, 77, 16, 14, + 24, 88, 76, 96, 76, 100, 63, 90, 52, 82, 1, 22, + 9, 225, 127, 0, 12, 128, 3, 102, 200, 111, 123}; + int8_t ivals2[NUM] = {-99, 127, -121, 100, 9, 5, 12, 35, -98, + 77, 112, -91, 11, 12, 3, 71, -66, 121, + 18, 14, 21, -22, 54, 88, -81, 31, 23, + 53, 97, 103, 71, 83, 97, 37, -41}; + uint16_t ivals3[NUM] = {28831, 23870, 54250, 5022, 9571, 60147, 9554, + 18818, 28689, 18229, 40512, 23200, 40454, 24841, + 43251, 63264, 29448, 45917, 882, 30788, 7586, + 57541, 22108, 59535, 31880, 7152, 63919, 58703, + 14686, 29914, 5872, 35868, 51479, 22721, 50927}; + int16_t ivals4[NUM] = { + 2798, -13656, 1592, 3992, -25870, 25172, 7761, -18347, 1617, + 25472, 26763, -5982, 24791, 27189, 22911, 22502, 15801, 25326, + -2196, 9205, -10418, 20464, -16616, -11285, 7249, 22866, 30574, + -1298, 31351, 28252, 21322, -10072, 7874, -26785, 22016}; + + uint32_t ivals5[NUM] = { + 2238578408, 102907035, 2316773768, 617902655, 532045482, 73173328, + 1862406505, 142735533, 3494078873, 610196959, 4210902254, 1863122236, + 1257721692, 30008197, 3199012044, 3503276708, 3504950001, 1240383071, + 2463430884, 904104390, 4044803029, 3164373711, 1586440767, 1999536602, + 3377662770, 927279985, 1740225703, 1133653675, 3975816601, 260339911, + 1115507520, 2279020820, 4289105012, 692964674, 53775301}; + + int32_t ivals6[NUM] = { + 507394811, 1949685322, 1624859474, -940434061, -1440675113, + -2002743224, 369969519, 840772268, 224522238, 296113452, + -714007528, 480713824, 665592454, 1696360848, 780843358, + -1901994531, 1667711523, 1390737696, 1357434904, -290165630, + 305128121, 1301489180, 630469211, -1385846315, 809333959, + 1098974670, 56900257, 876775101, -1496897817, 1172877939, + 1528916082, 559152364, 749878571, 2071902702, -430851798}; + + uint64_t ivals7[NUM] = {7916688577774406903ULL, 16873442000174444235ULL, + 5261140449171682429ULL, 6274209348061756377ULL, + 17881284978944229367ULL, 4701456380424752599ULL, + 6241062870187348613ULL, 8972524466137433448ULL, + 468629112944127776ULL, 17523909311582893643ULL, + 17447772191733931166ULL, 14311152396797789854ULL, + 9265327272409079662ULL, 9958475911404242556ULL, + 15359829357736068242ULL, 11416531655738519189ULL, + 16839972510321195914ULL, 1927049095689256442ULL, + 3356565661065236628ULL, 1065114285796701562ULL, + 7071763288904613033ULL, 16473053015775147286ULL, + 10317354477399696817ULL, 16005969584273256379ULL, + 15391010921709289298ULL, 17671303749287233862ULL, + 8028596930095411867ULL, 10265936863337610975ULL, + 17403745948359398749ULL, 8504886230707230194ULL, + 12855085916215721214ULL, 5562885793068933146ULL, + 1508385574711135517ULL, 5953119477818575536ULL, + 9165320150094769334ULL}; + + int64_t ivals8[NUM] = { + 2944696543084623337, 137239101631340692, 4370169869966467498, + 3842452903439153631, -795080033670806202, 3023506421574592237, + -4142692575864168559, 1716333381567689984, 1591746912204250089, + -1974664662220599925, 3144022139297218102, -371429365537296255, + 4202906659701034264, 3878513012313576184, -3425767072006791628, + -2929337291418891626, 1880013370888913338, 1498977159463939728, + -2928775660744278650, 4074214991200977615, 4291797122649374026, + -763110360214750992, 2883673064242977727, 4270151072450399778, + 1408696225027958214, 1264214335825459628, -4152065441956669638, + 2684706424226400837, 569335474272794084, -2798088842177577224, + 814002749054152728, 2517003920904582842, 4089891582575745386, + 705067059635512048, -2500935118374519236}; + + auto work_group_sorter = [](uint16_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1u16_p1u32_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1u16_p1u32_u32_p3i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1u16_p1u32_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1u16_p1u32_u32_p3i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter1 = [](uint16_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1u16_p1u8_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1u16_p1u8_u32_p3i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1u16_p1u8_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1u16_p1u8_u32_p3i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter2 = [](uint16_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1u16_p1i8_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1u16_p1i8_u32_p3i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1u16_p1i8_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1u16_p1i8_u32_p3i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter3 = [](uint16_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1u16_p1u16_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1u16_p1u16_u32_p3i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1u16_p1u16_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1u16_p1u16_u32_p3i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter4 = [](uint16_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1u16_p1i16_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1u16_p1i16_u32_p3i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1u16_p1i16_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1u16_p1i16_u32_p3i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter5 = [](uint16_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1u16_p1u32_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1u16_p1u32_u32_p3i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1u16_p1u32_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1u16_p1u32_u32_p3i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter6 = [](uint16_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1u16_p1i32_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1u16_p1i32_u32_p3i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1u16_p1i32_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1u16_p1i32_u32_p3i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter7 = [](uint16_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1u16_p1u64_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1u16_p1u64_u32_p3i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1u16_p1u64_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1u16_p1u64_u32_p3i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter8 = [](uint16_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1u16_p1i64_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1u16_p1i64_u32_p3i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1u16_p1i64_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1u16_p1i64_u32_p3i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + constexpr static int NUM1 = 32; + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 4 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 8 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 16 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 32 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 5 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 7 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 35 pass." << std::endl; + + constexpr static int NUM2 = 24; + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 3 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 4 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 6 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 8 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 12 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 24 pass." << std::endl; + + constexpr static int NUM3 = 20; + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 4 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 5 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 10 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 20 pass." << std::endl; + + constexpr static int NUM4 = 30; + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 3 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 5 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 6 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 10 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 15 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 30 pass." << std::endl; + + constexpr static int NUM5 = 25; + test_work_group_KV_private_sort( + q, ikeys, ivals6, work_group_sorter6); + std::cout << "KV private sort NUM = " << NUM5 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals6, work_group_sorter6); + std::cout << "KV private sort NUM = " << NUM5 + << ", WG = 5 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals6, work_group_sorter6); + std::cout << "KV private sort NUM = " << NUM5 + << ", WG = 25 pass." << std::endl; + + constexpr static int NUM6 = 30; + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 3 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 5 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 6 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 10 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 15 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 30 pass." << std::endl; + + constexpr static int NUM7 = 21; + test_work_group_KV_private_sort( + q, ikeys, ivals8, work_group_sorter8); + std::cout << "KV private sort NUM = " << NUM7 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals8, work_group_sorter8); + std::cout << "KV private sort NUM = " << NUM7 + << ", WG = 3 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals8, work_group_sorter8); + std::cout << "KV private sort NUM = " << NUM7 + << ", WG = 7 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals8, work_group_sorter8); + std::cout << "KV private sort NUM = " << NUM7 + << ", WG = 21 pass." << std::endl; + } +} diff --git a/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_p3_u32.cpp b/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_p3_u32.cpp new file mode 100644 index 0000000000000..4787de440b25e --- /dev/null +++ b/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_p3_u32.cpp @@ -0,0 +1,693 @@ +// RUN: %{build} -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -fsycl-device-lib-jit-link -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DDES -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DDES -fsycl-device-lib-jit-link -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DSPREAD -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DSPREAD -fsycl-device-lib-jit-link -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DDES -DSPREAD -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DDES -DSPREAD -fsycl-device-lib-jit-link -o %t.out +// RUN: %{run} %t.out + +// UNSUPPORTED: cuda || hip || cpu + +#include "group_private_KV_sort_p1p1_p3.hpp" + +int main() { + queue q; + + { + constexpr static int NUM = 40; + uint32_t ikeys[NUM] = { + 1, 11, 1, 9, 3, 100, 34, 8, + 121, 77, 125, 23, 222336, 2, 111, 91, + 881112, 2, 6662451, 213, 199181, 3183, 310910, 11242, + 3216, 1, 199, 7712423, 0, 0, 181, 17, + 15, 101, 44, 103934, 1, 11, 193, 213}; + uint8_t ivals1[NUM] = {99, 32, 1, 2, 67, 91, 45, 43, 91, 77, + 16, 14, 24, 88, 76, 96, 76, 100, 63, 90, + 52, 82, 1, 22, 9, 225, 127, 0, 12, 128, + 3, 102, 200, 111, 123, 15, 45, 66, 123, 91}; + int8_t ivals2[NUM] = {-99, 127, -121, 100, 9, 5, 12, 35, -98, 77, + 112, -91, 11, 12, 3, 71, -66, 121, 18, 14, + 21, -22, 54, 88, -81, 31, 23, 53, 97, 103, + 71, 83, 97, 37, -41, -71, 112, -121, 98, 78}; + uint16_t ivals3[NUM] = { + 28831, 23870, 54250, 5022, 9571, 60147, 9554, 18818, 28689, 18229, + 40512, 23200, 40454, 24841, 43251, 63264, 29448, 45917, 882, 30788, + 7586, 57541, 22108, 59535, 31880, 7152, 63919, 58703, 14686, 29914, + 5872, 35868, 51479, 22721, 50927, 55094, 2341, 12, 23411, 9812}; + int16_t ivals4[NUM] = { + 2798, -13656, 1592, 3992, -25870, 25172, 7761, -18347, + 1617, 25472, 26763, -5982, 24791, 27189, 22911, 22502, + 15801, 25326, -2196, 9205, -10418, 20464, -16616, -11285, + 7249, 22866, 30574, -1298, 31351, 28252, 21322, -10072, + 7874, -26785, 22016, -12421, 0, 9999, 8888, -7777}; + + uint32_t ivals5[NUM] = { + 2238578408, 102907035, 2316773768, 617902655, 532045482, 73173328, + 1862406505, 142735533, 3494078873, 610196959, 4210902254, 1863122236, + 1257721692, 30008197, 3199012044, 3503276708, 3504950001, 1240383071, + 2463430884, 904104390, 4044803029, 3164373711, 1586440767, 1999536602, + 3377662770, 927279985, 1740225703, 1133653675, 3975816601, 260339911, + 1115507520, 2279020820, 4289105012, 692964674, 53775301, 999999999, + 777777777, 44444444, 1000, 9124}; + + int32_t ivals6[NUM] = { + 507394811, 1949685322, 1624859474, -940434061, -1440675113, + -2002743224, 369969519, 840772268, 224522238, 296113452, + -714007528, 480713824, 665592454, 1696360848, 780843358, + -1901994531, 1667711523, 1390737696, 1357434904, -290165630, + 305128121, 1301489180, 630469211, -1385846315, 809333959, + 1098974670, 56900257, 876775101, -1496897817, 1172877939, + 1528916082, 559152364, 749878571, 2071902702, -430851798, + 912342, -88888888, 777777777, -11111111, 0}; + + uint64_t ivals7[NUM] = {7916688577774406903ULL, + 16873442000174444235ULL, + 5261140449171682429ULL, + 6274209348061756377ULL, + 17881284978944229367ULL, + 4701456380424752599ULL, + 6241062870187348613ULL, + 8972524466137433448ULL, + 468629112944127776ULL, + 17523909311582893643ULL, + 17447772191733931166ULL, + 14311152396797789854ULL, + 9265327272409079662ULL, + 9958475911404242556ULL, + 15359829357736068242ULL, + 11416531655738519189ULL, + 16839972510321195914ULL, + 1927049095689256442ULL, + 3356565661065236628ULL, + 1065114285796701562ULL, + 7071763288904613033ULL, + 16473053015775147286ULL, + 10317354477399696817ULL, + 16005969584273256379ULL, + 15391010921709289298ULL, + 17671303749287233862ULL, + 8028596930095411867ULL, + 10265936863337610975ULL, + 17403745948359398749ULL, + 8504886230707230194ULL, + 12855085916215721214ULL, + 5562885793068933146ULL, + 1508385574711135517ULL, + 5953119477818575536ULL, + 9165320150094769334ULL, + 0, + 11, + 324, + 943534, + 930525}; + + int64_t ivals8[NUM] = {2944696543084623337, + 137239101631340692, + 4370169869966467498, + 3842452903439153631, + -795080033670806202, + 3023506421574592237, + -4142692575864168559, + 1716333381567689984, + 1591746912204250089, + -1974664662220599925, + 3144022139297218102, + -371429365537296255, + 4202906659701034264, + 3878513012313576184, + -3425767072006791628, + -2929337291418891626, + 1880013370888913338, + 1498977159463939728, + -2928775660744278650, + 4074214991200977615, + 4291797122649374026, + -763110360214750992, + 2883673064242977727, + 4270151072450399778, + 1408696225027958214, + 1264214335825459628, + -4152065441956669638, + 2684706424226400837, + 569335474272794084, + -2798088842177577224, + 814002749054152728, + 2517003920904582842, + 4089891582575745386, + 705067059635512048, + -2500935118374519236, + 0, + -1, + 234, + 45, + -1232143254}; + + auto work_group_sorter = [](uint32_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1u32_p1u32_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1u32_p1u32_u32_p3i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1u32_p1u32_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1u32_p1u32_u32_p3i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter1 = [](uint32_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1u32_p1u8_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1u32_p1u8_u32_p3i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1u32_p1u8_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1u32_p1u8_u32_p3i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter2 = [](uint32_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1u32_p3i8_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1u32_p3i8_u32_p3i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1u32_p3i8_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1u32_p3i8_u32_p3i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter3 = [](uint32_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1u32_p1u16_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1u32_p1u16_u32_p3i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1u32_p1u16_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1u32_p1u16_u32_p3i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter4 = [](uint32_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1u32_p1i16_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1u32_p1i16_u32_p3i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1u32_p1i16_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1u32_p1i16_u32_p3i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter5 = [](uint32_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1u32_p1u32_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1u32_p1u32_u32_p3i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1u32_p1u32_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1u32_p1u32_u32_p3i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter6 = [](uint32_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1u32_p1i32_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1u32_p1i32_u32_p3i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1u32_p1i32_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1u32_p1i32_u32_p3i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter7 = [](uint32_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1u32_p1u64_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1u32_p1u64_u32_p3i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1u32_p1u64_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1u32_p1u64_u32_p3i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter8 = [](uint32_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1u32_p1i64_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1u32_p1i64_u32_p3i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1u32_p1i64_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1u32_p1i64_u32_p3i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + constexpr static int NUM1 = 36; + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 4 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 3 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 6 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 9 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 18 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 36 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 5 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 8 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 4 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 10 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 20 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 40 pass." << std::endl; + + constexpr static int NUM2 = 24; + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 3 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 4 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 6 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 8 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 12 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 24 pass." << std::endl; + + constexpr static int NUM3 = 20; + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 4 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 5 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 10 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 20 pass." << std::endl; + + constexpr static int NUM4 = 30; + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 3 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 5 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 6 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 10 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 15 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 30 pass." << std::endl; + + constexpr static int NUM5 = 25; + test_work_group_KV_private_sort( + q, ikeys, ivals6, work_group_sorter6); + std::cout << "KV private sort NUM = " << NUM5 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals6, work_group_sorter6); + std::cout << "KV private sort NUM = " << NUM5 + << ", WG = 5 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals6, work_group_sorter6); + std::cout << "KV private sort NUM = " << NUM5 + << ", WG = 25 pass." << std::endl; + + constexpr static int NUM6 = 30; + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 3 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 5 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 6 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 10 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 15 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 30 pass." << std::endl; + + constexpr static int NUM7 = 21; + test_work_group_KV_private_sort( + q, ikeys, ivals8, work_group_sorter8); + std::cout << "KV private sort NUM = " << NUM7 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals8, work_group_sorter8); + std::cout << "KV private sort NUM = " << NUM7 + << ", WG = 3 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals8, work_group_sorter8); + std::cout << "KV private sort NUM = " << NUM7 + << ", WG = 7 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals8, work_group_sorter8); + std::cout << "KV private sort NUM = " << NUM7 + << ", WG = 21 pass." << std::endl; + } +} diff --git a/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_p3_u64.cpp b/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_p3_u64.cpp new file mode 100644 index 0000000000000..f99b2f44215fd --- /dev/null +++ b/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_p3_u64.cpp @@ -0,0 +1,693 @@ +// RUN: %{build} -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -fsycl-device-lib-jit-link -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DDES -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DDES -fsycl-device-lib-jit-link -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DSPREAD -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DSPREAD -fsycl-device-lib-jit-link -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DDES -DSPREAD -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DDES -DSPREAD -fsycl-device-lib-jit-link -o %t.out +// RUN: %{run} %t.out + +// UNSUPPORTED: cuda || hip || cpu + +#include "group_private_KV_sort_p1p1_p3.hpp" + +int main() { + queue q; + + { + constexpr static int NUM = 40; + uint64_t ikeys[NUM] = { + 1, 11, 1, 9, 3, 100, 34, 8, + 121, 77, 125, 23, 222336, 2, 111, 91, + 881112, 2, 6662451, 213, 199181, 3183, 310910, 11242, + 3216, 1, 199, 7712423, 0, 0, 181, 17, + 15, 101, 44, 103934, 1, 11, 193, 213}; + uint8_t ivals1[NUM] = {99, 32, 1, 2, 67, 91, 45, 43, 91, 77, + 16, 14, 24, 88, 76, 96, 76, 100, 63, 90, + 52, 82, 1, 22, 9, 225, 127, 0, 12, 128, + 3, 102, 200, 111, 123, 15, 45, 66, 123, 91}; + int8_t ivals2[NUM] = {-99, 127, -121, 100, 9, 5, 12, 35, -98, 77, + 112, -91, 11, 12, 3, 71, -66, 121, 18, 14, + 21, -22, 54, 88, -81, 31, 23, 53, 97, 103, + 71, 83, 97, 37, -41, -71, 112, -121, 98, 78}; + uint16_t ivals3[NUM] = { + 28831, 23870, 54250, 5022, 9571, 60147, 9554, 18818, 28689, 18229, + 40512, 23200, 40454, 24841, 43251, 63264, 29448, 45917, 882, 30788, + 7586, 57541, 22108, 59535, 31880, 7152, 63919, 58703, 14686, 29914, + 5872, 35868, 51479, 22721, 50927, 55094, 2341, 12, 23411, 9812}; + int16_t ivals4[NUM] = { + 2798, -13656, 1592, 3992, -25870, 25172, 7761, -18347, + 1617, 25472, 26763, -5982, 24791, 27189, 22911, 22502, + 15801, 25326, -2196, 9205, -10418, 20464, -16616, -11285, + 7249, 22866, 30574, -1298, 31351, 28252, 21322, -10072, + 7874, -26785, 22016, -12421, 0, 9999, 8888, -7777}; + + uint32_t ivals5[NUM] = { + 2238578408, 102907035, 2316773768, 617902655, 532045482, 73173328, + 1862406505, 142735533, 3494078873, 610196959, 4210902254, 1863122236, + 1257721692, 30008197, 3199012044, 3503276708, 3504950001, 1240383071, + 2463430884, 904104390, 4044803029, 3164373711, 1586440767, 1999536602, + 3377662770, 927279985, 1740225703, 1133653675, 3975816601, 260339911, + 1115507520, 2279020820, 4289105012, 692964674, 53775301, 999999999, + 777777777, 44444444, 1000, 9124}; + + int32_t ivals6[NUM] = { + 507394811, 1949685322, 1624859474, -940434061, -1440675113, + -2002743224, 369969519, 840772268, 224522238, 296113452, + -714007528, 480713824, 665592454, 1696360848, 780843358, + -1901994531, 1667711523, 1390737696, 1357434904, -290165630, + 305128121, 1301489180, 630469211, -1385846315, 809333959, + 1098974670, 56900257, 876775101, -1496897817, 1172877939, + 1528916082, 559152364, 749878571, 2071902702, -430851798, + 912342, -88888888, 777777777, -11111111, 0}; + + uint64_t ivals7[NUM] = {7916688577774406903ULL, + 16873442000174444235ULL, + 5261140449171682429ULL, + 6274209348061756377ULL, + 17881284978944229367ULL, + 4701456380424752599ULL, + 6241062870187348613ULL, + 8972524466137433448ULL, + 468629112944127776ULL, + 17523909311582893643ULL, + 17447772191733931166ULL, + 14311152396797789854ULL, + 9265327272409079662ULL, + 9958475911404242556ULL, + 15359829357736068242ULL, + 11416531655738519189ULL, + 16839972510321195914ULL, + 1927049095689256442ULL, + 3356565661065236628ULL, + 1065114285796701562ULL, + 7071763288904613033ULL, + 16473053015775147286ULL, + 10317354477399696817ULL, + 16005969584273256379ULL, + 15391010921709289298ULL, + 17671303749287233862ULL, + 8028596930095411867ULL, + 10265936863337610975ULL, + 17403745948359398749ULL, + 8504886230707230194ULL, + 12855085916215721214ULL, + 5562885793068933146ULL, + 1508385574711135517ULL, + 5953119477818575536ULL, + 9165320150094769334ULL, + 0, + 11, + 324, + 943534, + 930525}; + + int64_t ivals8[NUM] = {2944696543084623337, + 137239101631340692, + 4370169869966467498, + 3842452903439153631, + -795080033670806202, + 3023506421574592237, + -4142692575864168559, + 1716333381567689984, + 1591746912204250089, + -1974664662220599925, + 3144022139297218102, + -371429365537296255, + 4202906659701034264, + 3878513012313576184, + -3425767072006791628, + -2929337291418891626, + 1880013370888913338, + 1498977159463939728, + -2928775660744278650, + 4074214991200977615, + 4291797122649374026, + -763110360214750992, + 2883673064242977727, + 4270151072450399778, + 1408696225027958214, + 1264214335825459628, + -4152065441956669638, + 2684706424226400837, + 569335474272794084, + -2798088842177577224, + 814002749054152728, + 2517003920904582842, + 4089891582575745386, + 705067059635512048, + -2500935118374519236, + 0, + -1, + 234, + 45, + -1232143254}; + + auto work_group_sorter = [](uint64_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1u64_p1u32_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1u64_p1u32_u32_p3i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1u64_p1u32_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1u64_p1u32_u32_p3i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter1 = [](uint64_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1u64_p1u8_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1u64_p1u8_u32_p3i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1u64_p1u8_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1u64_p1u8_u32_p3i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter2 = [](uint64_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1u64_p1i8_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1u64_p1i8_u32_p3i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1u64_p1i8_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1u64_p1i8_u32_p3i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter3 = [](uint64_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1u64_p1u16_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1u64_p1u16_u32_p3i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1u64_p1u16_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1u64_p1u16_u32_p3i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter4 = [](uint64_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1u64_p1i16_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1u64_p1i16_u32_p3i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1u64_p1i16_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1u64_p1i16_u32_p3i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter5 = [](uint64_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1u64_p1u32_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1u64_p1u32_u32_p3i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1u64_p1u32_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1u64_p1u32_u32_p3i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter6 = [](uint64_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1u64_p1i32_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1u64_p1i32_u32_p3i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1u64_p1i32_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1u64_p1i32_u32_p3i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter7 = [](uint64_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1u64_p1u64_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1u64_p1u64_u32_p3i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1u64_p1u64_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1u64_p1u64_u32_p3i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter8 = [](uint64_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1u64_p1i64_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1u64_p1i64_u32_p3i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1u64_p1i64_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1u64_p1i64_u32_p3i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + constexpr static int NUM1 = 36; + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 4 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 3 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 6 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 9 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 18 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 36 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 5 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 8 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 4 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 10 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 20 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 40 pass." << std::endl; + + constexpr static int NUM2 = 24; + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 3 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 4 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 6 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 8 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 12 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 24 pass." << std::endl; + + constexpr static int NUM3 = 20; + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 4 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 5 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 10 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 20 pass." << std::endl; + + constexpr static int NUM4 = 30; + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 3 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 5 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 6 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 10 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 15 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 30 pass." << std::endl; + + constexpr static int NUM5 = 25; + test_work_group_KV_private_sort( + q, ikeys, ivals6, work_group_sorter6); + std::cout << "KV private sort NUM = " << NUM5 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals6, work_group_sorter6); + std::cout << "KV private sort NUM = " << NUM5 + << ", WG = 5 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals6, work_group_sorter6); + std::cout << "KV private sort NUM = " << NUM5 + << ", WG = 25 pass." << std::endl; + + constexpr static int NUM6 = 30; + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 3 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 5 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 6 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 10 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 15 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 30 pass." << std::endl; + + constexpr static int NUM7 = 21; + test_work_group_KV_private_sort( + q, ikeys, ivals8, work_group_sorter8); + std::cout << "KV private sort NUM = " << NUM7 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals8, work_group_sorter8); + std::cout << "KV private sort NUM = " << NUM7 + << ", WG = 3 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals8, work_group_sorter8); + std::cout << "KV private sort NUM = " << NUM7 + << ", WG = 7 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals8, work_group_sorter8); + std::cout << "KV private sort NUM = " << NUM7 + << ", WG = 21 pass." << std::endl; + } +} diff --git a/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_p3_u8.cpp b/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_p3_u8.cpp new file mode 100644 index 0000000000000..b42ea4acc0736 --- /dev/null +++ b/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_p3_u8.cpp @@ -0,0 +1,607 @@ +// RUN: %{build} -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -fsycl-device-lib-jit-link -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DDES -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DDES -fsycl-device-lib-jit-link -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DSPREAD -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DSPREAD -fsycl-device-lib-jit-link -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DDES -DSPREAD -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DDES -DSPREAD -fsycl-device-lib-jit-link -o %t.out +// RUN: %{run} %t.out + +// UNSUPPORTED: cuda || hip || cpu + +#include "group_private_KV_sort_p1p1_p3.hpp" + +int main() { + queue q; + + { + constexpr static int NUM = 35; + uint8_t ikeys[NUM] = {1, 11, 1, 9, 3, 100, 34, 8, 121, + 77, 125, 23, 36, 2, 111, 91, 88, 2, + 51, 213, 181, 183, 31, 142, 216, 1, 199, + 124, 12, 0, 181, 17, 15, 101, 44}; + uint32_t ivals[NUM] = {99, 32, 1, 2, 67, 9123, 453, + 435, 91111, 777, 165, 145, 2456, 88811, + 761, 96, 765, 10000, 6364, 90, 525, + 882, 1, 2423, 9, 4324, 9123, 0, + 1232, 777, 555, 314159, 905, 9831, 84341}; + uint8_t ivals1[NUM] = {99, 32, 1, 2, 67, 91, 45, 43, 91, 77, 16, 14, + 24, 88, 76, 96, 76, 100, 63, 90, 52, 82, 1, 22, + 9, 225, 127, 0, 12, 128, 3, 102, 200, 111, 123}; + int8_t ivals2[NUM] = {-99, 127, -121, 100, 9, 5, 12, 35, -98, + 77, 112, -91, 11, 12, 3, 71, -66, 121, + 18, 14, 21, -22, 54, 88, -81, 31, 23, + 53, 97, 103, 71, 83, 97, 37, -41}; + uint16_t ivals3[NUM] = {28831, 23870, 54250, 5022, 9571, 60147, 9554, + 18818, 28689, 18229, 40512, 23200, 40454, 24841, + 43251, 63264, 29448, 45917, 882, 30788, 7586, + 57541, 22108, 59535, 31880, 7152, 63919, 58703, + 14686, 29914, 5872, 35868, 51479, 22721, 50927}; + int16_t ivals4[NUM] = { + 2798, -13656, 1592, 3992, -25870, 25172, 7761, -18347, 1617, + 25472, 26763, -5982, 24791, 27189, 22911, 22502, 15801, 25326, + -2196, 9205, -10418, 20464, -16616, -11285, 7249, 22866, 30574, + -1298, 31351, 28252, 21322, -10072, 7874, -26785, 22016}; + + uint32_t ivals5[NUM] = { + 2238578408, 102907035, 2316773768, 617902655, 532045482, 73173328, + 1862406505, 142735533, 3494078873, 610196959, 4210902254, 1863122236, + 1257721692, 30008197, 3199012044, 3503276708, 3504950001, 1240383071, + 2463430884, 904104390, 4044803029, 3164373711, 1586440767, 1999536602, + 3377662770, 927279985, 1740225703, 1133653675, 3975816601, 260339911, + 1115507520, 2279020820, 4289105012, 692964674, 53775301}; + + int32_t ivals6[NUM] = { + 507394811, 1949685322, 1624859474, -940434061, -1440675113, + -2002743224, 369969519, 840772268, 224522238, 296113452, + -714007528, 480713824, 665592454, 1696360848, 780843358, + -1901994531, 1667711523, 1390737696, 1357434904, -290165630, + 305128121, 1301489180, 630469211, -1385846315, 809333959, + 1098974670, 56900257, 876775101, -1496897817, 1172877939, + 1528916082, 559152364, 749878571, 2071902702, -430851798}; + + uint64_t ivals7[NUM] = {7916688577774406903ULL, 16873442000174444235ULL, + 5261140449171682429ULL, 6274209348061756377ULL, + 17881284978944229367ULL, 4701456380424752599ULL, + 6241062870187348613ULL, 8972524466137433448ULL, + 468629112944127776ULL, 17523909311582893643ULL, + 17447772191733931166ULL, 14311152396797789854ULL, + 9265327272409079662ULL, 9958475911404242556ULL, + 15359829357736068242ULL, 11416531655738519189ULL, + 16839972510321195914ULL, 1927049095689256442ULL, + 3356565661065236628ULL, 1065114285796701562ULL, + 7071763288904613033ULL, 16473053015775147286ULL, + 10317354477399696817ULL, 16005969584273256379ULL, + 15391010921709289298ULL, 17671303749287233862ULL, + 8028596930095411867ULL, 10265936863337610975ULL, + 17403745948359398749ULL, 8504886230707230194ULL, + 12855085916215721214ULL, 5562885793068933146ULL, + 1508385574711135517ULL, 5953119477818575536ULL, + 9165320150094769334ULL}; + + int64_t ivals8[NUM] = { + 2944696543084623337, 137239101631340692, 4370169869966467498, + 3842452903439153631, -795080033670806202, 3023506421574592237, + -4142692575864168559, 1716333381567689984, 1591746912204250089, + -1974664662220599925, 3144022139297218102, -371429365537296255, + 4202906659701034264, 3878513012313576184, -3425767072006791628, + -2929337291418891626, 1880013370888913338, 1498977159463939728, + -2928775660744278650, 4074214991200977615, 4291797122649374026, + -763110360214750992, 2883673064242977727, 4270151072450399778, + 1408696225027958214, 1264214335825459628, -4152065441956669638, + 2684706424226400837, 569335474272794084, -2798088842177577224, + 814002749054152728, 2517003920904582842, 4089891582575745386, + 705067059635512048, -2500935118374519236}; + + auto work_group_sorter = [](uint8_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1u8_p1u32_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1u8_p1u32_u32_p3i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1u8_p1u32_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1u8_p1u32_u32_p3i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter1 = [](uint8_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1u8_p1u8_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1u8_p1u8_u32_p3i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1u8_p1u8_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1u8_p1u8_u32_p3i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter2 = [](uint8_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1u8_p1i8_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1u8_p1i8_u32_p3i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1u8_p1i8_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1u8_p1i8_u32_p3i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter3 = [](uint8_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1u8_p1u16_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1u8_p1u16_u32_p3i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1u8_p1u16_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1u8_p1u16_u32_p3i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter4 = [](uint8_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1u8_p1i16_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1u8_p1i16_u32_p3i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1u8_p1i16_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1u8_p1i16_u32_p3i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter5 = [](uint8_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#if SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1u8_p1u32_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1u8_p1u32_u32_p3i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1u8_p1u32_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1u8_p1u32_u32_p3i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter6 = [](uint8_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1u8_p1i32_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1u8_p1i32_u32_p3i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1u8_p1i32_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1u8_p1i32_u32_p3i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter7 = [](uint8_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1u8_p1u64_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1u8_p1u64_u32_p3i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1u8_p1u64_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1u8_p1u64_u32_p3i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter8 = [](uint8_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1u8_p1i64_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1u8_p1i64_u32_p3i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1u8_p1i64_u32_p3i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1u8_p1i64_u32_p3i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + constexpr static int NUM1 = 32; + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 4 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 8 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 16 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 32 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 5 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 7 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 35 pass." << std::endl; + + constexpr static int NUM2 = 24; + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 3 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 4 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 6 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 8 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 12 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 24 pass." << std::endl; + + constexpr static int NUM3 = 20; + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 4 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 5 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 10 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 20 pass." << std::endl; + + constexpr static int NUM4 = 30; + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 3 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 5 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 6 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 10 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 15 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 30 pass." << std::endl; + + constexpr static int NUM5 = 25; + test_work_group_KV_private_sort( + q, ikeys, ivals6, work_group_sorter6); + std::cout << "KV private sort NUM = " << NUM5 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals6, work_group_sorter6); + std::cout << "KV private sort NUM = " << NUM5 + << ", WG = 5 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals6, work_group_sorter6); + std::cout << "KV private sort NUM = " << NUM5 + << ", WG = 25 pass." << std::endl; + + constexpr static int NUM6 = 30; + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 3 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 5 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 6 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 10 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 15 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 30 pass." << std::endl; + + constexpr static int NUM7 = 21; + test_work_group_KV_private_sort( + q, ikeys, ivals8, work_group_sorter8); + std::cout << "KV private sort NUM = " << NUM7 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals8, work_group_sorter8); + std::cout << "KV private sort NUM = " << NUM7 + << ", WG = 3 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals8, work_group_sorter8); + std::cout << "KV private sort NUM = " << NUM7 + << ", WG = 7 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals8, work_group_sorter8); + std::cout << "KV private sort NUM = " << NUM7 + << ", WG = 21 pass." << std::endl; + } +} diff --git a/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_u16.cpp b/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_u16.cpp new file mode 100644 index 0000000000000..6bd24951d2fc7 --- /dev/null +++ b/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_u16.cpp @@ -0,0 +1,607 @@ +// RUN: %{build} -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -fsycl-device-lib-jit-link -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DDES -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DDES -fsycl-device-lib-jit-link -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DSPREAD -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DSPREAD -fsycl-device-lib-jit-link -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DDES -DSPREAD -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DDES -DSPREAD -fsycl-device-lib-jit-link -o %t.out +// RUN: %{run} %t.out + +// UNSUPPORTED: cuda || hip || cpu + +#include "group_private_KV_sort_p1p1_p1.hpp" + +int main() { + queue q; + + { + constexpr static int NUM = 35; + uint16_t ikeys[NUM] = {1, 11, 1, 9, 3, 100, 34, 8, 121, + 77, 125, 23, 36, 2, 111, 91, 88, 2, + 51, 213, 181, 183, 31, 142, 216, 1, 199, + 124, 12, 0, 181, 17, 15, 101, 44}; + uint32_t ivals[NUM] = {99, 32, 1, 2, 67, 9123, 453, + 435, 91111, 777, 165, 145, 2456, 88811, + 761, 96, 765, 10000, 6364, 90, 525, + 882, 1, 2423, 9, 4324, 9123, 0, + 1232, 777, 555, 314159, 905, 9831, 84341}; + uint8_t ivals1[NUM] = {99, 32, 1, 2, 67, 91, 45, 43, 91, 77, 16, 14, + 24, 88, 76, 96, 76, 100, 63, 90, 52, 82, 1, 22, + 9, 225, 127, 0, 12, 128, 3, 102, 200, 111, 123}; + int8_t ivals2[NUM] = {-99, 127, -121, 100, 9, 5, 12, 35, -98, + 77, 112, -91, 11, 12, 3, 71, -66, 121, + 18, 14, 21, -22, 54, 88, -81, 31, 23, + 53, 97, 103, 71, 83, 97, 37, -41}; + uint16_t ivals3[NUM] = {28831, 23870, 54250, 5022, 9571, 60147, 9554, + 18818, 28689, 18229, 40512, 23200, 40454, 24841, + 43251, 63264, 29448, 45917, 882, 30788, 7586, + 57541, 22108, 59535, 31880, 7152, 63919, 58703, + 14686, 29914, 5872, 35868, 51479, 22721, 50927}; + int16_t ivals4[NUM] = { + 2798, -13656, 1592, 3992, -25870, 25172, 7761, -18347, 1617, + 25472, 26763, -5982, 24791, 27189, 22911, 22502, 15801, 25326, + -2196, 9205, -10418, 20464, -16616, -11285, 7249, 22866, 30574, + -1298, 31351, 28252, 21322, -10072, 7874, -26785, 22016}; + + uint32_t ivals5[NUM] = { + 2238578408, 102907035, 2316773768, 617902655, 532045482, 73173328, + 1862406505, 142735533, 3494078873, 610196959, 4210902254, 1863122236, + 1257721692, 30008197, 3199012044, 3503276708, 3504950001, 1240383071, + 2463430884, 904104390, 4044803029, 3164373711, 1586440767, 1999536602, + 3377662770, 927279985, 1740225703, 1133653675, 3975816601, 260339911, + 1115507520, 2279020820, 4289105012, 692964674, 53775301}; + + int32_t ivals6[NUM] = { + 507394811, 1949685322, 1624859474, -940434061, -1440675113, + -2002743224, 369969519, 840772268, 224522238, 296113452, + -714007528, 480713824, 665592454, 1696360848, 780843358, + -1901994531, 1667711523, 1390737696, 1357434904, -290165630, + 305128121, 1301489180, 630469211, -1385846315, 809333959, + 1098974670, 56900257, 876775101, -1496897817, 1172877939, + 1528916082, 559152364, 749878571, 2071902702, -430851798}; + + uint64_t ivals7[NUM] = {7916688577774406903ULL, 16873442000174444235ULL, + 5261140449171682429ULL, 6274209348061756377ULL, + 17881284978944229367ULL, 4701456380424752599ULL, + 6241062870187348613ULL, 8972524466137433448ULL, + 468629112944127776ULL, 17523909311582893643ULL, + 17447772191733931166ULL, 14311152396797789854ULL, + 9265327272409079662ULL, 9958475911404242556ULL, + 15359829357736068242ULL, 11416531655738519189ULL, + 16839972510321195914ULL, 1927049095689256442ULL, + 3356565661065236628ULL, 1065114285796701562ULL, + 7071763288904613033ULL, 16473053015775147286ULL, + 10317354477399696817ULL, 16005969584273256379ULL, + 15391010921709289298ULL, 17671303749287233862ULL, + 8028596930095411867ULL, 10265936863337610975ULL, + 17403745948359398749ULL, 8504886230707230194ULL, + 12855085916215721214ULL, 5562885793068933146ULL, + 1508385574711135517ULL, 5953119477818575536ULL, + 9165320150094769334ULL}; + + int64_t ivals8[NUM] = { + 2944696543084623337, 137239101631340692, 4370169869966467498, + 3842452903439153631, -795080033670806202, 3023506421574592237, + -4142692575864168559, 1716333381567689984, 1591746912204250089, + -1974664662220599925, 3144022139297218102, -371429365537296255, + 4202906659701034264, 3878513012313576184, -3425767072006791628, + -2929337291418891626, 1880013370888913338, 1498977159463939728, + -2928775660744278650, 4074214991200977615, 4291797122649374026, + -763110360214750992, 2883673064242977727, 4270151072450399778, + 1408696225027958214, 1264214335825459628, -4152065441956669638, + 2684706424226400837, 569335474272794084, -2798088842177577224, + 814002749054152728, 2517003920904582842, 4089891582575745386, + 705067059635512048, -2500935118374519236}; + + auto work_group_sorter = [](uint16_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1u16_p1u32_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1u16_p1u32_u32_p1i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1u16_p1u32_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1u16_p1u32_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter1 = [](uint16_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1u16_p1u8_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1u16_p1u8_u32_p1i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1u16_p1u8_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1u16_p1u8_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter2 = [](uint16_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1u16_p1i8_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1u16_p1i8_u32_p1i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1u16_p1i8_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1u16_p1i8_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter3 = [](uint16_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1u16_p1u16_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1u16_p1u16_u32_p1i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1u16_p1u16_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1u16_p1u16_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter4 = [](uint16_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1u16_p1i16_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1u16_p1i16_u32_p1i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1u16_p1i16_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1u16_p1i16_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter5 = [](uint16_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1u16_p1u32_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1u16_p1u32_u32_p1i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1u16_p1u32_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1u16_p1u32_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter6 = [](uint16_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1u16_p1i32_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1u16_p1i32_u32_p1i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1u16_p1i32_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1u16_p1i32_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter7 = [](uint16_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1u16_p1u64_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1u16_p1u64_u32_p1i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1u16_p1u64_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1u16_p1u64_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter8 = [](uint16_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1u16_p1i64_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1u16_p1i64_u32_p1i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1u16_p1i64_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1u16_p1i64_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + constexpr static int NUM1 = 32; + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 4 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 8 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 16 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 32 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 5 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 7 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 35 pass." << std::endl; + + constexpr static int NUM2 = 24; + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 3 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 4 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 6 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 8 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 12 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 24 pass." << std::endl; + + constexpr static int NUM3 = 20; + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 4 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 5 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 10 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 20 pass." << std::endl; + + constexpr static int NUM4 = 30; + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 3 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 5 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 6 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 10 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 15 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 30 pass." << std::endl; + + constexpr static int NUM5 = 25; + test_work_group_KV_private_sort( + q, ikeys, ivals6, work_group_sorter6); + std::cout << "KV private sort NUM = " << NUM5 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals6, work_group_sorter6); + std::cout << "KV private sort NUM = " << NUM5 + << ", WG = 5 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals6, work_group_sorter6); + std::cout << "KV private sort NUM = " << NUM5 + << ", WG = 25 pass." << std::endl; + + constexpr static int NUM6 = 30; + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 3 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 5 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 6 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 10 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 15 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 30 pass." << std::endl; + + constexpr static int NUM7 = 21; + test_work_group_KV_private_sort( + q, ikeys, ivals8, work_group_sorter8); + std::cout << "KV private sort NUM = " << NUM7 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals8, work_group_sorter8); + std::cout << "KV private sort NUM = " << NUM7 + << ", WG = 3 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals8, work_group_sorter8); + std::cout << "KV private sort NUM = " << NUM7 + << ", WG = 7 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals8, work_group_sorter8); + std::cout << "KV private sort NUM = " << NUM7 + << ", WG = 21 pass." << std::endl; + } +} diff --git a/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_u32.cpp b/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_u32.cpp new file mode 100644 index 0000000000000..b515a5a32ca91 --- /dev/null +++ b/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_u32.cpp @@ -0,0 +1,693 @@ +// RUN: %{build} -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -fsycl-device-lib-jit-link -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DDES -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DDES -fsycl-device-lib-jit-link -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DSPREAD -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DSPREAD -fsycl-device-lib-jit-link -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DDES -DSPREAD -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DDES -DSPREAD -fsycl-device-lib-jit-link -o %t.out +// RUN: %{run} %t.out + +// UNSUPPORTED: cuda || hip || cpu + +#include "group_private_KV_sort_p1p1_p1.hpp" + +int main() { + queue q; + + { + constexpr static int NUM = 40; + uint32_t ikeys[NUM] = { + 1, 11, 1, 9, 3, 100, 34, 8, + 121, 77, 125, 23, 222336, 2, 111, 91, + 881112, 2, 6662451, 213, 199181, 3183, 310910, 11242, + 3216, 1, 199, 7712423, 0, 0, 181, 17, + 15, 101, 44, 103934, 1, 11, 193, 213}; + uint8_t ivals1[NUM] = {99, 32, 1, 2, 67, 91, 45, 43, 91, 77, + 16, 14, 24, 88, 76, 96, 76, 100, 63, 90, + 52, 82, 1, 22, 9, 225, 127, 0, 12, 128, + 3, 102, 200, 111, 123, 15, 45, 66, 123, 91}; + int8_t ivals2[NUM] = {-99, 127, -121, 100, 9, 5, 12, 35, -98, 77, + 112, -91, 11, 12, 3, 71, -66, 121, 18, 14, + 21, -22, 54, 88, -81, 31, 23, 53, 97, 103, + 71, 83, 97, 37, -41, -71, 112, -121, 98, 78}; + uint16_t ivals3[NUM] = { + 28831, 23870, 54250, 5022, 9571, 60147, 9554, 18818, 28689, 18229, + 40512, 23200, 40454, 24841, 43251, 63264, 29448, 45917, 882, 30788, + 7586, 57541, 22108, 59535, 31880, 7152, 63919, 58703, 14686, 29914, + 5872, 35868, 51479, 22721, 50927, 55094, 2341, 12, 23411, 9812}; + int16_t ivals4[NUM] = { + 2798, -13656, 1592, 3992, -25870, 25172, 7761, -18347, + 1617, 25472, 26763, -5982, 24791, 27189, 22911, 22502, + 15801, 25326, -2196, 9205, -10418, 20464, -16616, -11285, + 7249, 22866, 30574, -1298, 31351, 28252, 21322, -10072, + 7874, -26785, 22016, -12421, 0, 9999, 8888, -7777}; + + uint32_t ivals5[NUM] = { + 2238578408, 102907035, 2316773768, 617902655, 532045482, 73173328, + 1862406505, 142735533, 3494078873, 610196959, 4210902254, 1863122236, + 1257721692, 30008197, 3199012044, 3503276708, 3504950001, 1240383071, + 2463430884, 904104390, 4044803029, 3164373711, 1586440767, 1999536602, + 3377662770, 927279985, 1740225703, 1133653675, 3975816601, 260339911, + 1115507520, 2279020820, 4289105012, 692964674, 53775301, 999999999, + 777777777, 44444444, 1000, 9124}; + + int32_t ivals6[NUM] = { + 507394811, 1949685322, 1624859474, -940434061, -1440675113, + -2002743224, 369969519, 840772268, 224522238, 296113452, + -714007528, 480713824, 665592454, 1696360848, 780843358, + -1901994531, 1667711523, 1390737696, 1357434904, -290165630, + 305128121, 1301489180, 630469211, -1385846315, 809333959, + 1098974670, 56900257, 876775101, -1496897817, 1172877939, + 1528916082, 559152364, 749878571, 2071902702, -430851798, + 912342, -88888888, 777777777, -11111111, 0}; + + uint64_t ivals7[NUM] = {7916688577774406903ULL, + 16873442000174444235ULL, + 5261140449171682429ULL, + 6274209348061756377ULL, + 17881284978944229367ULL, + 4701456380424752599ULL, + 6241062870187348613ULL, + 8972524466137433448ULL, + 468629112944127776ULL, + 17523909311582893643ULL, + 17447772191733931166ULL, + 14311152396797789854ULL, + 9265327272409079662ULL, + 9958475911404242556ULL, + 15359829357736068242ULL, + 11416531655738519189ULL, + 16839972510321195914ULL, + 1927049095689256442ULL, + 3356565661065236628ULL, + 1065114285796701562ULL, + 7071763288904613033ULL, + 16473053015775147286ULL, + 10317354477399696817ULL, + 16005969584273256379ULL, + 15391010921709289298ULL, + 17671303749287233862ULL, + 8028596930095411867ULL, + 10265936863337610975ULL, + 17403745948359398749ULL, + 8504886230707230194ULL, + 12855085916215721214ULL, + 5562885793068933146ULL, + 1508385574711135517ULL, + 5953119477818575536ULL, + 9165320150094769334ULL, + 0, + 11, + 324, + 943534, + 930525}; + + int64_t ivals8[NUM] = {2944696543084623337, + 137239101631340692, + 4370169869966467498, + 3842452903439153631, + -795080033670806202, + 3023506421574592237, + -4142692575864168559, + 1716333381567689984, + 1591746912204250089, + -1974664662220599925, + 3144022139297218102, + -371429365537296255, + 4202906659701034264, + 3878513012313576184, + -3425767072006791628, + -2929337291418891626, + 1880013370888913338, + 1498977159463939728, + -2928775660744278650, + 4074214991200977615, + 4291797122649374026, + -763110360214750992, + 2883673064242977727, + 4270151072450399778, + 1408696225027958214, + 1264214335825459628, + -4152065441956669638, + 2684706424226400837, + 569335474272794084, + -2798088842177577224, + 814002749054152728, + 2517003920904582842, + 4089891582575745386, + 705067059635512048, + -2500935118374519236, + 0, + -1, + 234, + 45, + -1232143254}; + + auto work_group_sorter = [](uint32_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1u32_p1u32_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1u32_p1u32_u32_p1i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1u32_p1u32_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1u32_p1u32_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter1 = [](uint32_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1u32_p1u8_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1u32_p1u8_u32_p1i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1u32_p1u8_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1u32_p1u8_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter2 = [](uint32_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1u32_p1i8_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1u32_p1i8_u32_p1i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1u32_p1i8_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1u32_p1i8_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter3 = [](uint32_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1u32_p1u16_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1u32_p1u16_u32_p1i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1u32_p1u16_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1u32_p1u16_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter4 = [](uint32_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1u32_p1i16_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1u32_p1i16_u32_p1i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1u32_p1i16_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1u32_p1i16_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter5 = [](uint32_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1u32_p1u32_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1u32_p1u32_u32_p1i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1u32_p1u32_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1u32_p1u32_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter6 = [](uint32_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1u32_p1i32_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1u32_p1i32_u32_p1i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1u32_p1i32_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1u32_p1i32_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter7 = [](uint32_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1u32_p1u64_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1u32_p1u64_u32_p1i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1u32_p1u64_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1u32_p1u64_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter8 = [](uint32_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1u32_p1i64_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1u32_p1i64_u32_p1i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1u32_p1i64_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1u32_p1i64_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + constexpr static int NUM1 = 36; + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 4 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 3 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 6 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 9 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 18 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 36 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 5 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 8 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 4 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 10 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 20 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 40 pass." << std::endl; + + constexpr static int NUM2 = 24; + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 3 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 4 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 6 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 8 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 12 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 24 pass." << std::endl; + + constexpr static int NUM3 = 20; + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 4 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 5 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 10 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 20 pass." << std::endl; + + constexpr static int NUM4 = 30; + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 3 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 5 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 6 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 10 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 15 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 30 pass." << std::endl; + + constexpr static int NUM5 = 25; + test_work_group_KV_private_sort( + q, ikeys, ivals6, work_group_sorter6); + std::cout << "KV private sort NUM = " << NUM5 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals6, work_group_sorter6); + std::cout << "KV private sort NUM = " << NUM5 + << ", WG = 5 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals6, work_group_sorter6); + std::cout << "KV private sort NUM = " << NUM5 + << ", WG = 25 pass." << std::endl; + + constexpr static int NUM6 = 30; + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 3 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 5 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 6 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 10 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 15 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 30 pass." << std::endl; + + constexpr static int NUM7 = 21; + test_work_group_KV_private_sort( + q, ikeys, ivals8, work_group_sorter8); + std::cout << "KV private sort NUM = " << NUM7 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals8, work_group_sorter8); + std::cout << "KV private sort NUM = " << NUM7 + << ", WG = 3 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals8, work_group_sorter8); + std::cout << "KV private sort NUM = " << NUM7 + << ", WG = 7 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals8, work_group_sorter8); + std::cout << "KV private sort NUM = " << NUM7 + << ", WG = 21 pass." << std::endl; + } +} diff --git a/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_u64.cpp b/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_u64.cpp new file mode 100644 index 0000000000000..39ca18929bd01 --- /dev/null +++ b/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_u64.cpp @@ -0,0 +1,693 @@ +// RUN: %{build} -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -fsycl-device-lib-jit-link -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DDES -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DDES -fsycl-device-lib-jit-link -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DSPREAD -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DSPREAD -fsycl-device-lib-jit-link -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DDES -DSPREAD -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DDES -DSPREAD -fsycl-device-lib-jit-link -o %t.out +// RUN: %{run} %t.out + +// UNSUPPORTED: cuda || hip || cpu + +#include "group_private_KV_sort_p1p1_p1.hpp" + +int main() { + queue q; + + { + constexpr static int NUM = 40; + uint64_t ikeys[NUM] = { + 1, 11, 1, 9, 3, 100, 34, 8, + 121, 77, 125, 23, 222336, 2, 111, 91, + 881112, 2, 6662451, 213, 199181, 3183, 310910, 11242, + 3216, 1, 199, 7712423, 0, 0, 181, 17, + 15, 101, 44, 103934, 1, 11, 193, 213}; + uint8_t ivals1[NUM] = {99, 32, 1, 2, 67, 91, 45, 43, 91, 77, + 16, 14, 24, 88, 76, 96, 76, 100, 63, 90, + 52, 82, 1, 22, 9, 225, 127, 0, 12, 128, + 3, 102, 200, 111, 123, 15, 45, 66, 123, 91}; + int8_t ivals2[NUM] = {-99, 127, -121, 100, 9, 5, 12, 35, -98, 77, + 112, -91, 11, 12, 3, 71, -66, 121, 18, 14, + 21, -22, 54, 88, -81, 31, 23, 53, 97, 103, + 71, 83, 97, 37, -41, -71, 112, -121, 98, 78}; + uint16_t ivals3[NUM] = { + 28831, 23870, 54250, 5022, 9571, 60147, 9554, 18818, 28689, 18229, + 40512, 23200, 40454, 24841, 43251, 63264, 29448, 45917, 882, 30788, + 7586, 57541, 22108, 59535, 31880, 7152, 63919, 58703, 14686, 29914, + 5872, 35868, 51479, 22721, 50927, 55094, 2341, 12, 23411, 9812}; + int16_t ivals4[NUM] = { + 2798, -13656, 1592, 3992, -25870, 25172, 7761, -18347, + 1617, 25472, 26763, -5982, 24791, 27189, 22911, 22502, + 15801, 25326, -2196, 9205, -10418, 20464, -16616, -11285, + 7249, 22866, 30574, -1298, 31351, 28252, 21322, -10072, + 7874, -26785, 22016, -12421, 0, 9999, 8888, -7777}; + + uint32_t ivals5[NUM] = { + 2238578408, 102907035, 2316773768, 617902655, 532045482, 73173328, + 1862406505, 142735533, 3494078873, 610196959, 4210902254, 1863122236, + 1257721692, 30008197, 3199012044, 3503276708, 3504950001, 1240383071, + 2463430884, 904104390, 4044803029, 3164373711, 1586440767, 1999536602, + 3377662770, 927279985, 1740225703, 1133653675, 3975816601, 260339911, + 1115507520, 2279020820, 4289105012, 692964674, 53775301, 999999999, + 777777777, 44444444, 1000, 9124}; + + int32_t ivals6[NUM] = { + 507394811, 1949685322, 1624859474, -940434061, -1440675113, + -2002743224, 369969519, 840772268, 224522238, 296113452, + -714007528, 480713824, 665592454, 1696360848, 780843358, + -1901994531, 1667711523, 1390737696, 1357434904, -290165630, + 305128121, 1301489180, 630469211, -1385846315, 809333959, + 1098974670, 56900257, 876775101, -1496897817, 1172877939, + 1528916082, 559152364, 749878571, 2071902702, -430851798, + 912342, -88888888, 777777777, -11111111, 0}; + + uint64_t ivals7[NUM] = {7916688577774406903ULL, + 16873442000174444235ULL, + 5261140449171682429ULL, + 6274209348061756377ULL, + 17881284978944229367ULL, + 4701456380424752599ULL, + 6241062870187348613ULL, + 8972524466137433448ULL, + 468629112944127776ULL, + 17523909311582893643ULL, + 17447772191733931166ULL, + 14311152396797789854ULL, + 9265327272409079662ULL, + 9958475911404242556ULL, + 15359829357736068242ULL, + 11416531655738519189ULL, + 16839972510321195914ULL, + 1927049095689256442ULL, + 3356565661065236628ULL, + 1065114285796701562ULL, + 7071763288904613033ULL, + 16473053015775147286ULL, + 10317354477399696817ULL, + 16005969584273256379ULL, + 15391010921709289298ULL, + 17671303749287233862ULL, + 8028596930095411867ULL, + 10265936863337610975ULL, + 17403745948359398749ULL, + 8504886230707230194ULL, + 12855085916215721214ULL, + 5562885793068933146ULL, + 1508385574711135517ULL, + 5953119477818575536ULL, + 9165320150094769334ULL, + 0, + 11, + 324, + 943534, + 930525}; + + int64_t ivals8[NUM] = {2944696543084623337, + 137239101631340692, + 4370169869966467498, + 3842452903439153631, + -795080033670806202, + 3023506421574592237, + -4142692575864168559, + 1716333381567689984, + 1591746912204250089, + -1974664662220599925, + 3144022139297218102, + -371429365537296255, + 4202906659701034264, + 3878513012313576184, + -3425767072006791628, + -2929337291418891626, + 1880013370888913338, + 1498977159463939728, + -2928775660744278650, + 4074214991200977615, + 4291797122649374026, + -763110360214750992, + 2883673064242977727, + 4270151072450399778, + 1408696225027958214, + 1264214335825459628, + -4152065441956669638, + 2684706424226400837, + 569335474272794084, + -2798088842177577224, + 814002749054152728, + 2517003920904582842, + 4089891582575745386, + 705067059635512048, + -2500935118374519236, + 0, + -1, + 234, + 45, + -1232143254}; + + auto work_group_sorter = [](uint64_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1u64_p1u32_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1u64_p1u32_u32_p1i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1u64_p1u32_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1u64_p1u32_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter1 = [](uint64_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1u64_p1u8_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1u64_p1u8_u32_p1i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1u64_p1u8_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1u64_p1u8_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter2 = [](uint64_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1u64_p1i8_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1u64_p1i8_u32_p1i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1u64_p1i8_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1u64_p1i8_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter3 = [](uint64_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1u64_p1u16_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1u64_p1u16_u32_p1i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1u64_p1u16_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1u64_p1u16_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter4 = [](uint64_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1u64_p1i16_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1u64_p1i16_u32_p1i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1u64_p1i16_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1u64_p1i16_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter5 = [](uint64_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1u64_p1u32_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1u64_p1u32_u32_p1i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1u64_p1u32_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1u64_p1u32_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter6 = [](uint64_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1u64_p1i32_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1u64_p1i32_u32_p1i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1u64_p1i32_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1u64_p1i32_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter7 = [](uint64_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1u64_p1u64_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1u64_p1u64_u32_p1i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1u64_p1u64_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1u64_p1u64_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter8 = [](uint64_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1u64_p1i64_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1u64_p1i64_u32_p1i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1u64_p1i64_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1u64_p1i64_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + constexpr static int NUM1 = 36; + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 4 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 3 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 6 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 9 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 18 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 36 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 5 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 8 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 4 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 10 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 20 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 40 pass." << std::endl; + + constexpr static int NUM2 = 24; + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 3 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 4 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 6 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 8 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 12 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 24 pass." << std::endl; + + constexpr static int NUM3 = 20; + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 4 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 5 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 10 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 20 pass." << std::endl; + + constexpr static int NUM4 = 30; + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 3 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 5 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 6 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 10 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 15 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 30 pass." << std::endl; + + constexpr static int NUM5 = 25; + test_work_group_KV_private_sort( + q, ikeys, ivals6, work_group_sorter6); + std::cout << "KV private sort NUM = " << NUM5 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals6, work_group_sorter6); + std::cout << "KV private sort NUM = " << NUM5 + << ", WG = 5 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals6, work_group_sorter6); + std::cout << "KV private sort NUM = " << NUM5 + << ", WG = 25 pass." << std::endl; + + constexpr static int NUM6 = 30; + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 3 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 5 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 6 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 10 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 15 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 30 pass." << std::endl; + + constexpr static int NUM7 = 21; + test_work_group_KV_private_sort( + q, ikeys, ivals8, work_group_sorter8); + std::cout << "KV private sort NUM = " << NUM7 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals8, work_group_sorter8); + std::cout << "KV private sort NUM = " << NUM7 + << ", WG = 3 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals8, work_group_sorter8); + std::cout << "KV private sort NUM = " << NUM7 + << ", WG = 7 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals8, work_group_sorter8); + std::cout << "KV private sort NUM = " << NUM7 + << ", WG = 21 pass." << std::endl; + } +} diff --git a/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_u8.cpp b/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_u8.cpp new file mode 100644 index 0000000000000..62e752934dfbd --- /dev/null +++ b/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_KV_sort_u8.cpp @@ -0,0 +1,607 @@ +// RUN: %{build} -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -fsycl-device-lib-jit-link -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DDES -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DDES -fsycl-device-lib-jit-link -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DSPREAD -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DSPREAD -fsycl-device-lib-jit-link -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DDES -DSPREAD -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DDES -DSPREAD -fsycl-device-lib-jit-link -o %t.out +// RUN: %{run} %t.out + +// UNSUPPORTED: cuda || hip || cpu + +#include "group_private_KV_sort_p1p1_p1.hpp" + +int main() { + queue q; + + { + constexpr static int NUM = 35; + uint8_t ikeys[NUM] = {1, 11, 1, 9, 3, 100, 34, 8, 121, + 77, 125, 23, 36, 2, 111, 91, 88, 2, + 51, 213, 181, 183, 31, 142, 216, 1, 199, + 124, 12, 0, 181, 17, 15, 101, 44}; + uint32_t ivals[NUM] = {99, 32, 1, 2, 67, 9123, 453, + 435, 91111, 777, 165, 145, 2456, 88811, + 761, 96, 765, 10000, 6364, 90, 525, + 882, 1, 2423, 9, 4324, 9123, 0, + 1232, 777, 555, 314159, 905, 9831, 84341}; + uint8_t ivals1[NUM] = {99, 32, 1, 2, 67, 91, 45, 43, 91, 77, 16, 14, + 24, 88, 76, 96, 76, 100, 63, 90, 52, 82, 1, 22, + 9, 225, 127, 0, 12, 128, 3, 102, 200, 111, 123}; + int8_t ivals2[NUM] = {-99, 127, -121, 100, 9, 5, 12, 35, -98, + 77, 112, -91, 11, 12, 3, 71, -66, 121, + 18, 14, 21, -22, 54, 88, -81, 31, 23, + 53, 97, 103, 71, 83, 97, 37, -41}; + uint16_t ivals3[NUM] = {28831, 23870, 54250, 5022, 9571, 60147, 9554, + 18818, 28689, 18229, 40512, 23200, 40454, 24841, + 43251, 63264, 29448, 45917, 882, 30788, 7586, + 57541, 22108, 59535, 31880, 7152, 63919, 58703, + 14686, 29914, 5872, 35868, 51479, 22721, 50927}; + int16_t ivals4[NUM] = { + 2798, -13656, 1592, 3992, -25870, 25172, 7761, -18347, 1617, + 25472, 26763, -5982, 24791, 27189, 22911, 22502, 15801, 25326, + -2196, 9205, -10418, 20464, -16616, -11285, 7249, 22866, 30574, + -1298, 31351, 28252, 21322, -10072, 7874, -26785, 22016}; + + uint32_t ivals5[NUM] = { + 2238578408, 102907035, 2316773768, 617902655, 532045482, 73173328, + 1862406505, 142735533, 3494078873, 610196959, 4210902254, 1863122236, + 1257721692, 30008197, 3199012044, 3503276708, 3504950001, 1240383071, + 2463430884, 904104390, 4044803029, 3164373711, 1586440767, 1999536602, + 3377662770, 927279985, 1740225703, 1133653675, 3975816601, 260339911, + 1115507520, 2279020820, 4289105012, 692964674, 53775301}; + + int32_t ivals6[NUM] = { + 507394811, 1949685322, 1624859474, -940434061, -1440675113, + -2002743224, 369969519, 840772268, 224522238, 296113452, + -714007528, 480713824, 665592454, 1696360848, 780843358, + -1901994531, 1667711523, 1390737696, 1357434904, -290165630, + 305128121, 1301489180, 630469211, -1385846315, 809333959, + 1098974670, 56900257, 876775101, -1496897817, 1172877939, + 1528916082, 559152364, 749878571, 2071902702, -430851798}; + + uint64_t ivals7[NUM] = {7916688577774406903ULL, 16873442000174444235ULL, + 5261140449171682429ULL, 6274209348061756377ULL, + 17881284978944229367ULL, 4701456380424752599ULL, + 6241062870187348613ULL, 8972524466137433448ULL, + 468629112944127776ULL, 17523909311582893643ULL, + 17447772191733931166ULL, 14311152396797789854ULL, + 9265327272409079662ULL, 9958475911404242556ULL, + 15359829357736068242ULL, 11416531655738519189ULL, + 16839972510321195914ULL, 1927049095689256442ULL, + 3356565661065236628ULL, 1065114285796701562ULL, + 7071763288904613033ULL, 16473053015775147286ULL, + 10317354477399696817ULL, 16005969584273256379ULL, + 15391010921709289298ULL, 17671303749287233862ULL, + 8028596930095411867ULL, 10265936863337610975ULL, + 17403745948359398749ULL, 8504886230707230194ULL, + 12855085916215721214ULL, 5562885793068933146ULL, + 1508385574711135517ULL, 5953119477818575536ULL, + 9165320150094769334ULL}; + + int64_t ivals8[NUM] = { + 2944696543084623337, 137239101631340692, 4370169869966467498, + 3842452903439153631, -795080033670806202, 3023506421574592237, + -4142692575864168559, 1716333381567689984, 1591746912204250089, + -1974664662220599925, 3144022139297218102, -371429365537296255, + 4202906659701034264, 3878513012313576184, -3425767072006791628, + -2929337291418891626, 1880013370888913338, 1498977159463939728, + -2928775660744278650, 4074214991200977615, 4291797122649374026, + -763110360214750992, 2883673064242977727, 4270151072450399778, + 1408696225027958214, 1264214335825459628, -4152065441956669638, + 2684706424226400837, 569335474272794084, -2798088842177577224, + 814002749054152728, 2517003920904582842, 4089891582575745386, + 705067059635512048, -2500935118374519236}; + + auto work_group_sorter = [](uint8_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1u8_p1u32_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1u8_p1u32_u32_p1i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1u8_p1u32_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1u8_p1u32_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter1 = [](uint8_t *keys, uint8_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1u8_p1u8_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1u8_p1u8_u32_p1i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1u8_p1u8_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1u8_p1u8_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter2 = [](uint8_t *keys, int8_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1u8_p1i8_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1u8_p1i8_u32_p1i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1u8_p1i8_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1u8_p1i8_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter3 = [](uint8_t *keys, uint16_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1u8_p1u16_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1u8_p1u16_u32_p1i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1u8_p1u16_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1u8_p1u16_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter4 = [](uint8_t *keys, int16_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1u8_p1i16_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1u8_p1i16_u32_p1i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1u8_p1i16_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1u8_p1i16_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter5 = [](uint8_t *keys, uint32_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#if SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1u8_p1u32_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1u8_p1u32_u32_p1i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1u8_p1u32_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1u8_p1u32_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter6 = [](uint8_t *keys, int32_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1u8_p1i32_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1u8_p1i32_u32_p1i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1u8_p1i32_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1u8_p1i32_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter7 = [](uint8_t *keys, uint64_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1u8_p1u64_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1u8_p1u64_u32_p1i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1u8_p1u64_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1u8_p1u64_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + auto work_group_sorter8 = [](uint8_t *keys, int64_t *vals, uint32_t n, + uint8_t *scratch) { +#if __DEVICE_CODE +#ifdef DES +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_descending_p1u8_p1i64_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_descending_p1u8_p1i64_u32_p1i8( + keys, vals, n, scratch); +#endif +#else +#ifdef SPREAD + __devicelib_default_work_group_private_sort_spread_ascending_p1u8_p1i64_u32_p1i8( + keys, vals, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1u8_p1i64_u32_p1i8( + keys, vals, n, scratch); +#endif +#endif +#endif + }; + + constexpr static int NUM1 = 32; + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 4 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 8 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 16 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals1, work_group_sorter1); + std::cout << "KV private sort NUM = " << NUM1 + << ", WG = 32 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 5 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 7 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals2, work_group_sorter2); + std::cout << "KV private sort NUM = " << NUM + << ", WG = 35 pass." << std::endl; + + constexpr static int NUM2 = 24; + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 3 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 4 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 6 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 8 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 12 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals3, work_group_sorter3); + std::cout << "KV private sort NUM = " << NUM2 + << ", WG = 24 pass." << std::endl; + + constexpr static int NUM3 = 20; + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 4 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 5 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 10 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals4, work_group_sorter4); + std::cout << "KV private sort NUM = " << NUM3 + << ", WG = 20 pass." << std::endl; + + constexpr static int NUM4 = 30; + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 3 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 5 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 6 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 10 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 15 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals5, work_group_sorter5); + std::cout << "KV private sort NUM = " << NUM4 + << ", WG = 30 pass." << std::endl; + + constexpr static int NUM5 = 25; + test_work_group_KV_private_sort( + q, ikeys, ivals6, work_group_sorter6); + std::cout << "KV private sort NUM = " << NUM5 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals6, work_group_sorter6); + std::cout << "KV private sort NUM = " << NUM5 + << ", WG = 5 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals6, work_group_sorter6); + std::cout << "KV private sort NUM = " << NUM5 + << ", WG = 25 pass." << std::endl; + + constexpr static int NUM6 = 30; + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 2 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 3 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 5 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 6 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 10 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 15 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals7, work_group_sorter7); + std::cout << "KV private sort NUM = " << NUM6 + << ", WG = 30 pass." << std::endl; + + constexpr static int NUM7 = 21; + test_work_group_KV_private_sort( + q, ikeys, ivals8, work_group_sorter8); + std::cout << "KV private sort NUM = " << NUM7 + << ", WG = 1 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals8, work_group_sorter8); + std::cout << "KV private sort NUM = " << NUM7 + << ", WG = 3 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals8, work_group_sorter8); + std::cout << "KV private sort NUM = " << NUM7 + << ", WG = 7 pass." << std::endl; + + test_work_group_KV_private_sort( + q, ikeys, ivals8, work_group_sorter8); + std::cout << "KV private sort NUM = " << NUM7 + << ", WG = 21 pass." << std::endl; + } +} diff --git a/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_sort_p1p1.cpp b/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_sort_p1p1.cpp new file mode 100644 index 0000000000000..e1c7eb6b8dc83 --- /dev/null +++ b/sycl/test-e2e/DeviceLib/group_sort/workgroup_private_sort_p1p1.cpp @@ -0,0 +1,195 @@ +// RUN: %{build} -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DCLOSE -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -fsycl-device-lib-jit-link -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DCLOSE -fsycl-device-lib-jit-link -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DDES -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DDES -DCLOSE -o %t.out +// RUN: %{run} %t.out + +// RUN: %{build} -DDES -DCLOSE -fsycl-device-lib-jit-link -o %t.out +// RUN: %{run} %t.out +// +// UNSUPPORTED: cuda || hip + +#include "group_private_sort_p1p1.hpp" +#include +#include +#include +#include +#include +using namespace sycl; +template +void test_work_group_private_sort(sycl::queue &q, Ty input[NUM], + SortHelper gsh) { + static_assert(NUM % WG_SZ == 0, + "Input size must be divisible by Work group size!"); + // Scratch memory size >= NUM * sizeof(Ty) * 2 + Ty scratch[NUM * 2] = { + 0, + }; + Ty result[NUM]; + Ty reference[NUM]; + memcpy(reference, input, sizeof(Ty) * NUM); +#ifdef DES + std::sort(&reference[0], &reference[NUM], std::greater()); +#else + std::sort(&reference[0], &reference[NUM]); +#endif + const static size_t wg_size = WG_SZ; + constexpr size_t num_per_work_item = NUM / WG_SZ; + nd_range<1> num_items((range<1>(wg_size)), (range<1>(wg_size))); + + { + buffer ibuf(input, NUM); + buffer obuf(scratch, NUM * 2); + buffer rbuf(result, NUM); + q.submit([&](auto &h) { + accessor in_acc{ibuf, h}; + accessor out_acc{obuf, h}; + accessor re_acc{rbuf, h}; + h.parallel_for(num_items, [=](nd_item<1> i) { + Ty private_buf[num_per_work_item]; + for (size_t idx = 0; idx < num_per_work_item; ++idx) + private_buf[idx] = + in_acc[i.get_local_linear_id() * num_per_work_item + idx]; + group_barrier(i.get_group()); + Ty *optr = + out_acc.template get_multi_ptr().get(); + uint8_t *by = reinterpret_cast(optr); + gsh(private_buf, num_per_work_item, by); + for (size_t idx = 0; idx < num_per_work_item; ++idx) + re_acc[i.get_local_linear_id() * num_per_work_item + idx] = + private_buf[idx]; + }); + }).wait(); + } + + bool fails = false; + +#ifdef CLOSE + for (size_t idx = 0; idx < NUM; ++idx) { + if (result[idx] != reference[idx]) { + fails = true; + break; + } + } +#else + for (size_t idx = 0; idx < NUM; ++idx) { + size_t idx1 = idx % WG_SZ; + size_t idx2 = idx / WG_SZ; + if (reference[idx] != result[idx1 * num_per_work_item + idx2]) { + fails = true; + break; + } + } +#endif + + assert(!fails); +} + +int main() { + queue q; + + { + constexpr static int NUM = 24; + int8_t a[NUM] = {-1, 11, 10, 9, 3, 100, 34, 8, 10, 77, 10, 103, + -12, -93, 23, 36, 2, 111, 91, 88, -2, -25, 98, -111}; + auto work_group_sorter = [](int8_t *first, uint32_t n, uint8_t *scratch) { +#ifdef CLOSE +#ifdef DES + __devicelib_default_work_group_private_sort_close_descending_p1i8_u32_p1i8( + first, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1i8_u32_p1i8( + first, n, scratch); +#endif +#else +#ifdef DES + __devicelib_default_work_group_private_sort_spread_descending_p1i8_u32_p1i8( + first, n, scratch); +#else + __devicelib_default_work_group_private_sort_spread_ascending_p1i8_u32_p1i8( + first, n, scratch); +#endif +#endif + }; + test_work_group_private_sort( + q, a, work_group_sorter); + std::cout << "work group private sort p1i8_u32_p1i8 passes" << std::endl; + } + + { + constexpr static int NUM = 32; + int16_t a[NUM] = {2162, 29891, 14709, -20987, -30051, -26861, 5629, + -11244, 25702, 29438, 22560, -15282, 27812, 28455, + 26871, -22327, 6495, 23519, 19389, 26328, 13253, + -24369, -1616, 3278, 5624, -6317, -3669, 11874, + -46, -4717, -27449, -9790}; + auto work_group_sorter = [](int16_t *first, uint32_t n, uint8_t *scratch) { +#ifdef CLOSE +#ifdef DES + __devicelib_default_work_group_private_sort_close_descending_p1i16_u32_p1i8( + first, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1i16_u32_p1i8( + first, n, scratch); +#endif +#else +#ifdef DES + __devicelib_default_work_group_private_sort_spread_descending_p1i16_u32_p1i8( + first, n, scratch); +#else + __devicelib_default_work_group_private_sort_spread_ascending_p1i16_u32_p1i8( + first, n, scratch); +#endif +#endif + }; + test_work_group_private_sort( + q, a, work_group_sorter); + std::cout << "work group private sort p1i16_u32_p1i8 passes" << std::endl; + } + + { + constexpr static int NUM = 32; + int32_t a[NUM] = {1319329913, -390041276, -2040725419, -217333100, + -900793956, -2138508211, 769705434, 122767310, + -1918605668, -16813517, 1616926513, -2141526068, + 631985359, 541606467, 662050754, 140359040, + 1834119354, 1910851165, 809736505, 451506849, + -1713169862, -1916401837, 1490159094, -2066441094, + -332318833, -1550930943, 1763101596, 500568854, + -1574546569, -596440302, 1522396193, -980468122}; + auto work_group_sorter = [](int32_t *first, uint32_t n, uint8_t *scratch) { +#ifdef CLOSE +#ifdef DES + __devicelib_default_work_group_private_sort_close_descending_p1i32_u32_p1i8( + first, n, scratch); +#else + __devicelib_default_work_group_private_sort_close_ascending_p1i32_u32_p1i8( + first, n, scratch); +#endif +#else +#ifdef DES + __devicelib_default_work_group_private_sort_spread_descending_p1i32_u32_p1i8( + first, n, scratch); +#else + __devicelib_default_work_group_private_sort_spread_ascending_p1i32_u32_p1i8( + first, n, scratch); +#endif +#endif + }; + test_work_group_private_sort( + q, a, work_group_sorter); + std::cout << "work group private sort p1i32_u32_p1i8 passes" << std::endl; + } +}