Skip to content

Commit

Permalink
Update FP8 scale-inverse in kernels with FP8 output (#1083)
Browse files Browse the repository at this point in the history
* Perform scale-inv update in cast-transpose kernels

Signed-off-by: Tim Moon <[email protected]>

* Perform scale-inv update in cast and activation kernels

Signed-off-by: Tim Moon <[email protected]>

* Perform sclae-inv update in LayerNorm and RMSNorm kernels

Signed-off-by: Tim Moon <[email protected]>

* Perform scale-inv update after FP8 GEMMs

Signed-off-by: Tim Moon <[email protected]>

* Fuse casts and scale-inv updates in linear module

Signed-off-by: Tim Moon <[email protected]>

* Fuse casts and scale-inv updates in layernorm-linear module

Signed-off-by: Tim Moon <[email protected]>

* Simplify kernel to update FP8 scale-inv

Signed-off-by: Tim Moon <[email protected]>

* Fix typos

Signed-off-by: Tim Moon <[email protected]>

* Debug amax update in layernorm kernels

Signed-off-by: Tim Moon <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Debug test failures

Signed-off-by: Tim Moon <[email protected]>

* Debug ONNX export

Use quantization scaling factor in ONNX quantize op.

Signed-off-by: Tim Moon <[email protected]>

* Review suggestion from @ptrendx

Signed-off-by: Tim Moon <[email protected]>

* Debug mismatched dtypes

Signed-off-by: Tim Moon <[email protected]>

---------

Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
timmoon10 and pre-commit-ci[bot] committed Aug 21, 2024
1 parent 5d5fe81 commit 8e3561b
Show file tree
Hide file tree
Showing 34 changed files with 824 additions and 380 deletions.
2 changes: 2 additions & 0 deletions tests/cpp/operator/test_cast_transpose.cu
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ void performTest(const size_t N, const size_t H) {
if (isFp8Type(otype)) {
auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32);
compareResults("amax", output_c.amax(), ref_amax, atol_amax, rtol_amax);
float ref_scale_inv = 1.f / output_c.scale();
compareResults("scale_inv", output_c.scale_inv(), ref_scale_inv, atol_amax, rtol_amax);
}
auto [atol, rtol] = getTolerances(otype);
compareResults("output_c", output_c, ref_output_c.get(), atol, rtol);
Expand Down
2 changes: 2 additions & 0 deletions tests/cpp/operator/test_cast_transpose_dbias.cu
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,8 @@ void performTest(const size_t N, const size_t H) {
if (isFp8Type(otype)) {
auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32);
compareResults("amax", output_c.amax(), ref_amax, atol_amax, rtol_amax);
float ref_scale_inv = 1.f / output_c.scale();
compareResults("scale_inv", output_c.scale_inv(), ref_scale_inv, atol_amax, rtol_amax);
}
auto [atol, rtol] = getTolerances(otype);
compareResults("output_c", output_c, ref_output_c.get(), atol, rtol);
Expand Down
2 changes: 2 additions & 0 deletions tests/cpp/operator/test_cast_transpose_dbias_dgelu.cu
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,8 @@ void performTest(const size_t N, const size_t H) {
if (isFp8Type(otype)) {
auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32);
compareResults("amax", output_c.amax(), ref_amax, atol_amax, rtol_amax);
float ref_scale_inv = 1.f / output_c.scale();
compareResults("scale_inv", output_c.scale_inv(), ref_scale_inv, atol_amax, rtol_amax);
}

auto [atol, rtol] = getTolerances(otype);
Expand Down
2 changes: 2 additions & 0 deletions tests/cpp/operator/test_layernorm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,8 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma)
auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32);
if (isFp8Type(otype)) {
compareResults("amax", z.amax(), ref_amax, atol_amax, rtol_amax);
float ref_scale_inv = 1.f / z.scale();
compareResults("scale_inv", z.scale_inv(), ref_scale_inv, atol_amax, rtol_amax);
}

auto [atol_stats, rtol_stats] = getTolerances(DType::kFloat32);
Expand Down
4 changes: 4 additions & 0 deletions tests/cpp/operator/test_multi_cast_transpose.cu
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,10 @@ void performTest() {
output_c_list[tensor_id].amax(),
ref_amax_list[tensor_id],
atol_amax, rtol_amax);
compareResults("scale_inv",
output_c_list[tensor_id].scale_inv(),
1.f / output_c_list[tensor_id].scale(),
atol_amax, rtol_amax);
}
auto [atol, rtol] = getTolerances(otype);
compareResults("output_c",
Expand Down
2 changes: 2 additions & 0 deletions tests/cpp/operator/test_rmsnorm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,8 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma)
auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32);
if (isFp8Type(otype)) {
compareResults("amax", z.amax(), ref_amax, atol_amax, rtol_amax);
float ref_scale_inv = 1.f / z.scale();
compareResults("scale_inv", z.scale_inv(), ref_scale_inv, atol_amax, rtol_amax);
}

auto [atol_stats, rtol_stats] = getTolerances(DType::kFloat32);
Expand Down
1 change: 1 addition & 0 deletions transformer_engine/common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ set(transformer_engine_SOURCES)
list(APPEND transformer_engine_SOURCES
pycudnn.cpp
transformer_engine.cpp
common.cu
transpose/cast_transpose.cu
transpose/transpose.cu
transpose/cast_transpose_fusion.cu
Expand Down
9 changes: 6 additions & 3 deletions transformer_engine/common/activation/activation_template.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ void act_fn(const Tensor &input, Tensor *output, cudaStream_t stream) {
reinterpret_cast<const IType *>(input.data.dptr),
reinterpret_cast<OType *>(output->data.dptr),
reinterpret_cast<const ComputeType *>(output->scale.dptr),
reinterpret_cast<ComputeType *>(output->amax.dptr), tot_elts, {},
reinterpret_cast<ComputeType *>(output->amax.dptr),
reinterpret_cast<ComputeType *>(output->scale_inv.dptr), tot_elts, {},
stream);); // NOLINT(*)
); // NOLINT(*)
}
Expand All @@ -50,7 +51,8 @@ void dact_fn(const Tensor &grad, const Tensor &input, Tensor *output, cudaStream
reinterpret_cast<const IType *>(input.data.dptr),
reinterpret_cast<OType *>(output->data.dptr),
reinterpret_cast<const ComputeType *>(output->scale.dptr),
reinterpret_cast<ComputeType *>(output->amax.dptr), tot_elts, {},
reinterpret_cast<ComputeType *>(output->amax.dptr),
reinterpret_cast<ComputeType *>(output->scale_inv.dptr), tot_elts, {},
stream);); // NOLINT(*)
); // NOLINT(*)
}
Expand All @@ -74,7 +76,8 @@ void gated_act_fn(const Tensor &input, Tensor *output, cudaStream_t stream) {
reinterpret_cast<const IType *>(input.data.dptr),
reinterpret_cast<OType *>(output->data.dptr),
reinterpret_cast<const ComputeType *>(output->scale.dptr),
reinterpret_cast<ComputeType *>(output->amax.dptr), output->data.shape[0],
reinterpret_cast<ComputeType *>(output->amax.dptr),
reinterpret_cast<ComputeType *>(output->scale_inv.dptr), output->data.shape[0],
output->data.shape[1], {},
stream);); // NOLINT(*)
); // NOLINT(*)
Expand Down
32 changes: 32 additions & 0 deletions transformer_engine/common/common.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
/*************************************************************************
* Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/

#include <transformer_engine/transformer_engine.h>

#include "./common.h"
#include "./utils.cuh"

namespace transformer_engine {

namespace {

__global__ void __launch_bounds__(1)
update_tensor_scale_inv_kernel(const float* __restrict__ scale_ptr,
float* __restrict__ scale_inv_ptr) {
const float scale = scale_ptr == nullptr ? 1 : *scale_ptr;
reciprocal<float>(scale_inv_ptr, scale);
}

} // namespace

void update_tensor_scale_inv(Tensor* t, cudaStream_t stream) {
if (t->scale_inv.dptr != nullptr) {
update_tensor_scale_inv_kernel<<<1, 1, 0, stream>>>(
reinterpret_cast<const float*>(t->scale.dptr), reinterpret_cast<float*>(t->scale_inv.dptr));
}
}

} // namespace transformer_engine
7 changes: 7 additions & 0 deletions transformer_engine/common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,13 @@ void CheckOutputTensor(const Tensor &t, const std::string &name, bool allow_empt

bool is_fp8_dtype(const DType t);

/*! \brief Update a tensor's FP8 scale-inverse
*
* The FP8 scale-inverse (dequantization scaling factor) is updated
* with the reciprocal of the FP8 scale (quantization scaling factor).
*/
void update_tensor_scale_inv(Tensor *t, cudaStream_t stream);

#define NVTE_API_CALL(api_name) \
transformer_engine::nvtx::NVTXWrapper _##api_name##_nvtx_wrapper(#api_name);

Expand Down
5 changes: 5 additions & 0 deletions transformer_engine/common/gemm/cublaslt_gemm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,11 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
workspace, /* workspace */
workspaceSize, stream)); /* stream */

// Update FP8 scale-inv in output tensor
if (is_fp8_dtype(outputD->data.dtype)) {
update_tensor_scale_inv(outputD, stream);
}

NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceDestroy(preference));
NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutDestroy(Ddesc));
NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutDestroy(Cdesc));
Expand Down
3 changes: 3 additions & 0 deletions transformer_engine/common/layer_norm/ln.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,9 @@ struct FwdParams : public ParamsBase {
// AMax output
void *amax;

// Inverse of scaling factor
void *scale_inv;

// Whether to compute scale and amax
bool fp8_out;
};
Expand Down
1 change: 1 addition & 0 deletions transformer_engine/common/layer_norm/ln_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size
params.epsilon = epsilon;
params.amax = z->amax.dptr;
params.scale = z->scale.dptr;
params.scale_inv = z->scale_inv.dptr;
params.fp8_out = fp8_out;
params.zero_centered_gamma = zero_centered_gamma;

Expand Down
32 changes: 24 additions & 8 deletions transformer_engine/common/layer_norm/ln_fwd_kernels.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -132,10 +132,18 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_fwd_tuned_kernel(
}
}
if (params.fp8_out) {
amax = reduce_max<WARPS_M * WARPS_N>(amax, warp);
if (threadIdx.x == 0 && threadIdx.y == 0) {
static_assert(std::is_same<compute_t, float>::value);
atomicMaxFloat(reinterpret_cast<compute_t *>(params.amax), amax);
// Reduce amax over block
if (params.amax != nullptr) {
amax = reduce_max<WARPS_M * WARPS_N>(amax, warp);
if (threadIdx.x == 0) {
static_assert(std::is_same<compute_t, float>::value);
atomicMaxFloat(reinterpret_cast<compute_t *>(params.amax), amax);
}
}

// Update scale-inverse
if (blockIdx.x == 0 && threadIdx.x == 0 && params.scale_inv != nullptr) {
reciprocal<compute_t>(reinterpret_cast<compute_t *>(params.scale_inv), scale);
}
}
}
Expand Down Expand Up @@ -291,10 +299,18 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_fwd_general_kerne

// Finalize fp8 factors
if (params.fp8_out) {
amax = reduce_max<WARPS_M * WARPS_N>(amax, warp);
if (threadIdx.x == 0) {
static_assert(std::is_same<compute_t, float>::value);
atomicMaxFloat(reinterpret_cast<compute_t *>(params.amax), amax);
// Reduce amax over block
if (params.amax != nullptr) {
amax = reduce_max<WARPS_M * WARPS_N>(amax, warp);
if (threadIdx.x == 0) {
static_assert(std::is_same<compute_t, float>::value);
atomicMaxFloat(reinterpret_cast<compute_t *>(params.amax), amax);
}
}

// Update scale-inverse
if (blockIdx.x == 0 && threadIdx.x == 0 && params.scale_inv != nullptr) {
reciprocal<compute_t>(reinterpret_cast<compute_t *>(params.scale_inv), scale);
}
}
}
Expand Down
1 change: 1 addition & 0 deletions transformer_engine/common/rmsnorm/rmsnorm_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tens
params.epsilon = epsilon;
params.amax = z->amax.dptr;
params.scale = z->scale.dptr;
params.scale_inv = z->scale_inv.dptr;
params.fp8_out = fp8_out;
params.zero_centered_gamma = zero_centered_gamma;

Expand Down
32 changes: 24 additions & 8 deletions transformer_engine/common/rmsnorm/rmsnorm_fwd_kernels.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -125,10 +125,18 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_fwd_tuned_ke
}
}
if (params.fp8_out) {
amax = reduce_max<WARPS_M * WARPS_N>(amax, warp);
if (threadIdx.x == 0 && threadIdx.y == 0) {
static_assert(std::is_same<compute_t, float>::value);
atomicMaxFloat(reinterpret_cast<compute_t *>(params.amax), amax);
// Reduce amax over block
if (params.amax != nullptr) {
amax = reduce_max<WARPS_M * WARPS_N>(amax, warp);
if (threadIdx.x == 0) {
static_assert(std::is_same<compute_t, float>::value);
atomicMaxFloat(reinterpret_cast<compute_t *>(params.amax), amax);
}
}

// Update scale-inverse
if (blockIdx.x == 0 && threadIdx.x == 0 && params.scale_inv != nullptr) {
reciprocal<compute_t>(reinterpret_cast<compute_t *>(params.scale_inv), scale);
}
}
}
Expand Down Expand Up @@ -267,10 +275,18 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_fwd_general_

// Finalize fp8 factors
if (params.fp8_out) {
amax = reduce_max<WARPS_M * WARPS_N>(amax, warp);
if (threadIdx.x == 0) {
static_assert(std::is_same<compute_t, float>::value);
atomicMaxFloat(reinterpret_cast<compute_t *>(params.amax), amax);
// Reduce amax over block
if (params.amax != nullptr) {
amax = reduce_max<WARPS_M * WARPS_N>(amax, warp);
if (threadIdx.x == 0) {
static_assert(std::is_same<compute_t, float>::value);
atomicMaxFloat(reinterpret_cast<compute_t *>(params.amax), amax);
}
}

// Update scale-inverse
if (blockIdx.x == 0 && threadIdx.x == 0 && params.scale_inv != nullptr) {
reciprocal<compute_t>(reinterpret_cast<compute_t *>(params.scale_inv), scale);
}
}
}
Expand Down
28 changes: 18 additions & 10 deletions transformer_engine/common/transpose/cast_transpose.cu
Original file line number Diff line number Diff line change
Expand Up @@ -101,14 +101,11 @@ struct KernelConfig {
};

template <size_t load_size, size_t store_size, typename IType, typename OType>
__global__ void __launch_bounds__(block_size)
cast_transpose_general_kernel(const IType *__restrict__ const input,
const CType *__restrict__ const noop,
OType *__restrict__ const output_c,
OType *__restrict__ const output_t,
const CType *__restrict__ const scale_ptr,
CType *__restrict__ const amax_ptr, const size_t row_length,
const size_t num_rows) {
__global__ void __launch_bounds__(block_size) cast_transpose_general_kernel(
const IType *__restrict__ const input, const CType *__restrict__ const noop,
OType *__restrict__ const output_c, OType *__restrict__ const output_t,
const CType *__restrict__ const scale_ptr, CType *__restrict__ const amax_ptr,
CType *__restrict__ const scale_inv_ptr, const size_t row_length, const size_t num_rows) {
if (noop != nullptr && noop[0] == 1.0f) return;

// Vectorized load/store sizes
Expand Down Expand Up @@ -207,9 +204,15 @@ __global__ void __launch_bounds__(block_size)
if (amax_ptr != nullptr) {
amax = reduce_max<warps_per_tile>(amax, tidy);
if (threadIdx.x == 0) {
static_assert(std::is_same<CType, float>::value);
atomicMaxFloat(amax_ptr, amax);
}
}

// Update scale-inverse
if (blockIdx.x == 0 && threadIdx.x == 0 && scale_inv_ptr != nullptr) {
reciprocal<CType>(scale_inv_ptr, scale);
}
}

} // namespace
Expand Down Expand Up @@ -255,6 +258,8 @@ void cast_transpose(const Tensor &input, const Tensor &noop, Tensor *cast_output
"Cast and transposed outputs need to share amax tensor.");
NVTE_CHECK(cast_output.scale.dptr == transposed_output.scale.dptr,
"Cast and transposed outputs need to share scale tensor.");
NVTE_CHECK(cast_output.scale_inv.dptr == transposed_output.scale_inv.dptr,
"Cast and transposed outputs need to share scale-inverse tensor.");

TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
input.data.dtype, InputType,
Expand Down Expand Up @@ -324,7 +329,9 @@ void cast_transpose(const Tensor &input, const Tensor &noop, Tensor *cast_output
static_cast<OutputType *>(cast_output.data.dptr),
static_cast<OutputType *>(transposed_output.data.dptr),
static_cast<const CType *>(cast_output.scale.dptr),
static_cast<CType *>(cast_output.amax.dptr), row_length, num_rows);
static_cast<CType *>(cast_output.amax.dptr),
static_cast<CType *>(cast_output.scale_inv.dptr), row_length,
num_rows);
} else { // Statically-compiled general kernel
constexpr size_t load_size = 4;
constexpr size_t store_size = 4;
Expand All @@ -339,7 +346,8 @@ void cast_transpose(const Tensor &input, const Tensor &noop, Tensor *cast_output
static_cast<OutputType *>(cast_output.data.dptr),
static_cast<OutputType *>(transposed_output.data.dptr),
static_cast<const CType *>(cast_output.scale.dptr),
static_cast<CType *>(cast_output.amax.dptr), row_length, num_rows);
static_cast<CType *>(cast_output.amax.dptr),
static_cast<CType *>(cast_output.scale_inv.dptr), row_length, num_rows);
}); // NOLINT(*)
); // NOLINT(*)
}
Expand Down
Loading

0 comments on commit 8e3561b

Please sign in to comment.