diff --git a/tensorflow/compiler/xla/python/xla.cc b/tensorflow/compiler/xla/python/xla.cc index c3fc42161ba..ae3b560b00e 100644 --- a/tensorflow/compiler/xla/python/xla.cc +++ b/tensorflow/compiler/xla/python/xla.cc @@ -62,6 +62,7 @@ limitations under the License. #include "tensorflow/python/lib/core/bfloat16.h" #include "third_party/nccl/nccl.h" #include "tensorflow/compiler/xla/service/gpu/alpa_nccl_wrapper.h" +#include "tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.h" // TODO(phawkins): remove host_id properties after JAX is update to avoid them. @@ -551,10 +552,15 @@ PYBIND11_MODULE(xla_extension, m) { "nccl broadcast with only a subset of gpus in the host are involved"); m.def("nccl_create_communicators", &gpu::NcclCreateCommunicators, "nccl create communicators for multiple threads case"); + m.def("nccl_create_communicators_no_stream", + &gpu::NcclCreateCommunicatorsNoStream, + "nccl create pure communicators"); m.def("get_buffer_device_id", &gpu::GetBufferDeviceId, "get the local device id for one pybuffer"); m.def("nccl_recv", &gpu::NcclRecv, "nccl recv data"); m.def("nccl_send", &gpu::NcclSend, "nccl send data"); + m.def("set_cross_mesh_communicator", &gpu::SetCrossMeshCommunicators, + "set nccl communicators for cross mesh collective communication"); } // NOLINT(readability/fn_size) diff --git a/tensorflow/compiler/xla/service/gpu/alpa_nccl_wrapper.cc b/tensorflow/compiler/xla/service/gpu/alpa_nccl_wrapper.cc index 13c39c5edf2..7a03dc874e5 100644 --- a/tensorflow/compiler/xla/service/gpu/alpa_nccl_wrapper.cc +++ b/tensorflow/compiler/xla/service/gpu/alpa_nccl_wrapper.cc @@ -311,6 +311,31 @@ StatusOr GetBufferDeviceId(PyBuffer::object buffer) { return buffer.buf()->device()->local_hardware_id(); } +StatusOr> NcclCreateCommunicatorsNoStream( + int world_size, std::vector devices_global_rank, + std::vector devices_ids, std::vector nccl_uid_vec) { +#if XLA_ENABLE_XCCL + int n_devices = devices_global_rank.size(); + CHECK_EQ(n_devices, devices_ids.size()); + ncclUniqueId nccl_uid; + CHECK_EQ(sizeof(nccl_uid.internal), nccl_uid_vec.size()); + memcpy(&nccl_uid.internal, nccl_uid_vec.data(), sizeof(nccl_uid.internal)); + + std::vector comms; + comms.resize(n_devices); + XLA_CUDA_RETURN_IF_ERROR(ncclGroupStart()); + for (int i = 0; i < n_devices; i++) { + cudaSetDevice(devices_ids[i]); + ncclComm_t* comm_ref = reinterpret_cast(comms.data() + i); + XLA_CUDA_RETURN_IF_ERROR(ncclCommInitRank( + comm_ref, world_size, nccl_uid, devices_global_rank[i])); + } + XLA_CUDA_RETURN_IF_ERROR(ncclGroupEnd()); + return comms; +#else // XLA_ENABLE_XCCL + return Unimplemented("NCCL support is not available."); +#endif // XLA_ENABLE_XCCL +} } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/alpa_nccl_wrapper.h b/tensorflow/compiler/xla/service/gpu/alpa_nccl_wrapper.h index 6d2acd28cba..224551a10b3 100644 --- a/tensorflow/compiler/xla/service/gpu/alpa_nccl_wrapper.h +++ b/tensorflow/compiler/xla/service/gpu/alpa_nccl_wrapper.h @@ -116,6 +116,10 @@ StatusOr> NcclCreateCommunicators(int world_siz std::vector nccl_uid, bool nccl_use_multistream); +StatusOr> NcclCreateCommunicatorsNoStream( + int world_size, std::vector devices_global_rank, + std::vector devices_ids, std::vector nccl_uid_vec); + StatusOr GetBufferDeviceId(PyBuffer::object buffer); diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc index 5e7aa1167c7..bda16a0c09e 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc @@ -153,6 +153,8 @@ const char* const kBuiltinSwapOutTarget = "__builtin$SwapOut"; const char* const kBuiltinSwapInTarget = "__builtin$SwapIn"; const char* const kBuiltinSwapDoneTarget = "__builtin$SwapDone"; const char* const kBuiltinMemZeroTarget = "__builtin$MemZero"; +const char* const kBuiltinCrossMeshAllReduceTarget = + "__builtin$CrossMeshAllReduce"; static ReductionDimensions GetReductionKindAndContiguousComponentsImpl( const Shape& input_shape, absl::Span dims_to_reduce) { diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h index 503b0a2796a..985b19bd653 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h @@ -68,6 +68,7 @@ extern const char* const kBuiltinSwapOutTarget; extern const char* const kBuiltinSwapInTarget; extern const char* const kBuiltinSwapDoneTarget; extern const char* const kBuiltinMemZeroTarget; +extern const char* const kBuiltinCrossMeshAllReduceTarget; // Returns true if either the dimensions being reduced or the dimensions being // kept are contiguous in the input of the reduce instruction. diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index bb20f7825a3..44b3a3454cd 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -1343,10 +1343,10 @@ Status IrEmitterUnnested::EmitCholeskyThunk(mlir::Operation* op) { } #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM -Status IrEmitterUnnested::EmitSwapThunk(mlir::Operation* op) { +using Slices = std::vector; +StatusOr> IrEmitterUnnested::CustomCallParseBuffers( + mlir::Operation* op) { auto custom_call = mlir::cast(op); - const std::string call_target_name = custom_call.call_target_name().str(); - std::vector operands; std::vector results; @@ -1393,6 +1393,16 @@ Status IrEmitterUnnested::EmitSwapThunk(mlir::Operation* op) { TF_ASSIGN_OR_RETURN(operands, values_to_slices(custom_call.args())); TF_ASSIGN_OR_RETURN(results, values_to_slices(custom_call.output())); } + return std::make_pair(operands, results); +} + +Status IrEmitterUnnested::EmitSwapThunk(mlir::Operation* op) { + auto custom_call = mlir::cast(op); + const std::string call_target_name = custom_call.call_target_name().str(); + + TF_ASSIGN_OR_RETURN(auto all_buffers, CustomCallParseBuffers(op)); + std::vector& operands = all_buffers.first; + std::vector& results = all_buffers.second; std::vector byte_sizes; std::vector keys = @@ -1470,8 +1480,41 @@ Status IrEmitterUnnested::EmitMemZeroThunk(mlir::Operation* op) { AddThunkToThunkSequence( absl::make_unique(GetThunkInfo(op), operand)); } + return Status::OK(); } +Status IrEmitterUnnested::EmitCrossMeshAllReduceTarget(mlir::Operation* op) { + auto custom_call = mlir::cast(op); + TF_ASSIGN_OR_RETURN(auto all_buffers, CustomCallParseBuffers(op)); + std::vector& operands = all_buffers.first; + std::vector& results = all_buffers.second; + std::vector buffers; + CHECK_EQ(operands.size(), results.size()); + for (auto buf : llvm::zip(operands, results, custom_call.args())) { + BufferAllocation::Slice src = std::get<0>(buf); + BufferAllocation::Slice dst = std::get<1>(buf); + mlir::Value src_value = std::get<2>(buf); + buffers.push_back(NcclCollectiveThunk::Buffer{ + ShapeUtil::ElementsIn(GetShape(src_value)), src, dst}); + } + ReductionKind reduction_kind; + // TODO(yonghao): the opaque should also adds participant mesh group info. + const std::string op_str = custom_call.backend_config().str(); + if (op_str == "SUM") { + reduction_kind = ReductionKind::SUM; + } else if (op_str == "AND" || op_str == "MIN") { + reduction_kind = ReductionKind::MIN; + } else if (op_str == "OR" || op_str == "MAX") { + reduction_kind = ReductionKind::MAX; + } else { + return InternalError("cross mesh allreduce op %s is unsupported", + op_str.c_str()); + } + auto op_type = GetShape(custom_call.args()[0]).element_type(); + AddThunkToThunkSequence(absl::make_unique( + GetThunkInfo(op), buffers, reduction_kind, op_type)); + return Status::OK(); +} Status IrEmitterUnnested::EmitCustomCallThunk(mlir::Operation* op) { auto custom_call = mlir::cast(op); @@ -5738,6 +5781,9 @@ Status IrEmitterUnnested::EmitOp(mlir::Operation* op) { if (call.call_target_name() == kBuiltinMemZeroTarget) { return EmitMemZeroThunk(op); } + if (call.call_target_name() == kBuiltinCrossMeshAllReduceTarget) { + return EmitCrossMeshAllReduceTarget(op); + } return EmitCustomCallThunk(op); } diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h index c08ccc5dec3..4e237e627d8 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h @@ -181,6 +181,10 @@ class IrEmitterUnnested : public IrEmitter { Status EmitSwapThunk(mlir::Operation* op); Status EmitSwapDoneThunk(mlir::Operation* op); Status EmitMemZeroThunk(mlir::Operation* op); + Status EmitCrossMeshAllReduceTarget(mlir::Operation* op); + using Slices = std::vector; + StatusOr> CustomCallParseBuffers( + mlir::Operation* op); Status EmitCustomCallThunk(mlir::Operation* op); Status EmitFftThunk(mlir::Operation* op); Status EmitFusion(mlir::Operation* op); diff --git a/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.cc b/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.cc index ae9c2450cf1..eb7175d2e1e 100644 --- a/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.cc @@ -458,5 +458,74 @@ Status NcclReduceScatterThunk::RunNcclCollective(const ExecuteParams& params, #endif // XLA_ENABLE_XCCL } +std::vector> nccl_comms; + +// FIXME(yonghao): support multiple groups of cross mesh nccl comms with keys +void SetCrossMeshCommunicators(const std::vector& comms, + const std::string& group_keys) { + nccl_comms.clear(); + nccl_comms.reserve(comms.size()); + for (std::uintptr_t comm : comms) { + nccl_comms.emplace_back( + std::make_unique(reinterpret_cast(comm))); + } +} + +NcclAllReduceConfig GetCrossMeshNcclAllReduceConfig( + ReductionKind reduction_kind, xla::PrimitiveType op_type) { + NcclAllReduceConfig config; + NcclCollectiveConfig& collective_config = config.config; + collective_config.operand_count = 1; + collective_config.operand_element_type.push_back(op_type); + // The replica_groups, collective_op_kind and group_mode are used to + // identify nccl comm in XLA's original collective thunks, so they are + // not used in this thunk. + config.reduction_kind = reduction_kind; + return config; +} + +CrossMeshNcclAllReduceThunk::CrossMeshNcclAllReduceThunk( + ThunkInfo thunk_info, std::vector buffers, + ReductionKind reduction_kind, xla::PrimitiveType op_type) + : Thunk(Thunk::kNcclAllReduce, thunk_info), + buffers_(buffers), + config_(GetCrossMeshNcclAllReduceConfig(reduction_kind, op_type)) {} + +Status CrossMeshNcclAllReduceThunk::ExecuteOnStream( + const ExecuteParams& params) { +#if XLA_ENABLE_XCCL + VLOG(1) << absl::StreamFormat("Starting %s.", Thunk::KindToString(kind())); + + se::StreamExecutor* executor = params.stream->parent(); + se::gpu::ScopedActivateExecutorContext scoped_context(executor); + + // TF_ASSIGN_OR_RETURN( + // NcclComm::Lock comm, + // AcquireNcclComm(params.run_id, op_id, std::move(participants), + // num_local_participants, *unique_id_callback, rank)); + // TODO(yonghao): support CrossMeshNcclAllReduce for different mesh groups as above + // using participants info created at compile time + int device_ordinal = params.stream->parent()->device_ordinal(); + NcclComm::Lock comm = nccl_comms[device_ordinal]->Acquire(); + + se::Stream& stream = *params.stream; + TF_RETURN_IF_ERROR(RunAllReduce(config_, buffers_, *params.buffer_allocations, + stream, *comm, "")); + + // Block host on the first call to ensure that all devices have allocated the + // required buffers for their communicators before allowing any device to + // continue enqueuing operations. Otherwise, the allocations can cause + // deadlock in the CUDA driver (b/215649390). + if (first_call_to_execute_) { + TF_RETURN_IF_ERROR(params.stream->BlockHostUntilDone()); + first_call_to_execute_ = false; + } + return Status::OK(); +#else // XLA_ENABLE_XCCL + return Unimplemented( + "NCCL support is not available: this binary was not built with a CUDA " + "compiler, which is necessary to build the NCCL source library."); +#endif // XLA_ENABLE_XCCL +} } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.h b/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.h index 7b72912d8d4..2bdcd18b62b 100644 --- a/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.h @@ -130,6 +130,26 @@ class NcclReduceScatterThunk : public NcclAllReduceThunkBase { ncclComm_t comm) override; }; +class CrossMeshNcclAllReduceThunk : public Thunk { + public: + using Buffer = NcclCollectiveThunk::Buffer; + + explicit CrossMeshNcclAllReduceThunk(ThunkInfo thunk_info, + std::vector buffers, + ReductionKind reduction_kind, + xla::PrimitiveType op_type); + + Status ExecuteOnStream(const ExecuteParams& params) override; + + private: + const NcclAllReduceConfig config_; + const std::vector buffers_; + bool first_call_to_execute_ = true; +}; + +void SetCrossMeshCommunicators(const std::vector& comms, + const std::string& group_keys); + } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/nccl_utils.h b/tensorflow/compiler/xla/service/gpu/nccl_utils.h index 39d7fbabefd..ac557af68ca 100644 --- a/tensorflow/compiler/xla/service/gpu/nccl_utils.h +++ b/tensorflow/compiler/xla/service/gpu/nccl_utils.h @@ -108,6 +108,7 @@ TF_LIB_GTL_DEFINE_INT_TYPE(OpId, int64_t); struct NcclComm : public Lockable { NcclComm() : Lockable(nullptr) {} + NcclComm(ncclComm_t comm) : Lockable(comm) {} }; StatusOr AcquireNcclComm( diff --git a/tensorflow/compiler/xla/service/spmd/auto_sharding.cc b/tensorflow/compiler/xla/service/spmd/auto_sharding.cc index 3db20b14db0..6e0a408a229 100644 --- a/tensorflow/compiler/xla/service/spmd/auto_sharding.cc +++ b/tensorflow/compiler/xla/service/spmd/auto_sharding.cc @@ -1415,7 +1415,11 @@ BuildStrategyAndCost(const HloInstructionSequence& sequence, break; } case HloOpcode::kCustomCall: { - if (IsCustomCallMarker(ins)) { + if (ins->IsCustomCall(kCrossMeshAllReduce)) { + strategies = CreateLeafStrategyVector(instruction_id, ins, strategy_map, + leaf_strategies); + AddReplicatedStrategy(ins, cluster_env, strategy_map, strategies, 0); + } else if (IsCustomCallMarker(ins)) { const HloInstruction* operand = ins->operand(0); const StrategyVector* src_strategies = strategy_map.at(operand).get(); CHECK(src_strategies->is_tuple); diff --git a/tensorflow/compiler/xla/service/spmd/auto_sharding_util.cc b/tensorflow/compiler/xla/service/spmd/auto_sharding_util.cc index fc982a2691a..2461225197a 100644 --- a/tensorflow/compiler/xla/service/spmd/auto_sharding_util.cc +++ b/tensorflow/compiler/xla/service/spmd/auto_sharding_util.cc @@ -21,6 +21,7 @@ NullStream& NullStream::Global() { const char* const kPipelineMarker = "pipeline_marker"; const char* const kIdentityMarker = "identity"; +const char* const kCrossMeshAllReduce = "__builtin$CrossMeshAllReduce"; // Return whether a reshape instruction is a special reshape that switches // the batch dim of a dot. diff --git a/tensorflow/compiler/xla/service/spmd/auto_sharding_util.h b/tensorflow/compiler/xla/service/spmd/auto_sharding_util.h index 35e746c7fac..e4b577bb977 100644 --- a/tensorflow/compiler/xla/service/spmd/auto_sharding_util.h +++ b/tensorflow/compiler/xla/service/spmd/auto_sharding_util.h @@ -29,6 +29,7 @@ using ReshardingCache = extern const char* const kPipelineMarker; extern const char* const kIdentityMarker; +extern const char* const kCrossMeshAllReduce; constexpr absl::string_view kPipelineMarkerStartType = "start"; constexpr absl::string_view kPipelineMarkerEndType = "end"; diff --git a/tensorflow/compiler/xla/service/spmd/spmd_partitioner.cc b/tensorflow/compiler/xla/service/spmd/spmd_partitioner.cc index 080ec921465..049a1be1fd7 100644 --- a/tensorflow/compiler/xla/service/spmd/spmd_partitioner.cc +++ b/tensorflow/compiler/xla/service/spmd/spmd_partitioner.cc @@ -1504,7 +1504,8 @@ std::vector SpmdPartitioningVisitor::CreateReplicaGroups( } Status SpmdPartitioningVisitor::DefaultAction(HloInstruction* hlo) { - if (hlo->IsCustomCall("identity") || hlo->IsCustomCall("pipeline_marker")) { + if (hlo->IsCustomCall("identity") || hlo->IsCustomCall("pipeline_marker") || + hlo->IsCustomCall("__builtin$CrossMeshAllReduce")) { return HandleElementwise(hlo); } diff --git a/tensorflow/compiler/xla/service/spmd/stateful_rng_spmd_partitioner.cc b/tensorflow/compiler/xla/service/spmd/stateful_rng_spmd_partitioner.cc index 84e786bbbcb..9df98e1b8b8 100644 --- a/tensorflow/compiler/xla/service/spmd/stateful_rng_spmd_partitioner.cc +++ b/tensorflow/compiler/xla/service/spmd/stateful_rng_spmd_partitioner.cc @@ -79,7 +79,10 @@ bool StatefulRngSpmdPartitioner::CanSideEffectingHaveReplicatedSharding( // Alpa-specific changes for profling if (hlo->opcode() == HloOpcode::kAllReduce && Cast(hlo)->use_global_device_ids()) return true; - if (hlo->IsCustomCall(kPipelineMarker) || hlo->IsCustomCall(kIdentityMarker)) return true; + if (hlo->IsCustomCall(kPipelineMarker) || + hlo->IsCustomCall(kIdentityMarker) || + hlo->IsCustomCall(kCrossMeshAllReduce)) + return true; return spmd::SpmdPartitioner::CanSideEffectingHaveReplicatedSharding(hlo); }