From 41273ebec7941ddf12c785cd6f7113cf8c9062c0 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Mon, 26 Aug 2024 08:31:47 -0700 Subject: [PATCH] norms refractor Signed-off-by: Phuong Nguyen --- transformer_engine/common/CMakeLists.txt | 1 + transformer_engine/common/layer_norm/ln.h | 239 ------------ .../common/layer_norm/ln_api.cpp | 351 ++---------------- .../common/layer_norm/ln_bwd_kernels.cuh | 2 +- .../layer_norm/ln_bwd_semi_cuda_kernel.cu | 40 +- .../common/layer_norm/ln_fwd_cuda_kernel.cu | 39 +- .../common/layer_norm/ln_fwd_kernels.cuh | 2 +- .../common/layer_norm/norms.cpp | 350 +++++++++++++++++ transformer_engine/common/layer_norm/norms.h | 334 +++++++++++++++++ transformer_engine/common/rmsnorm/rmsnorm.h | 89 ----- .../common/rmsnorm/rmsnorm_api.cpp | 307 ++------------- .../rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu | 40 +- .../common/rmsnorm/rmsnorm_fwd_cuda_kernel.cu | 38 +- 13 files changed, 838 insertions(+), 994 deletions(-) delete mode 100644 transformer_engine/common/layer_norm/ln.h create mode 100644 transformer_engine/common/layer_norm/norms.cpp create mode 100644 transformer_engine/common/layer_norm/norms.h delete mode 100644 transformer_engine/common/rmsnorm/rmsnorm.h diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index a6fd6815c3..06f6c8ae50 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -63,6 +63,7 @@ list(APPEND transformer_engine_SOURCES fused_attn/fused_attn.cpp fused_attn/utils.cu gemm/cublaslt_gemm.cu + layer_norm/norms.cpp layer_norm/ln_api.cpp layer_norm/ln_bwd_semi_cuda_kernel.cu layer_norm/ln_fwd_cuda_kernel.cu diff --git a/transformer_engine/common/layer_norm/ln.h b/transformer_engine/common/layer_norm/ln.h deleted file mode 100644 index 13543a10aa..0000000000 --- a/transformer_engine/common/layer_norm/ln.h +++ /dev/null @@ -1,239 +0,0 @@ -/************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * - * See LICENSE for license information. - ************************************************************************/ - -#ifndef TRANSFORMER_ENGINE_COMMON_LAYER_NORM_LN_H_ -#define TRANSFORMER_ENGINE_COMMON_LAYER_NORM_LN_H_ - -#include - -#include -#include -#include -#include -#include - -#include "../common.h" - -namespace transformer_engine { -namespace layer_norm { - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct LaunchParams { - size_t workspace_bytes; - size_t barrier_size; - - int multiprocessorCount; - cudaStream_t stream; - - Params params; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -struct ParamsBase { - ParamsBase() - : ctas_per_col(0), - rows(0), - cols(0), - x(nullptr), - mu(nullptr), - rs(nullptr), - gamma(nullptr), - workspace(nullptr), - barrier(nullptr), - zero_centered_gamma(false) {} - - // For Multi-CTA, number of different CTA groups. Otherwise same as gridDim.x. - int ctas_per_col; - // Size of CTA group. - int ctas_per_row; - - // Input is interpreted as matrix. We normalize across columns. - int rows; - int cols; - - // Common data pointers. - void *x; - void *mu; - void *rs; - void *gamma; - - // Multi-CTA workspace in gmem. - void *workspace; - - // Multi-CTA sync barriers in gmem. - int *barrier; - - // Whether gamma is centered around 0 - bool zero_centered_gamma; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -struct FwdParams : public ParamsBase { - FwdParams() : ParamsBase(), z(nullptr), beta(nullptr), epsilon(0.f), fp8_out(false) {} - - // Output of LN FWD. - void *z; - void *beta; - float epsilon; - - // Scaling factor - void *scale; - - // AMax output - void *amax; - - // Inverse of scaling factor - void *scale_inv; - - // Whether to compute scale and amax - bool fp8_out; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -struct BwdParams : public ParamsBase { - BwdParams() - : ParamsBase(), - dz(nullptr), - dbeta_part(nullptr), - dgamma_part(nullptr), - dx(nullptr), - dbeta(nullptr), - dgamma(nullptr) {} - - // Input: gradient wrt. LN FWD output. - void *dz; - - // Workspace for Wgrad pre-reduction. - void *dbeta_part; - void *dgamma_part; - - // Output: Dgrad. - void *dx; - // Output: Wgrad. - void *dbeta; - void *dgamma; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -using FwdFunction = std::function &, const bool)>; -using BwdFunction = std::function &, const bool)>; -using FunctionKey = uint64_t; -using FwdTunedRegistry = std::unordered_map; -using BwdTunedRegistry = std::unordered_map; -using FwdGeneralRegistry = std::unordered_map>; -using BwdGeneralRegistry = std::unordered_map>; - -extern FwdTunedRegistry FWD_TUNED_FUNCS; -extern BwdTunedRegistry BWD_TUNED_FUNCS; -extern FwdGeneralRegistry FWD_GENERAL_FUNCS; -extern BwdGeneralRegistry BWD_GENERAL_FUNCS; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct TypeId {}; - -template <> -struct TypeId { - constexpr static uint32_t Value = 0; -}; - -template <> -struct TypeId { - constexpr static uint32_t Value = 1; -}; - -template <> -struct TypeId { - constexpr static uint32_t Value = 2; -}; - -template <> -struct TypeId { - constexpr static uint32_t Value = 3; -}; - -template -struct Type2Key { - constexpr static uint32_t Value = TypeId::Value << S; -}; - -template -struct WeightType2Key : public Type2Key {}; - -template -struct InputType2Key : public Type2Key {}; - -template -struct OutputType2Key : public Type2Key {}; - -template -struct ComputeType2Key : public Type2Key {}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct Types2Key { - constexpr static uint32_t Value = WeightType2Key::Value | InputType2Key::Value | - OutputType2Key::Value | ComputeType2Key::Value; - constexpr static inline uint64_t get(const uint64_t hidden_size) { - constexpr uint64_t type_key = Value; - return (type_key << 32) | hidden_size; - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct FwdTunedRegistrar { - explicit FwdTunedRegistrar(FwdFunction f) { - uint64_t key = Types2Key::get(HIDDEN_SIZE); - FWD_TUNED_FUNCS.insert({key, f}); - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct FwdGeneralRegistrar { - explicit FwdGeneralRegistrar(FwdFunction f) { - uint64_t key = Types2Key::get(0); - FWD_GENERAL_FUNCS[key].insert({HIDDEN_SIZE, f}); - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct BwdTunedRegistrar { - explicit BwdTunedRegistrar(BwdFunction f) { - uint64_t key = Types2Key::get(HIDDEN_SIZE); - BWD_TUNED_FUNCS.insert({key, f}); - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct BwdGeneralRegistrar { - explicit BwdGeneralRegistrar(BwdFunction f) { - uint64_t key = Types2Key::get(0); - BWD_GENERAL_FUNCS[key].insert({HIDDEN_SIZE, f}); - } -}; - -////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace layer_norm -} // namespace transformer_engine - -#endif // TRANSFORMER_ENGINE_COMMON_LAYER_NORM_LN_H_ diff --git a/transformer_engine/common/layer_norm/ln_api.cpp b/transformer_engine/common/layer_norm/ln_api.cpp index 8a40450e59..f187aeb36a 100644 --- a/transformer_engine/common/layer_norm/ln_api.cpp +++ b/transformer_engine/common/layer_norm/ln_api.cpp @@ -7,10 +7,13 @@ #include #include +#include +#include +#include #include #include "../common.h" -#include "ln.h" +#include "norms.h" /* @@ -32,120 +35,6 @@ Compute always in FP32 */ namespace transformer_engine { -namespace layer_norm { - -using namespace transformer_engine; - -// Create registries and provide runtime versions of config hash functions. - -FwdTunedRegistry FWD_TUNED_FUNCS; -BwdTunedRegistry BWD_TUNED_FUNCS; -FwdGeneralRegistry FWD_GENERAL_FUNCS; -BwdGeneralRegistry BWD_GENERAL_FUNCS; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -uint32_t get_type_id(DType dtype) { - if (dtype == DType::kFloat16) { - return TypeId::Value; - } else if (dtype == DType::kBFloat16) { - return TypeId::Value; - } else if (dtype == DType::kFloat32) { - return TypeId::Value; - } else if (dtype == DType::kFloat8E4M3) { - return TypeId::Value; - } else { - NVTE_ERROR("Type not supported."); - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -uint64_t get_key(DType wtype, DType itype, DType otype, DType ctype, uint64_t hidden_size) { - using namespace layer_norm; - uint64_t type_key = get_type_id(wtype) | (get_type_id(itype) << 2) | (get_type_id(otype) << 4) | - (get_type_id(ctype) << 6); - uint64_t launcher_key = (type_key << 32) | hidden_size; - return launcher_key; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -layer_norm::FwdFunction& get_fwd_launcher(DType wtype, DType itype, DType otype, DType ctype, - const layer_norm::FwdParams& params) { - // Look for tuned kernel - auto tuned_key = layer_norm::get_key(wtype, itype, otype, ctype, params.cols); - auto is_aligned = [](const void* ptr) -> bool { - // Assume vectorized memory accesses are <=16B - return reinterpret_cast(ptr) % 16 == 0; - }; - if (params.rows % 4 == 0 && is_aligned(params.x) && is_aligned(params.mu) && - is_aligned(params.rs) && is_aligned(params.gamma) && is_aligned(params.beta) && - is_aligned(params.z) && layer_norm::FWD_TUNED_FUNCS.count(tuned_key) > 0) { - return layer_norm::FWD_TUNED_FUNCS.at(tuned_key); - } - - // Pick general kernel - auto general_key = layer_norm::get_key(wtype, itype, otype, ctype, 0); - if (layer_norm::FWD_GENERAL_FUNCS.count(general_key) == 0) { - NVTE_ERROR("FWD: Unsupported types."); - } - auto& general_func_map = layer_norm::FWD_GENERAL_FUNCS.at(general_key); - auto func_iter = general_func_map.lower_bound(params.cols); - if (func_iter == general_func_map.end()) { - // Hidden size is too big, need to use multi-CTA - return general_func_map.rbegin()->second; - } else { - return func_iter->second; - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -layer_norm::BwdFunction& get_bwd_launcher(DType wtype, DType itype, DType otype, DType ctype, - const layer_norm::BwdParams& params) { - // Look for tuned kernel - auto tuned_key = layer_norm::get_key(wtype, itype, otype, ctype, params.cols); - auto is_aligned = [](const void* ptr) -> bool { - // Assume vectorized memory accesses are <=16B - return reinterpret_cast(ptr) % 16 == 0; - }; - if (params.rows % 4 == 0 && is_aligned(params.x) && is_aligned(params.mu) && - is_aligned(params.rs) && is_aligned(params.gamma) && is_aligned(params.dz) && - is_aligned(params.dx) && is_aligned(params.dbeta) && is_aligned(params.dgamma) && - is_aligned(params.dbeta_part) && is_aligned(params.dgamma_part) && - layer_norm::BWD_TUNED_FUNCS.count(tuned_key) > 0) { - return layer_norm::BWD_TUNED_FUNCS.at(tuned_key); - } - - // Pick general kernel - auto general_key = layer_norm::get_key(wtype, itype, otype, ctype, 0); - if (layer_norm::BWD_GENERAL_FUNCS.count(general_key) == 0) { - NVTE_ERROR("BWD: Unsupported types."); - } - auto& general_func_map = layer_norm::BWD_GENERAL_FUNCS.at(general_key); - auto func_iter = general_func_map.lower_bound(params.cols); - if (func_iter == general_func_map.end()) { - // Hidden size is too big, need to use multi-CTA - return general_func_map.rbegin()->second; - } else { - return func_iter->second; - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -size_t product(const std::vector& shape) { - size_t ret = 1; - for (auto s : shape) { - ret *= s; - } - return ret; -} - -} // namespace layer_norm - -//////////////////////////////////////////////////////////////////////////////////////////////////// void layernorm_fwd(const Tensor& x, // BxSxhidden_size const Tensor& gamma, // hidden_size @@ -153,112 +42,36 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size const float epsilon, Tensor* z, Tensor* mu, Tensor* rsigma, cudaStream_t stream, const int multiprocessorCount, Tensor* workspace, Tensor* barrier, const bool zero_centered_gamma) { - const auto itype = x.data.dtype; - const auto wtype = gamma.data.dtype; - const auto otype = z->data.dtype; - const bool fp8_out = is_fp8_dtype(otype); - const auto ctype = layer_norm::DType::kFloat32; + using namespace transformer_engine; NVTE_CHECK(x.data.shape.size() == 2); - - const size_t rows = x.data.shape[0]; - const size_t cols = x.data.shape[1]; - const auto hidden_size = gamma.data.shape[0]; - NVTE_CHECK(gamma.data.shape == beta.data.shape); - NVTE_CHECK(hidden_size == cols); + NVTE_CHECK(x.data.shape[1] == gamma.data.shape[0]); NVTE_CHECK(epsilon >= 0.f); NVTE_CHECK(z->data.shape == x.data.shape); - NVTE_CHECK(mu->data.shape == std::vector{rows}); - NVTE_CHECK(mu->data.dtype == ctype); + NVTE_CHECK(mu->data.shape == std::vector{x.data.shape[0]}); + NVTE_CHECK(mu->data.dtype == DType::kFloat32); - NVTE_CHECK(rsigma->data.shape == std::vector{rows}); - NVTE_CHECK(rsigma->data.dtype == ctype); + NVTE_CHECK(rsigma->data.shape == std::vector{x.data.shape[0]}); + NVTE_CHECK(rsigma->data.dtype == DType::kFloat32); - layer_norm::LaunchParams launch_params; + if (workspace->data.dptr != nullptr) { + CheckInputTensor(x, "x"); + CheckInputTensor(gamma, "gamma"); + CheckInputTensor(beta, "beta"); - launch_params.multiprocessorCount = multiprocessorCount; - launch_params.stream = stream; - - // Set the kernel runtime parameters. - layer_norm::FwdParams& params = launch_params.params; - params.rows = rows; - params.cols = cols; - params.x = x.data.dptr; - params.mu = mu->data.dptr; - params.rs = rsigma->data.dptr; - params.gamma = gamma.data.dptr; - params.beta = beta.data.dptr; - params.z = z->data.dptr; - params.epsilon = epsilon; - params.amax = z->amax.dptr; - params.scale = z->scale.dptr; - params.scale_inv = z->scale_inv.dptr; - params.fp8_out = fp8_out; - params.zero_centered_gamma = zero_centered_gamma; - - // Request the kernel launcher. - auto launcher = layer_norm::get_fwd_launcher(wtype, itype, otype, ctype, params); - - // Query the kernel-specific launch parameters. - launcher(launch_params, true); - if (launch_params.workspace_bytes == 0) { - launch_params.workspace_bytes = 1; + CheckOutputTensor(*z, "z"); + CheckOutputTensor(*mu, "mu"); + CheckOutputTensor(*rsigma, "rsigma"); } - if (workspace->data.dptr == nullptr) { - NVTE_CHECK(barrier->data.dptr == nullptr); - - workspace->data.dtype = layer_norm::DType::kByte; - workspace->data.shape = {launch_params.workspace_bytes}; - - barrier->data.dtype = layer_norm::DType::kInt32; - barrier->data.shape = {launch_params.barrier_size}; - - return; - } else { - NVTE_CHECK(workspace->data.dtype == layer_norm::DType::kByte); - NVTE_CHECK(workspace->data.shape == std::vector{launch_params.workspace_bytes}); - } - - if (launch_params.barrier_size > 0) { - NVTE_CHECK(barrier->data.dptr != nullptr); - NVTE_CHECK(barrier->data.dtype == layer_norm::DType::kInt32); - NVTE_CHECK(barrier->data.shape == std::vector{launch_params.barrier_size}); - } - - // Tensor checks are delayed here in order to recover workspace sizes with null data - CheckInputTensor(x, "x"); - CheckInputTensor(gamma, "gamma"); - CheckInputTensor(beta, "beta"); - - CheckOutputTensor(*z, "z"); - CheckOutputTensor(*mu, "mu"); - CheckOutputTensor(*rsigma, "rsigma"); - - if (launch_params.barrier_size > 0) { - params.workspace = workspace->data.dptr; - params.barrier = reinterpret_cast(barrier->data.dptr); - } - - // Clear buffers - if (params.fp8_out) { - cudaMemsetAsync(params.amax, 0, layer_norm::product(z->amax.shape) * typeToSize(z->amax.dtype), - stream); - } - if (launch_params.barrier_size > 0) { - cudaMemsetAsync(params.barrier, 0, - layer_norm::product(barrier->data.shape) * typeToSize(barrier->data.dtype), - stream); - } - - // Launch the kernel. - launcher(launch_params, false); - - return; + NormFwdTe NormFwd(x, gamma, beta, epsilon, z, mu, rsigma, stream, + multiprocessorCount, workspace, barrier, + zero_centered_gamma); + norms_launcher(NormFwd, workspace, barrier); } void layernorm_bwd(const Tensor& dz, const Tensor& x, const Tensor& mu, const Tensor& rsigma, @@ -267,27 +80,17 @@ void layernorm_bwd(const Tensor& dz, const Tensor& x, const Tensor& mu, const Te const int multiprocessorCount, Tensor* workspace, Tensor* barrier, const bool zero_centered_gamma) { using namespace transformer_engine; - - auto itype = x.data.dtype; - auto wtype = gamma.data.dtype; - auto otype = wtype; - auto ctype = DType::kFloat32; - - NVTE_CHECK(dz.data.dtype == otype); - NVTE_CHECK(mu.data.dtype == ctype); - NVTE_CHECK(rsigma.data.dtype == ctype); + NVTE_CHECK(dz.data.dtype == gamma.data.dtype); + NVTE_CHECK(mu.data.dtype == DType::kFloat32); + NVTE_CHECK(rsigma.data.dtype == mu.data.dtype); NVTE_CHECK(x.data.shape.size() == 2); NVTE_CHECK(dz.data.shape == x.data.shape); - auto rows = x.data.shape[0]; - auto cols = x.data.shape[1]; - auto hidden_size = gamma.data.shape[0]; - - NVTE_CHECK(mu.data.shape[0] == rows); + NVTE_CHECK(mu.data.shape[0] == x.data.shape[0]); NVTE_CHECK(mu.data.shape == rsigma.data.shape); - NVTE_CHECK(gamma.data.shape[0] == cols); + NVTE_CHECK(gamma.data.shape[0] == x.data.shape[1]); NVTE_CHECK(dx->data.shape == x.data.shape); NVTE_CHECK(dx->data.dtype == x.data.dtype); @@ -298,93 +101,21 @@ void layernorm_bwd(const Tensor& dz, const Tensor& x, const Tensor& mu, const Te NVTE_CHECK(dbeta->data.shape == gamma.data.shape); NVTE_CHECK(dbeta->data.dtype == gamma.data.dtype); - layer_norm::LaunchParams launch_params; - launch_params.stream = stream; - launch_params.multiprocessorCount = multiprocessorCount; - - // Set the kernel runtime parameters. - layer_norm::BwdParams& params = launch_params.params; - params.rows = rows; - params.cols = cols; - params.x = x.data.dptr; - params.mu = mu.data.dptr; - params.rs = rsigma.data.dptr; - params.gamma = gamma.data.dptr; - params.dz = dz.data.dptr; - params.dx = dx->data.dptr; - params.dbeta = dbeta->data.dptr; - params.dgamma = dgamma->data.dptr; - params.dbeta_part = dbeta_part->data.dptr; - params.dgamma_part = dgamma_part->data.dptr; - params.zero_centered_gamma = zero_centered_gamma; - - auto launcher = layer_norm::get_bwd_launcher(wtype, itype, otype, ctype, params); - - // Query the kernel-specific launch parameters. - launcher(launch_params, true); - - // Populate shape and dtypes for FW to allocate memory - if (dgamma_part->data.dptr == nullptr) { - NVTE_CHECK(dbeta_part->data.dptr == nullptr); - - dgamma_part->data.dtype = ctype; - dgamma_part->data.shape = {static_cast(launch_params.params.ctas_per_col), - hidden_size}; - - dbeta_part->data.dtype = ctype; - dbeta_part->data.shape = {static_cast(launch_params.params.ctas_per_col), - hidden_size}; - - workspace->data.dtype = layer_norm::DType::kByte; - workspace->data.shape = {launch_params.workspace_bytes}; - - barrier->data.dtype = layer_norm::DType::kInt32; - barrier->data.shape = {launch_params.barrier_size}; - - return; - } else { - NVTE_CHECK(dbeta_part->data.dptr != nullptr); - auto pdw_shape = - std::vector{static_cast(launch_params.params.ctas_per_col), hidden_size}; - - NVTE_CHECK(dgamma_part->data.dtype == ctype); - NVTE_CHECK(dgamma_part->data.shape == pdw_shape); - NVTE_CHECK(dbeta_part->data.dtype == ctype); - NVTE_CHECK(dbeta_part->data.shape == pdw_shape); - } - - if (launch_params.barrier_size > 0) { - NVTE_CHECK(barrier->data.dptr != nullptr); - NVTE_CHECK(barrier->data.dtype == layer_norm::DType::kInt32); - NVTE_CHECK(barrier->data.shape == std::vector{launch_params.barrier_size}); - } - - if (launch_params.workspace_bytes > 0) { - NVTE_CHECK(workspace->data.dptr != nullptr); - NVTE_CHECK(workspace->data.dtype == layer_norm::DType::kByte); - NVTE_CHECK(workspace->data.shape == std::vector{launch_params.workspace_bytes}); - } - - // Tensor checks are delayed here in order to recover workspace sizes with null data - CheckInputTensor(dz, "dz"); - CheckInputTensor(x, "x"); - CheckInputTensor(mu, "mu"); - CheckInputTensor(rsigma, "rsigma"); - CheckInputTensor(gamma, "gamma"); - CheckOutputTensor(*dx, "dx"); - CheckOutputTensor(*dgamma, "dgamma"); - CheckOutputTensor(*dbeta, "dbeta"); - - if (launch_params.barrier_size > 0) { - params.workspace = workspace->data.dptr; - params.barrier = reinterpret_cast(barrier->data.dptr); - cudaMemsetAsync(params.barrier, 0, - layer_norm::product(barrier->data.shape) * typeToSize(barrier->data.dtype), - stream); - } - - // Launch the kernel. - launcher(launch_params, false); + if (workspace->data.dptr) { + CheckInputTensor(dz, "dz"); + CheckInputTensor(x, "x"); + CheckInputTensor(mu, "mu"); + CheckInputTensor(rsigma, "rsigma"); + CheckInputTensor(gamma, "gamma"); + CheckOutputTensor(*dx, "dx"); + CheckOutputTensor(*dgamma, "dgamma"); + CheckOutputTensor(*dbeta, "dbeta"); + } + + NormBwdTe BwdNorm(dz, x, mu, rsigma, gamma, dx, dgamma, dbeta, + dgamma_part, dbeta_part, stream, multiprocessorCount, + workspace, barrier, zero_centered_gamma); + norms_launcher(BwdNorm, workspace, barrier, dgamma_part, dbeta_part); } } // namespace transformer_engine diff --git a/transformer_engine/common/layer_norm/ln_bwd_kernels.cuh b/transformer_engine/common/layer_norm/ln_bwd_kernels.cuh index dbd0025244..b6a93cfc6a 100644 --- a/transformer_engine/common/layer_norm/ln_bwd_kernels.cuh +++ b/transformer_engine/common/layer_norm/ln_bwd_kernels.cuh @@ -8,7 +8,7 @@ #define TRANSFORMER_ENGINE_COMMON_LAYER_NORM_LN_BWD_KERNELS_CUH_ #include "../utils.cuh" -#include "ln.h" +#include "norms.h" namespace transformer_engine { namespace layer_norm { diff --git a/transformer_engine/common/layer_norm/ln_bwd_semi_cuda_kernel.cu b/transformer_engine/common/layer_norm/ln_bwd_semi_cuda_kernel.cu index 17f1256910..4002717b01 100644 --- a/transformer_engine/common/layer_norm/ln_bwd_semi_cuda_kernel.cu +++ b/transformer_engine/common/layer_norm/ln_bwd_semi_cuda_kernel.cu @@ -4,9 +4,9 @@ * See LICENSE for license information. ************************************************************************/ -#include "ln.h" #include "ln_bwd_kernels.cuh" #include "ln_kernel_traits.h" +#include "norms.h" using namespace transformer_engine::layer_norm; @@ -131,27 +131,27 @@ void launch_general_(LaunchParams &launch_params, //////////////////////////////////////////////////////////////////////////////////////////////////// -#define REGISTER_BWD_TUNED_LAUNCHER(HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, CTAS_PER_ROW, \ - WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINALIZE) \ - void ln_bwd_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ - LaunchParams &launch_params, const bool configure_params) { \ - launch_tuned_(launch_params, \ - configure_params); \ - } \ - static BwdTunedRegistrar \ - reg_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ - ln_bwd_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE) - -#define REGISTER_BWD_GENERAL_LAUNCHER(HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, WARPS_M, WARPS_N, \ - BYTES_PER_LDG, BYTES_PER_LDG_FINALIZE) \ - void ln_bwd_general_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ +#define REGISTER_BWD_TUNED_LAUNCHER(HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, CTAS_PER_ROW, \ + WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINALIZE) \ + void ln_bwd_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ LaunchParams &launch_params, const bool configure_params) { \ - launch_general_(launch_params, configure_params); \ + launch_tuned_(launch_params, \ + configure_params); \ } \ - static BwdGeneralRegistrar \ - reg_general_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ + static NormRegistrar \ + reg_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ + ln_bwd_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE) + +#define REGISTER_BWD_GENERAL_LAUNCHER(HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, WARPS_M, WARPS_N, \ + BYTES_PER_LDG, BYTES_PER_LDG_FINALIZE) \ + void ln_bwd_general_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ + LaunchParams &launch_params, const bool configure_params) { \ + launch_general_(launch_params, configure_params); \ + } \ + static NormRegistrar \ + reg_general_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ ln_bwd_general_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE) //////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/transformer_engine/common/layer_norm/ln_fwd_cuda_kernel.cu b/transformer_engine/common/layer_norm/ln_fwd_cuda_kernel.cu index 0c85f4aeb7..cb9436a5b8 100644 --- a/transformer_engine/common/layer_norm/ln_fwd_cuda_kernel.cu +++ b/transformer_engine/common/layer_norm/ln_fwd_cuda_kernel.cu @@ -4,9 +4,9 @@ * See LICENSE for license information. ************************************************************************/ -#include "ln.h" #include "ln_fwd_kernels.cuh" #include "ln_kernel_traits.h" +#include "norms.h" using namespace transformer_engine::layer_norm; @@ -106,29 +106,30 @@ void launch_general_(LaunchParams &launch_params, //////////////////////////////////////////////////////////////////////////////////////////////////// -#define REGISTER_FWD_TUNED_LAUNCHER(HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, CTAS_PER_ROW, \ - WARPS_M, WARPS_N, BYTES_PER_LDG) \ - void ln_fwd_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ - LaunchParams &launch_params, const bool configure_params) { \ - launch_tuned_(launch_params, configure_params); \ - } \ - static FwdTunedRegistrar \ - reg_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ - ln_fwd_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE) - -#define REGISTER_FWD_GENERAL_LAUNCHER(HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, WARPS_M, WARPS_N, \ - BYTES_PER_LDG) \ - void ln_fwd_general_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ +#define REGISTER_FWD_TUNED_LAUNCHER(HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, CTAS_PER_ROW, \ + WARPS_M, WARPS_N, BYTES_PER_LDG) \ + void ln_fwd_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ LaunchParams &launch_params, const bool configure_params) { \ - launch_general_(launch_params, configure_params); \ + launch_tuned_(launch_params, configure_params); \ } \ - static FwdGeneralRegistrar \ - reg_general_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ + static NormRegistrar \ + reg_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ + ln_fwd_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE) + +#define REGISTER_FWD_GENERAL_LAUNCHER(HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, WARPS_M, WARPS_N, \ + BYTES_PER_LDG) \ + void ln_fwd_general_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ + LaunchParams &launch_params, const bool configure_params) { \ + launch_general_(launch_params, configure_params); \ + } \ + static NormRegistrar \ + reg_general_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ ln_fwd_general_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE) //////////////////////////////////////////////////////////////////////////////////////////////////// +/// // Create tuned launch function and register. Macro signature: // HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG diff --git a/transformer_engine/common/layer_norm/ln_fwd_kernels.cuh b/transformer_engine/common/layer_norm/ln_fwd_kernels.cuh index bd3741d1d1..b791d71bdb 100644 --- a/transformer_engine/common/layer_norm/ln_fwd_kernels.cuh +++ b/transformer_engine/common/layer_norm/ln_fwd_kernels.cuh @@ -11,7 +11,7 @@ #include #include "../utils.cuh" -#include "ln.h" +#include "norms.h" namespace transformer_engine { namespace layer_norm { diff --git a/transformer_engine/common/layer_norm/norms.cpp b/transformer_engine/common/layer_norm/norms.cpp new file mode 100644 index 0000000000..328969052d --- /dev/null +++ b/transformer_engine/common/layer_norm/norms.cpp @@ -0,0 +1,350 @@ +/************************************************************************* + * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/* #include */ + +#include "norms.h" + +#include +#include +#include +#include + +/* + +Supported Type combinations: + +input compute weights output +======================================= +fp32 fp32 fp32 fp32 +fp16 fp32 fp16 fp16 +bf16 fp32 bf16 bf16 +fp32 fp32 fp16 fp16 +fp32 fp32 bf16 bf16 +bf16 fp32 bf16 fp8 + +Remarks: +Output type = Weight type +Compute always in FP32 + +*/ + +namespace transformer_engine { + +// Create registries and provide runtime versions of config hash functions. +FwdTunedRegistry LN_FWD_TUNED_FUNCS; +BwdTunedRegistry LN_BWD_TUNED_FUNCS; +FwdGeneralRegistry LN_FWD_GENERAL_FUNCS; +BwdGeneralRegistry LN_BWD_GENERAL_FUNCS; + +FwdTunedRegistry RMS_FWD_TUNED_FUNCS; +BwdTunedRegistry RMS_BWD_TUNED_FUNCS; +FwdGeneralRegistry RMS_FWD_GENERAL_FUNCS; +BwdGeneralRegistry RMS_BWD_GENERAL_FUNCS; + +uint32_t get_type_id(DType dtype) { + if (dtype == DType::kFloat16) { + return TypeId::Value; + } else if (dtype == DType::kBFloat16) { + return TypeId::Value; + } else if (dtype == DType::kFloat32) { + return TypeId::Value; + } else if (dtype == DType::kFloat8E4M3) { + return TypeId::Value; + } else { + NVTE_ERROR("Type not supported."); + } +} + +uint64_t get_key(DType wtype, DType itype, DType otype, DType ctype, uint64_t hidden_size) { + uint64_t type_key = get_type_id(wtype) | (get_type_id(itype) << 2) | (get_type_id(otype) << 4) | + (get_type_id(ctype) << 6); + uint64_t launcher_key = (type_key << 32) | hidden_size; + return launcher_key; +} + +template +FwdFunction& get_fwd_launcher(DType wtype, DType itype, DType otype, DType ctype, + const FwdParams& params) { + if constexpr (!IF_TE_FWD_NORMS()) NVTE_ERROR("Unexpected NVTE_NORM_TYPE!"); + + auto& FWD_TUNED_FUNCS = GET_REGISTRY(); + auto& FWD_GENERAL_FUNCS = GET_REGISTRY(); + + // Look for tuned kernel + auto tuned_key = get_key(wtype, itype, otype, ctype, params.cols); + auto is_aligned = [](const void* ptr) -> bool { + // Assume vectorized memory accesses are <=16B + return reinterpret_cast(ptr) % 16 == 0; + }; + if (params.rows % 4 == 0 && is_aligned(params.x) && is_aligned(params.rs) && + is_aligned(params.gamma) && is_aligned(params.z) && FWD_TUNED_FUNCS.count(tuned_key) > 0) { + if constexpr (NormEnum == NVTE_NORM_TYPE::LN_FWD_TE) { + if (is_aligned(params.mu) && is_aligned(params.beta)) return FWD_TUNED_FUNCS.at(tuned_key); + } else + return FWD_TUNED_FUNCS.at(tuned_key); + } + + // Pick general kernel + auto general_key = get_key(wtype, itype, otype, ctype, 0); + if (FWD_GENERAL_FUNCS.count(general_key) == 0) { + NVTE_ERROR("FWD: Unsupported types."); + } + auto& general_func_map = FWD_GENERAL_FUNCS.at(general_key); + auto func_iter = general_func_map.lower_bound(params.cols); + if (func_iter == general_func_map.end()) { + // Hidden size is too big, need to use multi-CTA + return general_func_map.rbegin()->second; + } else { + return func_iter->second; + } +} + +template +BwdFunction& get_bwd_launcher(DType wtype, DType itype, DType otype, DType ctype, + const BwdParams& params) { + if constexpr (!IF_TE_BWD_NORMS()) NVTE_ERROR("Unexpected NVTE_NORM_TYPE!"); + + auto& BWD_TUNED_FUNCS = GET_REGISTRY(); + auto& BWD_GENERAL_FUNCS = GET_REGISTRY(); + + // Look for tuned kernel + auto tuned_key = get_key(wtype, itype, otype, ctype, params.cols); + auto is_aligned = [](const void* ptr) -> bool { + // Assume vectorized memory accesses are <=16B + return reinterpret_cast(ptr) % 16 == 0; + }; + if (params.rows % 4 == 0 && is_aligned(params.x) && is_aligned(params.rs) && + is_aligned(params.gamma) && is_aligned(params.dz) && is_aligned(params.dx) && + is_aligned(params.dgamma) && is_aligned(params.dgamma_part) && + BWD_TUNED_FUNCS.count(tuned_key) > 0) { + if constexpr (NormEnum == NVTE_NORM_TYPE::LN_BWD_TE) { + if (is_aligned(params.mu) && is_aligned(params.dbeta) && is_aligned(params.dbeta_part)) + return BWD_TUNED_FUNCS.at(tuned_key); + + } else + return BWD_TUNED_FUNCS.at(tuned_key); + } + + // Pick general kernel + auto general_key = get_key(wtype, itype, otype, ctype, 0); + if (BWD_GENERAL_FUNCS.count(general_key) == 0) { + NVTE_ERROR("BWD: Unsupported types."); + } + auto& general_func_map = BWD_GENERAL_FUNCS.at(general_key); + auto func_iter = general_func_map.lower_bound(params.cols); + if (func_iter == general_func_map.end()) { + // Hidden size is too big, need to use multi-CTA + return general_func_map.rbegin()->second; + } else { + return func_iter->second; + } +} + +template +NormFwdTe::NormFwdTe() { + if constexpr (NormEnum == NVTE_NORM_TYPE::LN_FWD_TE) { + NVTE_ERROR("NormFwdTe default constructor is only for its inherited classes!"); + } +} + +template +NormFwdTe::NormFwdTe(const Tensor& x, const Tensor& gamma, const Tensor& beta, + const float epsilon, Tensor* z, Tensor* mu, Tensor* rsigma, + cudaStream_t stream, const int multiprocessorCount, + Tensor* workspace, Tensor* barrier, const bool zero_centered_gamma) { + if constexpr (!IF_TE_FWD_NORMS()) { + NVTE_ERROR("Unexpected NVTE_NORM_TYPE!"); + } + _launch_params.multiprocessorCount = multiprocessorCount; + _launch_params.stream = stream; + + // Set the kernel runtime parameters. + auto& params = _launch_params.params; + params.rows = x.data.shape[0]; + params.cols = x.data.shape[1]; + params.x = x.data.dptr; + params.rs = rsigma->data.dptr; + params.gamma = gamma.data.dptr; + params.z = z->data.dptr; + params.epsilon = epsilon; + params.amax = z->amax.dptr; + params.amax_byte_size = product(z->amax.shape) * typeToSize(z->amax.dtype); + params.scale = z->scale.dptr; + params.scale_inv = z->scale_inv.dptr; + params.scale_byte_size = product(z->scale.shape) * typeToSize(z->scale.dtype); + params.fp8_out = is_fp8_dtype(z->data.dtype); + params.zero_centered_gamma = zero_centered_gamma; + if constexpr (NormEnum == NVTE_NORM_TYPE::LN_FWD_TE) { + params.mu = mu->data.dptr; + params.beta = beta.data.dptr; + } + + // Request the kernel launcher. + _launcher = get_fwd_launcher(gamma.data.dtype, // wtype + x.data.dtype, // itype, + z->data.dtype, // otype, + DType::kFloat32, // ctype, + params); + if (params.fp8_out) set_amax(); +} + +/*** BWD TE ***/ +template +NormBwdTe::NormBwdTe(const Tensor& dz, const Tensor& x, const Tensor& mu, + const Tensor& rsigma, const Tensor& gamma, Tensor* dx, + Tensor* dgamma, Tensor* dbeta, Tensor* dgamma_part, + Tensor* dbeta_part, cudaStream_t stream, + const int multiprocessorCount, Tensor* workspace, Tensor* barrier, + const bool zero_centered_gamma) + : NormFwdTe::NormFwdTe() { + if constexpr (!IF_TE_BWD_NORMS()) NVTE_ERROR("Unexpected NVTE_NORM_TYPE!"); + + auto& _launch_params = NormFwdTe::_launch_params; + _launch_params.stream = stream; + _launch_params.multiprocessorCount = multiprocessorCount; + + // Set the kernel runtime parameters. + auto& params = _launch_params.params; + params.rows = x.data.shape[0]; + params.cols = x.data.shape[1]; + params.x = x.data.dptr; + params.rs = rsigma.data.dptr; + params.gamma = gamma.data.dptr; + params.dz = dz.data.dptr; + params.dx = dx->data.dptr; + params.dgamma = dgamma->data.dptr; + params.dgamma_part = dgamma_part->data.dptr; + params.zero_centered_gamma = zero_centered_gamma; + + if constexpr (NormEnum == NVTE_NORM_TYPE::LN_BWD_TE) { + params.mu = mu.data.dptr; + params.dbeta = dbeta->data.dptr; + params.dbeta_part = dbeta_part->data.dptr; + } + + NormFwdTe::_launcher = get_bwd_launcher(gamma.data.dtype, // wtype, + x.data.dtype, // itype, + gamma.data.dtype, // otype, + DType::kFloat32, // ctype, + params); +} + +template +void NormFwdTe::initialize() { + // Query the kernel-specific launch parameters. + _launcher(_launch_params, true); + if (_launch_params.workspace_bytes == 0) { + _launch_params.workspace_bytes = 1; + } +} + +template +void NormFwdTe::set_workspace_and_barrier(void* workspace_ptr, void* barrier_ptr) { + NVTE_CHECK(_launch_params.workspace_bytes); + _launch_params.params.workspace = workspace_ptr; + + if (_launch_params.barrier_size > 0) { + _launch_params.params.barrier = reinterpret_cast(barrier_ptr); + cudaMemsetAsync(_launch_params.params.barrier, 0, + _launch_params.barrier_size * typeToSize(DType::kFloat32), + _launch_params.stream); + } +} + +template +void NormFwdTe::set_amax() { + cudaMemsetAsync(_launch_params.params.amax, 0, _launch_params.params.amax_byte_size, + _launch_params.stream); +} + +template +void NormFwdTe::execute() { + _launcher(_launch_params, false); +} + +template +std::vector NormFwdTe::get_workspace_shape() { + return {_launch_params.workspace_bytes}; +} + +template +std::vector NormFwdTe::get_barrier_shape() { + return {_launch_params.barrier_size}; +} + +template +std::vector NormBwdTe::get_dgamma_shape() { + if constexpr (!IF_TE_BWD_NORMS()) NVTE_ERROR("Unexpected NVTE_NORM_TYPE!"); + return {static_cast(NormBwdTe::_launch_params.params.ctas_per_col), + static_cast(NormBwdTe::_launch_params.params.cols)}; +} + +template +void norms_launcher(NormType& Norm, Tensor* workspace, Tensor* barrier, Tensor* dgamma_part, + Tensor* dbeta_part) { + Norm.initialize(); + + // Populate shape and dtypes for FW to allocate memory + void* test_ptr = IF_TE_BWD_NORMS() ? dgamma_part->data.dptr : workspace->data.dptr; + if (test_ptr == nullptr) { + if constexpr (IF_TE_BWD_NORMS()) { + NVTE_CHECK(dgamma_part->data.dptr == nullptr); + dgamma_part->data.dtype = DType::kFloat32; + dgamma_part->data.shape = Norm.get_dgamma_shape(); + } + if constexpr (NormEnum == NVTE_NORM_TYPE::LN_BWD_TE) { + NVTE_CHECK(dbeta_part->data.dptr == nullptr); + dbeta_part->data.dtype = DType::kFloat32; + dbeta_part->data.shape = Norm.get_dgamma_shape(); + } + if constexpr (IF_TE_NORMS()) { + barrier->data.dtype = DType::kInt32; + barrier->data.shape = Norm.get_barrier_shape(); + } + workspace->data.dtype = DType::kByte; + workspace->data.shape = Norm.get_workspace_shape(); + + return; + } else { + if constexpr (IF_TE_BWD_NORMS()) { + NVTE_CHECK(dgamma_part->data.dtype == DType::kFloat32); + NVTE_CHECK(dgamma_part->data.shape == Norm.get_dgamma_shape()); + } + if constexpr (NormEnum == NVTE_NORM_TYPE::LN_BWD_TE) { + NVTE_CHECK(dbeta_part->data.dptr != nullptr); + NVTE_CHECK(dbeta_part->data.dtype == DType::kFloat32); + NVTE_CHECK(dbeta_part->data.shape == Norm.get_dgamma_shape()); + } + if constexpr (IF_TE_NORMS()) { + NVTE_CHECK(barrier->data.dtype == DType::kInt32); + NVTE_CHECK(barrier->data.shape == Norm.get_barrier_shape()); + } + NVTE_CHECK(workspace->data.dtype == DType::kByte); + NVTE_CHECK(workspace->data.shape == Norm.get_workspace_shape()); + } + + auto barrier_ptr = barrier != nullptr ? barrier->data.dptr : nullptr; + Norm.set_workspace_and_barrier(workspace->data.dptr, barrier_ptr); + + Norm.execute(); +} + +template class NormFwdTe; +template class NormFwdTe; +template class NormBwdTe; +template class NormBwdTe; + +template void norms_launcher>( + NormFwdTe&, Tensor*, Tensor*, Tensor*, Tensor*); +template void norms_launcher>( + NormBwdTe&, Tensor*, Tensor*, Tensor*, Tensor*); +template void norms_launcher>( + NormFwdTe&, Tensor*, Tensor*, Tensor*, Tensor*); +template void norms_launcher>( + NormBwdTe&, Tensor*, Tensor*, Tensor*, Tensor*); + +} // namespace transformer_engine diff --git a/transformer_engine/common/layer_norm/norms.h b/transformer_engine/common/layer_norm/norms.h new file mode 100644 index 0000000000..6181dc4921 --- /dev/null +++ b/transformer_engine/common/layer_norm/norms.h @@ -0,0 +1,334 @@ +/************************************************************************* + * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#ifndef TRANSFORMER_ENGINE_COMMON_LAYER_NORM_LN_H_ +#define TRANSFORMER_ENGINE_COMMON_LAYER_NORM_LN_H_ + +#include + +#include +#include +#include +#include +#include + +#include "../common.h" + +namespace transformer_engine { + +template +struct LaunchParams { + size_t workspace_bytes; + size_t barrier_size; + + int multiprocessorCount; + cudaStream_t stream; + + Params params; +}; + +struct ParamsBase { + ParamsBase() + : ctas_per_col(0), + rows(0), + cols(0), + x(nullptr), + mu(nullptr), + rs(nullptr), + gamma(nullptr), + workspace(nullptr), + barrier(nullptr), + zero_centered_gamma(false) {} + + // For Multi-CTA, number of different CTA groups. Otherwise same as gridDim.x. + int ctas_per_col; + // Size of CTA group. + int ctas_per_row; + + // Input is interpreted as matrix. We normalize across columns. + int rows; + int cols; + + // Common data pointers. + void* x; + void* mu; + void* rs; + void* gamma; + + // Multi-CTA workspace in gmem. + void* workspace; + + // Multi-CTA sync barriers in gmem. + int* barrier; + + // Whether gamma is centered around 0 + bool zero_centered_gamma; +}; + +struct FwdParams : public ParamsBase { + FwdParams() : ParamsBase(), z(nullptr), beta(nullptr), epsilon(0.f), fp8_out(false) {} + + // Output of LN FWD. + void* z; + void* beta; + float epsilon; + + // Scaling factor + void* scale; + int scale_byte_size; + + // Inverse of scaling factor + void* scale_inv; + + // AMax output + void* amax; + int amax_byte_size; + + // Whether to compute scale and amax + bool fp8_out; +}; + +struct BwdParams : public ParamsBase { + BwdParams() + : ParamsBase(), + dz(nullptr), + dbeta_part(nullptr), + dgamma_part(nullptr), + dx(nullptr), + dbeta(nullptr), + dgamma(nullptr) {} + + // Input: gradient wrt. LN FWD output. + void* dz; + + // Workspace for Wgrad pre-reduction. + void* dbeta_part; + void* dgamma_part; + + // Output: Dgrad. + void* dx; + // Output: Wgrad. + void* dbeta; + void* dgamma; +}; + +enum NVTE_NORM_TYPE { + LN_FWD_TE, + LN_BWD_TE, + LN_FWD_CUDNN, + LN_BWD_CUDNN, + RMS_FWD_TE, + RMS_BWD_TE, + RMS_FWD_CUDNN, + RMS_BWD_CUDNN, +}; + +template +constexpr bool IF_TE_NORMS() { + return (NormEnum == NVTE_NORM_TYPE::LN_FWD_TE || NormEnum == NVTE_NORM_TYPE::LN_BWD_TE || + NormEnum == NVTE_NORM_TYPE::RMS_FWD_TE || NormEnum == NVTE_NORM_TYPE::RMS_BWD_TE); +}; + +template +constexpr bool IF_TE_FWD_NORMS() { + return (NormEnum == NVTE_NORM_TYPE::LN_FWD_TE || NormEnum == NVTE_NORM_TYPE::RMS_FWD_TE); +}; + +template +constexpr bool IF_TE_BWD_NORMS() { + return (NormEnum == NVTE_NORM_TYPE::LN_BWD_TE || NormEnum == NVTE_NORM_TYPE::RMS_BWD_TE); +}; + +using FwdFunction = std::function&, const bool)>; +using BwdFunction = std::function&, const bool)>; +using FunctionKey = uint64_t; +using FwdTunedRegistry = std::unordered_map; +using BwdTunedRegistry = std::unordered_map; +using FwdGeneralRegistry = std::unordered_map>; +using BwdGeneralRegistry = std::unordered_map>; + +template +struct LauncherType; + +template +struct LauncherType()>::type> { + using ParamsType = std::conditional_t(), LaunchParams, + LaunchParams>; + using FunctionType = std::conditional_t(), FwdFunction, BwdFunction>; +}; + +extern FwdTunedRegistry LN_FWD_TUNED_FUNCS; +extern BwdTunedRegistry LN_BWD_TUNED_FUNCS; +extern FwdGeneralRegistry LN_FWD_GENERAL_FUNCS; +extern BwdGeneralRegistry LN_BWD_GENERAL_FUNCS; + +extern FwdTunedRegistry RMS_FWD_TUNED_FUNCS; +extern BwdTunedRegistry RMS_BWD_TUNED_FUNCS; +extern FwdGeneralRegistry RMS_FWD_GENERAL_FUNCS; +extern BwdGeneralRegistry RMS_BWD_GENERAL_FUNCS; + +template +struct RegistryType {}; + +template +struct RegistryType()>::type> { + using type = std::conditional_t< + IF_TUNED, std::conditional_t(), FwdTunedRegistry, BwdTunedRegistry>, + std::conditional_t(), FwdGeneralRegistry, BwdGeneralRegistry>>; +}; + +template +constexpr typename RegistryType::type& GET_REGISTRY() { + if constexpr (!IF_TE_NORMS()) NVTE_ERROR("Unexpected NVTE_NORM_TYPE!"); + if constexpr (IF_TUNED) { + if constexpr (NormEnum == NVTE_NORM_TYPE::LN_FWD_TE) + return LN_FWD_TUNED_FUNCS; + else if constexpr (NormEnum == NVTE_NORM_TYPE::LN_BWD_TE) + return LN_BWD_TUNED_FUNCS; + else if constexpr (NormEnum == NVTE_NORM_TYPE::RMS_FWD_TE) + return RMS_FWD_TUNED_FUNCS; + else if constexpr (NormEnum == NVTE_NORM_TYPE::RMS_BWD_TE) + return RMS_BWD_TUNED_FUNCS; + } else { + if constexpr (NormEnum == NVTE_NORM_TYPE::LN_FWD_TE) + return LN_FWD_GENERAL_FUNCS; + else if constexpr (NormEnum == NVTE_NORM_TYPE::LN_BWD_TE) + return LN_BWD_GENERAL_FUNCS; + else if constexpr (NormEnum == NVTE_NORM_TYPE::RMS_FWD_TE) + return RMS_FWD_GENERAL_FUNCS; + else if constexpr (NormEnum == NVTE_NORM_TYPE::RMS_BWD_TE) + return RMS_BWD_GENERAL_FUNCS; + } +}; + +template +struct TypeId {}; + +template <> +struct TypeId { + constexpr static uint32_t Value = 0; +}; + +template <> +struct TypeId { + constexpr static uint32_t Value = 1; +}; + +template <> +struct TypeId { + constexpr static uint32_t Value = 2; +}; + +template <> +struct TypeId { + constexpr static uint32_t Value = 3; +}; + +template +struct Type2Key { + constexpr static uint32_t Value = TypeId::Value << S; +}; + +template +struct WeightType2Key : public Type2Key {}; + +template +struct InputType2Key : public Type2Key {}; + +template +struct OutputType2Key : public Type2Key {}; + +template +struct ComputeType2Key : public Type2Key {}; + +template +struct Types2Key { + constexpr static uint32_t Value = WeightType2Key::Value | InputType2Key::Value | + OutputType2Key::Value | ComputeType2Key::Value; + constexpr static inline uint64_t get(const uint64_t hidden_size) { + constexpr uint64_t type_key = Value; + return (type_key << 32) | hidden_size; + } +}; + +template +struct NormRegistrar {}; + +template +struct NormRegistrar()>::type> { + explicit NormRegistrar(typename LauncherType::FunctionType f) { + auto& registry = GET_REGISTRY(); + if constexpr (IF_TUNED) { + uint64_t key = Types2Key::get(HIDDEN_SIZE); + registry.insert({key, f}); + } else { + uint64_t key = Types2Key::get(0); + registry[key].insert({HIDDEN_SIZE, f}); + } + } +}; + +class NormBase { + public: + virtual void initialize() = 0; + + virtual void execute() = 0; + + virtual void set_workspace_and_barrier(void* workspace_ptr, void* barrier_ptr) {}; + + virtual std::vector get_workspace_shape() { return {0}; }; + + virtual std::vector get_barrier_shape() { return {0}; }; +}; + +template +class NormFwdTe : public NormBase { + public: + NormFwdTe(); + + NormFwdTe(const Tensor& x, const Tensor& gamma, const Tensor& beta, const float epsilon, + Tensor* z, Tensor* mu, Tensor* rsigma, cudaStream_t stream, + const int multiprocessorCount, Tensor* workspace, Tensor* barrier, + const bool zero_centered_gamma); + + void initialize() override; + + void set_workspace_and_barrier(void* workspace_ptr, void* barrier_ptr) override; + + void set_amax(); + + void execute() override; + + std::vector get_workspace_shape() override; + + std::vector get_barrier_shape() override; + + typename LauncherType::ParamsType _launch_params; + + typename LauncherType::FunctionType _launcher; +}; + +template +class NormBwdTe : public NormFwdTe { + public: + NormBwdTe(const Tensor& dz, const Tensor& x, const Tensor& mu, const Tensor& rsigma, + const Tensor& gamma, Tensor* dx, Tensor* dgamma, Tensor* dbeta, Tensor* dgamma_part, + Tensor* dbeta_part, cudaStream_t stream, const int multiprocessorCount, + Tensor* workspace, Tensor* barrier, const bool zero_centered_gamma); + + std::vector get_dgamma_shape(); +}; + +template +void norms_launcher(NormType& Norm, Tensor* workspace, Tensor* barrier = nullptr, + Tensor* dgamma_part = nullptr, Tensor* dbeta_part = nullptr); + +} // namespace transformer_engine + +#endif // TRANSFORMER_ENGINE_COMMON_LAYER_NORM_LN_H_ diff --git a/transformer_engine/common/rmsnorm/rmsnorm.h b/transformer_engine/common/rmsnorm/rmsnorm.h deleted file mode 100644 index 8b4e1cf24e..0000000000 --- a/transformer_engine/common/rmsnorm/rmsnorm.h +++ /dev/null @@ -1,89 +0,0 @@ -/************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * - * See LICENSE for license information. - ************************************************************************/ - -#ifndef TRANSFORMER_ENGINE_COMMON_RMSNORM_RMSNORM_H_ -#define TRANSFORMER_ENGINE_COMMON_RMSNORM_RMSNORM_H_ - -#include - -#include -#include -#include -#include -#include - -#include "../common.h" -#include "../layer_norm/ln.h" - -namespace transformer_engine { -namespace rmsnorm { - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct LaunchParams : public transformer_engine::layer_norm::LaunchParams {}; -struct FwdParams : public transformer_engine::layer_norm::FwdParams {}; -struct BwdParams : public transformer_engine::layer_norm::BwdParams {}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -using FwdFunction = std::function &, const bool)>; -using BwdFunction = std::function &, const bool)>; -using FunctionKey = uint64_t; -using FwdTunedRegistry = std::unordered_map; -using BwdTunedRegistry = std::unordered_map; -using FwdGeneralRegistry = std::unordered_map>; -using BwdGeneralRegistry = std::unordered_map>; - -extern FwdTunedRegistry FWD_TUNED_FUNCS; -extern BwdTunedRegistry BWD_TUNED_FUNCS; -extern FwdGeneralRegistry FWD_GENERAL_FUNCS; -extern BwdGeneralRegistry BWD_GENERAL_FUNCS; - -////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct FwdTunedRegistrar { - explicit FwdTunedRegistrar(FwdFunction f) { - uint64_t key = layer_norm::Types2Key::get(HIDDEN_SIZE); - FWD_TUNED_FUNCS.insert({key, f}); - } -}; - -////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct FwdGeneralRegistrar { - explicit FwdGeneralRegistrar(FwdFunction f) { - uint64_t key = layer_norm::Types2Key::get(0); - FWD_GENERAL_FUNCS[key].insert({HIDDEN_SIZE, f}); - } -}; - -////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct BwdTunedRegistrar { - explicit BwdTunedRegistrar(BwdFunction f) { - uint64_t key = layer_norm::Types2Key::get(HIDDEN_SIZE); - BWD_TUNED_FUNCS.insert({key, f}); - } -}; - -////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct BwdGeneralRegistrar { - explicit BwdGeneralRegistrar(BwdFunction f) { - uint64_t key = layer_norm::Types2Key::get(0); - BWD_GENERAL_FUNCS[key].insert({HIDDEN_SIZE, f}); - } -}; - -} // namespace rmsnorm -} // namespace transformer_engine - -#endif // TRANSFORMER_ENGINE_COMMON_RMSNORM_RMSNORM_H_ diff --git a/transformer_engine/common/rmsnorm/rmsnorm_api.cpp b/transformer_engine/common/rmsnorm/rmsnorm_api.cpp index 9b143b2f85..6f2ed26ed5 100644 --- a/transformer_engine/common/rmsnorm/rmsnorm_api.cpp +++ b/transformer_engine/common/rmsnorm/rmsnorm_api.cpp @@ -5,219 +5,44 @@ ************************************************************************/ #include +#include +#include #include #include #include "../common.h" -#include "rmsnorm.h" +#include "../layer_norm/norms.h" #include "transformer_engine/rmsnorm.h" -/* - -Supported Type combinations: - -input compute weights output -======================================= -fp32 fp32 fp32 fp32 -fp16 fp32 fp16 fp16 -bf16 fp32 bf16 bf16 -fp32 fp32 fp32 fp16 -fp32 fp32 fp32 bf16 -fp32 fp32 fp32 fp8 -fp16 fp32 fp16 fp8 -bf16 fp32 bf16 fp8 - -Remarks: -Input type = Weight type -Compute always in FP32 - -*/ - namespace transformer_engine { -namespace layer_norm { -uint64_t get_key(DType wtype, DType itype, DType otype, DType ctype, uint64_t hidden_size); -} - -namespace rmsnorm { - -using namespace transformer_engine; - -FwdTunedRegistry FWD_TUNED_FUNCS; -BwdTunedRegistry BWD_TUNED_FUNCS; -FwdGeneralRegistry FWD_GENERAL_FUNCS; -BwdGeneralRegistry BWD_GENERAL_FUNCS; - -FwdFunction &get_fwd_launcher(DType wtype, DType itype, DType otype, DType ctype, - const layer_norm::FwdParams ¶ms) { - // Look for tuned kernel - auto tuned_key = layer_norm::get_key(wtype, itype, otype, ctype, params.cols); - auto is_aligned = [](const void *ptr) -> bool { - // Assume vectorized memory accesses are <=16B - return reinterpret_cast(ptr) % 16 == 0; - }; - if (params.rows % 4 == 0 && is_aligned(params.x) && is_aligned(params.rs) && - is_aligned(params.gamma) && is_aligned(params.z) && FWD_TUNED_FUNCS.count(tuned_key) > 0) { - return FWD_TUNED_FUNCS.at(tuned_key); - } - - // Pick general kernel - auto general_key = layer_norm::get_key(wtype, itype, otype, ctype, 0); - if (FWD_GENERAL_FUNCS.count(general_key) == 0) { - NVTE_ERROR("FWD: Unsupported types."); - } - auto &general_func_map = FWD_GENERAL_FUNCS.at(general_key); - auto func_iter = general_func_map.lower_bound(params.cols); - if (func_iter == general_func_map.end()) { - // Hidden size is too big, need to use multi-CTA - return general_func_map.rbegin()->second; - } else { - return func_iter->second; - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -BwdFunction &get_bwd_launcher(DType wtype, DType itype, DType otype, DType ctype, - const layer_norm::BwdParams ¶ms) { - // Look for tuned kernel - auto tuned_key = layer_norm::get_key(wtype, itype, otype, ctype, params.cols); - auto is_aligned = [](const void *ptr) -> bool { - // Assume vectorized memory accesses are <=16B - return reinterpret_cast(ptr) % 16 == 0; - }; - if (params.rows % 4 == 0 && is_aligned(params.x) && is_aligned(params.rs) && - is_aligned(params.gamma) && is_aligned(params.dz) && is_aligned(params.dx) && - is_aligned(params.dgamma) && is_aligned(params.dgamma_part) && - BWD_TUNED_FUNCS.count(tuned_key) > 0) { - return BWD_TUNED_FUNCS.at(tuned_key); - } - - // Pick general kernel - auto general_key = layer_norm::get_key(wtype, itype, otype, ctype, 0); - if (BWD_GENERAL_FUNCS.count(general_key) == 0) { - NVTE_ERROR("BWD: Unsupported types."); - } - auto &general_func_map = BWD_GENERAL_FUNCS.at(general_key); - auto func_iter = general_func_map.lower_bound(params.cols); - if (func_iter == general_func_map.end()) { - // Hidden size is too big, need to use multi-CTA - return general_func_map.rbegin()->second; - } else { - return func_iter->second; - } -} - -// //////////////////////////////////////////////////////////////////////////////////////////////////// - -inline size_t product(const std::vector &shape) { - return std::accumulate(shape.cbegin(), shape.cend(), size_t{1}, std::multiplies<>()); -} - -} // namespace rmsnorm - -//////////////////////////////////////////////////////////////////////////////////////////////////// - void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tensor *z, Tensor *rsigma, cudaStream_t stream, const int multiprocessorCount, Tensor *workspace, Tensor *barrier, const bool zero_centered_gamma) { - auto itype = x.data.dtype; - auto wtype = gamma.data.dtype; - auto otype = z->data.dtype; - const bool fp8_out = is_fp8_dtype(otype); - auto ctype = DType::kFloat32; - NVTE_CHECK(x.data.shape.size() == 2); - const size_t rows = x.data.shape[0]; - const size_t cols = x.data.shape[1]; - const auto hidden_size = gamma.data.shape[0]; - - NVTE_CHECK(hidden_size == cols); + NVTE_CHECK(gamma.data.shape[0] == x.data.shape[1]); NVTE_CHECK(epsilon >= 0.f); NVTE_CHECK(z->data.shape == x.data.shape); - NVTE_CHECK(rsigma->data.shape == std::vector{rows}); - NVTE_CHECK(rsigma->data.dtype == ctype); - - rmsnorm::LaunchParams launch_params; + NVTE_CHECK(rsigma->data.shape == std::vector{x.data.shape[0]}); + NVTE_CHECK(rsigma->data.dtype == DType::kFloat32); - launch_params.multiprocessorCount = multiprocessorCount; - launch_params.stream = stream; + if (workspace->data.dptr != nullptr) { + CheckInputTensor(x, "x"); + CheckInputTensor(gamma, "gamma"); - // Set the kernel runtime parameters. - rmsnorm::FwdParams ¶ms = launch_params.params; - params.rows = rows; - params.cols = cols; - params.x = x.data.dptr; - params.mu = nullptr; - params.rs = rsigma->data.dptr; - params.gamma = gamma.data.dptr; - params.beta = nullptr; - params.z = z->data.dptr; - params.epsilon = epsilon; - params.amax = z->amax.dptr; - params.scale = z->scale.dptr; - params.scale_inv = z->scale_inv.dptr; - params.fp8_out = fp8_out; - params.zero_centered_gamma = zero_centered_gamma; - - // Request the kernel launcher. - auto launcher = rmsnorm::get_fwd_launcher(wtype, itype, otype, ctype, params); - - // Query the kernel-specific launch parameters. - launcher(launch_params, true); - if (launch_params.workspace_bytes == 0) { - launch_params.workspace_bytes = 1; - } - - if (workspace->data.dptr == nullptr) { - NVTE_CHECK(barrier->data.dptr == nullptr); - - workspace->data.dtype = DType::kByte; - workspace->data.shape = {launch_params.workspace_bytes}; - - barrier->data.dtype = DType::kInt32; - barrier->data.shape = {launch_params.barrier_size}; - - return; - } else { - NVTE_CHECK(workspace->data.dtype == DType::kByte); - NVTE_CHECK(workspace->data.shape == std::vector{launch_params.workspace_bytes}); + CheckOutputTensor(*z, "z"); + CheckOutputTensor(*rsigma, "rsigma"); } - if (launch_params.barrier_size > 0) { - NVTE_CHECK(barrier->data.dptr != nullptr); - NVTE_CHECK(barrier->data.dtype == DType::kInt32); - NVTE_CHECK(barrier->data.shape == std::vector{launch_params.barrier_size}); - } - - // Tensor checks are delayed here in order to recover workspace sizes with null data - CheckInputTensor(x, "x"); - CheckInputTensor(gamma, "gamma"); - - CheckOutputTensor(*z, "z"); - CheckOutputTensor(*rsigma, "rsigma"); - - if (launch_params.barrier_size > 0) { - params.workspace = workspace->data.dptr; - params.barrier = reinterpret_cast(barrier->data.dptr); - } - - // Clear buffers - if (params.fp8_out) { - cudaMemsetAsync(params.amax, 0, rmsnorm::product(z->amax.shape) * typeToSize(z->amax.dtype), - stream); - } - if (launch_params.barrier_size > 0) { - cudaMemsetAsync(params.barrier, 0, - rmsnorm::product(barrier->data.shape) * typeToSize(barrier->data.dtype), - stream); - } + Tensor empty; - // Launch the kernel. - launcher(launch_params, false); + NormFwdTe NormFwd(x, gamma, empty, epsilon, z, &empty, rsigma, stream, + multiprocessorCount, workspace, barrier, + zero_centered_gamma); + norms_launcher(NormFwd, workspace, barrier); return; } @@ -228,22 +53,13 @@ void rmsnorm_bwd(const Tensor &dz, const Tensor &x, const Tensor &rsigma, const const bool zero_centered_gamma) { using namespace transformer_engine; - auto itype = x.data.dtype; - auto wtype = gamma.data.dtype; - auto otype = wtype; - auto ctype = DType::kFloat32; - - NVTE_CHECK(dz.data.dtype == otype); - NVTE_CHECK(rsigma.data.dtype == ctype); + NVTE_CHECK(dz.data.dtype == gamma.data.dtype); + NVTE_CHECK(rsigma.data.dtype == DType::kFloat32); NVTE_CHECK(x.data.shape.size() == 2); NVTE_CHECK(dz.data.shape == x.data.shape); - const auto rows = x.data.shape[0]; - const auto cols = x.data.shape[1]; - const auto hidden_size = gamma.data.shape[0]; - - NVTE_CHECK(gamma.data.shape[0] == cols); + NVTE_CHECK(gamma.data.shape[0] == x.data.shape[1]); NVTE_CHECK(dx->data.shape == x.data.shape); NVTE_CHECK(dx->data.dtype == x.data.dtype); @@ -251,82 +67,21 @@ void rmsnorm_bwd(const Tensor &dz, const Tensor &x, const Tensor &rsigma, const NVTE_CHECK(dgamma->data.shape == gamma.data.shape); NVTE_CHECK(dgamma->data.dtype == gamma.data.dtype); - rmsnorm::LaunchParams launch_params; - launch_params.stream = stream; - launch_params.multiprocessorCount = multiprocessorCount; - - // Set the kernel runtime parameters. - rmsnorm::BwdParams ¶ms = launch_params.params; - params.rows = rows; - params.cols = cols; - params.x = x.data.dptr; - params.mu = nullptr; - params.rs = rsigma.data.dptr; - params.gamma = gamma.data.dptr; - params.dz = dz.data.dptr; - params.dx = dx->data.dptr; - params.dbeta = nullptr; - params.dgamma = dgamma->data.dptr; - params.dbeta_part = nullptr; - params.dgamma_part = dgamma_part->data.dptr; - params.zero_centered_gamma = zero_centered_gamma; - - // Request the kernel launcher. - auto launcher = rmsnorm::get_bwd_launcher(wtype, itype, otype, ctype, params); - - // Query the kernel-specific launch parameters. - launcher(launch_params, true); - - // Populate shape and dtypes for FW to allocate memory - if (dgamma_part->data.dptr == nullptr) { - dgamma_part->data.dtype = ctype; - dgamma_part->data.shape = {static_cast(launch_params.params.ctas_per_col), - hidden_size}; - - workspace->data.dtype = DType::kByte; - workspace->data.shape = {launch_params.workspace_bytes}; - - barrier->data.dtype = DType::kInt32; - barrier->data.shape = {launch_params.barrier_size}; - - return; - } else { - auto pdw_shape = - std::vector{static_cast(launch_params.params.ctas_per_col), hidden_size}; - NVTE_CHECK(dgamma_part->data.dtype == ctype); - NVTE_CHECK(dgamma_part->data.shape == pdw_shape); - } - - if (launch_params.barrier_size > 0) { - NVTE_CHECK(barrier->data.dptr != nullptr); - NVTE_CHECK(barrier->data.dtype == DType::kInt32); - NVTE_CHECK(barrier->data.shape == std::vector{launch_params.barrier_size}); + if (workspace->data.dptr != nullptr) { + CheckInputTensor(dz, "dz"); + CheckInputTensor(x, "x"); + CheckInputTensor(rsigma, "rsigma"); + CheckInputTensor(gamma, "gamma"); + CheckOutputTensor(*dx, "dx"); + CheckOutputTensor(*dgamma, "dgamma"); } - if (launch_params.workspace_bytes > 0) { - NVTE_CHECK(workspace->data.dptr != nullptr); - NVTE_CHECK(workspace->data.dtype == DType::kByte); - NVTE_CHECK(workspace->data.shape == std::vector{launch_params.workspace_bytes}); - } - - // Tensor checks are delayed here in order to recover workspace sizes with null data - CheckInputTensor(dz, "dz"); - CheckInputTensor(x, "x"); - CheckInputTensor(rsigma, "rsigma"); - CheckInputTensor(gamma, "gamma"); - CheckOutputTensor(*dx, "dx"); - CheckOutputTensor(*dgamma, "dgamma"); - - if (launch_params.barrier_size > 0) { - params.workspace = workspace->data.dptr; - params.barrier = reinterpret_cast(barrier->data.dptr); - cudaMemsetAsync(params.barrier, 0, - rmsnorm::product(barrier->data.shape) * typeToSize(barrier->data.dtype), - stream); - } + Tensor empty; - // Launch the kernel. - launcher(launch_params, false); + NormBwdTe BwdNorm(dz, x, empty, rsigma, gamma, dx, dgamma, &empty, + dgamma_part, &empty, stream, multiprocessorCount, + workspace, barrier, zero_centered_gamma); + norms_launcher(BwdNorm, workspace, barrier, dgamma_part); } } // namespace transformer_engine diff --git a/transformer_engine/common/rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu b/transformer_engine/common/rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu index 3215a6a9d4..9bb8e56a0a 100644 --- a/transformer_engine/common/rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu +++ b/transformer_engine/common/rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu @@ -4,7 +4,7 @@ * See LICENSE for license information. ************************************************************************/ -#include "rmsnorm.h" +#include "../layer_norm/norms.h" #include "rmsnorm_bwd_kernels.cuh" #include "rmsnorm_kernel_traits.h" @@ -132,27 +132,27 @@ void launch_general_(LaunchParams &launch_params, //////////////////////////////////////////////////////////////////////////////////////////////////// -#define REGISTER_BWD_TUNED_LAUNCHER(HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, CTAS_PER_ROW, \ - WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINALIZE) \ - void rmsnorm_bwd_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ - LaunchParams &launch_params, const bool configure_params) { \ - launch_tuned_(launch_params, \ - configure_params); \ - } \ - static BwdTunedRegistrar \ - reg_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ +#define REGISTER_BWD_TUNED_LAUNCHER(HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, CTAS_PER_ROW, \ + WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINALIZE) \ + void rmsnorm_bwd_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ + LaunchParams &launch_params, const bool configure_params) { \ + launch_tuned_(launch_params, \ + configure_params); \ + } \ + static NormRegistrar \ + reg_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ rmsnorm_bwd_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE) -#define REGISTER_BWD_GENERAL_LAUNCHER(HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, WARPS_M, WARPS_N, \ - BYTES_PER_LDG, BYTES_PER_LDG_FINALIZE) \ - void rmsnorm_bwd_general_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ - LaunchParams &launch_params, const bool configure_params) { \ - launch_general_(launch_params, configure_params); \ - } \ - static BwdGeneralRegistrar \ - reg_general_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ +#define REGISTER_BWD_GENERAL_LAUNCHER(HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, WARPS_M, WARPS_N, \ + BYTES_PER_LDG, BYTES_PER_LDG_FINALIZE) \ + void rmsnorm_bwd_general_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ + LaunchParams &launch_params, const bool configure_params) { \ + launch_general_(launch_params, configure_params); \ + } \ + static NormRegistrar \ + reg_general_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ rmsnorm_bwd_general_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE) //////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/transformer_engine/common/rmsnorm/rmsnorm_fwd_cuda_kernel.cu b/transformer_engine/common/rmsnorm/rmsnorm_fwd_cuda_kernel.cu index 3c8e121540..c309d913f8 100644 --- a/transformer_engine/common/rmsnorm/rmsnorm_fwd_cuda_kernel.cu +++ b/transformer_engine/common/rmsnorm/rmsnorm_fwd_cuda_kernel.cu @@ -4,7 +4,7 @@ * See LICENSE for license information. ************************************************************************/ -#include "rmsnorm.h" +#include "../layer_norm/norms.h" #include "rmsnorm_fwd_kernels.cuh" #include "rmsnorm_kernel_traits.h" @@ -106,26 +106,26 @@ void launch_general_(LaunchParams &launch_params, //////////////////////////////////////////////////////////////////////////////////////////////////// -#define REGISTER_FWD_TUNED_LAUNCHER(HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, CTAS_PER_ROW, \ - WARPS_M, WARPS_N, BYTES_PER_LDG) \ - void rmsnorm_fwd_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ - LaunchParams &launch_params, const bool configure_params) { \ - launch_tuned_(launch_params, configure_params); \ - } \ - static FwdTunedRegistrar \ - reg_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ +#define REGISTER_FWD_TUNED_LAUNCHER(HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, CTAS_PER_ROW, \ + WARPS_M, WARPS_N, BYTES_PER_LDG) \ + void rmsnorm_fwd_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ + LaunchParams &launch_params, const bool configure_params) { \ + launch_tuned_(launch_params, configure_params); \ + } \ + static NormRegistrar \ + reg_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ rmsnorm_fwd_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE) -#define REGISTER_FWD_GENERAL_LAUNCHER(HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, WARPS_M, WARPS_N, \ - BYTES_PER_LDG) \ - void rmsnorm_fwd_general_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ - LaunchParams &launch_params, const bool configure_params) { \ - launch_general_(launch_params, configure_params); \ - } \ - static FwdGeneralRegistrar \ - reg_general_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ +#define REGISTER_FWD_GENERAL_LAUNCHER(HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, WARPS_M, WARPS_N, \ + BYTES_PER_LDG) \ + void rmsnorm_fwd_general_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ + LaunchParams &launch_params, const bool configure_params) { \ + launch_general_(launch_params, configure_params); \ + } \ + static NormRegistrar \ + reg_general_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ rmsnorm_fwd_general_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE) ////////////////////////////////////////////////////////////////////////////////////////////////////