Skip to content

Commit 9363b0e

Browse files
author
pauleonix
committed
Use BlockLoadToShared in DeviceMerge
1 parent c4d38b7 commit 9363b0e

File tree

3 files changed

+203
-46
lines changed

3 files changed

+203
-46
lines changed

cub/cub/agent/agent_merge.cuh

Lines changed: 187 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,24 @@
1414
#endif // no system header
1515

1616
#include <cub/agent/agent_merge_sort.cuh>
17-
#include <cub/block/block_load.cuh>
17+
#include <cub/block/block_load_to_shared.cuh>
1818
#include <cub/block/block_merge_sort.cuh>
1919
#include <cub/block/block_store.cuh>
2020
#include <cub/iterator/cache_modified_input_iterator.cuh>
2121
#include <cub/util_namespace.cuh>
2222
#include <cub/util_type.cuh>
2323

24+
#include <thrust/type_traits/is_contiguous_iterator.h>
25+
#include <thrust/type_traits/unwrap_contiguous_iterator.h>
26+
27+
#include <cuda/__memory/ptr_rebind.h>
2428
#include <cuda/std/__algorithm/max.h>
2529
#include <cuda/std/__algorithm/min.h>
30+
#include <cuda/std/__type_traits/conditional.h>
31+
#include <cuda/std/__type_traits/is_same.h>
32+
#include <cuda/std/__type_traits/is_trivially_copyable.h>
33+
#include <cuda/std/cstddef>
34+
#include <cuda/std/span>
2635

2736
CUB_NAMESPACE_BEGIN
2837
namespace detail
@@ -33,7 +42,8 @@ template <int ThreadsPerBlock,
3342
int ItemsPerThread,
3443
BlockLoadAlgorithm LoadAlgorithm,
3544
CacheLoadModifier LoadCacheModifier,
36-
BlockStoreAlgorithm StoreAlgorithm>
45+
BlockStoreAlgorithm StoreAlgorithm,
46+
bool UseBlockLoadToShared = false>
3747
struct agent_policy_t
3848
{
3949
// do not change data member names, policy_wrapper_t depends on it
@@ -43,6 +53,7 @@ struct agent_policy_t
4353
static constexpr BlockLoadAlgorithm LOAD_ALGORITHM = LoadAlgorithm;
4454
static constexpr CacheLoadModifier LOAD_MODIFIER = LoadCacheModifier;
4555
static constexpr BlockStoreAlgorithm STORE_ALGORITHM = StoreAlgorithm;
56+
static constexpr bool use_block_load_to_shared = UseBlockLoadToShared;
4657
};
4758

4859
// TODO(bgruber): can we unify this one with AgentMerge in agent_merge_sort.cuh?
@@ -57,45 +68,98 @@ template <typename Policy,
5768
typename CompareOp>
5869
struct agent_t
5970
{
60-
using policy = Policy;
71+
using policy = Policy;
72+
static constexpr int items_per_thread = Policy::ITEMS_PER_THREAD;
73+
static constexpr int threads_per_block = Policy::BLOCK_THREADS;
74+
static constexpr Offset items_per_tile = Policy::ITEMS_PER_TILE;
6175

6276
// key and value type are taken from the first input sequence (consistent with old Thrust behavior)
6377
using key_type = it_value_t<KeysIt1>;
6478
using item_type = it_value_t<ItemsIt1>;
6579

66-
using keys_load_it1 = try_make_cache_modified_iterator_t<Policy::LOAD_MODIFIER, KeysIt1>;
67-
using keys_load_it2 = try_make_cache_modified_iterator_t<Policy::LOAD_MODIFIER, KeysIt2>;
68-
using items_load_it1 = try_make_cache_modified_iterator_t<Policy::LOAD_MODIFIER, ItemsIt1>;
69-
using items_load_it2 = try_make_cache_modified_iterator_t<Policy::LOAD_MODIFIER, ItemsIt2>;
80+
using block_load_to_shared = cub::detail::BlockLoadToShared<threads_per_block>;
81+
using block_store_keys = typename BlockStoreType<Policy, KeysOutputIt, key_type>::type;
82+
using block_store_items = typename BlockStoreType<Policy, ItemsOutputIt, item_type>::type;
83+
84+
template <typename ValueT, typename Iter1, typename Iter2>
85+
static constexpr bool use_block_load_to_shared =
86+
Policy::use_block_load_to_shared && (sizeof(ValueT) == alignof(ValueT))
87+
&& ::cuda::std::is_trivially_copyable_v<ValueT> //
88+
&& THRUST_NS_QUALIFIER::is_contiguous_iterator_v<Iter1> //
89+
&& THRUST_NS_QUALIFIER::is_contiguous_iterator_v<Iter2>
90+
&& ::cuda::std::is_same_v<ValueT, cub::detail::it_value_t<Iter1>>
91+
&& ::cuda::std::is_same_v<ValueT, cub::detail::it_value_t<Iter2>>;
92+
93+
static constexpr bool keys_use_block_load_to_shared = use_block_load_to_shared<key_type, KeysIt1, KeysIt2>;
94+
static constexpr bool items_use_block_load_to_shared = use_block_load_to_shared<item_type, ItemsIt1, ItemsIt2>;
95+
static constexpr bool need_block_load_to_shared = keys_use_block_load_to_shared || items_use_block_load_to_shared;
96+
static constexpr int load2sh_minimum_align = block_load_to_shared::template SharedBufferAlignBytes<char>();
97+
98+
struct empty_t
99+
{
100+
struct TempStorage
101+
{};
102+
_CCCL_DEVICE _CCCL_FORCEINLINE empty_t(TempStorage) {}
103+
};
104+
105+
using optional_load2sh_t = ::cuda::std::conditional_t<need_block_load_to_shared, block_load_to_shared, empty_t>;
106+
107+
using keys_load_it1 =
108+
::cuda::std::conditional_t<keys_use_block_load_to_shared,
109+
KeysIt1,
110+
try_make_cache_modified_iterator_t<Policy::LOAD_MODIFIER, KeysIt1>>;
111+
using keys_load_it2 =
112+
::cuda::std::conditional_t<keys_use_block_load_to_shared,
113+
KeysIt2,
114+
try_make_cache_modified_iterator_t<Policy::LOAD_MODIFIER, KeysIt2>>;
115+
using items_load_it1 =
116+
::cuda::std::conditional_t<items_use_block_load_to_shared,
117+
ItemsIt1,
118+
try_make_cache_modified_iterator_t<Policy::LOAD_MODIFIER, ItemsIt1>>;
119+
using items_load_it2 =
120+
::cuda::std::conditional_t<items_use_block_load_to_shared,
121+
ItemsIt2,
122+
try_make_cache_modified_iterator_t<Policy::LOAD_MODIFIER, ItemsIt2>>;
70123

71-
using block_load_keys1 = typename BlockLoadType<Policy, keys_load_it1>::type;
72-
using block_load_keys2 = typename BlockLoadType<Policy, keys_load_it2>::type;
73-
using block_load_items1 = typename BlockLoadType<Policy, items_load_it1>::type;
74-
using block_load_items2 = typename BlockLoadType<Policy, items_load_it2>::type;
124+
template <typename ValueT, bool UseBlockLoadToShared>
125+
struct alignas(UseBlockLoadToShared ? block_load_to_shared::template SharedBufferAlignBytes<ValueT>()
126+
: alignof(ValueT)) buffer_t
127+
{
128+
// Need extra bytes of padding for TMA because this static buffer has to hold the two dynamically sized buffers.
129+
char c_array[UseBlockLoadToShared ? (block_load_to_shared::template SharedBufferSizeBytes<ValueT>(items_per_tile + 1)
130+
+ 2 * load2sh_minimum_align)
131+
: sizeof(ValueT) * (items_per_tile + 1)];
132+
};
75133

76-
using block_store_keys = typename BlockStoreType<Policy, KeysOutputIt, key_type>::type;
77-
using block_store_items = typename BlockStoreType<Policy, ItemsOutputIt, item_type>::type;
134+
struct temp_storages_bl2sh
135+
{
136+
union
137+
{
138+
typename block_store_keys::TempStorage store_keys;
139+
typename block_store_items::TempStorage store_items;
140+
buffer_t<key_type, keys_use_block_load_to_shared> keys_shared;
141+
buffer_t<item_type, items_use_block_load_to_shared> items_shared;
142+
};
143+
typename block_load_to_shared::TempStorage load2sh;
144+
};
78145

79-
union temp_storages
146+
union temp_storages_fallback
80147
{
81-
typename block_load_keys1::TempStorage load_keys1;
82-
typename block_load_keys2::TempStorage load_keys2;
83-
typename block_load_items1::TempStorage load_items1;
84-
typename block_load_items2::TempStorage load_items2;
85148
typename block_store_keys::TempStorage store_keys;
86149
typename block_store_items::TempStorage store_items;
87150

88-
key_type keys_shared[Policy::ITEMS_PER_TILE + 1];
89-
item_type items_shared[Policy::ITEMS_PER_TILE + 1];
151+
buffer_t<key_type, keys_use_block_load_to_shared> keys_shared;
152+
buffer_t<item_type, items_use_block_load_to_shared> items_shared;
153+
154+
typename empty_t::TempStorage load2sh;
90155
};
91156

157+
using temp_storages =
158+
::cuda::std::conditional_t<need_block_load_to_shared, temp_storages_bl2sh, temp_storages_fallback>;
159+
92160
struct TempStorage : Uninitialized<temp_storages>
93161
{};
94162

95-
static constexpr int items_per_thread = Policy::ITEMS_PER_THREAD;
96-
static constexpr int threads_per_block = Policy::BLOCK_THREADS;
97-
static constexpr Offset items_per_tile = Policy::ITEMS_PER_TILE;
98-
99163
// Per thread data
100164
temp_storages& storage;
101165
keys_load_it1 keys1_in;
@@ -128,18 +192,49 @@ struct agent_t
128192
const int num_keys1 = static_cast<int>(keys1_end - keys1_beg);
129193
const int num_keys2 = static_cast<int>(keys2_end - keys2_beg);
130194

131-
key_type keys_loc[items_per_thread];
132-
merge_sort::gmem_to_reg<threads_per_block, IsFullTile>(
133-
keys_loc, keys1_in + keys1_beg, keys2_in + keys2_beg, num_keys1, num_keys2);
134-
merge_sort::reg_to_shared<threads_per_block>(&storage.keys_shared[0], keys_loc);
135-
__syncthreads();
195+
optional_load2sh_t load2sh{storage.load2sh};
196+
197+
key_type* keys1_shared;
198+
key_type* keys2_shared;
199+
int keys2_offset;
200+
if constexpr (keys_use_block_load_to_shared)
201+
{
202+
::cuda::std::span keys1_src{THRUST_NS_QUALIFIER::try_unwrap_contiguous_iterator(keys1_in + keys1_beg),
203+
static_cast<::cuda::std::size_t>(num_keys1)};
204+
::cuda::std::span keys2_src{THRUST_NS_QUALIFIER::try_unwrap_contiguous_iterator(keys2_in + keys2_beg),
205+
static_cast<::cuda::std::size_t>(num_keys2)};
206+
::cuda::std::span keys_buffers{storage.keys_shared.c_array};
207+
auto keys1_buffer = keys_buffers.first(block_load_to_shared::template SharedBufferSizeBytes<key_type>(num_keys1));
208+
auto keys2_buffer = keys_buffers.last(block_load_to_shared::template SharedBufferSizeBytes<key_type>(num_keys2));
209+
_CCCL_ASSERT(keys1_buffer.end() <= keys2_buffer.begin(),
210+
"Keys buffer needs to be appropriately sized (internal)");
211+
auto keys1_sh = load2sh.template CopyAsync<key_type>(keys1_buffer, keys1_src);
212+
auto keys2_sh = load2sh.template CopyAsync<key_type>(keys2_buffer, keys2_src);
213+
load2sh.Commit();
214+
keys1_shared = data(keys1_sh);
215+
keys2_shared = data(keys2_sh);
216+
// Needed for using keys1_shared as one big buffer including both ranges in SerialMerge
217+
keys2_offset = static_cast<int>(keys2_shared - keys1_shared);
218+
load2sh.Wait();
219+
}
220+
else
221+
{
222+
key_type keys_loc[items_per_thread];
223+
merge_sort::gmem_to_reg<threads_per_block, IsFullTile>(
224+
keys_loc, keys1_in + keys1_beg, keys2_in + keys2_beg, num_keys1, num_keys2);
225+
keys1_shared = &::cuda::ptr_rebind<key_type>(storage.keys_shared.c_array)[0];
226+
// Needed for using keys1_shared as one big buffer including both ranges in SerialMerge
227+
keys2_offset = num_keys1;
228+
keys2_shared = keys1_shared + keys2_offset;
229+
merge_sort::reg_to_shared<threads_per_block>(keys1_shared, keys_loc);
230+
__syncthreads();
231+
}
136232

137233
// use binary search in shared memory to find merge path for each of thread.
138234
// we can use int type here, because the number of items in shared memory is limited
139235
const int diag0_loc = (::cuda::std::min) (num_keys1 + num_keys2, static_cast<int>(items_per_thread * threadIdx.x));
140236

141-
const int keys1_beg_loc =
142-
MergePath(&storage.keys_shared[0], &storage.keys_shared[num_keys1], num_keys1, num_keys2, diag0_loc, compare_op);
237+
const int keys1_beg_loc = MergePath(keys1_shared, keys2_shared, num_keys1, num_keys2, diag0_loc, compare_op);
143238
const int keys1_end_loc = num_keys1;
144239
const int keys2_beg_loc = diag0_loc - keys1_beg_loc;
145240
const int keys2_end_loc = num_keys2;
@@ -148,11 +243,12 @@ struct agent_t
148243
const int num_keys2_loc = keys2_end_loc - keys2_beg_loc;
149244

150245
// perform serial merge
246+
key_type keys_loc[items_per_thread];
151247
int indices[items_per_thread];
152248
cub::SerialMerge(
153-
&storage.keys_shared[0],
249+
keys1_shared,
154250
keys1_beg_loc,
155-
keys2_beg_loc + num_keys1,
251+
keys2_offset + keys2_beg_loc,
156252
num_keys1_loc,
157253
num_keys2_loc,
158254
keys_loc,
@@ -174,19 +270,69 @@ struct agent_t
174270
static constexpr bool have_items = !::cuda::std::is_same_v<item_type, NullType>;
175271
if constexpr (have_items)
176272
{
177-
item_type items_loc[items_per_thread];
178-
merge_sort::gmem_to_reg<threads_per_block, IsFullTile>(
179-
items_loc, items1_in + keys1_beg, items2_in + keys2_beg, num_keys1, num_keys2);
180-
__syncthreads(); // block_store_keys above uses shared memory, so make sure all threads are done before we write
181-
// to it
182-
merge_sort::reg_to_shared<threads_per_block>(&storage.items_shared[0], items_loc);
183-
__syncthreads();
273+
const auto translate_indices = [&](int items2_offset) -> void {
274+
const int diff = items2_offset - keys2_offset;
275+
_CCCL_PRAGMA_UNROLL_FULL()
276+
for (int i = 0; i < items_per_thread; ++i)
277+
{
278+
if (indices[i] >= keys2_offset)
279+
{
280+
indices[i] += diff;
281+
}
282+
}
283+
};
284+
285+
item_type* items1_shared;
286+
item_type* items2_shared;
287+
int items2_offset;
288+
if constexpr (keys_use_block_load_to_shared)
289+
{
290+
::cuda::std::span items1_src{THRUST_NS_QUALIFIER::try_unwrap_contiguous_iterator(items1_in + keys1_beg),
291+
static_cast<::cuda::std::size_t>(num_keys1)};
292+
::cuda::std::span items2_src{THRUST_NS_QUALIFIER::try_unwrap_contiguous_iterator(items2_in + keys2_beg),
293+
static_cast<::cuda::std::size_t>(num_keys2)};
294+
::cuda::std::span items_buffers{storage.items_shared.c_array};
295+
auto items1_buffer =
296+
items_buffers.first(block_load_to_shared::template SharedBufferSizeBytes<item_type>(num_keys1));
297+
auto items2_buffer =
298+
items_buffers.last(block_load_to_shared::template SharedBufferSizeBytes<item_type>(num_keys2));
299+
_CCCL_ASSERT(items1_buffer.end() <= items2_buffer.begin(),
300+
"Items buffer needs to be appropriately sized (internal)");
301+
// block_store_keys above uses shared memory, so make sure all threads are done before we write
302+
__syncthreads();
303+
auto items1_sh = load2sh.template CopyAsync<item_type>(items1_buffer, items1_src);
304+
auto items2_sh = load2sh.template CopyAsync<item_type>(items2_buffer, items2_src);
305+
load2sh.Commit();
306+
items1_shared = data(items1_sh);
307+
items2_shared = data(items2_sh);
308+
items2_offset = static_cast<int>(items2_shared - items1_shared);
309+
translate_indices(items2_offset);
310+
load2sh.Wait();
311+
}
312+
else
313+
{
314+
item_type items_loc[items_per_thread];
315+
merge_sort::gmem_to_reg<threads_per_block, IsFullTile>(
316+
items_loc, items1_in + keys1_beg, items2_in + keys2_beg, num_keys1, num_keys2);
317+
__syncthreads(); // block_store_keys above uses shared memory, so make sure all threads are done before we write
318+
// to it
319+
items1_shared = &::cuda::ptr_rebind<item_type>(storage.items_shared.c_array)[0];
320+
items2_offset = num_keys1;
321+
items2_shared = items1_shared + items2_offset;
322+
if constexpr (keys_use_block_load_to_shared)
323+
{
324+
translate_indices(items2_offset);
325+
}
326+
merge_sort::reg_to_shared<threads_per_block>(items1_shared, items_loc);
327+
__syncthreads();
328+
}
184329

185330
// gather items from shared mem
331+
item_type items_loc[items_per_thread];
186332
_CCCL_PRAGMA_UNROLL_FULL()
187333
for (int i = 0; i < items_per_thread; ++i)
188334
{
189-
items_loc[i] = storage.items_shared[indices[i]];
335+
items_loc[i] = items1_shared[indices[i]];
190336
}
191337
__syncthreads();
192338

cub/cub/device/dispatch/dispatch_merge.cuh

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -143,11 +143,11 @@ __launch_bounds__(
143143
auto& temp_storage = vsmem_helper_t::get_temp_storage(shared_temp_storage, global_temp_storage);
144144
MergeAgent{
145145
temp_storage.Alias(),
146-
try_make_cache_modified_iterator<MergePolicy::LOAD_MODIFIER>(keys1),
147-
try_make_cache_modified_iterator<MergePolicy::LOAD_MODIFIER>(items1),
146+
keys1,
147+
items1,
148148
num_keys1,
149-
try_make_cache_modified_iterator<MergePolicy::LOAD_MODIFIER>(keys2),
150-
try_make_cache_modified_iterator<MergePolicy::LOAD_MODIFIER>(items2),
149+
keys2,
150+
items2,
151151
num_keys2,
152152
keys_result,
153153
items_result,

cub/cub/device/dispatch/tuning/tuning_merge.cuh

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,18 @@ struct policy_hub
8383
BLOCK_STORE_WARP_TRANSPOSE>;
8484
};
8585

86-
using max_policy = policy600;
86+
struct policy800 : ChainedPolicy<800, policy800, policy600>
87+
{
88+
using merge_policy =
89+
agent_policy_t<512,
90+
Nominal4BItemsToItems<tune_type>(15),
91+
BLOCK_LOAD_WARP_TRANSPOSE,
92+
LOAD_DEFAULT,
93+
BLOCK_STORE_WARP_TRANSPOSE,
94+
/* UseBlockLoadToShared = */ true>;
95+
};
96+
97+
using max_policy = policy800;
8798
};
8899
} // namespace merge
89100
} // namespace detail

0 commit comments

Comments
 (0)