1414#endif // no system header
1515
1616#include < cub/agent/agent_merge_sort.cuh>
17- #include < cub/block/block_load .cuh>
17+ #include < cub/block/block_load_to_shared .cuh>
1818#include < cub/block/block_merge_sort.cuh>
1919#include < cub/block/block_store.cuh>
2020#include < cub/iterator/cache_modified_input_iterator.cuh>
2121#include < cub/util_namespace.cuh>
2222#include < cub/util_type.cuh>
2323
24+ #include < thrust/type_traits/is_contiguous_iterator.h>
25+ #include < thrust/type_traits/is_trivially_relocatable.h>
26+ #include < thrust/type_traits/unwrap_contiguous_iterator.h>
27+
28+ #include < cuda/__memory/ptr_rebind.h>
2429#include < cuda/std/__algorithm/max.h>
2530#include < cuda/std/__algorithm/min.h>
31+ #include < cuda/std/__type_traits/conditional.h>
32+ #include < cuda/std/__type_traits/is_same.h>
33+ #include < cuda/std/__type_traits/is_trivially_copyable.h>
34+ #include < cuda/std/cstddef>
35+ #include < cuda/std/span>
2636
2737CUB_NAMESPACE_BEGIN
2838namespace detail ::merge
2939{
30- template <int ThreadsPerBlock, int ItemsPerThread, CacheLoadModifier LoadCacheModifier, BlockStoreAlgorithm StoreAlgorithm>
40+ template <int ThreadsPerBlock,
41+ int ItemsPerThread,
42+ CacheLoadModifier LoadCacheModifier,
43+ BlockStoreAlgorithm StoreAlgorithm,
44+ bool UseBlockLoadToShared = false >
3145struct agent_policy_t
3246{
3347 // do not change data member names, policy_wrapper_t depends on it
@@ -36,6 +50,7 @@ struct agent_policy_t
3650 static constexpr int ITEMS_PER_TILE = BLOCK_THREADS * ITEMS_PER_THREAD;
3751 static constexpr CacheLoadModifier LOAD_MODIFIER = LoadCacheModifier;
3852 static constexpr BlockStoreAlgorithm STORE_ALGORITHM = StoreAlgorithm;
53+ static constexpr bool use_block_load_to_shared = UseBlockLoadToShared;
3954};
4055
4156// TODO(bgruber): can we unify this one with AgentMerge in agent_merge_sort.cuh?
@@ -50,29 +65,78 @@ template <typename Policy,
5065 typename CompareOp>
5166struct agent_t
5267{
53- using policy = Policy;
54-
55- // key and value type are taken from the first input sequence (consistent with old Thrust behavior)
56- using key_type = it_value_t <KeysIt1>;
57- using item_type = it_value_t <ItemsIt1>;
58- using block_store_keys = typename BlockStoreType<Policy, KeysOutputIt, key_type>::type;
59- using block_store_items = typename BlockStoreType<Policy, ItemsOutputIt, item_type>::type;
60-
68+ using policy = Policy;
6169 static constexpr int items_per_thread = Policy::ITEMS_PER_THREAD;
6270 static constexpr int threads_per_block = Policy::BLOCK_THREADS;
6371 static constexpr int items_per_tile = Policy::ITEMS_PER_TILE;
6472
65- union temp_storages
73+ // key and value type are taken from the first input sequence (consistent with old Thrust behavior)
74+ using key_type = it_value_t <KeysIt1>;
75+ using item_type = it_value_t <ItemsIt1>;
76+
77+ using block_load_to_shared = cub::detail::BlockLoadToShared<threads_per_block>;
78+ using block_store_keys = typename BlockStoreType<Policy, KeysOutputIt, key_type>::type;
79+ using block_store_items = typename BlockStoreType<Policy, ItemsOutputIt, item_type>::type;
80+
81+ template <typename ValueT, typename Iter1, typename Iter2>
82+ static constexpr bool use_block_load_to_shared =
83+ Policy::use_block_load_to_shared && (sizeof (ValueT) == alignof (ValueT))
84+ && THRUST_NS_QUALIFIER::is_trivially_relocatable_v<ValueT> //
85+ && THRUST_NS_QUALIFIER::is_contiguous_iterator_v<Iter1> //
86+ && THRUST_NS_QUALIFIER::is_contiguous_iterator_v<Iter2>
87+ && ::cuda::std::is_same_v<ValueT, cub::detail::it_value_t <Iter1>>
88+ && ::cuda::std::is_same_v<ValueT, cub::detail::it_value_t <Iter2>>;
89+
90+ static constexpr bool keys_use_block_load_to_shared = use_block_load_to_shared<key_type, KeysIt1, KeysIt2>;
91+ static constexpr bool items_use_block_load_to_shared = use_block_load_to_shared<item_type, ItemsIt1, ItemsIt2>;
92+ static constexpr bool need_block_load_to_shared = keys_use_block_load_to_shared || items_use_block_load_to_shared;
93+ static constexpr int load2sh_minimum_align = block_load_to_shared::template SharedBufferAlignBytes<char >();
94+
95+ struct empty_t
96+ {
97+ struct TempStorage
98+ {};
99+ _CCCL_DEVICE _CCCL_FORCEINLINE empty_t (TempStorage) {}
100+ };
101+
102+ using optional_load2sh_t = ::cuda::std::conditional_t <need_block_load_to_shared, block_load_to_shared, empty_t >;
103+
104+ template <typename ValueT, bool UseBlockLoadToShared>
105+ struct alignas (UseBlockLoadToShared ? block_load_to_shared::template SharedBufferAlignBytes<ValueT>()
106+ : alignof (ValueT)) buffer_t
107+ {
108+ // Need extra bytes of padding for TMA because this static buffer has to hold the two dynamically sized buffers.
109+ char c_array[UseBlockLoadToShared ? (block_load_to_shared::template SharedBufferSizeBytes<ValueT>(items_per_tile + 1 )
110+ + (alignof (ValueT) < load2sh_minimum_align ? 2 * load2sh_minimum_align : 0 ))
111+ : sizeof (ValueT) * (items_per_tile + 1 )];
112+ };
113+
114+ struct temp_storages_bl2sh
115+ {
116+ union
117+ {
118+ typename block_store_keys::TempStorage store_keys;
119+ typename block_store_items::TempStorage store_items;
120+ buffer_t <key_type, keys_use_block_load_to_shared> keys_shared;
121+ buffer_t <item_type, items_use_block_load_to_shared> items_shared;
122+ };
123+ typename block_load_to_shared::TempStorage load2sh;
124+ };
125+
126+ union temp_storages_fallback
66127 {
67128 typename block_store_keys::TempStorage store_keys;
68129 typename block_store_items::TempStorage store_items;
69130
70- // We could change SerialMerge to avoid reading one item out of bounds and drop the + 1 here. But that would
71- // introduce more branches (about 10% slower on 2^16 problem sizes on RTX 5090 in a first attempt)
72- key_type keys_shared[items_per_tile + 1 ];
73- item_type items_shared[items_per_tile + 1 ] ;
131+ buffer_t <key_type, keys_use_block_load_to_shared> keys_shared;
132+ buffer_t <item_type, items_use_block_load_to_shared> items_shared;
133+
134+ typename empty_t ::TempStorage load2sh ;
74135 };
75136
137+ using temp_storages =
138+ ::cuda::std::conditional_t <need_block_load_to_shared, temp_storages_bl2sh, temp_storages_fallback>;
139+
76140 struct TempStorage : Uninitialized<temp_storages>
77141 {};
78142
@@ -121,18 +185,50 @@ struct agent_t
121185 _CCCL_ASSERT (keys1_count_tile + keys2_count_tile == num_remaining, " " );
122186 }
123187
124- key_type keys_loc[items_per_thread];
188+ optional_load2sh_t load2sh{storage.load2sh };
189+
190+ key_type* keys1_shared;
191+ key_type* keys2_shared;
192+ int keys2_offset;
193+ if constexpr (keys_use_block_load_to_shared)
194+ {
195+ ::cuda::std::span keys1_src{THRUST_NS_QUALIFIER::try_unwrap_contiguous_iterator (keys1_in + keys1_beg),
196+ static_cast <::cuda::std::size_t >(keys1_count_tile)};
197+ ::cuda::std::span keys2_src{THRUST_NS_QUALIFIER::try_unwrap_contiguous_iterator (keys2_in + keys2_beg),
198+ static_cast <::cuda::std::size_t >(keys2_count_tile)};
199+ ::cuda::std::span keys_buffers{storage.keys_shared .c_array };
200+ auto keys1_buffer =
201+ keys_buffers.first (block_load_to_shared::template SharedBufferSizeBytes<key_type>(keys1_count_tile));
202+ auto keys2_buffer =
203+ keys_buffers.last (block_load_to_shared::template SharedBufferSizeBytes<key_type>(keys2_count_tile));
204+ _CCCL_ASSERT (keys1_buffer.end () <= keys2_buffer.begin (),
205+ " Keys buffer needs to be appropriately sized (internal)" );
206+ auto keys1_sh = load2sh.CopyAsync (keys1_buffer, keys1_src);
207+ auto keys2_sh = load2sh.CopyAsync (keys2_buffer, keys2_src);
208+ load2sh.Commit ();
209+ keys1_shared = data (keys1_sh);
210+ keys2_shared = data (keys2_sh);
211+ // Needed for using keys1_shared as one big buffer including both ranges in SerialMerge
212+ keys2_offset = static_cast <int >(keys2_shared - keys1_shared);
213+ load2sh.Wait ();
214+ }
215+ else
125216 {
217+ key_type keys_loc[items_per_thread];
126218 auto keys1_in_cm = try_make_cache_modified_iterator<Policy::LOAD_MODIFIER>(keys1_in);
127219 auto keys2_in_cm = try_make_cache_modified_iterator<Policy::LOAD_MODIFIER>(keys2_in);
128220 merge_sort::gmem_to_reg<threads_per_block, IsFullTile>(
129221 keys_loc, keys1_in_cm + keys1_beg, keys2_in_cm + keys2_beg, keys1_count_tile, keys2_count_tile);
130- merge_sort::reg_to_shared<threads_per_block>(&storage.keys_shared [0 ], keys_loc);
222+ keys1_shared = &::cuda::ptr_rebind<key_type>(storage.keys_shared .c_array )[0 ];
223+ // Needed for using keys1_shared as one big buffer including both ranges in SerialMerge
224+ keys2_offset = keys1_count_tile;
225+ keys2_shared = keys1_shared + keys2_offset;
226+ merge_sort::reg_to_shared<threads_per_block>(keys1_shared, keys_loc);
131227 __syncthreads ();
132228 }
133229
134- // now find the merge path for each of thread .
135- // we can use int type here, because the number of items in shared memory is limited
230+ // Now find the merge path for each of the threads .
231+ // We can use int type here, because the number of items in shared memory is limited.
136232 int diag0_thread = items_per_thread * static_cast <int >(threadIdx .x );
137233 if constexpr (IsFullTile)
138234 {
@@ -144,24 +240,20 @@ struct agent_t
144240 diag0_thread = (::cuda::std::min) (diag0_thread, num_remaining);
145241 }
146242
147- const int keys1_beg_thread = MergePath (
148- &storage.keys_shared [0 ],
149- &storage.keys_shared [keys1_count_tile],
150- keys1_count_tile,
151- keys2_count_tile,
152- diag0_thread,
153- compare_op);
243+ const int keys1_beg_thread =
244+ MergePath (keys1_shared, keys2_shared, keys1_count_tile, keys2_count_tile, diag0_thread, compare_op);
154245 const int keys2_beg_thread = diag0_thread - keys1_beg_thread;
155246
156247 const int keys1_count_thread = keys1_count_tile - keys1_beg_thread;
157248 const int keys2_count_thread = keys2_count_tile - keys2_beg_thread;
158249
159250 // perform serial merge
251+ key_type keys_loc[items_per_thread];
160252 int indices[items_per_thread];
161- SerialMerge (
162- &storage. keys_shared [ 0 ] ,
253+ cub:: SerialMerge (
254+ keys1_shared ,
163255 keys1_beg_thread,
164- keys2_beg_thread + keys1_count_tile ,
256+ keys2_offset + keys2_beg_thread ,
165257 keys1_count_thread,
166258 keys2_count_thread,
167259 keys_loc,
@@ -183,22 +275,73 @@ struct agent_t
183275 static constexpr bool have_items = !::cuda::std::is_same_v<item_type, NullType>;
184276 if constexpr (have_items)
185277 {
186- item_type items_loc[items_per_thread];
278+ // Both of these are only needed when either keys or items or both use BlockLoadToShared introducing padding (that
279+ // can differ between the keys and items)
280+ [[maybe_unsused]] const auto translate_indices = [&](int items2_offset) -> void {
281+ const int diff = items2_offset - keys2_offset;
282+ _CCCL_PRAGMA_UNROLL_FULL ()
283+ for (int i = 0 ; i < items_per_thread; ++i)
284+ {
285+ if (indices[i] >= keys2_offset)
286+ {
287+ indices[i] += diff;
288+ }
289+ }
290+ };
291+ // WAR for MSVC erroring ("declared but never referenced") despite [[maybe_unused]]
292+ (void ) translate_indices;
293+
294+ item_type* items1_shared;
295+ if constexpr (keys_use_block_load_to_shared)
187296 {
188- auto items1_in_cm = try_make_cache_modified_iterator<Policy::LOAD_MODIFIER>(items1_in);
189- auto items2_in_cm = try_make_cache_modified_iterator<Policy::LOAD_MODIFIER>(items2_in);
190- merge_sort::gmem_to_reg<threads_per_block, IsFullTile>(
191- items_loc, items1_in_cm + keys1_beg, items2_in_cm + keys2_beg, keys1_count_tile, keys2_count_tile);
192- __syncthreads (); // block_store_keys above uses SMEM, so make sure all threads are done before we write to it
193- merge_sort::reg_to_shared<threads_per_block>(&storage.items_shared [0 ], items_loc);
297+ ::cuda::std::span items1_src{THRUST_NS_QUALIFIER::try_unwrap_contiguous_iterator (items1_in + keys1_beg),
298+ static_cast <::cuda::std::size_t >(keys1_count_tile)};
299+ ::cuda::std::span items2_src{THRUST_NS_QUALIFIER::try_unwrap_contiguous_iterator (items2_in + keys2_beg),
300+ static_cast <::cuda::std::size_t >(keys2_count_tile)};
301+ ::cuda::std::span items_buffers{storage.items_shared .c_array };
302+ auto items1_buffer =
303+ items_buffers.first (block_load_to_shared::template SharedBufferSizeBytes<item_type>(keys1_count_tile));
304+ auto items2_buffer =
305+ items_buffers.last (block_load_to_shared::template SharedBufferSizeBytes<item_type>(keys2_count_tile));
306+ _CCCL_ASSERT (items1_buffer.end () <= items2_buffer.begin (),
307+ " Items buffer needs to be appropriately sized (internal)" );
308+ // block_store_keys above uses shared memory, so make sure all threads are done before we write
194309 __syncthreads ();
310+ auto items1_sh = load2sh.CopyAsync (items1_buffer, items1_src);
311+ auto items2_sh = load2sh.CopyAsync (items2_buffer, items2_src);
312+ load2sh.Commit ();
313+ items1_shared = data (items1_sh);
314+ item_type* items2_shared = data (items2_sh);
315+ const int items2_offset = static_cast <int >(items2_shared - items1_shared);
316+ translate_indices (items2_offset);
317+ load2sh.Wait ();
318+ }
319+ else
320+ {
321+ item_type items_loc[items_per_thread];
322+ {
323+ auto items1_in_cm = try_make_cache_modified_iterator<Policy::LOAD_MODIFIER>(items1_in);
324+ auto items2_in_cm = try_make_cache_modified_iterator<Policy::LOAD_MODIFIER>(items2_in);
325+ merge_sort::gmem_to_reg<threads_per_block, IsFullTile>(
326+ items_loc, items1_in_cm + keys1_beg, items2_in_cm + keys2_beg, keys1_count_tile, keys2_count_tile);
327+ __syncthreads (); // block_store_keys above uses SMEM, so make sure all threads are done before we write to it
328+ items1_shared = &::cuda::ptr_rebind<item_type>(storage.items_shared .c_array )[0 ];
329+ if constexpr (keys_use_block_load_to_shared)
330+ {
331+ const int items2_offset = keys1_count_tile;
332+ translate_indices (items2_offset);
333+ }
334+ merge_sort::reg_to_shared<threads_per_block>(items1_shared, items_loc);
335+ __syncthreads ();
336+ }
195337 }
196338
197339 // gather items from shared mem
340+ item_type items_loc[items_per_thread];
198341 _CCCL_PRAGMA_UNROLL_FULL ()
199342 for (int i = 0 ; i < items_per_thread; ++i)
200343 {
201- items_loc[i] = storage. items_shared [indices[i]];
344+ items_loc[i] = items1_shared [indices[i]];
202345 }
203346 __syncthreads ();
204347
@@ -222,11 +365,11 @@ struct agent_t
222365 static_cast <int >((::cuda::std::min) (static_cast <Offset>(items_per_tile), keys1_count + keys2_count - tile_base));
223366 if (items_in_tile == items_per_tile)
224367 {
225- consume_tile</* IsFullTile */ true >(tile_idx, tile_base, items_per_tile);
368+ consume_tile</* IsFullTile = */ true >(tile_idx, tile_base, items_per_tile);
226369 }
227370 else
228371 {
229- consume_tile</* IsFullTile */ false >(tile_idx, tile_base, items_in_tile);
372+ consume_tile</* IsFullTile = */ false >(tile_idx, tile_base, items_in_tile);
230373 }
231374 }
232375};
0 commit comments