1414#endif // no system header
1515
1616#include < cub/agent/agent_merge_sort.cuh>
17- #include < cub/block/block_load .cuh>
17+ #include < cub/block/block_load_to_shared .cuh>
1818#include < cub/block/block_merge_sort.cuh>
1919#include < cub/block/block_store.cuh>
2020#include < cub/iterator/cache_modified_input_iterator.cuh>
2121#include < cub/util_namespace.cuh>
2222#include < cub/util_type.cuh>
2323
24+ #include < thrust/type_traits/is_contiguous_iterator.h>
25+ #include < thrust/type_traits/is_trivially_relocatable.h>
26+ #include < thrust/type_traits/unwrap_contiguous_iterator.h>
27+
28+ #include < cuda/__memory/ptr_rebind.h>
2429#include < cuda/std/__algorithm/max.h>
2530#include < cuda/std/__algorithm/min.h>
31+ #include < cuda/std/__type_traits/conditional.h>
32+ #include < cuda/std/__type_traits/is_same.h>
33+ #include < cuda/std/__type_traits/is_trivially_copyable.h>
34+ #include < cuda/std/cstddef>
35+ #include < cuda/std/span>
2636
2737CUB_NAMESPACE_BEGIN
2838namespace detail
@@ -33,7 +43,8 @@ template <int ThreadsPerBlock,
3343 int ItemsPerThread,
3444 BlockLoadAlgorithm LoadAlgorithm,
3545 CacheLoadModifier LoadCacheModifier,
36- BlockStoreAlgorithm StoreAlgorithm>
46+ BlockStoreAlgorithm StoreAlgorithm,
47+ bool UseBlockLoadToShared = false >
3748struct agent_policy_t
3849{
3950 // do not change data member names, policy_wrapper_t depends on it
@@ -43,6 +54,7 @@ struct agent_policy_t
4354 static constexpr BlockLoadAlgorithm LOAD_ALGORITHM = LoadAlgorithm;
4455 static constexpr CacheLoadModifier LOAD_MODIFIER = LoadCacheModifier;
4556 static constexpr BlockStoreAlgorithm STORE_ALGORITHM = StoreAlgorithm;
57+ static constexpr bool use_block_load_to_shared = UseBlockLoadToShared;
4658};
4759
4860// TODO(bgruber): can we unify this one with AgentMerge in agent_merge_sort.cuh?
@@ -54,48 +66,102 @@ template <typename Policy,
5466 typename KeysOutputIt,
5567 typename ItemsOutputIt,
5668 typename Offset,
57- typename CompareOp>
69+ typename CompareOp,
70+ bool AllowBlockLoadToShared>
5871struct agent_t
5972{
60- using policy = Policy;
73+ using policy = Policy;
74+ static constexpr int items_per_thread = Policy::ITEMS_PER_THREAD;
75+ static constexpr int threads_per_block = Policy::BLOCK_THREADS;
76+ static constexpr Offset items_per_tile = Policy::ITEMS_PER_TILE;
6177
6278 // key and value type are taken from the first input sequence (consistent with old Thrust behavior)
6379 using key_type = it_value_t <KeysIt1>;
6480 using item_type = it_value_t <ItemsIt1>;
6581
66- using keys_load_it1 = try_make_cache_modified_iterator_t <Policy::LOAD_MODIFIER, KeysIt1>;
67- using keys_load_it2 = try_make_cache_modified_iterator_t <Policy::LOAD_MODIFIER, KeysIt2>;
68- using items_load_it1 = try_make_cache_modified_iterator_t <Policy::LOAD_MODIFIER, ItemsIt1>;
69- using items_load_it2 = try_make_cache_modified_iterator_t <Policy::LOAD_MODIFIER, ItemsIt2>;
82+ using block_load_to_shared = cub::detail::BlockLoadToShared<threads_per_block>;
83+ using block_store_keys = typename BlockStoreType<Policy, KeysOutputIt, key_type>::type;
84+ using block_store_items = typename BlockStoreType<Policy, ItemsOutputIt, item_type>::type;
85+
86+ template <typename ValueT, typename Iter1, typename Iter2>
87+ static constexpr bool use_block_load_to_shared =
88+ Policy::use_block_load_to_shared && (sizeof (ValueT) == alignof (ValueT)) && AllowBlockLoadToShared
89+ && THRUST_NS_QUALIFIER::is_trivially_relocatable_v<ValueT> //
90+ && THRUST_NS_QUALIFIER::is_contiguous_iterator_v<Iter1> //
91+ && THRUST_NS_QUALIFIER::is_contiguous_iterator_v<Iter2>
92+ && ::cuda::std::is_same_v<ValueT, cub::detail::it_value_t <Iter1>>
93+ && ::cuda::std::is_same_v<ValueT, cub::detail::it_value_t <Iter2>>;
94+
95+ static constexpr bool keys_use_block_load_to_shared = use_block_load_to_shared<key_type, KeysIt1, KeysIt2>;
96+ static constexpr bool items_use_block_load_to_shared = use_block_load_to_shared<item_type, ItemsIt1, ItemsIt2>;
97+ static constexpr bool need_block_load_to_shared = keys_use_block_load_to_shared || items_use_block_load_to_shared;
98+ static constexpr int load2sh_minimum_align = block_load_to_shared::template SharedBufferAlignBytes<char >();
99+
100+ struct empty_t
101+ {
102+ struct TempStorage
103+ {};
104+ _CCCL_DEVICE _CCCL_FORCEINLINE empty_t (TempStorage) {}
105+ };
106+
107+ using optional_load2sh_t = ::cuda::std::conditional_t <need_block_load_to_shared, block_load_to_shared, empty_t >;
108+
109+ using keys_load_it1 =
110+ ::cuda::std::conditional_t <keys_use_block_load_to_shared,
111+ KeysIt1,
112+ try_make_cache_modified_iterator_t <Policy::LOAD_MODIFIER, KeysIt1>>;
113+ using keys_load_it2 =
114+ ::cuda::std::conditional_t <keys_use_block_load_to_shared,
115+ KeysIt2,
116+ try_make_cache_modified_iterator_t <Policy::LOAD_MODIFIER, KeysIt2>>;
117+ using items_load_it1 =
118+ ::cuda::std::conditional_t <items_use_block_load_to_shared,
119+ ItemsIt1,
120+ try_make_cache_modified_iterator_t <Policy::LOAD_MODIFIER, ItemsIt1>>;
121+ using items_load_it2 =
122+ ::cuda::std::conditional_t <items_use_block_load_to_shared,
123+ ItemsIt2,
124+ try_make_cache_modified_iterator_t <Policy::LOAD_MODIFIER, ItemsIt2>>;
70125
71- using block_load_keys1 = typename BlockLoadType<Policy, keys_load_it1>::type;
72- using block_load_keys2 = typename BlockLoadType<Policy, keys_load_it2>::type;
73- using block_load_items1 = typename BlockLoadType<Policy, items_load_it1>::type;
74- using block_load_items2 = typename BlockLoadType<Policy, items_load_it2>::type;
126+ template <typename ValueT, bool UseBlockLoadToShared>
127+ struct alignas (UseBlockLoadToShared ? block_load_to_shared::template SharedBufferAlignBytes<ValueT>()
128+ : alignof (ValueT)) buffer_t
129+ {
130+ // Need extra bytes of padding for TMA because this static buffer has to hold the two dynamically sized buffers.
131+ char c_array[UseBlockLoadToShared ? (block_load_to_shared::template SharedBufferSizeBytes<ValueT>(items_per_tile + 1 )
132+ + (alignof (ValueT) < load2sh_minimum_align ? 2 * load2sh_minimum_align : 0 ))
133+ : sizeof (ValueT) * (items_per_tile + 1 )];
134+ };
75135
76- using block_store_keys = typename BlockStoreType<Policy, KeysOutputIt, key_type>::type;
77- using block_store_items = typename BlockStoreType<Policy, ItemsOutputIt, item_type>::type;
136+ struct temp_storages_bl2sh
137+ {
138+ union
139+ {
140+ typename block_store_keys::TempStorage store_keys;
141+ typename block_store_items::TempStorage store_items;
142+ buffer_t <key_type, keys_use_block_load_to_shared> keys_shared;
143+ buffer_t <item_type, items_use_block_load_to_shared> items_shared;
144+ };
145+ typename block_load_to_shared::TempStorage load2sh;
146+ };
78147
79- union temp_storages
148+ union temp_storages_fallback
80149 {
81- typename block_load_keys1::TempStorage load_keys1;
82- typename block_load_keys2::TempStorage load_keys2;
83- typename block_load_items1::TempStorage load_items1;
84- typename block_load_items2::TempStorage load_items2;
85150 typename block_store_keys::TempStorage store_keys;
86151 typename block_store_items::TempStorage store_items;
87152
88- key_type keys_shared[Policy::ITEMS_PER_TILE + 1 ];
89- item_type items_shared[Policy::ITEMS_PER_TILE + 1 ];
153+ buffer_t <key_type, keys_use_block_load_to_shared> keys_shared;
154+ buffer_t <item_type, items_use_block_load_to_shared> items_shared;
155+
156+ typename empty_t ::TempStorage load2sh;
90157 };
91158
159+ using temp_storages =
160+ ::cuda::std::conditional_t <need_block_load_to_shared, temp_storages_bl2sh, temp_storages_fallback>;
161+
92162 struct TempStorage : Uninitialized<temp_storages>
93163 {};
94164
95- static constexpr int items_per_thread = Policy::ITEMS_PER_THREAD;
96- static constexpr int threads_per_block = Policy::BLOCK_THREADS;
97- static constexpr Offset items_per_tile = Policy::ITEMS_PER_TILE;
98-
99165 // Per thread data
100166 temp_storages& storage;
101167 keys_load_it1 keys1_in;
@@ -128,18 +194,49 @@ struct agent_t
128194 const int num_keys1 = static_cast <int >(keys1_end - keys1_beg);
129195 const int num_keys2 = static_cast <int >(keys2_end - keys2_beg);
130196
131- key_type keys_loc[items_per_thread];
132- merge_sort::gmem_to_reg<threads_per_block, IsFullTile>(
133- keys_loc, keys1_in + keys1_beg, keys2_in + keys2_beg, num_keys1, num_keys2);
134- merge_sort::reg_to_shared<threads_per_block>(&storage.keys_shared [0 ], keys_loc);
135- __syncthreads ();
197+ optional_load2sh_t load2sh{storage.load2sh };
198+
199+ key_type* keys1_shared;
200+ key_type* keys2_shared;
201+ int keys2_offset;
202+ if constexpr (keys_use_block_load_to_shared)
203+ {
204+ ::cuda::std::span keys1_src{THRUST_NS_QUALIFIER::try_unwrap_contiguous_iterator (keys1_in + keys1_beg),
205+ static_cast <::cuda::std::size_t >(num_keys1)};
206+ ::cuda::std::span keys2_src{THRUST_NS_QUALIFIER::try_unwrap_contiguous_iterator (keys2_in + keys2_beg),
207+ static_cast <::cuda::std::size_t >(num_keys2)};
208+ ::cuda::std::span keys_buffers{storage.keys_shared .c_array };
209+ auto keys1_buffer = keys_buffers.first (block_load_to_shared::template SharedBufferSizeBytes<key_type>(num_keys1));
210+ auto keys2_buffer = keys_buffers.last (block_load_to_shared::template SharedBufferSizeBytes<key_type>(num_keys2));
211+ _CCCL_ASSERT (keys1_buffer.end () <= keys2_buffer.begin (),
212+ " Keys buffer needs to be appropriately sized (internal)" );
213+ auto keys1_sh = load2sh.CopyAsync (keys1_buffer, keys1_src);
214+ auto keys2_sh = load2sh.CopyAsync (keys2_buffer, keys2_src);
215+ load2sh.Commit ();
216+ keys1_shared = data (keys1_sh);
217+ keys2_shared = data (keys2_sh);
218+ // Needed for using keys1_shared as one big buffer including both ranges in SerialMerge
219+ keys2_offset = static_cast <int >(keys2_shared - keys1_shared);
220+ load2sh.Wait ();
221+ }
222+ else
223+ {
224+ key_type keys_loc[items_per_thread];
225+ merge_sort::gmem_to_reg<threads_per_block, IsFullTile>(
226+ keys_loc, keys1_in + keys1_beg, keys2_in + keys2_beg, num_keys1, num_keys2);
227+ keys1_shared = &::cuda::ptr_rebind<key_type>(storage.keys_shared .c_array )[0 ];
228+ // Needed for using keys1_shared as one big buffer including both ranges in SerialMerge
229+ keys2_offset = num_keys1;
230+ keys2_shared = keys1_shared + keys2_offset;
231+ merge_sort::reg_to_shared<threads_per_block>(keys1_shared, keys_loc);
232+ __syncthreads ();
233+ }
136234
137235 // use binary search in shared memory to find merge path for each of thread.
138236 // we can use int type here, because the number of items in shared memory is limited
139237 const int diag0_loc = (::cuda::std::min) (num_keys1 + num_keys2, static_cast <int >(items_per_thread * threadIdx .x ));
140238
141- const int keys1_beg_loc =
142- MergePath (&storage.keys_shared [0 ], &storage.keys_shared [num_keys1], num_keys1, num_keys2, diag0_loc, compare_op);
239+ const int keys1_beg_loc = MergePath (keys1_shared, keys2_shared, num_keys1, num_keys2, diag0_loc, compare_op);
143240 const int keys1_end_loc = num_keys1;
144241 const int keys2_beg_loc = diag0_loc - keys1_beg_loc;
145242 const int keys2_end_loc = num_keys2;
@@ -148,11 +245,12 @@ struct agent_t
148245 const int num_keys2_loc = keys2_end_loc - keys2_beg_loc;
149246
150247 // perform serial merge
248+ key_type keys_loc[items_per_thread];
151249 int indices[items_per_thread];
152250 cub::SerialMerge (
153- &storage. keys_shared [ 0 ] ,
251+ keys1_shared ,
154252 keys1_beg_loc,
155- keys2_beg_loc + num_keys1 ,
253+ keys2_offset + keys2_beg_loc ,
156254 num_keys1_loc,
157255 num_keys2_loc,
158256 keys_loc,
@@ -174,19 +272,67 @@ struct agent_t
174272 static constexpr bool have_items = !::cuda::std::is_same_v<item_type, NullType>;
175273 if constexpr (have_items)
176274 {
177- item_type items_loc[items_per_thread];
178- merge_sort::gmem_to_reg<threads_per_block, IsFullTile>(
179- items_loc, items1_in + keys1_beg, items2_in + keys2_beg, num_keys1, num_keys2);
180- __syncthreads (); // block_store_keys above uses shared memory, so make sure all threads are done before we write
181- // to it
182- merge_sort::reg_to_shared<threads_per_block>(&storage.items_shared [0 ], items_loc);
183- __syncthreads ();
275+ [[maybe_unsused]] const auto translate_indices = [&](int items2_offset) -> void {
276+ const int diff = items2_offset - keys2_offset;
277+ _CCCL_PRAGMA_UNROLL_FULL ()
278+ for (int i = 0 ; i < items_per_thread; ++i)
279+ {
280+ if (indices[i] >= keys2_offset)
281+ {
282+ indices[i] += diff;
283+ }
284+ }
285+ };
286+
287+ item_type* items1_shared;
288+ int items2_offset;
289+ if constexpr (keys_use_block_load_to_shared)
290+ {
291+ ::cuda::std::span items1_src{THRUST_NS_QUALIFIER::try_unwrap_contiguous_iterator (items1_in + keys1_beg),
292+ static_cast <::cuda::std::size_t >(num_keys1)};
293+ ::cuda::std::span items2_src{THRUST_NS_QUALIFIER::try_unwrap_contiguous_iterator (items2_in + keys2_beg),
294+ static_cast <::cuda::std::size_t >(num_keys2)};
295+ ::cuda::std::span items_buffers{storage.items_shared .c_array };
296+ auto items1_buffer =
297+ items_buffers.first (block_load_to_shared::template SharedBufferSizeBytes<item_type>(num_keys1));
298+ auto items2_buffer =
299+ items_buffers.last (block_load_to_shared::template SharedBufferSizeBytes<item_type>(num_keys2));
300+ _CCCL_ASSERT (items1_buffer.end () <= items2_buffer.begin (),
301+ " Items buffer needs to be appropriately sized (internal)" );
302+ // block_store_keys above uses shared memory, so make sure all threads are done before we write
303+ __syncthreads ();
304+ auto items1_sh = load2sh.CopyAsync (items1_buffer, items1_src);
305+ auto items2_sh = load2sh.CopyAsync (items2_buffer, items2_src);
306+ load2sh.Commit ();
307+ items1_shared = data (items1_sh);
308+ item_type* items2_shared = data (items2_sh);
309+ items2_offset = static_cast <int >(items2_shared - items1_shared);
310+ translate_indices (items2_offset);
311+ load2sh.Wait ();
312+ }
313+ else
314+ {
315+ item_type items_loc[items_per_thread];
316+ merge_sort::gmem_to_reg<threads_per_block, IsFullTile>(
317+ items_loc, items1_in + keys1_beg, items2_in + keys2_beg, num_keys1, num_keys2);
318+ __syncthreads (); // block_store_keys above uses shared memory, so make sure all threads are done before we write
319+ // to it
320+ items1_shared = &::cuda::ptr_rebind<item_type>(storage.items_shared .c_array )[0 ];
321+ items2_offset = num_keys1;
322+ if constexpr (keys_use_block_load_to_shared)
323+ {
324+ translate_indices (items2_offset);
325+ }
326+ merge_sort::reg_to_shared<threads_per_block>(items1_shared, items_loc);
327+ __syncthreads ();
328+ }
184329
185330 // gather items from shared mem
331+ item_type items_loc[items_per_thread];
186332 _CCCL_PRAGMA_UNROLL_FULL ()
187333 for (int i = 0 ; i < items_per_thread; ++i)
188334 {
189- items_loc[i] = storage. items_shared [indices[i]];
335+ items_loc[i] = items1_shared [indices[i]];
190336 }
191337 __syncthreads ();
192338
0 commit comments