Skip to content

Commit

Permalink
Merge pull request #6491 from STEllAR-GROUP/vector_bool_reduce
Browse files Browse the repository at this point in the history
More fixes to handling bool arguments for collective operations
  • Loading branch information
hkaiser authored May 16, 2024
2 parents 614d688 + ed7aa60 commit a945257
Show file tree
Hide file tree
Showing 6 changed files with 101 additions and 28 deletions.
39 changes: 28 additions & 11 deletions libs/full/collectives/include/hpx/collectives/all_reduce.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -168,19 +168,36 @@ namespace hpx::traits {
[op = HPX_FORWARD(F, op)](
auto& data, bool& data_available, std::size_t) mutable {
HPX_ASSERT(!data.empty());
if (!data_available && data.size() > 1)

if constexpr (!std::is_same_v<std::decay_t<T>, bool>)
{
if (!data_available && data.size() > 1)
{
// compute reduction result only once
auto it = data.begin();
data[0] = hpx::reduce(
++it, data.end(), data[0], HPX_FORWARD(F, op));
data_available = true;
}
return data[0];
}
else
{
// compute reduction result only once
auto it = data.begin();
data[0] = Communicator::template handle_bool<
std::decay_t<T>>(hpx::reduce(++it, data.end(),
Communicator::template handle_bool<std::decay_t<T>>(
data[0]),
HPX_FORWARD(F, op)));
data_available = true;
if (!data_available && data.size() > 1)
{
// compute reduction result only once
auto it = data.begin();
data[0] = hpx::reduce(++it, data.end(),
static_cast<bool>(data[0]),
[&](auto lhs, auto rhs) {
return HPX_FORWARD(F, op)(
static_cast<bool>(lhs),
static_cast<bool>(rhs));
});
data_available = true;
}
return static_cast<bool>(data[0]);
}
return Communicator::template handle_bool<std::decay_t<T>>(
data[0]);
});
}
};
Expand Down
20 changes: 16 additions & 4 deletions libs/full/collectives/include/hpx/collectives/exclusive_scan.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -187,10 +187,22 @@ namespace hpx::traits {

// first value is not taken into account
auto it = data.begin();
hpx::exclusive_scan(it, data.end(), dest.begin(),
Communicator::template handle_bool<std::decay_t<T>>(
*it),
HPX_FORWARD(F, op));

if constexpr (!std::is_same_v<std::decay_t<T>, bool>)
{
hpx::exclusive_scan(it, data.end(), dest.begin(),
*it, HPX_FORWARD(F, op));
}
else
{
hpx::exclusive_scan(it, data.end(), dest.begin(),
static_cast<bool>(*it),
[&](auto lhs, auto rhs) {
return HPX_FORWARD(F, op)(
static_cast<bool>(lhs),
static_cast<bool>(rhs));
});
}

std::swap(data, dest);
data_available = true;
Expand Down
16 changes: 14 additions & 2 deletions libs/full/collectives/include/hpx/collectives/inclusive_scan.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -172,8 +172,20 @@ namespace hpx::traits {
std::vector<std::decay_t<T>> dest;
dest.resize(data.size());

hpx::inclusive_scan(data.begin(), data.end(),
dest.begin(), HPX_FORWARD(F, op));
if constexpr (!std::is_same_v<std::decay_t<T>, bool>)
{
hpx::inclusive_scan(data.begin(), data.end(),
dest.begin(), HPX_FORWARD(F, op));
}
else
{
hpx::inclusive_scan(data.begin(), data.end(),
dest.begin(), [&](auto lhs, auto rhs) {
return HPX_FORWARD(F, op)(
static_cast<bool>(lhs),
static_cast<bool>(rhs));
});
}

std::swap(data, dest);
data_available = true;
Expand Down
32 changes: 23 additions & 9 deletions libs/full/collectives/include/hpx/collectives/reduce.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -265,17 +265,31 @@ namespace hpx::traits {
auto& data, bool&, std::size_t) mutable {
HPX_ASSERT(!data.empty());

if (data.size() > 1)
if constexpr (!std::is_same_v<std::decay_t<T>, bool>)
{
auto it = data.begin();
return Communicator::template handle_bool<
std::decay_t<T>>(hpx::reduce(++it, data.end(),
Communicator::template handle_bool<std::decay_t<T>>(
HPX_MOVE(data[0])),
HPX_FORWARD(F, op)));
if (data.size() > 1)
{
auto it = data.begin();
return hpx::reduce(++it, data.end(),
HPX_MOVE(data[0]), HPX_FORWARD(F, op));
}
return HPX_MOVE(data[0]);
}
else
{
if (data.size() > 1)
{
auto it = data.begin();
return static_cast<bool>(hpx::reduce(++it,
data.end(), static_cast<bool>(data[0]),
[&](auto lhs, auto rhs) {
return HPX_FORWARD(F, op)(
static_cast<bool>(lhs),
static_cast<bool>(rhs));
}));
}
return static_cast<bool>(data[0]);
}
return Communicator::template handle_bool<std::decay_t<T>>(
HPX_MOVE(data[0]));
});
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,15 @@ constexpr int ITERATIONS = 100;
constexpr int ITERATIONS = 1000;
#endif

struct plus_bool
{
template <typename T>
decltype(auto) operator()(T lhs, T rhs) const
{
return lhs + rhs;
}
};

void test_multiple_use_with_generation()
{
std::uint32_t const this_locality = hpx::get_locality_id();
Expand All @@ -41,7 +50,7 @@ void test_multiple_use_with_generation()
{
bool value = ((this_locality + i) % 2) ? true : false;
hpx::future<bool> overall_result = all_reduce(all_reduce_direct_client,
value, std::plus<>{}, generation_arg(i + 1));
value, plus_bool{}, generation_arg(i + 1));

bool sum = false;
for (std::uint32_t j = 0; j != num_localities; ++j)
Expand Down
11 changes: 10 additions & 1 deletion libs/full/collectives/tests/regressions/reduce_vector_bool.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,15 @@ constexpr int ITERATIONS = 100;
constexpr int ITERATIONS = 1000;
#endif

struct plus_bool
{
template <typename T>
decltype(auto) operator()(T lhs, T rhs) const
{
return lhs + rhs;
}
};

void test_multiple_use_with_generation()
{
std::uint32_t const this_locality = hpx::get_locality_id();
Expand All @@ -43,7 +52,7 @@ void test_multiple_use_with_generation()
if (this_locality == 0)
{
hpx::future<bool> overall_result = reduce_here(reduce_direct_client,
std::move(value), std::plus<>{}, generation_arg(i + 1));
std::move(value), plus_bool{}, generation_arg(i + 1));

bool sum = false;
for (std::uint32_t j = 0; j != num_localities; ++j)
Expand Down

0 comments on commit a945257

Please sign in to comment.