Skip to content

Commit

Permalink
[Feature] Support cross mesh nccl allreduce (#131)
Browse files Browse the repository at this point in the history
  • Loading branch information
ZYHowell authored Aug 26, 2022
1 parent 7e698a6 commit b80a87c
Show file tree
Hide file tree
Showing 15 changed files with 194 additions and 6 deletions.
6 changes: 6 additions & 0 deletions tensorflow/compiler/xla/python/xla.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ limitations under the License.
#include "tensorflow/python/lib/core/bfloat16.h"
#include "third_party/nccl/nccl.h"
#include "tensorflow/compiler/xla/service/gpu/alpa_nccl_wrapper.h"
#include "tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.h"

// TODO(phawkins): remove host_id properties after JAX is update to avoid them.

Expand Down Expand Up @@ -551,10 +552,15 @@ PYBIND11_MODULE(xla_extension, m) {
"nccl broadcast with only a subset of gpus in the host are involved");
m.def("nccl_create_communicators", &gpu::NcclCreateCommunicators,
"nccl create communicators for multiple threads case");
m.def("nccl_create_communicators_no_stream",
&gpu::NcclCreateCommunicatorsNoStream,
"nccl create pure communicators");
m.def("get_buffer_device_id", &gpu::GetBufferDeviceId,
"get the local device id for one pybuffer");
m.def("nccl_recv", &gpu::NcclRecv, "nccl recv data");
m.def("nccl_send", &gpu::NcclSend, "nccl send data");
m.def("set_cross_mesh_communicator", &gpu::SetCrossMeshCommunicators,
"set nccl communicators for cross mesh collective communication");

} // NOLINT(readability/fn_size)

Expand Down
25 changes: 25 additions & 0 deletions tensorflow/compiler/xla/service/gpu/alpa_nccl_wrapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,31 @@ StatusOr<int> GetBufferDeviceId(PyBuffer::object buffer) {
return buffer.buf()->device()->local_hardware_id();
}

StatusOr<std::vector<std::uintptr_t>> NcclCreateCommunicatorsNoStream(
int world_size, std::vector<int> devices_global_rank,
std::vector<int> devices_ids, std::vector<int8_t> nccl_uid_vec) {
#if XLA_ENABLE_XCCL
int n_devices = devices_global_rank.size();
CHECK_EQ(n_devices, devices_ids.size());
ncclUniqueId nccl_uid;
CHECK_EQ(sizeof(nccl_uid.internal), nccl_uid_vec.size());
memcpy(&nccl_uid.internal, nccl_uid_vec.data(), sizeof(nccl_uid.internal));

std::vector<std::uintptr_t> comms;
comms.resize(n_devices);
XLA_CUDA_RETURN_IF_ERROR(ncclGroupStart());
for (int i = 0; i < n_devices; i++) {
cudaSetDevice(devices_ids[i]);
ncclComm_t* comm_ref = reinterpret_cast<ncclComm_t*>(comms.data() + i);
XLA_CUDA_RETURN_IF_ERROR(ncclCommInitRank(
comm_ref, world_size, nccl_uid, devices_global_rank[i]));
}
XLA_CUDA_RETURN_IF_ERROR(ncclGroupEnd());
return comms;
#else // XLA_ENABLE_XCCL
return Unimplemented("NCCL support is not available.");
#endif // XLA_ENABLE_XCCL
}
} // namespace gpu

} // namespace xla
4 changes: 4 additions & 0 deletions tensorflow/compiler/xla/service/gpu/alpa_nccl_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,10 @@ StatusOr<std::shared_ptr<NcclCommStorage>> NcclCreateCommunicators(int world_siz
std::vector<char> nccl_uid,
bool nccl_use_multistream);

StatusOr<std::vector<std::uintptr_t>> NcclCreateCommunicatorsNoStream(
int world_size, std::vector<int> devices_global_rank,
std::vector<int> devices_ids, std::vector<int8_t> nccl_uid_vec);

StatusOr<int> GetBufferDeviceId(PyBuffer::object buffer);


Expand Down
2 changes: 2 additions & 0 deletions tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,8 @@ const char* const kBuiltinSwapOutTarget = "__builtin$SwapOut";
const char* const kBuiltinSwapInTarget = "__builtin$SwapIn";
const char* const kBuiltinSwapDoneTarget = "__builtin$SwapDone";
const char* const kBuiltinMemZeroTarget = "__builtin$MemZero";
const char* const kBuiltinCrossMeshAllReduceTarget =
"__builtin$CrossMeshAllReduce";

static ReductionDimensions GetReductionKindAndContiguousComponentsImpl(
const Shape& input_shape, absl::Span<const int64_t> dims_to_reduce) {
Expand Down
1 change: 1 addition & 0 deletions tensorflow/compiler/xla/service/gpu/ir_emission_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ extern const char* const kBuiltinSwapOutTarget;
extern const char* const kBuiltinSwapInTarget;
extern const char* const kBuiltinSwapDoneTarget;
extern const char* const kBuiltinMemZeroTarget;
extern const char* const kBuiltinCrossMeshAllReduceTarget;

// Returns true if either the dimensions being reduced or the dimensions being
// kept are contiguous in the input of the reduce instruction.
Expand Down
52 changes: 49 additions & 3 deletions tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1343,10 +1343,10 @@ Status IrEmitterUnnested::EmitCholeskyThunk(mlir::Operation* op) {
}
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM

Status IrEmitterUnnested::EmitSwapThunk(mlir::Operation* op) {
using Slices = std::vector<BufferAllocation::Slice>;
StatusOr<std::pair<Slices, Slices>> IrEmitterUnnested::CustomCallParseBuffers(
mlir::Operation* op) {
auto custom_call = mlir::cast<mlir::lmhlo::CustomCallOp>(op);
const std::string call_target_name = custom_call.call_target_name().str();

std::vector<BufferAllocation::Slice> operands;
std::vector<BufferAllocation::Slice> results;

Expand Down Expand Up @@ -1393,6 +1393,16 @@ Status IrEmitterUnnested::EmitSwapThunk(mlir::Operation* op) {
TF_ASSIGN_OR_RETURN(operands, values_to_slices(custom_call.args()));
TF_ASSIGN_OR_RETURN(results, values_to_slices(custom_call.output()));
}
return std::make_pair(operands, results);
}

Status IrEmitterUnnested::EmitSwapThunk(mlir::Operation* op) {
auto custom_call = mlir::cast<mlir::lmhlo::CustomCallOp>(op);
const std::string call_target_name = custom_call.call_target_name().str();

TF_ASSIGN_OR_RETURN(auto all_buffers, CustomCallParseBuffers(op));
std::vector<BufferAllocation::Slice>& operands = all_buffers.first;
std::vector<BufferAllocation::Slice>& results = all_buffers.second;

std::vector<int64_t> byte_sizes;
std::vector<std::string> keys =
Expand Down Expand Up @@ -1470,8 +1480,41 @@ Status IrEmitterUnnested::EmitMemZeroThunk(mlir::Operation* op) {
AddThunkToThunkSequence(
absl::make_unique<MemzeroThunk>(GetThunkInfo(op), operand));
}
return Status::OK();
}

Status IrEmitterUnnested::EmitCrossMeshAllReduceTarget(mlir::Operation* op) {
auto custom_call = mlir::cast<mlir::lmhlo::CustomCallOp>(op);
TF_ASSIGN_OR_RETURN(auto all_buffers, CustomCallParseBuffers(op));
std::vector<BufferAllocation::Slice>& operands = all_buffers.first;
std::vector<BufferAllocation::Slice>& results = all_buffers.second;
std::vector<NcclCollectiveThunk::Buffer> buffers;
CHECK_EQ(operands.size(), results.size());
for (auto buf : llvm::zip(operands, results, custom_call.args())) {
BufferAllocation::Slice src = std::get<0>(buf);
BufferAllocation::Slice dst = std::get<1>(buf);
mlir::Value src_value = std::get<2>(buf);
buffers.push_back(NcclCollectiveThunk::Buffer{
ShapeUtil::ElementsIn(GetShape(src_value)), src, dst});
}
ReductionKind reduction_kind;
// TODO(yonghao): the opaque should also adds participant mesh group info.
const std::string op_str = custom_call.backend_config().str();
if (op_str == "SUM") {
reduction_kind = ReductionKind::SUM;
} else if (op_str == "AND" || op_str == "MIN") {
reduction_kind = ReductionKind::MIN;
} else if (op_str == "OR" || op_str == "MAX") {
reduction_kind = ReductionKind::MAX;
} else {
return InternalError("cross mesh allreduce op %s is unsupported",
op_str.c_str());
}
auto op_type = GetShape(custom_call.args()[0]).element_type();
AddThunkToThunkSequence(absl::make_unique<CrossMeshNcclAllReduceThunk>(
GetThunkInfo(op), buffers, reduction_kind, op_type));
return Status::OK();
}

Status IrEmitterUnnested::EmitCustomCallThunk(mlir::Operation* op) {
auto custom_call = mlir::cast<mlir::lmhlo::CustomCallOp>(op);
Expand Down Expand Up @@ -5738,6 +5781,9 @@ Status IrEmitterUnnested::EmitOp(mlir::Operation* op) {
if (call.call_target_name() == kBuiltinMemZeroTarget) {
return EmitMemZeroThunk(op);
}
if (call.call_target_name() == kBuiltinCrossMeshAllReduceTarget) {
return EmitCrossMeshAllReduceTarget(op);
}
return EmitCustomCallThunk(op);
}

Expand Down
4 changes: 4 additions & 0 deletions tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,10 @@ class IrEmitterUnnested : public IrEmitter {
Status EmitSwapThunk(mlir::Operation* op);
Status EmitSwapDoneThunk(mlir::Operation* op);
Status EmitMemZeroThunk(mlir::Operation* op);
Status EmitCrossMeshAllReduceTarget(mlir::Operation* op);
using Slices = std::vector<BufferAllocation::Slice>;
StatusOr<std::pair<Slices, Slices>> CustomCallParseBuffers(
mlir::Operation* op);
Status EmitCustomCallThunk(mlir::Operation* op);
Status EmitFftThunk(mlir::Operation* op);
Status EmitFusion(mlir::Operation* op);
Expand Down
69 changes: 69 additions & 0 deletions tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.cc
Original file line number Diff line number Diff line change
Expand Up @@ -458,5 +458,74 @@ Status NcclReduceScatterThunk::RunNcclCollective(const ExecuteParams& params,
#endif // XLA_ENABLE_XCCL
}

std::vector<std::unique_ptr<NcclComm>> nccl_comms;

// FIXME(yonghao): support multiple groups of cross mesh nccl comms with keys
void SetCrossMeshCommunicators(const std::vector<std::uintptr_t>& comms,
const std::string& group_keys) {
nccl_comms.clear();
nccl_comms.reserve(comms.size());
for (std::uintptr_t comm : comms) {
nccl_comms.emplace_back(
std::make_unique<NcclComm>(reinterpret_cast<ncclComm_t>(comm)));
}
}

NcclAllReduceConfig GetCrossMeshNcclAllReduceConfig(
ReductionKind reduction_kind, xla::PrimitiveType op_type) {
NcclAllReduceConfig config;
NcclCollectiveConfig& collective_config = config.config;
collective_config.operand_count = 1;
collective_config.operand_element_type.push_back(op_type);
// The replica_groups, collective_op_kind and group_mode are used to
// identify nccl comm in XLA's original collective thunks, so they are
// not used in this thunk.
config.reduction_kind = reduction_kind;
return config;
}

CrossMeshNcclAllReduceThunk::CrossMeshNcclAllReduceThunk(
ThunkInfo thunk_info, std::vector<Buffer> buffers,
ReductionKind reduction_kind, xla::PrimitiveType op_type)
: Thunk(Thunk::kNcclAllReduce, thunk_info),
buffers_(buffers),
config_(GetCrossMeshNcclAllReduceConfig(reduction_kind, op_type)) {}

Status CrossMeshNcclAllReduceThunk::ExecuteOnStream(
const ExecuteParams& params) {
#if XLA_ENABLE_XCCL
VLOG(1) << absl::StreamFormat("Starting %s.", Thunk::KindToString(kind()));

se::StreamExecutor* executor = params.stream->parent();
se::gpu::ScopedActivateExecutorContext scoped_context(executor);

// TF_ASSIGN_OR_RETURN(
// NcclComm::Lock comm,
// AcquireNcclComm(params.run_id, op_id, std::move(participants),
// num_local_participants, *unique_id_callback, rank));
// TODO(yonghao): support CrossMeshNcclAllReduce for different mesh groups as above
// using participants info created at compile time
int device_ordinal = params.stream->parent()->device_ordinal();
NcclComm::Lock comm = nccl_comms[device_ordinal]->Acquire();

se::Stream& stream = *params.stream;
TF_RETURN_IF_ERROR(RunAllReduce(config_, buffers_, *params.buffer_allocations,
stream, *comm, ""));

// Block host on the first call to ensure that all devices have allocated the
// required buffers for their communicators before allowing any device to
// continue enqueuing operations. Otherwise, the allocations can cause
// deadlock in the CUDA driver (b/215649390).
if (first_call_to_execute_) {
TF_RETURN_IF_ERROR(params.stream->BlockHostUntilDone());
first_call_to_execute_ = false;
}
return Status::OK();
#else // XLA_ENABLE_XCCL
return Unimplemented(
"NCCL support is not available: this binary was not built with a CUDA "
"compiler, which is necessary to build the NCCL source library.");
#endif // XLA_ENABLE_XCCL
}
} // namespace gpu
} // namespace xla
20 changes: 20 additions & 0 deletions tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.h
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,26 @@ class NcclReduceScatterThunk : public NcclAllReduceThunkBase {
ncclComm_t comm) override;
};

class CrossMeshNcclAllReduceThunk : public Thunk {
public:
using Buffer = NcclCollectiveThunk::Buffer;

explicit CrossMeshNcclAllReduceThunk(ThunkInfo thunk_info,
std::vector<Buffer> buffers,
ReductionKind reduction_kind,
xla::PrimitiveType op_type);

Status ExecuteOnStream(const ExecuteParams& params) override;

private:
const NcclAllReduceConfig config_;
const std::vector<Buffer> buffers_;
bool first_call_to_execute_ = true;
};

void SetCrossMeshCommunicators(const std::vector<std::uintptr_t>& comms,
const std::string& group_keys);

} // namespace gpu
} // namespace xla

Expand Down
1 change: 1 addition & 0 deletions tensorflow/compiler/xla/service/gpu/nccl_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ TF_LIB_GTL_DEFINE_INT_TYPE(OpId, int64_t);

struct NcclComm : public Lockable<ncclComm_t> {
NcclComm() : Lockable(nullptr) {}
NcclComm(ncclComm_t comm) : Lockable(comm) {}
};

StatusOr<NcclComm::Lock> AcquireNcclComm(
Expand Down
6 changes: 5 additions & 1 deletion tensorflow/compiler/xla/service/spmd/auto_sharding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1415,7 +1415,11 @@ BuildStrategyAndCost(const HloInstructionSequence& sequence,
break;
}
case HloOpcode::kCustomCall: {
if (IsCustomCallMarker(ins)) {
if (ins->IsCustomCall(kCrossMeshAllReduce)) {
strategies = CreateLeafStrategyVector(instruction_id, ins, strategy_map,
leaf_strategies);
AddReplicatedStrategy(ins, cluster_env, strategy_map, strategies, 0);
} else if (IsCustomCallMarker(ins)) {
const HloInstruction* operand = ins->operand(0);
const StrategyVector* src_strategies = strategy_map.at(operand).get();
CHECK(src_strategies->is_tuple);
Expand Down
1 change: 1 addition & 0 deletions tensorflow/compiler/xla/service/spmd/auto_sharding_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ NullStream& NullStream::Global() {

const char* const kPipelineMarker = "pipeline_marker";
const char* const kIdentityMarker = "identity";
const char* const kCrossMeshAllReduce = "__builtin$CrossMeshAllReduce";

// Return whether a reshape instruction is a special reshape that switches
// the batch dim of a dot.
Expand Down
1 change: 1 addition & 0 deletions tensorflow/compiler/xla/service/spmd/auto_sharding_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ using ReshardingCache =

extern const char* const kPipelineMarker;
extern const char* const kIdentityMarker;
extern const char* const kCrossMeshAllReduce;
constexpr absl::string_view kPipelineMarkerStartType = "start";
constexpr absl::string_view kPipelineMarkerEndType = "end";

Expand Down
3 changes: 2 additions & 1 deletion tensorflow/compiler/xla/service/spmd/spmd_partitioner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1504,7 +1504,8 @@ std::vector<ReplicaGroup> SpmdPartitioningVisitor::CreateReplicaGroups(
}

Status SpmdPartitioningVisitor::DefaultAction(HloInstruction* hlo) {
if (hlo->IsCustomCall("identity") || hlo->IsCustomCall("pipeline_marker")) {
if (hlo->IsCustomCall("identity") || hlo->IsCustomCall("pipeline_marker") ||
hlo->IsCustomCall("__builtin$CrossMeshAllReduce")) {
return HandleElementwise(hlo);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,10 @@ bool StatefulRngSpmdPartitioner::CanSideEffectingHaveReplicatedSharding(
// Alpa-specific changes for profling
if (hlo->opcode() == HloOpcode::kAllReduce &&
Cast<HloAllReduceInstruction>(hlo)->use_global_device_ids()) return true;
if (hlo->IsCustomCall(kPipelineMarker) || hlo->IsCustomCall(kIdentityMarker)) return true;
if (hlo->IsCustomCall(kPipelineMarker) ||
hlo->IsCustomCall(kIdentityMarker) ||
hlo->IsCustomCall(kCrossMeshAllReduce))
return true;
return spmd::SpmdPartitioner::CanSideEffectingHaveReplicatedSharding(hlo);
}

Expand Down

0 comments on commit b80a87c

Please sign in to comment.