diff --git a/sxt/algorithm/block/runlength_count.h b/sxt/algorithm/block/runlength_count.h index d621c534..2ec94d84 100644 --- a/sxt/algorithm/block/runlength_count.h +++ b/sxt/algorithm/block/runlength_count.h @@ -20,81 +20,83 @@ * This is a workaround to define _VSTD before including cub/cub.cuh. * It should be removed when we can upgrade to a newer version of CUDA. * - * We need to define _VSTD in order to use the clang version defined in + * We need to define _VSTD in order to use the clang version defined in * clang.nix and the CUDA toolkit version defined in cuda.nix. - * + * * _VSTD was deprecated and removed from the LLVM truck. * NVIDIA: https://github.com/NVIDIA/cccl/pull/1331 * LLVM: https://github.com/llvm/llvm-project/commit/683bc94e1637bd9bacc978f5dc3c79cfc8ff94b9 - * - * We cannot currently use any CUDA toolkit above 12.4.1 because the Kubernetes - * cluster currently cannot install a driver above 550. - * - * See CUDA toolkit and driver support: https://docs.nvidia.com/cuda/cuda-toolkit-release-notes/index.html + * + * We cannot currently use any CUDA toolkit above 12.4.1 because the Kubernetes + * cluster currently cannot install a driver above 550. + * + * See CUDA toolkit and driver support: + * https://docs.nvidia.com/cuda/cuda-toolkit-release-notes/index.html */ #include <__config> + #define _VSTD std::_LIBCPP_ABI_NAMESPACE _LIBCPP_BEGIN_NAMESPACE_STD _LIBCPP_END_NAMESPACE_STD #include "cub/cub.cuh" #include "sxt/base/macro/cuda_callable.h" -namespace sxt::algbk { -//-------------------------------------------------------------------------------------------------- -// runlength_count -//-------------------------------------------------------------------------------------------------- -/** - * This is adapted from block_histogram_sort in CUB. See - * https://github.com/NVIDIA/cccl/blob/a51b1f8c75f8e577eeccc74b45f1ff16a2727265/cub/cub/block/specializations/block_histogram_sort.cuh - */ -template class runlength_count { - using Discontinuity = cub::BlockDiscontinuity; + namespace sxt::algbk { + //-------------------------------------------------------------------------------------------------- + // runlength_count + //-------------------------------------------------------------------------------------------------- + /** + * This is adapted from block_histogram_sort in CUB. See + * https://github.com/NVIDIA/cccl/blob/a51b1f8c75f8e577eeccc74b45f1ff16a2727265/cub/cub/block/specializations/block_histogram_sort.cuh + */ + template class runlength_count { + using Discontinuity = cub::BlockDiscontinuity; -public: - struct temp_storage { - typename Discontinuity::TempStorage discontinuity; - CounterT run_begin[NumBins]; - CounterT run_end[NumBins]; - }; + public: + struct temp_storage { + typename Discontinuity::TempStorage discontinuity; + CounterT run_begin[NumBins]; + CounterT run_end[NumBins]; + }; - CUDA_CALLABLE explicit runlength_count(temp_storage& storage) noexcept - : storage_{storage}, discontinuity_{storage.discontinuity} {} + CUDA_CALLABLE explicit runlength_count(temp_storage& storage) noexcept + : storage_{storage}, discontinuity_{storage.discontinuity} {} - /** - * If items holds sorted items across threads in a block, count and return a - * pointer to a table of the items' run lengths. - */ - template - CUDA_CALLABLE CounterT* count(T (&items)[ItemsPerThread]) noexcept { - auto thread_id = threadIdx.x; - for (unsigned i = thread_id; i < NumBins; i += NumThreads) { - storage_.run_begin[i] = NumThreads * ItemsPerThread; - storage_.run_end[i] = NumThreads * ItemsPerThread; - } - int flags[ItemsPerThread]; - auto flag_op = [&storage = storage_](T a, T b, int b_index) noexcept { - if (a != b) { - storage.run_begin[b] = static_cast(b_index); - storage.run_end[a] = static_cast(b_index); - return true; - } else { - return false; + /** + * If items holds sorted items across threads in a block, count and return a + * pointer to a table of the items' run lengths. + */ + template + CUDA_CALLABLE CounterT* count(T (&items)[ItemsPerThread]) noexcept { + auto thread_id = threadIdx.x; + for (unsigned i = thread_id; i < NumBins; i += NumThreads) { + storage_.run_begin[i] = NumThreads * ItemsPerThread; + storage_.run_end[i] = NumThreads * ItemsPerThread; } - }; - __syncthreads(); - discontinuity_.FlagHeads(flags, items, flag_op); - if (thread_id == 0) { - storage_.run_begin[items[0]] = 0; - } - __syncthreads(); - for (unsigned i = thread_id; i < NumBins; i += NumThreads) { - storage_.run_end[i] -= storage_.run_begin[i]; + int flags[ItemsPerThread]; + auto flag_op = [&storage = storage_](T a, T b, int b_index) noexcept { + if (a != b) { + storage.run_begin[b] = static_cast(b_index); + storage.run_end[a] = static_cast(b_index); + return true; + } else { + return false; + } + }; + __syncthreads(); + discontinuity_.FlagHeads(flags, items, flag_op); + if (thread_id == 0) { + storage_.run_begin[items[0]] = 0; + } + __syncthreads(); + for (unsigned i = thread_id; i < NumBins; i += NumThreads) { + storage_.run_end[i] -= storage_.run_begin[i]; + } + return storage_.run_end; } - return storage_.run_end; - } -private: - temp_storage& storage_; - Discontinuity discontinuity_; -}; + private: + temp_storage& storage_; + Discontinuity discontinuity_; + }; } // namespace sxt::algbk diff --git a/sxt/multiexp/base/scalar_array.cc b/sxt/multiexp/base/scalar_array.cc index 7cc4f271..21654d64 100644 --- a/sxt/multiexp/base/scalar_array.cc +++ b/sxt/multiexp/base/scalar_array.cc @@ -22,17 +22,18 @@ * This is a workaround to define _VSTD before including cub/cub.cuh. * It should be removed when we can upgrade to a newer version of CUDA. * - * We need to define _VSTD in order to use the clang version defined in + * We need to define _VSTD in order to use the clang version defined in * clang.nix and the CUDA toolkit version defined in cuda.nix. - * + * * _VSTD was deprecated and removed from the LLVM truck. * NVIDIA: https://github.com/NVIDIA/cccl/pull/1331 * LLVM: https://github.com/llvm/llvm-project/commit/683bc94e1637bd9bacc978f5dc3c79cfc8ff94b9 - * - * We cannot currently use any CUDA toolkit above 12.4.1 because the Kubernetes - * cluster currently cannot install a driver above 550. - * - * See CUDA toolkit and driver support: https://docs.nvidia.com/cuda/cuda-toolkit-release-notes/index.html + * + * We cannot currently use any CUDA toolkit above 12.4.1 because the Kubernetes + * cluster currently cannot install a driver above 550. + * + * See CUDA toolkit and driver support: + * https://docs.nvidia.com/cuda/cuda-toolkit-release-notes/index.html */ #include <__config> #define _VSTD std::_LIBCPP_ABI_NAMESPACE @@ -50,100 +51,100 @@ _LIBCPP_BEGIN_NAMESPACE_STD _LIBCPP_END_NAMESPACE_STD #include "sxt/memory/management/managed_array.h" #include "sxt/memory/resource/async_device_resource.h" -namespace sxt::mtxb { -//-------------------------------------------------------------------------------------------------- -// scalar_blob -//-------------------------------------------------------------------------------------------------- -namespace { -template struct scalar_blob { - uint8_t data[NumBytes]; -}; -} // namespace + namespace sxt::mtxb { + //-------------------------------------------------------------------------------------------------- + // scalar_blob + //-------------------------------------------------------------------------------------------------- + namespace { + template struct scalar_blob { + uint8_t data[NumBytes]; + }; + } // namespace -//-------------------------------------------------------------------------------------------------- -// transpose_kernel -//-------------------------------------------------------------------------------------------------- -template -static __global__ void transpose_kernel(uint8_t* __restrict__ dst, - const scalar_blob* __restrict__ src, - unsigned n) noexcept { - using Scalar = scalar_blob; + //-------------------------------------------------------------------------------------------------- + // transpose_kernel + //-------------------------------------------------------------------------------------------------- + template + static __global__ void transpose_kernel(uint8_t* __restrict__ dst, + const scalar_blob* __restrict__ src, + unsigned n) noexcept { + using Scalar = scalar_blob; - auto byte_index = threadIdx.x; - auto tile_index = blockIdx.x; - auto output_index = blockIdx.y; - auto num_tiles = gridDim.x; - auto n_per_tile = basn::divide_up(basn::divide_up(n, NumBytes), num_tiles) * NumBytes; + auto byte_index = threadIdx.x; + auto tile_index = blockIdx.x; + auto output_index = blockIdx.y; + auto num_tiles = gridDim.x; + auto n_per_tile = basn::divide_up(basn::divide_up(n, NumBytes), num_tiles) * NumBytes; - auto first = tile_index * n_per_tile; - auto m = min(n_per_tile, n - first); + auto first = tile_index * n_per_tile; + auto m = min(n_per_tile, n - first); - // adjust pointers - src += first; - src += output_index * n; - dst += byte_index * n + first; - dst += output_index * NumBytes * n; + // adjust pointers + src += first; + src += output_index * n; + dst += byte_index * n + first; + dst += output_index * NumBytes * n; - // set up algorithm - using BlockExchange = cub::BlockExchange; - __shared__ typename BlockExchange::TempStorage temp_storage; + // set up algorithm + using BlockExchange = cub::BlockExchange; + __shared__ typename BlockExchange::TempStorage temp_storage; - // transpose - Scalar s; - unsigned out_first = 0; - for (unsigned i = byte_index; i < n_per_tile; i += NumBytes) { - if (i < m) { - s = src[i]; - } - BlockExchange(temp_storage).StripedToBlocked(s.data); - __syncthreads(); - for (unsigned j = 0; j < NumBytes; ++j) { - auto out_index = out_first + j; - if (out_index < m) { - dst[out_index] = s.data[j]; + // transpose + Scalar s; + unsigned out_first = 0; + for (unsigned i = byte_index; i < n_per_tile; i += NumBytes) { + if (i < m) { + s = src[i]; } + BlockExchange(temp_storage).StripedToBlocked(s.data); + __syncthreads(); + for (unsigned j = 0; j < NumBytes; ++j) { + auto out_index = out_first + j; + if (out_index < m) { + dst[out_index] = s.data[j]; + } + } + out_first += NumBytes; + __syncthreads(); } - out_first += NumBytes; - __syncthreads(); } -} -//-------------------------------------------------------------------------------------------------- -// transpose_scalars_to_device -//-------------------------------------------------------------------------------------------------- -xena::future<> transpose_scalars_to_device(basct::span array, - basct::cspan scalars, - unsigned element_num_bytes, unsigned n) noexcept { - auto num_outputs = static_cast(scalars.size()); - if (n == 0 || num_outputs == 0) { - co_return; - } - SXT_DEBUG_ASSERT( - // clang-format off + //-------------------------------------------------------------------------------------------------- + // transpose_scalars_to_device + //-------------------------------------------------------------------------------------------------- + xena::future<> transpose_scalars_to_device(basct::span array, + basct::cspan scalars, + unsigned element_num_bytes, unsigned n) noexcept { + auto num_outputs = static_cast(scalars.size()); + if (n == 0 || num_outputs == 0) { + co_return; + } + SXT_DEBUG_ASSERT( + // clang-format off array.size() == num_outputs * element_num_bytes * n && basdv::is_active_device_pointer(array.data()) && basdv::is_host_pointer(scalars[0]) - // clang-format on - ); - basdv::stream stream; - memr::async_device_resource resource{stream}; - memmg::managed_array array_p{array.size(), &resource}; - auto num_bytes_per_output = element_num_bytes * n; - for (size_t output_index = 0; output_index < num_outputs; ++output_index) { - basdv::async_copy_host_to_device( - basct::subspan(array_p, output_index * num_bytes_per_output, num_bytes_per_output), - basct::cspan{scalars[output_index], num_bytes_per_output}, stream); + // clang-format on + ); + basdv::stream stream; + memr::async_device_resource resource{stream}; + memmg::managed_array array_p{array.size(), &resource}; + auto num_bytes_per_output = element_num_bytes * n; + for (size_t output_index = 0; output_index < num_outputs; ++output_index) { + basdv::async_copy_host_to_device( + basct::subspan(array_p, output_index * num_bytes_per_output, num_bytes_per_output), + basct::cspan{scalars[output_index], num_bytes_per_output}, stream); + } + auto num_tiles = std::min(basn::divide_up(n, num_outputs * element_num_bytes), 64u); + auto num_bytes_log2 = basn::ceil_log2(element_num_bytes); + basn::constexpr_switch<6>( + num_bytes_log2, + [&](std::integral_constant) noexcept { + constexpr auto NumBytes = 1u << LogNumBytes; + SXT_DEBUG_ASSERT(NumBytes == element_num_bytes); + transpose_kernel<<>>( + array.data(), reinterpret_cast*>(array_p.data()), n); + }); + co_await xendv::await_stream(std::move(stream)); } - auto num_tiles = std::min(basn::divide_up(n, num_outputs * element_num_bytes), 64u); - auto num_bytes_log2 = basn::ceil_log2(element_num_bytes); - basn::constexpr_switch<6>( - num_bytes_log2, - [&](std::integral_constant) noexcept { - constexpr auto NumBytes = 1u << LogNumBytes; - SXT_DEBUG_ASSERT(NumBytes == element_num_bytes); - transpose_kernel<<>>( - array.data(), reinterpret_cast*>(array_p.data()), n); - }); - co_await xendv::await_stream(std::move(stream)); -} } // namespace sxt::mtxb diff --git a/sxt/multiexp/bucket_method2/multiproduct_table.cc b/sxt/multiexp/bucket_method2/multiproduct_table.cc index 54924cb5..57b68e71 100644 --- a/sxt/multiexp/bucket_method2/multiproduct_table.cc +++ b/sxt/multiexp/bucket_method2/multiproduct_table.cc @@ -20,17 +20,18 @@ * This is a workaround to define _VSTD before including cub/cub.cuh. * It should be removed when we can upgrade to a newer version of CUDA. * - * We need to define _VSTD in order to use the clang version defined in + * We need to define _VSTD in order to use the clang version defined in * clang.nix and the CUDA toolkit version defined in cuda.nix. - * + * * _VSTD was deprecated and removed from the LLVM truck. * NVIDIA: https://github.com/NVIDIA/cccl/pull/1331 * LLVM: https://github.com/llvm/llvm-project/commit/683bc94e1637bd9bacc978f5dc3c79cfc8ff94b9 - * - * We cannot currently use any CUDA toolkit above 12.4.1 because the Kubernetes - * cluster currently cannot install a driver above 550. - * - * See CUDA toolkit and driver support: https://docs.nvidia.com/cuda/cuda-toolkit-release-notes/index.html + * + * We cannot currently use any CUDA toolkit above 12.4.1 because the Kubernetes + * cluster currently cannot install a driver above 550. + * + * See CUDA toolkit and driver support: + * https://docs.nvidia.com/cuda/cuda-toolkit-release-notes/index.html */ #include <__config> #define _VSTD std::_LIBCPP_ABI_NAMESPACE @@ -50,56 +51,55 @@ _LIBCPP_BEGIN_NAMESPACE_STD _LIBCPP_END_NAMESPACE_STD #include "sxt/multiexp/base/scalar_array.h" #include "sxt/multiexp/bucket_method2/multiproduct_table_kernel.h" -namespace sxt::mtxbk2 { -//-------------------------------------------------------------------------------------------------- -// make_multiproduct_table -//-------------------------------------------------------------------------------------------------- -xena::future<> make_multiproduct_table(basct::span bucket_prefix_counts, - basct::span indexes, - basct::cspan scalars, - unsigned element_num_bytes, unsigned bit_width, - unsigned n) noexcept { - auto num_outputs = scalars.size(); - auto num_buckets_per_digit = (1u << bit_width) - 1u; - auto num_digits = basn::divide_up(element_num_bytes * 8u, bit_width); - auto num_buckets_per_output = num_buckets_per_digit * num_digits; - auto num_buckets_total = num_buckets_per_output * num_outputs; - SXT_DEBUG_ASSERT(bucket_prefix_counts.size() == num_buckets_total && - indexes.size() == num_outputs * num_digits * n && - basdv::is_active_device_pointer(bucket_prefix_counts.data()) && - basdv::is_active_device_pointer(indexes.data())); + namespace sxt::mtxbk2 { + //-------------------------------------------------------------------------------------------------- + // make_multiproduct_table + //-------------------------------------------------------------------------------------------------- + xena::future<> make_multiproduct_table( + basct::span bucket_prefix_counts, basct::span indexes, + basct::cspan scalars, unsigned element_num_bytes, unsigned bit_width, + unsigned n) noexcept { + auto num_outputs = scalars.size(); + auto num_buckets_per_digit = (1u << bit_width) - 1u; + auto num_digits = basn::divide_up(element_num_bytes * 8u, bit_width); + auto num_buckets_per_output = num_buckets_per_digit * num_digits; + auto num_buckets_total = num_buckets_per_output * num_outputs; + SXT_DEBUG_ASSERT(bucket_prefix_counts.size() == num_buckets_total && + indexes.size() == num_outputs * num_digits * n && + basdv::is_active_device_pointer(bucket_prefix_counts.data()) && + basdv::is_active_device_pointer(indexes.data())); - // transpose scalars - basl::info("copying scalars to device"); - memmg::managed_array bytes{num_outputs * n * element_num_bytes, - memr::get_device_resource()}; - co_await mtxb::transpose_scalars_to_device(bytes, scalars, element_num_bytes, n); + // transpose scalars + basl::info("copying scalars to device"); + memmg::managed_array bytes{num_outputs * n * element_num_bytes, + memr::get_device_resource()}; + co_await mtxb::transpose_scalars_to_device(bytes, scalars, element_num_bytes, n); - // compute buckets - basl::info("computing multiproduct decomposition"); - SXT_RELEASE_ASSERT(bit_width == 8u, "only support bit_width == 8u for now"); - SXT_RELEASE_ASSERT(n <= max_multiexponentiation_length_v, "limit length for now"); - basdv::stream stream; - fit_multiproduct_table_kernel( - [&]( - std::integral_constant, - std::integral_constant) noexcept { - multiproduct_table_kernel - <<>>( - bucket_prefix_counts.data(), indexes.data(), bytes.data(), n); - }, - n); + // compute buckets + basl::info("computing multiproduct decomposition"); + SXT_RELEASE_ASSERT(bit_width == 8u, "only support bit_width == 8u for now"); + SXT_RELEASE_ASSERT(n <= max_multiexponentiation_length_v, "limit length for now"); + basdv::stream stream; + fit_multiproduct_table_kernel( + [&]( + std::integral_constant, + std::integral_constant) noexcept { + multiproduct_table_kernel + <<>>( + bucket_prefix_counts.data(), indexes.data(), bytes.data(), n); + }, + n); - // prefix sum - auto f = [bucket_prefix_counts = bucket_prefix_counts.data(), - num_buckets_per_digit = num_buckets_per_digit] __host__ - __device__(unsigned /*num_digits_total*/, unsigned index) noexcept { - auto counts = bucket_prefix_counts + index * num_buckets_per_digit; - for (unsigned i = 1; i < num_buckets_per_digit; ++i) { - counts[i] += counts[i - 1u]; - } - }; - algi::launch_for_each_kernel(stream, f, num_digits * num_outputs); - co_await xendv::await_stream(stream); -} + // prefix sum + auto f = [bucket_prefix_counts = bucket_prefix_counts.data(), + num_buckets_per_digit = num_buckets_per_digit] __host__ + __device__(unsigned /*num_digits_total*/, unsigned index) noexcept { + auto counts = bucket_prefix_counts + index * num_buckets_per_digit; + for (unsigned i = 1; i < num_buckets_per_digit; ++i) { + counts[i] += counts[i - 1u]; + } + }; + algi::launch_for_each_kernel(stream, f, num_digits * num_outputs); + co_await xendv::await_stream(stream); + } } // namespace sxt::mtxbk2 diff --git a/sxt/multiexp/bucket_method2/multiproduct_table_kernel.h b/sxt/multiexp/bucket_method2/multiproduct_table_kernel.h index 1e8da12c..3e8e5928 100644 --- a/sxt/multiexp/bucket_method2/multiproduct_table_kernel.h +++ b/sxt/multiexp/bucket_method2/multiproduct_table_kernel.h @@ -22,17 +22,18 @@ * This is a workaround to define _VSTD before including cub/cub.cuh. * It should be removed when we can upgrade to a newer version of CUDA. * - * We need to define _VSTD in order to use the clang version defined in + * We need to define _VSTD in order to use the clang version defined in * clang.nix and the CUDA toolkit version defined in cuda.nix. - * + * * _VSTD was deprecated and removed from the LLVM truck. * NVIDIA: https://github.com/NVIDIA/cccl/pull/1331 * LLVM: https://github.com/llvm/llvm-project/commit/683bc94e1637bd9bacc978f5dc3c79cfc8ff94b9 - * - * We cannot currently use any CUDA toolkit above 12.4.1 because the Kubernetes - * cluster currently cannot install a driver above 550. - * - * See CUDA toolkit and driver support: https://docs.nvidia.com/cuda/cuda-toolkit-release-notes/index.html + * + * We cannot currently use any CUDA toolkit above 12.4.1 because the Kubernetes + * cluster currently cannot install a driver above 550. + * + * See CUDA toolkit and driver support: + * https://docs.nvidia.com/cuda/cuda-toolkit-release-notes/index.html */ #include <__config> #define _VSTD std::_LIBCPP_ABI_NAMESPACE @@ -45,81 +46,81 @@ _LIBCPP_BEGIN_NAMESPACE_STD _LIBCPP_END_NAMESPACE_STD #include "sxt/base/num/divide_up.h" #include "sxt/multiexp/bucket_method2/constants.h" -namespace sxt::mtxbk2 { -//-------------------------------------------------------------------------------------------------- -// multiproduct_table_kernel -//-------------------------------------------------------------------------------------------------- -template -__global__ void multiproduct_table_kernel(uint16_t* __restrict__ bucket_counts, - uint16_t* __restrict__ indexes, - const uint8_t* __restrict__ bytes, unsigned n) { - uint16_t thread_index = threadIdx.x; - auto digit_index = blockIdx.x; - auto output_index = blockIdx.y; - auto num_digits = gridDim.x; - auto num_buckets_per_digit = (1u << BitWidth) - 1u; + namespace sxt::mtxbk2 { + //-------------------------------------------------------------------------------------------------- + // multiproduct_table_kernel + //-------------------------------------------------------------------------------------------------- + template + __global__ void multiproduct_table_kernel(uint16_t* __restrict__ bucket_counts, + uint16_t* __restrict__ indexes, + const uint8_t* __restrict__ bytes, unsigned n) { + uint16_t thread_index = threadIdx.x; + auto digit_index = blockIdx.x; + auto output_index = blockIdx.y; + auto num_digits = gridDim.x; + auto num_buckets_per_digit = (1u << BitWidth) - 1u; - // algorithms and shared memory - using RadixSort = cub::BlockRadixSort; - using RunlengthCount = algbk::runlength_count; - __shared__ union { - RadixSort::TempStorage sort; - RunlengthCount::temp_storage count; - } temp_storage; + // algorithms and shared memory + using RadixSort = cub::BlockRadixSort; + using RunlengthCount = algbk::runlength_count; + __shared__ union { + RadixSort::TempStorage sort; + RunlengthCount::temp_storage count; + } temp_storage; - // adjust pointers - bucket_counts += digit_index * num_buckets_per_digit; - bucket_counts += output_index * num_digits * num_buckets_per_digit; - indexes += digit_index * n; - indexes += output_index * num_digits * n; - bytes += digit_index * n; - bytes += output_index * num_digits * n; + // adjust pointers + bucket_counts += digit_index * num_buckets_per_digit; + bucket_counts += output_index * num_digits * num_buckets_per_digit; + indexes += digit_index * n; + indexes += output_index * num_digits * n; + bytes += digit_index * n; + bytes += output_index * num_digits * n; - // load bytes - uint8_t keys[ItemsPerThread]; - uint16_t values[ItemsPerThread]; - for (uint16_t i = 0; i < ItemsPerThread; ++i) { - auto index = thread_index + i * NumThreads; - if (index < n) { - keys[i] = bytes[index]; - values[i] = index; - } else { - keys[i] = 0; - values[i] = 0; + // load bytes + uint8_t keys[ItemsPerThread]; + uint16_t values[ItemsPerThread]; + for (uint16_t i = 0; i < ItemsPerThread; ++i) { + auto index = thread_index + i * NumThreads; + if (index < n) { + keys[i] = bytes[index]; + values[i] = index; + } else { + keys[i] = 0; + values[i] = 0; + } } - } - // sort - RadixSort(temp_storage.sort).Sort(keys, values); - __syncthreads(); + // sort + RadixSort(temp_storage.sort).Sort(keys, values); + __syncthreads(); - // count - auto counts = RunlengthCount(temp_storage.count).count(keys); - __syncthreads(); + // count + auto counts = RunlengthCount(temp_storage.count).count(keys); + __syncthreads(); - // write counts - for (unsigned i = thread_index; i < num_buckets_per_digit; i += NumThreads) { - bucket_counts[i] = counts[i + 1]; - } + // write counts + for (unsigned i = thread_index; i < num_buckets_per_digit; i += NumThreads) { + bucket_counts[i] = counts[i + 1]; + } - // write indexes - auto zero_count = counts[0]; - for (unsigned i = 0; i < ItemsPerThread; ++i) { - auto index = i + thread_index * ItemsPerThread; - if (index >= zero_count) { - indexes[index - zero_count] = values[i]; + // write indexes + auto zero_count = counts[0]; + for (unsigned i = 0; i < ItemsPerThread; ++i) { + auto index = i + thread_index * ItemsPerThread; + if (index >= zero_count) { + indexes[index - zero_count] = values[i]; + } } } -} -//-------------------------------------------------------------------------------------------------- -// fit_multiproduct_table_kernel -//-------------------------------------------------------------------------------------------------- -template void fit_multiproduct_table_kernel(F f, unsigned n) noexcept { - SXT_RELEASE_ASSERT(n <= max_multiexponentiation_length_v); - basn::constexpr_switch<1, max_multiexponentiation_length_v / 128u + 1u>( - basn::divide_up(n, 128u), [&](std::integral_constant) noexcept { - return f(std::integral_constant{}, std::integral_constant{}); - }); -} + //-------------------------------------------------------------------------------------------------- + // fit_multiproduct_table_kernel + //-------------------------------------------------------------------------------------------------- + template void fit_multiproduct_table_kernel(F f, unsigned n) noexcept { + SXT_RELEASE_ASSERT(n <= max_multiexponentiation_length_v); + basn::constexpr_switch<1, max_multiexponentiation_length_v / 128u + 1u>( + basn::divide_up(n, 128u), [&](std::integral_constant) noexcept { + return f(std::integral_constant{}, std::integral_constant{}); + }); + } } // namespace sxt::mtxbk2