Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[oneDPL][ranges] support size limit for output for merge algorithm #1942

Open
wants to merge 20 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
5133e22
[oneDPL][ranges][merge] support size limit for output; serial patte
MikeDvorskiy Nov 20, 2024
d92698d
[oneDPL][ranges][merge][test] + merge checker with size limit for ou…
MikeDvorskiy Nov 20, 2024
866aef8
[oneDPL][ranges][merge] support size limit for output; parallel patt…
MikeDvorskiy Nov 22, 2024
0111bbc
[oneDPL][ranges][merge] + fixes
MikeDvorskiy Nov 22, 2024
9b10482
[oneDPL][ranges][merge] support size limit for output; + draft for m…
MikeDvorskiy Nov 27, 2024
60420ca
[oneDPL][ranges][merge] support size limit for output; + __simd_merge
MikeDvorskiy Nov 27, 2024
55791e6
[oneDPL][ranges][merge] support size limit for output; code cleanup
MikeDvorskiy Nov 27, 2024
1ef887e
[oneDPL][ranges][merge] support size limit for output; fixes and raf…
MikeDvorskiy Nov 27, 2024
2de3e8e
[oneDPL][ranges][merge] support size limit for output; fixes for __p…
MikeDvorskiy Nov 28, 2024
f2fdfc5
[oneDPL][ranges][merge] support size limit for output; + oneapi::dpl…
MikeDvorskiy Nov 28, 2024
b7bde74
[oneDPL][ranges][merge][test] changes in the checker
MikeDvorskiy Nov 28, 2024
01c63ce
[oneDPL][ranges][merge] support size limit for output; + check n==0
MikeDvorskiy Nov 28, 2024
75e133e
[oneDPL][ranges][merge] support size limit for output; + __parallel_…
MikeDvorskiy Nov 28, 2024
5b078ad
[oneDPL][ranges][merge] support size limit for output; + fixes for s…
MikeDvorskiy Nov 29, 2024
bc9c941
[oneDPL][ranges][merge] support size limit for output; + a fix
MikeDvorskiy Dec 5, 2024
94800f8
[oneDPL][ranges][merge][test] support size limit for output; + fixes
MikeDvorskiy Dec 5, 2024
f3efc69
[oneDPL][ranges][merge][dpcpp] support size limit for output range, …
MikeDvorskiy Dec 9, 2024
324fe55
[oneDPL][ranges][merge] + compilation fix
MikeDvorskiy Dec 9, 2024
2171386
[oneDPL][ranges][merge][dpcpp] + out_of_bounds fix
MikeDvorskiy Dec 10, 2024
98a7acb
[oneDPL][make] + usage ONEAPI_DEVICE_SELECTOR variable
MikeDvorskiy Dec 10, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 122 additions & 1 deletion include/oneapi/dpl/pstl/algorithm_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#include "parallel_backend.h"
#include "parallel_impl.h"
#include "iterator_impl.h"
#include "../functional"

#if _ONEDPL_HETERO_BACKEND
# include "hetero/algorithm_impl_hetero.h" // for __pattern_fill_n, __pattern_generate_n
Expand Down Expand Up @@ -2948,6 +2949,49 @@ __pattern_remove_if(__parallel_tag<_IsVector> __tag, _ExecutionPolicy&& __exec,
// merge
//------------------------------------------------------------------------

template<typename It1, typename It2, typename ItOut, typename _Comp>
std::pair<It1, It2>
__brick_merge_2(It1 __it_1, It1 __it_1_e, It2 __it_2, It2 __it_2_e, ItOut __it_out, ItOut __it_out_e, _Comp __comp,
/* __is_vector = */ std::false_type)
{
while(__it_1 != __it_1_e && __it_2 != __it_2_e)
{
if (__comp(*__it_1, *__it_2))
{
*__it_out = *__it_1;
++__it_out, ++__it_1;
}
else
{
*__it_out = *__it_2;
++__it_out, ++__it_2;
}
if(__it_out == __it_out_e)
return {__it_1, __it_2};
}

if(__it_1 == __it_1_e)
{
for(; __it_2 != __it_2_e && __it_out != __it_out_e; ++__it_2, ++__it_out)
*__it_out = *__it_2;
}
else
{
//assert(__it_2 == __it_2_e);
for(; __it_1 != __it_1_e && __it_out != __it_out_e; ++__it_1, ++__it_out)
*__it_out = *__it_1;
}
return {__it_1, __it_2};
}

template<typename It1, typename It2, typename ItOut, typename _Comp>
std::pair<It1, It2>
__brick_merge_2(It1 __it_1, It1 __it_1_e, It2 __it_2, It2 __it_2_e, ItOut __it_out, ItOut __it_out_e, _Comp __comp,
/* __is_vector = */ std::true_type)
{
return __unseq_backend::__simd_merge(__it_1, __it_1_e, __it_2, __it_2_e, __it_out, __it_out_e, __comp);
}

template <class _ForwardIterator1, class _ForwardIterator2, class _OutputIterator, class _Compare>
_OutputIterator
__brick_merge(_ForwardIterator1 __first1, _ForwardIterator1 __last1, _ForwardIterator2 __first2,
Expand Down Expand Up @@ -2980,10 +3024,87 @@ __pattern_merge(_Tag, _ExecutionPolicy&&, _ForwardIterator1 __first1, _ForwardIt
typename _Tag::__is_vector{});
}

template<class _Tag, typename _ExecutionPolicy, typename _It1, typename _Index1, typename _It2,
typename _Index2, typename _OutIt, typename _Index3, typename _Comp>
std::pair<_It1, _It2>
__pattern_merge_2(_Tag, _ExecutionPolicy&& __exec, _It1 __it_1, _Index1 __n_1, _It2 __it_2,
_Index2 __n_2, _OutIt __it_out, _Index3 __n_out, _Comp __comp)
{
return __brick_merge_2(__it_1, __it_1 + __n_1, __it_2, __it_2 + __n_2, __it_out, __it_out + __n_out, __comp,
typename _Tag::__is_vector{});
}

template<typename _IsVector, typename _ExecutionPolicy, typename _It1, typename _Index1, typename _It2,
typename _Index2, typename _OutIt, typename _Index3, typename _Comp>
std::pair<_It1, _It2>
__pattern_merge_2(__parallel_tag<_IsVector>, _ExecutionPolicy&& __exec, _It1 __it_1, _Index1 __n_1, _It2 __it_2,
_Index2 __n_2, _OutIt __it_out, _Index3 __n_out, _Comp __comp)
{
using __backend_tag = typename __parallel_tag<_IsVector>::__backend_tag;

_It1 __it_res_1;
_It2 __it_res_2;

__internal::__except_handler([&]() {
__par_backend::__parallel_for(__backend_tag{}, std::forward<_ExecutionPolicy>(__exec), _Index3(0), __n_out,
[=, &__it_res_1, &__it_res_2](_Index3 __i, _Index3 __j)
{
//a start merging point on the merge path; for each thread
_Index1 __r = 0; //row index
_Index2 __c = 0; //column index

if(__i > 0)
{
//calc merge path intersection:
const _Index3 __d_size =
std::abs(std::max<_Index2>(0, __i - __n_2) - (std::min<_Index1>(__i, __n_1) - 1)) + 1;

auto __get_row = [__i, __n_1](auto __d)
{ return std::min<_Index1>(__i, __n_1) - __d - 1; };
auto __get_column = [__i, __n_1](auto __d)
{ return std::max<_Index1>(0, __i - __n_1 - 1) + __d + (__i / (__n_1 + 1) > 0 ? 1 : 0); };

oneapi::dpl::counting_iterator<_Index3> __it_d(0);

auto __res_d = *std::lower_bound(__it_d, __it_d + __d_size, 1,
[&](auto __d, auto __val) {
auto __r = __get_row(__d);
auto __c = __get_column(__d);

oneapi::dpl::__internal::__compare<_Comp, oneapi::dpl::identity>
__cmp{__comp, oneapi::dpl::identity{}};
const auto __res = (__cmp(__it_1[__r], __it_2[__c]) ? 1 : 0);

return __res < __val;
}
);

//intersection point
__r = __get_row(__res_d);
__c = __get_column(__res_d);
++__r; //to get a merge matrix ceil, lying on the current diagonal
}

//serial merge n elements, starting from input x and y, to [i, j) output range
auto __res = __brick_merge_2(__it_1 + __r, __it_1 + __n_1,
__it_2 + __c, __it_2 + __n_2,
__it_out + __i, __it_out + __j, __comp, _IsVector{});

if(__j == __n_out)
{
__it_res_1 = __res.first;
__it_res_2 = __res.second;
}
}, _ONEDPL_MERGE_CUT_OFF); //grainsize
});

return {__it_res_1, __it_res_2};
}

template <class _IsVector, class _ExecutionPolicy, class _RandomAccessIterator1, class _RandomAccessIterator2,
class _RandomAccessIterator3, class _Compare>
_RandomAccessIterator3
__pattern_merge(__parallel_tag<_IsVector>, _ExecutionPolicy&& __exec, _RandomAccessIterator1 __first1,
__pattern_merge(__parallel_tag<_IsVector> __tag, _ExecutionPolicy&& __exec, _RandomAccessIterator1 __first1,
_RandomAccessIterator1 __last1, _RandomAccessIterator2 __first2, _RandomAccessIterator2 __last2,
_RandomAccessIterator3 __d_first, _Compare __comp)
{
Expand Down
36 changes: 18 additions & 18 deletions include/oneapi/dpl/pstl/algorithm_ranges_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -448,31 +448,31 @@ auto
__pattern_merge(_Tag __tag, _ExecutionPolicy&& __exec, _R1&& __r1, _R2&& __r2, _OutRange&& __out_r, _Comp __comp,
_Proj1 __proj1, _Proj2 __proj2)
{
static_assert(__is_parallel_tag_v<_Tag> || typename _Tag::__is_vector{});
assert(std::ranges::size(__r1) + std::ranges::size(__r2) <= std::ranges::size(__out_r)); // for debug purposes only

using __return_type = std::ranges::merge_result<std::ranges::borrowed_iterator_t<_R1>, std::ranges::borrowed_iterator_t<_R2>,
std::ranges::borrowed_iterator_t<_OutRange>>;
auto __comp_2 = [__comp, __proj1, __proj2](auto&& __val1, auto&& __val2) { return std::invoke(__comp,
std::invoke(__proj1, std::forward<decltype(__val1)>(__val1)), std::invoke(__proj2,
std::forward<decltype(__val2)>(__val2)));};

auto __res = oneapi::dpl::__internal::__pattern_merge(__tag, std::forward<_ExecutionPolicy>(__exec),
std::ranges::begin(__r1), std::ranges::begin(__r1) + std::ranges::size(__r1), std::ranges::begin(__r2),
std::ranges::begin(__r2) + std::ranges::size(__r2), std::ranges::begin(__out_r), __comp_2);
using _Index1 = std::ranges::range_difference_t<_R1>;
using _Index2 = std::ranges::range_difference_t<_R2>;
using _Index3 = std::ranges::range_difference_t<_OutRange>;

using __return_type = std::ranges::merge_result<std::ranges::borrowed_iterator_t<_R1>, std::ranges::borrowed_iterator_t<_R2>,
std::ranges::borrowed_iterator_t<_OutRange>>;
_Index1 __n_1 = std::ranges::size(__r1);
_Index2 __n_2 = std::ranges::size(__r2);
_Index3 __n_out = std::min<_Index3>(__n_1 + __n_2, std::ranges::size(__out_r));

return __return_type{std::ranges::begin(__r1) + std::ranges::size(__r1), std::ranges::begin(__r2) + std::ranges::size(__r2), __res};
}
auto __it_1 = std::ranges::begin(__r1);
auto __it_2 = std::ranges::begin(__r2);
auto __it_out = std::ranges::begin(__out_r);

template<typename _ExecutionPolicy, typename _R1, typename _R2, typename _OutRange, typename _Comp,
typename _Proj1, typename _Proj2>
auto
__pattern_merge(__serial_tag</*IsVector*/std::false_type>, _ExecutionPolicy&& __exec, _R1&& __r1, _R2&& __r2, _OutRange&& __out_r, _Comp __comp,
_Proj1 __proj1, _Proj2 __proj2)
{
return std::ranges::merge(std::forward<_R1>(__r1), std::forward<_R2>(__r2), std::ranges::begin(__out_r), __comp, __proj1,
__proj2);
if(__n_out == 0)
return __return_type{__it_1, __it_2, __it_out};

auto __res = __pattern_merge_2(__tag, std::forward<_ExecutionPolicy>(__exec), __it_2, __n_2, __it_1, __n_1, __it_out, __n_out, __comp_2);

return __return_type{__res.second, __res.first, __it_out + __n_out};
}

} // namespace __ranges
Expand Down
7 changes: 5 additions & 2 deletions include/oneapi/dpl/pstl/glue_algorithm_ranges_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -1173,9 +1173,12 @@ merge(_ExecutionPolicy&& __exec, _Range1&& __rng1, _Range2&& __rng2, _Range3&& _
{
const auto __dispatch_tag = oneapi::dpl::__ranges::__select_backend(__exec, __rng1, __rng2, __rng3);

return oneapi::dpl::__internal::__ranges::__pattern_merge(
auto __view_res = views::all_write(::std::forward<_Range3>(__rng3));
oneapi::dpl::__internal::__ranges::__pattern_merge(
__dispatch_tag, ::std::forward<_ExecutionPolicy>(__exec), views::all_read(::std::forward<_Range1>(__rng1)),
views::all_read(::std::forward<_Range2>(__rng2)), views::all_write(::std::forward<_Range3>(__rng3)), __comp);
views::all_read(::std::forward<_Range2>(__rng2)), __view_res, __comp);

return __view_res.size();
dmitriy-sobolev marked this conversation as resolved.
Show resolved Hide resolved
dmitriy-sobolev marked this conversation as resolved.
Show resolved Hide resolved
}

template <typename _ExecutionPolicy, typename _Range1, typename _Range2, typename _Range3>
Expand Down
49 changes: 28 additions & 21 deletions include/oneapi/dpl/pstl/hetero/algorithm_ranges_impl_hetero.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,17 +51,19 @@ namespace __ranges
//------------------------------------------------------------------------

template <typename _BackendTag, typename _ExecutionPolicy, typename _Function, typename... _Ranges>
void
auto
__pattern_walk_n(__hetero_tag<_BackendTag>, _ExecutionPolicy&& __exec, _Function __f, _Ranges&&... __rngs)
{
auto __n = oneapi::dpl::__ranges::__get_first_range_size(__rngs...);
using _Size = std::make_unsigned_t<std::common_type_t<oneapi::dpl::__internal::__difference_t<_Ranges>...>>;
auto __n = std::min({_Size(__rngs.size())...});
if (__n > 0)
{
oneapi::dpl::__par_backend_hetero::__parallel_for(_BackendTag{}, ::std::forward<_ExecutionPolicy>(__exec),
unseq_backend::walk_n<_ExecutionPolicy, _Function>{__f}, __n,
::std::forward<_Ranges>(__rngs)...)
.__deferrable_wait();
}
return __n;
}

#if _ONEDPL_CPP20_RANGES_PRESENT
Expand Down Expand Up @@ -680,44 +682,43 @@ struct __copy2_wrapper;

template <typename _BackendTag, typename _ExecutionPolicy, typename _Range1, typename _Range2, typename _Range3,
typename _Compare>
oneapi::dpl::__internal::__difference_t<_Range3>
std::pair<oneapi::dpl::__internal::__difference_t<_Range1>, oneapi::dpl::__internal::__difference_t<_Range2>>
dmitriy-sobolev marked this conversation as resolved.
Show resolved Hide resolved
__pattern_merge(__hetero_tag<_BackendTag> __tag, _ExecutionPolicy&& __exec, _Range1&& __rng1, _Range2&& __rng2,
_Range3&& __rng3, _Compare __comp)
{
auto __n1 = __rng1.size();
auto __n2 = __rng2.size();
auto __n = __n1 + __n2;
if (__n == 0)
return 0;
if (__rng3.size() == 0)
return {0, 0};

//To consider the direct copying pattern call in case just one of sequences is empty.
if (__n1 == 0)
{
oneapi::dpl::__internal::__ranges::__pattern_walk_n(
auto __res = oneapi::dpl::__internal::__ranges::__pattern_walk_n(
__tag,
oneapi::dpl::__par_backend_hetero::make_wrapped_policy<__copy1_wrapper>(
::std::forward<_ExecutionPolicy>(__exec)),
oneapi::dpl::__internal::__brick_copy<__hetero_tag<_BackendTag>, _ExecutionPolicy>{},
::std::forward<_Range2>(__rng2), ::std::forward<_Range3>(__rng3));
return {0, __res};
}
else if (__n2 == 0)

if (__n2 == 0)
{
oneapi::dpl::__internal::__ranges::__pattern_walk_n(
auto __res = oneapi::dpl::__internal::__ranges::__pattern_walk_n(
__tag,
oneapi::dpl::__par_backend_hetero::make_wrapped_policy<__copy2_wrapper>(
::std::forward<_ExecutionPolicy>(__exec)),
oneapi::dpl::__internal::__brick_copy<__hetero_tag<_BackendTag>, _ExecutionPolicy>{},
::std::forward<_Range1>(__rng1), ::std::forward<_Range3>(__rng3));
}
else
{
__par_backend_hetero::__parallel_merge(_BackendTag{}, ::std::forward<_ExecutionPolicy>(__exec),
::std::forward<_Range1>(__rng1), ::std::forward<_Range2>(__rng2),
::std::forward<_Range3>(__rng3), __comp)
.__deferrable_wait();
return {__res, 0};
}

return __n;
auto __res = __par_backend_hetero::__parallel_merge(_BackendTag{}, ::std::forward<_ExecutionPolicy>(__exec),
::std::forward<_Range1>(__rng1), ::std::forward<_Range2>(__rng2),
::std::forward<_Range3>(__rng3), __comp);

return __res.get();
}

#if _ONEDPL_CPP20_RANGES_PRESENT
Expand All @@ -727,21 +728,27 @@ auto
__pattern_merge(__hetero_tag<_BackendTag> __tag, _ExecutionPolicy&& __exec, _R1&& __r1, _R2&& __r2, _OutRange&& __out_r,
_Comp __comp, _Proj1 __proj1, _Proj2 __proj2)
{
assert(std::ranges::size(__r1) + std::ranges::size(__r2) <= std::ranges::size(__out_r)); // for debug purposes only

auto __comp_2 = [__comp, __proj1, __proj2](auto&& __val1, auto&& __val2) { return std::invoke(__comp,
std::invoke(__proj1, std::forward<decltype(__val1)>(__val1)),
std::invoke(__proj2, std::forward<decltype(__val2)>(__val2)));};

using _Index1 = std::ranges::range_difference_t<_R1>;
using _Index2 = std::ranges::range_difference_t<_R2>;
using _Index3 = std::ranges::range_difference_t<_OutRange>;

_Index1 __n_1 = std::ranges::size(__r1);
_Index2 __n_2 = std::ranges::size(__r2);
_Index3 __n_out = std::min<_Index3>(__n_1 + __n_2, std::ranges::size(__out_r));

auto __res = oneapi::dpl::__internal::__ranges::__pattern_merge(__tag, std::forward<_ExecutionPolicy>(__exec),
oneapi::dpl::__ranges::views::all_read(__r1), oneapi::dpl::__ranges::views::all_read(__r2),
oneapi::dpl::__ranges::views::all_write(__out_r), __comp_2);

using __return_t = std::ranges::merge_result<std::ranges::borrowed_iterator_t<_R1>, std::ranges::borrowed_iterator_t<_R2>,
std::ranges::borrowed_iterator_t<_OutRange>>;

return __return_t{std::ranges::begin(__r1) + std::ranges::size(__r1), std::ranges::begin(__r2) +
std::ranges::size(__r2), std::ranges::begin(__out_r) + __res};
return __return_t{std::ranges::begin(__r1) + __res.first, std::ranges::begin(__r2) + __res.second,
std::ranges::begin(__out_r) + __n_out};
}
#endif //_ONEDPL_CPP20_RANGES_PRESENT

Expand Down
23 changes: 19 additions & 4 deletions include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl_merge.h
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ struct __parallel_merge_submitter<_IdType, __internal::__optional_kernel_name<_N
{
const _IdType __n1 = __rng1.size();
const _IdType __n2 = __rng2.size();
const _IdType __n = __n1 + __n2;
const _IdType __n = std::min<_IdType>(__n1 + __n2, __rng3.size());

assert(__n1 > 0 || __n2 > 0);

Expand All @@ -153,16 +153,31 @@ struct __parallel_merge_submitter<_IdType, __internal::__optional_kernel_name<_N

const _IdType __steps = oneapi::dpl::__internal::__dpl_ceiling_div(__n, __chunk);

using __result_and_scratch_storage_t = __result_and_scratch_storage<_ExecutionPolicy, std::pair<std::size_t, std::size_t>>;
__result_and_scratch_storage_t __result_storage{__exec, 2, 0};

auto __event = __exec.queue().submit([&](sycl::handler& __cgh) {
oneapi::dpl::__ranges::__require_access(__cgh, __rng1, __rng2, __rng3);
auto __result_acc = __result_storage.template __get_result_acc<sycl::access_mode::write>(__cgh, __dpl_sycl::__no_init{});

__cgh.parallel_for<_Name...>(sycl::range</*dim=*/1>(__steps), [=](sycl::item</*dim=*/1> __item_id) {
const _IdType __i_elem = __item_id.get_linear_id() * __chunk;
auto __id = __item_id.get_linear_id();
const _IdType __i_elem = __id * __chunk;

const auto __n_merge = std::min<_IdType>(__chunk, __n - __i_elem);
const auto __start = __find_start_point(__rng1, __rng2, __i_elem, __n1, __n2, __comp);
__serial_merge(__rng1, __rng2, __rng3, __start.first, __start.second, __i_elem, __chunk, __n1, __n2,
__serial_merge(__rng1, __rng2, __rng3, __start.first, __start.second, __i_elem, __n_merge, __n1, __n2,
__comp);

if(__id == __steps - 1) //the last WI does additional work
{
const auto __ends = __find_start_point(__rng1, __rng2, __n, __n1, __n2, __comp);
auto __res_ptr = __result_and_scratch_storage_t::__get_usm_or_buffer_accessor_ptr(__result_acc);
*__res_ptr = {__ends.first, __ends.second};
}
});
});
return __future(__event);
return __future(__event, __result_storage);
}
};

Expand Down
Loading
Loading