From 3c808fbbbac26379201f088c47f79922fc4f5f4a Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Wed, 17 Jul 2024 21:38:35 -0400 Subject: [PATCH] [ Kernel ] FP8 Dynamic-Per-Token Quant Kernel (#6511) Co-authored-by: Varun Sundar Rabindranath --- csrc/ops.h | 10 ++- csrc/quantization/fp8/common.cu | 144 ++++++++++++++++++++++++++----- csrc/torch_bindings.cpp | 10 ++- tests/kernels/quant_utils.py | 56 ++++++++++++ tests/kernels/test_fp8_quant.py | 54 ++++++++++++ tests/kernels/test_int8_quant.py | 26 +++--- vllm/_custom_ops.py | 11 +++ 7 files changed, 271 insertions(+), 40 deletions(-) create mode 100644 tests/kernels/quant_utils.py create mode 100644 tests/kernels/test_fp8_quant.py diff --git a/csrc/ops.h b/csrc/ops.h index 1e94a9f45ef08..c0f924c09b515 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -128,12 +128,16 @@ torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight, void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int64_t bit); -void static_scaled_fp8_quant(torch::Tensor& out, torch::Tensor& input, - torch::Tensor& scale); +void static_scaled_fp8_quant(torch::Tensor& out, torch::Tensor const& input, + torch::Tensor const& scale); -void dynamic_scaled_fp8_quant(torch::Tensor& out, torch::Tensor& input, +void dynamic_scaled_fp8_quant(torch::Tensor& out, torch::Tensor const& input, torch::Tensor& scale); +void dynamic_per_token_scaled_fp8_quant(torch::Tensor& out, + torch::Tensor const& input, + torch::Tensor& scale); + void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, int64_t block_size, torch::Tensor sorted_token_ids, torch::Tensor experts_ids, diff --git a/csrc/quantization/fp8/common.cu b/csrc/quantization/fp8/common.cu index 6120086d72df2..0938c0707679f 100644 --- a/csrc/quantization/fp8/common.cu +++ b/csrc/quantization/fp8/common.cu @@ -7,6 +7,8 @@ #include "cuda_compat.h" #include "dispatch_utils.h" +#include "../../reduction_utils.cuh" + namespace vllm { __device__ __forceinline__ float atomicMaxFloat(float* addr, float value) { @@ -88,25 +90,48 @@ typedef struct __align__(4) { float8x4_t; template -__global__ void scaled_fp8_quant_kernel(c10::Float8_e4m3fn* __restrict__ out, - const scalar_t* __restrict__ input, - const float* __restrict__ scale, - int64_t num_elems) { - int tid = blockDim.x * blockIdx.x + threadIdx.x; +__device__ float thread_max_vec(scalar_t const* __restrict__ input, + int64_t const num_elems, int const tid, + int const step) { + // Vectorized input/output to better utilize memory bandwidth. + vec4_t const* vectorized_in = + reinterpret_cast const*>(input); - // Invert the scale so that we can use multiplications to avoid expensive - // division. - const float inverted_scale = 1.0f / (*scale); + int const num_vec_elems = num_elems >> 2; + float absmax_val = 0.0f; + +#pragma unroll 4 + for (int i = tid; i < num_vec_elems; i += step) { + vec4_t in_vec = vectorized_in[i]; + absmax_val = max(absmax_val, fabs(in_vec.x)); + absmax_val = max(absmax_val, fabs(in_vec.y)); + absmax_val = max(absmax_val, fabs(in_vec.z)); + absmax_val = max(absmax_val, fabs(in_vec.w)); + } + // Handle the remaining elements if num_elems is not divisible by 4 + for (int i = num_vec_elems * 4 + tid; i < num_elems; i += step) { + absmax_val = max(absmax_val, fabs(input[i])); + } + + return absmax_val; +} + +template +__device__ void scaled_fp8_conversion_vec(c10::Float8_e4m3fn* __restrict__ out, + scalar_t const* __restrict__ input, + float const inverted_scale, + int64_t const num_elems, + int const tid, int const step) { // Vectorized input/output to better utilize memory bandwidth. - const vec4_t* vectorized_in = - reinterpret_cast*>(input); + vec4_t const* vectorized_in = + reinterpret_cast const*>(input); float8x4_t* vectorized_out = reinterpret_cast(out); - int num_vec_elems = num_elems >> 2; + int const num_vec_elems = num_elems >> 2; #pragma unroll 4 - for (int i = tid; i < num_vec_elems; i += blockDim.x * gridDim.x) { + for (int i = tid; i < num_vec_elems; i += step) { vec4_t in_vec = vectorized_in[i]; float8x4_t out_vec; @@ -118,17 +143,74 @@ __global__ void scaled_fp8_quant_kernel(c10::Float8_e4m3fn* __restrict__ out, } // Handle the remaining elements if num_elems is not divisible by 4 - for (int i = num_vec_elems * 4 + tid; i < num_elems; - i += blockDim.x * gridDim.x) { + for (int i = num_vec_elems * 4 + tid; i < num_elems; i += step) { out[i] = scaled_fp8_conversion(input[i], inverted_scale); } } +template +__global__ void scaled_fp8_quant_kernel(c10::Float8_e4m3fn* __restrict__ out, + const scalar_t* __restrict__ input, + const float* __restrict__ scale, + int64_t num_elems) { + int tid = blockDim.x * blockIdx.x + threadIdx.x; + + // Invert the scale so that we can use multiplications to avoid expensive + // division. + const float inverted_scale = 1.0f / (*scale); + + scaled_fp8_conversion_vec(out, input, inverted_scale, num_elems, tid, + blockDim.x * gridDim.x); +} + +template +__global__ void dynamic_per_token_scaled_fp8_quant_kernel( + c10::Float8_e4m3fn* __restrict__ out, float* __restrict__ scale, + scalar_t const* __restrict__ input, const int hidden_size) { + int const tid = threadIdx.x; + int const token_idx = blockIdx.x; + + scalar_t const* __restrict__ token_input = &input[token_idx * hidden_size]; + c10::Float8_e4m3fn* __restrict__ token_output = &out[token_idx * hidden_size]; + + // For vectorization, token_input and token_output pointers need to be + // aligned at 8-byte and 4-byte addresses respectively. + bool const can_vectorize = hidden_size % 4 == 0; + + float absmax_val = 0.0f; + if (can_vectorize) { + absmax_val = thread_max_vec(token_input, hidden_size, tid, blockDim.x); + } else { + for (int i = tid; i < hidden_size; i += blockDim.x) { + float const x = static_cast(token_input[i]); + absmax_val = max(absmax_val, fabs(x)); + } + } + + float const block_absmax_val_maybe = blockReduceMax(absmax_val); + __shared__ float block_absmax_val; + if (tid == 0) { + block_absmax_val = block_absmax_val_maybe; + scale[token_idx] = block_absmax_val / FP8_E4M3_MAX; + } + __syncthreads(); + + float const inverted_scale = FP8_E4M3_MAX / block_absmax_val; + if (can_vectorize) { + scaled_fp8_conversion_vec(token_output, token_input, inverted_scale, + hidden_size, tid, blockDim.x); + } else { + for (int i = tid; i < hidden_size; i += blockDim.x) { + token_output[i] = scaled_fp8_conversion(token_input[i], inverted_scale); + } + } +} + } // namespace vllm -void static_scaled_fp8_quant(torch::Tensor& out, // [..., d] - torch::Tensor& input, // [..., d] - torch::Tensor& scale) // [1] +void static_scaled_fp8_quant(torch::Tensor& out, // [..., d] + torch::Tensor const& input, // [..., d] + torch::Tensor const& scale) // [1] { int64_t num_tokens = input.numel() / input.size(-1); int64_t num_elems = input.numel(); @@ -144,9 +226,9 @@ void static_scaled_fp8_quant(torch::Tensor& out, // [..., d] }); } -void dynamic_scaled_fp8_quant(torch::Tensor& out, // [..., d] - torch::Tensor& input, // [..., d] - torch::Tensor& scale) // [1] +void dynamic_scaled_fp8_quant(torch::Tensor& out, // [..., d] + torch::Tensor const& input, // [..., d] + torch::Tensor& scale) // [1] { int64_t num_tokens = input.numel() / input.size(-1); int64_t num_elems = input.numel(); @@ -163,3 +245,25 @@ void dynamic_scaled_fp8_quant(torch::Tensor& out, // [..., d] scale.data_ptr(), num_elems); }); } + +void dynamic_per_token_scaled_fp8_quant(torch::Tensor& out, // [..., d] + torch::Tensor const& input, // [..., d] + torch::Tensor& scales) { + TORCH_CHECK(input.is_contiguous()); + TORCH_CHECK(out.is_contiguous()); + + int const hidden_size = input.size(-1); + int const num_tokens = input.numel() / hidden_size; + dim3 const grid(num_tokens); + dim3 const block(std::min(hidden_size, 1024)); + + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + VLLM_DISPATCH_FLOATING_TYPES( + input.scalar_type(), "dynamic_per_token_scaled_fp8_quant_kernel", [&] { + vllm::dynamic_per_token_scaled_fp8_quant_kernel + <<>>( + out.data_ptr(), scales.data_ptr(), + input.data_ptr(), hidden_size); + }); +} diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index ff9875e0e17a3..55ccc6f53b455 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -179,12 +179,20 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { "static_scaled_fp8_quant(Tensor! out, Tensor input, Tensor scale) -> ()"); ops.impl("static_scaled_fp8_quant", torch::kCUDA, &static_scaled_fp8_quant); - // Compute FP8 quantized tensor and scaling factor. + // Compute dynamic-per-tensor FP8 quantized tensor and scaling factor. ops.def( "dynamic_scaled_fp8_quant(Tensor! out, Tensor input, Tensor! scale) -> " "()"); ops.impl("dynamic_scaled_fp8_quant", torch::kCUDA, &dynamic_scaled_fp8_quant); + // Compute dynamic-per-token FP8 quantized tensor and scaling factor. + ops.def( + "dynamic_per_token_scaled_fp8_quant(Tensor! out, Tensor input, Tensor! " + "scale) -> " + "()"); + ops.impl("dynamic_per_token_scaled_fp8_quant", torch::kCUDA, + &dynamic_per_token_scaled_fp8_quant); + // Aligning the number of tokens to be processed by each expert such // that it is divisible by the block size. ops.def( diff --git a/tests/kernels/quant_utils.py b/tests/kernels/quant_utils.py new file mode 100644 index 0000000000000..a1513bdffe768 --- /dev/null +++ b/tests/kernels/quant_utils.py @@ -0,0 +1,56 @@ +from typing import Tuple, Union + +import torch + + +def as_float32_tensor(x: Union[float, torch.tensor]) -> torch.tensor: + return torch.as_tensor(x, dtype=torch.float32, device='cuda') + +def ref_dynamic_per_token_quant(x: torch.tensor, + quant_dtype: torch.dtype) \ + -> Tuple[torch.tensor, torch.tensor]: + + assert quant_dtype in [torch.int8, torch.float8_e4m3fn] + qtype_traits = torch.iinfo(quant_dtype) if quant_dtype == torch.int8 \ + else torch.finfo(quant_dtype) + qtype_max = as_float32_tensor(qtype_traits.max) + + # For fp8, in order to match the cuda kernel output, we have to do exactly + # the same operations as in the corresponding fp8 kernel to prevent + # rounding errors. + + # Compute scales + x_token_max, _ = x.abs().max(dim=-1) + x_token_max = as_float32_tensor(x_token_max) + scales = (x_token_max / qtype_max)[:, None] + + # Quant + iscales = (qtype_max / x_token_max)[:, None] + torch_out = as_float32_tensor(x) * iscales + torch_out = torch_out.round() if quant_dtype == torch.int8 else torch_out + torch_out = torch_out.clamp(qtype_traits.min, + qtype_traits.max).to(quant_dtype) + + return torch_out, scales + + +# The int8 version is very similar. Incorporate the int8 version, like in +# ref_dynamic_per_token_quant, when we have a dynamic_per_tensor int8 quant +# kernel +def ref_dynamic_per_tensor_fp8_quant(x: torch.tensor) \ + -> Tuple[torch.tensor, torch.tensor]: + + fp8_traits = torch.finfo(torch.float8_e4m3fn) + fp8_max = as_float32_tensor(fp8_traits.max) + one = as_float32_tensor(1.0) + + # For fp8, in order to match the cuda kernel output, we have to do exactly + # the same operations as in the corresponding fp8 kernel to prevent + # rounding errors. + + x_max = as_float32_tensor(x.abs().max()) + ref_scale = x_max / fp8_max + ref_iscale = one / ref_scale + ref_out = (as_float32_tensor(x) * ref_iscale).clamp( + fp8_traits.min, fp8_traits.max).to(dtype=torch.float8_e4m3fn) + return ref_out, ref_scale diff --git a/tests/kernels/test_fp8_quant.py b/tests/kernels/test_fp8_quant.py new file mode 100644 index 0000000000000..6b555c8e242ad --- /dev/null +++ b/tests/kernels/test_fp8_quant.py @@ -0,0 +1,54 @@ +import pytest +import torch + +import vllm._custom_ops as ops +from tests.kernels.quant_utils import (ref_dynamic_per_tensor_fp8_quant, + ref_dynamic_per_token_quant) + +DTYPES = [torch.half, torch.bfloat16, torch.float] +HIDDEN_SIZES = [1, 2, 3, 4, 16, 67, 768, 2048, 5120, 5137, 8192, + 8193] # Arbitrary values for testing +HIDDEN_SIZES += list(range(1024, 1033)) # vectorized conversion edge cases +NUM_TOKENS = [1, 7, 83, 4096] # Arbitrary values for testing +SEEDS = [0] + + +@pytest.mark.parametrize("num_tokens", NUM_TOKENS) +@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("seed", SEEDS) +@torch.inference_mode() +def test_dynamic_per_token_fp8_quant(num_tokens: int, hidden_size: int, + dtype: torch.dtype, seed: int) -> None: + torch.random.manual_seed(seed) + torch.cuda.manual_seed(seed) + + x = torch.rand(num_tokens, hidden_size, dtype=dtype, + device="cuda") + 1e-6 # avoid nans + + ref_out, ref_scales = ref_dynamic_per_token_quant(x, torch.float8_e4m3fn) + ops_out, ops_scales = ops.dynamic_per_token_scaled_fp8_quant(x) + + assert torch.allclose(ref_scales, ops_scales) + assert torch.allclose(ref_out.to(dtype=torch.float32), + ops_out.to(dtype=torch.float32)) + + +@pytest.mark.parametrize("num_tokens", NUM_TOKENS) +@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("seed", SEEDS) +@torch.inference_mode() +def test_dynamic_per_tensor_fp8_quant(num_tokens: int, hidden_size: int, + dtype: torch.dtype, seed: int) -> None: + torch.random.manual_seed(seed) + torch.cuda.manual_seed(seed) + + x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") + + ref_out, ref_scale = ref_dynamic_per_tensor_fp8_quant(x) + ops_out, ops_scale = ops.scaled_fp8_quant(x) + + assert torch.allclose(ref_scale, ops_scale) + assert torch.allclose(ref_out.to(dtype=torch.float32), + ops_out.to(dtype=torch.float32)) diff --git a/tests/kernels/test_int8_quant.py b/tests/kernels/test_int8_quant.py index 0daf7439468aa..03acbf7968ff1 100644 --- a/tests/kernels/test_int8_quant.py +++ b/tests/kernels/test_int8_quant.py @@ -3,6 +3,8 @@ # ruff: noqa: F401 import vllm._C +from tests.kernels.quant_utils import ref_dynamic_per_token_quant +from vllm._custom_ops import scaled_int8_quant DTYPES = [torch.half, torch.bfloat16, torch.float] HIDDEN_SIZES = [16, 67, 768, 2048, 5120, 5137, 8192, @@ -21,23 +23,16 @@ def test_dynamic_scaled_int8_quant(num_tokens: int, hidden_size: int, dtype: torch.dtype, seed: int) -> None: torch.random.manual_seed(seed) torch.cuda.manual_seed(seed) - int8_traits = torch.iinfo(torch.int8) x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") * 1000 - x_token_max, _ = x.max(dim=1) - x_token_max = x_token_max.to(dtype=torch.float32) - scales = (x_token_max / float(127.0))[:, None].to(device="cuda", - dtype=torch.float32) - torch_out = (x / scales).round().clamp(int8_traits.min, - int8_traits.max).to(torch.int8) - - ops_out = torch.empty_like(x, dtype=torch.int8, device="cuda") - scales_out = torch.empty_like(scales, dtype=torch.float32, device="cuda") - torch.ops._C.dynamic_scaled_int8_quant(ops_out, x, scales_out) + # reference + ref_out, ref_scales = ref_dynamic_per_token_quant(x, torch.int8) + # kernel + ops_out, ops_scales = scaled_int8_quant(x) - assert torch.allclose(scales_out, scales) - assert torch.allclose(torch_out, ops_out, + assert torch.allclose(ops_scales, ref_scales) + assert torch.allclose(ops_out, ref_out, atol=1) # big atol to account for rounding errors @@ -55,12 +50,11 @@ def test_static_scaled_int8_quant(num_tokens: int, hidden_size: int, int8_traits = torch.iinfo(torch.int8) x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") * 1000 + scale = torch.tensor([scale], dtype=torch.float32, device="cuda") out1 = (x / scale).round().clamp(int8_traits.min, int8_traits.max).to(torch.int8) - out2 = torch.empty_like(x, dtype=torch.int8) - scale_argument = torch.tensor([scale], dtype=torch.float32, device="cuda") + out2, _ = scaled_int8_quant(x, scale) - torch.ops._C.static_scaled_int8_quant(out2, x, scale_argument) assert torch.allclose(out1, out2, atol=1) # big atol to account for rounding errors diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 143957f7b65f0..07646ae582a28 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -335,6 +335,17 @@ def scaled_fp8_quant( return output, scale +def dynamic_per_token_scaled_fp8_quant( + input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + + output = torch.empty_like(input, dtype=torch.float8_e4m3fn) + scales = torch.empty((input.numel() // input.shape[-1], 1), + device=input.device, + dtype=torch.float32) + torch.ops._C.dynamic_per_token_scaled_fp8_quant(output, input, scales) + return output, scales + + # int8 def scaled_int8_quant( input: torch.Tensor,