diff --git a/include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl_merge.h b/include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl_merge.h index d7c2a096e41..10b26cde64a 100644 --- a/include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl_merge.h +++ b/include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl_merge.h @@ -59,10 +59,9 @@ __find_start_point(const _Rng1& __rng1, const _Rng2& __rng2, const _Index __i_el const _Index __q = __i_elem; //diagonal index const _Index __n_diag = std::min<_Index>(__q, __n1); //diagonal size auto __res = - std::lower_bound(__diag_it, __diag_it + __n_diag, 1 /*value to find*/, - [&__rng2, &__rng1, __q, __comp](const auto& __i_diag, const auto& __value) mutable { - const auto __zero_or_one = __comp(__rng2[__q - __i_diag - 1], __rng1[__i_diag]); - return __zero_or_one < __value; + std::lower_bound(__diag_it, __diag_it + __n_diag, false /*value to find*/, + [&__rng2, &__rng1, __q, __comp](const auto& __i_diag, const bool __value) mutable { + return __value == __comp(__rng2[__q - __i_diag - 1], __rng1[__i_diag]); }); return std::make_pair(*__res, __q - *__res); } @@ -71,10 +70,9 @@ __find_start_point(const _Rng1& __rng1, const _Rng2& __rng2, const _Index __i_el const _Index __q = __i_elem - __n2; //diagonal index const _Index __n_diag = std::min<_Index>(__n1 - __q, __n2); //diagonal size auto __res = - std::lower_bound(__diag_it, __diag_it + __n_diag, 1 /*value to find*/, - [&__rng2, &__rng1, __n2, __q, __comp](const auto& __i_diag, const auto& __value) mutable { - const auto __zero_or_one = __comp(__rng2[__n2 - __i_diag - 1], __rng1[__q + __i_diag]); - return __zero_or_one < __value; + std::lower_bound(__diag_it, __diag_it + __n_diag, false /*value to find*/, + [&__rng2, &__rng1, __n2, __q, __comp](const auto& __i_diag, const bool __value) mutable { + return __value == __comp(__rng2[__n2 - __i_diag - 1], __rng1[__q + __i_diag]); }); return std::make_pair(__q + *__res, __n2 - *__res); } @@ -167,12 +165,10 @@ __find_start_point_in(const _Rng1& __rng1, const _Index __rng1_from, _Index __rn __it_t __diag_it_begin(idx1_from); __it_t __diag_it_end(idx1_to); - constexpr int kValue = 1; + constexpr bool kValue = false; const __it_t __res = std::lower_bound(__diag_it_begin, __diag_it_end, kValue, - [&__rng1, &__rng2, __index_sum, __comp](_Index __idx, const auto& __value) { - const auto __zero_or_one = - __comp(__rng2[__index_sum - __idx], __rng1[__idx]); - return __zero_or_one < kValue; + [&__rng1, &__rng2, __index_sum, __comp](_Index __idx, const bool __value) { + return __value == __comp(__rng2[__index_sum - __idx], __rng1[__idx]); }); return _split_point_t<_Index>{*__res, __index_sum - *__res + 1};