Skip to content

support high priority stream #1715

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 26 additions & 7 deletions src/xccl/ProcessGroupXCCL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -290,14 +290,26 @@ const std::string& ProcessGroupXCCL::logPrefix() const {
return logPrefix_;
}

const std::vector<uint64_t>& ProcessGroupXCCL::groupRanks() const {
Copy link
Preview

Copilot AI Jun 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Returning a reference to a static vector that is only populated when options_->global_ranks_in_group is empty and local_id_ == 0 may lead to inconsistent behavior and potential thread-safety issues. Consider moving the default global ranks logic into the Options struct or ensuring proper synchronization.

Copilot uses AI. Check for mistakes.

if (options_->global_ranks_in_group.empty() && local_id_ == 0) {
static std::vector<uint64_t> globalRanks(size_);
std::iota(globalRanks.begin(), globalRanks.end(), 0);
return globalRanks;
}
return options_->global_ranks_in_group;
}

ProcessGroupXCCL::ProcessGroupXCCL(
const c10::intrusive_ptr<Store>& store,
c10::intrusive_ptr<Store> store,
int rank,
int size)
int size,
c10::intrusive_ptr<Options> options)
: Backend(rank, size),
store_(store),
store_(std::move(store)),
options_(std::move(options)),
xcclCommCounter_(0),
local_id_(process_group_id++) {
this->setGroupUid(options_->group_name);
logPrefix_ = createLogPrefix();
blockingWait_ = getCvarBool(TORCH_XCCL_BLOCKING_WAIT, false);
init();
Expand All @@ -306,7 +318,10 @@ ProcessGroupXCCL::ProcessGroupXCCL(
getCvarString({"TORCH_DISTRIBUTED_DEBUG"}, OFF.c_str());
const auto XcclVersion = getXcclVersion();
LOG(INFO) << logPrefix() << "ProcessGroupXCCL initialization options: "
<< "size: " << size << ", global rank: " << rank_;
<< "size: " << size << ", global rank: " << rank_
<< ", USE_HIGH_PRIORITY_STREAM: "
<< options_->is_high_priority_stream
<< ", PG Name: " << options_->group_name;

LOG(INFO) << logPrefix() << "ProcessGroupXCCL environments: "
<< "XCCL version: " << XcclVersion
Expand Down Expand Up @@ -388,9 +403,9 @@ std::shared_ptr<xcclComm_t> ProcessGroupXCCL::getXCCLComm(
rank = p2pRank;
}

c10::impl::VirtualGuardImpl impl(device.type());
c10::Stream stream =
impl.getStreamFromGlobalPool(device, /*isHighPriority=*/false);
bool force_high = getCvarBool(TORCH_XCCL_HIGH_PRIORITY, false);
c10::Stream stream = at::xpu::getStreamFromPool(
options_->is_high_priority_stream || force_high);
sycl::queue& q = c10::xpu::XPUStream(stream).queue();

auto ctx = ccl::create_context(q.get_context());
Expand Down Expand Up @@ -451,6 +466,10 @@ void ProcessGroupXCCL::groupEnd() {
--xcclActiveGroupCounter_;
}

ProcessGroupXCCL::Options::Options(bool is_high_priority_stream)
: Backend::Options(XCCL_BACKEND_NAME),
is_high_priority_stream(is_high_priority_stream) {}

static constexpr int CoalActive = 0x01, CoalColl = 0x02, CoalP2P = 0x04;
void ProcessGroupXCCL::startCoalescing() {
coalescedDevice_.set_index(-1);
Expand Down
33 changes: 30 additions & 3 deletions src/xccl/ProcessGroupXCCL.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@
#include <torch/csrc/distributed/c10d/logger.hpp>
namespace c10d {

static std::vector<std::string> TORCH_XCCL_HIGH_PRIORITY = {
"TORCH_XCCL_HIGH_PRIORITY"};

static std::vector<std::string> TORCH_XCCL_BLOCKING_WAIT = {
"TORCH_XCCL_BLOCKING_WAIT",
"XCCL_BLOCKING_WAIT"};
Expand Down Expand Up @@ -105,17 +108,38 @@ class TORCH_API ProcessGroupXCCL : public Backend {
friend class ProcessGroupXCCL;
};

ProcessGroupXCCL(const c10::intrusive_ptr<Store>& store, int rank, int size);
struct Options : Backend::Options {
explicit Options(bool is_high_priority_stream = false);

static c10::intrusive_ptr<Options> create(
bool is_high_priority_stream = false) {
return c10::make_intrusive<Options>(is_high_priority_stream);
}
bool is_high_priority_stream;
std::vector<uint64_t> global_ranks_in_group;
std::string group_name;
};

ProcessGroupXCCL(
c10::intrusive_ptr<Store> store,
int rank,
int size,
c10::intrusive_ptr<Options> options = Options::create());

C10_DEPRECATED ProcessGroupXCCL(
const c10::intrusive_ptr<Store>& store,
int rank,
int size,
const std::string& groupName)
: ProcessGroupXCCL(store, rank, size) {}
const std::string& groupName,
c10::intrusive_ptr<Options> options = Options::create())
: ProcessGroupXCCL(store, rank, size, std::move(options)) {}

~ProcessGroupXCCL() override;

c10::intrusive_ptr<Options> getOptions() {
return options_;
}

const std::string getBackendName() const override {
return std::string(XCCL_BACKEND_NAME);
}
Expand Down Expand Up @@ -367,12 +391,15 @@ class TORCH_API ProcessGroupXCCL : public Backend {

const std::string& logPrefix() const;

const std::vector<uint64_t>& groupRanks() const;

protected:
std::unordered_map<std::string, std::pair<at::xpu::XPUStream, ccl::stream>>
xcclStreamsMap_;
std::unordered_map<std::string, at::xpu::XPUEvent> xcclEventsMap_;
std::unordered_map<std::string, std::shared_ptr<xcclComm_t>> devXCCLCommMap_;
c10::intrusive_ptr<Store> store_;
const c10::intrusive_ptr<Options> options_;
uint64_t xcclCommCounter_{0};
std::mutex mutex_;
std::set<int> usedDeviceIdxs_;
Expand Down
26 changes: 26 additions & 0 deletions test/xpu/distributed/test_c10d_xccl.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,32 @@ def _test_broadcast_coalesced(self, process_group, device, root_rank):
if self.rank != root_rank:
self.assertEqual(tensors, target)

def _test_pass_xccl_options(self, pg_opts):
store = c10d.FileStore(self.file_name, self.world_size)
# Test init_process_group accepts options
dist.init_process_group(
"xccl",
world_size=self.world_size,
rank=self.rank,
store=store,
pg_options=pg_opts,
)

# Test with new_group
pg = c10d.new_group([0, 1], pg_options=pg_opts)
# test the process group works as expected
t = torch.tensor([self.rank + 1] * 10).xpu(self.rank)
pg.allreduce(t).wait()
expected_tensor = torch.tensor([3] * 10).xpu(self.rank)
self.assertEqual(expected_tensor, t)

@requires_xccl()
@skip_if_lt_x_gpu(2)
def test_pass_xccl_options_high_priority_stream(self):
pg_opts = c10d.ProcessGroupXCCL.Options()
pg_opts.is_high_priority_stream = True
self._test_pass_xccl_options(pg_opts)

@requires_xccl()
@skip_if_lt_x_gpu(2)
def test_broadcast_coalesced_xccl(self):
Expand Down
Loading