Skip to content

Commit

Permalink
Making set_intersection work with zip-like iterators
Browse files Browse the repository at this point in the history
  • Loading branch information
hkaiser committed Sep 18, 2023
1 parent c8ea2c7 commit a04babe
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,24 @@
namespace hpx::parallel::detail {
/// \cond NOINTERNAL

template <typename T>
struct decay_tuple
{
using type = T;
};

template <typename T>
struct decay_tuple<std::tuple<T>>
{
using type = std::tuple<std::remove_const_t<T>>;
};

template <typename T>
struct decay_tuple<std::tuple<T&>>
{
using type = std::tuple<std::remove_const_t<T>>;
};

///////////////////////////////////////////////////////////////////////////
template <typename FwdIter>
struct set_operations_buffer
Expand All @@ -43,17 +61,24 @@ namespace hpx::parallel::detail {
public:
rewritable_ref() = default;

explicit constexpr rewritable_ref(T const& item) noexcept
explicit constexpr rewritable_ref(T& item) noexcept
: item_(&item)
{
}

rewritable_ref& operator=(T const& item)
rewritable_ref& operator=(T& item)
{
item_ = &item;
return *this;
}

template <typename U>
rewritable_ref& operator=(U const& item)
{
*item_ = item;
return *this;
}

// different versions of clang-format produce different results
// clang-format off
operator T const&() const
Expand All @@ -64,7 +89,7 @@ namespace hpx::parallel::detail {
// clang-format on

private:
T const* item_ = nullptr;
T* item_ = nullptr;
};

using value_type = typename std::iterator_traits<FwdIter>::value_type;
Expand Down Expand Up @@ -149,8 +174,13 @@ namespace hpx::parallel::detail {
bool const first_partition = start1 == 0;
bool const last_partition = end1 == static_cast<std::size_t>(len1);

auto start_value = HPX_INVOKE(proj1, first1[start1]);
auto end_value = HPX_INVOKE(proj1, first1[end1]);
using result_type =
std::invoke_result_t<Proj1, hpx::traits::iter_value_t<Iter1>>;
using element_type =
typename decay_tuple<std::decay_t<result_type>>::type;

element_type start_value = HPX_INVOKE(proj1, first1[start1]);
element_type end_value = HPX_INVOKE(proj1, first1[end1]);

// all but the last chunk require special handling
if (!last_partition)
Expand All @@ -166,7 +196,8 @@ namespace hpx::parallel::detail {
// last element of the current chunk
if (end1 != 0)
{
auto end_value1 = HPX_INVOKE(proj1, first1[end1 - 1]);
element_type end_value1 =
HPX_INVOKE(proj1, first1[end1 - 1]);

while (!HPX_INVOKE(f, end_value1, end_value) && --end1 != 0)
{
Expand All @@ -180,7 +211,8 @@ namespace hpx::parallel::detail {
// first element of the current chunk
if (start1 != 0)
{
auto start_value1 = HPX_INVOKE(proj1, first1[start1 - 1]);
element_type start_value1 =
HPX_INVOKE(proj1, first1[start1 - 1]);

while (
!HPX_INVOKE(f, start_value1, start_value) && --start1 != 0)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,6 @@ namespace hpx::parallel {
HPX_MOVE(first1), HPX_MOVE(first2), HPX_MOVE(dest)});
}

using buffer_type = typename set_operations_buffer<Iter3>::type;
using func_type = std::decay_t<F>;

// calculate approximate destination index
Expand All @@ -286,8 +285,8 @@ namespace hpx::parallel {

// perform required set operation for one chunk
auto f2 = [proj1, proj2](Iter1 part_first1, Sent1 part_last1,
Iter2 part_first2, Sent2 part_last2,
buffer_type* d, func_type const& f) {
Iter2 part_first2, Sent2 part_last2, auto* d,
func_type const& f) {
return sequential_set_intersection(part_first1, part_last1,
part_first2, part_last2, d, f, proj1, proj2);
};
Expand Down

0 comments on commit a04babe

Please sign in to comment.