Skip to content

Commit 5466563

Browse files
Add WarpReduce Device-Side Benchmarks (#6431)
Co-authored-by: Bernhard Manfred Gruber <[email protected]>
1 parent ef5398c commit 5466563

File tree

5 files changed

+150
-1
lines changed

5 files changed

+150
-1
lines changed
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
// SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
3+
4+
#pragma once
5+
6+
#include <cub/config.cuh>
7+
8+
#include <cuda_runtime_api.h>
9+
#include <device_side_benchmark.cuh>
10+
#include <nvbench_helper.cuh>
11+
12+
struct benchmark_op_t
13+
{
14+
template <typename T>
15+
__device__ __forceinline__ T operator()(T thread_data) const
16+
{
17+
using WarpReduce = cub::WarpReduce<T>;
18+
using TempStorage = typename WarpReduce::TempStorage;
19+
__shared__ TempStorage temp_storage[32];
20+
auto warp_id = threadIdx.x / 32;
21+
return WarpReduce{temp_storage[warp_id]}.Reduce(thread_data, op_t{});
22+
}
23+
};
24+
25+
template <typename T>
26+
void warp_reduce(nvbench::state& state, nvbench::type_list<T>)
27+
{
28+
constexpr int block_size = 256;
29+
constexpr int unroll_factor = 128; // compromise between compile time and noise
30+
const auto& kernel = benchmark_kernel<block_size, unroll_factor, benchmark_op_t, T>;
31+
const int num_SMs = state.get_device().value().get_number_of_sms();
32+
const int device = state.get_device().value().get_id();
33+
int max_blocks_per_SM = 0;
34+
NVBENCH_CUDA_CALL_NOEXCEPT(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_blocks_per_SM, kernel, block_size, 0));
35+
const int grid_size = max_blocks_per_SM * num_SMs;
36+
state.exec(nvbench::exec_tag::gpu | nvbench::exec_tag::no_batch, [&](nvbench::launch&) {
37+
kernel<<<grid_size, block_size>>>(benchmark_op_t{});
38+
});
39+
}
40+
41+
NVBENCH_BENCH_TYPES(warp_reduce, NVBENCH_TYPE_AXES(value_types)).set_name("warp_reduce").set_type_axes_names({"T{ct}"});
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
// SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
3+
4+
#include <nvbench_helper.cuh>
5+
6+
// complex types cannot be compared with operator<
7+
using value_types =
8+
nvbench::type_list<int8_t,
9+
int16_t,
10+
int32_t,
11+
#if NVBENCH_HELPER_HAS_I128
12+
int128_t,
13+
#endif
14+
#if _CCCL_HAS_NVFP16() && _CCCL_CTK_AT_LEAST(12, 2)
15+
__half,
16+
#endif
17+
#if _CCCL_HAS_NVBF16() && _CCCL_CTK_AT_LEAST(12, 2)
18+
__nv_bfloat16,
19+
#endif
20+
float,
21+
double>;
22+
23+
using op_t = ::cuda::minimum<>;
24+
#include "warp_reduce_base.cuh"
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
// SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
3+
4+
#include <nvbench_helper.cuh>
5+
6+
using value_types = nvbench::type_list<
7+
int8_t,
8+
int16_t,
9+
int32_t,
10+
#if NVBENCH_HELPER_HAS_I128
11+
int128_t,
12+
#endif
13+
#if _CCCL_HAS_NVFP16() && _CCCL_CTK_AT_LEAST(12, 2)
14+
__half,
15+
#endif
16+
#if _CCCL_HAS_NVBF16() && _CCCL_CTK_AT_LEAST(12, 2)
17+
__nv_bfloat16,
18+
#endif
19+
float,
20+
double,
21+
#if _CCCL_HAS_NVFP16() && _CCCL_CTK_AT_LEAST(12, 2)
22+
cuda::std::complex<__half>,
23+
#endif
24+
#if _CCCL_HAS_NVBF16() && _CCCL_CTK_AT_LEAST(12, 2)
25+
cuda::std::complex<__nv_bfloat16>,
26+
#endif
27+
cuda::std::complex<float>,
28+
cuda::std::complex<double>>;
29+
30+
using op_t = ::cuda::std::plus<>;
31+
#include "warp_reduce_base.cuh"
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
// SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
3+
4+
#include <cuda/cmath>
5+
#include <cuda/ptx>
6+
#include <cuda/std/cstdint>
7+
#include <cuda/std/cstring>
8+
#include <cuda/utility>
9+
10+
template <typename T>
11+
__device__ __forceinline__ static T generate_random_data()
12+
{
13+
constexpr auto size = cuda::ceil_div(sizeof(T), sizeof(uint32_t));
14+
uint32_t data[size];
15+
for (int i = 0; i < size; i++)
16+
{
17+
data[i] = cuda::ptx::get_sreg_clock();
18+
}
19+
T ret;
20+
::cuda::std::memcpy(&ret, data, sizeof(T));
21+
return ret;
22+
}
23+
24+
__device__ static int device_var[16];
25+
26+
template <typename T>
27+
__device__ __forceinline__ static void sink(T value)
28+
{
29+
if (cuda::ptx::get_sreg_smid() == static_cast<uint32_t>(-1))
30+
{
31+
*reinterpret_cast<T*>(device_var) = value;
32+
}
33+
}
34+
35+
template <int BlockThreads, int UnrollFactor, typename ActionT, typename T>
36+
__launch_bounds__(BlockThreads) __global__ static void benchmark_kernel(_CCCL_GRID_CONSTANT const ActionT action)
37+
{
38+
auto data = generate_random_data<T>();
39+
cuda::static_for<UnrollFactor>([&]([[maybe_unused]] auto _) {
40+
data = action(data);
41+
});
42+
sink(data);
43+
}

nvbench_helper/nvbench_helper/nvbench_helper.cuh

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,17 @@ NVBENCH_DECLARE_TYPE_STRINGS(uint128_t, "U128", "uint128_t");
3333

3434
using complex = cuda::std::complex<float>;
3535

36-
NVBENCH_DECLARE_TYPE_STRINGS(complex, "C64", "complex");
36+
#if _CCCL_HAS_NVFP16()
37+
NVBENCH_DECLARE_TYPE_STRINGS(__half, "Half", "half");
38+
NVBENCH_DECLARE_TYPE_STRINGS(cuda::std::complex<__half>, "C16", "complex_half");
39+
#endif
40+
#if _CCCL_HAS_NVBF16()
41+
NVBENCH_DECLARE_TYPE_STRINGS(__nv_bfloat16, "Bfloat16", "bfloat16");
42+
NVBENCH_DECLARE_TYPE_STRINGS(cuda::std::complex<__nv_bfloat16>, "CB16", "complex_bfloat16");
43+
#endif
44+
NVBENCH_DECLARE_TYPE_STRINGS(complex, "C32", "complex32");
45+
NVBENCH_DECLARE_TYPE_STRINGS(cuda::std::complex<double>, "C64", "complex64");
46+
3747
NVBENCH_DECLARE_TYPE_STRINGS(::cuda::std::false_type, "false", "false_type");
3848
NVBENCH_DECLARE_TYPE_STRINGS(::cuda::std::true_type, "true", "true_type");
3949
NVBENCH_DECLARE_TYPE_STRINGS(cub::ArgMin, "ArgMin", "cub::ArgMin");

0 commit comments

Comments
 (0)