Skip to content

Commit

Permalink
[XLA:GPU] add initial runtime support and working execution test for …
Browse files Browse the repository at this point in the history
…NCCL group

PiperOrigin-RevId: 695810515
  • Loading branch information
Google-ML-Automation committed Nov 12, 2024
1 parent bee8a4f commit f173440
Show file tree
Hide file tree
Showing 6 changed files with 130 additions and 64 deletions.
77 changes: 56 additions & 21 deletions xla/service/gpu/ir_emitter_unnested.cc
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ limitations under the License.
#include "xla/service/gpu/runtime/nccl_collective_broadcast_thunk.h"
#include "xla/service/gpu/runtime/nccl_collective_permute_thunk.h"
#include "xla/service/gpu/runtime/nccl_collective_thunk.h"
#include "xla/service/gpu/runtime/nccl_group_thunk.h"
#include "xla/service/gpu/runtime/nccl_p2p_thunk_common.h"
#include "xla/service/gpu/runtime/nccl_recv_thunk.h"
#include "xla/service/gpu/runtime/nccl_send_thunk.h"
Expand Down Expand Up @@ -1980,6 +1981,11 @@ static const HloInstruction* FindCanonicalSendRecvStartOp(
inst->opcode() == HloOpcode::kRecv ||
inst->opcode() == HloOpcode::kSendDone ||
inst->opcode() == HloOpcode::kRecvDone);
// If the instruction is wrapped in an async computation, return the
// instruction itself.
if (inst->parent()->IsAsyncComputation()) {
return inst;
}

// Find container while loop and index for the send/recv case or return
// canonical start op directly.
Expand Down Expand Up @@ -2053,9 +2059,36 @@ static const HloInstruction* FindCanonicalSendRecvStartOp(
return canonical_start_op;
}

absl::Status IrEmitterUnnested::EmitNcclGroupDoneThunk(
const HloInstruction* instr) {
return absl::UnimplementedError("EmitNcclGroupDoneThunk not implemented");
absl::Status IrEmitterUnnested::EmitNcclGroupThunk(const HloInstruction* instr,
Thunk::Kind kind) {
emit_group_thunks_ = true;
for (const HloInstruction* instr :
instr->async_wrapped_computation()->instructions()) {
if (kind == Thunk::Kind::kNcclGroupStart) {
TF_RETURN_IF_ERROR(EmitHloInstruction(instr));
} else {
// For kNcclGroupDone, we only need to emit the corresponding async done
// instructions. For now, only send/recv is supported.
switch (instr->opcode()) {
case HloOpcode::kSend:
TF_RETURN_IF_ERROR(
EmitNcclAsyncDone(Thunk::Kind::kNcclSendDone, instr));
break;
case HloOpcode::kRecv:
TF_RETURN_IF_ERROR(
EmitNcclAsyncDone(Thunk::Kind::kNcclRecvDone, instr));
break;
default:
break;
}
}
}
auto thunk = std::make_unique<NcclGroupThunk>(
instr, kind, std::move(scoped_thunk_sequence_));
// TODO (rosiezou): use absl cleanup to automatically reset this boolean.
emit_group_thunks_ = false;
AddThunkToThunkSequence(std::move(thunk));
return absl::OkStatus();
}

absl::Status IrEmitterUnnested::EmitNcclAsyncDone(Thunk::Kind kind,
Expand Down Expand Up @@ -2366,15 +2399,19 @@ absl::Status IrEmitterUnnested::EmitSendThunk(const HloSendInstruction* instr) {
CollectivesAsyncEvents& collectives_async_events =
GetCollectivesAsyncEvents();

// Wire up async events.
const HloInstruction* canonical_send_instr =
FindCanonicalSendRecvStartOp(instr);
if (collectives_async_events.contains(canonical_send_instr)) {
thunk->set_async_events(collectives_async_events[canonical_send_instr]);
// Wire up async events if the send thunk isn't emitted as a part of a
// group thunk.
if (!emit_group_thunks_) {
const HloInstruction* canonical_send_instr =
FindCanonicalSendRecvStartOp(instr);
if (collectives_async_events.contains(canonical_send_instr)) {
thunk->set_async_events(collectives_async_events[canonical_send_instr]);
} else {
collectives_async_events.try_emplace(instr, thunk->async_events());
}
} else {
collectives_async_events.try_emplace(instr, thunk->async_events());
}

AddThunkToThunkSequence(std::move(thunk));
return absl::OkStatus();
}
Expand Down Expand Up @@ -2439,14 +2476,17 @@ absl::Status IrEmitterUnnested::EmitRecvThunk(const HloRecvInstruction* instr) {
GetCollectivesAsyncEvents();

// Wire up async events.
const HloInstruction* canonical_recv_instr =
FindCanonicalSendRecvStartOp(instr);
if (collectives_async_events.contains(canonical_recv_instr)) {
thunk->set_async_events(collectives_async_events[canonical_recv_instr]);
if (!emit_group_thunks_) {
const HloInstruction* canonical_recv_instr =
FindCanonicalSendRecvStartOp(instr);
if (collectives_async_events.contains(canonical_recv_instr)) {
thunk->set_async_events(collectives_async_events[canonical_recv_instr]);
} else {
collectives_async_events.try_emplace(instr, thunk->async_events());
}
} else {
collectives_async_events.try_emplace(instr, thunk->async_events());
}

AddThunkToThunkSequence(std::move(thunk));
return absl::OkStatus();
}
Expand Down Expand Up @@ -2479,11 +2519,6 @@ absl::Status IrEmitterUnnested::EmitRecvDoneThunk(
return absl::OkStatus();
}

absl::Status IrEmitterUnnested::EmitNcclGroupStartThunk(
const HloInstruction* instruction) {
return absl::UnimplementedError("EmittNcclGroupStartThunk not implemented");
}

absl::Status IrEmitterUnnested::EmitHloInstruction(
const HloInstruction* instr) {
switch (instr->opcode()) {
Expand All @@ -2507,7 +2542,7 @@ absl::Status IrEmitterUnnested::EmitHloInstruction(
case HloOpcode::kAsyncDone: {
if (!instr->async_wrapped_computation()
->CanExpandIntoSingleInstruction()) {
return EmitNcclGroupDoneThunk(instr);
return EmitNcclGroupThunk(instr, Thunk::kNcclGroupDone);
}
const HloInstruction* wrapped = instr->async_wrapped_instruction();
switch (wrapped->opcode()) {
Expand Down Expand Up @@ -2540,7 +2575,7 @@ absl::Status IrEmitterUnnested::EmitHloInstruction(
// Multi-op async start will emit a NCCL group thunk.
if (!instr->async_wrapped_computation()
->CanExpandIntoSingleInstruction()) {
return EmitNcclGroupStartThunk(instr);
return EmitNcclGroupThunk(instr, Thunk::kNcclGroupStart);
}
const HloInstruction* wrapped = instr->async_wrapped_instruction();
switch (wrapped->opcode()) {
Expand Down
11 changes: 8 additions & 3 deletions xla/service/gpu/ir_emitter_unnested.h
Original file line number Diff line number Diff line change
Expand Up @@ -184,16 +184,19 @@ class IrEmitterUnnested : public IrEmitter {

absl::Status EmitHloInstruction(const HloInstruction* instr);

absl::Status EmitNcclGroupStartThunk(const HloInstruction* instr);

absl::Status EmitNcclGroupDoneThunk(const HloInstruction* instr);
absl::Status EmitNcclGroupThunk(const HloInstruction* instr,
Thunk::Kind kind);

absl::Status EmitTargetElementLoop(
const HloInstruction& hlo,
const llvm_ir::ElementGenerator& body_emitter) override;

// Add a owning Thunk object to the thunk sequence.
void AddThunkToThunkSequence(std::unique_ptr<Thunk> thunk) {
if (emit_group_thunks_) {
scoped_thunk_sequence_.emplace_back(std::move(thunk));
return;
}
thunk_sequence_.emplace_back(std::move(thunk));
}

Expand Down Expand Up @@ -340,6 +343,8 @@ class IrEmitterUnnested : public IrEmitter {

// The thunk sequence this IrEmitter generates for the input computation.
ThunkSequence thunk_sequence_;
ThunkSequence scoped_thunk_sequence_;
bool emit_group_thunks_ = false;

// Container for async send/recv events shared by send/recv thunks.
std::shared_ptr<SendRecvAsyncEvents> send_recv_events_;
Expand Down
52 changes: 32 additions & 20 deletions xla/service/gpu/runtime/nccl_group_thunk.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,39 +15,51 @@ limitations under the License.

#include "xla/service/gpu/runtime/nccl_group_thunk.h"

#include <cstdint>
#include <memory>
#include <utility>
#include <vector>

#include "absl/status/status.h"
#include "absl/strings/string_view.h"
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/service/gpu/runtime/nccl_api.h"
#include "xla/service/gpu/runtime/nccl_clique_key.h"
#include "xla/service/gpu/runtime/nccl_collective_thunk.h"
#include "xla/service/gpu/runtime/thunk.h"
#include "xla/stream_executor/stream.h"
#include "tsl/platform/errors.h"

namespace xla {
namespace gpu {

NcclGroupThunk::NcclGroupThunk(Thunk::Kind kind, ThunkInfo thunk_info,
NcclApi* nccl_api,
const HloInstruction* instruction,
int64_t replica_count, int64_t partition_count)
: NcclCollectiveThunk(kind, thunk_info, nccl_api,
/*is_sync=*/false),
config_() {
NcclGroupThunk::NcclGroupThunk(const HloInstruction* instruction,
Thunk::Kind kind,
std::vector<std::unique_ptr<Thunk>> thunks)
: Thunk(kind, ThunkInfo::WithProfileAnnotation(instruction)) {
nccl_api_ = NcclApi::Default();
for (auto& thunk : thunks) {
thunks_.emplace_back(std::move(thunk));
}
}

absl::Status NcclGroupThunk::RunNcclCollective(
const ExecuteParams& params, se::Stream& stream,
NcclCommHandleWrapper comm_wrapper) {
return absl::UnimplementedError(
"RunNcclCollective not implemented for NcclGroupThunk");
absl::Status NcclGroupThunk::Prepare(const PrepareParams& params,
ResourceRequests& resource_requests) {
for (const std::unique_ptr<Thunk>& thunk : thunks_) {
TF_RETURN_IF_ERROR(thunk->Prepare(params, resource_requests));
}
return absl::OkStatus();
}
absl::Status NcclGroupThunk::Initialize(const InitializeParams& params) {
for (const std::unique_ptr<Thunk>& thunk : thunks_) {
TF_RETURN_IF_ERROR(thunk->Initialize(params));
}
return absl::OkStatus();
}

AsyncStreamKind NcclGroupThunk::GetAsyncStreamKind() const {
return AsyncStreamKind::kCollective;
absl::Status NcclGroupThunk::ExecuteOnStream(
const Thunk::ExecuteParams& params) {
TF_RETURN_IF_ERROR(nccl_api_->GroupStart());
for (const std::unique_ptr<Thunk>& thunk : thunks_) {
TF_RETURN_IF_ERROR(thunk->ExecuteOnStream(params));
}
TF_RETURN_IF_ERROR(nccl_api_->GroupEnd());
return absl::OkStatus();
}

} // namespace gpu
} // namespace xla
26 changes: 12 additions & 14 deletions xla/service/gpu/runtime/nccl_group_thunk.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,13 @@ limitations under the License.
#define XLA_SERVICE_GPU_RUNTIME_NCCL_GROUP_THUNK_H_

#include <cstdint>
#include <memory>
#include <utility>

#include "absl/status/status.h"
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/service/gpu/runtime/nccl_api.h"
#include "xla/service/gpu/runtime/nccl_clique_key.h"
#include "xla/service/gpu/runtime/nccl_collective_thunk.h"
#include "xla/service/gpu/runtime/thunk.h"

namespace xla {
namespace gpu {
Expand All @@ -30,21 +32,17 @@ namespace gpu {
// NCCL group call in order for them to be dispatched to NCCL as a NCCL group.
// NCCL may or may not execute them in parallel.

class NcclGroupThunk : public NcclCollectiveThunk {
class NcclGroupThunk : public Thunk {
public:
NcclGroupThunk(Thunk::Kind kind, ThunkInfo thunk_info, NcclApi* nccl_api,
const HloInstruction* instruction, int64_t replica_count,
int64_t partition_count);

protected:
const NcclCollectiveConfig& config() const override { return config_; }
absl::Status RunNcclCollective(const ExecuteParams& params,
se::Stream& stream,
NcclCommHandleWrapper comm_wrapper) override;
AsyncStreamKind GetAsyncStreamKind() const override;
NcclGroupThunk(const HloInstruction* instruction, Thunk::Kind kind,
std::vector<std::unique_ptr<Thunk>> thunks);
absl::Status Prepare(const PrepareParams& params,
ResourceRequests& resource_requests) override;
absl::Status ExecuteOnStream(const Thunk::ExecuteParams& params) override;
absl::Status Initialize(const InitializeParams& params) override;

private:
const NcclCollectiveConfig config_;
ThunkSequence thunks_;
NcclApi* nccl_api_;
};

Expand Down
3 changes: 3 additions & 0 deletions xla/tests/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -2470,13 +2470,16 @@ xla_test(
],
deps = [
":hlo_test_base",
":literal_test_util",
":xla_internal_test_main",
"//xla:literal",
"//xla:literal_util",
"//xla/hlo/testlib:verified_hlo_module",
"//xla/service:hlo_module_config",
"@com_google_absl//absl/log",
"@com_google_absl//absl/strings:string_view",
"@com_google_absl//absl/types:span",
"@com_google_googletest//:gtest",
"@tsl//tsl/platform:statusor",
],
)
Expand Down
25 changes: 19 additions & 6 deletions xla/tests/nccl_group_execution_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ limitations under the License.
#include <utility>
#include <vector>

#include <gtest/gtest.h>
#include "absl/log/log.h"
#include "absl/strings/string_view.h"
#include "absl/types/span.h"
Expand All @@ -41,29 +42,34 @@ class NcclGroupExecutionTest : public HloTestBase {
};

XLA_TEST_F(NcclGroupExecutionTest, NcclGroupSendRecvNoWhileLoop) {
// TODO (rosiezou): remove the channel_id=0 workaround once it is optional.
const absl::string_view kModuleStr = R"(
HloModule module_main, entry_computation_layout={()->(f32[], f32[])}
wrapped_send_recv {
param0 = f32[] parameter(0)
param1 = token[] parameter(1)
send1 = (f32[], u32[], token[]) send(param0, param1), channel_id=1
send1 = (f32[], u32[], token[]) send(param0, param1), channel_id=0,
frontend_attributes={_xla_send_recv_source_target_pairs={{0,1},{1,2}}}
param2 = f32[] parameter(2)
param3 = token[] parameter(3)
send2 = (f32[], u32[], token[]) send(param2, param3), channel_id=2
send2 = (f32[], u32[], token[]) send(param2, param3), channel_id=0,
frontend_attributes={_xla_send_recv_source_target_pairs={{2,3}}}
param4 = token[] parameter(4)
recv1 = (f32[], u32[], token[]) recv(param4), channel_id=1
recv1 = (f32[], u32[], token[]) recv(param4), channel_id=0,
frontend_attributes={_xla_send_recv_source_target_pairs={{0,1},{1,2}}}
param5 = token[] parameter(5)
recv2 = (f32[], u32[], token[]) recv(param5), channel_id=2
recv2 = (f32[], u32[], token[]) recv(param5), channel_id=0,
frontend_attributes={_xla_send_recv_source_target_pairs={{2,3}}}
ROOT out = ((f32[], u32[], token[]), (f32[], u32[], token[]),
(f32[], u32[], token[]), (f32[], u32[], token[]))
tuple(send1, send2, recv1, recv2)
}
ENTRY main {
data1 = f32[] constant(1)
data1 = f32[] constant(10)
after-all1 = token[] after-all()
data2 = f32[] constant(2)
data2 = f32[] constant(20)
after-all2 = token[] after-all()
async-comp-start = ((f32[], token[], f32[], token[], token[], token[]),
((f32[], u32[], token[]), (f32[], u32[], token[]), (f32[], u32[], token[]),
Expand Down Expand Up @@ -104,6 +110,13 @@ XLA_TEST_F(NcclGroupExecutionTest, NcclGroupSendRecvNoWhileLoop) {
ExecuteReplicated(std::move(module), absl::Span<Literal* const>{},
kNumReplicas,
/*run_hlo_passes=*/true));
ASSERT_EQ(results.size(), kNumReplicas);
// TODO (rosiezou): remove the string comparison once a tuple comparison
// function is available in LiteralTestUtil.
EXPECT_EQ(results[0].ToStringWithoutShapeOneline(), "( 0, 0 )");
EXPECT_EQ(results[1].ToStringWithoutShapeOneline(), "( 10, 0 )");
EXPECT_EQ(results[2].ToStringWithoutShapeOneline(), "( 10, 0 )");
EXPECT_EQ(results[3].ToStringWithoutShapeOneline(), "( 0, 20 )");
}

} // namespace
Expand Down

0 comments on commit f173440

Please sign in to comment.