-
Notifications
You must be signed in to change notification settings - Fork 42
[wip] Xccl/nan #1756
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
Chao1Han
wants to merge
6
commits into
main
Choose a base branch
from
xccl/nan
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
[wip] Xccl/nan #1756
Changes from all commits
Commits
Show all changes
6 commits
Select commit
Hold shift + click to select a range
eda6607
add nan check for xccl
Chao1Han 595a1dc
cmake and format
Chao1Han aef4be3
add nan check
Chao1Han e2bb25d
Merge branch 'main' into xccl/nan
mengfei25 b3a2e94
Merge branch 'main' into xccl/nan
mengfei25 38c1c10
update
Chao1Han File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,220 @@ | ||
#include <ATen/Dispatch.h> | ||
#include <ATen/NumericUtils.h> | ||
#include <ATen/native/xpu/sycl/MemoryAccessUtils.h> | ||
#include <ATen/xpu/XPUContext.h> | ||
#include <comm/SYCLContext.h> | ||
#include <stdint.h> | ||
#include <torch/torch.h> | ||
#include <xccl/NanCheck_XPU.hpp> | ||
#include <algorithm> | ||
|
||
namespace c10d { | ||
|
||
using BytePack = at::native::memory::aligned_vector<uint64_t, 2>; | ||
|
||
template <typename T, int EltPerPack> | ||
struct CheckBytePack { | ||
static void check(BytePack* tmp) { | ||
T* data = (T*)tmp; | ||
#pragma unroll 8 | ||
for (int i = 0; i < EltPerPack; i++) { | ||
if (at::_isnan(data[i])) | ||
assert(0); | ||
} | ||
} | ||
}; | ||
|
||
template <typename T> | ||
struct CheckBytePack<T, /*EltPerPack*/ 2> { | ||
static void check(BytePack* tmp) { | ||
T* data = (T*)tmp; | ||
if (at::_isnan(data[0]) || at::_isnan(data[1])) | ||
assert(0); | ||
} | ||
}; | ||
|
||
template <typename T> | ||
struct CheckBytePack<T, /*EltPerPack*/ 4> { | ||
static void check(BytePack* tmp) { | ||
T* data = (T*)tmp; | ||
if (at::_isnan(data[0]) || at::_isnan(data[1]) || at::_isnan(data[2]) || | ||
at::_isnan(data[3])) | ||
assert(0); | ||
} | ||
}; | ||
|
||
template <typename T> | ||
struct CheckBytePack<T, /*EltPerPack*/ 8> { | ||
static void check(BytePack* tmp) { | ||
T* data = (T*)tmp; | ||
if (at::_isnan(data[0]) || at::_isnan(data[1]) || at::_isnan(data[2]) || | ||
at::_isnan(data[3]) || at::_isnan(data[4]) || at::_isnan(data[5]) || | ||
at::_isnan(data[6]) || at::_isnan(data[7])) { | ||
assert(0); | ||
} | ||
} | ||
}; | ||
|
||
template <typename T> | ||
struct HasNanFP8x8 { | ||
static bool check(uint64_t fp8x8) = delete; | ||
/* | ||
{ | ||
// `static_assert` in template definition requires c++23 onwards. | ||
// But the error message still applies if you find yourself here. | ||
static_assert( | ||
false, | ||
"You should never call this template definition because it is empty. You " | ||
"can follow the example of Float8_e4m3fn below to implement the check for | ||
" "your new datatype." | ||
); | ||
} | ||
*/ | ||
}; | ||
|
||
template <> | ||
struct HasNanFP8x8<c10::Float8_e4m3fn> { | ||
static bool check(uint64_t fp8x8) { | ||
auto t = fp8x8 & 0x7F7F7F7F7F7F7F7FULL; | ||
auto incremented = t + 0x0101010101010101ULL; | ||
auto overflow = incremented & 0x8080808080808080ULL; | ||
return overflow != 0; | ||
} | ||
}; | ||
|
||
template <> | ||
struct HasNanFP8x8<c10::Float8_e5m2> { | ||
static bool check(uint64_t fp8x8) { | ||
auto t = fp8x8 & 0x7F7F7F7F7F7F7F7FULL; | ||
auto incremented = t + 0x0303030303030303ULL; | ||
auto overflow = incremented & 0x8080808080808080ULL; | ||
return overflow != 0; | ||
} | ||
}; | ||
|
||
template <typename T> | ||
struct CheckBytePack<T, /*EltPerPack*/ 16> { | ||
static void check(BytePack* tmp) { | ||
if (HasNanFP8x8<T>::check(tmp->val[0]) || | ||
HasNanFP8x8<T>::check(tmp->val[1])) | ||
assert(0); | ||
} | ||
}; | ||
|
||
#define UNROLL 8 | ||
|
||
template <typename T> | ||
void checkChunk(BytePack* ptr, int nWorkers) { | ||
BytePack tmp[UNROLL]; | ||
|
||
#pragma unroll 8 | ||
for (int j = 0; j < UNROLL; j++) { | ||
tmp[j] = ptr[nWorkers * j]; | ||
} | ||
// Then check each BytePack in the tmp buffer | ||
#pragma unroll 8 | ||
for (int j = 0; j < UNROLL; j++) { | ||
CheckBytePack<T, sizeof(BytePack) / sizeof(T)>::check(tmp + j); | ||
} | ||
// Note: we separate the check from the load for efficient loading | ||
} | ||
|
||
// Align address of `ptr` up, to the alignment of `T` | ||
#define ALIGN_UP(ptr, T) \ | ||
(((uintptr_t)ptr + sizeof(T) - 1) / sizeof(T) * sizeof(T)) | ||
|
||
template <typename T> | ||
struct checkForNaN { | ||
void operator()(sycl::nd_item<1> item) const { | ||
constexpr int EltPerPack = sizeof(BytePack) / sizeof(T); | ||
|
||
size_t offset = item.get_global_id(0); | ||
|
||
// Align input address up to BytePack in case it is not | ||
T* ptrAlign = (T*)ALIGN_UP(data, BytePack); | ||
size_t preProcElts = | ||
std::min<size_t>(static_cast<size_t>(ptrAlign - data), size); | ||
|
||
size_t size_left = size; | ||
|
||
if (offset < preProcElts) { | ||
if (at::_isnan(data[offset])) | ||
assert(0); | ||
} | ||
size_left -= preProcElts; | ||
|
||
BytePack* ptr = (BytePack*)ptrAlign; | ||
size_t sizeInBP = size_left * sizeof(T) / sizeof(BytePack); | ||
size_t loopSize = item.get_global_range(0) * UNROLL; | ||
|
||
for (; offset + loopSize <= sizeInBP; offset += loopSize) { | ||
checkChunk<T>(ptr + offset, item.get_global_range(0)); | ||
} | ||
|
||
for (; offset < sizeInBP; offset += item.get_global_range(0)) { | ||
BytePack tmp = ptr[offset]; | ||
CheckBytePack<T, EltPerPack>::check(&tmp); | ||
} | ||
|
||
if (item.get_local_id(0) < size_left % EltPerPack) { | ||
T* tailPtr = (T*)(ptr + sizeInBP); | ||
if (at::_isnan(tailPtr[item.get_local_id(0)])) | ||
assert(0); | ||
} | ||
} | ||
checkForNaN(T* data, size_t size, int64_t num_group, int64_t max_group_size) | ||
: data(data), | ||
size(size), | ||
num_group_(num_group), | ||
max_group_size_(max_group_size) {} | ||
|
||
private: | ||
T* data; | ||
size_t size; | ||
int64_t num_group_; | ||
int64_t max_group_size_; | ||
}; | ||
|
||
template <typename T> | ||
void checkfornan_impl_xpu( | ||
const at::Tensor& tensor, | ||
at::xpu::XPUStream& stream) { | ||
// skip check for non float types | ||
if (!torch::is_floating_point(tensor)) { | ||
return; | ||
} | ||
|
||
int64_t maxNumThreadsPerBlock = syclMaxWorkGroupSize<checkForNaN<T>>(); | ||
|
||
const size_t numThreadsPerBlock = | ||
std::min<size_t>(maxNumThreadsPerBlock, tensor.numel()); | ||
|
||
if (!(numThreadsPerBlock > 0)) { | ||
return; | ||
} | ||
|
||
int64_t numBlocks = | ||
(tensor.numel() + maxNumThreadsPerBlock - 1) / maxNumThreadsPerBlock; | ||
auto global_range{numBlocks * maxNumThreadsPerBlock}; | ||
auto local_range{maxNumThreadsPerBlock}; | ||
|
||
using Kernel = checkForNaN<T>; | ||
auto kfn = Kernel( | ||
tensor.data_ptr<T>(), tensor.numel(), numBlocks, maxNumThreadsPerBlock); | ||
|
||
sycl_kernel_submit(global_range, local_range, stream.queue(), kfn); | ||
} | ||
|
||
// CHECK if a Tensor contains NAN in any of its element | ||
void checkForNan(const at::Tensor& tensor, at::xpu::XPUStream& stream) { | ||
AT_DISPATCH_FLOATING_TYPES_AND4( | ||
at::ScalarType::Half, | ||
at::ScalarType::BFloat16, | ||
at::ScalarType::Float8_e4m3fn, | ||
at::ScalarType::Float8_e5m2, | ||
tensor.scalar_type(), | ||
"checkForNaN_XPU", | ||
[&]() { checkfornan_impl_xpu<scalar_t>(tensor, stream); }); | ||
} | ||
|
||
} // namespace c10d |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
#pragma once | ||
|
||
#ifdef USE_C10D_XCCL | ||
|
||
#include <ATen/ATen.h> | ||
#include <c10/xpu/XPUStream.h> | ||
|
||
namespace c10d { | ||
|
||
void checkForNan(const at::Tensor& tensor, at::xpu::XPUStream& stream); | ||
|
||
} // namespace c10d | ||
|
||
#endif // USE_C10D_XCCL |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The NaN check currently only scans
inputs
before communication. To catch NaNs introduced by the collective or point-to-point operations, consider adding a post-operation loop overoutputs
.Copilot uses AI. Check for mistakes.