Skip to content

Commit

Permalink
Specify __serial_merge by unroll factor template param
Browse files Browse the repository at this point in the history
Signed-off-by: Sergey Kopienko <[email protected]>
  • Loading branch information
SergeyKopienko committed Dec 17, 2024
1 parent be021ac commit d3d863d
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ __find_start_point(const _Rng1& __rng1, const _Rng2& __rng2, const _Index __i_el

// Do serial merge of the data from rng1 (starting from start1) and rng2 (starting from start2) and writing
// to rng3 (starting from start3) in 'chunk' steps, but do not exceed the total size of the sequences (n1 and n2)
template <typename _Rng1, typename _Rng2, typename _Rng3, typename _Index, typename _Compare>
template <unsigned int _UnrollFactor = 4, typename _Rng1, typename _Rng2, typename _Rng3, typename _Index, typename _Compare>
void
__serial_merge(const _Rng1& __rng1, const _Rng2& __rng2, _Rng3& __rng3, _Index __start1, _Index __start2,
const _Index __start3, const std::uint8_t __chunk, const _Index __n1, const _Index __n2, _Compare __comp)
Expand All @@ -99,7 +99,7 @@ __serial_merge(const _Rng1& __rng1, const _Rng2& __rng2, _Rng3& __rng3, _Index _

bool __rng1_idx_less__n1, __rng2_idx_less__n2;

#pragma unroll 4
#pragma unroll _UnrollFactor
for (_Index __rng3_idx = __start3; __rng3_idx < __rng3_idx_end; ++__rng3_idx)
{
__rng1_idx_less__n1 = __rng1_idx < __rng1_idx_end;
Expand All @@ -119,6 +119,8 @@ struct __parallel_merge_submitter;
template <typename _IdType, typename... _Name>
struct __parallel_merge_submitter<_IdType, __internal::__optional_kernel_name<_Name...>>
{
static constexpr std::uint32_t __gpu_chunk = 4;

template <typename _ExecutionPolicy, typename _Range1, typename _Range2, typename _Range3, typename _Compare>
auto
operator()(_ExecutionPolicy&& __exec, _Range1&& __rng1, _Range2&& __rng2, _Range3&& __rng3, _Compare __comp) const
Expand All @@ -132,7 +134,7 @@ struct __parallel_merge_submitter<_IdType, __internal::__optional_kernel_name<_N
_PRINT_INFO_IN_DEBUG_MODE(__exec);

// Empirical number of values to process per work-item
const std::uint8_t __chunk = __exec.queue().get_device().is_cpu() ? 128 : 4;
const std::uint8_t __chunk = __exec.queue().get_device().is_cpu() ? 128 : __gpu_chunk;

const _IdType __steps = oneapi::dpl::__internal::__dpl_ceiling_div(__n, __chunk);

Expand All @@ -141,8 +143,8 @@ struct __parallel_merge_submitter<_IdType, __internal::__optional_kernel_name<_N
__cgh.parallel_for<_Name...>(sycl::range</*dim=*/1>(__steps), [=](sycl::item</*dim=*/1> __item_id) {
const _IdType __i_elem = __item_id.get_linear_id() * __chunk;
const auto __start = __find_start_point(__rng1, __rng2, __i_elem, __n1, __n2, __comp);
__serial_merge(__rng1, __rng2, __rng3, __start.first, __start.second, __i_elem, __chunk, __n1, __n2,
__comp);
__serial_merge<__gpu_chunk>(__rng1, __rng2, __rng3, __start.first, __start.second, __i_elem, __chunk,
__n1, __n2, __comp);
});
});
return __future(__event);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,8 @@ struct __merge_sort_global_submitter;
template <typename _IndexT, typename... _GlobalSortName>
struct __merge_sort_global_submitter<_IndexT, __internal::__optional_kernel_name<_GlobalSortName...>>
{
static constexpr std::uint32_t __gpu_chunk = 4;

template <typename _Range, typename _Compare, typename _TempBuf, typename _LeafSizeT>
std::pair<sycl::event, bool>
operator()(sycl::queue& __q, _Range& __rng, _Compare __comp, _LeafSizeT __leaf_size, _TempBuf& __temp_buf,
Expand All @@ -241,7 +243,7 @@ struct __merge_sort_global_submitter<_IndexT, __internal::__optional_kernel_name
const _IndexT __n = __rng.size();
_IndexT __n_sorted = __leaf_size;
const bool __is_cpu = __q.get_device().is_cpu();
const std::uint32_t __chunk = __is_cpu ? 32 : 4;
const std::uint32_t __chunk = __is_cpu ? 32 : __gpu_chunk;
const std::size_t __steps = oneapi::dpl::__internal::__dpl_ceiling_div(__n, __chunk);
bool __data_in_temp = false;

Expand Down Expand Up @@ -272,17 +274,17 @@ struct __merge_sort_global_submitter<_IndexT, __internal::__optional_kernel_name
const oneapi::dpl::__ranges::drop_view_simple __rng2(__dst, __offset + __n1);

const auto start = __find_start_point(__rng1, __rng2, __i_elem_local, __n1, __n2, __comp);
__serial_merge(__rng1, __rng2, __rng /*__rng3*/, start.first, start.second, __i_elem,
__chunk, __n1, __n2, __comp);
__serial_merge<__gpu_chunk>(__rng1, __rng2, __rng /*__rng3*/, start.first, start.second,
__i_elem, __chunk, __n1, __n2, __comp);
}
else
{
const oneapi::dpl::__ranges::drop_view_simple __rng1(__rng, __offset);
const oneapi::dpl::__ranges::drop_view_simple __rng2(__rng, __offset + __n1);

const auto start = __find_start_point(__rng1, __rng2, __i_elem_local, __n1, __n2, __comp);
__serial_merge(__rng1, __rng2, __dst /*__rng3*/, start.first, start.second, __i_elem,
__chunk, __n1, __n2, __comp);
__serial_merge<__gpu_chunk>(__rng1, __rng2, __dst /*__rng3*/, start.first, start.second,
__i_elem, __chunk, __n1, __n2, __comp);
}
});
});
Expand Down

0 comments on commit d3d863d

Please sign in to comment.