diff --git a/include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl_merge.h b/include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl_merge.h index 753e32816a0..cadff26a15d 100644 --- a/include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl_merge.h +++ b/include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl_merge.h @@ -82,50 +82,33 @@ __find_start_point(const _Rng1& __rng1, const _Rng2& __rng2, const _Index __i_el // to rng3 (starting from start3) in 'chunk' steps, but do not exceed the total size of the sequences (n1 and n2) template void -__serial_merge(const _Rng1& __rng1, const _Rng2& __rng2, _Rng3& __rng3, _Index __start1, _Index __start2, - const _Index __start3, const std::uint8_t __chunk, const _Index __n1, const _Index __n2, _Compare __comp) +__serial_merge(const _Rng1& __rng1, const _Rng2& __rng2, _Rng3& __rng3, const _Index __start1, const _Index __start2, + const _Index __start3, const _Index __chunk, const _Index __n1, const _Index __n2, _Compare __comp) { - if (__start1 >= __n1) - { - //copying a residual of the second seq - const _Index __n = std::min<_Index>(__n2 - __start2, __chunk); - for (std::uint8_t __i = 0; __i < __n; ++__i) - __rng3[__start3 + __i] = __rng2[__start2 + __i]; - } - else if (__start2 >= __n2) - { - //copying a residual of the first seq - const _Index __n = std::min<_Index>(__n1 - __start1, __chunk); - for (std::uint8_t __i = 0; __i < __n; ++__i) - __rng3[__start3 + __i] = __rng1[__start1 + __i]; - } - else + const _Index __rng1_size = std::min<_Index>(__n1 > __start1 ? __n1 - __start1 : _Index{0}, __chunk); + const _Index __rng2_size = std::min<_Index>(__n2 > __start2 ? __n2 - __start2 : _Index{0}, __chunk); + const _Index __rng3_size = std::min<_Index>(__rng1_size + __rng2_size, __chunk); + + const _Index __rng1_idx_end = __start1 + __rng1_size; + const _Index __rng2_idx_end = __start2 + __rng2_size; + const _Index __rng3_idx_end = __start3 + __rng3_size; + + _Index __rng1_idx = __start1; + _Index __rng2_idx = __start2; + + for (_Index __rng3_idx = __start3; __rng3_idx < __rng3_idx_end; ++__rng3_idx) { - for (std::uint8_t __i = 0; __i < __chunk && __start1 < __n1 && __start2 < __n2; ++__i) - { - const auto& __val1 = __rng1[__start1]; - const auto& __val2 = __rng2[__start2]; - if (__comp(__val2, __val1)) - { - __rng3[__start3 + __i] = __val2; - if (++__start2 == __n2) - { - //copying a residual of the first seq - for (++__i; __i < __chunk && __start1 < __n1; ++__i, ++__start1) - __rng3[__start3 + __i] = __rng1[__start1]; - } - } - else - { - __rng3[__start3 + __i] = __val1; - if (++__start1 == __n1) - { - //copying a residual of the second seq - for (++__i; __i < __chunk && __start2 < __n2; ++__i, ++__start2) - __rng3[__start3 + __i] = __rng2[__start2]; - } - } - } + const bool __rng1_idx_less_n1 = __rng1_idx < __rng1_idx_end; + const bool __rng2_idx_less_n2 = __rng2_idx < __rng2_idx_end; + + // One of __rng1_idx_less_n1 and __rng2_idx_less_n2 should be true here + // because 1) we should fill output data with elements from one of the input ranges + // 2) we calculate __rng3_idx_end as std::min<_Index>(__rng1_size + __rng2_size, __chunk). + __rng3[__rng3_idx] = + ((__rng1_idx_less_n1 && __rng2_idx_less_n2 && __comp(__rng2[__rng2_idx], __rng1[__rng1_idx])) || + !__rng1_idx_less_n1) + ? __rng2[__rng2_idx++] + : __rng1[__rng1_idx++]; } } @@ -149,7 +132,7 @@ struct __parallel_merge_submitter<_IdType, __internal::__optional_kernel_name<_N _PRINT_INFO_IN_DEBUG_MODE(__exec); // Empirical number of values to process per work-item - const std::uint8_t __chunk = __exec.queue().get_device().is_cpu() ? 128 : 4; + const _IdType __chunk = __exec.queue().get_device().is_cpu() ? 128 : 4; const _IdType __steps = oneapi::dpl::__internal::__dpl_ceiling_div(__n, __chunk); diff --git a/include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl_merge_sort.h b/include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl_merge_sort.h index 0765f8ef7bc..a9e60b81c71 100644 --- a/include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl_merge_sort.h +++ b/include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl_merge_sort.h @@ -63,7 +63,7 @@ struct __group_merge_path_sorter template bool sort(const sycl::nd_item<1>& __item, const _StorageAcc& __storage_acc, _Compare __comp, std::uint32_t __start, - std::uint32_t __end, std::uint32_t __sorted, std::uint16_t __data_per_workitem, + std::uint32_t __end, std::uint32_t __sorted, std::uint32_t __data_per_workitem, std::uint32_t __workgroup_size) const { const std::uint32_t __sorted_final = __data_per_workitem * __workgroup_size; @@ -91,7 +91,8 @@ struct __group_merge_path_sorter auto __in_ptr1 = __in_ptr + __start1; auto __in_ptr2 = __in_ptr + __start2; - const auto __start = __find_start_point(__in_ptr1, __in_ptr2, __id_local, __n1, __n2, __comp); + const std::pair __start = + __find_start_point(__in_ptr1, __in_ptr2, __id_local, __n1, __n2, __comp); // TODO: copy the data into registers before the merge to halve the required amount of SLM __serial_merge(__in_ptr1, __in_ptr2, __out_ptr, __start.first, __start.second, __id, __data_per_workitem, __n1, __n2, __comp); @@ -241,7 +242,7 @@ struct __merge_sort_global_submitter<_IndexT, __internal::__optional_kernel_name const _IndexT __n = __rng.size(); _IndexT __n_sorted = __leaf_size; const bool __is_cpu = __q.get_device().is_cpu(); - const std::uint32_t __chunk = __is_cpu ? 32 : 4; + const _IndexT __chunk = __is_cpu ? 32 : 4; const std::size_t __steps = oneapi::dpl::__internal::__dpl_ceiling_div(__n, __chunk); bool __data_in_temp = false;