diff --git a/include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl_merge.h b/include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl_merge.h index c706042b69b..42bbcfdde48 100644 --- a/include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl_merge.h +++ b/include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl_merge.h @@ -46,38 +46,6 @@ using _split_point_t = std::pair<_Index, _Index>; // 2 | 0 0 0 0 | 1 // | ----> // 3 | 0 0 0 0 0 | -template -auto -__find_start_point(const _Rng1& __rng1, const _Rng2& __rng2, const _Index __i_elem, const _Index __n1, - const _Index __n2, _Compare __comp) -{ - //searching for the first '1', a lower bound for a diagonal [0, 0,..., 0, 1, 1,.... 1, 1] - oneapi::dpl::counting_iterator<_Index> __diag_it(0); - - if (__i_elem < __n2) //a condition to specify upper or lower part of the merge matrix to be processed - { - const _Index __q = __i_elem; //diagonal index - const _Index __n_diag = std::min<_Index>(__q, __n1); //diagonal size - auto __res = - std::lower_bound(__diag_it, __diag_it + __n_diag, false /*value to find*/, - [&__rng2, &__rng1, __q, __comp](const auto& __i_diag, const bool __value) mutable { - return __value == __comp(__rng2[__q - __i_diag - 1], __rng1[__i_diag]); - }); - return std::make_pair(*__res, __q - *__res); - } - else - { - const _Index __q = __i_elem - __n2; //diagonal index - const _Index __n_diag = std::min<_Index>(__n1 - __q, __n2); //diagonal size - auto __res = - std::lower_bound(__diag_it, __diag_it + __n_diag, false /*value to find*/, - [&__rng2, &__rng1, __n2, __q, __comp](const auto& __i_diag, const bool __value) mutable { - return __value == __comp(__rng2[__n2 - __i_diag - 1], __rng1[__q + __i_diag]); - }); - return std::make_pair(__q + *__res, __n2 - *__res); - } -} - template _split_point_t<_Index> __find_start_point_in(const _Rng1& __rng1, const _Index __rng1_from, _Index __rng1_to, const _Rng2& __rng2, @@ -226,7 +194,8 @@ struct __parallel_merge_submitter<_IdType, __internal::__optional_kernel_name<_M __cgh.parallel_for<_MergeKernelName...>( sycl::range(__steps), [=](sycl::item __item_id) { const _IdType __i_elem = __item_id.get_linear_id() * __chunk; - const auto __start = __find_start_point(__rng1, __rng2, __i_elem, __n1, __n2, __comp); + const auto __start = + __find_start_point_in(__rng1, _IdType{0}, __n1, __rng2, _IdType{0}, __n2, __i_elem, __comp); __serial_merge(__rng1, __rng2, __rng3, __start.first, __start.second, __i_elem, __chunk, __n1, __n2, __comp); }); @@ -307,10 +276,10 @@ struct __parallel_merge_submitter_large<_IdType, _CustomName, const _IdType __i_elem = __global_idx * __base_diag_chunk; __base_diagonals_sp_global_ptr[__global_idx] = - __i_elem == 0 - ? _split_point_t<_IdType>{0, 0} - : (__i_elem < __n ? __find_start_point(__rng1, __rng2, __i_elem, __n1, __n2, __comp) - : _split_point_t<_IdType>{__n1, __n2}); + __i_elem == 0 ? _split_point_t<_IdType>{0, 0} + : (__i_elem < __n ? __find_start_point_in(__rng1, _IdType{0}, __n1, __rng2, + _IdType{0}, __n2, __i_elem, __comp) + : _split_point_t<_IdType>{__n1, __n2}); }); }); } diff --git a/include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl_merge_sort.h b/include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl_merge_sort.h index a9e60b81c71..e2f90de4abc 100644 --- a/include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl_merge_sort.h +++ b/include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl_merge_sort.h @@ -91,8 +91,8 @@ struct __group_merge_path_sorter auto __in_ptr1 = __in_ptr + __start1; auto __in_ptr2 = __in_ptr + __start2; - const std::pair __start = - __find_start_point(__in_ptr1, __in_ptr2, __id_local, __n1, __n2, __comp); + const std::pair __start = __find_start_point_in( + __in_ptr1, std::uint32_t{0}, __n1, __in_ptr2, std::uint32_t{0}, __n2, __id_local, __comp); // TODO: copy the data into registers before the merge to halve the required amount of SLM __serial_merge(__in_ptr1, __in_ptr2, __out_ptr, __start.first, __start.second, __id, __data_per_workitem, __n1, __n2, __comp); @@ -272,7 +272,8 @@ struct __merge_sort_global_submitter<_IndexT, __internal::__optional_kernel_name const oneapi::dpl::__ranges::drop_view_simple __rng1(__dst, __offset); const oneapi::dpl::__ranges::drop_view_simple __rng2(__dst, __offset + __n1); - const auto start = __find_start_point(__rng1, __rng2, __i_elem_local, __n1, __n2, __comp); + const auto start = __find_start_point_in(__rng1, _IndexT{0}, __n1, __rng2, _IndexT{0}, __n2, + __i_elem_local, __comp); __serial_merge(__rng1, __rng2, __rng /*__rng3*/, start.first, start.second, __i_elem, __chunk, __n1, __n2, __comp); } @@ -281,7 +282,8 @@ struct __merge_sort_global_submitter<_IndexT, __internal::__optional_kernel_name const oneapi::dpl::__ranges::drop_view_simple __rng1(__rng, __offset); const oneapi::dpl::__ranges::drop_view_simple __rng2(__rng, __offset + __n1); - const auto start = __find_start_point(__rng1, __rng2, __i_elem_local, __n1, __n2, __comp); + const auto start = __find_start_point_in(__rng1, _IndexT{0}, __n1, __rng2, _IndexT{0}, __n2, + __i_elem_local, __comp); __serial_merge(__rng1, __rng2, __dst /*__rng3*/, start.first, start.second, __i_elem, __chunk, __n1, __n2, __comp); }