diff --git a/include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl_merge.h b/include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl_merge.h index 10ddf37fc50..9e5451c69a9 100644 --- a/include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl_merge.h +++ b/include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl_merge.h @@ -275,8 +275,8 @@ struct __parallel_merge_submitter<_IdType, __internal::__optional_kernel_name<_N } }; -template -struct __parallel_merge_submitter_large<_IdType, __internal::__optional_kernel_name<_Name...>> +template +struct __parallel_merge_submitter_large<_IdType, __internal::__optional_kernel_name<_KernelName>> { template auto @@ -294,7 +294,7 @@ struct __parallel_merge_submitter_large<_IdType, __internal::__optional_kernel_n using _FindSplitPointsKernelOnMidDiagonal = oneapi::dpl::__par_backend_hetero::__internal::__kernel_name_generator< - _find_split_points_kernel_on_mid_diagonal, _CustomName, _Range1, _Range2, _IdType, _Compare>; + _find_split_points_kernel_on_mid_diagonal, _KernelName, _CustomName, _Range1, _Range2, _IdType, _Compare>; // Empirical number of values to process per work-item const std::uint8_t __chunk = __exec.queue().get_device().is_cpu() ? 128 : 4; @@ -343,7 +343,7 @@ struct __parallel_merge_submitter_large<_IdType, __internal::__optional_kernel_n __cgh.depends_on(__event); - __cgh.parallel_for<_Name...>(sycl::range(__steps), [=](sycl::item __item_id) { + __cgh.parallel_for<_KernelName>(sycl::range(__steps), [=](sycl::item __item_id) { auto __global_idx = __item_id.get_linear_id(); const _IdType __i_elem = __global_idx * __chunk; @@ -379,6 +379,9 @@ struct __parallel_merge_submitter_large<_IdType, __internal::__optional_kernel_n template class __merge_kernel_name; +template +class __merge_kernel_name_large; + template auto __parallel_merge(oneapi::dpl::__internal::__device_backend_tag, _ExecutionPolicy&& __exec, _Range1&& __rng1, @@ -387,23 +390,47 @@ __parallel_merge(oneapi::dpl::__internal::__device_backend_tag, _ExecutionPolicy using _CustomName = oneapi::dpl::__internal::__policy_kernel_name<_ExecutionPolicy>; const auto __n = __rng1.size() + __rng2.size(); - if (__n <= std::numeric_limits::max()) + if (__n < 4 * 1'048'576) { - using _WiIndex = std::uint32_t; - using _MergeKernel = oneapi::dpl::__par_backend_hetero::__internal::__kernel_name_provider< - __merge_kernel_name<_CustomName, _WiIndex>>; - return __parallel_merge_submitter<_WiIndex, _MergeKernel>()( - std::forward<_ExecutionPolicy>(__exec), std::forward<_Range1>(__rng1), std::forward<_Range2>(__rng2), - std::forward<_Range3>(__rng3), __comp); + if (__n <= std::numeric_limits::max()) + { + using _WiIndex = std::uint32_t; + using _MergeKernel = oneapi::dpl::__par_backend_hetero::__internal::__kernel_name_provider< + __merge_kernel_name<_CustomName, _WiIndex>>; + return __parallel_merge_submitter<_WiIndex, _MergeKernel>()( + std::forward<_ExecutionPolicy>(__exec), std::forward<_Range1>(__rng1), std::forward<_Range2>(__rng2), + std::forward<_Range3>(__rng3), __comp); + } + else + { + using _WiIndex = std::uint64_t; + using _MergeKernel = oneapi::dpl::__par_backend_hetero::__internal::__kernel_name_provider< + __merge_kernel_name<_CustomName, _WiIndex>>; + return __parallel_merge_submitter<_WiIndex, _MergeKernel>()( + std::forward<_ExecutionPolicy>(__exec), std::forward<_Range1>(__rng1), std::forward<_Range2>(__rng2), + std::forward<_Range3>(__rng3), __comp); + } } else { - using _WiIndex = std::uint64_t; - using _MergeKernel = oneapi::dpl::__par_backend_hetero::__internal::__kernel_name_provider< - __merge_kernel_name<_CustomName, _WiIndex>>; - return __parallel_merge_submitter<_WiIndex, _MergeKernel>()( - std::forward<_ExecutionPolicy>(__exec), std::forward<_Range1>(__rng1), std::forward<_Range2>(__rng2), - std::forward<_Range3>(__rng3), __comp); + if (__n <= std::numeric_limits::max()) + { + using _WiIndex = std::uint32_t; + using _MergeKernelLarge = oneapi::dpl::__par_backend_hetero::__internal::__kernel_name_provider< + __merge_kernel_name_large<_CustomName, _WiIndex>>; + return __parallel_merge_submitter_large<_WiIndex, _MergeKernelLarge>()( + std::forward<_ExecutionPolicy>(__exec), std::forward<_Range1>(__rng1), std::forward<_Range2>(__rng2), + std::forward<_Range3>(__rng3), __comp); + } + else + { + using _WiIndex = std::uint64_t; + using _MergeKernelLarge = oneapi::dpl::__par_backend_hetero::__internal::__kernel_name_provider< + __merge_kernel_name_large<_CustomName, _WiIndex>>; + return __parallel_merge_submitter_large<_WiIndex, _MergeKernelLarge>()( + std::forward<_ExecutionPolicy>(__exec), std::forward<_Range1>(__rng1), std::forward<_Range2>(__rng2), + std::forward<_Range3>(__rng3), __comp); + } } }