diff --git a/src/xccl/ProcessGroupXCCL.cpp b/src/xccl/ProcessGroupXCCL.cpp index cf05c3fd5..439e50899 100644 --- a/src/xccl/ProcessGroupXCCL.cpp +++ b/src/xccl/ProcessGroupXCCL.cpp @@ -290,14 +290,26 @@ const std::string& ProcessGroupXCCL::logPrefix() const { return logPrefix_; } +const std::vector& ProcessGroupXCCL::groupRanks() const { + if (options_->global_ranks_in_group.empty() && local_id_ == 0) { + static std::vector globalRanks(size_); + std::iota(globalRanks.begin(), globalRanks.end(), 0); + return globalRanks; + } + return options_->global_ranks_in_group; +} + ProcessGroupXCCL::ProcessGroupXCCL( - const c10::intrusive_ptr& store, + c10::intrusive_ptr store, int rank, - int size) + int size, + c10::intrusive_ptr options) : Backend(rank, size), - store_(store), + store_(std::move(store)), + options_(std::move(options)), xcclCommCounter_(0), local_id_(process_group_id++) { + this->setGroupUid(options_->group_name); logPrefix_ = createLogPrefix(); blockingWait_ = getCvarBool(TORCH_XCCL_BLOCKING_WAIT, false); init(); @@ -306,7 +318,10 @@ ProcessGroupXCCL::ProcessGroupXCCL( getCvarString({"TORCH_DISTRIBUTED_DEBUG"}, OFF.c_str()); const auto XcclVersion = getXcclVersion(); LOG(INFO) << logPrefix() << "ProcessGroupXCCL initialization options: " - << "size: " << size << ", global rank: " << rank_; + << "size: " << size << ", global rank: " << rank_ + << ", USE_HIGH_PRIORITY_STREAM: " + << options_->is_high_priority_stream + << ", PG Name: " << options_->group_name; LOG(INFO) << logPrefix() << "ProcessGroupXCCL environments: " << "XCCL version: " << XcclVersion @@ -388,9 +403,9 @@ std::shared_ptr ProcessGroupXCCL::getXCCLComm( rank = p2pRank; } - c10::impl::VirtualGuardImpl impl(device.type()); - c10::Stream stream = - impl.getStreamFromGlobalPool(device, /*isHighPriority=*/false); + bool force_high = getCvarBool(TORCH_XCCL_HIGH_PRIORITY, false); + c10::Stream stream = at::xpu::getStreamFromPool( + options_->is_high_priority_stream || force_high); sycl::queue& q = c10::xpu::XPUStream(stream).queue(); auto ctx = ccl::create_context(q.get_context()); @@ -451,6 +466,10 @@ void ProcessGroupXCCL::groupEnd() { --xcclActiveGroupCounter_; } +ProcessGroupXCCL::Options::Options(bool is_high_priority_stream) + : Backend::Options(XCCL_BACKEND_NAME), + is_high_priority_stream(is_high_priority_stream) {} + static constexpr int CoalActive = 0x01, CoalColl = 0x02, CoalP2P = 0x04; void ProcessGroupXCCL::startCoalescing() { coalescedDevice_.set_index(-1); diff --git a/src/xccl/ProcessGroupXCCL.hpp b/src/xccl/ProcessGroupXCCL.hpp index 8fc765f69..526f7fa21 100644 --- a/src/xccl/ProcessGroupXCCL.hpp +++ b/src/xccl/ProcessGroupXCCL.hpp @@ -22,6 +22,9 @@ #include namespace c10d { +static std::vector TORCH_XCCL_HIGH_PRIORITY = { + "TORCH_XCCL_HIGH_PRIORITY"}; + static std::vector TORCH_XCCL_BLOCKING_WAIT = { "TORCH_XCCL_BLOCKING_WAIT", "XCCL_BLOCKING_WAIT"}; @@ -105,17 +108,38 @@ class TORCH_API ProcessGroupXCCL : public Backend { friend class ProcessGroupXCCL; }; - ProcessGroupXCCL(const c10::intrusive_ptr& store, int rank, int size); + struct Options : Backend::Options { + explicit Options(bool is_high_priority_stream = false); + + static c10::intrusive_ptr create( + bool is_high_priority_stream = false) { + return c10::make_intrusive(is_high_priority_stream); + } + bool is_high_priority_stream; + std::vector global_ranks_in_group; + std::string group_name; + }; + + ProcessGroupXCCL( + c10::intrusive_ptr store, + int rank, + int size, + c10::intrusive_ptr options = Options::create()); C10_DEPRECATED ProcessGroupXCCL( const c10::intrusive_ptr& store, int rank, int size, - const std::string& groupName) - : ProcessGroupXCCL(store, rank, size) {} + const std::string& groupName, + c10::intrusive_ptr options = Options::create()) + : ProcessGroupXCCL(store, rank, size, std::move(options)) {} ~ProcessGroupXCCL() override; + c10::intrusive_ptr getOptions() { + return options_; + } + const std::string getBackendName() const override { return std::string(XCCL_BACKEND_NAME); } @@ -367,12 +391,15 @@ class TORCH_API ProcessGroupXCCL : public Backend { const std::string& logPrefix() const; + const std::vector& groupRanks() const; + protected: std::unordered_map> xcclStreamsMap_; std::unordered_map xcclEventsMap_; std::unordered_map> devXCCLCommMap_; c10::intrusive_ptr store_; + const c10::intrusive_ptr options_; uint64_t xcclCommCounter_{0}; std::mutex mutex_; std::set usedDeviceIdxs_; diff --git a/test/xpu/distributed/test_c10d_xccl.py b/test/xpu/distributed/test_c10d_xccl.py index 0625a6993..faae24f99 100644 --- a/test/xpu/distributed/test_c10d_xccl.py +++ b/test/xpu/distributed/test_c10d_xccl.py @@ -340,6 +340,32 @@ def _test_broadcast_coalesced(self, process_group, device, root_rank): if self.rank != root_rank: self.assertEqual(tensors, target) + def _test_pass_xccl_options(self, pg_opts): + store = c10d.FileStore(self.file_name, self.world_size) + # Test init_process_group accepts options + dist.init_process_group( + "xccl", + world_size=self.world_size, + rank=self.rank, + store=store, + pg_options=pg_opts, + ) + + # Test with new_group + pg = c10d.new_group([0, 1], pg_options=pg_opts) + # test the process group works as expected + t = torch.tensor([self.rank + 1] * 10).xpu(self.rank) + pg.allreduce(t).wait() + expected_tensor = torch.tensor([3] * 10).xpu(self.rank) + self.assertEqual(expected_tensor, t) + + @requires_xccl() + @skip_if_lt_x_gpu(2) + def test_pass_xccl_options_high_priority_stream(self): + pg_opts = c10d.ProcessGroupXCCL.Options() + pg_opts.is_high_priority_stream = True + self._test_pass_xccl_options(pg_opts) + @requires_xccl() @skip_if_lt_x_gpu(2) def test_broadcast_coalesced_xccl(self):