Skip to content

Commit 25d09e9

Browse files
pauleonixbernhardmgruber
authored andcommitted
Use BlockLoadToShared in DeviceMerge
1 parent 6bd1de3 commit 25d09e9

File tree

3 files changed

+214
-47
lines changed

3 files changed

+214
-47
lines changed

cub/cub/agent/agent_merge.cuh

Lines changed: 182 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -14,20 +14,34 @@
1414
#endif // no system header
1515

1616
#include <cub/agent/agent_merge_sort.cuh>
17-
#include <cub/block/block_load.cuh>
17+
#include <cub/block/block_load_to_shared.cuh>
1818
#include <cub/block/block_merge_sort.cuh>
1919
#include <cub/block/block_store.cuh>
2020
#include <cub/iterator/cache_modified_input_iterator.cuh>
2121
#include <cub/util_namespace.cuh>
2222
#include <cub/util_type.cuh>
2323

24+
#include <thrust/type_traits/is_contiguous_iterator.h>
25+
#include <thrust/type_traits/is_trivially_relocatable.h>
26+
#include <thrust/type_traits/unwrap_contiguous_iterator.h>
27+
28+
#include <cuda/__memory/ptr_rebind.h>
2429
#include <cuda/std/__algorithm/max.h>
2530
#include <cuda/std/__algorithm/min.h>
31+
#include <cuda/std/__type_traits/conditional.h>
32+
#include <cuda/std/__type_traits/is_same.h>
33+
#include <cuda/std/__type_traits/is_trivially_copyable.h>
34+
#include <cuda/std/cstddef>
35+
#include <cuda/std/span>
2636

2737
CUB_NAMESPACE_BEGIN
2838
namespace detail::merge
2939
{
30-
template <int ThreadsPerBlock, int ItemsPerThread, CacheLoadModifier LoadCacheModifier, BlockStoreAlgorithm StoreAlgorithm>
40+
template <int ThreadsPerBlock,
41+
int ItemsPerThread,
42+
CacheLoadModifier LoadCacheModifier,
43+
BlockStoreAlgorithm StoreAlgorithm,
44+
bool UseBlockLoadToShared = false>
3145
struct agent_policy_t
3246
{
3347
// do not change data member names, policy_wrapper_t depends on it
@@ -36,6 +50,7 @@ struct agent_policy_t
3650
static constexpr int ITEMS_PER_TILE = BLOCK_THREADS * ITEMS_PER_THREAD;
3751
static constexpr CacheLoadModifier LOAD_MODIFIER = LoadCacheModifier;
3852
static constexpr BlockStoreAlgorithm STORE_ALGORITHM = StoreAlgorithm;
53+
static constexpr bool use_block_load_to_shared = UseBlockLoadToShared;
3954
};
4055

4156
// TODO(bgruber): can we unify this one with AgentMerge in agent_merge_sort.cuh?
@@ -50,29 +65,78 @@ template <typename Policy,
5065
typename CompareOp>
5166
struct agent_t
5267
{
53-
using policy = Policy;
54-
55-
// key and value type are taken from the first input sequence (consistent with old Thrust behavior)
56-
using key_type = it_value_t<KeysIt1>;
57-
using item_type = it_value_t<ItemsIt1>;
58-
using block_store_keys = typename BlockStoreType<Policy, KeysOutputIt, key_type>::type;
59-
using block_store_items = typename BlockStoreType<Policy, ItemsOutputIt, item_type>::type;
60-
68+
using policy = Policy;
6169
static constexpr int items_per_thread = Policy::ITEMS_PER_THREAD;
6270
static constexpr int threads_per_block = Policy::BLOCK_THREADS;
6371
static constexpr int items_per_tile = Policy::ITEMS_PER_TILE;
6472

65-
union temp_storages
73+
// key and value type are taken from the first input sequence (consistent with old Thrust behavior)
74+
using key_type = it_value_t<KeysIt1>;
75+
using item_type = it_value_t<ItemsIt1>;
76+
77+
using block_load_to_shared = cub::detail::BlockLoadToShared<threads_per_block>;
78+
using block_store_keys = typename BlockStoreType<Policy, KeysOutputIt, key_type>::type;
79+
using block_store_items = typename BlockStoreType<Policy, ItemsOutputIt, item_type>::type;
80+
81+
template <typename ValueT, typename Iter1, typename Iter2>
82+
static constexpr bool use_block_load_to_shared =
83+
Policy::use_block_load_to_shared && (sizeof(ValueT) == alignof(ValueT))
84+
&& THRUST_NS_QUALIFIER::is_trivially_relocatable_v<ValueT> //
85+
&& THRUST_NS_QUALIFIER::is_contiguous_iterator_v<Iter1> //
86+
&& THRUST_NS_QUALIFIER::is_contiguous_iterator_v<Iter2>
87+
&& ::cuda::std::is_same_v<ValueT, cub::detail::it_value_t<Iter1>>
88+
&& ::cuda::std::is_same_v<ValueT, cub::detail::it_value_t<Iter2>>;
89+
90+
static constexpr bool keys_use_block_load_to_shared = use_block_load_to_shared<key_type, KeysIt1, KeysIt2>;
91+
static constexpr bool items_use_block_load_to_shared = use_block_load_to_shared<item_type, ItemsIt1, ItemsIt2>;
92+
static constexpr bool need_block_load_to_shared = keys_use_block_load_to_shared || items_use_block_load_to_shared;
93+
static constexpr int load2sh_minimum_align = block_load_to_shared::template SharedBufferAlignBytes<char>();
94+
95+
struct empty_t
96+
{
97+
struct TempStorage
98+
{};
99+
_CCCL_DEVICE _CCCL_FORCEINLINE empty_t(TempStorage) {}
100+
};
101+
102+
using optional_load2sh_t = ::cuda::std::conditional_t<need_block_load_to_shared, block_load_to_shared, empty_t>;
103+
104+
template <typename ValueT, bool UseBlockLoadToShared>
105+
struct alignas(UseBlockLoadToShared ? block_load_to_shared::template SharedBufferAlignBytes<ValueT>()
106+
: alignof(ValueT)) buffer_t
107+
{
108+
// Need extra bytes of padding for TMA because this static buffer has to hold the two dynamically sized buffers.
109+
char c_array[UseBlockLoadToShared ? (block_load_to_shared::template SharedBufferSizeBytes<ValueT>(items_per_tile + 1)
110+
+ (alignof(ValueT) < load2sh_minimum_align ? 2 * load2sh_minimum_align : 0))
111+
: sizeof(ValueT) * (items_per_tile + 1)];
112+
};
113+
114+
struct temp_storages_bl2sh
115+
{
116+
union
117+
{
118+
typename block_store_keys::TempStorage store_keys;
119+
typename block_store_items::TempStorage store_items;
120+
buffer_t<key_type, keys_use_block_load_to_shared> keys_shared;
121+
buffer_t<item_type, items_use_block_load_to_shared> items_shared;
122+
};
123+
typename block_load_to_shared::TempStorage load2sh;
124+
};
125+
126+
union temp_storages_fallback
66127
{
67128
typename block_store_keys::TempStorage store_keys;
68129
typename block_store_items::TempStorage store_items;
69130

70-
// We could change SerialMerge to avoid reading one item out of bounds and drop the + 1 here. But that would
71-
// introduce more branches (about 10% slower on 2^16 problem sizes on RTX 5090 in a first attempt)
72-
key_type keys_shared[items_per_tile + 1];
73-
item_type items_shared[items_per_tile + 1];
131+
buffer_t<key_type, keys_use_block_load_to_shared> keys_shared;
132+
buffer_t<item_type, items_use_block_load_to_shared> items_shared;
133+
134+
typename empty_t::TempStorage load2sh;
74135
};
75136

137+
using temp_storages =
138+
::cuda::std::conditional_t<need_block_load_to_shared, temp_storages_bl2sh, temp_storages_fallback>;
139+
76140
struct TempStorage : Uninitialized<temp_storages>
77141
{};
78142

@@ -121,18 +185,50 @@ struct agent_t
121185
_CCCL_ASSERT(keys1_count_tile + keys2_count_tile == num_remaining, "");
122186
}
123187

124-
key_type keys_loc[items_per_thread];
188+
optional_load2sh_t load2sh{storage.load2sh};
189+
190+
key_type* keys1_shared;
191+
key_type* keys2_shared;
192+
int keys2_offset;
193+
if constexpr (keys_use_block_load_to_shared)
194+
{
195+
::cuda::std::span keys1_src{THRUST_NS_QUALIFIER::try_unwrap_contiguous_iterator(keys1_in + keys1_beg),
196+
static_cast<::cuda::std::size_t>(keys1_count_tile)};
197+
::cuda::std::span keys2_src{THRUST_NS_QUALIFIER::try_unwrap_contiguous_iterator(keys2_in + keys2_beg),
198+
static_cast<::cuda::std::size_t>(keys2_count_tile)};
199+
::cuda::std::span keys_buffers{storage.keys_shared.c_array};
200+
auto keys1_buffer =
201+
keys_buffers.first(block_load_to_shared::template SharedBufferSizeBytes<key_type>(keys1_count_tile));
202+
auto keys2_buffer =
203+
keys_buffers.last(block_load_to_shared::template SharedBufferSizeBytes<key_type>(keys2_count_tile));
204+
_CCCL_ASSERT(keys1_buffer.end() <= keys2_buffer.begin(),
205+
"Keys buffer needs to be appropriately sized (internal)");
206+
auto keys1_sh = load2sh.CopyAsync(keys1_buffer, keys1_src);
207+
auto keys2_sh = load2sh.CopyAsync(keys2_buffer, keys2_src);
208+
load2sh.Commit();
209+
keys1_shared = data(keys1_sh);
210+
keys2_shared = data(keys2_sh);
211+
// Needed for using keys1_shared as one big buffer including both ranges in SerialMerge
212+
keys2_offset = static_cast<int>(keys2_shared - keys1_shared);
213+
load2sh.Wait();
214+
}
215+
else
125216
{
217+
key_type keys_loc[items_per_thread];
126218
auto keys1_in_cm = try_make_cache_modified_iterator<Policy::LOAD_MODIFIER>(keys1_in);
127219
auto keys2_in_cm = try_make_cache_modified_iterator<Policy::LOAD_MODIFIER>(keys2_in);
128220
merge_sort::gmem_to_reg<threads_per_block, IsFullTile>(
129221
keys_loc, keys1_in_cm + keys1_beg, keys2_in_cm + keys2_beg, keys1_count_tile, keys2_count_tile);
130-
merge_sort::reg_to_shared<threads_per_block>(&storage.keys_shared[0], keys_loc);
222+
keys1_shared = &::cuda::ptr_rebind<key_type>(storage.keys_shared.c_array)[0];
223+
// Needed for using keys1_shared as one big buffer including both ranges in SerialMerge
224+
keys2_offset = keys1_count_tile;
225+
keys2_shared = keys1_shared + keys2_offset;
226+
merge_sort::reg_to_shared<threads_per_block>(keys1_shared, keys_loc);
131227
__syncthreads();
132228
}
133229

134-
// now find the merge path for each of thread.
135-
// we can use int type here, because the number of items in shared memory is limited
230+
// Now find the merge path for each of the threads.
231+
// We can use int type here, because the number of items in shared memory is limited.
136232
int diag0_thread = items_per_thread * static_cast<int>(threadIdx.x);
137233
if constexpr (IsFullTile)
138234
{
@@ -144,24 +240,20 @@ struct agent_t
144240
diag0_thread = (::cuda::std::min) (diag0_thread, num_remaining);
145241
}
146242

147-
const int keys1_beg_thread = MergePath(
148-
&storage.keys_shared[0],
149-
&storage.keys_shared[keys1_count_tile],
150-
keys1_count_tile,
151-
keys2_count_tile,
152-
diag0_thread,
153-
compare_op);
243+
const int keys1_beg_thread =
244+
MergePath(keys1_shared, keys2_shared, keys1_count_tile, keys2_count_tile, diag0_thread, compare_op);
154245
const int keys2_beg_thread = diag0_thread - keys1_beg_thread;
155246

156247
const int keys1_count_thread = keys1_count_tile - keys1_beg_thread;
157248
const int keys2_count_thread = keys2_count_tile - keys2_beg_thread;
158249

159250
// perform serial merge
251+
key_type keys_loc[items_per_thread];
160252
int indices[items_per_thread];
161-
SerialMerge(
162-
&storage.keys_shared[0],
253+
cub::SerialMerge(
254+
keys1_shared,
163255
keys1_beg_thread,
164-
keys2_beg_thread + keys1_count_tile,
256+
keys2_offset + keys2_beg_thread,
165257
keys1_count_thread,
166258
keys2_count_thread,
167259
keys_loc,
@@ -183,22 +275,73 @@ struct agent_t
183275
static constexpr bool have_items = !::cuda::std::is_same_v<item_type, NullType>;
184276
if constexpr (have_items)
185277
{
186-
item_type items_loc[items_per_thread];
278+
// Both of these are only needed when either keys or items or both use BlockLoadToShared introducing padding (that
279+
// can differ between the keys and items)
280+
[[maybe_unsused]] const auto translate_indices = [&](int items2_offset) -> void {
281+
const int diff = items2_offset - keys2_offset;
282+
_CCCL_PRAGMA_UNROLL_FULL()
283+
for (int i = 0; i < items_per_thread; ++i)
284+
{
285+
if (indices[i] >= keys2_offset)
286+
{
287+
indices[i] += diff;
288+
}
289+
}
290+
};
291+
// WAR for MSVC erroring ("declared but never referenced") despite [[maybe_unused]]
292+
(void) translate_indices;
293+
294+
item_type* items1_shared;
295+
if constexpr (keys_use_block_load_to_shared)
187296
{
188-
auto items1_in_cm = try_make_cache_modified_iterator<Policy::LOAD_MODIFIER>(items1_in);
189-
auto items2_in_cm = try_make_cache_modified_iterator<Policy::LOAD_MODIFIER>(items2_in);
190-
merge_sort::gmem_to_reg<threads_per_block, IsFullTile>(
191-
items_loc, items1_in_cm + keys1_beg, items2_in_cm + keys2_beg, keys1_count_tile, keys2_count_tile);
192-
__syncthreads(); // block_store_keys above uses SMEM, so make sure all threads are done before we write to it
193-
merge_sort::reg_to_shared<threads_per_block>(&storage.items_shared[0], items_loc);
297+
::cuda::std::span items1_src{THRUST_NS_QUALIFIER::try_unwrap_contiguous_iterator(items1_in + keys1_beg),
298+
static_cast<::cuda::std::size_t>(keys1_count_tile)};
299+
::cuda::std::span items2_src{THRUST_NS_QUALIFIER::try_unwrap_contiguous_iterator(items2_in + keys2_beg),
300+
static_cast<::cuda::std::size_t>(keys2_count_tile)};
301+
::cuda::std::span items_buffers{storage.items_shared.c_array};
302+
auto items1_buffer =
303+
items_buffers.first(block_load_to_shared::template SharedBufferSizeBytes<item_type>(keys1_count_tile));
304+
auto items2_buffer =
305+
items_buffers.last(block_load_to_shared::template SharedBufferSizeBytes<item_type>(keys2_count_tile));
306+
_CCCL_ASSERT(items1_buffer.end() <= items2_buffer.begin(),
307+
"Items buffer needs to be appropriately sized (internal)");
308+
// block_store_keys above uses shared memory, so make sure all threads are done before we write
194309
__syncthreads();
310+
auto items1_sh = load2sh.CopyAsync(items1_buffer, items1_src);
311+
auto items2_sh = load2sh.CopyAsync(items2_buffer, items2_src);
312+
load2sh.Commit();
313+
items1_shared = data(items1_sh);
314+
item_type* items2_shared = data(items2_sh);
315+
const int items2_offset = static_cast<int>(items2_shared - items1_shared);
316+
translate_indices(items2_offset);
317+
load2sh.Wait();
318+
}
319+
else
320+
{
321+
item_type items_loc[items_per_thread];
322+
{
323+
auto items1_in_cm = try_make_cache_modified_iterator<Policy::LOAD_MODIFIER>(items1_in);
324+
auto items2_in_cm = try_make_cache_modified_iterator<Policy::LOAD_MODIFIER>(items2_in);
325+
merge_sort::gmem_to_reg<threads_per_block, IsFullTile>(
326+
items_loc, items1_in_cm + keys1_beg, items2_in_cm + keys2_beg, keys1_count_tile, keys2_count_tile);
327+
__syncthreads(); // block_store_keys above uses SMEM, so make sure all threads are done before we write to it
328+
items1_shared = &::cuda::ptr_rebind<item_type>(storage.items_shared.c_array)[0];
329+
if constexpr (keys_use_block_load_to_shared)
330+
{
331+
const int items2_offset = keys1_count_tile;
332+
translate_indices(items2_offset);
333+
}
334+
merge_sort::reg_to_shared<threads_per_block>(items1_shared, items_loc);
335+
__syncthreads();
336+
}
195337
}
196338

197339
// gather items from shared mem
340+
item_type items_loc[items_per_thread];
198341
_CCCL_PRAGMA_UNROLL_FULL()
199342
for (int i = 0; i < items_per_thread; ++i)
200343
{
201-
items_loc[i] = storage.items_shared[indices[i]];
344+
items_loc[i] = items1_shared[indices[i]];
202345
}
203346
__syncthreads();
204347

@@ -222,11 +365,11 @@ struct agent_t
222365
static_cast<int>((::cuda::std::min) (static_cast<Offset>(items_per_tile), keys1_count + keys2_count - tile_base));
223366
if (items_in_tile == items_per_tile)
224367
{
225-
consume_tile</* IsFullTile */ true>(tile_idx, tile_base, items_per_tile);
368+
consume_tile</* IsFullTile = */ true>(tile_idx, tile_base, items_per_tile);
226369
}
227370
else
228371
{
229-
consume_tile</* IsFullTile */ false>(tile_idx, tile_base, items_in_tile);
372+
consume_tile</* IsFullTile = */ false>(tile_idx, tile_base, items_in_tile);
230373
}
231374
}
232375
};

cub/cub/device/dispatch/dispatch_merge.cuh

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,22 +28,36 @@
2828
CUB_NAMESPACE_BEGIN
2929
namespace detail::merge
3030
{
31+
template <typename PolicyT>
32+
struct policy_noblockload2smem_t : PolicyT
33+
{
34+
static constexpr bool use_block_load_to_shared = false;
35+
};
36+
3137
inline constexpr int fallback_BLOCK_THREADS = 64;
3238
inline constexpr int fallback_ITEMS_PER_THREAD = 1;
3339

3440
template <typename DefaultPolicy, class... Args>
3541
class choose_merge_agent
3642
{
37-
using default_agent_t = agent_t<DefaultPolicy, Args...>;
38-
using fallback_agent_t =
39-
agent_t<policy_wrapper_t<DefaultPolicy, fallback_BLOCK_THREADS, fallback_ITEMS_PER_THREAD>, Args...>;
43+
using default_load2sh_agent_t = agent_t<DefaultPolicy, Args...>;
44+
using default_noload2sh_agent_t = agent_t<policy_noblockload2smem_t<DefaultPolicy>, Args...>;
45+
46+
using fallback_agent_t = agent_t<
47+
policy_wrapper_t<policy_noblockload2smem_t<DefaultPolicy>, fallback_BLOCK_THREADS, fallback_ITEMS_PER_THREAD>,
48+
Args...>;
4049

41-
// Use fallback if merge agent exceeds maximum shared memory, but the fallback agent still fits
42-
static constexpr bool use_fallback = sizeof(typename default_agent_t::TempStorage) > max_smem_per_block
43-
&& sizeof(typename fallback_agent_t::TempStorage) <= max_smem_per_block;
50+
static constexpr bool use_default_load2sh =
51+
sizeof(typename default_load2sh_agent_t::TempStorage) <= max_smem_per_block;
52+
// Use fallback if merge agent exceeds maximum shared memory, but the fallback agent still fits, else use
53+
// vsmem-compatible version, so noload2sh
54+
static constexpr bool use_fallback = sizeof(typename fallback_agent_t::TempStorage) <= max_smem_per_block;
4455

4556
public:
46-
using type = ::cuda::std::conditional_t<use_fallback, fallback_agent_t, default_agent_t>;
57+
using type =
58+
::cuda::std::conditional_t<use_default_load2sh,
59+
default_load2sh_agent_t,
60+
::cuda::std::conditional_t<use_fallback, fallback_agent_t, default_noload2sh_agent_t>>;
4761
};
4862

4963
// Computes the merge path intersections at equally wide intervals. The approach is outlined in the paper:

cub/cub/device/dispatch/tuning/tuning_merge.cuh

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,17 @@ struct policy_hub
4545
agent_policy_t<512, Nominal4BItemsToItems<tune_type>(15), LOAD_DEFAULT, BLOCK_STORE_WARP_TRANSPOSE>;
4646
};
4747

48-
using max_policy = policy600;
48+
struct policy800 : ChainedPolicy<800, policy800, policy600>
49+
{
50+
using merge_policy =
51+
agent_policy_t<512,
52+
Nominal4BItemsToItems<tune_type>(15),
53+
LOAD_DEFAULT,
54+
BLOCK_STORE_WARP_TRANSPOSE,
55+
/* UseBlockLoadToShared = */ true>;
56+
};
57+
58+
using max_policy = policy800;
4959
};
5060
} // namespace detail::merge
5161

0 commit comments

Comments
 (0)