Skip to content

Commit

Permalink
Added more fixes on top of base branch, needs one last rebase
Browse files Browse the repository at this point in the history
  • Loading branch information
Tixxx committed Jan 28, 2025
1 parent d9766ba commit 11f965b
Show file tree
Hide file tree
Showing 10 changed files with 324 additions and 119 deletions.
149 changes: 141 additions & 8 deletions xla/service/gpu/ir_emitter_unnested.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2005,6 +2005,112 @@ absl::Status IrEmitterUnnested::EmitNcclThunk(
return absl::OkStatus();
}

template <typename NvshmemThunkType, typename HloInstType>
absl::Status IrEmitterUnnested::EmitNvshmemThunk(
Thunk::Kind kind, const HloInstruction* async_start,
const HloInstType* inst, std::optional<bool> use_global_device_ids) {
CHECK(kind == Thunk::Kind::kNvshmemAllReduceStart);
const auto& hlo_config = ir_emitter_context_->hlo_module().config();
int64_t replica_count = hlo_config.replica_count();
int64_t partition_count = hlo_config.num_partitions();
VLOG(2) << NvshmemThunkType::GetHloOpName()
<< "; replica count: " << replica_count
<< "; partition count: " << partition_count
<< "; operand count: " << inst->operand_count();

// A given collective op can be degenerate if across all groups formed
// by it are singleton. In such a case, we don't need to do any communication
// and we can just copy the input to the output.
//
// The only exception is RaggedAllToAll, which is not degenerate even if
// all groups are singleton. In a singleton group case, RaggedAllToAll becomes
// a generic equivalent of DynamicUpdateSlice, except update size is not
// statically known. This operation can not be expressed in term of standard
// HLO instructions, so the best solution we have is to use NCCL thunk even
// for degenerate cases.
bool is_degenerate = GetNvshmemCollectiveConfig(inst, use_global_device_ids)
.IsDegenerate(replica_count, partition_count);
absl::Status implementable_status = NvshmemThunkType::CheckImplementable(
inst, replica_count, partition_count);
bool should_use_nvshmem_thunk = !is_degenerate && implementable_status.ok();

// Stash relevant information in NvshmemCollectiveThunk::Buffer even if we may
// not generate an NcclCollectiveThunk.
std::vector<NvshmemCollectiveThunk::Buffer> buffers;

int64_t operand_count = inst->operand_count();
buffers.reserve(operand_count);

// Adds a source and destination buffers pair to `buffers`.
auto add_buffer = [&](int64_t element_count, BufferAllocation::Slice src,
int64_t src_memory_space, BufferAllocation::Slice dst,
int64_t dst_memory_space) {
buffers.push_back(NvshmemCollectiveThunk::Buffer{
/*element_count=*/element_count,
/*source_buffer=*/src,
/*destination_buffer=*/dst,
/*source_memory_space=*/src_memory_space,
/*destination_memory_space=*/dst_memory_space,
/*source_value=*/nullptr,
/*destination_value=*/nullptr});
};

if (kind == Thunk::Kind::kNvshmemAllReduceStart) {
// For other operations simply zip operands with results.
for (int64_t i = 0; i < operand_count; i++) {
ShapeIndex idx = operand_count > 1 ? ShapeIndex({i}) : ShapeIndex({});
const Shape& src_shape = inst->operand(i)->shape();
const Shape& dst_shape = ShapeUtil::GetSubshape(inst->shape(), idx);
TF_ASSIGN_OR_RETURN(auto src, GetAllocationSliceForHlo(inst->operand(i)));
TF_ASSIGN_OR_RETURN(auto dst, GetAllocationSliceForHlo(inst, idx));
add_buffer(ShapeUtil::ElementsIn(src_shape), src,
src_shape.layout().memory_space(), dst,
dst_shape.layout().memory_space());
}
}

if (should_use_nccl_thunk) {
auto thunk_info = Thunk::ThunkInfo::WithProfileAnnotation(inst);
// The wrapper name is used when syntactic sugar is turned on.
if (ir_emitter_context_->debug_options().xla_syntax_sugar_async_ops()) {
thunk_info.profile_annotation = async_start->name();
}
auto thunk = std::make_unique<NvshmemThunkType>(
thunk_info, inst, /*buffers=*/std::move(buffers),
ir_emitter_context_->debug_options().xla_gpu_use_memcpy_local_p2p());
GetCollectivesAsyncEvents().insert({async_start, thunk->async_events()});
AddThunkToThunkSequence(std::move(thunk));
return absl::OkStatus();
}

if (!is_degenerate) {
return implementable_status;
}

// Signal that start thunk not created with nullptr.
GetCollectivesAsyncEvents().insert({async_start, nullptr});

VLOG(1) << "Collective call is degenerate, not doing NCCL call";

// Degenerate collectives are simply identity function. Buffer
// assignment expects a copy, so that's what we do.
ThunkSequence thunks;
for (int64_t i = 0; i < buffers.size(); i++) {
const Shape shape = inst->operand(i)->shape();
thunks.push_back(std::make_unique<DeviceToDeviceCopyThunk>(
Thunk::ThunkInfo::WithProfileAnnotation(inst),
/*source_buffer=*/buffers[i].source_buffer,
/*destination_buffer=*/buffers[i].destination_buffer,
/*mem_size=*/ShapeUtil::ByteSizeOf(shape)));
}
if (thunks.size() == 1) {
AddThunkToThunkSequence(std::move(thunks[0]));
} else {
AddThunkToThunkSequence(std::make_unique<SequentialThunk>(
Thunk::ThunkInfo::WithProfileAnnotation(inst), std::move(thunks)));
}
return absl::OkStatus();
}
// Find the canonical send/recv start op for one of send, recv, send-done, or
// recv-done. For trivial cases send/recv and send-done/recv-done come in pairs
// and the canonical start op is the send/recv op of the pair. If send/recv is
Expand Down Expand Up @@ -2161,6 +2267,31 @@ absl::Status IrEmitterUnnested::EmitNcclAsyncDone(Thunk::Kind kind,
return absl::OkStatus();
}

absl::Status IrEmitterUnnested::EmitNvshmemAsyncDone(
Thunk::Kind kind, const HloInstruction* instr) {
CHECK(kind == Thunk::Kind::kNvshmemAllReduceDone);

const HloInstruction* start =
is_send_recv ? FindCanonicalSendRecvStartOp(inst) : inst->operand(0);

// Find canonical async event.
CollectivesAsyncEvents& collectives_async_events =
GetCollectivesAsyncEvents();
auto async_events_it = collectives_async_events.find(start);
TF_RET_CHECK(async_events_it != collectives_async_events.end())
<< "couldn't find async events for start operation";

// Can be null if no start thunk was created (e.g. if the start op is
// degenerate), in which case there's nothing to do here.
if (!async_events_it->second) return absl::OkStatus();

AsyncStreamKind stream_kind = AsyncStreamKind::kCollective;
AddThunkToThunkSequence(std::make_unique<NvshmemCollectiveDoneThunk>(
kind, Thunk::ThunkInfo::WithProfileAnnotation(inst),
async_events_it->second, stream_kind));
return absl::OkStatus();
}

absl::Status IrEmitterUnnested::EmitInfeed(const HloInfeedInstruction* instr) {
// Infeed instruction returns a tuple containing the result data and a token.
// We only need the result data to construct the infeed thunk.
Expand Down Expand Up @@ -2516,11 +2647,12 @@ absl::Status IrEmitterUnnested::EmitRecvDoneThunk(

inline bool IsNvshmemCollective(const HloInstruction* instr) {
bool is_nvshmem_collective = false;
if(instr->has_backend_config()) {
if (instr->has_backend_config()) {
auto gpu_config = instr->backend_config<GpuBackendConfig>();
const CollectiveBackendConfig& backend_config =
gpu_config.value().collective_backend_config();
is_nvshmem_collective = backend_config.backend() == CollectiveBackendConfig::NVSHMEM;
is_nvshmem_collective =
backend_config.backend() == CollectiveBackendConfig::NVSHMEM;
}
return is_nvshmem_collective;
}
Expand All @@ -2537,18 +2669,19 @@ absl::Status IrEmitterUnnested::EmitHloInstruction(
all_gather->use_global_device_ids());
}

case HloOpcode::kAllReduceDone:{
if(IsNvshmemCollective(instr)) {
case HloOpcode::kAllReduceDone: {
if (IsNvshmemCollective(instr)) {
return EmitNvshmemAsyncDone(Thunk::kNvshmemAllReduceDone, instr);
}
return EmitNcclAsyncDone(Thunk::kNcclAllReduceDone, instr);
}
case HloOpcode::kAllReduceStart: {
auto* all_reduce = Cast<HloAllReduceInstruction>(instr);
if(IsNvshmemCollective(instr)) {
return EmitNvshmemThunk<NvshmemAllReduceStartThunk, HloAllReduceInstruction>(
Thunk::kNvshmemAllReduceStart, all_reduce, all_reduce,
all_reduce->use_global_device_ids());
if (IsNvshmemCollective(instr)) {
return EmitNvshmemThunk<NvshmemAllReduceStartThunk,
HloAllReduceInstruction>(
Thunk::kNvshmemAllReduceStart, all_reduce, all_reduce,
all_reduce->use_global_device_ids());
}
return EmitNcclThunk<NcclAllReduceStartThunk, HloAllReduceInstruction>(
Thunk::kNcclAllReduceStart, all_reduce, all_reduce,
Expand Down
1 change: 1 addition & 0 deletions xla/service/gpu/runtime/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,7 @@ cc_library(
"@tsl//tsl/platform:numbers",
"@tsl//tsl/platform:logging",
"@tsl//tsl/platform:statusor",
"@com_googlesource_code_re2//:re2",
]+ if_cuda_is_configured([
"@local_config_cuda//cuda:cuda_headers",
"@nvshmem//:nvshmem",
Expand Down
82 changes: 8 additions & 74 deletions xla/service/gpu/runtime/nvshmem_all_reduce_thunk.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ limitations under the License.
#include "tsl/platform/errors.h"
#include "tsl/platform/logging.h"
#include "tsl/platform/statusor.h"
#include "xla/service/gpu/runtime/nvshmem_api.h"

namespace xla {
namespace gpu {
Expand All @@ -46,17 +47,14 @@ absl::Status RunAllReduce(GpuCollectives* collectives,
se::Stream& stream, Communicator* comm) {
int device_ordinal = stream.parent()->device_ordinal();
VLOG(3) << "Performing all-reduce from device ordinal: " << device_ordinal;
TF_RETURN_IF_ERROR(
MaybeRegisterBuffers(collectives, stream.parent(), buffers, comm));

TF_RETURN_IF_ERROR(collectives->GroupStart());
for (DeviceBufferPair& buffer : buffers) {
TF_RETURN_IF_ERROR(comm->AllReduce(
buffer.source_buffer, buffer.destination_buffer, buffer.element_type,
buffer.element_count, reduction_kind, GpuCollectives::On(stream)));
TF_RETURN_IF_ERROR(xla::gpu::NvshmemApi::Default().DoAllreduce(
xla::gpu::NvshmemApi::TEAMSKIND::kNODE, buffer.element_type,
buffer.destination_buffer, buffer.source_buffer, buffer.element_count,
config_.reduction_kind));
}

return collectives->GroupEnd();
return absl::OkStatus();
}

namespace impl {
Expand All @@ -82,7 +80,8 @@ NvshmemAllReduceConfig GetNvshmemAllReduceConfigInst(HloInstType* inst) {
CHECK(reduction_kind.has_value());

NvshmemAllReduceConfig config;
config.config = GetNvshmemCollectiveConfig(inst, inst->use_global_device_ids());
config.config =
GetNvshmemCollectiveConfig(inst, inst->use_global_device_ids());
config.reduction_kind = *reduction_kind;
return config;
}
Expand Down Expand Up @@ -136,70 +135,5 @@ absl::Status NvshmemAllReduceStartThunk::RunNvshmemCollective(
device_buffers, stream, comm_handle.comm);
}

NvshmemReduceScatterStartThunk::NvshmemReduceScatterStartThunk(
ThunkInfo thunk_info, const HloReduceScatterInstruction* inst,
std::vector<Buffer> buffers)
: NvshmemAllReduceReduceScatterThunkBase(
Thunk::kNvshmemReduceScatterStart, thunk_info,
impl::GetNvshmemAllReduceConfigInst(inst), std::move(buffers),
inst->backend_config<GpuBackendConfig>()
->collective_backend_config()
.is_sync()) {}

/*static*/ absl::Status NvshmemReduceScatterStartThunk::CheckImplementable(
const HloReduceScatterInstruction* inst, int64_t replica_count,
int64_t partition_count) {
return AddOpDescription<NvshmemReduceScatterStartThunk>(
impl::CheckImplementableInst(inst, Thunk::kNvshmemReduceScatterStart), inst,
replica_count, partition_count);
}

/*static*/ CollectiveOpGroupMode NvshmemReduceScatterStartThunk::GetGroupMode(
const HloReduceScatterInstruction* inst) {
return impl::GetGroupModeInst(inst);
}

absl::Status NvshmemReduceScatterStartThunk::RunNvshmemCollective(
const ExecuteParams& params, se::Stream& stream,
CommunicatorHandle comm_handle) {
TF_ASSIGN_OR_RETURN(
std::vector<DeviceBufferPair> device_buffers,
ConvertToDeviceBuffers(params, buffers_,
config_.config.operand_element_type));
TF_ASSIGN_OR_RETURN(GpuCollectives * collectives, GetGpuCollectives(params));
return ::xla::gpu::RunReduceScatter(collectives, config_.reduction_kind,
device_buffers, stream, comm_handle.comm);
}

absl::Status RunReduceScatter(GpuCollectives* collectives,
ReductionKind reduction_kind,
std::vector<DeviceBufferPair>& buffers,
se::Stream& stream, Communicator* comm) {
int device_ordinal = stream.parent()->device_ordinal();
VLOG(3) << "Performing reduce-scatter from device ordinal: "
<< device_ordinal;
TF_RETURN_IF_ERROR(
MaybeRegisterBuffers(collectives, stream.parent(), buffers, comm));

TF_ASSIGN_OR_RETURN(int32_t num_ranks, comm->NumRanks());

TF_RETURN_IF_ERROR(collectives->GroupStart());

for (DeviceBufferPair& buffer : buffers) {
// buffer.element_count is the source buffers element count. For
// ReduceScatter, we need the destination buffers element count.
TF_RET_CHECK(buffer.element_count % num_ranks == 0)
<< "Source buffer was not an exact multiple of the number of "
"participants.";

TF_RETURN_IF_ERROR(comm->ReduceScatter(
buffer.source_buffer, buffer.destination_buffer, buffer.element_type,
buffer.element_count / num_ranks, reduction_kind,
GpuCollectives::On(stream)));
}

return collectives->GroupEnd();
}

} // namespace gpu
} // namespace xla
34 changes: 19 additions & 15 deletions xla/service/gpu/runtime/nvshmem_all_reduce_thunk.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,13 @@ struct NvshmemAllReduceConfig {
class NvshmemAllReduceReduceScatterThunkBase : public NvshmemCollectiveThunk {
public:
NvshmemAllReduceReduceScatterThunkBase(Kind kind, ThunkInfo thunk_info,
NvshmemAllReduceConfig config,
std::vector<Buffer> buffers,
bool is_sync);
NvshmemAllReduceConfig config,
std::vector<Buffer> buffers,
bool is_sync);

const NvshmemCollectiveConfig& config() const override { return config_.config; }
const NvshmemCollectiveConfig& config() const override {
return config_.config;
}
ReductionKind reduction_kind() const { return config_.reduction_kind; }

absl::Span<const Buffer> buffers() const { return buffers_; }
Expand All @@ -59,12 +61,13 @@ class NvshmemAllReduceReduceScatterThunkBase : public NvshmemCollectiveThunk {
// AllReduce thunk.
// -----------------------------------------------------------------------------

class NvshmemAllReduceStartThunk : public NvshmemAllReduceReduceScatterThunkBase {
class NvshmemAllReduceStartThunk
: public NvshmemAllReduceReduceScatterThunkBase {
public:
NvshmemAllReduceStartThunk(ThunkInfo thunk_info,
const HloAllReduceInstruction* inst,
std::vector<Buffer> buffers,
bool p2p_memcpy_enabled = false);
const HloAllReduceInstruction* inst,
std::vector<Buffer> buffers,
bool p2p_memcpy_enabled = false);

static const char* GetHloOpName() { return "all-reduce-start"; }

Expand All @@ -77,18 +80,19 @@ class NvshmemAllReduceStartThunk : public NvshmemAllReduceReduceScatterThunkBase

protected:
absl::Status RunNvshmemCollective(const ExecuteParams& params,
se::Stream& stream,
CommunicatorHandle comm_handle) override;
se::Stream& stream,
CommunicatorHandle comm_handle) override;
};

// -----------------------------------------------------------------------------
// ReduceScatter thunk
// -----------------------------------------------------------------------------
class NvshmemReduceScatterStartThunk : public NvshmemAllReduceReduceScatterThunkBase {
class NvshmemReduceScatterStartThunk
: public NvshmemAllReduceReduceScatterThunkBase {
public:
NvshmemReduceScatterStartThunk(ThunkInfo thunk_info,
const HloReduceScatterInstruction* inst,
std::vector<Buffer> buffers);
const HloReduceScatterInstruction* inst,
std::vector<Buffer> buffers);

static const char* GetHloOpName() { return "reduce-scatter-start"; }

Expand All @@ -101,8 +105,8 @@ class NvshmemReduceScatterStartThunk : public NvshmemAllReduceReduceScatterThunk

protected:
absl::Status RunNvshmemCollective(const ExecuteParams& params,
se::Stream& stream,
CommunicatorHandle comm_handle) override;
se::Stream& stream,
CommunicatorHandle comm_handle) override;
};

// -----------------------------------------------------------------------------
Expand Down
Loading

0 comments on commit 11f965b

Please sign in to comment.