Skip to content

Commit

Permalink
Fixing race in collective operations
Browse files Browse the repository at this point in the history
- adding test
  • Loading branch information
hkaiser committed Feb 12, 2024
1 parent b61b6dc commit 5f0a8ee
Show file tree
Hide file tree
Showing 5 changed files with 849 additions and 25 deletions.
2 changes: 1 addition & 1 deletion libs/core/futures/src/future_data.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ namespace hpx::lcos::detail {
#endif

bool const is_hpx_thread = nullptr != hpx::threads::get_self_ptr();
if (!is_hpx_thread || !recurse_asynchronously)
if (is_hpx_thread && !recurse_asynchronously)
{
// directly execute continuation on this thread
run_on_completed(HPX_FORWARD(Callback, on_completed));
Expand Down
2 changes: 1 addition & 1 deletion libs/core/lcos_local/include/hpx/lcos_local/and_gate.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -408,7 +408,7 @@ namespace hpx::lcos::local {
// Note: This type is not thread-safe. It has to be protected from
// concurrent access by different threads by the code using instances
// of this type.
struct and_gate : public base_and_gate<hpx::no_mutex>
struct and_gate : base_and_gate<hpx::no_mutex>
{
private:
using base_type = base_and_gate<hpx::no_mutex>;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,8 @@ namespace hpx::collectives::detail {
};

private:
std::size_t get_num_sites(std::size_t num_values) const noexcept
[[nodiscard]] constexpr std::size_t get_num_sites(
std::size_t num_values) const noexcept
{
return num_values == static_cast<std::size_t>(-1) ? num_sites_ :
num_values;
Expand Down Expand Up @@ -231,19 +232,22 @@ namespace hpx::collectives::detail {
std::size_t generation, std::size_t capacity, F&& f, Lock& l)
{
HPX_ASSERT_OWNS_LOCK(l);
auto sf = gate_.get_shared_future(l);

traits::detail::get_shared_state(sf)->reserve_callbacks(
get_num_sites(capacity));

auto fut = sf.then(hpx::launch::sync, HPX_FORWARD(F, f));

// Wait for the requested generation to be processed.
gate_.synchronize(generation == static_cast<std::size_t>(-1) ?
gate_.generation(l) :
generation,
l);

return fut;
// Get future from gate only after synchronization as otherwise we
// may get a future returned that does not belong to the requested
// generation.
auto sf = gate_.get_shared_future(l);

traits::detail::get_shared_state(sf)->reserve_callbacks(
get_num_sites(capacity));

return sf.then(hpx::launch::sync, HPX_FORWARD(F, f));
}

template <typename Lock>
Expand All @@ -262,9 +266,16 @@ namespace hpx::collectives::detail {
"collective operation {}, which {}, generation {}.",
basename_, operation, which, generation);
}
current_operation_ = operation;

if (generation == static_cast<std::size_t>(-1) ||
generation == gate_.generation(l))
{
current_operation_ = operation;
}

return true;
}

return false;
}

Expand All @@ -284,6 +295,11 @@ namespace hpx::collectives::detail {
// This callback will be invoked once for each participating
// site after all sites have checked in.

// On exit, keep track of number of invocations of this
// callback.
auto on_exit = hpx::experimental::scope_exit(
[this] { ++on_ready_count_; });

f.get(); // propagate any exceptions

// It does not matter whether the lock will be acquired here. It
Expand Down Expand Up @@ -315,19 +331,14 @@ namespace hpx::collectives::detail {
l.unlock();
HPX_THROW_EXCEPTION(hpx::error::invalid_status,
"communicator::handle_data::on_ready",
"communictor {}: sequencing error, an excessive "
"communicator {}: sequencing error, an excessive "
"number of on_ready callbacks have been invoked before "
"the end of the collective operation {}, which {}, "
"generation {}. Expected count {}, received count {}.",
basename_, operation, which, generation,
on_ready_count_, num_sites_);
}

// On exit, keep track of number of invocations of this
// callback.
auto on_exit = hpx::experimental::scope_exit(
[this] { ++on_ready_count_; });

if constexpr (!std::is_same_v<std::nullptr_t,
std::decay_t<Finalizer>>)
{
Expand All @@ -338,8 +349,6 @@ namespace hpx::collectives::detail {
else
{
HPX_UNUSED(this);
HPX_UNUSED(which);
HPX_UNUSED(generation);
HPX_UNUSED(num_values);
HPX_UNUSED(finalizer);
}
Expand Down Expand Up @@ -373,7 +382,7 @@ namespace hpx::collectives::detail {

if constexpr (!std::is_same_v<std::nullptr_t, std::decay_t<Step>>)
{
// call provided step function for each invocation site
// Call provided step function for each invocation site.
HPX_FORWARD(Step, step)(access_data<Data>(num_values), which);
}

Expand All @@ -399,7 +408,7 @@ namespace hpx::collectives::detail {
"been invoked at the end of the collective {} "
"operation. Expected count {}, received count {}, "
"which {}, generation {}.",
*operation, on_ready_count_, num_sites_, which,
operation, on_ready_count_, num_sites_, which,
generation);
return;
}
Expand All @@ -416,7 +425,7 @@ namespace hpx::collectives::detail {
return f;
}

// protect against vector<bool> idiosyncrasies
// Protect against vector<bool> idiosyncrasies.
template <typename ValueType, typename Data>
static constexpr decltype(auto) handle_bool(Data&& data) noexcept
{
Expand All @@ -433,15 +442,15 @@ namespace hpx::collectives::detail {
template <typename Communicator, typename Operation>
friend struct hpx::traits::communication_operation;

mutex_type mtx_;
hpx::unique_any_nonser data_;
hpx::lcos::local::and_gate gate_;
std::size_t const num_sites_;
std::size_t on_ready_count_ = 0;
char const* current_operation_ = nullptr;
char const* basename_ = nullptr;
mutex_type mtx_;
bool needs_initialization_ = true;
bool data_available_ = false;
char const* basename_ = nullptr;
};
} // namespace hpx::collectives::detail

Expand Down
3 changes: 2 additions & 1 deletion libs/full/collectives/tests/unit/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2019-2023 Hartmut Kaiser
# Copyright (c) 2019-2024 Hartmut Kaiser
#
# SPDX-License-Identifier: BSL-1.0
# Distributed under the Boost Software License, Version 1.0. (See accompanying
Expand All @@ -23,6 +23,7 @@ if(HPX_WITH_NETWORKING)
set(tests
${tests}
broadcast_direct
concurrent_collectives
exclusive_scan_
gather
inclusive_scan_
Expand Down
Loading

0 comments on commit 5f0a8ee

Please sign in to comment.