-
Notifications
You must be signed in to change notification settings - Fork 42
[wip] Xccl/nan #1756
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[wip] Xccl/nan #1756
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR adds optional NaN checks to XCCL collective and point-to-point operations on XPU, driven by a new TORCH_XCCL_NAN_CHECK
CVar.
Key changes:
- Introduce
nanCheck
flag in collective/P2P APIs andenableNanCheck_
member - Implement XPU-side NaN detection kernel (
checkForNan
) - Update build (CMake) to compile the new SYCL-based checker separately
Reviewed Changes
Copilot reviewed 5 out of 5 changed files in this pull request and generated 4 comments.
Show a summary per file
File | Description |
---|---|
src/xccl/ProcessGroupXCCL.hpp | Added CVar vector, nanCheck parameters on collective APIs, and enableNanCheck_ with setter |
src/xccl/ProcessGroupXCCL.cpp | Initialized enableNanCheck_ , passed nanCheck through calls, and inserted pre-communication NaN checks |
src/xccl/NanCheck_XPU.hpp | Declared checkForNan interface for XPU streams |
src/xccl/NanCheck_XPU.cpp | Implemented a SYCL kernel to scan tensors for NaNs on XPU |
src/xccl/CMakeLists.txt | Updated source lists to compile NanCheck_XPU.cpp under SYCL target |
Comments suppressed due to low confidence (1)
src/xccl/NanCheck_XPU.cpp:177
- [nitpick] Function name
checkfornan_impl_xpu
is inconsistent with the CamelCase style elsewhere; consider renaming tocheckForNanImplXPU
for clarity.
void checkfornan_impl_xpu(
src/xccl/NanCheck_XPU.cpp
Outdated
void operator()(sycl::nd_item<1> item) const { | ||
constexpr int EltPerPack = sizeof(BytePack) / sizeof(T); | ||
|
||
size_t offset = item.get_global_id()[2]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Indexing get_global_id()[2]
on a 1D sycl::nd_item<1>
is out-of-bounds; you should use get_global_id(0)
or get_global_id()[0]
instead.
size_t offset = item.get_global_id()[2]; | |
size_t offset = item.get_global_id(0); |
Copilot uses AI. Check for mistakes.
src/xccl/NanCheck_XPU.cpp
Outdated
if (item.get_local_id(1) < size_left % EltPerPack) { | ||
T* tailPtr = (T*)(ptr + sizeInBP); | ||
if (isnan(tailPtr[item.get_local_id(1)])) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Calling get_local_id(1)
on a 1D work-group is invalid; this should be get_local_id(0)
to avoid out-of-bounds access.
if (item.get_local_id(1) < size_left % EltPerPack) { | |
T* tailPtr = (T*)(ptr + sizeInBP); | |
if (isnan(tailPtr[item.get_local_id(1)])) | |
if (item.get_local_id(0) < size_left % EltPerPack) { | |
T* tailPtr = (T*)(ptr + sizeInBP); | |
if (isnan(tailPtr[item.get_local_id(0)])) |
Copilot uses AI. Check for mistakes.
@@ -27,6 +27,9 @@ static std::vector<std::string> TORCH_XCCL_BLOCKING_WAIT = { | |||
"XCCL_BLOCKING_WAIT"}; | |||
|
|||
using xcclComm_t = ccl::communicator; | |||
|
|||
static std::vector<std::string> TORCH_XCCL_NAN_CHECK = {"TORCH_XCCL_NAN_CHECK"}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nitpick] Using a std::vector<std::string>
for a single CVar name adds runtime allocation; consider using a static const char*[]
or std::array<std::string_view, 1>
to avoid overhead.
static std::vector<std::string> TORCH_XCCL_NAN_CHECK = {"TORCH_XCCL_NAN_CHECK"}; | |
static constexpr std::array<std::string_view, 1> TORCH_XCCL_NAN_CHECK = {"TORCH_XCCL_NAN_CHECK"}; |
Copilot uses AI. Check for mistakes.
@@ -620,6 +629,12 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::collective( | |||
|
|||
c10::OptionalDeviceGuard gpuGuard(device); | |||
|
|||
if (nanCheck) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The NaN check currently only scans inputs
before communication. To catch NaNs introduced by the collective or point-to-point operations, consider adding a post-operation loop over outputs
.
Copilot uses AI. Check for mistakes.
No description provided.