From f0701ce70304817ffd02df9e2e849df3dfec6add Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranth Date: Fri, 19 Jul 2024 20:19:49 +0000 Subject: [PATCH 1/2] Add ub_scale --- csrc/ops.h | 6 +-- csrc/quantization/fp8/common.cu | 73 ++++++++++++++++++++++----------- csrc/torch_bindings.cpp | 2 +- tests/kernels/quant_utils.py | 30 ++++++++++---- tests/kernels/test_fp8_quant.py | 13 ++++-- vllm/_custom_ops.py | 3 +- 6 files changed, 86 insertions(+), 41 deletions(-) diff --git a/csrc/ops.h b/csrc/ops.h index c0f924c09b515..6541b4d46d7f6 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -134,9 +134,9 @@ void static_scaled_fp8_quant(torch::Tensor& out, torch::Tensor const& input, void dynamic_scaled_fp8_quant(torch::Tensor& out, torch::Tensor const& input, torch::Tensor& scale); -void dynamic_per_token_scaled_fp8_quant(torch::Tensor& out, - torch::Tensor const& input, - torch::Tensor& scale); +void dynamic_per_token_scaled_fp8_quant( + torch::Tensor& out, torch::Tensor const& input, torch::Tensor& scale, + c10::optional const& scale_ub); void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, int64_t block_size, torch::Tensor sorted_token_ids, diff --git a/csrc/quantization/fp8/common.cu b/csrc/quantization/fp8/common.cu index 0938c0707679f..56e3a3f43446a 100644 --- a/csrc/quantization/fp8/common.cu +++ b/csrc/quantization/fp8/common.cu @@ -23,10 +23,16 @@ __device__ __forceinline__ float atomicMaxFloat(float* addr, float value) { #define FP8_E4M3_MAX std::numeric_limits::max() -template +template __device__ __forceinline__ c10::Float8_e4m3fn scaled_fp8_conversion( - const scalar_t val, const float inverted_scale) { - float x = static_cast(val) * inverted_scale; + float const val, float const scale) { + float x = 0.0f; + if constexpr (is_scale_inverted) { + x = val * scale; + } else { + x = val / scale; + } + float r = fmax(-FP8_E4M3_MAX, fmin(x, FP8_E4M3_MAX)); return static_cast(r); } @@ -117,10 +123,10 @@ __device__ float thread_max_vec(scalar_t const* __restrict__ input, return absmax_val; } -template +template __device__ void scaled_fp8_conversion_vec(c10::Float8_e4m3fn* __restrict__ out, scalar_t const* __restrict__ input, - float const inverted_scale, + float const scale, int64_t const num_elems, int const tid, int const step) { // Vectorized input/output to better utilize memory bandwidth. @@ -135,16 +141,21 @@ __device__ void scaled_fp8_conversion_vec(c10::Float8_e4m3fn* __restrict__ out, vec4_t in_vec = vectorized_in[i]; float8x4_t out_vec; - out_vec.x = scaled_fp8_conversion(in_vec.x, inverted_scale); - out_vec.y = scaled_fp8_conversion(in_vec.y, inverted_scale); - out_vec.z = scaled_fp8_conversion(in_vec.z, inverted_scale); - out_vec.w = scaled_fp8_conversion(in_vec.w, inverted_scale); + out_vec.x = scaled_fp8_conversion( + static_cast(in_vec.x), scale); + out_vec.y = scaled_fp8_conversion( + static_cast(in_vec.y), scale); + out_vec.z = scaled_fp8_conversion( + static_cast(in_vec.z), scale); + out_vec.w = scaled_fp8_conversion( + static_cast(in_vec.w), scale); vectorized_out[i] = out_vec; } // Handle the remaining elements if num_elems is not divisible by 4 for (int i = num_vec_elems * 4 + tid; i < num_elems; i += step) { - out[i] = scaled_fp8_conversion(input[i], inverted_scale); + out[i] = scaled_fp8_conversion( + static_cast(input[i]), scale); } } @@ -158,15 +169,17 @@ __global__ void scaled_fp8_quant_kernel(c10::Float8_e4m3fn* __restrict__ out, // Invert the scale so that we can use multiplications to avoid expensive // division. const float inverted_scale = 1.0f / (*scale); - - scaled_fp8_conversion_vec(out, input, inverted_scale, num_elems, tid, - blockDim.x * gridDim.x); + scaled_fp8_conversion_vec( + out, input, inverted_scale, num_elems, tid, blockDim.x * gridDim.x); } template __global__ void dynamic_per_token_scaled_fp8_quant_kernel( c10::Float8_e4m3fn* __restrict__ out, float* __restrict__ scale, - scalar_t const* __restrict__ input, const int hidden_size) { + scalar_t const* __restrict__ input, float const* __restrict__ scale_ub, + const int hidden_size) { + float const min_scaling_factor = 1.0f / (FP8_E4M3_MAX * 512.f); + int const tid = threadIdx.x; int const token_idx = blockIdx.x; @@ -188,20 +201,27 @@ __global__ void dynamic_per_token_scaled_fp8_quant_kernel( } float const block_absmax_val_maybe = blockReduceMax(absmax_val); - __shared__ float block_absmax_val; + __shared__ float token_scale; if (tid == 0) { - block_absmax_val = block_absmax_val_maybe; - scale[token_idx] = block_absmax_val / FP8_E4M3_MAX; + if (scale_ub) { + token_scale = min(block_absmax_val_maybe, *scale_ub); + } else { + token_scale = block_absmax_val_maybe; + } + // token scale computation + token_scale = max(token_scale / FP8_E4M3_MAX, min_scaling_factor); + scale[token_idx] = token_scale; } __syncthreads(); - float const inverted_scale = FP8_E4M3_MAX / block_absmax_val; + // Note that we don't use inverted scales so we can match FBGemm impl. if (can_vectorize) { - scaled_fp8_conversion_vec(token_output, token_input, inverted_scale, - hidden_size, tid, blockDim.x); + scaled_fp8_conversion_vec( + token_output, token_input, token_scale, hidden_size, tid, blockDim.x); } else { for (int i = tid; i < hidden_size; i += blockDim.x) { - token_output[i] = scaled_fp8_conversion(token_input[i], inverted_scale); + token_output[i] = scaled_fp8_conversion( + static_cast(token_input[i]), token_scale); } } } @@ -246,9 +266,10 @@ void dynamic_scaled_fp8_quant(torch::Tensor& out, // [..., d] }); } -void dynamic_per_token_scaled_fp8_quant(torch::Tensor& out, // [..., d] - torch::Tensor const& input, // [..., d] - torch::Tensor& scales) { +void dynamic_per_token_scaled_fp8_quant( + torch::Tensor& out, // [..., d] + torch::Tensor const& input, // [..., d] + torch::Tensor& scales, std::optional const& scale_ub) { TORCH_CHECK(input.is_contiguous()); TORCH_CHECK(out.is_contiguous()); @@ -264,6 +285,8 @@ void dynamic_per_token_scaled_fp8_quant(torch::Tensor& out, // [..., d] vllm::dynamic_per_token_scaled_fp8_quant_kernel <<>>( out.data_ptr(), scales.data_ptr(), - input.data_ptr(), hidden_size); + input.data_ptr(), + scale_ub.has_value() ? scale_ub->data_ptr() : nullptr, + hidden_size); }); } diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 55ccc6f53b455..d5136e45e781e 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -188,7 +188,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // Compute dynamic-per-token FP8 quantized tensor and scaling factor. ops.def( "dynamic_per_token_scaled_fp8_quant(Tensor! out, Tensor input, Tensor! " - "scale) -> " + "scale, Tensor? scale_ub) -> " "()"); ops.impl("dynamic_per_token_scaled_fp8_quant", torch::kCUDA, &dynamic_per_token_scaled_fp8_quant); diff --git a/tests/kernels/quant_utils.py b/tests/kernels/quant_utils.py index a1513bdffe768..cec2b05bafd21 100644 --- a/tests/kernels/quant_utils.py +++ b/tests/kernels/quant_utils.py @@ -1,4 +1,4 @@ -from typing import Tuple, Union +from typing import Optional, Tuple, Union import torch @@ -7,13 +7,19 @@ def as_float32_tensor(x: Union[float, torch.tensor]) -> torch.tensor: return torch.as_tensor(x, dtype=torch.float32, device='cuda') def ref_dynamic_per_token_quant(x: torch.tensor, - quant_dtype: torch.dtype) \ + quant_dtype: torch.dtype, + scale_ub: Optional[torch.tensor] = None) \ -> Tuple[torch.tensor, torch.tensor]: assert quant_dtype in [torch.int8, torch.float8_e4m3fn] + if scale_ub is not None: + assert quant_dtype == torch.float8_e4m3fn + qtype_traits = torch.iinfo(quant_dtype) if quant_dtype == torch.int8 \ else torch.finfo(quant_dtype) qtype_max = as_float32_tensor(qtype_traits.max) + s_1 = as_float32_tensor(1.0) + s_512 = as_float32_tensor(512.0) # For fp8, in order to match the cuda kernel output, we have to do exactly # the same operations as in the corresponding fp8 kernel to prevent @@ -22,14 +28,24 @@ def ref_dynamic_per_token_quant(x: torch.tensor, # Compute scales x_token_max, _ = x.abs().max(dim=-1) x_token_max = as_float32_tensor(x_token_max) + if scale_ub is not None: + x_token_max = x_token_max.clamp(max=scale_ub) scales = (x_token_max / qtype_max)[:, None] # Quant - iscales = (qtype_max / x_token_max)[:, None] - torch_out = as_float32_tensor(x) * iscales - torch_out = torch_out.round() if quant_dtype == torch.int8 else torch_out - torch_out = torch_out.clamp(qtype_traits.min, - qtype_traits.max).to(quant_dtype) + if quant_dtype == torch.int8: + iscales = as_float32_tensor(s_1 / scales) + torch_out = as_float32_tensor(x) * iscales + torch_out = torch_out.round() + torch_out = torch_out.clamp(qtype_traits.min, + qtype_traits.max).to(quant_dtype) + else: + assert quant_dtype == torch.float8_e4m3fn + min_scaling_factor = s_1 / (qtype_max * s_512) + scales = scales.clamp(min=min_scaling_factor) + torch_out = as_float32_tensor(x) / scales + torch_out = torch_out.clamp(qtype_traits.min, + qtype_traits.max).to(quant_dtype) return torch_out, scales diff --git a/tests/kernels/test_fp8_quant.py b/tests/kernels/test_fp8_quant.py index 9077976f44bc9..ca2920a8fe086 100644 --- a/tests/kernels/test_fp8_quant.py +++ b/tests/kernels/test_fp8_quant.py @@ -10,25 +10,30 @@ 8193] # Arbitrary values for testing HIDDEN_SIZES += list(range(1024, 1033)) # vectorized conversion edge cases NUM_TOKENS = [1, 7, 83, 4096] # Arbitrary values for testing +SCALE_UBS = [True, False] SEEDS = [0] @pytest.mark.parametrize("num_tokens", NUM_TOKENS) @pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) @pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("scale_ub", SCALE_UBS) @pytest.mark.parametrize("seed", SEEDS) @torch.inference_mode() def test_dynamic_per_token_fp8_quant(num_tokens: int, hidden_size: int, - dtype: torch.dtype, seed: int) -> None: + dtype: torch.dtype, scale_ub: bool, + seed: int) -> None: torch.random.manual_seed(seed) torch.cuda.manual_seed(seed) x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") + 1e-6 # avoid nans - ref_out, ref_scales = ref_dynamic_per_token_quant(x, torch.float8_e4m3fn) - ops_out, ops_scales = ops.scaled_fp8_quant(x, - use_per_token_if_dynamic=True) + scale_ub = torch.mean(x).to(dtype=torch.float32, device='cuda') \ + if scale_ub else None + ref_out, ref_scales = ref_dynamic_per_token_quant(x, torch.float8_e4m3fn, + scale_ub) + ops_out, ops_scales = ops.dynamic_per_token_scaled_fp8_quant(x, scale_ub, use_per_token_if_dynamic=True) assert torch.allclose(ref_scales, ops_scales) assert torch.allclose(ref_out.to(dtype=torch.float32), diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 666db29568952..873c6786a85a0 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -300,6 +300,7 @@ def scaled_fp8_quant( input: torch.Tensor, scale: Optional[torch.Tensor] = None, batch_dim_padding: Optional[int] = None, + scale_ub: Optional[torch.Tensor] = None, use_per_token_if_dynamic: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor]: """ @@ -336,7 +337,7 @@ def scaled_fp8_quant( device=input.device, dtype=torch.float32) torch.ops._C.dynamic_per_token_scaled_fp8_quant( - output, input, scale) + output, input, scale, scale_ub) else: scale = torch.zeros(1, device=input.device, dtype=torch.float32) torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale) From d5d2b696aa4be1d8c10932ad79cce0735befe8dc Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranth Date: Sat, 20 Jul 2024 00:00:48 +0000 Subject: [PATCH 2/2] format --- tests/kernels/test_fp8_quant.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/kernels/test_fp8_quant.py b/tests/kernels/test_fp8_quant.py index ca2920a8fe086..bf1a8df649972 100644 --- a/tests/kernels/test_fp8_quant.py +++ b/tests/kernels/test_fp8_quant.py @@ -33,7 +33,9 @@ def test_dynamic_per_token_fp8_quant(num_tokens: int, hidden_size: int, if scale_ub else None ref_out, ref_scales = ref_dynamic_per_token_quant(x, torch.float8_e4m3fn, scale_ub) - ops_out, ops_scales = ops.dynamic_per_token_scaled_fp8_quant(x, scale_ub, use_per_token_if_dynamic=True) + ops_out, ops_scales = ops.scaled_fp8_quant(x, + scale_ub=scale_ub, + use_per_token_if_dynamic=True) assert torch.allclose(ref_scales, ops_scales) assert torch.allclose(ref_out.to(dtype=torch.float32),