Skip to content

Commit a9fd762

Browse files
Translate indices when iterating
1 parent 6f3d1c6 commit a9fd762

File tree

1 file changed

+22
-21
lines changed

1 file changed

+22
-21
lines changed

cub/cub/agent/agent_merge.cuh

Lines changed: 22 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -269,23 +269,8 @@ struct agent_t
269269
static constexpr bool have_items = !::cuda::std::is_same_v<item_type, NullType>;
270270
if constexpr (have_items)
271271
{
272-
// Both of these are only needed when either keys or items or both use BlockLoadToShared introducing padding (that
273-
// can differ between the keys and items)
274-
[[maybe_unsused]] const auto translate_indices = [&](int items2_offset) -> void {
275-
const int diff = items2_offset - keys2_offset;
276-
_CCCL_PRAGMA_UNROLL_FULL()
277-
for (int i = 0; i < items_per_thread; ++i)
278-
{
279-
if (indices[i] >= keys2_offset)
280-
{
281-
indices[i] += diff;
282-
}
283-
}
284-
};
285-
// WAR for MSVC erroring ("declared but never referenced") despite [[maybe_unused]]
286-
(void) translate_indices;
287-
288272
item_type* items1_shared;
273+
int items2_offset;
289274
if constexpr (items_use_block_load_to_shared)
290275
{
291276
::cuda::std::span items1_src{THRUST_NS_QUALIFIER::unwrap_contiguous_iterator(items1_in + keys1_beg),
@@ -304,8 +289,7 @@ struct agent_t
304289
items1_shared = data(load2sh.CopyAsync(items1_buffer, items1_src));
305290
item_type* items2_shared = data(load2sh.CopyAsync(items2_buffer, items2_src));
306291
load2sh.Commit();
307-
const int items2_offset = static_cast<int>(items2_shared - items1_shared);
308-
translate_indices(items2_offset);
292+
items2_offset = static_cast<int>(items2_shared - items1_shared);
309293
load2sh.Wait();
310294
}
311295
else
@@ -320,8 +304,7 @@ struct agent_t
320304
items1_shared = &storage.items_shared[0];
321305
if constexpr (keys_use_block_load_to_shared)
322306
{
323-
const int items2_offset = keys1_count_tile;
324-
translate_indices(items2_offset);
307+
items2_offset = keys1_count_tile;
325308
}
326309
merge_sort::reg_to_shared<threads_per_block>(items1_shared, items_loc);
327310
__syncthreads();
@@ -330,10 +313,28 @@ struct agent_t
330313

331314
// gather items from shared mem
332315
item_type items_loc[items_per_thread];
316+
317+
// Both of these are only needed when either keys or items or both use BlockLoadToShared introducing padding (that
318+
// can differ between the keys and items)
319+
static constexpr bool must_translate_indices = items_use_block_load_to_shared || keys_use_block_load_to_shared;
320+
int diff;
321+
if constexpr (must_translate_indices)
322+
{
323+
diff = items2_offset - keys2_offset;
324+
}
325+
333326
_CCCL_PRAGMA_UNROLL_FULL()
334327
for (int i = 0; i < items_per_thread; ++i)
335328
{
336-
items_loc[i] = items1_shared[indices[i]];
329+
auto index = indices[i];
330+
if constexpr (must_translate_indices)
331+
{
332+
if (index >= keys2_offset)
333+
{
334+
index += diff;
335+
}
336+
}
337+
items_loc[i] = items1_shared[index];
337338
}
338339
__syncthreads();
339340

0 commit comments

Comments
 (0)