From 438ac2b21ebc20ac1b184fe43c5453363afff977 Mon Sep 17 00:00:00 2001 From: Hartmut Kaiser Date: Mon, 18 Sep 2023 18:42:39 -0500 Subject: [PATCH] Making set_intersection work with zip-like iterators --- .../algorithms/detail/set_operation.hpp | 46 ++++++++++++++++--- .../parallel/algorithms/set_intersection.hpp | 4 +- 2 files changed, 41 insertions(+), 9 deletions(-) diff --git a/libs/core/algorithms/include/hpx/parallel/algorithms/detail/set_operation.hpp b/libs/core/algorithms/include/hpx/parallel/algorithms/detail/set_operation.hpp index 6ad845a27d70..2577bc281449 100644 --- a/libs/core/algorithms/include/hpx/parallel/algorithms/detail/set_operation.hpp +++ b/libs/core/algorithms/include/hpx/parallel/algorithms/detail/set_operation.hpp @@ -33,6 +33,24 @@ namespace hpx::parallel::detail { /// \cond NOINTERNAL + template + struct decay_tuple + { + using type = T; + }; + + template + struct decay_tuple> + { + using type = std::tuple>; + }; + + template + struct decay_tuple> + { + using type = std::tuple>; + }; + /////////////////////////////////////////////////////////////////////////// template struct set_operations_buffer @@ -43,17 +61,24 @@ namespace hpx::parallel::detail { public: rewritable_ref() = default; - explicit constexpr rewritable_ref(T const& item) noexcept + explicit constexpr rewritable_ref(T& item) noexcept : item_(&item) { } - rewritable_ref& operator=(T const& item) + rewritable_ref& operator=(T& item) { item_ = &item; return *this; } + template + rewritable_ref& operator=(U const& item) + { + *item_ = item; + return *this; + } + // different versions of clang-format produce different results // clang-format off operator T const&() const @@ -64,7 +89,7 @@ namespace hpx::parallel::detail { // clang-format on private: - T const* item_ = nullptr; + T* item_ = nullptr; }; using value_type = typename std::iterator_traits::value_type; @@ -149,8 +174,13 @@ namespace hpx::parallel::detail { bool const first_partition = start1 == 0; bool const last_partition = end1 == static_cast(len1); - auto start_value = HPX_INVOKE(proj1, first1[start1]); - auto end_value = HPX_INVOKE(proj1, first1[end1]); + using result_type = + std::invoke_result_t>; + using element_type = + typename decay_tuple>::type; + + element_type start_value = HPX_INVOKE(proj1, first1[start1]); + element_type end_value = HPX_INVOKE(proj1, first1[end1]); // all but the last chunk require special handling if (!last_partition) @@ -166,7 +196,8 @@ namespace hpx::parallel::detail { // last element of the current chunk if (end1 != 0) { - auto end_value1 = HPX_INVOKE(proj1, first1[end1 - 1]); + element_type end_value1 = + HPX_INVOKE(proj1, first1[end1 - 1]); while (!HPX_INVOKE(f, end_value1, end_value) && --end1 != 0) { @@ -180,7 +211,8 @@ namespace hpx::parallel::detail { // first element of the current chunk if (start1 != 0) { - auto start_value1 = HPX_INVOKE(proj1, first1[start1 - 1]); + element_type start_value1 = + HPX_INVOKE(proj1, first1[start1 - 1]); while ( !HPX_INVOKE(f, start_value1, start_value) && --start1 != 0) diff --git a/libs/core/algorithms/include/hpx/parallel/algorithms/set_intersection.hpp b/libs/core/algorithms/include/hpx/parallel/algorithms/set_intersection.hpp index 064cceea16f8..e5d234066859 100644 --- a/libs/core/algorithms/include/hpx/parallel/algorithms/set_intersection.hpp +++ b/libs/core/algorithms/include/hpx/parallel/algorithms/set_intersection.hpp @@ -286,8 +286,8 @@ namespace hpx::parallel { // perform required set operation for one chunk auto f2 = [proj1, proj2](Iter1 part_first1, Sent1 part_last1, - Iter2 part_first2, Sent2 part_last2, - buffer_type* d, func_type const& f) { + Iter2 part_first2, Sent2 part_last2, auto* d, + func_type const& f) { return sequential_set_intersection(part_first1, part_last1, part_first2, part_last2, d, f, proj1, proj2); };