Skip to content

[wip] Xccl/nan #1756

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/xccl/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,13 @@

file(GLOB xccl_h "*.hpp")
file(GLOB xccl_cpp "*.cpp")
list(REMOVE_ITEM xccl_cpp "${CMAKE_CURRENT_SOURCE_DIR}/NanCheck_XPU.cpp")

list(APPEND ATen_XPU_XCCL_SRCS ${xccl_cpp})
list(APPEND ATen_XPU_SYCL_SRCS "${CMAKE_CURRENT_SOURCE_DIR}/NanCheck_XPU.cpp")

set(ATen_XPU_XCCL_SRCS ${ATen_XPU_XCCL_SRCS} PARENT_SCOPE)
set(ATen_XPU_SYCL_SRCS ${ATen_XPU_SYCL_SRCS} PARENT_SCOPE)

# Why copy the header file to the build directory?
# We want register XCCL backend to PyTorch c10d in torch/csrc/distributed/c10d/init.cpp#L27-L29.
Expand Down
220 changes: 220 additions & 0 deletions src/xccl/NanCheck_XPU.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,220 @@
#include <ATen/Dispatch.h>
#include <ATen/NumericUtils.h>
#include <ATen/native/xpu/sycl/MemoryAccessUtils.h>
#include <ATen/xpu/XPUContext.h>
#include <comm/SYCLContext.h>
#include <stdint.h>
#include <torch/torch.h>
#include <xccl/NanCheck_XPU.hpp>
#include <algorithm>

namespace c10d {

using BytePack = at::native::memory::aligned_vector<uint64_t, 2>;

template <typename T, int EltPerPack>
struct CheckBytePack {
static void check(BytePack* tmp) {
T* data = (T*)tmp;
#pragma unroll 8
for (int i = 0; i < EltPerPack; i++) {
if (at::_isnan(data[i]))
assert(0);
}
}
};

template <typename T>
struct CheckBytePack<T, /*EltPerPack*/ 2> {
static void check(BytePack* tmp) {
T* data = (T*)tmp;
if (at::_isnan(data[0]) || at::_isnan(data[1]))
assert(0);
}
};

template <typename T>
struct CheckBytePack<T, /*EltPerPack*/ 4> {
static void check(BytePack* tmp) {
T* data = (T*)tmp;
if (at::_isnan(data[0]) || at::_isnan(data[1]) || at::_isnan(data[2]) ||
at::_isnan(data[3]))
assert(0);
}
};

template <typename T>
struct CheckBytePack<T, /*EltPerPack*/ 8> {
static void check(BytePack* tmp) {
T* data = (T*)tmp;
if (at::_isnan(data[0]) || at::_isnan(data[1]) || at::_isnan(data[2]) ||
at::_isnan(data[3]) || at::_isnan(data[4]) || at::_isnan(data[5]) ||
at::_isnan(data[6]) || at::_isnan(data[7])) {
assert(0);
}
}
};

template <typename T>
struct HasNanFP8x8 {
static bool check(uint64_t fp8x8) = delete;
/*
{
// `static_assert` in template definition requires c++23 onwards.
// But the error message still applies if you find yourself here.
static_assert(
false,
"You should never call this template definition because it is empty. You "
"can follow the example of Float8_e4m3fn below to implement the check for
" "your new datatype."
);
}
*/
};

template <>
struct HasNanFP8x8<c10::Float8_e4m3fn> {
static bool check(uint64_t fp8x8) {
auto t = fp8x8 & 0x7F7F7F7F7F7F7F7FULL;
auto incremented = t + 0x0101010101010101ULL;
auto overflow = incremented & 0x8080808080808080ULL;
return overflow != 0;
}
};

template <>
struct HasNanFP8x8<c10::Float8_e5m2> {
static bool check(uint64_t fp8x8) {
auto t = fp8x8 & 0x7F7F7F7F7F7F7F7FULL;
auto incremented = t + 0x0303030303030303ULL;
auto overflow = incremented & 0x8080808080808080ULL;
return overflow != 0;
}
};

template <typename T>
struct CheckBytePack<T, /*EltPerPack*/ 16> {
static void check(BytePack* tmp) {
if (HasNanFP8x8<T>::check(tmp->val[0]) ||
HasNanFP8x8<T>::check(tmp->val[1]))
assert(0);
}
};

#define UNROLL 8

template <typename T>
void checkChunk(BytePack* ptr, int nWorkers) {
BytePack tmp[UNROLL];

#pragma unroll 8
for (int j = 0; j < UNROLL; j++) {
tmp[j] = ptr[nWorkers * j];
}
// Then check each BytePack in the tmp buffer
#pragma unroll 8
for (int j = 0; j < UNROLL; j++) {
CheckBytePack<T, sizeof(BytePack) / sizeof(T)>::check(tmp + j);
}
// Note: we separate the check from the load for efficient loading
}

// Align address of `ptr` up, to the alignment of `T`
#define ALIGN_UP(ptr, T) \
(((uintptr_t)ptr + sizeof(T) - 1) / sizeof(T) * sizeof(T))

template <typename T>
struct checkForNaN {
void operator()(sycl::nd_item<1> item) const {
constexpr int EltPerPack = sizeof(BytePack) / sizeof(T);

size_t offset = item.get_global_id(0);

// Align input address up to BytePack in case it is not
T* ptrAlign = (T*)ALIGN_UP(data, BytePack);
size_t preProcElts =
std::min<size_t>(static_cast<size_t>(ptrAlign - data), size);

size_t size_left = size;

if (offset < preProcElts) {
if (at::_isnan(data[offset]))
assert(0);
}
size_left -= preProcElts;

BytePack* ptr = (BytePack*)ptrAlign;
size_t sizeInBP = size_left * sizeof(T) / sizeof(BytePack);
size_t loopSize = item.get_global_range(0) * UNROLL;

for (; offset + loopSize <= sizeInBP; offset += loopSize) {
checkChunk<T>(ptr + offset, item.get_global_range(0));
}

for (; offset < sizeInBP; offset += item.get_global_range(0)) {
BytePack tmp = ptr[offset];
CheckBytePack<T, EltPerPack>::check(&tmp);
}

if (item.get_local_id(0) < size_left % EltPerPack) {
T* tailPtr = (T*)(ptr + sizeInBP);
if (at::_isnan(tailPtr[item.get_local_id(0)]))
assert(0);
}
}
checkForNaN(T* data, size_t size, int64_t num_group, int64_t max_group_size)
: data(data),
size(size),
num_group_(num_group),
max_group_size_(max_group_size) {}

private:
T* data;
size_t size;
int64_t num_group_;
int64_t max_group_size_;
};

template <typename T>
void checkfornan_impl_xpu(
const at::Tensor& tensor,
at::xpu::XPUStream& stream) {
// skip check for non float types
if (!torch::is_floating_point(tensor)) {
return;
}

int64_t maxNumThreadsPerBlock = syclMaxWorkGroupSize<checkForNaN<T>>();

const size_t numThreadsPerBlock =
std::min<size_t>(maxNumThreadsPerBlock, tensor.numel());

if (!(numThreadsPerBlock > 0)) {
return;
}

int64_t numBlocks =
(tensor.numel() + maxNumThreadsPerBlock - 1) / maxNumThreadsPerBlock;
auto global_range{numBlocks * maxNumThreadsPerBlock};
auto local_range{maxNumThreadsPerBlock};

using Kernel = checkForNaN<T>;
auto kfn = Kernel(
tensor.data_ptr<T>(), tensor.numel(), numBlocks, maxNumThreadsPerBlock);

sycl_kernel_submit(global_range, local_range, stream.queue(), kfn);
}

// CHECK if a Tensor contains NAN in any of its element
void checkForNan(const at::Tensor& tensor, at::xpu::XPUStream& stream) {
AT_DISPATCH_FLOATING_TYPES_AND4(
at::ScalarType::Half,
at::ScalarType::BFloat16,
at::ScalarType::Float8_e4m3fn,
at::ScalarType::Float8_e5m2,
tensor.scalar_type(),
"checkForNaN_XPU",
[&]() { checkfornan_impl_xpu<scalar_t>(tensor, stream); });
}

} // namespace c10d
14 changes: 14 additions & 0 deletions src/xccl/NanCheck_XPU.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
#pragma once

#ifdef USE_C10D_XCCL

#include <ATen/ATen.h>
#include <c10/xpu/XPUStream.h>

namespace c10d {

void checkForNan(const at::Tensor& tensor, at::xpu::XPUStream& stream);

} // namespace c10d

#endif // USE_C10D_XCCL
35 changes: 30 additions & 5 deletions src/xccl/ProcessGroupXCCL.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#ifdef USE_C10D_XCCL

#include <torch/csrc/distributed/c10d/ParamCommsUtils.hpp>
#include <xccl/NanCheck_XPU.hpp>
#include <xccl/ProcessGroupXCCL.hpp>

namespace c10d {
Expand Down Expand Up @@ -338,6 +339,7 @@ ProcessGroupXCCL::ProcessGroupXCCL(
local_id_(process_group_id++) {
logPrefix_ = createLogPrefix();
blockingWait_ = getCvarBool(TORCH_XCCL_BLOCKING_WAIT, false);
enableNanCheck_ = getCvarBool(TORCH_XCCL_NAN_CHECK, false);
init();
const std::string OFF = "OFF";
std::string torch_distributed_debug =
Expand All @@ -349,7 +351,8 @@ ProcessGroupXCCL::ProcessGroupXCCL(
LOG(INFO) << logPrefix() << "ProcessGroupXCCL environments: "
<< "XCCL version: " << XcclVersion
<< ", TORCH_XCCL_BLOCKING_WAIT: " << blockingWait_
<< ", TORCH_DISTRIBUTED_DEBUG: " << torch_distributed_debug;
<< ", TORCH_DISTRIBUTED_DEBUG: " << torch_distributed_debug
<< ", TORCH_XCCL_NAN_CHECK: " << enableNanCheck_;
}

ProcessGroupXCCL::~ProcessGroupXCCL() = default;
Expand All @@ -360,6 +363,10 @@ uint64_t ProcessGroupXCCL::getSequenceNumberForGroup() {
return seqCollective_;
}

void ProcessGroupXCCL::setEnableNanCheck(bool enableNanCheck) {
enableNanCheck_ = enableNanCheck;
}

c10::intrusive_ptr<ProcessGroupXCCL::WorkXCCL> ProcessGroupXCCL::initWork(
at::Device& device,
int rank,
Expand Down Expand Up @@ -553,7 +560,9 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::collective(
PostProcess post,
OpType opType,
bool asyncOp,
const char* profilingTitle) {
const char* profilingTitle,
bool nanCheck) {
nanCheck &= enableNanCheck_;
seqCollective_++;
auto device = inputs[0].device();
const auto key = std::to_string(device.index());
Expand Down Expand Up @@ -620,6 +629,12 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::collective(

c10::OptionalDeviceGuard gpuGuard(device);

if (nanCheck) {
Copy link
Preview

Copilot AI Jun 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The NaN check currently only scans inputs before communication. To catch NaNs introduced by the collective or point-to-point operations, consider adding a post-operation loop over outputs.

Copilot uses AI. Check for mistakes.

for (const auto& input : inputs) {
checkForNan(input, stream);
}
}

pre(stream, work);

for (const auto i : c10::irange(inputs.size())) {
Expand Down Expand Up @@ -697,6 +712,10 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::pointToPoint(
auto cclstream = xcclStreamsMap_.at(key).second;
syncStream(device, xcclEventsMap_[key], stream);

if (enableNanCheck_ && opType == OpType::SEND) {
checkForNan(tensor, stream);
}

if (!coalescing_state_) {
auto work =
initWork(device, rank_, opType, true, profilingTitle, {tensor}, {});
Expand Down Expand Up @@ -1006,6 +1025,7 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::scatter(
this->getSize()); // worldSize

const auto root = opts.rootRank;
bool nanCheck = (rank_ == root);

auto outputs = std::vector<at::Tensor>{outputTensor};
return collective(
Expand Down Expand Up @@ -1059,7 +1079,8 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::scatter(
},
OpType::SCATTER,
opts.asyncOp,
"xccl:scatter");
"xccl:scatter",
nanCheck);
}

c10::intrusive_ptr<Work> ProcessGroupXCCL::allreduce_impl(
Expand Down Expand Up @@ -1222,6 +1243,7 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::broadcast(
this->getSize()); // worldSize

const auto root = opts.rootRank + opts.rootTensor;
bool nanCheck = (root == rank_);

return collective(
tensor,
Expand All @@ -1243,7 +1265,8 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::broadcast(
},
OpType::BROADCAST,
opts.asyncOp,
"xccl:broadcast");
"xccl:broadcast",
nanCheck);
}

c10::intrusive_ptr<Work> ProcessGroupXCCL::_broadcast_oop(
Expand All @@ -1256,6 +1279,7 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::_broadcast_oop(
"Tensor input and output of _broadcast_oop must have the same number of elements ");
}
const auto root = opts.rootRank + opts.rootTensor;
bool nanCheck = (root == rank_);
return collective(
inputTensor,
outputTensor,
Expand All @@ -1277,7 +1301,8 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::_broadcast_oop(
},
OpType::BROADCAST,
opts.asyncOp,
"xccl:_broadcast_oop");
"xccl:_broadcast_oop",
nanCheck);
}

c10::intrusive_ptr<Work> ProcessGroupXCCL::reduce(
Expand Down
Loading
Loading