Skip to content

Commit

Permalink
Refine sequence checking, improve error messages
Browse files Browse the repository at this point in the history
  • Loading branch information
hkaiser committed Feb 12, 2024
1 parent 47906ea commit b61b6dc
Show file tree
Hide file tree
Showing 20 changed files with 126 additions and 163 deletions.
9 changes: 2 additions & 7 deletions libs/full/collectives/include/hpx/collectives/all_gather.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -136,12 +136,7 @@ namespace hpx::traits {
template <>
struct communicator_data<all_gather_tag>
{
static constexpr char const* name() noexcept
{
return "all_gather";
}

HPX_EXPORT static operation_id_type id() noexcept;
HPX_EXPORT static char const* name() noexcept;
};
} // namespace communication

Expand All @@ -156,7 +151,7 @@ namespace hpx::traits {
{
return communicator.template handle_data<std::decay_t<T>>(
communication::communicator_data<
communication::all_gather_tag>::id(),
communication::all_gather_tag>::name(),
which, generation,
// step function (invoked for each get)
[&t](auto& data, std::size_t which) {
Expand Down
9 changes: 2 additions & 7 deletions libs/full/collectives/include/hpx/collectives/all_reduce.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -142,12 +142,7 @@ namespace hpx::traits {
template <>
struct communicator_data<all_reduce_tag>
{
static constexpr char const* name() noexcept
{
return "all_reduce";
}

HPX_EXPORT static operation_id_type id() noexcept;
HPX_EXPORT static char const* name() noexcept;
};
} // namespace communication

Expand All @@ -162,7 +157,7 @@ namespace hpx::traits {
{
return communicator.template handle_data<std::decay_t<T>>(
communication::communicator_data<
communication::all_reduce_tag>::id(),
communication::all_reduce_tag>::name(),
which, generation,
// step function (invoked for each get)
[&t](auto& data, std::size_t which) {
Expand Down
9 changes: 2 additions & 7 deletions libs/full/collectives/include/hpx/collectives/all_to_all.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -137,12 +137,7 @@ namespace hpx::traits {
template <>
struct communicator_data<all_to_all_tag>
{
static constexpr char const* name() noexcept
{
return "all_to_all";
}

HPX_EXPORT static operation_id_type id() noexcept;
HPX_EXPORT static char const* name() noexcept;
};
} // namespace communication

Expand All @@ -157,7 +152,7 @@ namespace hpx::traits {
{
return communicator.template handle_data<std::vector<T>>(
communication::communicator_data<
communication::all_to_all_tag>::id(),
communication::all_to_all_tag>::name(),
which, generation,
// step function (invoked for each get)
[&t](auto& data, std::size_t which) {
Expand Down
11 changes: 3 additions & 8 deletions libs/full/collectives/include/hpx/collectives/broadcast.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -219,12 +219,7 @@ namespace hpx::traits {
template <>
struct communicator_data<broadcast_tag>
{
static constexpr char const* name() noexcept
{
return "broadcast";
}

HPX_EXPORT static operation_id_type id() noexcept;
HPX_EXPORT static char const* name() noexcept;
};
} // namespace communication

Expand All @@ -239,7 +234,7 @@ namespace hpx::traits {

return communicator.template handle_data<data_type>(
communication::communicator_data<
communication::broadcast_tag>::id(),
communication::broadcast_tag>::name(),
which, generation,
// no step function
nullptr,
Expand All @@ -257,7 +252,7 @@ namespace hpx::traits {
{
return communicator.template handle_data<std::decay_t<T>>(
communication::communicator_data<
communication::broadcast_tag>::id(),
communication::broadcast_tag>::name(),
which, generation,
// step function (invoked once for set)
[&t](auto& data, std::size_t) { data[0] = HPX_FORWARD(T, t); },
Expand Down
118 changes: 70 additions & 48 deletions libs/full/collectives/include/hpx/collectives/detail/communicator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,6 @@ namespace hpx::traits {

namespace communication {

using operation_id_type = void const*;

// Retrieve name of the current communicator
template <typename Operation>
struct communicator_data
Expand All @@ -45,11 +43,6 @@ namespace hpx::traits {
{
return "<unknown>";
}

static constexpr operation_id_type id() noexcept
{
return nullptr;
}
};
} // namespace communication
} // namespace hpx::traits
Expand All @@ -65,7 +58,15 @@ namespace hpx::collectives::detail {
public:
HPX_EXPORT communicator_server() noexcept;

HPX_EXPORT explicit communicator_server(std::size_t num_sites) noexcept;
HPX_EXPORT explicit communicator_server(
std::size_t num_sites, char const* basename) noexcept;

communicator_server(communicator_server const&) = delete;
communicator_server(communicator_server&&) = delete;
communicator_server& operator=(communicator_server const&) = delete;
communicator_server& operator=(communicator_server&&) = delete;

HPX_EXPORT ~communicator_server();

private:
template <typename Operation>
Expand Down Expand Up @@ -245,18 +246,39 @@ namespace hpx::collectives::detail {
return fut;
}

template <typename Lock>
bool set_operation_and_check_sequencing(Lock& l, char const* operation,
std::size_t which, std::size_t generation)
{
if (current_operation_ == nullptr)
{
if (on_ready_count_ != 0)
{
l.unlock();
HPX_THROW_EXCEPTION(hpx::error::invalid_status,
"communicator::handle_data",
"communicator: {}: sequencing error, on_ready callback "
"was already invoked before the start of the "
"collective operation {}, which {}, generation {}.",
basename_, operation, which, generation);
}
current_operation_ = operation;
return true;
}
return false;
}

// Step will be invoked under lock for each site that checks in (either
// set or get).
//
// Finalizer will be invoked under lock after all sites have checked in.
template <typename Data, typename Step, typename Finalizer>
auto handle_data(
hpx::traits::communication::operation_id_type operation,
std::size_t which, std::size_t generation,
[[maybe_unused]] Step&& step, Finalizer&& finalizer,
auto handle_data(char const* operation, std::size_t which,
std::size_t generation, [[maybe_unused]] Step&& step,
Finalizer&& finalizer,
std::size_t num_values = static_cast<std::size_t>(-1))
{
auto on_ready = [this, operation, which, num_values,
auto on_ready = [this, operation, which, generation, num_values,
finalizer = HPX_FORWARD(Finalizer, finalizer)](
shared_future<void>&& f) mutable {
// This callback will be invoked once for each participating
Expand All @@ -278,10 +300,12 @@ namespace hpx::collectives::detail {
l.unlock();
HPX_THROW_EXCEPTION(hpx::error::invalid_status,
"communicator::handle_data::on_ready",
"sequencing error, operation type mismatch: invoked "
"for {}, ongoing operation {}",
operation,
current_operation_ ? current_operation_ : "unknown");
"communicator {}: sequencing error, operation type "
"mismatch: invoked for {}, ongoing operation {}, which "
"{}, generation {}.",
basename_, operation,
current_operation_ ? current_operation_ : "unknown",
which, generation);
}

// Verify that the number of invocations of this callback is in
Expand All @@ -291,11 +315,12 @@ namespace hpx::collectives::detail {
l.unlock();
HPX_THROW_EXCEPTION(hpx::error::invalid_status,
"communicator::handle_data::on_ready",
"sequencing error, an excessive number of on_ready "
"callbacks have been invoked before the end of the "
"collective {} operation. Expected count {}, received "
"count {}.",
operation, on_ready_count_, num_sites_);
"communictor {}: sequencing error, an excessive "
"number of on_ready callbacks have been invoked before "
"the end of the collective operation {}, which {}, "
"generation {}. Expected count {}, received count {}.",
basename_, operation, which, generation,
on_ready_count_, num_sites_);
}

// On exit, keep track of number of invocations of this
Expand All @@ -314,6 +339,7 @@ namespace hpx::collectives::detail {
{
HPX_UNUSED(this);
HPX_UNUSED(which);
HPX_UNUSED(generation);
HPX_UNUSED(num_values);
HPX_UNUSED(finalizer);
}
Expand All @@ -324,33 +350,27 @@ namespace hpx::collectives::detail {

// Verify that there is no overlap between different types of
// operations on the same communicator.
if (current_operation_ == nullptr)
{
if (on_ready_count_ != 0)
{
l.unlock();
HPX_THROW_EXCEPTION(hpx::error::invalid_status,
"communicator::handle_data",
"sequencing error, on_ready callback was already "
"invoked before the start of the collective {} "
"operation",
operation);
}
current_operation_ = operation;
}
else if (current_operation_ != operation)
set_operation_and_check_sequencing(l, operation, which, generation);

auto f = get_future_and_synchronize(
generation, num_values, HPX_MOVE(on_ready), l);

// We may have just finished a different operation, thus we have to
// possibly reset the operation type stored in this communicator.
if (current_operation_ != operation &&
!set_operation_and_check_sequencing(
l, operation, which, generation))
{
l.unlock();
HPX_THROW_EXCEPTION(hpx::error::invalid_status,
"communicator::handle_data",
"sequencing error, operation type mismatch: invoked for "
"{}, ongoing operation {}",
operation, current_operation_);
"communicator {}: sequencing error, operation type "
"mismatch: invoked for {}, ongoing operation {}, which {}, "
"generation {}.",
basename_, operation, current_operation_, which,
generation);
}

auto f = get_future_and_synchronize(
generation, num_values, HPX_MOVE(on_ready), l);

if constexpr (!std::is_same_v<std::nullptr_t, std::decay_t<Step>>)
{
// call provided step function for each invocation site
Expand All @@ -360,7 +380,7 @@ namespace hpx::collectives::detail {
// Make sure next generation is enabled only after previous
// generation has finished executing.
gate_.set(which, l,
[this, operation, generation](
[this, operation, which, generation](
auto& l, auto& gate, error_code& ec) {
// This callback is invoked synchronously once for each
// collective operation after all data has been received and
Expand All @@ -377,8 +397,10 @@ namespace hpx::collectives::detail {
"communicator::handle_data",
"sequencing error, not all on_ready callbacks have "
"been invoked at the end of the collective {} "
"operation. Expected count {}, received count {}.",
operation, on_ready_count_, num_sites_);
"operation. Expected count {}, received count {}, "
"which {}, generation {}.",
*operation, on_ready_count_, num_sites_, which,
generation);
return;
}

Expand Down Expand Up @@ -416,10 +438,10 @@ namespace hpx::collectives::detail {
hpx::lcos::local::and_gate gate_;
std::size_t const num_sites_;
std::size_t on_ready_count_ = 0;
hpx::traits::communication::operation_id_type current_operation_ =
nullptr;
char const* current_operation_ = nullptr;
bool needs_initialization_ = true;
bool data_available_ = false;
char const* basename_ = nullptr;
};
} // namespace hpx::collectives::detail

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -154,12 +154,7 @@ namespace hpx::traits {
template <>
struct communicator_data<exclusive_scan_tag>
{
static constexpr char const* name() noexcept
{
return "exclusive_scan";
}

HPX_EXPORT static operation_id_type id() noexcept;
HPX_EXPORT static char const* name() noexcept;
};
} // namespace communication

Expand All @@ -175,7 +170,7 @@ namespace hpx::traits {
{
return communicator.template handle_data<std::decay_t<T>>(
communication::communicator_data<
communication::exclusive_scan_tag>::id(),
communication::exclusive_scan_tag>::name(),
which, generation,
// step function (invoked for each get)
[&t](auto& data, std::size_t which) {
Expand Down
11 changes: 3 additions & 8 deletions libs/full/collectives/include/hpx/collectives/gather.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -238,12 +238,7 @@ namespace hpx::traits {
template <>
struct communicator_data<gather_tag>
{
static constexpr char const* name() noexcept
{
return "gather";
}

HPX_EXPORT static operation_id_type id() noexcept;
HPX_EXPORT static char const* name() noexcept;
};
} // namespace communication

Expand All @@ -256,7 +251,7 @@ namespace hpx::traits {
{
return communicator.template handle_data<std::decay_t<T>>(
communication::communicator_data<
communication::gather_tag>::id(),
communication::gather_tag>::name(),
which, generation,
// step function (invoked once for get)
[&t](auto& data, std::size_t which) {
Expand All @@ -272,7 +267,7 @@ namespace hpx::traits {
{
return communicator.template handle_data<std::decay_t<T>>(
communication::communicator_data<
communication::gather_tag>::id(),
communication::gather_tag>::name(),
which, generation,
// step function (invoked for each set)
[&t](auto& data, std::size_t which) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,12 +142,7 @@ namespace hpx::traits {
template <>
struct communicator_data<inclusive_scan_tag>
{
static constexpr char const* name() noexcept
{
return "inclusive_scan";
}

HPX_EXPORT static operation_id_type id() noexcept;
HPX_EXPORT static char const* name() noexcept;
};
} // namespace communication

Expand All @@ -163,7 +158,7 @@ namespace hpx::traits {
{
return communicator.template handle_data<std::decay_t<T>>(
communication::communicator_data<
communication::inclusive_scan_tag>::id(),
communication::inclusive_scan_tag>::name(),
which, generation,
// step function (invoked for each get)
[&t](auto& data, std::size_t which) {
Expand Down
Loading

0 comments on commit b61b6dc

Please sign in to comment.