Skip to content

Commit

Permalink
[xla:cpu] Migrate CollectivePermute to unified collectives API
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 711922306
  • Loading branch information
ezhulenev authored and Google-ML-Automation committed Jan 6, 2025
1 parent f894287 commit 8f29f77
Show file tree
Hide file tree
Showing 12 changed files with 112 additions and 69 deletions.
2 changes: 2 additions & 0 deletions xla/backends/cpu/runtime/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -557,6 +557,8 @@ cc_library(
"//xla:status_macros",
"//xla:util",
"//xla:xla_data_proto_cc",
"//xla/backends/cpu/collectives:cpu_collectives",
"//xla/core/collectives:rank_id",
"//xla/runtime:buffer_use",
"//xla/service:buffer_assignment",
"//xla/service:collective_ops_utils",
Expand Down
25 changes: 18 additions & 7 deletions xla/backends/cpu/runtime/collective_permute_thunk.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,22 @@ limitations under the License.
#include <cstdint>
#include <memory>
#include <optional>
#include <string>
#include <utility>
#include <vector>

#include "absl/container/inlined_vector.h"
#include "absl/memory/memory.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h"
#include "absl/strings/str_join.h"
#include "absl/types/span.h"
#include "xla/backends/cpu/collectives/cpu_collectives.h"
#include "xla/backends/cpu/runtime/collective_thunk.h"
#include "xla/backends/cpu/runtime/thunk.h"
#include "xla/core/collectives/rank_id.h"
#include "xla/service/buffer_assignment.h"
#include "xla/service/collective_ops_utils.h"
#include "xla/service/computation_placer.h"
Expand Down Expand Up @@ -83,12 +87,12 @@ CollectivePermuteThunk::Execute(const ExecuteParams& params) {
: logical_id.replica_id;

// Find replicas that we will communicate with.
std::optional<int32_t> source_replica_id;
std::vector<int32_t> copy_to;
std::optional<RankId> source_replica_id;
std::vector<RankId> copy_to;

for (auto& [from, to] : source_target_pairs_) {
if (from == logical_device_id) {
copy_to.push_back(to);
copy_to.push_back(RankId(to));
}
if (to == logical_device_id) {
TF_RET_CHECK(!source_replica_id.has_value())
Expand All @@ -98,6 +102,10 @@ CollectivePermuteThunk::Execute(const ExecuteParams& params) {
}
}

auto rank_fmt = [](std::string* out, RankId rank) {
absl::StrAppend(out, rank.value());
};

VLOG(3) << absl::StreamFormat(
"CollectivePermute: #source_buffers=%d, #destination_buffers=%d, "
"source_target_pairs=[%s], logical_device_id=%d (%s), "
Expand All @@ -106,7 +114,8 @@ CollectivePermuteThunk::Execute(const ExecuteParams& params) {
absl::StrJoin(source_target_pairs_, ", ", absl::PairFormatter("->")),
logical_device_id,
op_params().has_channel_id ? "computation id" : "replica id",
source_replica_id.value_or(-1), absl::StrJoin(copy_to, ","));
source_replica_id.value_or(RankId(-1)).value(),
absl::StrJoin(copy_to, ",", rank_fmt));

for (int i = 0; i < data.source.size(); ++i) {
VLOG(3) << absl::StreamFormat(
Expand All @@ -123,12 +132,14 @@ CollectivePermuteThunk::Execute(const ExecuteParams& params) {
return ExecuteWithCommunicator(
params.collective_params,
[&](const RendezvousKey& key, CollectivesCommunicator& comm) {
CpuCollectives::Executor executor(key, DefaultCollectiveTimeout());

for (int32_t i = 0; i < data.source.size(); ++i) {
const Shape& shape = source_shape(i);
TF_RETURN_IF_ERROR(comm.CollectivePermute(
key, ShapeUtil::ByteSizeOf(shape), source_replica_id, copy_to,
data.source[i].opaque(), data.destination[i].opaque(),
DefaultCollectiveTimeout()));
data.source[i], data.destination[i], shape.element_type(),
ShapeUtil::ElementsIn(shape), source_replica_id, copy_to,
executor));
}
return absl::OkStatus();
});
Expand Down
2 changes: 2 additions & 0 deletions xla/pjrt/cpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,7 @@ cc_library(
"//xla:types",
"//xla:xla_data_proto_cc",
"//xla/backends/cpu/collectives:cpu_collectives",
"//xla/core/collectives:rank_id",
"//xla/service:collective_ops_utils",
"//xla/service:global_device_id",
"//xla/service/cpu:collectives_interface",
Expand All @@ -308,6 +309,7 @@ cc_library(
"//xla/tsl/platform:statusor",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/log",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
Expand Down
34 changes: 20 additions & 14 deletions xla/pjrt/cpu/gloo_collectives.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ limitations under the License.
#include <utility>
#include <vector>

#include "absl/log/log.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_cat.h"
Expand All @@ -48,6 +49,7 @@ limitations under the License.
#include "gloo/transport/unbound_buffer.h"
#include "gloo/types.h"
#include "xla/backends/cpu/collectives/cpu_collectives.h"
#include "xla/core/collectives/rank_id.h"
#include "xla/primitive_util.h"
#include "xla/service/collective_ops_utils.h"
#include "xla/service/cpu/collectives_interface.h"
Expand Down Expand Up @@ -193,37 +195,41 @@ absl::Status GlooCollectivesCommunicator::AllReduce(
static constexpr uint8_t kCollectivePermuteSlotPrefix = 0x40;

absl::Status GlooCollectivesCommunicator::CollectivePermute(
const RendezvousKey& key, size_t num_bytes, std::optional<int> source_rank,
absl::Span<int const> target_ranks, const void* input_buffer,
void* output_buffer, absl::Duration timeout) {
se::DeviceMemoryBase send_buffer, se::DeviceMemoryBase recv_buffer,
PrimitiveType dtype, size_t count, std::optional<RankId> source_rank,
absl::Span<const RankId> target_ranks, const Executor& executor) {
uint32_t tag = 0; // TODO(phawkins): come up with better tags.
const auto slot = gloo::Slot::build(kCollectivePermuteSlotPrefix, tag);

TF_ASSIGN_OR_RETURN(auto cpu_executor, CpuCollectives::TryCast(&executor));
size_t num_bytes = count * primitive_util::ByteWidth(dtype);

try {
std::unique_ptr<gloo::transport::UnboundBuffer> in;
std::unique_ptr<gloo::transport::UnboundBuffer> out;
for (int target : target_ranks) {
for (RankId target : target_ranks) {
if (target != context_->rank) {
VLOG(1) << "send from " << context_->rank << " to " << target;
VLOG(1) << "send from " << context_->rank << " to " << target.value();
if (!in) {
in = context_->createUnboundBuffer(const_cast<void*>(input_buffer),
num_bytes);
in = context_->createUnboundBuffer(send_buffer.opaque(), num_bytes);
}
in->send(target, slot);
in->send(target.value(), slot);
}
}
if (source_rank) {
if (*source_rank == context_->rank) {
std::memcpy(output_buffer, input_buffer, num_bytes);
std::memcpy(recv_buffer.opaque(), send_buffer.opaque(), num_bytes);
} else {
VLOG(1) << "recv at " << context_->rank << " from " << *source_rank;
out = context_->createUnboundBuffer(output_buffer, num_bytes);
out->recv(*source_rank, slot);
VLOG(1) << "recv at " << context_->rank << " from "
<< source_rank->value();
out = context_->createUnboundBuffer(recv_buffer.opaque(), num_bytes);
out->recv(source_rank->value(), slot);
}
} else {
std::memset(output_buffer, 0, num_bytes);
std::memset(recv_buffer.opaque(), 0, num_bytes);
}
VLOG(1) << "wait for send at " << context_->rank;
auto deadline = absl::ToChronoTime(absl::Now() + timeout);
auto deadline = absl::ToChronoTime(absl::Now() + cpu_executor->timeout());
if (in) {
in->waitSend(deadline);
}
Expand Down
11 changes: 6 additions & 5 deletions xla/pjrt/cpu/gloo_collectives.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,12 @@ class GlooCollectivesCommunicator : public CollectivesCommunicator {
se::DeviceMemoryBase recv_buffer, PrimitiveType dtype,
size_t count, ReductionKind reduction_kind,
const Executor& executor) override;
absl::Status CollectivePermute(const RendezvousKey& key, size_t num_bytes,
std::optional<int> source_rank,
absl::Span<int const> target_ranks,
const void* input_buffer, void* output_buffer,
absl::Duration timeout) override;
absl::Status CollectivePermute(se::DeviceMemoryBase send_buffer,
se::DeviceMemoryBase recv_buffer,
PrimitiveType dtype, size_t count,
std::optional<RankId> source_rank,
absl::Span<const RankId> target_ranks,
const Executor& executor) override;
absl::Status AllToAll(const RendezvousKey& key, size_t chunk_bytes,
absl::Span<const void* const> input_buffers,
absl::Span<void* const> output_buffers,
Expand Down
28 changes: 15 additions & 13 deletions xla/pjrt/cpu/mpi_collectives.cc
Original file line number Diff line number Diff line change
Expand Up @@ -146,36 +146,38 @@ absl::Status MpiCollectivesCommunicator::AllReduce(
}

absl::Status MpiCollectivesCommunicator::CollectivePermute(
const RendezvousKey& key, size_t num_bytes, std::optional<int> source_rank,
absl::Span<int const> target_ranks, const void* input_buffer,
void* output_buffer, absl::Duration timeout) {
se::DeviceMemoryBase send_buffer, se::DeviceMemoryBase recv_buffer,
PrimitiveType dtype, size_t count, std::optional<RankId> source_rank,
absl::Span<const RankId> target_ranks, const Executor& executor) {
int tag = 0; // TODO come up with better tags.

const int rank = mpi_rank_;

std::vector<MPI_Request> requests;

size_t num_bytes = count * primitive_util::ByteWidth(dtype);

if (source_rank) {
if (*source_rank == rank) {
std::memcpy(output_buffer, input_buffer, num_bytes);
if (source_rank->value() == rank) {
std::memcpy(recv_buffer.opaque(), send_buffer.opaque(), num_bytes);
} else {
VLOG(1) << "recv at " << rank << " from " << *source_rank;
VLOG(1) << "recv at " << rank << " from " << source_rank->value();
requests.emplace_back();
TF_RETURN_IF_ERROR(MpiErrorToAbslStatus(
MPI_Irecv(output_buffer, num_bytes, MPI_BYTE, *source_rank, tag,
comm_, &requests.back())));
MPI_Irecv(recv_buffer.opaque(), num_bytes, MPI_BYTE,
source_rank->value(), tag, comm_, &requests.back())));
}
} else {
std::memset(output_buffer, 0, num_bytes);
std::memset(recv_buffer.opaque(), 0, num_bytes);
}

for (int target : target_ranks) {
for (RankId target : target_ranks) {
if (target != rank) {
VLOG(1) << "send from " << rank << " to " << target;
VLOG(1) << "send from " << rank << " to " << target.value();
requests.emplace_back();
TF_RETURN_IF_ERROR(MpiErrorToAbslStatus(
MPI_Isend(input_buffer, num_bytes, MPI_BYTE, target, tag, comm_,
&requests.back())));
MPI_Isend(send_buffer.opaque(), num_bytes, MPI_BYTE, target.value(),
tag, comm_, &requests.back())));
}
}

Expand Down
11 changes: 6 additions & 5 deletions xla/pjrt/cpu/mpi_collectives.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,12 @@ class MpiCollectivesCommunicator : public CollectivesCommunicator {
se::DeviceMemoryBase recv_buffer, PrimitiveType dtype,
size_t count, ReductionKind reduction_kind,
const Executor& executor) override;
absl::Status CollectivePermute(const RendezvousKey& key, size_t num_bytes,
std::optional<int> source_rank,
absl::Span<int const> target_ranks,
const void* input_buffer, void* output_buffer,
absl::Duration timeout) override;
absl::Status CollectivePermute(se::DeviceMemoryBase send_buffer,
se::DeviceMemoryBase recv_buffer,
PrimitiveType dtype, size_t count,
std::optional<RankId> source_rank,
absl::Span<const RankId> target_ranks,
const Executor& executor) override;
absl::Status AllToAll(const RendezvousKey& key, size_t chunk_bytes,
absl::Span<const void* const> input_buffers,
absl::Span<void* const> output_buffers,
Expand Down
3 changes: 3 additions & 0 deletions xla/service/cpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1035,6 +1035,7 @@ cc_library(
"//xla:util",
"//xla:xla_data_proto_cc",
"//xla/backends/cpu/collectives:cpu_collectives",
"//xla/core/collectives:rank_id",
"//xla/hlo/parser:hlo_parser",
"//xla/service:collective_ops_utils",
"//xla/service:computation_placer",
Expand Down Expand Up @@ -1985,6 +1986,7 @@ cc_library(
deps = [
"//xla:xla_data_proto_cc",
"//xla/core/collectives:communicator",
"//xla/core/collectives:rank_id",
"//xla/service:collective_ops_utils",
"//xla/service:global_device_id",
"//xla/stream_executor:device_memory",
Expand All @@ -2007,6 +2009,7 @@ cc_library(
"//xla:util",
"//xla:xla_data_proto_cc",
"//xla/backends/cpu/collectives:cpu_collectives",
"//xla/core/collectives:rank_id",
"//xla/service:collective_ops_utils",
"//xla/service:global_device_id",
"//xla/stream_executor:device_memory",
Expand Down
14 changes: 7 additions & 7 deletions xla/service/cpu/collectives_interface.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ limitations under the License.
#include "absl/time/time.h"
#include "absl/types/span.h"
#include "xla/core/collectives/communicator.h"
#include "xla/core/collectives/rank_id.h"
#include "xla/service/collective_ops_utils.h"
#include "xla/service/global_device_id.h"
#include "xla/stream_executor/device_memory.h"
Expand Down Expand Up @@ -52,13 +53,12 @@ class CollectivesCommunicator {
// source_rank: the rank from which this rank should receive its data.
// Optional; if absent, then the output is filled with zeros.
// target_rank: the ranks to which this rank should send its data.
virtual absl::Status CollectivePermute(const RendezvousKey& key,
size_t num_bytes,
std::optional<int> source_rank,
absl::Span<int const> target_ranks,
const void* input_buffer,
void* output_buffer,
absl::Duration timeout) = 0;
virtual absl::Status CollectivePermute(se::DeviceMemoryBase send_buffer,
se::DeviceMemoryBase recv_buffer,
PrimitiveType dtype, size_t count,
std::optional<RankId> source_rank,
absl::Span<const RankId> target_ranks,
const Executor& executor) = 0;

// Performs an all-to-all.
// The all-to-all chunks are passed separately and do not have to be
Expand Down
19 changes: 13 additions & 6 deletions xla/service/cpu/cpu_runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ limitations under the License.
#include "absl/time/time.h"
#include "absl/types/span.h"
#include "xla/backends/cpu/collectives/cpu_collectives.h"
#include "xla/core/collectives/rank_id.h"
#include "xla/executable_run_options.h"
#include "xla/hlo/parser/hlo_parser.h"
#include "xla/layout_util.h"
Expand Down Expand Up @@ -537,19 +538,19 @@ void CollectivePermuteImpl(const ExecutableRunOptions* run_options,
int32_t logical_device_id =
channel_id_present ? logical_id.computation_id : logical_id.replica_id;

std::optional<int> source_replica_id;
std::vector<int> copy_to;
std::optional<RankId> source_replica_id;
std::vector<RankId> copy_to;
for (auto& p : pairs) {
std::vector<std::string> mapping = absl::StrSplit(p, '=');
CHECK_EQ(mapping.size(), 2);
int from = std::stoi(mapping[0]);
int to = std::stoi(mapping[1]);
if (from == logical_device_id) {
copy_to.push_back(to);
copy_to.push_back(RankId(to));
}
if (to == logical_device_id) {
CHECK(!source_replica_id.has_value());
source_replica_id = from;
source_replica_id = RankId(from);
}
}
RendezvousKey rendezvous_key =
Expand All @@ -562,9 +563,15 @@ void CollectivePermuteImpl(const ExecutableRunOptions* run_options,

auto communicator =
collectives->GetCommunicator(rendezvous_key.global_devices, rank).value();

CpuCollectives::Executor executor(rendezvous_key, DefaultCollectiveTimeout());

se::DeviceMemoryBase input_buffer_data(input_buffer, byte_size);
se::DeviceMemoryBase output_buffer_data(output_buffer, byte_size);

TF_CHECK_OK(communicator->CollectivePermute(
rendezvous_key, byte_size, source_replica_id, copy_to, input_buffer,
output_buffer, DefaultCollectiveTimeout()));
input_buffer_data, output_buffer_data, U8, byte_size, source_replica_id,
copy_to, executor));
}
} // namespace
} // namespace runtime
Expand Down
Loading

0 comments on commit 8f29f77

Please sign in to comment.