Skip to content

Commit

Permalink
include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl_merge.h - …
Browse files Browse the repository at this point in the history
…using __parallel_merge_submitter_large for merge data equal or greater then 4M items

Signed-off-by: Sergey Kopienko <[email protected]>
  • Loading branch information
SergeyKopienko committed Nov 7, 2024
1 parent b33656a commit a6164fd
Showing 1 changed file with 44 additions and 17 deletions.
61 changes: 44 additions & 17 deletions include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl_merge.h
Original file line number Diff line number Diff line change
Expand Up @@ -275,8 +275,8 @@ struct __parallel_merge_submitter<_IdType, __internal::__optional_kernel_name<_N
}
};

template <typename _IdType, typename... _Name>
struct __parallel_merge_submitter_large<_IdType, __internal::__optional_kernel_name<_Name...>>
template <typename _IdType, typename _KernelName>
struct __parallel_merge_submitter_large<_IdType, __internal::__optional_kernel_name<_KernelName>>
{
template <typename _ExecutionPolicy, typename _Range1, typename _Range2, typename _Range3, typename _Compare>
auto
Expand All @@ -294,7 +294,7 @@ struct __parallel_merge_submitter_large<_IdType, __internal::__optional_kernel_n

using _FindSplitPointsKernelOnMidDiagonal =
oneapi::dpl::__par_backend_hetero::__internal::__kernel_name_generator<
_find_split_points_kernel_on_mid_diagonal, _CustomName, _Range1, _Range2, _IdType, _Compare>;
_find_split_points_kernel_on_mid_diagonal, _KernelName, _CustomName, _Range1, _Range2, _IdType, _Compare>;

// Empirical number of values to process per work-item
const std::uint8_t __chunk = __exec.queue().get_device().is_cpu() ? 128 : 4;
Expand Down Expand Up @@ -343,7 +343,7 @@ struct __parallel_merge_submitter_large<_IdType, __internal::__optional_kernel_n

__cgh.depends_on(__event);

__cgh.parallel_for<_Name...>(sycl::range</*dim=*/1>(__steps), [=](sycl::item</*dim=*/1> __item_id) {
__cgh.parallel_for<_KernelName>(sycl::range</*dim=*/1>(__steps), [=](sycl::item</*dim=*/1> __item_id) {
auto __global_idx = __item_id.get_linear_id();
const _IdType __i_elem = __global_idx * __chunk;

Expand Down Expand Up @@ -379,6 +379,9 @@ struct __parallel_merge_submitter_large<_IdType, __internal::__optional_kernel_n
template <typename... _Name>
class __merge_kernel_name;

template <typename... _Name>
class __merge_kernel_name_large;

template <typename _ExecutionPolicy, typename _Range1, typename _Range2, typename _Range3, typename _Compare>
auto
__parallel_merge(oneapi::dpl::__internal::__device_backend_tag, _ExecutionPolicy&& __exec, _Range1&& __rng1,
Expand All @@ -387,23 +390,47 @@ __parallel_merge(oneapi::dpl::__internal::__device_backend_tag, _ExecutionPolicy
using _CustomName = oneapi::dpl::__internal::__policy_kernel_name<_ExecutionPolicy>;

const auto __n = __rng1.size() + __rng2.size();
if (__n <= std::numeric_limits<std::uint32_t>::max())
if (__n < 4 * 1'048'576)
{
using _WiIndex = std::uint32_t;
using _MergeKernel = oneapi::dpl::__par_backend_hetero::__internal::__kernel_name_provider<
__merge_kernel_name<_CustomName, _WiIndex>>;
return __parallel_merge_submitter<_WiIndex, _MergeKernel>()(
std::forward<_ExecutionPolicy>(__exec), std::forward<_Range1>(__rng1), std::forward<_Range2>(__rng2),
std::forward<_Range3>(__rng3), __comp);
if (__n <= std::numeric_limits<std::uint32_t>::max())
{
using _WiIndex = std::uint32_t;
using _MergeKernel = oneapi::dpl::__par_backend_hetero::__internal::__kernel_name_provider<
__merge_kernel_name<_CustomName, _WiIndex>>;
return __parallel_merge_submitter<_WiIndex, _MergeKernel>()(
std::forward<_ExecutionPolicy>(__exec), std::forward<_Range1>(__rng1), std::forward<_Range2>(__rng2),
std::forward<_Range3>(__rng3), __comp);
}
else
{
using _WiIndex = std::uint64_t;
using _MergeKernel = oneapi::dpl::__par_backend_hetero::__internal::__kernel_name_provider<
__merge_kernel_name<_CustomName, _WiIndex>>;
return __parallel_merge_submitter<_WiIndex, _MergeKernel>()(
std::forward<_ExecutionPolicy>(__exec), std::forward<_Range1>(__rng1), std::forward<_Range2>(__rng2),
std::forward<_Range3>(__rng3), __comp);
}
}
else
{
using _WiIndex = std::uint64_t;
using _MergeKernel = oneapi::dpl::__par_backend_hetero::__internal::__kernel_name_provider<
__merge_kernel_name<_CustomName, _WiIndex>>;
return __parallel_merge_submitter<_WiIndex, _MergeKernel>()(
std::forward<_ExecutionPolicy>(__exec), std::forward<_Range1>(__rng1), std::forward<_Range2>(__rng2),
std::forward<_Range3>(__rng3), __comp);
if (__n <= std::numeric_limits<std::uint32_t>::max())
{
using _WiIndex = std::uint32_t;
using _MergeKernelLarge = oneapi::dpl::__par_backend_hetero::__internal::__kernel_name_provider<
__merge_kernel_name_large<_CustomName, _WiIndex>>;
return __parallel_merge_submitter_large<_WiIndex, _MergeKernelLarge>()(
std::forward<_ExecutionPolicy>(__exec), std::forward<_Range1>(__rng1), std::forward<_Range2>(__rng2),
std::forward<_Range3>(__rng3), __comp);
}
else
{
using _WiIndex = std::uint64_t;
using _MergeKernelLarge = oneapi::dpl::__par_backend_hetero::__internal::__kernel_name_provider<
__merge_kernel_name_large<_CustomName, _WiIndex>>;
return __parallel_merge_submitter_large<_WiIndex, _MergeKernelLarge>()(
std::forward<_ExecutionPolicy>(__exec), std::forward<_Range1>(__rng1), std::forward<_Range2>(__rng2),
std::forward<_Range3>(__rng3), __comp);
}
}
}

Expand Down

0 comments on commit a6164fd

Please sign in to comment.