Skip to content

Commit 5de62bd

Browse files
author
pauleonix
committed
Use BlockLoadToShared in DeviceMerge
1 parent cf3b2d5 commit 5de62bd

File tree

3 files changed

+224
-53
lines changed

3 files changed

+224
-53
lines changed

cub/cub/agent/agent_merge.cuh

Lines changed: 188 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,25 @@
1414
#endif // no system header
1515

1616
#include <cub/agent/agent_merge_sort.cuh>
17-
#include <cub/block/block_load.cuh>
17+
#include <cub/block/block_load_to_shared.cuh>
1818
#include <cub/block/block_merge_sort.cuh>
1919
#include <cub/block/block_store.cuh>
2020
#include <cub/iterator/cache_modified_input_iterator.cuh>
2121
#include <cub/util_namespace.cuh>
2222
#include <cub/util_type.cuh>
2323

24+
#include <thrust/type_traits/is_contiguous_iterator.h>
25+
#include <thrust/type_traits/is_trivially_relocatable.h>
26+
#include <thrust/type_traits/unwrap_contiguous_iterator.h>
27+
28+
#include <cuda/__memory/ptr_rebind.h>
2429
#include <cuda/std/__algorithm/max.h>
2530
#include <cuda/std/__algorithm/min.h>
31+
#include <cuda/std/__type_traits/conditional.h>
32+
#include <cuda/std/__type_traits/is_same.h>
33+
#include <cuda/std/__type_traits/is_trivially_copyable.h>
34+
#include <cuda/std/cstddef>
35+
#include <cuda/std/span>
2636

2737
CUB_NAMESPACE_BEGIN
2838
namespace detail
@@ -33,7 +43,8 @@ template <int ThreadsPerBlock,
3343
int ItemsPerThread,
3444
BlockLoadAlgorithm LoadAlgorithm,
3545
CacheLoadModifier LoadCacheModifier,
36-
BlockStoreAlgorithm StoreAlgorithm>
46+
BlockStoreAlgorithm StoreAlgorithm,
47+
bool UseBlockLoadToShared = false>
3748
struct agent_policy_t
3849
{
3950
// do not change data member names, policy_wrapper_t depends on it
@@ -43,6 +54,7 @@ struct agent_policy_t
4354
static constexpr BlockLoadAlgorithm LOAD_ALGORITHM = LoadAlgorithm;
4455
static constexpr CacheLoadModifier LOAD_MODIFIER = LoadCacheModifier;
4556
static constexpr BlockStoreAlgorithm STORE_ALGORITHM = StoreAlgorithm;
57+
static constexpr bool use_block_load_to_shared = UseBlockLoadToShared;
4658
};
4759

4860
// TODO(bgruber): can we unify this one with AgentMerge in agent_merge_sort.cuh?
@@ -54,48 +66,102 @@ template <typename Policy,
5466
typename KeysOutputIt,
5567
typename ItemsOutputIt,
5668
typename Offset,
57-
typename CompareOp>
69+
typename CompareOp,
70+
bool AllowBlockLoadToShared>
5871
struct agent_t
5972
{
60-
using policy = Policy;
73+
using policy = Policy;
74+
static constexpr int items_per_thread = Policy::ITEMS_PER_THREAD;
75+
static constexpr int threads_per_block = Policy::BLOCK_THREADS;
76+
static constexpr Offset items_per_tile = Policy::ITEMS_PER_TILE;
6177

6278
// key and value type are taken from the first input sequence (consistent with old Thrust behavior)
6379
using key_type = it_value_t<KeysIt1>;
6480
using item_type = it_value_t<ItemsIt1>;
6581

66-
using keys_load_it1 = try_make_cache_modified_iterator_t<Policy::LOAD_MODIFIER, KeysIt1>;
67-
using keys_load_it2 = try_make_cache_modified_iterator_t<Policy::LOAD_MODIFIER, KeysIt2>;
68-
using items_load_it1 = try_make_cache_modified_iterator_t<Policy::LOAD_MODIFIER, ItemsIt1>;
69-
using items_load_it2 = try_make_cache_modified_iterator_t<Policy::LOAD_MODIFIER, ItemsIt2>;
82+
using block_load_to_shared = cub::detail::BlockLoadToShared<threads_per_block>;
83+
using block_store_keys = typename BlockStoreType<Policy, KeysOutputIt, key_type>::type;
84+
using block_store_items = typename BlockStoreType<Policy, ItemsOutputIt, item_type>::type;
85+
86+
template <typename ValueT, typename Iter1, typename Iter2>
87+
static constexpr bool use_block_load_to_shared =
88+
Policy::use_block_load_to_shared && (sizeof(ValueT) == alignof(ValueT)) && AllowBlockLoadToShared
89+
&& THRUST_NS_QUALIFIER::is_trivially_relocatable_v<ValueT> //
90+
&& THRUST_NS_QUALIFIER::is_contiguous_iterator_v<Iter1> //
91+
&& THRUST_NS_QUALIFIER::is_contiguous_iterator_v<Iter2>
92+
&& ::cuda::std::is_same_v<ValueT, cub::detail::it_value_t<Iter1>>
93+
&& ::cuda::std::is_same_v<ValueT, cub::detail::it_value_t<Iter2>>;
94+
95+
static constexpr bool keys_use_block_load_to_shared = use_block_load_to_shared<key_type, KeysIt1, KeysIt2>;
96+
static constexpr bool items_use_block_load_to_shared = use_block_load_to_shared<item_type, ItemsIt1, ItemsIt2>;
97+
static constexpr bool need_block_load_to_shared = keys_use_block_load_to_shared || items_use_block_load_to_shared;
98+
static constexpr int load2sh_minimum_align = block_load_to_shared::template SharedBufferAlignBytes<char>();
99+
100+
struct empty_t
101+
{
102+
struct TempStorage
103+
{};
104+
_CCCL_DEVICE _CCCL_FORCEINLINE empty_t(TempStorage) {}
105+
};
106+
107+
using optional_load2sh_t = ::cuda::std::conditional_t<need_block_load_to_shared, block_load_to_shared, empty_t>;
108+
109+
using keys_load_it1 =
110+
::cuda::std::conditional_t<keys_use_block_load_to_shared,
111+
KeysIt1,
112+
try_make_cache_modified_iterator_t<Policy::LOAD_MODIFIER, KeysIt1>>;
113+
using keys_load_it2 =
114+
::cuda::std::conditional_t<keys_use_block_load_to_shared,
115+
KeysIt2,
116+
try_make_cache_modified_iterator_t<Policy::LOAD_MODIFIER, KeysIt2>>;
117+
using items_load_it1 =
118+
::cuda::std::conditional_t<items_use_block_load_to_shared,
119+
ItemsIt1,
120+
try_make_cache_modified_iterator_t<Policy::LOAD_MODIFIER, ItemsIt1>>;
121+
using items_load_it2 =
122+
::cuda::std::conditional_t<items_use_block_load_to_shared,
123+
ItemsIt2,
124+
try_make_cache_modified_iterator_t<Policy::LOAD_MODIFIER, ItemsIt2>>;
70125

71-
using block_load_keys1 = typename BlockLoadType<Policy, keys_load_it1>::type;
72-
using block_load_keys2 = typename BlockLoadType<Policy, keys_load_it2>::type;
73-
using block_load_items1 = typename BlockLoadType<Policy, items_load_it1>::type;
74-
using block_load_items2 = typename BlockLoadType<Policy, items_load_it2>::type;
126+
template <typename ValueT, bool UseBlockLoadToShared>
127+
struct alignas(UseBlockLoadToShared ? block_load_to_shared::template SharedBufferAlignBytes<ValueT>()
128+
: alignof(ValueT)) buffer_t
129+
{
130+
// Need extra bytes of padding for TMA because this static buffer has to hold the two dynamically sized buffers.
131+
char c_array[UseBlockLoadToShared ? (block_load_to_shared::template SharedBufferSizeBytes<ValueT>(items_per_tile + 1)
132+
+ (alignof(ValueT) < load2sh_minimum_align ? 2 * load2sh_minimum_align : 0))
133+
: sizeof(ValueT) * (items_per_tile + 1)];
134+
};
75135

76-
using block_store_keys = typename BlockStoreType<Policy, KeysOutputIt, key_type>::type;
77-
using block_store_items = typename BlockStoreType<Policy, ItemsOutputIt, item_type>::type;
136+
struct temp_storages_bl2sh
137+
{
138+
union
139+
{
140+
typename block_store_keys::TempStorage store_keys;
141+
typename block_store_items::TempStorage store_items;
142+
buffer_t<key_type, keys_use_block_load_to_shared> keys_shared;
143+
buffer_t<item_type, items_use_block_load_to_shared> items_shared;
144+
};
145+
typename block_load_to_shared::TempStorage load2sh;
146+
};
78147

79-
union temp_storages
148+
union temp_storages_fallback
80149
{
81-
typename block_load_keys1::TempStorage load_keys1;
82-
typename block_load_keys2::TempStorage load_keys2;
83-
typename block_load_items1::TempStorage load_items1;
84-
typename block_load_items2::TempStorage load_items2;
85150
typename block_store_keys::TempStorage store_keys;
86151
typename block_store_items::TempStorage store_items;
87152

88-
key_type keys_shared[Policy::ITEMS_PER_TILE + 1];
89-
item_type items_shared[Policy::ITEMS_PER_TILE + 1];
153+
buffer_t<key_type, keys_use_block_load_to_shared> keys_shared;
154+
buffer_t<item_type, items_use_block_load_to_shared> items_shared;
155+
156+
typename empty_t::TempStorage load2sh;
90157
};
91158

159+
using temp_storages =
160+
::cuda::std::conditional_t<need_block_load_to_shared, temp_storages_bl2sh, temp_storages_fallback>;
161+
92162
struct TempStorage : Uninitialized<temp_storages>
93163
{};
94164

95-
static constexpr int items_per_thread = Policy::ITEMS_PER_THREAD;
96-
static constexpr int threads_per_block = Policy::BLOCK_THREADS;
97-
static constexpr Offset items_per_tile = Policy::ITEMS_PER_TILE;
98-
99165
// Per thread data
100166
temp_storages& storage;
101167
keys_load_it1 keys1_in;
@@ -128,18 +194,49 @@ struct agent_t
128194
const int num_keys1 = static_cast<int>(keys1_end - keys1_beg);
129195
const int num_keys2 = static_cast<int>(keys2_end - keys2_beg);
130196

131-
key_type keys_loc[items_per_thread];
132-
merge_sort::gmem_to_reg<threads_per_block, IsFullTile>(
133-
keys_loc, keys1_in + keys1_beg, keys2_in + keys2_beg, num_keys1, num_keys2);
134-
merge_sort::reg_to_shared<threads_per_block>(&storage.keys_shared[0], keys_loc);
135-
__syncthreads();
197+
optional_load2sh_t load2sh{storage.load2sh};
198+
199+
key_type* keys1_shared;
200+
key_type* keys2_shared;
201+
int keys2_offset;
202+
if constexpr (keys_use_block_load_to_shared)
203+
{
204+
::cuda::std::span keys1_src{THRUST_NS_QUALIFIER::try_unwrap_contiguous_iterator(keys1_in + keys1_beg),
205+
static_cast<::cuda::std::size_t>(num_keys1)};
206+
::cuda::std::span keys2_src{THRUST_NS_QUALIFIER::try_unwrap_contiguous_iterator(keys2_in + keys2_beg),
207+
static_cast<::cuda::std::size_t>(num_keys2)};
208+
::cuda::std::span keys_buffers{storage.keys_shared.c_array};
209+
auto keys1_buffer = keys_buffers.first(block_load_to_shared::template SharedBufferSizeBytes<key_type>(num_keys1));
210+
auto keys2_buffer = keys_buffers.last(block_load_to_shared::template SharedBufferSizeBytes<key_type>(num_keys2));
211+
_CCCL_ASSERT(keys1_buffer.end() <= keys2_buffer.begin(),
212+
"Keys buffer needs to be appropriately sized (internal)");
213+
auto keys1_sh = load2sh.CopyAsync(keys1_buffer, keys1_src);
214+
auto keys2_sh = load2sh.CopyAsync(keys2_buffer, keys2_src);
215+
load2sh.Commit();
216+
keys1_shared = data(keys1_sh);
217+
keys2_shared = data(keys2_sh);
218+
// Needed for using keys1_shared as one big buffer including both ranges in SerialMerge
219+
keys2_offset = static_cast<int>(keys2_shared - keys1_shared);
220+
load2sh.Wait();
221+
}
222+
else
223+
{
224+
key_type keys_loc[items_per_thread];
225+
merge_sort::gmem_to_reg<threads_per_block, IsFullTile>(
226+
keys_loc, keys1_in + keys1_beg, keys2_in + keys2_beg, num_keys1, num_keys2);
227+
keys1_shared = &::cuda::ptr_rebind<key_type>(storage.keys_shared.c_array)[0];
228+
// Needed for using keys1_shared as one big buffer including both ranges in SerialMerge
229+
keys2_offset = num_keys1;
230+
keys2_shared = keys1_shared + keys2_offset;
231+
merge_sort::reg_to_shared<threads_per_block>(keys1_shared, keys_loc);
232+
__syncthreads();
233+
}
136234

137235
// use binary search in shared memory to find merge path for each of thread.
138236
// we can use int type here, because the number of items in shared memory is limited
139237
const int diag0_loc = (::cuda::std::min) (num_keys1 + num_keys2, static_cast<int>(items_per_thread * threadIdx.x));
140238

141-
const int keys1_beg_loc =
142-
MergePath(&storage.keys_shared[0], &storage.keys_shared[num_keys1], num_keys1, num_keys2, diag0_loc, compare_op);
239+
const int keys1_beg_loc = MergePath(keys1_shared, keys2_shared, num_keys1, num_keys2, diag0_loc, compare_op);
143240
const int keys1_end_loc = num_keys1;
144241
const int keys2_beg_loc = diag0_loc - keys1_beg_loc;
145242
const int keys2_end_loc = num_keys2;
@@ -148,11 +245,12 @@ struct agent_t
148245
const int num_keys2_loc = keys2_end_loc - keys2_beg_loc;
149246

150247
// perform serial merge
248+
key_type keys_loc[items_per_thread];
151249
int indices[items_per_thread];
152250
cub::SerialMerge(
153-
&storage.keys_shared[0],
251+
keys1_shared,
154252
keys1_beg_loc,
155-
keys2_beg_loc + num_keys1,
253+
keys2_offset + keys2_beg_loc,
156254
num_keys1_loc,
157255
num_keys2_loc,
158256
keys_loc,
@@ -174,19 +272,67 @@ struct agent_t
174272
static constexpr bool have_items = !::cuda::std::is_same_v<item_type, NullType>;
175273
if constexpr (have_items)
176274
{
177-
item_type items_loc[items_per_thread];
178-
merge_sort::gmem_to_reg<threads_per_block, IsFullTile>(
179-
items_loc, items1_in + keys1_beg, items2_in + keys2_beg, num_keys1, num_keys2);
180-
__syncthreads(); // block_store_keys above uses shared memory, so make sure all threads are done before we write
181-
// to it
182-
merge_sort::reg_to_shared<threads_per_block>(&storage.items_shared[0], items_loc);
183-
__syncthreads();
275+
[[maybe_unsused]] const auto translate_indices = [&](int items2_offset) -> void {
276+
const int diff = items2_offset - keys2_offset;
277+
_CCCL_PRAGMA_UNROLL_FULL()
278+
for (int i = 0; i < items_per_thread; ++i)
279+
{
280+
if (indices[i] >= keys2_offset)
281+
{
282+
indices[i] += diff;
283+
}
284+
}
285+
};
286+
287+
item_type* items1_shared;
288+
int items2_offset;
289+
if constexpr (keys_use_block_load_to_shared)
290+
{
291+
::cuda::std::span items1_src{THRUST_NS_QUALIFIER::try_unwrap_contiguous_iterator(items1_in + keys1_beg),
292+
static_cast<::cuda::std::size_t>(num_keys1)};
293+
::cuda::std::span items2_src{THRUST_NS_QUALIFIER::try_unwrap_contiguous_iterator(items2_in + keys2_beg),
294+
static_cast<::cuda::std::size_t>(num_keys2)};
295+
::cuda::std::span items_buffers{storage.items_shared.c_array};
296+
auto items1_buffer =
297+
items_buffers.first(block_load_to_shared::template SharedBufferSizeBytes<item_type>(num_keys1));
298+
auto items2_buffer =
299+
items_buffers.last(block_load_to_shared::template SharedBufferSizeBytes<item_type>(num_keys2));
300+
_CCCL_ASSERT(items1_buffer.end() <= items2_buffer.begin(),
301+
"Items buffer needs to be appropriately sized (internal)");
302+
// block_store_keys above uses shared memory, so make sure all threads are done before we write
303+
__syncthreads();
304+
auto items1_sh = load2sh.CopyAsync(items1_buffer, items1_src);
305+
auto items2_sh = load2sh.CopyAsync(items2_buffer, items2_src);
306+
load2sh.Commit();
307+
items1_shared = data(items1_sh);
308+
item_type* items2_shared = data(items2_sh);
309+
items2_offset = static_cast<int>(items2_shared - items1_shared);
310+
translate_indices(items2_offset);
311+
load2sh.Wait();
312+
}
313+
else
314+
{
315+
item_type items_loc[items_per_thread];
316+
merge_sort::gmem_to_reg<threads_per_block, IsFullTile>(
317+
items_loc, items1_in + keys1_beg, items2_in + keys2_beg, num_keys1, num_keys2);
318+
__syncthreads(); // block_store_keys above uses shared memory, so make sure all threads are done before we write
319+
// to it
320+
items1_shared = &::cuda::ptr_rebind<item_type>(storage.items_shared.c_array)[0];
321+
items2_offset = num_keys1;
322+
if constexpr (keys_use_block_load_to_shared)
323+
{
324+
translate_indices(items2_offset);
325+
}
326+
merge_sort::reg_to_shared<threads_per_block>(items1_shared, items_loc);
327+
__syncthreads();
328+
}
184329

185330
// gather items from shared mem
331+
item_type items_loc[items_per_thread];
186332
_CCCL_PRAGMA_UNROLL_FULL()
187333
for (int i = 0; i < items_per_thread; ++i)
188334
{
189-
items_loc[i] = storage.items_shared[indices[i]];
335+
items_loc[i] = items1_shared[indices[i]];
190336
}
191337
__syncthreads();
192338

cub/cub/device/dispatch/dispatch_merge.cuh

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -34,16 +34,30 @@ inline constexpr int fallback_ITEMS_PER_THREAD = 1;
3434
template <typename DefaultPolicy, class... Args>
3535
class choose_merge_agent
3636
{
37-
using default_agent_t = agent_t<DefaultPolicy, Args...>;
37+
using default_load2sh_agent_t = agent_t<DefaultPolicy, Args..., /* AllowBlockLoadToShared = */ true>;
38+
using default_noload2sh_agent_t = agent_t<DefaultPolicy, Args..., /* AllowBlockLoadToShared = */ false>;
39+
40+
// Disallow BlockLoadToShared with its additional padding because we want to keep TempStorage minimal in the fallback,
41+
// so avoid BlockLoadToShared with its padding needs. The restricted tile size of the fallback might also be a bad
42+
// combination with the expensive mbarrier setup even when no padding is needed.
3843
using fallback_agent_t =
39-
agent_t<policy_wrapper_t<DefaultPolicy, fallback_BLOCK_THREADS, fallback_ITEMS_PER_THREAD>, Args...>;
44+
agent_t<policy_wrapper_t<DefaultPolicy, fallback_BLOCK_THREADS, fallback_ITEMS_PER_THREAD>, Args..., false>;
4045

41-
// Use fallback if merge agent exceeds maximum shared memory, but the fallback agent still fits
42-
static constexpr bool use_fallback = sizeof(typename default_agent_t::TempStorage) > max_smem_per_block
43-
&& sizeof(typename fallback_agent_t::TempStorage) <= max_smem_per_block;
46+
static constexpr bool use_default_load2sh =
47+
sizeof(typename default_load2sh_agent_t::TempStorage) <= max_smem_per_block;
48+
static constexpr bool use_default_noload2sh =
49+
sizeof(typename default_noload2sh_agent_t::TempStorage) <= max_smem_per_block;
50+
// Use fallback if merge agent exceeds maximum shared memory, but the fallback agent still fits, else use
51+
// vsmem-compatible version, so noload2sh
52+
static constexpr bool use_fallback = sizeof(typename fallback_agent_t::TempStorage) <= max_smem_per_block;
4453

4554
public:
46-
using type = ::cuda::std::conditional_t<use_fallback, fallback_agent_t, default_agent_t>;
55+
using type = ::cuda::std::conditional_t<
56+
use_default_load2sh,
57+
default_load2sh_agent_t,
58+
::cuda::std::conditional_t<use_default_noload2sh,
59+
default_noload2sh_agent_t,
60+
::cuda::std::conditional_t<use_fallback, fallback_agent_t, default_noload2sh_agent_t>>>;
4761
};
4862

4963
// Computes the merge path intersections at equally wide intervals. The approach is outlined in the paper:
@@ -143,11 +157,11 @@ __launch_bounds__(
143157
auto& temp_storage = vsmem_helper_t::get_temp_storage(shared_temp_storage, global_temp_storage);
144158
MergeAgent{
145159
temp_storage.Alias(),
146-
try_make_cache_modified_iterator<MergePolicy::LOAD_MODIFIER>(keys1),
147-
try_make_cache_modified_iterator<MergePolicy::LOAD_MODIFIER>(items1),
160+
keys1,
161+
items1,
148162
num_keys1,
149-
try_make_cache_modified_iterator<MergePolicy::LOAD_MODIFIER>(keys2),
150-
try_make_cache_modified_iterator<MergePolicy::LOAD_MODIFIER>(items2),
163+
keys2,
164+
items2,
151165
num_keys2,
152166
keys_result,
153167
items_result,

cub/cub/device/dispatch/tuning/tuning_merge.cuh

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,18 @@ struct policy_hub
8383
BLOCK_STORE_WARP_TRANSPOSE>;
8484
};
8585

86-
using max_policy = policy600;
86+
struct policy800 : ChainedPolicy<800, policy800, policy600>
87+
{
88+
using merge_policy =
89+
agent_policy_t<512,
90+
Nominal4BItemsToItems<tune_type>(15),
91+
BLOCK_LOAD_WARP_TRANSPOSE,
92+
LOAD_DEFAULT,
93+
BLOCK_STORE_WARP_TRANSPOSE,
94+
/* UseBlockLoadToShared = */ true>;
95+
};
96+
97+
using max_policy = policy800;
8798
};
8899
} // namespace merge
89100
} // namespace detail

0 commit comments

Comments
 (0)