diff --git a/tests/cpp/operator/test_cast_transpose.cu b/tests/cpp/operator/test_cast_transpose.cu index 8c168c76f4..39a6614179 100644 --- a/tests/cpp/operator/test_cast_transpose.cu +++ b/tests/cpp/operator/test_cast_transpose.cu @@ -69,6 +69,8 @@ void performTest(const size_t N, const size_t H) { if (isFp8Type(otype)) { auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32); compareResults("amax", output_c.amax(), ref_amax, atol_amax, rtol_amax); + float ref_scale_inv = 1.f / output_c.scale(); + compareResults("scale_inv", output_c.scale_inv(), ref_scale_inv, atol_amax, rtol_amax); } auto [atol, rtol] = getTolerances(otype); compareResults("output_c", output_c, ref_output_c.get(), atol, rtol); diff --git a/tests/cpp/operator/test_cast_transpose_dbias.cu b/tests/cpp/operator/test_cast_transpose_dbias.cu index a2c8594730..651508c871 100644 --- a/tests/cpp/operator/test_cast_transpose_dbias.cu +++ b/tests/cpp/operator/test_cast_transpose_dbias.cu @@ -116,6 +116,8 @@ void performTest(const size_t N, const size_t H) { if (isFp8Type(otype)) { auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32); compareResults("amax", output_c.amax(), ref_amax, atol_amax, rtol_amax); + float ref_scale_inv = 1.f / output_c.scale(); + compareResults("scale_inv", output_c.scale_inv(), ref_scale_inv, atol_amax, rtol_amax); } auto [atol, rtol] = getTolerances(otype); compareResults("output_c", output_c, ref_output_c.get(), atol, rtol); diff --git a/tests/cpp/operator/test_cast_transpose_dbias_dgelu.cu b/tests/cpp/operator/test_cast_transpose_dbias_dgelu.cu index 5920f23f38..38ac955bc9 100644 --- a/tests/cpp/operator/test_cast_transpose_dbias_dgelu.cu +++ b/tests/cpp/operator/test_cast_transpose_dbias_dgelu.cu @@ -132,6 +132,8 @@ void performTest(const size_t N, const size_t H) { if (isFp8Type(otype)) { auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32); compareResults("amax", output_c.amax(), ref_amax, atol_amax, rtol_amax); + float ref_scale_inv = 1.f / output_c.scale(); + compareResults("scale_inv", output_c.scale_inv(), ref_scale_inv, atol_amax, rtol_amax); } auto [atol, rtol] = getTolerances(otype); diff --git a/tests/cpp/operator/test_layernorm.cu b/tests/cpp/operator/test_layernorm.cu index 07a8a8884c..cdd8e7846c 100644 --- a/tests/cpp/operator/test_layernorm.cu +++ b/tests/cpp/operator/test_layernorm.cu @@ -230,6 +230,8 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma) auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32); if (isFp8Type(otype)) { compareResults("amax", z.amax(), ref_amax, atol_amax, rtol_amax); + float ref_scale_inv = 1.f / z.scale(); + compareResults("scale_inv", z.scale_inv(), ref_scale_inv, atol_amax, rtol_amax); } auto [atol_stats, rtol_stats] = getTolerances(DType::kFloat32); diff --git a/tests/cpp/operator/test_multi_cast_transpose.cu b/tests/cpp/operator/test_multi_cast_transpose.cu index 85fc3a573a..e7fb183217 100644 --- a/tests/cpp/operator/test_multi_cast_transpose.cu +++ b/tests/cpp/operator/test_multi_cast_transpose.cu @@ -139,6 +139,10 @@ void performTest() { output_c_list[tensor_id].amax(), ref_amax_list[tensor_id], atol_amax, rtol_amax); + compareResults("scale_inv", + output_c_list[tensor_id].scale_inv(), + 1.f / output_c_list[tensor_id].scale(), + atol_amax, rtol_amax); } auto [atol, rtol] = getTolerances(otype); compareResults("output_c", diff --git a/tests/cpp/operator/test_rmsnorm.cu b/tests/cpp/operator/test_rmsnorm.cu index e4e34bac8a..0ec3a877e5 100644 --- a/tests/cpp/operator/test_rmsnorm.cu +++ b/tests/cpp/operator/test_rmsnorm.cu @@ -187,6 +187,8 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma) auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32); if (isFp8Type(otype)) { compareResults("amax", z.amax(), ref_amax, atol_amax, rtol_amax); + float ref_scale_inv = 1.f / z.scale(); + compareResults("scale_inv", z.scale_inv(), ref_scale_inv, atol_amax, rtol_amax); } auto [atol_stats, rtol_stats] = getTolerances(DType::kFloat32); diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index 7fab75dca0..a4497751f4 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -44,6 +44,7 @@ set(transformer_engine_SOURCES) list(APPEND transformer_engine_SOURCES pycudnn.cpp transformer_engine.cpp + common.cu transpose/cast_transpose.cu transpose/transpose.cu transpose/cast_transpose_fusion.cu diff --git a/transformer_engine/common/activation/activation_template.h b/transformer_engine/common/activation/activation_template.h index fc93705dff..6184e235bd 100644 --- a/transformer_engine/common/activation/activation_template.h +++ b/transformer_engine/common/activation/activation_template.h @@ -27,7 +27,8 @@ void act_fn(const Tensor &input, Tensor *output, cudaStream_t stream) { reinterpret_cast(input.data.dptr), reinterpret_cast(output->data.dptr), reinterpret_cast(output->scale.dptr), - reinterpret_cast(output->amax.dptr), tot_elts, {}, + reinterpret_cast(output->amax.dptr), + reinterpret_cast(output->scale_inv.dptr), tot_elts, {}, stream);); // NOLINT(*) ); // NOLINT(*) } @@ -50,7 +51,8 @@ void dact_fn(const Tensor &grad, const Tensor &input, Tensor *output, cudaStream reinterpret_cast(input.data.dptr), reinterpret_cast(output->data.dptr), reinterpret_cast(output->scale.dptr), - reinterpret_cast(output->amax.dptr), tot_elts, {}, + reinterpret_cast(output->amax.dptr), + reinterpret_cast(output->scale_inv.dptr), tot_elts, {}, stream);); // NOLINT(*) ); // NOLINT(*) } @@ -74,7 +76,8 @@ void gated_act_fn(const Tensor &input, Tensor *output, cudaStream_t stream) { reinterpret_cast(input.data.dptr), reinterpret_cast(output->data.dptr), reinterpret_cast(output->scale.dptr), - reinterpret_cast(output->amax.dptr), output->data.shape[0], + reinterpret_cast(output->amax.dptr), + reinterpret_cast(output->scale_inv.dptr), output->data.shape[0], output->data.shape[1], {}, stream);); // NOLINT(*) ); // NOLINT(*) diff --git a/transformer_engine/common/common.cu b/transformer_engine/common/common.cu new file mode 100644 index 0000000000..4e95fc24de --- /dev/null +++ b/transformer_engine/common/common.cu @@ -0,0 +1,32 @@ +/************************************************************************* + * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include + +#include "./common.h" +#include "./utils.cuh" + +namespace transformer_engine { + +namespace { + +__global__ void __launch_bounds__(1) + update_tensor_scale_inv_kernel(const float* __restrict__ scale_ptr, + float* __restrict__ scale_inv_ptr) { + const float scale = scale_ptr == nullptr ? 1 : *scale_ptr; + reciprocal(scale_inv_ptr, scale); +} + +} // namespace + +void update_tensor_scale_inv(Tensor* t, cudaStream_t stream) { + if (t->scale_inv.dptr != nullptr) { + update_tensor_scale_inv_kernel<<<1, 1, 0, stream>>>( + reinterpret_cast(t->scale.dptr), reinterpret_cast(t->scale_inv.dptr)); + } +} + +} // namespace transformer_engine diff --git a/transformer_engine/common/common.h b/transformer_engine/common/common.h index 42b529f388..7e72e1b031 100644 --- a/transformer_engine/common/common.h +++ b/transformer_engine/common/common.h @@ -262,6 +262,13 @@ void CheckOutputTensor(const Tensor &t, const std::string &name, bool allow_empt bool is_fp8_dtype(const DType t); +/*! \brief Update a tensor's FP8 scale-inverse + * + * The FP8 scale-inverse (dequantization scaling factor) is updated + * with the reciprocal of the FP8 scale (quantization scaling factor). + */ +void update_tensor_scale_inv(Tensor *t, cudaStream_t stream); + #define NVTE_API_CALL(api_name) \ transformer_engine::nvtx::NVTXWrapper _##api_name##_nvtx_wrapper(#api_name); diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index c9b57752e2..8667b64e65 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -269,6 +269,11 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, workspace, /* workspace */ workspaceSize, stream)); /* stream */ + // Update FP8 scale-inv in output tensor + if (is_fp8_dtype(outputD->data.dtype)) { + update_tensor_scale_inv(outputD, stream); + } + NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceDestroy(preference)); NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutDestroy(Ddesc)); NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutDestroy(Cdesc)); diff --git a/transformer_engine/common/layer_norm/ln.h b/transformer_engine/common/layer_norm/ln.h index 45839ed75b..13543a10aa 100644 --- a/transformer_engine/common/layer_norm/ln.h +++ b/transformer_engine/common/layer_norm/ln.h @@ -89,6 +89,9 @@ struct FwdParams : public ParamsBase { // AMax output void *amax; + // Inverse of scaling factor + void *scale_inv; + // Whether to compute scale and amax bool fp8_out; }; diff --git a/transformer_engine/common/layer_norm/ln_api.cpp b/transformer_engine/common/layer_norm/ln_api.cpp index 115422e94e..8a40450e59 100644 --- a/transformer_engine/common/layer_norm/ln_api.cpp +++ b/transformer_engine/common/layer_norm/ln_api.cpp @@ -196,6 +196,7 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size params.epsilon = epsilon; params.amax = z->amax.dptr; params.scale = z->scale.dptr; + params.scale_inv = z->scale_inv.dptr; params.fp8_out = fp8_out; params.zero_centered_gamma = zero_centered_gamma; diff --git a/transformer_engine/common/layer_norm/ln_fwd_kernels.cuh b/transformer_engine/common/layer_norm/ln_fwd_kernels.cuh index 9fe4c16373..bd3741d1d1 100644 --- a/transformer_engine/common/layer_norm/ln_fwd_kernels.cuh +++ b/transformer_engine/common/layer_norm/ln_fwd_kernels.cuh @@ -132,10 +132,18 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_fwd_tuned_kernel( } } if (params.fp8_out) { - amax = reduce_max(amax, warp); - if (threadIdx.x == 0 && threadIdx.y == 0) { - static_assert(std::is_same::value); - atomicMaxFloat(reinterpret_cast(params.amax), amax); + // Reduce amax over block + if (params.amax != nullptr) { + amax = reduce_max(amax, warp); + if (threadIdx.x == 0) { + static_assert(std::is_same::value); + atomicMaxFloat(reinterpret_cast(params.amax), amax); + } + } + + // Update scale-inverse + if (blockIdx.x == 0 && threadIdx.x == 0 && params.scale_inv != nullptr) { + reciprocal(reinterpret_cast(params.scale_inv), scale); } } } @@ -291,10 +299,18 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_fwd_general_kerne // Finalize fp8 factors if (params.fp8_out) { - amax = reduce_max(amax, warp); - if (threadIdx.x == 0) { - static_assert(std::is_same::value); - atomicMaxFloat(reinterpret_cast(params.amax), amax); + // Reduce amax over block + if (params.amax != nullptr) { + amax = reduce_max(amax, warp); + if (threadIdx.x == 0) { + static_assert(std::is_same::value); + atomicMaxFloat(reinterpret_cast(params.amax), amax); + } + } + + // Update scale-inverse + if (blockIdx.x == 0 && threadIdx.x == 0 && params.scale_inv != nullptr) { + reciprocal(reinterpret_cast(params.scale_inv), scale); } } } diff --git a/transformer_engine/common/rmsnorm/rmsnorm_api.cpp b/transformer_engine/common/rmsnorm/rmsnorm_api.cpp index e9a6ff483d..9b143b2f85 100644 --- a/transformer_engine/common/rmsnorm/rmsnorm_api.cpp +++ b/transformer_engine/common/rmsnorm/rmsnorm_api.cpp @@ -159,6 +159,7 @@ void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tens params.epsilon = epsilon; params.amax = z->amax.dptr; params.scale = z->scale.dptr; + params.scale_inv = z->scale_inv.dptr; params.fp8_out = fp8_out; params.zero_centered_gamma = zero_centered_gamma; diff --git a/transformer_engine/common/rmsnorm/rmsnorm_fwd_kernels.cuh b/transformer_engine/common/rmsnorm/rmsnorm_fwd_kernels.cuh index a1cfc2293c..c435ae3744 100644 --- a/transformer_engine/common/rmsnorm/rmsnorm_fwd_kernels.cuh +++ b/transformer_engine/common/rmsnorm/rmsnorm_fwd_kernels.cuh @@ -125,10 +125,18 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_fwd_tuned_ke } } if (params.fp8_out) { - amax = reduce_max(amax, warp); - if (threadIdx.x == 0 && threadIdx.y == 0) { - static_assert(std::is_same::value); - atomicMaxFloat(reinterpret_cast(params.amax), amax); + // Reduce amax over block + if (params.amax != nullptr) { + amax = reduce_max(amax, warp); + if (threadIdx.x == 0) { + static_assert(std::is_same::value); + atomicMaxFloat(reinterpret_cast(params.amax), amax); + } + } + + // Update scale-inverse + if (blockIdx.x == 0 && threadIdx.x == 0 && params.scale_inv != nullptr) { + reciprocal(reinterpret_cast(params.scale_inv), scale); } } } @@ -267,10 +275,18 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_fwd_general_ // Finalize fp8 factors if (params.fp8_out) { - amax = reduce_max(amax, warp); - if (threadIdx.x == 0) { - static_assert(std::is_same::value); - atomicMaxFloat(reinterpret_cast(params.amax), amax); + // Reduce amax over block + if (params.amax != nullptr) { + amax = reduce_max(amax, warp); + if (threadIdx.x == 0) { + static_assert(std::is_same::value); + atomicMaxFloat(reinterpret_cast(params.amax), amax); + } + } + + // Update scale-inverse + if (blockIdx.x == 0 && threadIdx.x == 0 && params.scale_inv != nullptr) { + reciprocal(reinterpret_cast(params.scale_inv), scale); } } } diff --git a/transformer_engine/common/transpose/cast_transpose.cu b/transformer_engine/common/transpose/cast_transpose.cu index 6cbd4daade..dd45d0a668 100644 --- a/transformer_engine/common/transpose/cast_transpose.cu +++ b/transformer_engine/common/transpose/cast_transpose.cu @@ -101,14 +101,11 @@ struct KernelConfig { }; template -__global__ void __launch_bounds__(block_size) - cast_transpose_general_kernel(const IType *__restrict__ const input, - const CType *__restrict__ const noop, - OType *__restrict__ const output_c, - OType *__restrict__ const output_t, - const CType *__restrict__ const scale_ptr, - CType *__restrict__ const amax_ptr, const size_t row_length, - const size_t num_rows) { +__global__ void __launch_bounds__(block_size) cast_transpose_general_kernel( + const IType *__restrict__ const input, const CType *__restrict__ const noop, + OType *__restrict__ const output_c, OType *__restrict__ const output_t, + const CType *__restrict__ const scale_ptr, CType *__restrict__ const amax_ptr, + CType *__restrict__ const scale_inv_ptr, const size_t row_length, const size_t num_rows) { if (noop != nullptr && noop[0] == 1.0f) return; // Vectorized load/store sizes @@ -207,9 +204,15 @@ __global__ void __launch_bounds__(block_size) if (amax_ptr != nullptr) { amax = reduce_max(amax, tidy); if (threadIdx.x == 0) { + static_assert(std::is_same::value); atomicMaxFloat(amax_ptr, amax); } } + + // Update scale-inverse + if (blockIdx.x == 0 && threadIdx.x == 0 && scale_inv_ptr != nullptr) { + reciprocal(scale_inv_ptr, scale); + } } } // namespace @@ -255,6 +258,8 @@ void cast_transpose(const Tensor &input, const Tensor &noop, Tensor *cast_output "Cast and transposed outputs need to share amax tensor."); NVTE_CHECK(cast_output.scale.dptr == transposed_output.scale.dptr, "Cast and transposed outputs need to share scale tensor."); + NVTE_CHECK(cast_output.scale_inv.dptr == transposed_output.scale_inv.dptr, + "Cast and transposed outputs need to share scale-inverse tensor."); TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( input.data.dtype, InputType, @@ -324,7 +329,9 @@ void cast_transpose(const Tensor &input, const Tensor &noop, Tensor *cast_output static_cast(cast_output.data.dptr), static_cast(transposed_output.data.dptr), static_cast(cast_output.scale.dptr), - static_cast(cast_output.amax.dptr), row_length, num_rows); + static_cast(cast_output.amax.dptr), + static_cast(cast_output.scale_inv.dptr), row_length, + num_rows); } else { // Statically-compiled general kernel constexpr size_t load_size = 4; constexpr size_t store_size = 4; @@ -339,7 +346,8 @@ void cast_transpose(const Tensor &input, const Tensor &noop, Tensor *cast_output static_cast(cast_output.data.dptr), static_cast(transposed_output.data.dptr), static_cast(cast_output.scale.dptr), - static_cast(cast_output.amax.dptr), row_length, num_rows); + static_cast(cast_output.amax.dptr), + static_cast(cast_output.scale_inv.dptr), row_length, num_rows); }); // NOLINT(*) ); // NOLINT(*) } diff --git a/transformer_engine/common/transpose/cast_transpose_fusion.cu b/transformer_engine/common/transpose/cast_transpose_fusion.cu index d96757990f..a8361d57ea 100644 --- a/transformer_engine/common/transpose/cast_transpose_fusion.cu +++ b/transformer_engine/common/transpose/cast_transpose_fusion.cu @@ -433,15 +433,19 @@ __global__ void __launch_bounds__(cast_transpose_num_threads) } } - /* warp tile amax reduce*/ - amax = reduce_max(amax, warp_id); - - if (threadIdx.x == 0) { - static_assert(std::is_same::value); - if (param.amax != nullptr) { + // Reduce amax over block + if (param.amax != nullptr) { + amax = reduce_max(amax, warp_id); + if (threadIdx.x == 0) { + static_assert(std::is_same::value); atomicMaxFloat(param.amax, amax); } } + + // Update scale-inverse + if (blockIdx.x == 0 && threadIdx.x == 0 && param.scale_inv != nullptr) { + reciprocal(param.scale_inv, scale); + } } static const char *ActTypeToString[] = { @@ -870,17 +874,18 @@ __global__ void __launch_bounds__(cast_transpose_num_threads) __syncthreads(); } - /* warp tile amax reduce*/ - max = reduce_max(max, warp_id); - - if (threadIdx.x == 0) { - static_assert(std::is_same::value); - if (amax != nullptr) { + // Reduce amax over block + if (amax != nullptr) { + max = reduce_max(max, warp_id); + if (threadIdx.x == 0) { + static_assert(std::is_same::value); atomicMaxFloat(amax, max); } - if (scale_inv != nullptr) { - reciprocal(scale_inv, scale); - } + } + + // Update scale-inverse + if (blockIdx.x == 0 && threadIdx.x == 0 && scale_inv != nullptr) { + reciprocal(scale_inv, scale); } } @@ -1079,17 +1084,18 @@ __global__ void __launch_bounds__(cast_transpose_num_threads) __syncthreads(); } - /* warp tile amax reduce*/ - max = reduce_max(max, warp_id); - - if (threadIdx.x == 0) { - static_assert(std::is_same::value); - if (amax != nullptr) { + // Reduce amax over block + if (amax != nullptr) { + max = reduce_max(max, warp_id); + if (threadIdx.x == 0) { + static_assert(std::is_same::value); atomicMaxFloat(amax, max); } - if (scale_inv != nullptr) { - reciprocal(scale_inv, scale); - } + } + + // Update scale-inverse + if (blockIdx.x == 0 && threadIdx.x == 0 && scale_inv != nullptr) { + reciprocal(scale_inv, scale); } } diff --git a/transformer_engine/common/transpose/multi_cast_transpose.cu b/transformer_engine/common/transpose/multi_cast_transpose.cu index 8e6e90a7bf..4026016519 100644 --- a/transformer_engine/common/transpose/multi_cast_transpose.cu +++ b/transformer_engine/common/transpose/multi_cast_transpose.cu @@ -36,6 +36,8 @@ struct MultiCastTransposeArgs { void* scale_list[kMaxTensorsPerKernel]; // (output) AMAX's of input tensors void* amax_list[kMaxTensorsPerKernel]; + // (output) Inverse of scaling factor for output tensors + void* scale_inv_list[kMaxTensorsPerKernel]; // Input matrix heights int num_rows_list[kMaxTensorsPerKernel]; // Input matrix widths @@ -82,7 +84,8 @@ __global__ void __launch_bounds__(threads_per_block) OType* output_t = reinterpret_cast(args.output_t_list[tensor_id]); const CType* scale_ptr = reinterpret_cast(args.scale_list[tensor_id]); const CType scale = scale_ptr == nullptr ? 1 : *scale_ptr; - CType* amax = reinterpret_cast(args.amax_list[tensor_id]); + CType* amax_ptr = reinterpret_cast(args.amax_list[tensor_id]); + CType* scale_inv_ptr = reinterpret_cast(args.scale_inv_list[tensor_id]); const int num_rows = args.num_rows_list[tensor_id]; const int row_length = args.row_length_list[tensor_id]; @@ -183,7 +186,10 @@ __global__ void __launch_bounds__(threads_per_block) local_amax = reduce_max(local_amax, tidy); if (tid == 0) { static_assert(std::is_same::value); - if (amax != nullptr) atomicMaxFloat(amax, local_amax); + if (amax_ptr != nullptr) atomicMaxFloat(amax_ptr, local_amax); + } + if (tile_id == 0 && tid == 0 && scale_inv_ptr != nullptr) { + reciprocal(scale_inv_ptr, scale); } } @@ -285,6 +291,7 @@ void multi_cast_transpose(const std::vector input_list, kernel_args.output_t_list[pos] = transposed_output_list[tensor_id]->data.dptr; kernel_args.scale_list[pos] = cast_output_list[tensor_id]->scale.dptr; kernel_args.amax_list[pos] = cast_output_list[tensor_id]->amax.dptr; + kernel_args.scale_inv_list[pos] = cast_output_list[tensor_id]->scale_inv.dptr; kernel_args.num_rows_list[pos] = num_rows; kernel_args.row_length_list[pos] = row_length; kernel_args.block_range[pos + 1] = kernel_args.block_range[pos] + num_tiles; diff --git a/transformer_engine/common/transpose/rtc/cast_transpose.cu b/transformer_engine/common/transpose/rtc/cast_transpose.cu index 6ea8326147..07244a42e9 100644 --- a/transformer_engine/common/transpose/rtc/cast_transpose.cu +++ b/transformer_engine/common/transpose/rtc/cast_transpose.cu @@ -25,7 +25,7 @@ __global__ void __launch_bounds__(block_size) cast_transpose_optimized_kernel( const IType* __restrict__ const input, const CType* __restrict__ const noop, OType* __restrict__ const output_c, OType* __restrict__ const output_t, const CType* __restrict__ const scale_ptr, CType* __restrict__ const amax_ptr, - const size_t row_length, const size_t num_rows) { + CType* __restrict__ const scale_inv_ptr, const size_t row_length, const size_t num_rows) { if (noop != nullptr && noop[0] == 1.0f) return; // Vectorized load/store sizes @@ -121,4 +121,9 @@ __global__ void __launch_bounds__(block_size) cast_transpose_optimized_kernel( atomicMaxFloat(amax_ptr, amax); } } + + // Update scale-inverse + if (blockIdx.x == 0 && threadIdx.x == 0 && scale_inv_ptr != nullptr) { + reciprocal(scale_inv_ptr, scale); + } } diff --git a/transformer_engine/common/transpose/rtc/cast_transpose_fusion.cu b/transformer_engine/common/transpose/rtc/cast_transpose_fusion.cu index c005be98ef..4ba1cb4c69 100644 --- a/transformer_engine/common/transpose/rtc/cast_transpose_fusion.cu +++ b/transformer_engine/common/transpose/rtc/cast_transpose_fusion.cu @@ -229,12 +229,16 @@ __global__ void __launch_bounds__(BLOCK_SIZE) } } - // warp tile amax reduce - const CType max_block = reduce_max(amax, warp_id); - - if (threadIdx.x == 0) { - if (param.amax != nullptr) { + // Reduce amax over block + if (param.amax != nullptr) { + const CType max_block = reduce_max(amax, warp_id); + if (threadIdx.x == 0) { atomicMaxFloat(param.amax, max_block); } } + + // Update scale-inverse + if (blockIdx.x == 0 && threadIdx.x == 0 && param.scale_inv != nullptr) { + reciprocal(param.scale_inv, scale); + } } diff --git a/transformer_engine/common/util/cast.cu b/transformer_engine/common/util/cast.cu index 8dd2b98ebf..dd03afd21b 100644 --- a/transformer_engine/common/util/cast.cu +++ b/transformer_engine/common/util/cast.cu @@ -46,7 +46,8 @@ void fp8_quantize(const Tensor &input, Tensor *output, cudaStream_t stream) { reinterpret_cast(input.data.dptr), reinterpret_cast(output->data.dptr), reinterpret_cast(output->scale.dptr), - reinterpret_cast(output->amax.dptr), N, {}, + reinterpret_cast(output->amax.dptr), + reinterpret_cast(output->scale_inv.dptr), N, {}, stream);); // NOLINT(*) ); // NOLINT(*) } @@ -68,7 +69,7 @@ void fp8_dequantize(const Tensor &input, Tensor *output, cudaStream_t stream) { p.scale_inv = reinterpret_cast(input.scale_inv.dptr); VectorizedUnaryKernelLauncher( reinterpret_cast(input.data.dptr), - reinterpret_cast(output->data.dptr), nullptr, nullptr, N, p, + reinterpret_cast(output->data.dptr), nullptr, nullptr, nullptr, N, p, stream);); // NOLINT(*) ); // NOLINT(*) } diff --git a/transformer_engine/common/util/vectorized_pointwise.h b/transformer_engine/common/util/vectorized_pointwise.h index 63ad1857cf..8653bf45a4 100644 --- a/transformer_engine/common/util/vectorized_pointwise.h +++ b/transformer_engine/common/util/vectorized_pointwise.h @@ -168,12 +168,12 @@ template __launch_bounds__(unary_kernel_threads) __global__ void unary_kernel(const InputType *input, OutputType *output, const ComputeType *scale, - ComputeType *amax, Param p, const size_t N, + ComputeType *amax, ComputeType *scale_inv, Param p, const size_t N, const size_t num_aligned_elements) { VectorizedLoader loader(input, N); VectorizedStorer storer(output, N); ComputeType max = 0; - ComputeType s = 0; + ComputeType s = 1; if constexpr (is_fp8::value) { if (scale != nullptr) s = *scale; } @@ -199,12 +199,18 @@ __launch_bounds__(unary_kernel_threads) __global__ storer.store(tid, N); } if constexpr (is_fp8::value) { - /* warp tile amax reduce*/ - max = reduce_max(max, warp_id); + // Reduce amax over block + if (amax != nullptr) { + max = reduce_max(max, warp_id); + if (threadIdx.x == 0) { + static_assert(std::is_same::value); + atomicMaxFloat(amax, max); + } + } - if (threadIdx.x == 0 && amax != nullptr) { - static_assert(std::is_same::value); - atomicMaxFloat(amax, max); + // Update scale-inverse + if (blockIdx.x == 0 && threadIdx.x == 0 && scale_inv != nullptr) { + reciprocal(scale_inv, s); } } } @@ -214,13 +220,13 @@ template __launch_bounds__(unary_kernel_threads) __global__ void unary_grad_kernel(const InputTypeGrad *grad, const InputType *input, OutputType *output, - const ComputeType *scale, ComputeType *amax, Param p, const size_t N, - const size_t num_aligned_elements) { + const ComputeType *scale, ComputeType *amax, ComputeType *scale_inv, + Param p, const size_t N, const size_t num_aligned_elements) { VectorizedLoader loader(input, N); VectorizedLoader grad_loader(grad, N); VectorizedStorer storer(output, N); ComputeType max = 0; - ComputeType s = 0; + ComputeType s = 1; if constexpr (is_fp8::value) { if (scale != nullptr) s = *scale; } @@ -248,12 +254,18 @@ __launch_bounds__(unary_kernel_threads) __global__ storer.store(tid, N); } if constexpr (is_fp8::value) { - /* warp tile amax reduce*/ - max = reduce_max(max, warp_id); + // Reduce amax over block + if (amax != nullptr) { + max = reduce_max(max, warp_id); + if (threadIdx.x == 0) { + static_assert(std::is_same::value); + atomicMaxFloat(amax, max); + } + } - if (threadIdx.x == 0 && amax != nullptr) { - static_assert(std::is_same::value); - atomicMaxFloat(amax, max); + // Update scale-inverse + if (blockIdx.x == 0 && threadIdx.x == 0 && scale_inv != nullptr) { + reciprocal(scale_inv, s); } } } @@ -311,7 +323,7 @@ Alignment CheckAlignment(const size_t lead_dim, const int nvec, const T... ptrs) template void VectorizedUnaryKernelLauncher(const InputType *input, OutputType *output, const fp32 *scale, - fp32 *amax, const size_t N, const Param params, + fp32 *amax, fp32 *scale_inv, const size_t N, const Param params, cudaStream_t stream) { if (N != 0) { auto align = CheckAlignment(N, nvec, input, output); @@ -325,16 +337,16 @@ void VectorizedUnaryKernelLauncher(const InputType *input, OutputType *output, c switch (align) { case Alignment::SAME_ALIGNED: unary_kernel<<>>( - input, output, scale, amax, params, N, num_aligned_elements); + input, output, scale, amax, scale_inv, params, N, num_aligned_elements); break; case Alignment::SAME_UNALIGNED: unary_kernel<<>>( - input, output, scale, amax, params, N, num_aligned_elements); + input, output, scale, amax, scale_inv, params, N, num_aligned_elements); break; case Alignment::DIFFERENT: { // If the pointers are aligned differently we cannot vectorize - unary_kernel<1, true, fp32, Param, OP> - <<>>(input, output, scale, amax, params, N, N); + unary_kernel<1, true, fp32, Param, OP><<>>( + input, output, scale, amax, scale_inv, params, N, N); break; } } @@ -345,7 +357,8 @@ template void VectorizedUnaryGradKernelLauncher(const InputTypeGrad *grad, const InputType *input, OutputType *output, const fp32 *scale, fp32 *amax, - const size_t N, const Param params, cudaStream_t stream) { + fp32 *scale_inv, const size_t N, const Param params, + cudaStream_t stream) { if (N != 0) { auto align = CheckAlignment(N, nvec, input, grad, output); @@ -358,16 +371,16 @@ void VectorizedUnaryGradKernelLauncher(const InputTypeGrad *grad, const InputTyp switch (align) { case Alignment::SAME_ALIGNED: unary_grad_kernel<<>>( - grad, input, output, scale, amax, params, N, num_aligned_elements); + grad, input, output, scale, amax, scale_inv, params, N, num_aligned_elements); break; case Alignment::SAME_UNALIGNED: unary_grad_kernel<<>>( - grad, input, output, scale, amax, params, N, num_aligned_elements); + grad, input, output, scale, amax, scale_inv, params, N, num_aligned_elements); break; case Alignment::DIFFERENT: { // If the pointers are aligned differently we cannot vectorize - unary_grad_kernel<1, true, fp32, Param, OP> - <<>>(grad, input, output, scale, amax, params, N, N); + unary_grad_kernel<1, true, fp32, Param, OP><<>>( + grad, input, output, scale, amax, scale_inv, params, N, N); break; } } @@ -379,8 +392,8 @@ template __launch_bounds__(unary_kernel_threads) __global__ void gated_act_kernel(const InputType *input, OutputType *output, const ComputeType *scale, - ComputeType *amax, const size_t m, const size_t n, const Param p, - const size_t num_aligned_elements) { + ComputeType *amax, ComputeType *scale_inv, const size_t m, const size_t n, + const Param p, const size_t num_aligned_elements) { const size_t M = num_aligned_elements * m; for (size_t tid = blockIdx.x * blockDim.x + threadIdx.x; tid < M; tid += gridDim.x * blockDim.x) { const size_t id_x = tid % num_aligned_elements; @@ -389,7 +402,7 @@ __launch_bounds__(unary_kernel_threads) __global__ VectorizedLoader loader1(input + id_y * n * 2 + n, n); VectorizedStorer storer(output + id_y * n, n); ComputeType max = 0; - ComputeType s = 0; + ComputeType s = 1; if constexpr (is_fp8::value) { if (scale != nullptr) s = *scale; } @@ -412,12 +425,18 @@ __launch_bounds__(unary_kernel_threads) __global__ storer.store(id_x, n); if constexpr (is_fp8::value) { - /* warp tile amax reduce*/ - max = reduce_max(max, warp_id); + // Reduce amax over block + if (amax != nullptr) { + max = reduce_max(max, warp_id); + if (threadIdx.x == 0) { + static_assert(std::is_same::value); + atomicMaxFloat(amax, max); + } + } - if (threadIdx.x == 0 && amax != nullptr) { - static_assert(std::is_same::value); - atomicMaxFloat(amax, max); + // Update scale-inverse + if (blockIdx.x == 0 && threadIdx.x == 0 && scale_inv != nullptr) { + reciprocal(scale_inv, s); } } } @@ -427,8 +446,8 @@ template void GatedActivationKernelLauncher(const InputType *input, OutputType *output, const fp32 *scale, - fp32 *amax, const size_t m, const size_t n, const Param &p, - cudaStream_t stream) { + fp32 *amax, fp32 *scale_inv, const size_t m, const size_t n, + const Param &p, cudaStream_t stream) { if (m != 0 && n != 0) { size_t num_aligned_elements = get_num_aligned_elements(input, n, nvec, sizeof(InputType)); constexpr size_t threads = unary_kernel_threads; @@ -439,18 +458,18 @@ void GatedActivationKernelLauncher(const InputType *input, OutputType *output, c switch (auto align = CheckAlignment(n, nvec, input, input + n, output)) { case Alignment::SAME_ALIGNED: gated_act_kernel - <<>>(input, output, scale, amax, m, n, p, + <<>>(input, output, scale, amax, scale_inv, m, n, p, num_aligned_elements); break; case Alignment::SAME_UNALIGNED: gated_act_kernel - <<>>(input, output, scale, amax, m, n, p, + <<>>(input, output, scale, amax, scale_inv, m, n, p, num_aligned_elements); break; case Alignment::DIFFERENT: { // If the pointers are aligned differently we cannot vectorize gated_act_kernel<1, true, ComputeType, Param, Activation> - <<>>(input, output, scale, amax, m, n, p, n); + <<>>(input, output, scale, amax, scale_inv, m, n, p, n); break; } } diff --git a/transformer_engine/common/utils.cuh b/transformer_engine/common/utils.cuh index bcfc0c608d..6703ce728c 100644 --- a/transformer_engine/common/utils.cuh +++ b/transformer_engine/common/utils.cuh @@ -852,6 +852,11 @@ __device__ __forceinline__ void reciprocal(T *value_inv, const T value) { *value_inv = 1 / value; } +template <> +__device__ __forceinline__ void reciprocal(float *value_inv, const float value) { + *value_inv = __frcp_rn(value); +} + } // namespace transformer_engine #endif // TRANSFORMER_ENGINE_COMMON_UTILS_CUH_ diff --git a/transformer_engine/pytorch/cpp_extensions/_common.py b/transformer_engine/pytorch/cpp_extensions/_common.py new file mode 100644 index 0000000000..6ab7d95138 --- /dev/null +++ b/transformer_engine/pytorch/cpp_extensions/_common.py @@ -0,0 +1,87 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Helper functions for C++ extensions""" +import functools +from typing import Dict, Optional, Tuple, Union + +import torch + +import transformer_engine_torch as tex + + +@functools.lru_cache(maxsize=None) +def empty_tensor() -> torch.Tensor: + """Get tensor with no entries and no data""" + return torch.Tensor() + + +def canonicalize_fp8_scales( + *, + scale: Optional[torch.Tensor] = None, + amax: Optional[torch.Tensor] = None, + scale_inv: Optional[torch.Tensor] = None, + fp8_meta: Optional[tex.FP8TensorMeta] = None, + fp8_meta_index: Union[tex.FP8FwdTensors, tex.FP8BwdTensors, None] = None, + allow_multiple_offsets: bool = True, +) -> Tuple[Dict[str, torch.Tensor], Dict[str, int]]: + """Canonicalize FP8 scaling factors (scale, amax, scale-inverse) + + If a scaling factor is not provided, try to access it within the + FP8 meta tensors. Returns dict with tensors and dict with tensor + offsets. + + """ + + # Default: use provided scales with no offsets + scale_offset = 0 + amax_offset = 0 + scale_inv_offset = 0 + + # Get scales from FP8 meta tensors if needed + if (fp8_meta is not None) and any(arg is None for arg in (scale, amax, scale_inv)): + if fp8_meta_index is None: + raise ValueError("Provided `fp8_meta` without corresponding `fp8_meta_index`") + fp8_meta_index = int(fp8_meta_index) + if scale is None: + scale = fp8_meta.scale + scale_offset = fp8_meta_index + if amax is None: + amax = fp8_meta.amax_history + amax_offset = fp8_meta_index + if scale_inv is None: + scale_inv = fp8_meta.scale_inv + scale_inv_offset = fp8_meta_index + + # Construct empty tensors if needed + if scale is None: + scale = empty_tensor() + scale_offset = 0 + if amax is None: + amax = empty_tensor() + amax_offset = 0 + if scale_inv is None: + scale_inv = empty_tensor() + scale_inv_offset = 0 + + # Force offsets to be the same if needed + if not allow_multiple_offsets and not scale_offset == amax_offset == scale_inv_offset: + if scale_offset != 0: + scale = scale[scale_offset] + scale_offset = 0 + if amax_offset != 0: + amax = amax[0][amax_offset] + amax_offset = 0 + if scale_inv_offset != 0: + scale_inv = scale_inv[scale_inv_offset] + scale_inv_offset = 0 + + # Pack tensors and offsets into dicts + tensors = dict(scale=scale, amax=amax, scale_inv=scale_inv) + offsets = dict( + scale_offset=scale_offset, + amax_offset=amax_offset, + scale_inv_offset=scale_inv_offset, + ) + return tensors, offsets diff --git a/transformer_engine/pytorch/cpp_extensions/activation.py b/transformer_engine/pytorch/cpp_extensions/activation.py index 767fe25291..f204982aa0 100644 --- a/transformer_engine/pytorch/cpp_extensions/activation.py +++ b/transformer_engine/pytorch/cpp_extensions/activation.py @@ -3,192 +3,235 @@ # See LICENSE for license information. """Python interface for activation extensions""" -from typing import Union +from typing import Optional, Union + import torch -import transformer_engine_torch as tex +import transformer_engine_torch as tex +from ._common import canonicalize_fp8_scales __all__ = ["gelu", "relu", "reglu", "geglu", "swiglu", "qgelu", "srelu"] def gelu( inp: torch.Tensor, - fp8_meta_tensor: tex.FP8TensorMeta, - fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors], + fp8_meta_tensor: Optional[tex.FP8TensorMeta], + fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors, None], otype: tex.DType, + scale: Optional[torch.Tensor] = None, + amax: Optional[torch.Tensor] = None, + scale_inv: Optional[torch.Tensor] = None, ) -> torch.Tensor: """GeLU with FP8 output""" - empty_tensor = torch.Tensor() - if fp8_meta_tensor is not None: - scale = fp8_meta_tensor.scale - amax_history = fp8_meta_tensor.amax_history - scale_inv = fp8_meta_tensor.scale_inv - else: - scale = empty_tensor - amax_history = empty_tensor - scale_inv = empty_tensor + # Get FP8 scaling factors + fp8_scales, fp8_scales_offsets = canonicalize_fp8_scales( + scale=scale, + amax=amax, + scale_inv=scale_inv, + fp8_meta=fp8_meta_tensor, + fp8_meta_index=fp8_tensor, + allow_multiple_offsets=False, + ) + + # Launch kernel return torch.ops.tex_ts.gelu_ts( inp, - scale, - amax_history, - scale_inv, - fp8_tensor, + fp8_scales["scale"], + fp8_scales["amax"], + fp8_scales["scale_inv"], + fp8_scales_offsets["scale_offset"], otype, ) def relu( inp: torch.Tensor, - fp8_meta_tensor: tex.FP8TensorMeta, - fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors], + fp8_meta_tensor: Optional[tex.FP8TensorMeta], + fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors, None], otype: tex.DType, + scale: Optional[torch.Tensor] = None, + amax: Optional[torch.Tensor] = None, + scale_inv: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ReLU with FP8 output""" - empty_tensor = torch.Tensor() - if fp8_meta_tensor is not None: - scale = fp8_meta_tensor.scale - amax_history = fp8_meta_tensor.amax_history - scale_inv = fp8_meta_tensor.scale_inv - else: - scale = empty_tensor - amax_history = empty_tensor - scale_inv = empty_tensor + + # Get FP8 scaling factors + fp8_scales, fp8_scales_offsets = canonicalize_fp8_scales( + scale=scale, + amax=amax, + scale_inv=scale_inv, + fp8_meta=fp8_meta_tensor, + fp8_meta_index=fp8_tensor, + allow_multiple_offsets=False, + ) + + # Launch kernel return torch.ops.tex_ts.relu_ts( inp, - scale, - amax_history, - scale_inv, - fp8_tensor, + fp8_scales["scale"], + fp8_scales["amax"], + fp8_scales["scale_inv"], + fp8_scales_offsets["scale_offset"], otype, ) def geglu( inp: torch.Tensor, - fp8_meta_tensor: tex.FP8TensorMeta, - fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors], + fp8_meta_tensor: Optional[tex.FP8TensorMeta], + fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors, None], otype: tex.DType, + scale: Optional[torch.Tensor] = None, + amax: Optional[torch.Tensor] = None, + scale_inv: Optional[torch.Tensor] = None, ) -> torch.Tensor: """GeGLU with FP8 output""" - empty_tensor = torch.Tensor() - if fp8_meta_tensor is not None: - scale = fp8_meta_tensor.scale - amax_history = fp8_meta_tensor.amax_history - scale_inv = fp8_meta_tensor.scale_inv - else: - scale = empty_tensor - amax_history = empty_tensor - scale_inv = empty_tensor + + # Get FP8 scaling factors + fp8_scales, fp8_scales_offsets = canonicalize_fp8_scales( + scale=scale, + amax=amax, + scale_inv=scale_inv, + fp8_meta=fp8_meta_tensor, + fp8_meta_index=fp8_tensor, + allow_multiple_offsets=False, + ) + + # Launch kernel return torch.ops.tex_ts.geglu_ts( inp, - scale, - amax_history, - scale_inv, - fp8_tensor, + fp8_scales["scale"], + fp8_scales["amax"], + fp8_scales["scale_inv"], + fp8_scales_offsets["scale_offset"], otype, ) def reglu( inp: torch.Tensor, - fp8_meta_tensor: tex.FP8TensorMeta, - fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors], + fp8_meta_tensor: Optional[tex.FP8TensorMeta], + fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors, None], otype: tex.DType, + scale: Optional[torch.Tensor] = None, + amax: Optional[torch.Tensor] = None, + scale_inv: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ReGLU with FP8 output""" - empty_tensor = torch.Tensor() - if fp8_meta_tensor is not None: - scale = fp8_meta_tensor.scale - amax_history = fp8_meta_tensor.amax_history - scale_inv = fp8_meta_tensor.scale_inv - else: - scale = empty_tensor - amax_history = empty_tensor - scale_inv = empty_tensor + + # Get FP8 scaling factors + fp8_scales, fp8_scales_offsets = canonicalize_fp8_scales( + scale=scale, + amax=amax, + scale_inv=scale_inv, + fp8_meta=fp8_meta_tensor, + fp8_meta_index=fp8_tensor, + allow_multiple_offsets=False, + ) + + # Launch kernel return torch.ops.tex_ts.reglu_ts( inp, - scale, - amax_history, - scale_inv, - fp8_tensor, + fp8_scales["scale"], + fp8_scales["amax"], + fp8_scales["scale_inv"], + fp8_scales_offsets["scale_offset"], otype, ) def swiglu( inp: torch.Tensor, - fp8_meta_tensor: tex.FP8TensorMeta, - fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors], + fp8_meta_tensor: Optional[tex.FP8TensorMeta], + fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors, None], otype: tex.DType, + scale: Optional[torch.Tensor] = None, + amax: Optional[torch.Tensor] = None, + scale_inv: Optional[torch.Tensor] = None, ) -> torch.Tensor: """SwiGLU with FP8 output""" - empty_tensor = torch.Tensor() - if fp8_meta_tensor is not None: - scale = fp8_meta_tensor.scale - amax_history = fp8_meta_tensor.amax_history - scale_inv = fp8_meta_tensor.scale_inv - else: - scale = empty_tensor - amax_history = empty_tensor - scale_inv = empty_tensor + + # Get FP8 scaling factors + fp8_scales, fp8_scales_offsets = canonicalize_fp8_scales( + scale=scale, + amax=amax, + scale_inv=scale_inv, + fp8_meta=fp8_meta_tensor, + fp8_meta_index=fp8_tensor, + allow_multiple_offsets=False, + ) + + # Launch kernel return torch.ops.tex_ts.swiglu_ts( inp, - scale, - amax_history, - scale_inv, - fp8_tensor, + fp8_scales["scale"], + fp8_scales["amax"], + fp8_scales["scale_inv"], + fp8_scales_offsets["scale_offset"], otype, ) def qgelu( inp: torch.Tensor, - fp8_meta_tensor: tex.FP8TensorMeta, - fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors], + fp8_meta_tensor: Optional[tex.FP8TensorMeta], + fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors, None], otype: tex.DType, + scale: Optional[torch.Tensor] = None, + amax: Optional[torch.Tensor] = None, + scale_inv: Optional[torch.Tensor] = None, ) -> torch.Tensor: """QuickGELU with FP8 output""" - empty_tensor = torch.Tensor() - if fp8_meta_tensor is not None: - scale = fp8_meta_tensor.scale - amax_history = fp8_meta_tensor.amax_history - scale_inv = fp8_meta_tensor.scale_inv - else: - scale = empty_tensor - amax_history = empty_tensor - scale_inv = empty_tensor + + # Get FP8 scaling factors + fp8_scales, fp8_scales_offsets = canonicalize_fp8_scales( + scale=scale, + amax=amax, + scale_inv=scale_inv, + fp8_meta=fp8_meta_tensor, + fp8_meta_index=fp8_tensor, + allow_multiple_offsets=False, + ) + + # Launch kernel return torch.ops.tex_ts.qgelu_ts( inp, - scale, - amax_history, - scale_inv, - fp8_tensor, + fp8_scales["scale"], + fp8_scales["amax"], + fp8_scales["scale_inv"], + fp8_scales_offsets["scale_offset"], otype, ) def srelu( inp: torch.Tensor, - fp8_meta_tensor: tex.FP8TensorMeta, - fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors], + fp8_meta_tensor: Optional[tex.FP8TensorMeta], + fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors, None], otype: tex.DType, + scale: Optional[torch.Tensor] = None, + amax: Optional[torch.Tensor] = None, + scale_inv: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ReLU with FP8 output""" - empty_tensor = torch.Tensor() - if fp8_meta_tensor is not None: - scale = fp8_meta_tensor.scale - amax_history = fp8_meta_tensor.amax_history - scale_inv = fp8_meta_tensor.scale_inv - else: - scale = empty_tensor - amax_history = empty_tensor - scale_inv = empty_tensor + + # Get FP8 scaling factors + fp8_scales, fp8_scales_offsets = canonicalize_fp8_scales( + scale=scale, + amax=amax, + scale_inv=scale_inv, + fp8_meta=fp8_meta_tensor, + fp8_meta_index=fp8_tensor, + allow_multiple_offsets=False, + ) + + # Launch kernel return torch.ops.tex_ts.srelu_ts( inp, - scale, - amax_history, - scale_inv, - fp8_tensor, + fp8_scales["scale"], + fp8_scales["amax"], + fp8_scales["scale_inv"], + fp8_scales_offsets["scale_offset"], otype, ) diff --git a/transformer_engine/pytorch/cpp_extensions/cast.py b/transformer_engine/pytorch/cpp_extensions/cast.py index 2856d4727b..0c78a65a6c 100644 --- a/transformer_engine/pytorch/cpp_extensions/cast.py +++ b/transformer_engine/pytorch/cpp_extensions/cast.py @@ -4,57 +4,91 @@ """Python interface for cast extensions""" from typing import Optional, Union + import torch -import transformer_engine_torch as tex +import transformer_engine_torch as tex +from ._common import canonicalize_fp8_scales, empty_tensor __all__ = ["cast_to_fp8", "cast_from_fp8"] def cast_to_fp8( inp: torch.Tensor, - fp8_meta_tensor: tex.FP8TensorMeta, - fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors], + fp8_meta_tensor: Optional[tex.FP8TensorMeta], + fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors, None], otype: tex.DType, out: Optional[torch.Tensor] = None, -) -> Optional[torch.Tensor]: + scale: Optional[torch.Tensor] = None, + amax: Optional[torch.Tensor] = None, + scale_inv: Optional[torch.Tensor] = None, +) -> torch.Tensor: """Cast input to FP8""" - if out is not None: - if inp.nelement() > 0: - torch.ops.tex_ts.cast_to_fp8_noalloc_ts( - inp, - fp8_meta_tensor.scale, - out, - fp8_meta_tensor.amax_history, - fp8_meta_tensor.scale_inv, - fp8_tensor, - otype, - ) - return None - - return torch.ops.tex_ts.cast_to_fp8_ts( - inp, - fp8_meta_tensor.scale, - fp8_meta_tensor.amax_history, - fp8_meta_tensor.scale_inv, - fp8_tensor, - otype, + # Get FP8 scaling factors + fp8_scales, fp8_scales_offsets = canonicalize_fp8_scales( + scale=scale, + amax=amax, + scale_inv=scale_inv, + fp8_meta=fp8_meta_tensor, + fp8_meta_index=fp8_tensor, + allow_multiple_offsets=False, ) + # Launch FP8 cast kernel + if inp.nelement() == 0: + if out is None: + out = torch.empty_like(inp, dtype=torch.uint8) + elif out is None: + out = torch.ops.tex_ts.cast_to_fp8_ts( + inp, + fp8_scales["scale"], + fp8_scales["amax"], + fp8_scales["scale_inv"], + fp8_scales_offsets["scale_offset"], + otype, + ) + else: + torch.ops.tex_ts.cast_to_fp8_noalloc_ts( + inp, + fp8_scales["scale"], + out, + fp8_scales["amax"], + fp8_scales["scale_inv"], + fp8_scales_offsets["scale_offset"], + otype, + ) + return out + def cast_from_fp8( inp: torch.Tensor, - fp8_meta_tensor: tex.FP8TensorMeta, - fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors], + fp8_meta_tensor: Optional[tex.FP8TensorMeta], + fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors, None], itype: tex.DType, otype: tex.DType, + scale_inv: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Cast input from FP8""" + + # Get scaling factors from FP8 meta tensors if needed + scale_inv_offset = 0 + if (fp8_meta_tensor is not None) and (scale_inv is None): + if fp8_tensor is None: + raise ValueError("Provided `fp8_meta_tensor` without corresponding `fp8_tensor`") + scale_inv = fp8_meta_tensor.scale_inv + scale_inv_offset = int(fp8_tensor) + + # Construct empty tensors if needed + if scale_inv is None: + scale_inv = empty_tensor() + scale_inv_offset = 0 + + # Launch FP8 cast kernel return torch.ops.tex_ts.cast_from_fp8_ts( inp, - fp8_meta_tensor.scale_inv, - fp8_tensor, + scale_inv, + scale_inv_offset, itype, otype, ) diff --git a/transformer_engine/pytorch/cpp_extensions/normalization.py b/transformer_engine/pytorch/cpp_extensions/normalization.py index dd90bb0b66..50fd6b7709 100644 --- a/transformer_engine/pytorch/cpp_extensions/normalization.py +++ b/transformer_engine/pytorch/cpp_extensions/normalization.py @@ -4,8 +4,11 @@ """Python interface for normalization extensions""" from typing import Optional, Tuple, Union + import torch + import transformer_engine_torch as tex +from ._common import canonicalize_fp8_scales __all__ = [ @@ -23,46 +26,55 @@ def layernorm_fwd_fp8( weight: torch.Tensor, bias: torch.Tensor, eps: float, - fp8_meta_tensor: tex.FP8TensorMeta, - fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors], + fp8_meta_tensor: Optional[tex.FP8TensorMeta], + fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors, None], otype: tex.DType, sm_margin: int, zero_centered_gamma: bool, ln_out: Optional[torch.Tensor] = None, + scale: Optional[torch.Tensor] = None, + amax: Optional[torch.Tensor] = None, + scale_inv: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """LayerNorm with FP8 output""" + + # Get FP8 scaling factors + fp8_scales, fp8_scales_offsets = canonicalize_fp8_scales( + scale=scale, + amax=amax, + scale_inv=scale_inv, + fp8_meta=fp8_meta_tensor, + fp8_meta_index=fp8_tensor, + ) + + # Launch kernel if ln_out is not None: return tex.layernorm_fwd_fp8_noalloc( inp, weight, bias, eps, - fp8_meta_tensor.scale, + fp8_scales["scale"], ln_out, - fp8_meta_tensor.amax_history, - fp8_meta_tensor.scale_inv, + fp8_scales["amax"], + fp8_scales["scale_inv"], otype, sm_margin, zero_centered_gamma, - scale_offset=int(fp8_tensor), - amax_offset=int(fp8_tensor), - scale_inv_offset=int(fp8_tensor), + **fp8_scales_offsets, ) - return tex.layernorm_fwd_fp8( inp, weight, bias, eps, - fp8_meta_tensor.scale, - fp8_meta_tensor.amax_history, - fp8_meta_tensor.scale_inv, + fp8_scales["scale"], + fp8_scales["amax"], + fp8_scales["scale_inv"], otype, sm_margin, zero_centered_gamma, - scale_offset=int(fp8_tensor), - amax_offset=int(fp8_tensor), - scale_inv_offset=int(fp8_tensor), + **fp8_scales_offsets, ) @@ -71,26 +83,41 @@ def layernorm_fwd_fp8_inf( weight: torch.Tensor, bias: torch.Tensor, eps: float, - fp8_meta_tensor: tex.FP8TensorMeta, - fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors], + fp8_meta_tensor: Optional[tex.FP8TensorMeta], + fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors, None], otype: tex.DType, sm_margin: int, zero_centered_gamma, + scale: Optional[torch.Tensor] = None, + amax: Optional[torch.Tensor] = None, + scale_inv: Optional[torch.Tensor] = None, ) -> torch.Tensor: """LayerNorm with FP8 output. This version of layernorm_fwd_fp8 is specialized for inference, and returns only the normalized output. """ + + # Get FP8 scaling factors + fp8_scales, fp8_scales_offsets = canonicalize_fp8_scales( + scale=scale, + amax=amax, + scale_inv=scale_inv, + fp8_meta=fp8_meta_tensor, + fp8_meta_index=fp8_tensor, + allow_multiple_offsets=False, + ) + + # Launch kernel ret = torch.ops.tex_ts.layernorm_fwd_fp8_inf_ts( inp, weight, bias, eps, - fp8_meta_tensor.scale, - fp8_meta_tensor.amax_history, - fp8_meta_tensor.scale_inv, - fp8_tensor, + fp8_scales["scale"], + fp8_scales["amax"], + fp8_scales["scale_inv"], + fp8_scales_offsets["scale_offset"], otype, sm_margin, zero_centered_gamma, @@ -121,44 +148,53 @@ def rmsnorm_fwd_fp8( inp: torch.Tensor, weight: torch.Tensor, eps: float, - fp8_meta_tensor: tex.FP8TensorMeta, - fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors], + fp8_meta_tensor: Optional[tex.FP8TensorMeta], + fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors, None], otype: tex.DType, sm_margin: int, zero_centered_gamma: bool, rmsnorm_out: Optional[torch.Tensor] = None, + scale: Optional[torch.Tensor] = None, + amax: Optional[torch.Tensor] = None, + scale_inv: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """RMSNorm with FP8 output""" + + # Get FP8 scaling factors + fp8_scales, fp8_scales_offsets = canonicalize_fp8_scales( + scale=scale, + amax=amax, + scale_inv=scale_inv, + fp8_meta=fp8_meta_tensor, + fp8_meta_index=fp8_tensor, + ) + + # Launch kernel if rmsnorm_out is not None: return tex.rmsnorm_fwd_fp8_noalloc( inp, weight, eps, - fp8_meta_tensor.scale, + fp8_scales["scale"], rmsnorm_out, - fp8_meta_tensor.amax_history, - fp8_meta_tensor.scale_inv, + fp8_scales["amax"], + fp8_scales["scale_inv"], otype, sm_margin, zero_centered_gamma, - scale_offset=int(fp8_tensor), - amax_offset=int(fp8_tensor), - scale_inv_offset=int(fp8_tensor), + **fp8_scales_offsets, ) - return tex.rmsnorm_fwd_fp8( inp, weight, eps, - fp8_meta_tensor.scale, - fp8_meta_tensor.amax_history, - fp8_meta_tensor.scale_inv, + fp8_scales["scale"], + fp8_scales["amax"], + fp8_scales["scale_inv"], otype, sm_margin, zero_centered_gamma, - scale_offset=int(fp8_tensor), - amax_offset=int(fp8_tensor), - scale_inv_offset=int(fp8_tensor), + **fp8_scales_offsets, ) @@ -166,25 +202,40 @@ def rmsnorm_fwd_fp8_inf( inp: torch.Tensor, weight: torch.Tensor, eps: float, - fp8_meta_tensor: tex.FP8TensorMeta, - fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors], + fp8_meta_tensor: Optional[tex.FP8TensorMeta], + fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors, None], otype: tex.DType, sm_margin: int, zero_centered_gamma, + scale: Optional[torch.Tensor] = None, + amax: Optional[torch.Tensor] = None, + scale_inv: Optional[torch.Tensor] = None, ) -> torch.Tensor: """RMSNorm with FP8 output. This version of rmsnorm_fwd_fp8 is specialized for inference, and returns only the normalized output. """ + + # Get FP8 scaling factors + fp8_scales, fp8_scales_offsets = canonicalize_fp8_scales( + scale=scale, + amax=amax, + scale_inv=scale_inv, + fp8_meta=fp8_meta_tensor, + fp8_meta_index=fp8_tensor, + allow_multiple_offsets=False, + ) + + # Launch kernel ret = torch.ops.tex_ts.rmsnorm_fwd_fp8_inf_ts( inp, weight, eps, - fp8_meta_tensor.scale, - fp8_meta_tensor.amax_history, - fp8_meta_tensor.scale_inv, - fp8_tensor, + fp8_scales["scale"], + fp8_scales["amax"], + fp8_scales["scale_inv"], + fp8_scales_offsets["scale_offset"], otype, sm_margin, zero_centered_gamma, diff --git a/transformer_engine/pytorch/cpp_extensions/transpose.py b/transformer_engine/pytorch/cpp_extensions/transpose.py index d96b743b9e..37a1b59da2 100644 --- a/transformer_engine/pytorch/cpp_extensions/transpose.py +++ b/transformer_engine/pytorch/cpp_extensions/transpose.py @@ -4,9 +4,12 @@ """Python interface for transpose extensions""" from typing import List, Optional, Tuple, Union + import torch + import transformer_engine_torch as tex from ..constants import TE_DType +from ._common import canonicalize_fp8_scales, empty_tensor __all__ = [ @@ -20,83 +23,115 @@ def fp8_cast_transpose_fused( inp: torch.Tensor, - fp8_meta_tensor: tex.FP8TensorMeta, - fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors], + fp8_meta_tensor: Optional[tex.FP8TensorMeta], + fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors, None], otype: tex.DType, cast_out: Optional[torch.Tensor] = None, transpose_out: Optional[torch.Tensor] = None, + scale: Optional[torch.Tensor] = None, + amax: Optional[torch.Tensor] = None, + scale_inv: Optional[torch.Tensor] = None, noop_flag: Optional[torch.Tensor] = None, -) -> Union[Tuple[torch.Tensor, torch.Tensor], None]: +) -> Tuple[torch.Tensor, torch.Tensor]: """Cast + Transpose with FP8 output""" - return_outputs = False + # Allocate outputs if needed if transpose_out is None: transpose_out = torch.empty(inp.shape[1], inp.shape[0], device="cuda", dtype=torch.uint8) - return_outputs = True if cast_out is None: cast_out = torch.empty_like(inp, dtype=torch.uint8) - return_outputs = True + # Get FP8 scaling factors + fp8_scales, fp8_scales_offsets = canonicalize_fp8_scales( + scale=scale, + amax=amax, + scale_inv=scale_inv, + fp8_meta=fp8_meta_tensor, + fp8_meta_index=fp8_tensor, + ) + + # Construct no-op flag if needed if noop_flag is None: - noop_flag = torch.Tensor() + noop_flag = empty_tensor() + # Launch kernel if needed if inp.nelement() > 0: tex.fused_cast_transpose_noop( inp, noop_flag, - fp8_meta_tensor.scale, - fp8_meta_tensor.amax_history, - fp8_meta_tensor.scale_inv, + fp8_scales["scale"], + fp8_scales["amax"], + fp8_scales["scale_inv"], cast_out, transpose_out, otype, - scale_offset=int(fp8_tensor), - amax_offset=int(fp8_tensor), - scale_inv_offset=int(fp8_tensor), + **fp8_scales_offsets, ) - if return_outputs: - return cast_out, transpose_out - return None + return cast_out, transpose_out def fp8_cast_transpose_bgrad_fused( inp: torch.Tensor, - fp8_meta_tensor: tex.FP8TensorMeta, - fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors], + fp8_meta_tensor: Optional[tex.FP8TensorMeta], + fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors, None], otype: tex.DType, + scale: Optional[torch.Tensor] = None, + amax: Optional[torch.Tensor] = None, + scale_inv: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Cast + Transpose + BGRAD with FP8 output""" + + # Get FP8 scaling factors + fp8_scales, fp8_scales_offsets = canonicalize_fp8_scales( + scale=scale, + amax=amax, + scale_inv=scale_inv, + fp8_meta=fp8_meta_tensor, + fp8_meta_index=fp8_tensor, + ) + + # Launch kernel return tex.fused_cast_transpose_bgrad( inp, - fp8_meta_tensor.scale, - fp8_meta_tensor.amax_history, - fp8_meta_tensor.scale_inv, + fp8_scales["scale"], + fp8_scales["amax"], + fp8_scales["scale_inv"], otype, - scale_offset=int(fp8_tensor), - amax_offset=int(fp8_tensor), - scale_inv_offset=int(fp8_tensor), + **fp8_scales_offsets, ) def fp8_transpose_bgrad_fused( inp: torch.Tensor, - fp8_meta_tensor: tex.FP8TensorMeta, - fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors], + fp8_meta_tensor: Optional[tex.FP8TensorMeta], + fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors, None], otype: tex.DType, grad_bias_type: torch.dtype, + scale: Optional[torch.Tensor] = None, + amax: Optional[torch.Tensor] = None, + scale_inv: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Transpose + BGRAD with FP8 output""" + + # Get FP8 scaling factors + fp8_scales, fp8_scales_offsets = canonicalize_fp8_scales( + scale=scale, + amax=amax, + scale_inv=scale_inv, + fp8_meta=fp8_meta_tensor, + fp8_meta_index=fp8_tensor, + ) + + # Launch kernel return tex.fused_fp8_transpose_bgrad( inp, - fp8_meta_tensor.scale, - fp8_meta_tensor.amax_history, - fp8_meta_tensor.scale_inv, + fp8_scales["scale"], + fp8_scales["amax"], + fp8_scales["scale_inv"], otype, TE_DType[grad_bias_type], - scale_offset=int(fp8_tensor), - amax_offset=int(fp8_tensor), - scale_inv_offset=int(fp8_tensor), + **fp8_scales_offsets, ) @@ -106,18 +141,30 @@ def fp8_cast_transpose_bgrad_dgelu_fused( fp8_meta_tensor: tex.FP8TensorMeta, fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors], otype: tex.DType, + scale: Optional[torch.Tensor] = None, + amax: Optional[torch.Tensor] = None, + scale_inv: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Cast + Transpose + BGRAD + DGELU with FP8 output""" + + # Get FP8 scaling factors + fp8_scales, fp8_scales_offsets = canonicalize_fp8_scales( + scale=scale, + amax=amax, + scale_inv=scale_inv, + fp8_meta=fp8_meta_tensor, + fp8_meta_index=fp8_tensor, + ) + + # Launch kernel return tex.fused_cast_transpose_bgrad_dgelu( grad_output, gelu_input, - fp8_meta_tensor.scale, - fp8_meta_tensor.amax_history, - fp8_meta_tensor.scale_inv, + fp8_scales["scale"], + fp8_scales["amax"], + fp8_scales["scale_inv"], otype, - scale_offset=int(fp8_tensor), - amax_offset=int(fp8_tensor), - scale_inv_offset=int(fp8_tensor), + **fp8_scales_offsets, ) diff --git a/transformer_engine/pytorch/float8_tensor.py b/transformer_engine/pytorch/float8_tensor.py index b7f87ad397..d531979868 100644 --- a/transformer_engine/pytorch/float8_tensor.py +++ b/transformer_engine/pytorch/float8_tensor.py @@ -117,13 +117,6 @@ def forward( scale_inv: Optional[torch.Tensor] = None, ) -> Float8Tensor: - # Manually compute scale-inverse if needed - if scale is not None and scale_inv is None: - if isinstance(scale, torch.Tensor): - scale_inv = scale.reciprocal() - else: - scale_inv = 1 / scale - # Extract data from FP8 meta tensors if provided if fp8_meta is not None: fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key( @@ -138,9 +131,6 @@ def forward( scale = fp8_meta[fp8_meta_key].scale[fp8_meta_index] if amax is None: amax = fp8_meta[fp8_meta_key].amax_history[0][fp8_meta_index] - if scale_inv is None: - scale_inv = fp8_meta[fp8_meta_key].scale_inv[fp8_meta_index] - scale_inv = scale_inv.detach().view(1).clone() # Check input tensor tensor = tensor.contiguous().cuda().detach() @@ -163,8 +153,9 @@ def forward( # Check scale-inverse if scale_inv is None: - scale_inv = scale.reciprocal() - scale_inv = scale_inv.to(device=tensor.device, dtype=torch.float32) + scale_inv = torch.empty([1], dtype=torch.float32, device=tensor.device) + else: + scale_inv = scale_inv.to(device=tensor.device, dtype=torch.float32) # Check amax if amax is None: @@ -737,19 +728,9 @@ def cast_transpose_( self._fp8_dtype, cast_out=data, transpose_out=transpose, + scale_inv=self._scale_inv, noop_flag=noop_flag, ) - scale = fp8_meta.scale[fp8_meta_index : fp8_meta_index + 1] - scale_inv = self._scale_inv - if noop_flag is None: - torch.reciprocal(scale, out=scale_inv) - else: - torch.where( - noop_flag.bool(), - scale_inv, - scale.reciprocal(), - out=scale_inv, - ) self._transpose_invalid = False @torch.no_grad() @@ -853,7 +834,6 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): fp8_meta_index = dst._fp8_meta_index scale = dst._fp8_meta[fp8_meta_key].scale[fp8_meta_index] amax = dst._fp8_meta[fp8_meta_key].amax_history[0][fp8_meta_index] - dst._scale_inv.copy_(scale.detach().reciprocal()) # Cast to FP8 if not dst._data.is_contiguous(): diff --git a/transformer_engine/pytorch/module/_common.py b/transformer_engine/pytorch/module/_common.py index 281e3fe104..23a06e318f 100644 --- a/transformer_engine/pytorch/module/_common.py +++ b/transformer_engine/pytorch/module/_common.py @@ -52,6 +52,9 @@ def _apply_normalization( fwd_ln_sm_margin: int, zero_centered_gamma: bool, is_grad_enabled: bool, + fp8_scale: Optional[torch.Tensor] = None, + fp8_amax: Optional[torch.Tensor] = None, + fp8_scale_inv: Optional[torch.Tensor] = None, ): normalization_func = _get_normalization_func(normalization, fp8_out, is_grad_enabled, True) @@ -70,6 +73,9 @@ def _apply_normalization( fp8_dtype_forward, fwd_ln_sm_margin, zero_centered_gamma, + scale=fp8_scale, + amax=fp8_amax, + scale_inv=fp8_scale_inv, **output_kwarg, ) else: @@ -82,6 +88,9 @@ def _apply_normalization( fp8_dtype_forward, fwd_ln_sm_margin, zero_centered_gamma, + scale=fp8_scale, + amax=fp8_amax, + scale_inv=fp8_scale_inv, ), None, None, diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 10560cdad6..d6045d8e77 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -46,6 +46,7 @@ from ..graph import is_graph_capturing from ._common import _apply_normalization, _noop_cat from ..float8_tensor import Float8Tensor +from ..export import is_in_onnx_export_mode __all__ = ["LayerNormLinear"] @@ -126,8 +127,13 @@ def forward( inputmat, dtype=ln_out_dtype, memory_format=torch.contiguous_format ) + # Objects for FP8 cast fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) + ln_out_scale_inv = None + if fp8: + ln_out_scale_inv = torch.empty([1], dtype=torch.float32, device=inputmat.device) + # Launch normalization kernel ln_out, mu, rsigma = _apply_normalization( inputmat, ln_out, @@ -140,6 +146,7 @@ def forward( fwd_ln_sm_margin, zero_centered_gamma, is_grad_enabled, + fp8_scale_inv=ln_out_scale_inv, ) # Column Parallel Linear @@ -172,6 +179,7 @@ def forward( tex.FP8FwdTensors.GEMM1_INPUT, fp8_dtype_forward, out=ln_out_fp8, + scale_inv=ln_out_scale_inv, ) ln_out = torch.empty_like(ln_out_fp8) else: @@ -180,6 +188,7 @@ def forward( fp8_meta["scaling_fwd"], tex.FP8FwdTensors.GEMM1_INPUT, fp8_dtype_forward, + scale_inv=ln_out_scale_inv, ) if ln_out_gathered: rank = torch.distributed.get_rank(tp_group) @@ -199,6 +208,18 @@ def forward( assert isinstance(weight_fp8, Float8Tensor) + # Hack for ONNX export + # Note: ONNX models are represented as a graph of tensor + # operations, so the in-place scale-inv update doesn't fit + # very well. We work around this by making it look like + # the scale-inv tensor is initialized with a copy. + # Note: ONNX export expects FP8 scales can be represented + # with constant ops. However, copying into a buffer + # involves an expand op for array broadcasting. We work + # around this by filling the buffer instead. + if is_in_onnx_export_mode(): + ln_out_scale_inv.fill_(ln_out_scale_inv.item()) + if fp8_meta["recipe"].fp8_mha: out_index, meta_tensor, output_te_dtype, output_dtype = ( tex.FP8FwdTensors.GEMM1_OUTPUT, @@ -219,8 +240,8 @@ def forward( 0, weight_fp8._fp8_dtype, ln_out_total, - fp8_meta["scaling_fwd"].scale_inv, - tex.FP8FwdTensors.GEMM1_INPUT, + ln_out_scale_inv, + 0, fp8_dtype_forward, output_dtype, get_workspace(), @@ -306,7 +327,7 @@ def forward( weight_fp8, weight.main_grad if cpu_offloading and fuse_wgrad_accumulation else None, ln_out if weight.requires_grad else None, - fp8_meta["scaling_fwd"].scale_inv.clone() if fp8 else None, + ln_out_scale_inv, ) ctx.activation_dtype = activation_dtype @@ -377,7 +398,7 @@ def backward( weight_fp8, main_grad, ln_out, - fwd_scale_inverses, + ln_out_scale_inv, ) = ctx.saved_tensors # Gather intermediate/activation tensors if needed @@ -570,8 +591,8 @@ def backward( ln_out_total_t = tex.fp8_transpose(ln_out_total, fp8_dtype_forward) wgrad, _ = tex.fp8_gemm( ln_out_total_t, - fwd_scale_inverses, - tex.FP8FwdTensors.GEMM1_INPUT, + ln_out_scale_inv, + 0, fp8_dtype_forward, ( grad_output_t._data @@ -596,8 +617,8 @@ def backward( else: ln_out_total_c = torch.ops.tex_ts.cast_from_fp8_ts( ln_out_total, - fwd_scale_inverses, - tex.FP8FwdTensors.GEMM1_INPUT, + ln_out_scale_inv, + 0, fp8_dtype_forward, TE_DType[ctx.activation_dtype], ) diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 68d333262d..175e5ab5cf 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -48,6 +48,7 @@ from ..jit import no_torch_dynamo from ..graph import is_graph_capturing from ..float8_tensor import Float8Tensor +from ..export import is_in_onnx_export_mode __all__ = ["Linear"] @@ -103,10 +104,12 @@ def forward( inputmat = cast_if_needed(inputmat, activation_dtype) inputmat_t = None inputmat_no_fp8 = inputmat + inputmat_scale_inv = None if fp8: fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) if isinstance(inputmat, Float8Tensor): + inputmat_scale_inv = inputmat._scale_inv if ( not fp8_meta["recipe"].override_linear_precision.wgrad and is_grad_enabled @@ -116,6 +119,7 @@ def forward( # FP8 input for forward, FP8 input transpose for backward wgrad inputmat_t = inputmat.transpose_2d() else: + inputmat_scale_inv = torch.empty([1], dtype=torch.float32, device=inputmat.device) if ( not fp8_meta["recipe"].override_linear_precision.wgrad and is_grad_enabled @@ -128,6 +132,7 @@ def forward( fp8_meta["scaling_fwd"], tex.FP8FwdTensors.GEMM1_INPUT, fp8_dtype_forward, + scale_inv=inputmat_scale_inv, ) else: # FP8 input for forward @@ -136,8 +141,21 @@ def forward( fp8_meta["scaling_fwd"], tex.FP8FwdTensors.GEMM1_INPUT, fp8_dtype_forward, + scale_inv=inputmat_scale_inv, ) + # Hack for ONNX export + # Note: ONNX models are represented as a graph of tensor + # operations, so the in-place scale-inv update doesn't fit + # very well. We work around this by making it look like + # the scale-inv tensor is initialized with a copy. + # Note: ONNX export expects FP8 scales can be represented + # with constant ops. However, copying into a buffer + # involves an expand op for array broadcasting. We work + # around this by filling the buffer instead. + if is_in_onnx_export_mode(): + inputmat_scale_inv.fill_(inputmat_scale_inv.item()) + # Column Parallel Linear if parallel_mode == "column" and sequence_parallel: inputmat_total, _ = gather_along_first_dim(inputmat, tp_group) @@ -206,8 +224,8 @@ def forward( if isinstance(inputmat_total, Float8Tensor) else inputmat_total ), - fp8_meta["scaling_fwd"].scale_inv, - tex.FP8FwdTensors.GEMM1_INPUT, + inputmat_scale_inv, + 0, fp8_dtype_forward, proj_out_pttype, get_workspace(), @@ -312,10 +330,10 @@ def forward( ctx.save_for_backward( saved_inputmat, saved_inputmat_t, + inputmat_scale_inv, weight, weight_fp8, weight.main_grad if cpu_offloading and fuse_wgrad_accumulation else None, - fp8_meta["scaling_fwd"].scale_inv.clone() if fp8 else None, ) ctx.activation_dtype = activation_dtype @@ -364,10 +382,10 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ( inputmat, inputmat_t, + inputmat_scale_inv, weight, weight_fp8, main_grad, - fwd_scale_inverses, ) = ctx.saved_tensors # Gather intermediate/activation tensors if needed @@ -520,8 +538,8 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], if isinstance(inputmat_t_total, Float8Tensor) else inputmat_t_total ), - fwd_scale_inverses, - tex.FP8FwdTensors.GEMM1_INPUT, + inputmat_scale_inv, + 0, fp8_dtype_forward, grad_output_t, ctx.fp8_meta["scaling_bwd"].scale_inv, diff --git a/transformer_engine/pytorch/te_onnx_extensions.py b/transformer_engine/pytorch/te_onnx_extensions.py index 05c1a5a0f5..0fa9401163 100755 --- a/transformer_engine/pytorch/te_onnx_extensions.py +++ b/transformer_engine/pytorch/te_onnx_extensions.py @@ -74,7 +74,7 @@ def is_dtype_bf16(t): return t.type().scalarType() == "BFloat16" -def quantize(g, inputs, scale_inv, fp8_tensor): +def quantize(g, inputs, scale, fp8_tensor): """Helper Function for Quantization""" output_shape = torch.onnx.symbolic_helper._get_tensor_sizes(inputs) @@ -83,7 +83,7 @@ def quantize(g, inputs, scale_inv, fp8_tensor): if not is_dtype_fp32(inputs): inputs = g.op("Cast", inputs, to_i=_C_onnx.TensorProtoDataType.FLOAT) - scale = g.op("Constant", value_t=torch.tensor(scale_inv[fp8_tensor])) + scale = g.op("Constant", value_t=torch.tensor(1 / scale[fp8_tensor])) q_op = g.op(make_op_name("TRT_FP8QuantizeLinear"), inputs, scale).setType( inputs.type().with_dtype(torch.uint8).with_sizes(output_shape) ) @@ -124,18 +124,18 @@ def compute_in_fp32(g, inp, subgraph, *args, **kwargs): return sg_out -@symbolic_helper.parse_args("v", "v", "v", "fs", "i", "i") +@symbolic_helper.parse_args("v", "fs", "v", "v", "i", "i") def onnx_cast_to_fp8(g, inputs, scale, amax, scale_inv, fp8_tensor, otype): """ONNX graph for cast_to_fp8""" # pylint: disable=unused-argument - return quantize(g, inputs, scale_inv, fp8_tensor) + return quantize(g, inputs, scale, fp8_tensor) -@symbolic_helper.parse_args("v", "v", "v", "v", "fs", "i", "i") +@symbolic_helper.parse_args("v", "fs", "v", "v", "v", "i", "i") def onnx_cast_to_fp8_noalloc(g, inputs, scale, output, amax, scale_inv, fp8_tensor, otype): """ONNX graph for cast_to_fp8_noalloc""" # pylint: disable=unused-argument - return quantize(g, inputs, scale_inv, fp8_tensor) + return quantize(g, inputs, scale, fp8_tensor) @symbolic_helper.parse_args("v", "fs", "i", "i", "i") @@ -145,25 +145,25 @@ def onnx_cast_from_fp8(g, inputs, scale_inv, fp8_tensor, itype, otype): return dequantize(g, inputs, scale_inv, fp8_tensor, otype) -@symbolic_helper.parse_args("v", "v", "v", "fs", "i", "i") +@symbolic_helper.parse_args("v", "fs", "v", "v", "i", "i") def onnx_fp8_gelu(g, inputs, scale, amax, scale_inv, fp8_tensor, otype): """ONNX graph for fp8_gelu""" # pylint: disable=unused-argument # TE computes GELU using float32 precision so wrap the GELU subgraph with # conversion to/from float32. gelu = compute_in_fp32(g, inputs, torch.onnx.symbolic_opset9.gelu, "tanh") - if scale_inv: - gelu = quantize(g, gelu, scale_inv, fp8_tensor) + if scale: + gelu = quantize(g, gelu, scale, fp8_tensor) return gelu -@symbolic_helper.parse_args("v", "v", "v", "fs", "i", "i") +@symbolic_helper.parse_args("v", "fs", "v", "v", "i", "i") def onnx_fp8_relu(g, inputs, scale, amax, scale_inv, fp8_tensor, otype): """ONNX graph for fp8_relu""" # pylint: disable=unused-argument relu = compute_in_fp32(g, inputs, torch.onnx.symbolic_opset9.relu) - if scale_inv: - relu = quantize(g, relu, scale_inv, fp8_tensor) + if scale: + relu = quantize(g, relu, scale, fp8_tensor) return relu @@ -178,13 +178,13 @@ def onnx_swiglu(g: jit_utils.GraphContext, inp, dim): return g.op("Mul", g.op("Sigmoid", first), second) -@symbolic_helper.parse_args("v", "v", "v", "fs", "i", "i") +@symbolic_helper.parse_args("v", "fs", "v", "v", "i", "i") def onnx_fp8_swiglu(g, inputs, scale, amax, scale_inv, fp8_tensor, otype): """ONNX graph for fp8_swiglu""" # pylint: disable=unused-argument swiglu = compute_in_fp32(g, inputs, onnx_swiglu, 1) - if scale_inv: - swiglu = quantize(g, swiglu, scale_inv, fp8_tensor) + if scale: + swiglu = quantize(g, swiglu, scale, fp8_tensor) return swiglu @@ -199,13 +199,13 @@ def onnx_reglu(g: jit_utils.GraphContext, inp, dim): return g.op("Mul", g.op("Relu", first), second) -@symbolic_helper.parse_args("v", "v", "v", "fs", "i", "i") +@symbolic_helper.parse_args("v", "fs", "v", "v", "i", "i") def onnx_fp8_reglu(g, inputs, scale, amax, scale_inv, fp8_tensor, otype): """ONNX graph for fp8_reglu""" # pylint: disable=unused-argument reglu = compute_in_fp32(g, inputs, onnx_reglu, 1) - if scale_inv: - reglu = quantize(g, reglu, scale_inv, fp8_tensor) + if scale: + reglu = quantize(g, reglu, scale, fp8_tensor) return reglu @@ -221,13 +221,13 @@ def onnx_geglu(g: jit_utils.GraphContext, inp, dim): return g.op("Mul", first_gelu, second) -@symbolic_helper.parse_args("v", "v", "v", "fs", "i", "i") +@symbolic_helper.parse_args("v", "fs", "v", "v", "i", "i") def onnx_fp8_geglu(g, inputs, scale, amax, scale_inv, fp8_tensor, otype): """ONNX graph for fp8_geglu""" # pylint: disable=unused-argument geglu = compute_in_fp32(g, inputs, onnx_geglu, 1) - if scale_inv: - geglu = quantize(g, geglu, scale_inv, fp8_tensor) + if scale: + geglu = quantize(g, geglu, scale, fp8_tensor) return geglu @@ -245,7 +245,7 @@ def onnx_fp8_geglu(g, inputs, scale, amax, scale_inv, fp8_tensor, otype): "v", "fs", "i", - "fs", + "v", "v", "i", "v", @@ -330,7 +330,7 @@ def _ones_like(g, inp, dtype): return one -@symbolic_helper.parse_args("v", "v", "v", "f", "v", "v", "fs", "i", "i", "i", "b") +@symbolic_helper.parse_args("v", "v", "v", "f", "fs", "v", "v", "i", "i", "i", "b") def onnx_layernorm_fwd_fp8( g, inputs, @@ -355,7 +355,7 @@ def onnx_layernorm_fwd_fp8( bias = g.op("Cast", bias, to_i=inp_dtype) ln = onnx_layernorm_fwd(g, inputs, weight, bias, eps, sm_margin, zero_centered_gamma) - fp8_ln = quantize(g, ln, scale_inv, fp8_tensor) + fp8_ln = quantize(g, ln, scale, fp8_tensor) return fp8_ln @@ -391,7 +391,7 @@ def onnx_layernorm_fwd(g, inputs, weight, bias, eps, sm_margin, zero_centered_ga return ln -@symbolic_helper.parse_args("v", "v", "f", "v", "v", "fs", "i", "i", "i", "b") +@symbolic_helper.parse_args("v", "v", "f", "fs", "v", "v", "i", "i", "i", "b") def onnx_rmsnorm_fwd_fp8( g, inputs, @@ -413,7 +413,7 @@ def onnx_rmsnorm_fwd_fp8( weight = g.op("Cast", weight, to_i=inp_dtype) ln = onnx_rmsnorm_fwd(g, inputs, weight, eps, sm_margin, zero_centered_gamma) - fp8_ln = quantize(g, ln, scale_inv, fp8_tensor) + fp8_ln = quantize(g, ln, scale, fp8_tensor) return fp8_ln