Skip to content

Commit

Permalink
Remove __find_start_point implementation and usage
Browse files Browse the repository at this point in the history
Signed-off-by: Sergey Kopienko <[email protected]>
  • Loading branch information
SergeyKopienko committed Dec 20, 2024
1 parent 6f5ec48 commit 6dd8e51
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 42 deletions.
43 changes: 6 additions & 37 deletions include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl_merge.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,38 +46,6 @@ using _split_point_t = std::pair<_Index, _Index>;
// 2 | 0 0 0 0 | 1
// | ---->
// 3 | 0 0 0 0 0 |
template <typename _Rng1, typename _Rng2, typename _Index, typename _Compare>
auto
__find_start_point(const _Rng1& __rng1, const _Rng2& __rng2, const _Index __i_elem, const _Index __n1,
const _Index __n2, _Compare __comp)
{
//searching for the first '1', a lower bound for a diagonal [0, 0,..., 0, 1, 1,.... 1, 1]
oneapi::dpl::counting_iterator<_Index> __diag_it(0);

if (__i_elem < __n2) //a condition to specify upper or lower part of the merge matrix to be processed
{
const _Index __q = __i_elem; //diagonal index
const _Index __n_diag = std::min<_Index>(__q, __n1); //diagonal size
auto __res =
std::lower_bound(__diag_it, __diag_it + __n_diag, false /*value to find*/,
[&__rng2, &__rng1, __q, __comp](const auto& __i_diag, const bool __value) mutable {
return __value == __comp(__rng2[__q - __i_diag - 1], __rng1[__i_diag]);
});
return std::make_pair(*__res, __q - *__res);
}
else
{
const _Index __q = __i_elem - __n2; //diagonal index
const _Index __n_diag = std::min<_Index>(__n1 - __q, __n2); //diagonal size
auto __res =
std::lower_bound(__diag_it, __diag_it + __n_diag, false /*value to find*/,
[&__rng2, &__rng1, __n2, __q, __comp](const auto& __i_diag, const bool __value) mutable {
return __value == __comp(__rng2[__n2 - __i_diag - 1], __rng1[__q + __i_diag]);
});
return std::make_pair(__q + *__res, __n2 - *__res);
}
}

template <typename _Rng1, typename _Rng2, typename _Index, typename _Compare>
_split_point_t<_Index>
__find_start_point_in(const _Rng1& __rng1, const _Index __rng1_from, _Index __rng1_to, const _Rng2& __rng2,
Expand Down Expand Up @@ -226,7 +194,8 @@ struct __parallel_merge_submitter<_IdType, __internal::__optional_kernel_name<_M
__cgh.parallel_for<_MergeKernelName...>(
sycl::range</*dim=*/1>(__steps), [=](sycl::item</*dim=*/1> __item_id) {
const _IdType __i_elem = __item_id.get_linear_id() * __chunk;
const auto __start = __find_start_point(__rng1, __rng2, __i_elem, __n1, __n2, __comp);
const auto __start =
__find_start_point_in(__rng1, _IdType{0}, __n1, __rng2, _IdType{0}, __n2, __i_elem, __comp);
__serial_merge(__rng1, __rng2, __rng3, __start.first, __start.second, __i_elem, __chunk, __n1,
__n2, __comp);
});
Expand Down Expand Up @@ -307,10 +276,10 @@ struct __parallel_merge_submitter_large<_IdType, _CustomName,
const _IdType __i_elem = __global_idx * __base_diag_chunk;

__base_diagonals_sp_global_ptr[__global_idx] =
__i_elem == 0
? _split_point_t<_IdType>{0, 0}
: (__i_elem < __n ? __find_start_point(__rng1, __rng2, __i_elem, __n1, __n2, __comp)
: _split_point_t<_IdType>{__n1, __n2});
__i_elem == 0 ? _split_point_t<_IdType>{0, 0}
: (__i_elem < __n ? __find_start_point_in(__rng1, _IdType{0}, __n1, __rng2,
_IdType{0}, __n2, __i_elem, __comp)
: _split_point_t<_IdType>{__n1, __n2});
});
});
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
#include "sycl_traits.h" // SYCL traits specialization for some oneDPL types.
#include "../../utils.h" // __dpl_bit_floor, __dpl_bit_ceil
#include "../../utils_ranges.h" // __difference_t
#include "parallel_backend_sycl_merge.h" // __find_start_point, __serial_merge
#include "parallel_backend_sycl_merge.h" // __find_start_point_in, __serial_merge

namespace oneapi
{
Expand Down Expand Up @@ -91,8 +91,8 @@ struct __group_merge_path_sorter
auto __in_ptr1 = __in_ptr + __start1;
auto __in_ptr2 = __in_ptr + __start2;

const std::pair<std::uint32_t, std::uint32_t> __start =
__find_start_point(__in_ptr1, __in_ptr2, __id_local, __n1, __n2, __comp);
const std::pair<std::uint32_t, std::uint32_t> __start = __find_start_point_in(
__in_ptr1, std::uint32_t{0}, __n1, __in_ptr2, std::uint32_t{0}, __n2, __id_local, __comp);
// TODO: copy the data into registers before the merge to halve the required amount of SLM
__serial_merge(__in_ptr1, __in_ptr2, __out_ptr, __start.first, __start.second, __id, __data_per_workitem,
__n1, __n2, __comp);
Expand Down Expand Up @@ -272,7 +272,8 @@ struct __merge_sort_global_submitter<_IndexT, __internal::__optional_kernel_name
const oneapi::dpl::__ranges::drop_view_simple __rng1(__dst, __offset);
const oneapi::dpl::__ranges::drop_view_simple __rng2(__dst, __offset + __n1);

const auto start = __find_start_point(__rng1, __rng2, __i_elem_local, __n1, __n2, __comp);
const auto start = __find_start_point_in(__rng1, _IndexT{0}, __n1, __rng2, _IndexT{0}, __n2,
__i_elem_local, __comp);
__serial_merge(__rng1, __rng2, __rng /*__rng3*/, start.first, start.second, __i_elem,
__chunk, __n1, __n2, __comp);
}
Expand All @@ -281,7 +282,8 @@ struct __merge_sort_global_submitter<_IndexT, __internal::__optional_kernel_name
const oneapi::dpl::__ranges::drop_view_simple __rng1(__rng, __offset);
const oneapi::dpl::__ranges::drop_view_simple __rng2(__rng, __offset + __n1);

const auto start = __find_start_point(__rng1, __rng2, __i_elem_local, __n1, __n2, __comp);
const auto start = __find_start_point_in(__rng1, _IndexT{0}, __n1, __rng2, _IndexT{0}, __n2,
__i_elem_local, __comp);
__serial_merge(__rng1, __rng2, __dst /*__rng3*/, start.first, start.second, __i_elem,
__chunk, __n1, __n2, __comp);
}
Expand Down

0 comments on commit 6dd8e51

Please sign in to comment.