From 6566f163819422286fbd5b28e5892e20e0a83e65 Mon Sep 17 00:00:00 2001 From: min-jean-cho Date: Mon, 17 Feb 2025 13:12:28 -0800 Subject: [PATCH 01/15] Use Welford LayerNorm --- src/ATen/native/xpu/LayerNorm.cpp | 11 +- src/ATen/native/xpu/sycl/LayerNormKernels.cpp | 318 ++++++++++++++++++ src/ATen/native/xpu/sycl/LayerNormKernels.h | 8 +- 3 files changed, 328 insertions(+), 9 deletions(-) diff --git a/src/ATen/native/xpu/LayerNorm.cpp b/src/ATen/native/xpu/LayerNorm.cpp index f9e032122..8ee463b84 100644 --- a/src/ATen/native/xpu/LayerNorm.cpp +++ b/src/ATen/native/xpu/LayerNorm.cpp @@ -49,18 +49,19 @@ ::std::tuple layer_norm_xpu( Tensor Y = at::native::empty_like( *X, - c10::nullopt /* dtype */, - c10::nullopt /* layout */, - c10::nullopt /* device */, - c10::nullopt /* pin_memory */, + std::nullopt /* dtype */, + std::nullopt /* layout */, + std::nullopt /* device */, + std::nullopt /* pin_memory */, LEGACY_CONTIGUOUS_MEMORY_FORMAT); + auto acc_type = at::toAccumulateType(input.scalar_type(), true); Tensor mean = at::empty({M}, X->options().dtype(acc_type)); Tensor rstd = at::empty({M}, X->options().dtype(acc_type)); native::xpu::layer_norm_kernel( - *X, *gamma, *beta, M, N, epsilon, Y, mean, rstd); + *X, *gamma, *beta, M, N, epsilon, &Y, &mean, &rstd); const auto input_shape = input.sizes(); const size_t axis = input.dim() - normalized_shape.size(); diff --git a/src/ATen/native/xpu/sycl/LayerNormKernels.cpp b/src/ATen/native/xpu/sycl/LayerNormKernels.cpp index 171c54736..8d67ed88a 100644 --- a/src/ATen/native/xpu/sycl/LayerNormKernels.cpp +++ b/src/ATen/native/xpu/sycl/LayerNormKernels.cpp @@ -268,6 +268,300 @@ class LayerNormBackward : public NormBackward { int64_t numel; }; +constexpr int vec_size = + 4; // we could make it dependent on dtype, but that would lead to different + // results between float and low-p types + +// Checks alignment of buffers for using vectorized loads / stores +template +bool can_vectorize(const T* ptr, int alignment) { + uint64_t addr = reinterpret_cast(ptr); + return addr % alignment == 0; +}; + +struct WelfordDataLN { + float mean; + float sigma2; + float count; + WelfordDataLN() : mean(0.f), sigma2(0.f), count(0.f) {} + WelfordDataLN(float mean, float sigma2, float count) + : mean(mean), sigma2(sigma2), count(count) {} +}; + +template +WelfordDataLN WelfordOnlineSum(const U val, const WelfordDataLN& curr_sum) { + U delta = val - curr_sum.mean; + U new_count = curr_sum.count + 1.f; + U new_mean = curr_sum.mean + + delta * (1.f / new_count); // proper division is slow, this is less + // accurate but noticeably faster + return WelfordDataLN( + new_mean, curr_sum.sigma2 + delta * (val - new_mean), new_count); +} + +WelfordDataLN WelfordCombine( + const WelfordDataLN dataB, + const WelfordDataLN dataA) { + using U = decltype(dataB.count); + U delta = dataB.mean - dataA.mean; + U count = dataA.count + dataB.count; + U mean, sigma2; + if (count > decltype(dataB.count){0}) { + auto coef = 1.f / count; // NB we don't use --use_fast_math, but this is + // emulation, 1./count goes to intrinsic, `* coef` + // is multiplication, instead of slow fp division + auto nA = dataA.count * coef; + auto nB = dataB.count * coef; + mean = nA * dataA.mean + nB * dataB.mean; + sigma2 = dataA.sigma2 + dataB.sigma2 + delta * delta * dataA.count * nB; + } else { + mean = U(0); + sigma2 = U(0); + } + return {mean, sigma2, count}; +} + +template +WelfordDataLN compute_stats( + const T* __restrict__ X, + const int N, + shared_t meansigmabuf, + shared_t countbuf, + sycl::nd_item<2>& item_id) { + int tid = item_id.get_local_linear_id(); + // X points to the row to read + using vec_t = aligned_vector; + using acc_t = acc_type; + const vec_t* X_vec = reinterpret_cast(X); + const int numx = item_id.get_local_range(1) * item_id.get_local_range(0); + const int thrx = item_id.get_local_linear_id(); + const int n_vec_to_read = N / vec_size; + WelfordDataLN wd(0.f, 0.f, 0.f); + // no tail, we check that N is multiple of vec_size + for (int i = thrx; i < n_vec_to_read; i += numx) { + vec_t data = X_vec[i]; +#pragma unroll + for (int ii = 0; ii < vec_size; ii++) { + wd = WelfordOnlineSum(static_cast(data.val[ii]), wd); + } + } + // intra-warp reduction + auto sg = item_id.get_sub_group(); + int sgSize = static_cast(sg.get_local_range()[1]); + for (int offset = (sgSize >> 1); offset > 0; offset >>= 1) { + WelfordDataLN wdB{ + sycl::shift_group_left(sg, wd.mean, offset), + sycl::shift_group_left(sg, wd.sigma2, offset), + sycl::shift_group_left(sg, wd.count, offset)}; + wd = WelfordCombine(wd, wdB); + } + // threadIdx.x == 0 has correct values for each warp + // inter-warp reductions + if (item_id.get_local_range(0) > 1) { + for (int offset = item_id.get_local_range(0) / 2; offset > 0; offset /= 2) { + // upper half of warps write to shared + if (item_id.get_local_id(1) == 0 && item_id.get_local_id(0) >= offset && + item_id.get_local_id(0) < 2 * offset) { + const int wrt_y = item_id.get_local_id(0) - offset; + meansigmabuf[2 * wrt_y] = wd.mean; + meansigmabuf[2 * wrt_y + 1] = wd.sigma2; + countbuf[wrt_y] = wd.count; + } + item_id.barrier(sycl_local_fence); + // lower half merges + if (item_id.get_local_id(1) == 0 && item_id.get_local_id(0) < offset) { + const int local_y = item_id.get_local_id(0); + WelfordDataLN wdB{ + static_cast(meansigmabuf[2 * local_y]), + static_cast(meansigmabuf[2 * local_y + 1]), + static_cast(countbuf[local_y])}; + wd = WelfordCombine(wd, wdB); + } + item_id.barrier(sycl_local_fence); + } + if (item_id.get_local_id(1) == 0 && item_id.get_local_id(0) == 0) { + meansigmabuf[0] = wd.mean; + meansigmabuf[1] = wd.sigma2 / float(N); + } + item_id.barrier(sycl_local_fence); + return WelfordDataLN{ + static_cast(meansigmabuf[0]), + static_cast(meansigmabuf[1]), + 0.f}; + } else { + return WelfordDataLN{ + sycl::shift_group_left(sg, wd.mean, 0), + sycl::shift_group_left(sg, wd.sigma2, 0) / float(N), + 0.f}; + } +} + +template +struct vectorized_layer_norm_kernel_impl { + [[intel::reqd_sub_group_size(SIMD)]] void operator()( + sycl::nd_item<2> item_id) const { + auto i1 = item_id.get_group(1); + const T* block_row = X_ + i1 * N_; + WelfordDataLN wd = compute_stats( + block_row, N_, meansigmabuf_, countbuf_, item_id); + + using vec_t = aligned_vector; + const vec_t* X_vec = reinterpret_cast(block_row); + const vec_t* gamma_vec = + (gamma_ != nullptr) ? reinterpret_cast(gamma_) : nullptr; + const vec_t* beta_vec = + (beta_ != nullptr) ? reinterpret_cast(beta_) : nullptr; + vec_t* Y_vec = reinterpret_cast(Y_ + i1 * N_); + + const int numx = item_id.get_local_range(1) * item_id.get_local_range(0); + const int thrx = item_id.get_local_linear_id(); + const int n_vec_to_read = N_ / vec_size; + + T_ACC rstd_val = c10::xpu::compat::rsqrt(wd.sigma2 + eps_); + + // No tail, N is guaranteed to be multiple of vec size + for (int i = thrx; i < n_vec_to_read; i += numx) { + vec_t data = X_vec[i]; + vec_t out; + + // Computation is performed in T_ACC, X is cast to T_ACC and result is + // implicitly cast to T + if (gamma_vec != nullptr && beta_vec != nullptr) { +#pragma unroll + for (int ii = 0; ii < vec_size; ii++) { + out.val[ii] = static_cast(gamma_vec[i].val[ii]) * + (rstd_val * (static_cast(data.val[ii]) - wd.mean)) + + static_cast(beta_vec[i].val[ii]); + } + } else if (gamma_vec != nullptr) { +#pragma unroll + for (int ii = 0; ii < vec_size; ii++) { + out.val[ii] = static_cast(gamma_vec[i].val[ii]) * + (rstd_val * (static_cast(data.val[ii]) - wd.mean)); + } + } else if (beta_vec != nullptr) { +#pragma unroll + for (int ii = 0; ii < vec_size; ii++) { + out.val[ii] = + (rstd_val * (static_cast(data.val[ii]) - wd.mean)) + + static_cast(beta_vec[i].val[ii]); + } + } else { +#pragma unroll + for (int ii = 0; ii < vec_size; ii++) { + out.val[ii] = rstd_val * (static_cast(data.val[ii]) - wd.mean); + } + } + Y_vec[i] = out; + } + if (thrx == 0) { + mean_[i1] = wd.mean; + rstd_[i1] = rstd_val; + } + } + + void sycl_ker_config_convention(sycl::handler& cgh) { + meansigmabuf_ = sycl_local_acc_t(2 * 16, cgh); + countbuf_ = sycl_local_acc_t(16, cgh); + } + + vectorized_layer_norm_kernel_impl( + const int N, + T_ACC eps, + const T* __restrict__ X, + const T* gamma, + const T* beta, + T_ACC* mean, + T_ACC* rstd, + T* Y) + : N_(N), + eps_(eps), + X_(X), + gamma_(gamma), + beta_(beta), + mean_(mean), + rstd_(rstd), + Y_(Y) {} + + private: + const int N_; + T_ACC eps_; + const T* __restrict__ X_; + const T* gamma_; + const T* beta_; + T_ACC* mean_; + T_ACC* rstd_; + T* Y_; + sycl_local_acc_t meansigmabuf_; + sycl_local_acc_t countbuf_; +}; + +template +void launch_vectorized_layer_norm_kernel( + int N, + int64_t M, + T_ACC eps, + const T* X_data, + const T* gamma_data, + const T* beta_data, + T* Y_data, + T_ACC* mean_data, + T_ACC* rstd_data) { + vectorized_layer_norm_kernel_impl kfn( + N, eps, X_data, gamma_data, beta_data, mean_data, rstd_data, Y_data); + int64_t sg_size = syclMaxSubGroupSize(); + int64_t wg_size = syclMaxWorkGroupSize(kfn); + sycl::range<2> local_range{size_t(wg_size / sg_size), size_t(sg_size)}; + sycl::range<2> global_range(M * size_t(wg_size / sg_size), size_t(sg_size)); + auto queue = getCurrentSYCLQueue(); + sycl_kernel_submit(global_range, local_range, queue, kfn); +} + +template +void LayerNormKernelImplInternalXPU( + const Tensor& X, + const Tensor& gamma, + const Tensor& beta, + int64_t M, + int64_t N, + T_ACC eps, + Tensor* Y, + Tensor* mean, + Tensor* rstd) { + const T* X_data = X.const_data_ptr(); + const T* gamma_data = gamma.defined() ? gamma.const_data_ptr() : nullptr; + const T* beta_data = beta.defined() ? beta.const_data_ptr() : nullptr; + T* Y_data = Y->data_ptr(); + T_ACC* mean_data = mean->data_ptr(); + T_ACC* rstd_data = rstd->data_ptr(); + + constexpr int num_vec_elems = vec_size; + constexpr int alignment = num_vec_elems * sizeof(T); + bool can_vec_X = can_vectorize(X_data, alignment); + bool can_vec_Y = can_vectorize(Y_data, alignment); + bool can_vec_gamma = + gamma.defined() ? can_vectorize(gamma_data, alignment) : true; + bool can_vec_beta = + beta.defined() ? can_vectorize(beta_data, alignment) : true; + + if ((std::is_same_v || std::is_same_v || + std::is_same_v)&&N <= + static_cast(1ULL << std::numeric_limits::digits) && + N % num_vec_elems == 0 && can_vec_X && can_vec_Y && can_vec_gamma && + can_vec_beta) { + launch_vectorized_layer_norm_kernel( + static_cast(N), + M, + eps, + X_data, + gamma_data, + beta_data, + Y_data, + mean_data, + rstd_data); + } +} + template void _layer_norm_kernel( const Tensor& X, @@ -596,6 +890,29 @@ void _layer_norm_backward_kernel( dY, X, mean_data, var_data, dgamma, dbeta, config_w); } +void layer_norm_kernel( + const Tensor& X, + const Tensor& gamma, + const Tensor& beta, + int64_t M, + int64_t N, + double eps, + Tensor* Y, + Tensor* mean, + Tensor* rstd) { + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, + X.scalar_type(), + "layer_norm_xpu", + [&]() { + using acc_t = acc_type_device; + LayerNormKernelImplInternalXPU( + X, gamma, beta, M, N, static_cast(eps), Y, mean, rstd); + }); +} + +/* std::tuple layer_norm_kernel( const Tensor& X, const Tensor& gamma, @@ -621,6 +938,7 @@ std::tuple layer_norm_kernel( return std::make_tuple(Y, mean, rstd); } +*/ std::tuple layer_norm_backward_kernel( const Tensor& dY, diff --git a/src/ATen/native/xpu/sycl/LayerNormKernels.h b/src/ATen/native/xpu/sycl/LayerNormKernels.h index 0c57a61ba..1f6f2f974 100644 --- a/src/ATen/native/xpu/sycl/LayerNormKernels.h +++ b/src/ATen/native/xpu/sycl/LayerNormKernels.h @@ -6,16 +6,16 @@ namespace at { namespace native { namespace xpu { -TORCH_XPU_API std::tuple layer_norm_kernel( +TORCH_XPU_API void layer_norm_kernel( const Tensor& X, const Tensor& gamma, const Tensor& beta, int64_t M, int64_t N, double eps, - Tensor& Y, - Tensor& mean, - Tensor& rstd); + Tensor* Y, + Tensor* mean, + Tensor* rstd); TORCH_XPU_API std::tuple layer_norm_backward_kernel( const Tensor& dY, From 712011fbff7f6097e227748a8fe54ab966dc39d0 Mon Sep 17 00:00:00 2001 From: min-jean-cho Date: Mon, 17 Feb 2025 18:35:10 -0800 Subject: [PATCH 02/15] Use Welford LayerNorm --- src/ATen/native/xpu/sycl/LayerNormKernels.cpp | 179 +----------------- 1 file changed, 5 insertions(+), 174 deletions(-) diff --git a/src/ATen/native/xpu/sycl/LayerNormKernels.cpp b/src/ATen/native/xpu/sycl/LayerNormKernels.cpp index 8d67ed88a..774aee41b 100644 --- a/src/ATen/native/xpu/sycl/LayerNormKernels.cpp +++ b/src/ATen/native/xpu/sycl/LayerNormKernels.cpp @@ -15,105 +15,6 @@ namespace at { namespace native { namespace xpu { -template -class LayerNormForward : public NormForward { - public: - using accscalar_t = acc_type_device; - typedef NormForward NF; - LayerNormForward() = delete; - LayerNormForward( - const scalar_t* X_data, - scalar_t* Y_data, - mean_t* mean_data, - mean_t* var_data, - const weight_t* gamma_data, - const weight_t* beta_data, - accscalar_t eps, - int64_t M, - int64_t N) - : NormForward( - X_data, - Y_data, - mean_data, - var_data, - gamma_data, - beta_data, - eps), - M(M), - N(N) { - numel = M * N; - }; - - template < - int vec_size, - typename index_t, - typename vec_t, - typename weight_vec_t, - typename nd_item_id> - void update( - nd_item_id item_id, - const NormConfig& cfg, - accscalar_t sum1 = 0, - accscalar_t sum2 = 0) const { - auto group_id = item_id.get_group(0); - auto group_id_foreach = item_id.get_group(1); - auto local_id = item_id.get_local_id(2); - - index_t group_offset = group_id * cfg.problem_size; - if (cfg.workgroup_num_foreach == 1) { - if (local_id == 0) { - NF::reduce_project(item_id, sum1, sum2, cfg); - } - item_id.barrier(sycl_global_fence); - } - - mean_t mean_val = NF::mean_data[group_id]; - mean_t var_val = NF::var_data[group_id]; - for (index_t j = local_id * vec_size; j < cfg.workgroup_work_size; - j += cfg.workgroup_size * vec_size) { - index_t plane_offset = group_id_foreach * cfg.workgroup_work_size + j; - if (plane_offset < (index_t)cfg.problem_size) { - vec_t X_val = *(reinterpret_cast( - NF::X_data + group_offset + plane_offset)); - weight_vec_t gamma_val, beta_val; - vec_t Y_val; - if (NF::gamma_data != nullptr) { - gamma_val = *(reinterpret_cast( - NF::gamma_data + plane_offset)); - } - if (NF::beta_data != nullptr) { - beta_val = *(reinterpret_cast( - NF::beta_data + plane_offset)); - } - - for (int v = 0; v < vec_size; ++v) { - if (NF::gamma_data != nullptr && NF::beta_data != nullptr) { - Y_val[v] = static_cast(gamma_val[v]) * - (var_val * static_cast(X_val[v] - mean_val)) + - static_cast(beta_val[v]); - } else if (NF::gamma_data != nullptr) { - Y_val[v] = static_cast(gamma_val[v]) * - (var_val * static_cast(X_val[v] - mean_val)); - } else if (NF::beta_data != nullptr) { - Y_val[v] = - (var_val * static_cast(X_val[v] - mean_val)) + - static_cast(beta_val[v]); - } else { - Y_val[v] = - (var_val * static_cast(X_val[v] - mean_val)); - } - } - *(reinterpret_cast(NF::Y_data + group_offset + plane_offset)) = - Y_val; - } - } - }; - - int64_t M; - int64_t N; - int64_t numel; -}; - template class LayerNormBackward : public NormBackward { public: @@ -397,7 +298,7 @@ WelfordDataLN compute_stats( } template -struct vectorized_layer_norm_kernel_impl { +struct VectorizedLayerNormKernelFunctor { [[intel::reqd_sub_group_size(SIMD)]] void operator()( sycl::nd_item<2> item_id) const { auto i1 = item_id.get_group(1); @@ -465,7 +366,7 @@ struct vectorized_layer_norm_kernel_impl { countbuf_ = sycl_local_acc_t(16, cgh); } - vectorized_layer_norm_kernel_impl( + VectorizedLayerNormKernelFunctor( const int N, T_ACC eps, const T* __restrict__ X, @@ -507,7 +408,7 @@ void launch_vectorized_layer_norm_kernel( T* Y_data, T_ACC* mean_data, T_ACC* rstd_data) { - vectorized_layer_norm_kernel_impl kfn( + VectorizedLayerNormKernelFunctor kfn( N, eps, X_data, gamma_data, beta_data, mean_data, rstd_data, Y_data); int64_t sg_size = syclMaxSubGroupSize(); int64_t wg_size = syclMaxWorkGroupSize(kfn); @@ -518,7 +419,7 @@ void launch_vectorized_layer_norm_kernel( } template -void LayerNormKernelImplInternalXPU( +void _layer_norm_kernel( const Tensor& X, const Tensor& gamma, const Tensor& beta, @@ -562,48 +463,6 @@ void LayerNormKernelImplInternalXPU( } } -template -void _layer_norm_kernel( - const Tensor& X, - const Tensor& gamma, - const Tensor& beta, - int64_t M, - int64_t N, - acc_type_device eps, - Tensor& Y, - Tensor& mean, - Tensor& rstd) { - TORCH_CHECK(X.numel() == M * N); - TORCH_CHECK(!gamma.defined() || gamma.numel() == N); - TORCH_CHECK(!beta.defined() || beta.numel() == N); - - const scalar_t* X_data = X.const_data_ptr(); - scalar_t* Y_data = Y.data_ptr(); - mean_t* mean_data = mean.data_ptr(); - mean_t* var_data = rstd.data_ptr(); - const weight_t* gamma_data = - gamma.defined() ? gamma.const_data_ptr() : nullptr; - const weight_t* beta_data = - beta.defined() ? beta.const_data_ptr() : nullptr; - - auto config = NormConfig(M, N, 1, sizeof(scalar_t)); - bool can_use_32bit_index = canUse32BitIndexMath(X); - LayerNormForward norm( - X_data, Y_data, mean_data, var_data, gamma_data, beta_data, eps, M, N); - - if (config.workgroup_num_foreach == 1) { - vectorized_fused_norm_kernel( - norm, config, can_use_32bit_index); - } else { - Tensor semaphores, scratchpad; - config.template init_global_reduce(X, semaphores, scratchpad); - rowwise_moments_kernel( - norm, config, can_use_32bit_index); - norm_update_kernel( - norm, config, can_use_32bit_index); - } -} - template < typename scalar_t, typename accscalar_t, @@ -907,39 +766,11 @@ void layer_norm_kernel( "layer_norm_xpu", [&]() { using acc_t = acc_type_device; - LayerNormKernelImplInternalXPU( + _layer_norm_kernel( X, gamma, beta, M, N, static_cast(eps), Y, mean, rstd); }); } -/* -std::tuple layer_norm_kernel( - const Tensor& X, - const Tensor& gamma, - const Tensor& beta, - int64_t M, - int64_t N, - double eps, - Tensor& Y, - Tensor& mean, - Tensor& rstd) { - if (M > 0) { - AT_DISPATCH_FLOATING_TYPES_AND2( - at::ScalarType::Half, - at::ScalarType::BFloat16, - X.scalar_type(), - "layer_norm_xpu", - [&]() { - using acc_t = acc_type_device; - _layer_norm_kernel( - X, gamma, beta, M, N, static_cast(eps), Y, mean, rstd); - }); - } - - return std::make_tuple(Y, mean, rstd); -} -*/ - std::tuple layer_norm_backward_kernel( const Tensor& dY, const Tensor& X, From affa9ce6b6f5516964ab34606de21a8e25bba345 Mon Sep 17 00:00:00 2001 From: min-jean-cho Date: Mon, 17 Feb 2025 23:33:46 -0800 Subject: [PATCH 03/15] Use Welford LayerNorm --- src/ATen/native/xpu/sycl/LayerNormKernels.cpp | 28 +++++++++++++------ 1 file changed, 20 insertions(+), 8 deletions(-) diff --git a/src/ATen/native/xpu/sycl/LayerNormKernels.cpp b/src/ATen/native/xpu/sycl/LayerNormKernels.cpp index 774aee41b..17e418e0f 100644 --- a/src/ATen/native/xpu/sycl/LayerNormKernels.cpp +++ b/src/ATen/native/xpu/sycl/LayerNormKernels.cpp @@ -305,7 +305,6 @@ struct VectorizedLayerNormKernelFunctor { const T* block_row = X_ + i1 * N_; WelfordDataLN wd = compute_stats( block_row, N_, meansigmabuf_, countbuf_, item_id); - using vec_t = aligned_vector; const vec_t* X_vec = reinterpret_cast(block_row); const vec_t* gamma_vec = @@ -362,8 +361,10 @@ struct VectorizedLayerNormKernelFunctor { } void sycl_ker_config_convention(sycl::handler& cgh) { - meansigmabuf_ = sycl_local_acc_t(2 * 16, cgh); - countbuf_ = sycl_local_acc_t(16, cgh); + meansigmabuf_ = + sycl_local_acc_t(2 * get_group_reduce_group_size(sg_size_), cgh); + countbuf_ = + sycl_local_acc_t(get_group_reduce_group_size(sg_size_), cgh); } VectorizedLayerNormKernelFunctor( @@ -374,7 +375,8 @@ struct VectorizedLayerNormKernelFunctor { const T* beta, T_ACC* mean, T_ACC* rstd, - T* Y) + T* Y, + int64_t sg_size) : N_(N), eps_(eps), X_(X), @@ -382,7 +384,8 @@ struct VectorizedLayerNormKernelFunctor { beta_(beta), mean_(mean), rstd_(rstd), - Y_(Y) {} + Y_(Y), + sg_size_(sg_size) {} private: const int N_; @@ -393,6 +396,7 @@ struct VectorizedLayerNormKernelFunctor { T_ACC* mean_; T_ACC* rstd_; T* Y_; + int64_t sg_size_; sycl_local_acc_t meansigmabuf_; sycl_local_acc_t countbuf_; }; @@ -408,12 +412,20 @@ void launch_vectorized_layer_norm_kernel( T* Y_data, T_ACC* mean_data, T_ACC* rstd_data) { - VectorizedLayerNormKernelFunctor kfn( - N, eps, X_data, gamma_data, beta_data, mean_data, rstd_data, Y_data); int64_t sg_size = syclMaxSubGroupSize(); + VectorizedLayerNormKernelFunctor kfn( + N, + eps, + X_data, + gamma_data, + beta_data, + mean_data, + rstd_data, + Y_data, + sg_size); int64_t wg_size = syclMaxWorkGroupSize(kfn); sycl::range<2> local_range{size_t(wg_size / sg_size), size_t(sg_size)}; - sycl::range<2> global_range(M * size_t(wg_size / sg_size), size_t(sg_size)); + sycl::range<2> global_range(size_t(wg_size / sg_size), M * size_t(sg_size)); auto queue = getCurrentSYCLQueue(); sycl_kernel_submit(global_range, local_range, queue, kfn); } From db37b210677a86adc647f13db6a739b2d7ed6213 Mon Sep 17 00:00:00 2001 From: min-jean-cho Date: Mon, 17 Feb 2025 23:42:31 -0800 Subject: [PATCH 04/15] Update src/ATen/native/xpu/sycl/LayerNormKernels.cpp Co-authored-by: Yu, Guangye <106960996+guangyey@users.noreply.github.com> --- src/ATen/native/xpu/sycl/LayerNormKernels.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ATen/native/xpu/sycl/LayerNormKernels.cpp b/src/ATen/native/xpu/sycl/LayerNormKernels.cpp index 17e418e0f..6f8483862 100644 --- a/src/ATen/native/xpu/sycl/LayerNormKernels.cpp +++ b/src/ATen/native/xpu/sycl/LayerNormKernels.cpp @@ -224,7 +224,7 @@ WelfordDataLN WelfordCombine( template WelfordDataLN compute_stats( - const T* __restrict__ X, + const T* RESTRICT X, const int N, shared_t meansigmabuf, shared_t countbuf, From 737f0e844b104372c0102a670aafb43797e146fa Mon Sep 17 00:00:00 2001 From: min-jean-cho Date: Mon, 17 Feb 2025 23:43:51 -0800 Subject: [PATCH 05/15] Use Welford LayerNorm --- src/ATen/native/xpu/sycl/LayerNormKernels.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/ATen/native/xpu/sycl/LayerNormKernels.cpp b/src/ATen/native/xpu/sycl/LayerNormKernels.cpp index 6f8483862..8bb7899c2 100644 --- a/src/ATen/native/xpu/sycl/LayerNormKernels.cpp +++ b/src/ATen/native/xpu/sycl/LayerNormKernels.cpp @@ -370,7 +370,7 @@ struct VectorizedLayerNormKernelFunctor { VectorizedLayerNormKernelFunctor( const int N, T_ACC eps, - const T* __restrict__ X, + const T* RESTRICT X, const T* gamma, const T* beta, T_ACC* mean, @@ -390,7 +390,7 @@ struct VectorizedLayerNormKernelFunctor { private: const int N_; T_ACC eps_; - const T* __restrict__ X_; + const T* RESTRICT X_; const T* gamma_; const T* beta_; T_ACC* mean_; From 6dcf05979960ad49d22c379ad94ce83b77e4bde2 Mon Sep 17 00:00:00 2001 From: min-jean-cho Date: Tue, 18 Feb 2025 18:05:52 -0800 Subject: [PATCH 06/15] Use Welford LayerNorm --- src/ATen/native/xpu/sycl/LayerNormKernels.cpp | 146 ++++++++++++++++++ 1 file changed, 146 insertions(+) diff --git a/src/ATen/native/xpu/sycl/LayerNormKernels.cpp b/src/ATen/native/xpu/sycl/LayerNormKernels.cpp index 8bb7899c2..17e9c0406 100644 --- a/src/ATen/native/xpu/sycl/LayerNormKernels.cpp +++ b/src/ATen/native/xpu/sycl/LayerNormKernels.cpp @@ -3,6 +3,7 @@ #include #include #include +#include #include #include @@ -180,6 +181,139 @@ bool can_vectorize(const T* ptr, int alignment) { return addr % alignment == 0; }; +template +struct RowwiseMomentsFunctor : public __SYCL_KER_CONFIG_CONVENTION__ { + using WelfordType = WelfordData; + using WelfordOp = WelfordOps>; + + [[intel::reqd_sub_group_size(SIMD)]] void operator()( + sycl::nd_item<1> item_id) const { + const int64_t i = item_id.get_group(0); + WelfordOp welford_op = {/*correction=*/0, /*take_sqrt=*/false}; + WelfordType val(0, 0, 0, 0); + for (int64_t j = item_id.get_local_id(0); j < N_; + j += item_id.get_local_range(0)) { + const int64_t index = i * N_ + j; + val = welford_op.reduce(val, static_cast(X_[index]), index); + } + + val = GroupReduceWithoutBroadcast( + item_id, val, welford_op, shared_); + + if (item_id.get_local_id(0) == 0) { + T_ACC m1; + T_ACC m2; + std::tie(m2, m1) = welford_op.project(val); + mean_[i] = m1; + rstd_[i] = c10::xpu::compat::rsqrt(m2 + eps_); + } + } + + void sycl_ker_config_convention(sycl::handler& cgh) { + shared_ = sycl_local_acc_t(SIMD, cgh); + } + + RowwiseMomentsFunctor( + int64_t N, + T_ACC eps, + const T* X, + T_ACC* mean, + T_ACC* rstd) + : N_(N), eps_(eps), X_(X), mean_(mean), rstd_(rstd) {} + + private: + int64_t N_; + T_ACC eps_; + const T* X_; + T_ACC* mean_; + T_ACC* rstd_; + sycl_local_acc_t shared_; +}; + +template +void launch_rowwise_moments_kernel( + int N, + int64_t M, + T_ACC eps, + const T* X_data, + T_ACC* mean_data, + T_ACC* rstd_data) { + RowwiseMomentsFunctor kfn(N, eps, X_data, mean_data, rstd_data); + + int64_t sg_size = syclMaxSubGroupSize(); + int64_t wg_size = get_group_reduce_group_size(sg_size); + sycl::range<1> local_range{size_t(wg_size)}; + sycl::range<1> global_range{size_t(M * wg_size)}; + auto queue = getCurrentSYCLQueue(); + + sycl_kernel_submit(global_range, local_range, queue, kfn); +} + +template +struct LayerNormForwardKernelFunctor { + void operator()(sycl::nd_item<1> item_id) const { + const int64_t i = item_id.get_group(0); + for (int64_t j = item_id.get_local_id(0); j < N_; + j += item_id.get_local_range(0)) { + const int64_t index = i * N_ + j; + const T_ACC gamma_v = + gamma_ == nullptr ? T_ACC(1) : static_cast(gamma_[j]); + const T_ACC beta_v = + beta_ == nullptr ? T_ACC(0) : static_cast(beta_[j]); + Y_[index] = + (static_cast(X_[index]) - static_cast(mean_[i])) * + static_cast(rstd_[i]) * gamma_v + + beta_v; + } + } + LayerNormForwardKernelFunctor( + int64_t N, + const T* X, + const T_ACC* mean, + const T_ACC* rstd, + const T* gamma, + const T* beta, + T* Y) + : N_(N), + X_(X), + mean_(mean), + rstd_(rstd), + gamma_(gamma), + beta_(beta), + Y_(Y) {} + + private: + int64_t N_; + const T* X_; + const T_ACC* mean_; + const T_ACC* rstd_; + const T* gamma_; + const T* beta_; + T* Y_; +}; + +template +void launch_layer_norm_forward_kernel( + int N, + int64_t M, + const T* X_data, + const T_ACC* mean_data, + const T_ACC* rstd_data, + const T* gamma_data, + const T* beta_data, + T* Y_data) { + LayerNormForwardKernelFunctor kfn( + N, X_data, mean_data, rstd_data, gamma_data, beta_data, Y_data); + + int64_t sg_size = syclMaxSubGroupSize(); + int64_t wg_size = get_group_reduce_group_size(sg_size); + sycl::range<1> local_range{size_t(wg_size)}; + sycl::range<1> global_range(M * size_t(wg_size)); + auto queue = getCurrentSYCLQueue(); + + sycl_kernel_submit(global_range, local_range, queue, kfn); +} + struct WelfordDataLN { float mean; float sigma2; @@ -472,6 +606,18 @@ void _layer_norm_kernel( Y_data, mean_data, rstd_data); + } else { + launch_rowwise_moments_kernel( + static_cast(N), M, eps, X_data, mean_data, rstd_data); + launch_layer_norm_forward_kernel( + static_cast(N), + M, + X_data, + mean_data, + rstd_data, + gamma_data, + beta_data, + Y_data); } } From 4358e06dfdb98324c04e83064e950f416bbfff95 Mon Sep 17 00:00:00 2001 From: min-jean-cho Date: Wed, 19 Feb 2025 04:30:29 -0800 Subject: [PATCH 07/15] Use Welford LayerNrom --- src/ATen/native/xpu/sycl/LayerNormKernels.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ATen/native/xpu/sycl/LayerNormKernels.cpp b/src/ATen/native/xpu/sycl/LayerNormKernels.cpp index 17e9c0406..075a2fa56 100644 --- a/src/ATen/native/xpu/sycl/LayerNormKernels.cpp +++ b/src/ATen/native/xpu/sycl/LayerNormKernels.cpp @@ -382,7 +382,7 @@ WelfordDataLN compute_stats( } // intra-warp reduction auto sg = item_id.get_sub_group(); - int sgSize = static_cast(sg.get_local_range()[1]); + int sgSize = static_cast(sg.get_local_range()[0]); for (int offset = (sgSize >> 1); offset > 0; offset >>= 1) { WelfordDataLN wdB{ sycl::shift_group_left(sg, wd.mean, offset), From 7cf71a52c17ccfbcd54c7f855f04d9706c089e94 Mon Sep 17 00:00:00 2001 From: min-jean-cho Date: Thu, 20 Feb 2025 22:34:55 -0800 Subject: [PATCH 08/15] Use Welford Layernorm --- src/ATen/native/xpu/sycl/LayerNormKernels.cpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/ATen/native/xpu/sycl/LayerNormKernels.cpp b/src/ATen/native/xpu/sycl/LayerNormKernels.cpp index 075a2fa56..042ddc424 100644 --- a/src/ATen/native/xpu/sycl/LayerNormKernels.cpp +++ b/src/ATen/native/xpu/sycl/LayerNormKernels.cpp @@ -382,8 +382,7 @@ WelfordDataLN compute_stats( } // intra-warp reduction auto sg = item_id.get_sub_group(); - int sgSize = static_cast(sg.get_local_range()[0]); - for (int offset = (sgSize >> 1); offset > 0; offset >>= 1) { + for (int offset = (SIMD >> 1); offset > 0; offset >>= 1) { WelfordDataLN wdB{ sycl::shift_group_left(sg, wd.mean, offset), sycl::shift_group_left(sg, wd.sigma2, offset), From 1692cd77381ab008bec7eed99df8b6956446d7b9 Mon Sep 17 00:00:00 2001 From: min-jean-cho Date: Thu, 20 Feb 2025 22:52:17 -0800 Subject: [PATCH 09/15] Use Welford Layernorm --- test/regressions/test_layer_norm.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) create mode 100644 test/regressions/test_layer_norm.py diff --git a/test/regressions/test_layer_norm.py b/test/regressions/test_layer_norm.py new file mode 100644 index 000000000..b78fb9663 --- /dev/null +++ b/test/regressions/test_layer_norm.py @@ -0,0 +1,20 @@ +# Owner(s): ["module: intel"] +import torch +import torch.nn as nn +from torch.testing._internal.common_utils import TestCase + +cpu_device = torch.device("cpu") +xpu_device = torch.device("xpu") + + +class TestLayerNorm(TestCase): + def test_layer_norm_no_nan(self, dtype=torch.float): + dim = [5] + x_cpu = torch.tensor([[1e15, 1e15 + 1, 1e15 + 2, 1e15 + 3, 1e15 + 4]], dtype=torch.float32) + layernorm_cpu = nn.LayerNorm(dim) + y_cpu = layernorm_cpu(x_cpu) + + x_xpu = x_cpu.to(xpu_device) + layernorm_xpu = nn.LayerNorm(dim).to(xpu_device) + y_xpu = layernorm_xpu(x_xpu) + self.assertEqual(y_cpu, y_xpu.to(cpu_device)) From 61a837b58fe46f024fd9849c328dfad43453151a Mon Sep 17 00:00:00 2001 From: min-jean-cho Date: Thu, 20 Feb 2025 22:59:29 -0800 Subject: [PATCH 10/15] Use Welford Layernorm --- test/regressions/test_layer_norm.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/test/regressions/test_layer_norm.py b/test/regressions/test_layer_norm.py index b78fb9663..f042352a1 100644 --- a/test/regressions/test_layer_norm.py +++ b/test/regressions/test_layer_norm.py @@ -10,7 +10,9 @@ class TestLayerNorm(TestCase): def test_layer_norm_no_nan(self, dtype=torch.float): dim = [5] - x_cpu = torch.tensor([[1e15, 1e15 + 1, 1e15 + 2, 1e15 + 3, 1e15 + 4]], dtype=torch.float32) + x_cpu = torch.tensor( + [[1e15, 1e15 + 1, 1e15 + 2, 1e15 + 3, 1e15 + 4]], dtype=torch.float32 + ) layernorm_cpu = nn.LayerNorm(dim) y_cpu = layernorm_cpu(x_cpu) From fe0b4bc902ff8373cd3e3f1235e2ef8e8425b41d Mon Sep 17 00:00:00 2001 From: min-jean-cho Date: Thu, 20 Feb 2025 23:03:05 -0800 Subject: [PATCH 11/15] Use Welford Layernorm --- test/regressions/test_layer_norm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/regressions/test_layer_norm.py b/test/regressions/test_layer_norm.py index f042352a1..b3bb47584 100644 --- a/test/regressions/test_layer_norm.py +++ b/test/regressions/test_layer_norm.py @@ -11,7 +11,7 @@ class TestLayerNorm(TestCase): def test_layer_norm_no_nan(self, dtype=torch.float): dim = [5] x_cpu = torch.tensor( - [[1e15, 1e15 + 1, 1e15 + 2, 1e15 + 3, 1e15 + 4]], dtype=torch.float32 + [[1e15, 1e15 + 1, 1e15 + 2, 1e15 + 3, 1e15 + 4]] ) layernorm_cpu = nn.LayerNorm(dim) y_cpu = layernorm_cpu(x_cpu) From f1c78eba545d9bc9da93cc7af3d28b158c084e91 Mon Sep 17 00:00:00 2001 From: min-jean-cho Date: Thu, 20 Feb 2025 23:05:38 -0800 Subject: [PATCH 12/15] Use Welford Layernorm --- test/regressions/test_layer_norm.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/test/regressions/test_layer_norm.py b/test/regressions/test_layer_norm.py index b3bb47584..254ccb08c 100644 --- a/test/regressions/test_layer_norm.py +++ b/test/regressions/test_layer_norm.py @@ -10,9 +10,7 @@ class TestLayerNorm(TestCase): def test_layer_norm_no_nan(self, dtype=torch.float): dim = [5] - x_cpu = torch.tensor( - [[1e15, 1e15 + 1, 1e15 + 2, 1e15 + 3, 1e15 + 4]] - ) + x_cpu = torch.tensor([[1e15, 1e15 + 1, 1e15 + 2, 1e15 + 3, 1e15 + 4]]) layernorm_cpu = nn.LayerNorm(dim) y_cpu = layernorm_cpu(x_cpu) From d405b59e0fdea0516bb08ba72f9a2c30631a4015 Mon Sep 17 00:00:00 2001 From: min-jean-cho Date: Fri, 21 Feb 2025 00:00:05 -0800 Subject: [PATCH 13/15] Use Welford LayerNorm --- src/ATen/native/xpu/sycl/LayerNormKernels.cpp | 72 ++++++++++--------- 1 file changed, 38 insertions(+), 34 deletions(-) diff --git a/src/ATen/native/xpu/sycl/LayerNormKernels.cpp b/src/ATen/native/xpu/sycl/LayerNormKernels.cpp index 042ddc424..ea1c1ea06 100644 --- a/src/ATen/native/xpu/sycl/LayerNormKernels.cpp +++ b/src/ATen/native/xpu/sycl/LayerNormKernels.cpp @@ -330,8 +330,10 @@ WelfordDataLN WelfordOnlineSum(const U val, const WelfordDataLN& curr_sum) { U new_mean = curr_sum.mean + delta * (1.f / new_count); // proper division is slow, this is less // accurate but noticeably faster - return WelfordDataLN( - new_mean, curr_sum.sigma2 + delta * (val - new_mean), new_count); + return { + static_cast(new_mean), + static_cast(curr_sum.sigma2 + delta * (val - new_mean)), + static_cast(new_count)}; } WelfordDataLN WelfordCombine( @@ -356,17 +358,15 @@ WelfordDataLN WelfordCombine( return {mean, sigma2, count}; } -template +template WelfordDataLN compute_stats( const T* RESTRICT X, const int N, - shared_t meansigmabuf, - shared_t countbuf, + T_ACC& buf, sycl::nd_item<2>& item_id) { - int tid = item_id.get_local_linear_id(); // X points to the row to read using vec_t = aligned_vector; - using acc_t = acc_type; + using acc_t = acc_type_device; const vec_t* X_vec = reinterpret_cast(X); const int numx = item_id.get_local_range(1) * item_id.get_local_range(0); const int thrx = item_id.get_local_linear_id(); @@ -389,43 +389,45 @@ WelfordDataLN compute_stats( sycl::shift_group_left(sg, wd.count, offset)}; wd = WelfordCombine(wd, wdB); } + // threadIdx.x == 0 has correct values for each warp // inter-warp reductions if (item_id.get_local_range(0) > 1) { + auto addr_offset = item_id.get_local_range(0); for (int offset = item_id.get_local_range(0) / 2; offset > 0; offset /= 2) { // upper half of warps write to shared if (item_id.get_local_id(1) == 0 && item_id.get_local_id(0) >= offset && item_id.get_local_id(0) < 2 * offset) { const int wrt_y = item_id.get_local_id(0) - offset; - meansigmabuf[2 * wrt_y] = wd.mean; - meansigmabuf[2 * wrt_y + 1] = wd.sigma2; - countbuf[wrt_y] = wd.count; + buf[2 * wrt_y] = wd.mean; + buf[2 * wrt_y + 1] = wd.sigma2; + buf[wrt_y + addr_offset] = wd.count; } item_id.barrier(sycl_local_fence); + // lower half merges if (item_id.get_local_id(1) == 0 && item_id.get_local_id(0) < offset) { - const int local_y = item_id.get_local_id(0); + const int rd_y = item_id.get_local_id(0); WelfordDataLN wdB{ - static_cast(meansigmabuf[2 * local_y]), - static_cast(meansigmabuf[2 * local_y + 1]), - static_cast(countbuf[local_y])}; + static_cast(buf[2 * rd_y]), + static_cast(buf[2 * rd_y + 1]), + static_cast(buf[rd_y + addr_offset])}; wd = WelfordCombine(wd, wdB); } item_id.barrier(sycl_local_fence); } + if (item_id.get_local_id(1) == 0 && item_id.get_local_id(0) == 0) { - meansigmabuf[0] = wd.mean; - meansigmabuf[1] = wd.sigma2 / float(N); + buf[0] = wd.mean; + buf[1] = wd.sigma2 / float(N); } item_id.barrier(sycl_local_fence); return WelfordDataLN{ - static_cast(meansigmabuf[0]), - static_cast(meansigmabuf[1]), - 0.f}; + static_cast(buf[0]), static_cast(buf[1]), 0.f}; } else { return WelfordDataLN{ - sycl::shift_group_left(sg, wd.mean, 0), - sycl::shift_group_left(sg, wd.sigma2, 0) / float(N), + sycl::select_from_group(sg, wd.mean, 0), + sycl::select_from_group(sg, wd.sigma2, 0) / float(N), 0.f}; } } @@ -436,8 +438,8 @@ struct VectorizedLayerNormKernelFunctor { sycl::nd_item<2> item_id) const { auto i1 = item_id.get_group(1); const T* block_row = X_ + i1 * N_; - WelfordDataLN wd = compute_stats( - block_row, N_, meansigmabuf_, countbuf_, item_id); + WelfordDataLN wd = compute_stats(block_row, N_, buf_, item_id); + using vec_t = aligned_vector; const vec_t* X_vec = reinterpret_cast(block_row); const vec_t* gamma_vec = @@ -494,10 +496,8 @@ struct VectorizedLayerNormKernelFunctor { } void sycl_ker_config_convention(sycl::handler& cgh) { - meansigmabuf_ = - sycl_local_acc_t(2 * get_group_reduce_group_size(sg_size_), cgh); - countbuf_ = - sycl_local_acc_t(get_group_reduce_group_size(sg_size_), cgh); + buf_ = + sycl_local_acc_t(sycl::range<1>((wg_size_ / sg_size_) * 2), cgh); } VectorizedLayerNormKernelFunctor( @@ -509,7 +509,8 @@ struct VectorizedLayerNormKernelFunctor { T_ACC* mean, T_ACC* rstd, T* Y, - int64_t sg_size) + int64_t sg_size, + int64_t wg_size) : N_(N), eps_(eps), X_(X), @@ -518,7 +519,8 @@ struct VectorizedLayerNormKernelFunctor { mean_(mean), rstd_(rstd), Y_(Y), - sg_size_(sg_size) {} + sg_size_(sg_size), + wg_size_(wg_size) {} private: const int N_; @@ -530,8 +532,8 @@ struct VectorizedLayerNormKernelFunctor { T_ACC* rstd_; T* Y_; int64_t sg_size_; - sycl_local_acc_t meansigmabuf_; - sycl_local_acc_t countbuf_; + int64_t wg_size_; + sycl_local_acc_t buf_; }; template @@ -545,8 +547,10 @@ void launch_vectorized_layer_norm_kernel( T* Y_data, T_ACC* mean_data, T_ACC* rstd_data) { + using KernelClass = VectorizedLayerNormKernelFunctor; int64_t sg_size = syclMaxSubGroupSize(); - VectorizedLayerNormKernelFunctor kfn( + int64_t wg_size = syclMaxWorkGroupSize(); + KernelClass kfn( N, eps, X_data, @@ -555,8 +559,8 @@ void launch_vectorized_layer_norm_kernel( mean_data, rstd_data, Y_data, - sg_size); - int64_t wg_size = syclMaxWorkGroupSize(kfn); + sg_size, + wg_size); sycl::range<2> local_range{size_t(wg_size / sg_size), size_t(sg_size)}; sycl::range<2> global_range(size_t(wg_size / sg_size), M * size_t(sg_size)); auto queue = getCurrentSYCLQueue(); From 7395b883c90debd05cf78a0cdf608a015e4377b6 Mon Sep 17 00:00:00 2001 From: min-jean-cho Date: Fri, 21 Feb 2025 11:15:32 -0800 Subject: [PATCH 14/15] Use Welford LayerNorm --- src/ATen/native/xpu/sycl/LayerNormKernels.cpp | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/src/ATen/native/xpu/sycl/LayerNormKernels.cpp b/src/ATen/native/xpu/sycl/LayerNormKernels.cpp index ea1c1ea06..1f9a48870 100644 --- a/src/ATen/native/xpu/sycl/LayerNormKernels.cpp +++ b/src/ATen/native/xpu/sycl/LayerNormKernels.cpp @@ -433,7 +433,8 @@ WelfordDataLN compute_stats( } template -struct VectorizedLayerNormKernelFunctor { +struct VectorizedLayerNormKernelFunctor + : public __SYCL_KER_CONFIG_CONVENTION__ { [[intel::reqd_sub_group_size(SIMD)]] void operator()( sycl::nd_item<2> item_id) const { auto i1 = item_id.get_group(1); @@ -496,8 +497,7 @@ struct VectorizedLayerNormKernelFunctor { } void sycl_ker_config_convention(sycl::handler& cgh) { - buf_ = - sycl_local_acc_t(sycl::range<1>((wg_size_ / sg_size_) * 2), cgh); + buf_ = sycl_local_acc_t((wg_size_ / SIMD) * 2, cgh); } VectorizedLayerNormKernelFunctor( @@ -509,7 +509,6 @@ struct VectorizedLayerNormKernelFunctor { T_ACC* mean, T_ACC* rstd, T* Y, - int64_t sg_size, int64_t wg_size) : N_(N), eps_(eps), @@ -519,7 +518,6 @@ struct VectorizedLayerNormKernelFunctor { mean_(mean), rstd_(rstd), Y_(Y), - sg_size_(sg_size), wg_size_(wg_size) {} private: @@ -548,7 +546,6 @@ void launch_vectorized_layer_norm_kernel( T_ACC* mean_data, T_ACC* rstd_data) { using KernelClass = VectorizedLayerNormKernelFunctor; - int64_t sg_size = syclMaxSubGroupSize(); int64_t wg_size = syclMaxWorkGroupSize(); KernelClass kfn( N, @@ -559,10 +556,9 @@ void launch_vectorized_layer_norm_kernel( mean_data, rstd_data, Y_data, - sg_size, wg_size); - sycl::range<2> local_range{size_t(wg_size / sg_size), size_t(sg_size)}; - sycl::range<2> global_range(size_t(wg_size / sg_size), M * size_t(sg_size)); + sycl::range<2> local_range{size_t(wg_size / SIMD), SIMD}; + sycl::range<2> global_range(size_t(wg_size / SIMD), M * SIMD); auto queue = getCurrentSYCLQueue(); sycl_kernel_submit(global_range, local_range, queue, kfn); } From bdac8dde7bfc56cc1e057004cd6488b686b1200c Mon Sep 17 00:00:00 2001 From: Yutao Xu Date: Mon, 24 Feb 2025 15:56:09 +0800 Subject: [PATCH 15/15] Update LayerNormKernels.cpp --- src/ATen/native/xpu/sycl/LayerNormKernels.cpp | 20 ++++++------------- 1 file changed, 6 insertions(+), 14 deletions(-) diff --git a/src/ATen/native/xpu/sycl/LayerNormKernels.cpp b/src/ATen/native/xpu/sycl/LayerNormKernels.cpp index 1f9a48870..6982354f8 100644 --- a/src/ATen/native/xpu/sycl/LayerNormKernels.cpp +++ b/src/ATen/native/xpu/sycl/LayerNormKernels.cpp @@ -232,7 +232,7 @@ struct RowwiseMomentsFunctor : public __SYCL_KER_CONFIG_CONVENTION__ { template void launch_rowwise_moments_kernel( - int N, + int64_t N, int64_t M, T_ACC eps, const T* X_data, @@ -240,7 +240,7 @@ void launch_rowwise_moments_kernel( T_ACC* rstd_data) { RowwiseMomentsFunctor kfn(N, eps, X_data, mean_data, rstd_data); - int64_t sg_size = syclMaxSubGroupSize(); + int64_t sg_size = SIMD; int64_t wg_size = get_group_reduce_group_size(sg_size); sycl::range<1> local_range{size_t(wg_size)}; sycl::range<1> global_range{size_t(M * wg_size)}; @@ -294,7 +294,7 @@ struct LayerNormForwardKernelFunctor { template void launch_layer_norm_forward_kernel( - int N, + int64_t N, int64_t M, const T* X_data, const T_ACC* mean_data, @@ -305,7 +305,7 @@ void launch_layer_norm_forward_kernel( LayerNormForwardKernelFunctor kfn( N, X_data, mean_data, rstd_data, gamma_data, beta_data, Y_data); - int64_t sg_size = syclMaxSubGroupSize(); + int64_t sg_size = SIMD; int64_t wg_size = get_group_reduce_group_size(sg_size); sycl::range<1> local_range{size_t(wg_size)}; sycl::range<1> global_range(M * size_t(wg_size)); @@ -606,17 +606,9 @@ void _layer_norm_kernel( mean_data, rstd_data); } else { - launch_rowwise_moments_kernel( - static_cast(N), M, eps, X_data, mean_data, rstd_data); + launch_rowwise_moments_kernel(N, M, eps, X_data, mean_data, rstd_data); launch_layer_norm_forward_kernel( - static_cast(N), - M, - X_data, - mean_data, - rstd_data, - gamma_data, - beta_data, - Y_data); + N, M, X_data, mean_data, rstd_data, gamma_data, beta_data, Y_data); } }