From c27863536b3d0f6f57fe75107e410da36527c21f Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Tue, 16 Jul 2024 19:59:55 +0000 Subject: [PATCH 01/14] add empty fp8 quant test file --- tests/kernels/test_fp8_quant.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 tests/kernels/test_fp8_quant.py diff --git a/tests/kernels/test_fp8_quant.py b/tests/kernels/test_fp8_quant.py new file mode 100644 index 0000000000000..e69de29bb2d1d From d53a845716b23cb779b06aff7a31ca926ad08ac4 Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Tue, 16 Jul 2024 20:35:26 +0000 Subject: [PATCH 02/14] add fp8 tests --- csrc/quantization/fp8/common.cu | 8 +++++ tests/kernels/quant_utils.py | 20 ++++++++++++ tests/kernels/test_fp8_quant.py | 54 ++++++++++++++++++++++++++++++++ tests/kernels/test_int8_quant.py | 26 +++++++-------- 4 files changed, 93 insertions(+), 15 deletions(-) create mode 100644 tests/kernels/quant_utils.py diff --git a/csrc/quantization/fp8/common.cu b/csrc/quantization/fp8/common.cu index 6120086d72df2..c4b8e94fbc1eb 100644 --- a/csrc/quantization/fp8/common.cu +++ b/csrc/quantization/fp8/common.cu @@ -163,3 +163,11 @@ void dynamic_scaled_fp8_quant(torch::Tensor& out, // [..., d] scale.data_ptr(), num_elems); }); } + +#if 0 +void dynamic_per_token_fp8_quant(torch::Tensor& out, // [..., d] + torch::Tensor& input, // [..., d] + torch::Tensor& scales) +{ +} +#endif diff --git a/tests/kernels/quant_utils.py b/tests/kernels/quant_utils.py new file mode 100644 index 0000000000000..3b645380422c7 --- /dev/null +++ b/tests/kernels/quant_utils.py @@ -0,0 +1,20 @@ +import torch +from typing import Tuple + +def ref_dynamic_per_token_quant(x: torch.tensor, + quant_dtype: torch.dtype) \ + -> Tuple[torch.tensor, torch.tensor]: + + assert quant_dtype in [torch.int8, torch.float8_e4m3fn] + qtype_traits = torch.iinfo(quant_dtype) if quant_dtype == torch.int8 \ + else torch.finfo(quant_dtype) + + # Compute scales + x_token_max, _ = x.abs().max(dim=-1) + x_token_max = x_token_max.to(dtype=torch.float32) + scales = (x_token_max / float(qtype_traits.max))[:, None].to(device="cuda", + dtype=torch.float32) + # Quant + torch_out = (x / scales).round().clamp(qtype_traits.min, qtype_traits.max).to(quant_dtype) + + return torch_out, scales diff --git a/tests/kernels/test_fp8_quant.py b/tests/kernels/test_fp8_quant.py index e69de29bb2d1d..c278a9a1b94f7 100644 --- a/tests/kernels/test_fp8_quant.py +++ b/tests/kernels/test_fp8_quant.py @@ -0,0 +1,54 @@ +import pytest +import torch + +import vllm._custom_ops as ops +from quant_utils import ref_dynamic_per_token_quant + +DTYPES = [torch.half, torch.bfloat16, torch.float] +HIDDEN_SIZES = [16, 67, 768, 2048, 5120, 5137, 8192, + 8193] # Arbitrary values for testing +NUM_TOKENS = [1, 7, 83, 4096] # Arbitrary values for testing +SEEDS = [0] +SCALE = [0.1, 0.5, 0.8, 1.2, 2.1] + +#@pytest.mark.parametrize("num_tokens", NUM_TOKENS) +#@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) +#@pytest.mark.parametrize("dtype", DTYPES) +#@pytest.mark.parametrize("seed", SEEDS) +#@torch.inference_mode() +#def test_dynamic_per_token_fp8_quant(num_tokens: int, hidden_size: int, +# dtype: torch.dtype, seed: int) -> None: +# torch.random.manual_seed(seed) +# torch.cuda.manual_seed(seed) +# +# x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") * 1000 +# +# ref_out, ref_scales = ref_dynamic_per_token_quant(x, torch.float8_e4m3fn) +# ops_out, ops_scales = ops.dynamic_per_token_fp8_quant(x) +# +# assert torch.allclose(ref_scales, ops_scales) +# assert torch.allclose(ref_out, ops_out) + +@pytest.mark.parametrize("num_tokens", NUM_TOKENS) +@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("seed", SEEDS) +@torch.inference_mode() +def test_dynamic_per_tensor_fp8_quant(num_tokens: int, hidden_size: int, + dtype: torch.dtype, seed: int) -> None: + torch.random.manual_seed(seed) + torch.cuda.manual_seed(seed) + + fp8_traits = torch.iinfo(torch.float8_e4m3fn) + + x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") * 1000 + + # reference + ref_scale = (x.abs().max() / float(fp8_traits.max))[:, None].to(device="cuda", + dtype=torch.float32) + ref_out = (x / ref_scale).round().clamp(fp8_traits.min, fp8_traits.max).to(torch.float8_e4m3fn) + # kernel + ops_out, ops_scale = ops.scaled_fp8_quant(x) + + assert torch.allclose(ref_scale, ops_scale) + assert torch.allclose(ref_out, ops_scale) diff --git a/tests/kernels/test_int8_quant.py b/tests/kernels/test_int8_quant.py index 0daf7439468aa..c085f42c4868c 100644 --- a/tests/kernels/test_int8_quant.py +++ b/tests/kernels/test_int8_quant.py @@ -3,6 +3,9 @@ # ruff: noqa: F401 import vllm._C +from vllm._custom_ops import scaled_int8_quant + +from quant_utils import ref_dynamic_per_token_quant DTYPES = [torch.half, torch.bfloat16, torch.float] HIDDEN_SIZES = [16, 67, 768, 2048, 5120, 5137, 8192, @@ -25,19 +28,13 @@ def test_dynamic_scaled_int8_quant(num_tokens: int, hidden_size: int, x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") * 1000 - x_token_max, _ = x.max(dim=1) - x_token_max = x_token_max.to(dtype=torch.float32) - scales = (x_token_max / float(127.0))[:, None].to(device="cuda", - dtype=torch.float32) - torch_out = (x / scales).round().clamp(int8_traits.min, - int8_traits.max).to(torch.int8) - - ops_out = torch.empty_like(x, dtype=torch.int8, device="cuda") - scales_out = torch.empty_like(scales, dtype=torch.float32, device="cuda") - torch.ops._C.dynamic_scaled_int8_quant(ops_out, x, scales_out) + # reference + ref_out, ref_scales = ref_dynamic_per_token_quant(x, torch.int8) + # kernel + ops_out, ops_scales = scaled_int8_quant(x, scales_out) - assert torch.allclose(scales_out, scales) - assert torch.allclose(torch_out, ops_out, + assert torch.allclose(ops_scales, ref_scales) + assert torch.allclose(ops_out, ref_out, atol=1) # big atol to account for rounding errors @@ -55,12 +52,11 @@ def test_static_scaled_int8_quant(num_tokens: int, hidden_size: int, int8_traits = torch.iinfo(torch.int8) x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") * 1000 + scale = torch.tensor([scale], dtype=torch.float32, device="cuda") out1 = (x / scale).round().clamp(int8_traits.min, int8_traits.max).to(torch.int8) - out2 = torch.empty_like(x, dtype=torch.int8) - scale_argument = torch.tensor([scale], dtype=torch.float32, device="cuda") + out2, _ = scaled_int8_quant(x, scale) - torch.ops._C.static_scaled_int8_quant(out2, x, scale_argument) assert torch.allclose(out1, out2, atol=1) # big atol to account for rounding errors From bec233266d5ffff6590241a46b1ed982f6438481 Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Tue, 16 Jul 2024 22:30:12 +0000 Subject: [PATCH 03/14] increase tolerance --- tests/kernels/test_fp8_quant.py | 22 +++++++++++++++------- tests/kernels/test_int8_quant.py | 5 ++--- 2 files changed, 17 insertions(+), 10 deletions(-) diff --git a/tests/kernels/test_fp8_quant.py b/tests/kernels/test_fp8_quant.py index c278a9a1b94f7..b1ce95df10321 100644 --- a/tests/kernels/test_fp8_quant.py +++ b/tests/kernels/test_fp8_quant.py @@ -2,7 +2,7 @@ import torch import vllm._custom_ops as ops -from quant_utils import ref_dynamic_per_token_quant +from tests.kernels.quant_utils import ref_dynamic_per_token_quant DTYPES = [torch.half, torch.bfloat16, torch.float] HIDDEN_SIZES = [16, 67, 768, 2048, 5120, 5137, 8192, @@ -39,16 +39,24 @@ def test_dynamic_per_tensor_fp8_quant(num_tokens: int, hidden_size: int, torch.random.manual_seed(seed) torch.cuda.manual_seed(seed) - fp8_traits = torch.iinfo(torch.float8_e4m3fn) + fp8_traits = torch.finfo(torch.float8_e4m3fn) - x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") * 1000 + x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") # reference - ref_scale = (x.abs().max() / float(fp8_traits.max))[:, None].to(device="cuda", - dtype=torch.float32) - ref_out = (x / ref_scale).round().clamp(fp8_traits.min, fp8_traits.max).to(torch.float8_e4m3fn) + ref_scale = x.abs().max().to(dtype=torch.float32) / float(fp8_traits.max) + assert ref_scale.dtype == torch.float32 + ref_out = (x.to(dtype=torch.float32) / ref_scale).clamp(fp8_traits.min, fp8_traits.max).to(dtype=torch.float8_e4m3fn) # kernel + assert x.dtype == dtype ops_out, ops_scale = ops.scaled_fp8_quant(x) + assert ops_out.dtype == torch.float8_e4m3fn assert torch.allclose(ref_scale, ops_scale) - assert torch.allclose(ref_out, ops_scale) + # TODO (varun) : For some test cases, the computed scale in the kernel is different + # from the reference implementation in the 8th/9th digits. example, + # ref_scales : 0.002223423682153225 + # ops_scales : 0.0022234234493225813 + # This precludes an exact match in the outputs. This needs to be investigated further. + assert torch.allclose(ref_out.to(dtype=torch.float32), ops_out.to(dtype=torch.float32), + atol=1) diff --git a/tests/kernels/test_int8_quant.py b/tests/kernels/test_int8_quant.py index c085f42c4868c..16ad41ae16bc8 100644 --- a/tests/kernels/test_int8_quant.py +++ b/tests/kernels/test_int8_quant.py @@ -5,7 +5,7 @@ import vllm._C from vllm._custom_ops import scaled_int8_quant -from quant_utils import ref_dynamic_per_token_quant +from tests.kernels.quant_utils import ref_dynamic_per_token_quant DTYPES = [torch.half, torch.bfloat16, torch.float] HIDDEN_SIZES = [16, 67, 768, 2048, 5120, 5137, 8192, @@ -24,14 +24,13 @@ def test_dynamic_scaled_int8_quant(num_tokens: int, hidden_size: int, dtype: torch.dtype, seed: int) -> None: torch.random.manual_seed(seed) torch.cuda.manual_seed(seed) - int8_traits = torch.iinfo(torch.int8) x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") * 1000 # reference ref_out, ref_scales = ref_dynamic_per_token_quant(x, torch.int8) # kernel - ops_out, ops_scales = scaled_int8_quant(x, scales_out) + ops_out, ops_scales = scaled_int8_quant(x) assert torch.allclose(ops_scales, ref_scales) assert torch.allclose(ops_out, ref_out, From ae6e3350340b03168d5ea40120e69a5b6f1c5543 Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Wed, 17 Jul 2024 13:06:52 +0000 Subject: [PATCH 04/14] fix fp8 dynamic pertoken quant tests --- csrc/ops.h | 4 ++ csrc/quantization/fp8/common.cu | 50 ++++++++++++++++-- csrc/torch_bindings.cpp | 8 ++- tests/kernels/quant_utils.py | 13 +++-- tests/kernels/test_fp8_quant.py | 92 ++++++++++++++++++++------------- vllm/_custom_ops.py | 8 +++ 6 files changed, 131 insertions(+), 44 deletions(-) diff --git a/csrc/ops.h b/csrc/ops.h index fb1099e4fe0c2..d13fd19c15221 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -129,6 +129,10 @@ void static_scaled_fp8_quant(torch::Tensor& out, torch::Tensor& input, void dynamic_scaled_fp8_quant(torch::Tensor& out, torch::Tensor& input, torch::Tensor& scale); +void dynamic_per_token_scaled_fp8_quant(torch::Tensor& out, + torch::Tensor& input, + torch::Tensor& scale); + void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, int64_t block_size, torch::Tensor sorted_token_ids, torch::Tensor experts_ids, diff --git a/csrc/quantization/fp8/common.cu b/csrc/quantization/fp8/common.cu index c4b8e94fbc1eb..2bac5d04fb3ee 100644 --- a/csrc/quantization/fp8/common.cu +++ b/csrc/quantization/fp8/common.cu @@ -7,6 +7,8 @@ #include "cuda_compat.h" #include "dispatch_utils.h" +#include "../../reduction_utils.cuh" + namespace vllm { __device__ __forceinline__ float atomicMaxFloat(float* addr, float value) { @@ -124,6 +126,36 @@ __global__ void scaled_fp8_quant_kernel(c10::Float8_e4m3fn* __restrict__ out, } } +template +__global__ void dynamic_per_token_scaled_fp8_quant_kernel( + c10::Float8_e4m3fn* __restrict__ out, + float* __restrict__ scale, + scalar_t const* __restrict__ input, + const int hidden_size) { + + int const tid = threadIdx.x; + int const token_idx = blockIdx.x; + float absmax_val = 0.0f; + + for (int i = tid; i < hidden_size; i += blockDim.x) { + float const x = static_cast(input[token_idx * hidden_size + i]); + absmax_val = max(absmax_val, fabs(x)); + } + + float const block_absmax_val_maybe = blockReduceMax(absmax_val); + __shared__ float block_absmax_val; + if (tid == 0) { + block_absmax_val = block_absmax_val_maybe; + scale[token_idx] = block_absmax_val / FP8_E4M3_MAX; + } + __syncthreads(); + + float const inverted_scale = FP8_E4M3_MAX / block_absmax_val; + for (int i = tid; i < hidden_size; i += blockDim.x) { + out[token_idx * hidden_size + i] = scaled_fp8_conversion(input[token_idx * hidden_size + i], inverted_scale); + } +} + } // namespace vllm void static_scaled_fp8_quant(torch::Tensor& out, // [..., d] @@ -164,10 +196,22 @@ void dynamic_scaled_fp8_quant(torch::Tensor& out, // [..., d] }); } -#if 0 -void dynamic_per_token_fp8_quant(torch::Tensor& out, // [..., d] +void dynamic_per_token_scaled_fp8_quant(torch::Tensor& out, // [..., d] torch::Tensor& input, // [..., d] torch::Tensor& scales) { + int const hidden_size = input.size(-1); + int const num_tokens = input.numel() / hidden_size; + dim3 const grid(num_tokens); + dim3 const block(std::min(hidden_size, 1024)); + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + VLLM_DISPATCH_FLOATING_TYPES( + input.scalar_type(), "dynamic_per_token_scaled_fp8_quant_kernel", [&] { + vllm::dynamic_per_token_scaled_fp8_quant_kernel + <<>>(out.data_ptr(), + scales.data_ptr(), + input.data_ptr(), + hidden_size); + }); } -#endif diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 18331a674eeba..570b61430b593 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -175,12 +175,18 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { "static_scaled_fp8_quant(Tensor! out, Tensor input, Tensor scale) -> ()"); ops.impl("static_scaled_fp8_quant", torch::kCUDA, &static_scaled_fp8_quant); - // Compute FP8 quantized tensor and scaling factor. + // Compute dynamic-per-tensor FP8 quantized tensor and scaling factor. ops.def( "dynamic_scaled_fp8_quant(Tensor! out, Tensor input, Tensor! scale) -> " "()"); ops.impl("dynamic_scaled_fp8_quant", torch::kCUDA, &dynamic_scaled_fp8_quant); + // Compute dynamic-per-token FP8 quantized tensor and scaling factor. + ops.def( + "dynamic_per_token_scaled_fp8_quant(Tensor! out, Tensor input, Tensor! scale) -> " + "()"); + ops.impl("dynamic_per_token_scaled_fp8_quant", torch::kCUDA, &dynamic_per_token_scaled_fp8_quant); + // Aligning the number of tokens to be processed by each expert such // that it is divisible by the block size. ops.def( diff --git a/tests/kernels/quant_utils.py b/tests/kernels/quant_utils.py index 3b645380422c7..1fb603f1e0698 100644 --- a/tests/kernels/quant_utils.py +++ b/tests/kernels/quant_utils.py @@ -12,9 +12,16 @@ def ref_dynamic_per_token_quant(x: torch.tensor, # Compute scales x_token_max, _ = x.abs().max(dim=-1) x_token_max = x_token_max.to(dtype=torch.float32) - scales = (x_token_max / float(qtype_traits.max))[:, None].to(device="cuda", - dtype=torch.float32) + scales = x_token_max / torch.as_tensor([qtype_traits.max], dtype=torch.float32, device='cuda') + scales = scales[:, None] + # Quant - torch_out = (x / scales).round().clamp(qtype_traits.min, qtype_traits.max).to(quant_dtype) + # For fp8, inorder to match the cuda kernel output, we have to do the same operations + # to prevent rounding errors. + iscales = torch.as_tensor([qtype_traits.max], dtype=torch.float32, device='cuda') / x_token_max + iscales = iscales[:, None] + torch_out = (x.to(dtype=torch.float32) * iscales).to(device="cuda", dtype=torch.float32) + torch_out = torch_out.round() if quant_dtype == torch.int8 else torch_out + torch_out = torch_out.clamp(qtype_traits.min, qtype_traits.max).to(quant_dtype) return torch_out, scales diff --git a/tests/kernels/test_fp8_quant.py b/tests/kernels/test_fp8_quant.py index b1ce95df10321..e522991eaecf7 100644 --- a/tests/kernels/test_fp8_quant.py +++ b/tests/kernels/test_fp8_quant.py @@ -11,52 +11,70 @@ SEEDS = [0] SCALE = [0.1, 0.5, 0.8, 1.2, 2.1] -#@pytest.mark.parametrize("num_tokens", NUM_TOKENS) -#@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) -#@pytest.mark.parametrize("dtype", DTYPES) -#@pytest.mark.parametrize("seed", SEEDS) -#@torch.inference_mode() -#def test_dynamic_per_token_fp8_quant(num_tokens: int, hidden_size: int, -# dtype: torch.dtype, seed: int) -> None: -# torch.random.manual_seed(seed) -# torch.cuda.manual_seed(seed) -# -# x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") * 1000 -# -# ref_out, ref_scales = ref_dynamic_per_token_quant(x, torch.float8_e4m3fn) -# ops_out, ops_scales = ops.dynamic_per_token_fp8_quant(x) -# -# assert torch.allclose(ref_scales, ops_scales) -# assert torch.allclose(ref_out, ops_out) @pytest.mark.parametrize("num_tokens", NUM_TOKENS) @pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("seed", SEEDS) @torch.inference_mode() -def test_dynamic_per_tensor_fp8_quant(num_tokens: int, hidden_size: int, +def test_dynamic_per_token_fp8_quant(num_tokens: int, hidden_size: int, dtype: torch.dtype, seed: int) -> None: torch.random.manual_seed(seed) torch.cuda.manual_seed(seed) - fp8_traits = torch.finfo(torch.float8_e4m3fn) - x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") - # reference - ref_scale = x.abs().max().to(dtype=torch.float32) / float(fp8_traits.max) - assert ref_scale.dtype == torch.float32 - ref_out = (x.to(dtype=torch.float32) / ref_scale).clamp(fp8_traits.min, fp8_traits.max).to(dtype=torch.float8_e4m3fn) - # kernel - assert x.dtype == dtype - ops_out, ops_scale = ops.scaled_fp8_quant(x) - assert ops_out.dtype == torch.float8_e4m3fn - - assert torch.allclose(ref_scale, ops_scale) - # TODO (varun) : For some test cases, the computed scale in the kernel is different - # from the reference implementation in the 8th/9th digits. example, - # ref_scales : 0.002223423682153225 - # ops_scales : 0.0022234234493225813 - # This precludes an exact match in the outputs. This needs to be investigated further. - assert torch.allclose(ref_out.to(dtype=torch.float32), ops_out.to(dtype=torch.float32), - atol=1) + ref_out, ref_scales = ref_dynamic_per_token_quant(x, torch.float8_e4m3fn) + ops_out, ops_scales = ops.dynamic_per_token_scaled_fp8_quant(x) + + #ref_out_flt = ref_out.to(dtype=torch.float32) + #ops_out_flt = ops_out.to(dtype=torch.float32) + #for i in range(num_tokens): + # for j in range(hidden_size): + # if not torch.allclose(ref_out_flt[i][j], ops_out_flt[i][j]): + # print (f"first error at token {i} - col {j}") + # assert False + + #torch.set_printoptions(profile="full") + #idx = 522 + #print (f"ref out {ref_out[idx].to(dtype=torch.float32)}") + #print (f"ops out {ops_out[idx].to(dtype=torch.float32)}") + #print (f"ref scales : {ref_scales[idx].item()}") + #print(f"ops scales : {ops_scales[idx].item()}") + + + assert torch.allclose(ref_scales, ops_scales) + assert torch.allclose(ref_out.to(dtype=torch.float32), + ops_out.to(dtype=torch.float32)) + +#@pytest.mark.parametrize("num_tokens", NUM_TOKENS) +#@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) +#@pytest.mark.parametrize("dtype", DTYPES) +#@pytest.mark.parametrize("seed", SEEDS) +#@torch.inference_mode() +#def test_dynamic_per_tensor_fp8_quant(num_tokens: int, hidden_size: int, +# dtype: torch.dtype, seed: int) -> None: +# torch.random.manual_seed(seed) +# torch.cuda.manual_seed(seed) +# +# fp8_traits = torch.finfo(torch.float8_e4m3fn) +# +# x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") +# +# # reference +# ref_scale = x.abs().max().to(dtype=torch.float32) / float(fp8_traits.max) +# assert ref_scale.dtype == torch.float32 +# ref_out = (x.to(dtype=torch.float32) / ref_scale).clamp(fp8_traits.min, fp8_traits.max).to(dtype=torch.float8_e4m3fn) +# # kernel +# assert x.dtype == dtype +# ops_out, ops_scale = ops.scaled_fp8_quant(x) +# assert ops_out.dtype == torch.float8_e4m3fn +# +# assert torch.allclose(ref_scale, ops_scale) +# # TODO (varun) : For some test cases, the computed scale in the kernel is different +# # from the reference implementation in the 8th/9th digits. example, +# # ref_scales : 0.002223423682153225 +# # ops_scales : 0.0022234234493225813 +# # This precludes an exact match in the outputs. This needs to be investigated further. +# assert torch.allclose(ref_out.to(dtype=torch.float32), ops_out.to(dtype=torch.float32), +# atol=1) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 03308d04012aa..a4e768b7c2ba3 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -319,6 +319,14 @@ def scaled_fp8_quant( torch.ops._C.static_scaled_fp8_quant(output, input, scale) return output, scale +def dynamic_per_token_scaled_fp8_quant( + input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + + output = torch.empty_like(input, dtype=torch.float8_e4m3fn) + scales = torch.empty((input.numel() // input.shape[-1], 1), + device=input.device, dtype=torch.float32) + torch.ops._C.dynamic_per_token_scaled_fp8_quant(output, input, scales) + return output, scales # int8 def scaled_int8_quant( From d86c80b6759883e88b2bc357d2634934553c918b Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Wed, 17 Jul 2024 13:15:17 +0000 Subject: [PATCH 05/14] fix dynamic pertensor fp8 quant --- tests/kernels/test_fp8_quant.py | 102 +++++++++++++++----------------- 1 file changed, 48 insertions(+), 54 deletions(-) diff --git a/tests/kernels/test_fp8_quant.py b/tests/kernels/test_fp8_quant.py index e522991eaecf7..6073cc9a537fe 100644 --- a/tests/kernels/test_fp8_quant.py +++ b/tests/kernels/test_fp8_quant.py @@ -12,69 +12,63 @@ SCALE = [0.1, 0.5, 0.8, 1.2, 2.1] -@pytest.mark.parametrize("num_tokens", NUM_TOKENS) -@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) -@pytest.mark.parametrize("dtype", DTYPES) -@pytest.mark.parametrize("seed", SEEDS) -@torch.inference_mode() -def test_dynamic_per_token_fp8_quant(num_tokens: int, hidden_size: int, - dtype: torch.dtype, seed: int) -> None: - torch.random.manual_seed(seed) - torch.cuda.manual_seed(seed) - - x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") - - ref_out, ref_scales = ref_dynamic_per_token_quant(x, torch.float8_e4m3fn) - ops_out, ops_scales = ops.dynamic_per_token_scaled_fp8_quant(x) - - #ref_out_flt = ref_out.to(dtype=torch.float32) - #ops_out_flt = ops_out.to(dtype=torch.float32) - #for i in range(num_tokens): - # for j in range(hidden_size): - # if not torch.allclose(ref_out_flt[i][j], ops_out_flt[i][j]): - # print (f"first error at token {i} - col {j}") - # assert False - - #torch.set_printoptions(profile="full") - #idx = 522 - #print (f"ref out {ref_out[idx].to(dtype=torch.float32)}") - #print (f"ops out {ops_out[idx].to(dtype=torch.float32)}") - #print (f"ref scales : {ref_scales[idx].item()}") - #print(f"ops scales : {ops_scales[idx].item()}") - - - assert torch.allclose(ref_scales, ops_scales) - assert torch.allclose(ref_out.to(dtype=torch.float32), - ops_out.to(dtype=torch.float32)) - #@pytest.mark.parametrize("num_tokens", NUM_TOKENS) #@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) #@pytest.mark.parametrize("dtype", DTYPES) #@pytest.mark.parametrize("seed", SEEDS) #@torch.inference_mode() -#def test_dynamic_per_tensor_fp8_quant(num_tokens: int, hidden_size: int, +#def test_dynamic_per_token_fp8_quant(num_tokens: int, hidden_size: int, # dtype: torch.dtype, seed: int) -> None: # torch.random.manual_seed(seed) # torch.cuda.manual_seed(seed) # -# fp8_traits = torch.finfo(torch.float8_e4m3fn) -# # x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") # -# # reference -# ref_scale = x.abs().max().to(dtype=torch.float32) / float(fp8_traits.max) -# assert ref_scale.dtype == torch.float32 -# ref_out = (x.to(dtype=torch.float32) / ref_scale).clamp(fp8_traits.min, fp8_traits.max).to(dtype=torch.float8_e4m3fn) -# # kernel -# assert x.dtype == dtype -# ops_out, ops_scale = ops.scaled_fp8_quant(x) -# assert ops_out.dtype == torch.float8_e4m3fn +# ref_out, ref_scales = ref_dynamic_per_token_quant(x, torch.float8_e4m3fn) +# ops_out, ops_scales = ops.dynamic_per_token_scaled_fp8_quant(x) # -# assert torch.allclose(ref_scale, ops_scale) -# # TODO (varun) : For some test cases, the computed scale in the kernel is different -# # from the reference implementation in the 8th/9th digits. example, -# # ref_scales : 0.002223423682153225 -# # ops_scales : 0.0022234234493225813 -# # This precludes an exact match in the outputs. This needs to be investigated further. -# assert torch.allclose(ref_out.to(dtype=torch.float32), ops_out.to(dtype=torch.float32), -# atol=1) +# #ref_out_flt = ref_out.to(dtype=torch.float32) +# #ops_out_flt = ops_out.to(dtype=torch.float32) +# #for i in range(num_tokens): +# # for j in range(hidden_size): +# # if not torch.allclose(ref_out_flt[i][j], ops_out_flt[i][j]): +# # print (f"first error at token {i} - col {j}") +# # assert False +# +# #torch.set_printoptions(profile="full") +# #idx = 522 +# #print (f"ref out {ref_out[idx].to(dtype=torch.float32)}") +# #print (f"ops out {ops_out[idx].to(dtype=torch.float32)}") +# #print (f"ref scales : {ref_scales[idx].item()}") +# #print(f"ops scales : {ops_scales[idx].item()}") +# +# +# assert torch.allclose(ref_scales, ops_scales) +# assert torch.allclose(ref_out.to(dtype=torch.float32), +# ops_out.to(dtype=torch.float32)) + +@pytest.mark.parametrize("num_tokens", NUM_TOKENS) +@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("seed", SEEDS) +@torch.inference_mode() +def test_dynamic_per_tensor_fp8_quant(num_tokens: int, hidden_size: int, + dtype: torch.dtype, seed: int) -> None: + torch.random.manual_seed(seed) + torch.cuda.manual_seed(seed) + + fp8_traits = torch.finfo(torch.float8_e4m3fn) + + x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") + + # reference + x_max = x.abs().max().to(dtype=torch.float32) + fp8_max = torch.as_tensor([fp8_traits.max], dtype=torch.float32, device='cuda') + ref_scale = x_max / fp8_max + ref_iscale = torch.as_tensor([1.0], dtype=torch.float32, device='cuda') / ref_scale + ref_out = (x.to(dtype=torch.float32) * ref_iscale).clamp(fp8_traits.min, fp8_traits.max).to(dtype=torch.float8_e4m3fn) + # kernel + ops_out, ops_scale = ops.scaled_fp8_quant(x) + + assert torch.allclose(ref_scale, ops_scale) + assert torch.allclose(ref_out.to(dtype=torch.float32), ops_out.to(dtype=torch.float32)) From ea72e0b104395d4fd648b2ffe6e6d80827028aab Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Wed, 17 Jul 2024 13:24:35 +0000 Subject: [PATCH 06/14] refactor --- tests/kernels/quant_utils.py | 24 +++++++++++- tests/kernels/test_fp8_quant.py | 65 +++++++++++---------------------- 2 files changed, 43 insertions(+), 46 deletions(-) diff --git a/tests/kernels/quant_utils.py b/tests/kernels/quant_utils.py index 1fb603f1e0698..d09bef847f041 100644 --- a/tests/kernels/quant_utils.py +++ b/tests/kernels/quant_utils.py @@ -9,6 +9,9 @@ def ref_dynamic_per_token_quant(x: torch.tensor, qtype_traits = torch.iinfo(quant_dtype) if quant_dtype == torch.int8 \ else torch.finfo(quant_dtype) + # For fp8, inorder to match the cuda kernel output, we have to do the same operations + # to prevent rounding errors. + # Compute scales x_token_max, _ = x.abs().max(dim=-1) x_token_max = x_token_max.to(dtype=torch.float32) @@ -16,8 +19,6 @@ def ref_dynamic_per_token_quant(x: torch.tensor, scales = scales[:, None] # Quant - # For fp8, inorder to match the cuda kernel output, we have to do the same operations - # to prevent rounding errors. iscales = torch.as_tensor([qtype_traits.max], dtype=torch.float32, device='cuda') / x_token_max iscales = iscales[:, None] torch_out = (x.to(dtype=torch.float32) * iscales).to(device="cuda", dtype=torch.float32) @@ -25,3 +26,22 @@ def ref_dynamic_per_token_quant(x: torch.tensor, torch_out = torch_out.clamp(qtype_traits.min, qtype_traits.max).to(quant_dtype) return torch_out, scales + +# The int8 version is very similar. Incorporate the int8 version, like in +# ref_dynamic_per_token_quant, when we have a dynamic_per_tensor int8 quant +# kernel +def ref_dynamic_per_tensor_fp8_quant(x: torch.tensor) \ + -> Tuple[torch.tensor, torch.tensor]: + + fp8_traits = torch.finfo(torch.float8_e4m3fn) + fp8_max = torch.as_tensor([fp8_traits.max], dtype=torch.float32, device='cuda') + one = torch.as_tensor([1.0], dtype=torch.float32, device='cuda') + + # For fp8, inorder to match the cuda kernel output, we have to do the same operations + # to prevent rounding errors. + + x_max = x.abs().max().to(dtype=torch.float32) + ref_scale = x_max / fp8_max + ref_iscale = one / ref_scale + ref_out = (x.to(dtype=torch.float32) * ref_iscale).clamp(fp8_traits.min, fp8_traits.max).to(dtype=torch.float8_e4m3fn) + return ref_out, ref_scale diff --git a/tests/kernels/test_fp8_quant.py b/tests/kernels/test_fp8_quant.py index 6073cc9a537fe..9916454470f7b 100644 --- a/tests/kernels/test_fp8_quant.py +++ b/tests/kernels/test_fp8_quant.py @@ -2,50 +2,32 @@ import torch import vllm._custom_ops as ops -from tests.kernels.quant_utils import ref_dynamic_per_token_quant +from tests.kernels.quant_utils import ref_dynamic_per_token_quant, ref_dynamic_per_tensor_fp8_quant DTYPES = [torch.half, torch.bfloat16, torch.float] HIDDEN_SIZES = [16, 67, 768, 2048, 5120, 5137, 8192, 8193] # Arbitrary values for testing NUM_TOKENS = [1, 7, 83, 4096] # Arbitrary values for testing SEEDS = [0] -SCALE = [0.1, 0.5, 0.8, 1.2, 2.1] +@pytest.mark.parametrize("num_tokens", NUM_TOKENS) +@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("seed", SEEDS) +@torch.inference_mode() +def test_dynamic_per_token_fp8_quant(num_tokens: int, hidden_size: int, + dtype: torch.dtype, seed: int) -> None: + torch.random.manual_seed(seed) + torch.cuda.manual_seed(seed) + + x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") + + ref_out, ref_scales = ref_dynamic_per_token_quant(x, torch.float8_e4m3fn) + ops_out, ops_scales = ops.dynamic_per_token_scaled_fp8_quant(x) -#@pytest.mark.parametrize("num_tokens", NUM_TOKENS) -#@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) -#@pytest.mark.parametrize("dtype", DTYPES) -#@pytest.mark.parametrize("seed", SEEDS) -#@torch.inference_mode() -#def test_dynamic_per_token_fp8_quant(num_tokens: int, hidden_size: int, -# dtype: torch.dtype, seed: int) -> None: -# torch.random.manual_seed(seed) -# torch.cuda.manual_seed(seed) -# -# x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") -# -# ref_out, ref_scales = ref_dynamic_per_token_quant(x, torch.float8_e4m3fn) -# ops_out, ops_scales = ops.dynamic_per_token_scaled_fp8_quant(x) -# -# #ref_out_flt = ref_out.to(dtype=torch.float32) -# #ops_out_flt = ops_out.to(dtype=torch.float32) -# #for i in range(num_tokens): -# # for j in range(hidden_size): -# # if not torch.allclose(ref_out_flt[i][j], ops_out_flt[i][j]): -# # print (f"first error at token {i} - col {j}") -# # assert False -# -# #torch.set_printoptions(profile="full") -# #idx = 522 -# #print (f"ref out {ref_out[idx].to(dtype=torch.float32)}") -# #print (f"ops out {ops_out[idx].to(dtype=torch.float32)}") -# #print (f"ref scales : {ref_scales[idx].item()}") -# #print(f"ops scales : {ops_scales[idx].item()}") -# -# -# assert torch.allclose(ref_scales, ops_scales) -# assert torch.allclose(ref_out.to(dtype=torch.float32), -# ops_out.to(dtype=torch.float32)) + assert torch.allclose(ref_scales, ops_scales) + assert torch.allclose(ref_out.to(dtype=torch.float32), + ops_out.to(dtype=torch.float32)) @pytest.mark.parametrize("num_tokens", NUM_TOKENS) @pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) @@ -57,18 +39,13 @@ def test_dynamic_per_tensor_fp8_quant(num_tokens: int, hidden_size: int, torch.random.manual_seed(seed) torch.cuda.manual_seed(seed) - fp8_traits = torch.finfo(torch.float8_e4m3fn) - x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") # reference - x_max = x.abs().max().to(dtype=torch.float32) - fp8_max = torch.as_tensor([fp8_traits.max], dtype=torch.float32, device='cuda') - ref_scale = x_max / fp8_max - ref_iscale = torch.as_tensor([1.0], dtype=torch.float32, device='cuda') / ref_scale - ref_out = (x.to(dtype=torch.float32) * ref_iscale).clamp(fp8_traits.min, fp8_traits.max).to(dtype=torch.float8_e4m3fn) + ref_out, ref_scale = ref_dynamic_per_tensor_fp8_quant(x) # kernel ops_out, ops_scale = ops.scaled_fp8_quant(x) assert torch.allclose(ref_scale, ops_scale) - assert torch.allclose(ref_out.to(dtype=torch.float32), ops_out.to(dtype=torch.float32)) + assert torch.allclose(ref_out.to(dtype=torch.float32), + ops_out.to(dtype=torch.float32)) From 2859b62db31d23067270f091c3312bf3990c8951 Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Wed, 17 Jul 2024 13:34:41 +0000 Subject: [PATCH 07/14] refactor quant utils --- tests/kernels/quant_utils.py | 24 +++++++++++++----------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/tests/kernels/quant_utils.py b/tests/kernels/quant_utils.py index d09bef847f041..e2fe251afb24d 100644 --- a/tests/kernels/quant_utils.py +++ b/tests/kernels/quant_utils.py @@ -1,5 +1,8 @@ import torch -from typing import Tuple +from typing import Tuple, Union + +def as_float32_tensor(x: [float, torch.tensor]) -> torch.tensor: + return torch.as_tensor(x, dtype=torch.float32, device='cuda') def ref_dynamic_per_token_quant(x: torch.tensor, quant_dtype: torch.dtype) \ @@ -8,20 +11,19 @@ def ref_dynamic_per_token_quant(x: torch.tensor, assert quant_dtype in [torch.int8, torch.float8_e4m3fn] qtype_traits = torch.iinfo(quant_dtype) if quant_dtype == torch.int8 \ else torch.finfo(quant_dtype) + qtype_max = as_float32_tensor(qtype_traits.max) # For fp8, inorder to match the cuda kernel output, we have to do the same operations # to prevent rounding errors. # Compute scales x_token_max, _ = x.abs().max(dim=-1) - x_token_max = x_token_max.to(dtype=torch.float32) - scales = x_token_max / torch.as_tensor([qtype_traits.max], dtype=torch.float32, device='cuda') - scales = scales[:, None] + x_token_max = as_float32_tensor(x_token_max) + scales = (x_token_max / qtype_max)[:, None] # Quant - iscales = torch.as_tensor([qtype_traits.max], dtype=torch.float32, device='cuda') / x_token_max - iscales = iscales[:, None] - torch_out = (x.to(dtype=torch.float32) * iscales).to(device="cuda", dtype=torch.float32) + iscales = (qtype_max / x_token_max)[:, None] + torch_out = as_float32_tensor(x) * iscales torch_out = torch_out.round() if quant_dtype == torch.int8 else torch_out torch_out = torch_out.clamp(qtype_traits.min, qtype_traits.max).to(quant_dtype) @@ -34,14 +36,14 @@ def ref_dynamic_per_tensor_fp8_quant(x: torch.tensor) \ -> Tuple[torch.tensor, torch.tensor]: fp8_traits = torch.finfo(torch.float8_e4m3fn) - fp8_max = torch.as_tensor([fp8_traits.max], dtype=torch.float32, device='cuda') - one = torch.as_tensor([1.0], dtype=torch.float32, device='cuda') + fp8_max = as_float32_tensor(fp8_traits.max) + one = as_float32_tensor(1.0) # For fp8, inorder to match the cuda kernel output, we have to do the same operations # to prevent rounding errors. - x_max = x.abs().max().to(dtype=torch.float32) + x_max = as_float32_tensor(x.abs().max()) ref_scale = x_max / fp8_max ref_iscale = one / ref_scale - ref_out = (x.to(dtype=torch.float32) * ref_iscale).clamp(fp8_traits.min, fp8_traits.max).to(dtype=torch.float8_e4m3fn) + ref_out = (as_float32_tensor(x) * ref_iscale).clamp(fp8_traits.min, fp8_traits.max).to(dtype=torch.float8_e4m3fn) return ref_out, ref_scale From b164a32953e48f8edd2446b6ecd210b2cccf237a Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Wed, 17 Jul 2024 14:40:18 +0000 Subject: [PATCH 08/14] vectorize conversions --- csrc/quantization/fp8/common.cu | 61 ++++++++++++++++++++++----------- tests/kernels/test_fp8_quant.py | 7 ++-- 2 files changed, 44 insertions(+), 24 deletions(-) diff --git a/csrc/quantization/fp8/common.cu b/csrc/quantization/fp8/common.cu index 2bac5d04fb3ee..9098a361a5f67 100644 --- a/csrc/quantization/fp8/common.cu +++ b/csrc/quantization/fp8/common.cu @@ -89,26 +89,22 @@ typedef struct __align__(4) { } float8x4_t; -template -__global__ void scaled_fp8_quant_kernel(c10::Float8_e4m3fn* __restrict__ out, - const scalar_t* __restrict__ input, - const float* __restrict__ scale, - int64_t num_elems) { - int tid = blockDim.x * blockIdx.x + threadIdx.x; - - // Invert the scale so that we can use multiplications to avoid expensive - // division. - const float inverted_scale = 1.0f / (*scale); +template +__device__ void scaled_fp8_conversion_vec(c10::Float8_e4m3fn* __restrict__ out, + scalar_t const* __restrict__ input, + float const inverted_scale, + int64_t const num_elems, + int const tid, int const step) { // Vectorized input/output to better utilize memory bandwidth. - const vec4_t* vectorized_in = - reinterpret_cast*>(input); + vec4_t const* vectorized_in = + reinterpret_cast const*>(input); float8x4_t* vectorized_out = reinterpret_cast(out); - int num_vec_elems = num_elems >> 2; + int const num_vec_elems = num_elems >> 2; #pragma unroll 4 - for (int i = tid; i < num_vec_elems; i += blockDim.x * gridDim.x) { + for (int i = tid; i < num_vec_elems; i += step) { vec4_t in_vec = vectorized_in[i]; float8x4_t out_vec; @@ -120,18 +116,33 @@ __global__ void scaled_fp8_quant_kernel(c10::Float8_e4m3fn* __restrict__ out, } // Handle the remaining elements if num_elems is not divisible by 4 - for (int i = num_vec_elems * 4 + tid; i < num_elems; - i += blockDim.x * gridDim.x) { + for (int i = num_vec_elems * 4 + tid; i < num_elems; i += step) { out[i] = scaled_fp8_conversion(input[i], inverted_scale); } } +template +__global__ void scaled_fp8_quant_kernel(c10::Float8_e4m3fn* __restrict__ out, + const scalar_t* __restrict__ input, + const float* __restrict__ scale, + int64_t num_elems) { + int tid = blockDim.x * blockIdx.x + threadIdx.x; + + // Invert the scale so that we can use multiplications to avoid expensive + // division. + const float inverted_scale = 1.0f / (*scale); + + scaled_fp8_conversion_vec(out, input, inverted_scale, num_elems, + tid, blockDim.x * gridDim.x); +} + template __global__ void dynamic_per_token_scaled_fp8_quant_kernel( c10::Float8_e4m3fn* __restrict__ out, float* __restrict__ scale, scalar_t const* __restrict__ input, - const int hidden_size) { + const int hidden_size, + bool const vectorize_conversions) { int const tid = threadIdx.x; int const token_idx = blockIdx.x; @@ -151,8 +162,15 @@ __global__ void dynamic_per_token_scaled_fp8_quant_kernel( __syncthreads(); float const inverted_scale = FP8_E4M3_MAX / block_absmax_val; - for (int i = tid; i < hidden_size; i += blockDim.x) { - out[token_idx * hidden_size + i] = scaled_fp8_conversion(input[token_idx * hidden_size + i], inverted_scale); + if (vectorize_conversions) { + scalar_t const* token_input = &input[token_idx * hidden_size]; + c10::Float8_e4m3fn* token_output = &out[token_idx * hidden_size]; + scaled_fp8_conversion_vec(token_output, token_input, inverted_scale, + hidden_size, tid, blockDim.x); + } else { + for (int i = tid; i < hidden_size; i += blockDim.x) { + out[token_idx * hidden_size + i] = scaled_fp8_conversion(input[token_idx * hidden_size + i], inverted_scale); + } } } @@ -204,6 +222,8 @@ void dynamic_per_token_scaled_fp8_quant(torch::Tensor& out, // [..., d] int const num_tokens = input.numel() / hidden_size; dim3 const grid(num_tokens); dim3 const block(std::min(hidden_size, 1024)); + bool const vectorize_conversions = (hidden_size % 4 == 0) && input.is_contiguous() && out.is_contiguous(); + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); VLLM_DISPATCH_FLOATING_TYPES( @@ -212,6 +232,7 @@ void dynamic_per_token_scaled_fp8_quant(torch::Tensor& out, // [..., d] <<>>(out.data_ptr(), scales.data_ptr(), input.data_ptr(), - hidden_size); + hidden_size, + vectorize_conversions); }); } diff --git a/tests/kernels/test_fp8_quant.py b/tests/kernels/test_fp8_quant.py index 9916454470f7b..cab26eee5a6ff 100644 --- a/tests/kernels/test_fp8_quant.py +++ b/tests/kernels/test_fp8_quant.py @@ -5,8 +5,9 @@ from tests.kernels.quant_utils import ref_dynamic_per_token_quant, ref_dynamic_per_tensor_fp8_quant DTYPES = [torch.half, torch.bfloat16, torch.float] -HIDDEN_SIZES = [16, 67, 768, 2048, 5120, 5137, 8192, +HIDDEN_SIZES = [1, 2, 3, 4, 16, 67, 768, 2048, 5120, 5137, 8192, 8193] # Arbitrary values for testing +HIDDEN_SIZES += list(range(1024, 1033)) # vectorized conversion edge cases NUM_TOKENS = [1, 7, 83, 4096] # Arbitrary values for testing SEEDS = [0] @@ -20,7 +21,7 @@ def test_dynamic_per_token_fp8_quant(num_tokens: int, hidden_size: int, torch.random.manual_seed(seed) torch.cuda.manual_seed(seed) - x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") + x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") + 1e-6 # avoid nans ref_out, ref_scales = ref_dynamic_per_token_quant(x, torch.float8_e4m3fn) ops_out, ops_scales = ops.dynamic_per_token_scaled_fp8_quant(x) @@ -41,9 +42,7 @@ def test_dynamic_per_tensor_fp8_quant(num_tokens: int, hidden_size: int, x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") - # reference ref_out, ref_scale = ref_dynamic_per_tensor_fp8_quant(x) - # kernel ops_out, ops_scale = ops.scaled_fp8_quant(x) assert torch.allclose(ref_scale, ops_scale) From bcf5a84a769907f7c108c7e39c52c413fc04db9c Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Wed, 17 Jul 2024 14:44:28 +0000 Subject: [PATCH 09/14] format --- csrc/quantization/fp8/common.cu | 35 ++++++++++++++------------------ csrc/torch_bindings.cpp | 6 ++++-- tests/kernels/quant_utils.py | 27 +++++++++++++++--------- tests/kernels/test_fp8_quant.py | 14 ++++++++----- tests/kernels/test_int8_quant.py | 7 +++---- vllm/_custom_ops.py | 7 +++++-- 6 files changed, 53 insertions(+), 43 deletions(-) diff --git a/csrc/quantization/fp8/common.cu b/csrc/quantization/fp8/common.cu index 9098a361a5f67..5993e116c51c6 100644 --- a/csrc/quantization/fp8/common.cu +++ b/csrc/quantization/fp8/common.cu @@ -89,13 +89,12 @@ typedef struct __align__(4) { } float8x4_t; -template +template __device__ void scaled_fp8_conversion_vec(c10::Float8_e4m3fn* __restrict__ out, scalar_t const* __restrict__ input, float const inverted_scale, int64_t const num_elems, int const tid, int const step) { - // Vectorized input/output to better utilize memory bandwidth. vec4_t const* vectorized_in = reinterpret_cast const*>(input); @@ -132,18 +131,15 @@ __global__ void scaled_fp8_quant_kernel(c10::Float8_e4m3fn* __restrict__ out, // division. const float inverted_scale = 1.0f / (*scale); - scaled_fp8_conversion_vec(out, input, inverted_scale, num_elems, - tid, blockDim.x * gridDim.x); + scaled_fp8_conversion_vec(out, input, inverted_scale, num_elems, tid, + blockDim.x * gridDim.x); } template __global__ void dynamic_per_token_scaled_fp8_quant_kernel( - c10::Float8_e4m3fn* __restrict__ out, - float* __restrict__ scale, - scalar_t const* __restrict__ input, - const int hidden_size, + c10::Float8_e4m3fn* __restrict__ out, float* __restrict__ scale, + scalar_t const* __restrict__ input, const int hidden_size, bool const vectorize_conversions) { - int const tid = threadIdx.x; int const token_idx = blockIdx.x; float absmax_val = 0.0f; @@ -169,7 +165,8 @@ __global__ void dynamic_per_token_scaled_fp8_quant_kernel( hidden_size, tid, blockDim.x); } else { for (int i = tid; i < hidden_size; i += blockDim.x) { - out[token_idx * hidden_size + i] = scaled_fp8_conversion(input[token_idx * hidden_size + i], inverted_scale); + out[token_idx * hidden_size + i] = scaled_fp8_conversion( + input[token_idx * hidden_size + i], inverted_scale); } } } @@ -214,25 +211,23 @@ void dynamic_scaled_fp8_quant(torch::Tensor& out, // [..., d] }); } -void dynamic_per_token_scaled_fp8_quant(torch::Tensor& out, // [..., d] - torch::Tensor& input, // [..., d] - torch::Tensor& scales) -{ +void dynamic_per_token_scaled_fp8_quant(torch::Tensor& out, // [..., d] + torch::Tensor& input, // [..., d] + torch::Tensor& scales) { int const hidden_size = input.size(-1); int const num_tokens = input.numel() / hidden_size; dim3 const grid(num_tokens); dim3 const block(std::min(hidden_size, 1024)); - bool const vectorize_conversions = (hidden_size % 4 == 0) && input.is_contiguous() && out.is_contiguous(); + bool const vectorize_conversions = + (hidden_size % 4 == 0) && input.is_contiguous() && out.is_contiguous(); const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); VLLM_DISPATCH_FLOATING_TYPES( input.scalar_type(), "dynamic_per_token_scaled_fp8_quant_kernel", [&] { vllm::dynamic_per_token_scaled_fp8_quant_kernel - <<>>(out.data_ptr(), - scales.data_ptr(), - input.data_ptr(), - hidden_size, - vectorize_conversions); + <<>>( + out.data_ptr(), scales.data_ptr(), + input.data_ptr(), hidden_size, vectorize_conversions); }); } diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 570b61430b593..96781bc82d527 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -183,9 +183,11 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // Compute dynamic-per-token FP8 quantized tensor and scaling factor. ops.def( - "dynamic_per_token_scaled_fp8_quant(Tensor! out, Tensor input, Tensor! scale) -> " + "dynamic_per_token_scaled_fp8_quant(Tensor! out, Tensor input, Tensor! " + "scale) -> " "()"); - ops.impl("dynamic_per_token_scaled_fp8_quant", torch::kCUDA, &dynamic_per_token_scaled_fp8_quant); + ops.impl("dynamic_per_token_scaled_fp8_quant", torch::kCUDA, + &dynamic_per_token_scaled_fp8_quant); // Aligning the number of tokens to be processed by each expert such // that it is divisible by the block size. diff --git a/tests/kernels/quant_utils.py b/tests/kernels/quant_utils.py index e2fe251afb24d..90fb6e9b5d297 100644 --- a/tests/kernels/quant_utils.py +++ b/tests/kernels/quant_utils.py @@ -1,7 +1,9 @@ -import torch from typing import Tuple, Union -def as_float32_tensor(x: [float, torch.tensor]) -> torch.tensor: +import torch + + +def as_float32_tensor(x: Union[float, torch.tensor]) -> torch.tensor: return torch.as_tensor(x, dtype=torch.float32, device='cuda') def ref_dynamic_per_token_quant(x: torch.tensor, @@ -13,8 +15,9 @@ def ref_dynamic_per_token_quant(x: torch.tensor, else torch.finfo(quant_dtype) qtype_max = as_float32_tensor(qtype_traits.max) - # For fp8, inorder to match the cuda kernel output, we have to do the same operations - # to prevent rounding errors. + # For fp8, inorder to match the cuda kernel output, we have to do exactly + # the same operations as in the corresponding fp8 kernel to prevent rounding + # errors. # Compute scales x_token_max, _ = x.abs().max(dim=-1) @@ -25,10 +28,12 @@ def ref_dynamic_per_token_quant(x: torch.tensor, iscales = (qtype_max / x_token_max)[:, None] torch_out = as_float32_tensor(x) * iscales torch_out = torch_out.round() if quant_dtype == torch.int8 else torch_out - torch_out = torch_out.clamp(qtype_traits.min, qtype_traits.max).to(quant_dtype) + torch_out = torch_out.clamp(qtype_traits.min, + qtype_traits.max).to(quant_dtype) return torch_out, scales + # The int8 version is very similar. Incorporate the int8 version, like in # ref_dynamic_per_token_quant, when we have a dynamic_per_tensor int8 quant # kernel @@ -39,11 +44,13 @@ def ref_dynamic_per_tensor_fp8_quant(x: torch.tensor) \ fp8_max = as_float32_tensor(fp8_traits.max) one = as_float32_tensor(1.0) - # For fp8, inorder to match the cuda kernel output, we have to do the same operations - # to prevent rounding errors. + # For fp8, inorder to match the cuda kernel output, we have to do exactly + # the same operations as in the corresponding fp8 kernel to prevent rounding + # errors. x_max = as_float32_tensor(x.abs().max()) - ref_scale = x_max / fp8_max - ref_iscale = one / ref_scale - ref_out = (as_float32_tensor(x) * ref_iscale).clamp(fp8_traits.min, fp8_traits.max).to(dtype=torch.float8_e4m3fn) + ref_scale = x_max / fp8_max + ref_iscale = one / ref_scale + ref_out = (as_float32_tensor(x) * ref_iscale).clamp( + fp8_traits.min, fp8_traits.max).to(dtype=torch.float8_e4m3fn) return ref_out, ref_scale diff --git a/tests/kernels/test_fp8_quant.py b/tests/kernels/test_fp8_quant.py index cab26eee5a6ff..6b555c8e242ad 100644 --- a/tests/kernels/test_fp8_quant.py +++ b/tests/kernels/test_fp8_quant.py @@ -2,26 +2,29 @@ import torch import vllm._custom_ops as ops -from tests.kernels.quant_utils import ref_dynamic_per_token_quant, ref_dynamic_per_tensor_fp8_quant +from tests.kernels.quant_utils import (ref_dynamic_per_tensor_fp8_quant, + ref_dynamic_per_token_quant) DTYPES = [torch.half, torch.bfloat16, torch.float] HIDDEN_SIZES = [1, 2, 3, 4, 16, 67, 768, 2048, 5120, 5137, 8192, 8193] # Arbitrary values for testing -HIDDEN_SIZES += list(range(1024, 1033)) # vectorized conversion edge cases +HIDDEN_SIZES += list(range(1024, 1033)) # vectorized conversion edge cases NUM_TOKENS = [1, 7, 83, 4096] # Arbitrary values for testing SEEDS = [0] + @pytest.mark.parametrize("num_tokens", NUM_TOKENS) @pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("seed", SEEDS) @torch.inference_mode() def test_dynamic_per_token_fp8_quant(num_tokens: int, hidden_size: int, - dtype: torch.dtype, seed: int) -> None: + dtype: torch.dtype, seed: int) -> None: torch.random.manual_seed(seed) torch.cuda.manual_seed(seed) - x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") + 1e-6 # avoid nans + x = torch.rand(num_tokens, hidden_size, dtype=dtype, + device="cuda") + 1e-6 # avoid nans ref_out, ref_scales = ref_dynamic_per_token_quant(x, torch.float8_e4m3fn) ops_out, ops_scales = ops.dynamic_per_token_scaled_fp8_quant(x) @@ -30,13 +33,14 @@ def test_dynamic_per_token_fp8_quant(num_tokens: int, hidden_size: int, assert torch.allclose(ref_out.to(dtype=torch.float32), ops_out.to(dtype=torch.float32)) + @pytest.mark.parametrize("num_tokens", NUM_TOKENS) @pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("seed", SEEDS) @torch.inference_mode() def test_dynamic_per_tensor_fp8_quant(num_tokens: int, hidden_size: int, - dtype: torch.dtype, seed: int) -> None: + dtype: torch.dtype, seed: int) -> None: torch.random.manual_seed(seed) torch.cuda.manual_seed(seed) diff --git a/tests/kernels/test_int8_quant.py b/tests/kernels/test_int8_quant.py index 16ad41ae16bc8..03acbf7968ff1 100644 --- a/tests/kernels/test_int8_quant.py +++ b/tests/kernels/test_int8_quant.py @@ -3,9 +3,8 @@ # ruff: noqa: F401 import vllm._C -from vllm._custom_ops import scaled_int8_quant - from tests.kernels.quant_utils import ref_dynamic_per_token_quant +from vllm._custom_ops import scaled_int8_quant DTYPES = [torch.half, torch.bfloat16, torch.float] HIDDEN_SIZES = [16, 67, 768, 2048, 5120, 5137, 8192, @@ -28,7 +27,7 @@ def test_dynamic_scaled_int8_quant(num_tokens: int, hidden_size: int, x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") * 1000 # reference - ref_out, ref_scales = ref_dynamic_per_token_quant(x, torch.int8) + ref_out, ref_scales = ref_dynamic_per_token_quant(x, torch.int8) # kernel ops_out, ops_scales = scaled_int8_quant(x) @@ -55,7 +54,7 @@ def test_static_scaled_int8_quant(num_tokens: int, hidden_size: int, out1 = (x / scale).round().clamp(int8_traits.min, int8_traits.max).to(torch.int8) - out2, _ = scaled_int8_quant(x, scale) + out2, _ = scaled_int8_quant(x, scale) assert torch.allclose(out1, out2, atol=1) # big atol to account for rounding errors diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index a4e768b7c2ba3..9bb9ea19ed832 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -319,15 +319,18 @@ def scaled_fp8_quant( torch.ops._C.static_scaled_fp8_quant(output, input, scale) return output, scale + def dynamic_per_token_scaled_fp8_quant( - input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: output = torch.empty_like(input, dtype=torch.float8_e4m3fn) scales = torch.empty((input.numel() // input.shape[-1], 1), - device=input.device, dtype=torch.float32) + device=input.device, + dtype=torch.float32) torch.ops._C.dynamic_per_token_scaled_fp8_quant(output, input, scales) return output, scales + # int8 def scaled_int8_quant( input: torch.Tensor, From dc0e0eb8dc66ca9c5bc2fe91e7431adb8cef1d6a Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Wed, 17 Jul 2024 15:10:57 +0000 Subject: [PATCH 10/14] Add torch checks --- csrc/quantization/fp8/common.cu | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/csrc/quantization/fp8/common.cu b/csrc/quantization/fp8/common.cu index 5993e116c51c6..adab13f3d9b5f 100644 --- a/csrc/quantization/fp8/common.cu +++ b/csrc/quantization/fp8/common.cu @@ -138,8 +138,7 @@ __global__ void scaled_fp8_quant_kernel(c10::Float8_e4m3fn* __restrict__ out, template __global__ void dynamic_per_token_scaled_fp8_quant_kernel( c10::Float8_e4m3fn* __restrict__ out, float* __restrict__ scale, - scalar_t const* __restrict__ input, const int hidden_size, - bool const vectorize_conversions) { + scalar_t const* __restrict__ input, const int hidden_size) { int const tid = threadIdx.x; int const token_idx = blockIdx.x; float absmax_val = 0.0f; @@ -158,6 +157,7 @@ __global__ void dynamic_per_token_scaled_fp8_quant_kernel( __syncthreads(); float const inverted_scale = FP8_E4M3_MAX / block_absmax_val; + bool const vectorize_conversions = hidden_size % 4 == 0; if (vectorize_conversions) { scalar_t const* token_input = &input[token_idx * hidden_size]; c10::Float8_e4m3fn* token_output = &out[token_idx * hidden_size]; @@ -214,12 +214,13 @@ void dynamic_scaled_fp8_quant(torch::Tensor& out, // [..., d] void dynamic_per_token_scaled_fp8_quant(torch::Tensor& out, // [..., d] torch::Tensor& input, // [..., d] torch::Tensor& scales) { + TORCH_CHECK(input.is_contiguous()); + TORCH_CHECK(out.is_contiguous()); + int const hidden_size = input.size(-1); int const num_tokens = input.numel() / hidden_size; dim3 const grid(num_tokens); dim3 const block(std::min(hidden_size, 1024)); - bool const vectorize_conversions = - (hidden_size % 4 == 0) && input.is_contiguous() && out.is_contiguous(); const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); @@ -228,6 +229,6 @@ void dynamic_per_token_scaled_fp8_quant(torch::Tensor& out, // [..., d] vllm::dynamic_per_token_scaled_fp8_quant_kernel <<>>( out.data_ptr(), scales.data_ptr(), - input.data_ptr(), hidden_size, vectorize_conversions); + input.data_ptr(), hidden_size); }); } From 9dffe30714f08d23feb77c89a558865196c55f92 Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Wed, 17 Jul 2024 20:29:22 +0000 Subject: [PATCH 11/14] nits --- tests/kernels/quant_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/kernels/quant_utils.py b/tests/kernels/quant_utils.py index 90fb6e9b5d297..63b95901fb884 100644 --- a/tests/kernels/quant_utils.py +++ b/tests/kernels/quant_utils.py @@ -15,7 +15,7 @@ def ref_dynamic_per_token_quant(x: torch.tensor, else torch.finfo(quant_dtype) qtype_max = as_float32_tensor(qtype_traits.max) - # For fp8, inorder to match the cuda kernel output, we have to do exactly + # For fp8, in order to match the cuda kernel output, we have to do exactly # the same operations as in the corresponding fp8 kernel to prevent rounding # errors. @@ -44,7 +44,7 @@ def ref_dynamic_per_tensor_fp8_quant(x: torch.tensor) \ fp8_max = as_float32_tensor(fp8_traits.max) one = as_float32_tensor(1.0) - # For fp8, inorder to match the cuda kernel output, we have to do exactly + # For fp8, in order to match the cuda kernel output, we have to do exactly # the same operations as in the corresponding fp8 kernel to prevent rounding # errors. From bceaebaf732b92fb313d94c8c8ba532fc7bd116d Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Wed, 17 Jul 2024 21:18:41 +0000 Subject: [PATCH 12/14] vectorize absmax calc --- csrc/ops.h | 8 ++-- csrc/quantization/fp8/common.cu | 65 ++++++++++++++++++++++++++------- 2 files changed, 55 insertions(+), 18 deletions(-) diff --git a/csrc/ops.h b/csrc/ops.h index d13fd19c15221..fd778264923ea 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -123,14 +123,14 @@ torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight, void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int64_t bit); -void static_scaled_fp8_quant(torch::Tensor& out, torch::Tensor& input, - torch::Tensor& scale); +void static_scaled_fp8_quant(torch::Tensor& out, torch::Tensor const& input, + torch::Tensor const& scale); -void dynamic_scaled_fp8_quant(torch::Tensor& out, torch::Tensor& input, +void dynamic_scaled_fp8_quant(torch::Tensor& out, torch::Tensor const& input, torch::Tensor& scale); void dynamic_per_token_scaled_fp8_quant(torch::Tensor& out, - torch::Tensor& input, + torch::Tensor const& input, torch::Tensor& scale); void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, diff --git a/csrc/quantization/fp8/common.cu b/csrc/quantization/fp8/common.cu index adab13f3d9b5f..37d6fed10c8f1 100644 --- a/csrc/quantization/fp8/common.cu +++ b/csrc/quantization/fp8/common.cu @@ -89,6 +89,35 @@ typedef struct __align__(4) { } float8x4_t; +template +__device__ float thread_max_vec(scalar_t const* __restrict__ input, + int64_t const num_elems, + int const tid, + int const step) { + // Vectorized input/output to better utilize memory bandwidth. + vec4_t const* vectorized_in = + reinterpret_cast const*>(input); + + int const num_vec_elems = num_elems >> 2; + float absmax_val = 0.0f; + +#pragma unroll 4 + for (int i = tid; i < num_vec_elems; i += step) { + vec4_t in_vec = vectorized_in[i]; + absmax_val = max(absmax_val, fabs(in_vec.x)); + absmax_val = max(absmax_val, fabs(in_vec.y)); + absmax_val = max(absmax_val, fabs(in_vec.z)); + absmax_val = max(absmax_val, fabs(in_vec.w)); + } + + // Handle the remaining elements if num_elems is not divisible by 4 + for (int i = num_vec_elems * 4 + tid; i < num_elems; i += step) { + absmax_val = max(absmax_val, fabs(input[i])); + } + + return absmax_val; +} + template __device__ void scaled_fp8_conversion_vec(c10::Float8_e4m3fn* __restrict__ out, scalar_t const* __restrict__ input, @@ -139,13 +168,25 @@ template __global__ void dynamic_per_token_scaled_fp8_quant_kernel( c10::Float8_e4m3fn* __restrict__ out, float* __restrict__ scale, scalar_t const* __restrict__ input, const int hidden_size) { + int const tid = threadIdx.x; int const token_idx = blockIdx.x; - float absmax_val = 0.0f; - for (int i = tid; i < hidden_size; i += blockDim.x) { - float const x = static_cast(input[token_idx * hidden_size + i]); - absmax_val = max(absmax_val, fabs(x)); + scalar_t const* __restrict__ token_input = &input[token_idx * hidden_size]; + c10::Float8_e4m3fn* __restrict__ token_output = &out[token_idx * hidden_size]; + + // For vectorization, token_input and token_output pointers need to be + // aligned at 8-byte and 4-byte addresses respectively. + bool const can_vectorize = hidden_size % 4 == 0; + + float absmax_val = 0.0f; + if (can_vectorize) { + absmax_val = thread_max_vec(token_input, hidden_size, tid, blockDim.x); + } else { + for (int i = tid; i < hidden_size; i += blockDim.x) { + float const x = static_cast(token_input[i]); + absmax_val = max(absmax_val, fabs(x)); + } } float const block_absmax_val_maybe = blockReduceMax(absmax_val); @@ -157,16 +198,12 @@ __global__ void dynamic_per_token_scaled_fp8_quant_kernel( __syncthreads(); float const inverted_scale = FP8_E4M3_MAX / block_absmax_val; - bool const vectorize_conversions = hidden_size % 4 == 0; - if (vectorize_conversions) { - scalar_t const* token_input = &input[token_idx * hidden_size]; - c10::Float8_e4m3fn* token_output = &out[token_idx * hidden_size]; + if (can_vectorize) { scaled_fp8_conversion_vec(token_output, token_input, inverted_scale, hidden_size, tid, blockDim.x); } else { for (int i = tid; i < hidden_size; i += blockDim.x) { - out[token_idx * hidden_size + i] = scaled_fp8_conversion( - input[token_idx * hidden_size + i], inverted_scale); + token_output[i] = scaled_fp8_conversion(token_input[i], inverted_scale); } } } @@ -174,8 +211,8 @@ __global__ void dynamic_per_token_scaled_fp8_quant_kernel( } // namespace vllm void static_scaled_fp8_quant(torch::Tensor& out, // [..., d] - torch::Tensor& input, // [..., d] - torch::Tensor& scale) // [1] + torch::Tensor const& input, // [..., d] + torch::Tensor const& scale) // [1] { int64_t num_tokens = input.numel() / input.size(-1); int64_t num_elems = input.numel(); @@ -192,7 +229,7 @@ void static_scaled_fp8_quant(torch::Tensor& out, // [..., d] } void dynamic_scaled_fp8_quant(torch::Tensor& out, // [..., d] - torch::Tensor& input, // [..., d] + torch::Tensor const& input, // [..., d] torch::Tensor& scale) // [1] { int64_t num_tokens = input.numel() / input.size(-1); @@ -212,7 +249,7 @@ void dynamic_scaled_fp8_quant(torch::Tensor& out, // [..., d] } void dynamic_per_token_scaled_fp8_quant(torch::Tensor& out, // [..., d] - torch::Tensor& input, // [..., d] + torch::Tensor const& input, // [..., d] torch::Tensor& scales) { TORCH_CHECK(input.is_contiguous()); TORCH_CHECK(out.is_contiguous()); From a178ce18ffe967fea5af600abdd137862dae5be9 Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Wed, 17 Jul 2024 21:19:14 +0000 Subject: [PATCH 13/14] format --- csrc/quantization/fp8/common.cu | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/csrc/quantization/fp8/common.cu b/csrc/quantization/fp8/common.cu index 37d6fed10c8f1..0938c0707679f 100644 --- a/csrc/quantization/fp8/common.cu +++ b/csrc/quantization/fp8/common.cu @@ -91,9 +91,8 @@ float8x4_t; template __device__ float thread_max_vec(scalar_t const* __restrict__ input, - int64_t const num_elems, - int const tid, - int const step) { + int64_t const num_elems, int const tid, + int const step) { // Vectorized input/output to better utilize memory bandwidth. vec4_t const* vectorized_in = reinterpret_cast const*>(input); @@ -168,7 +167,6 @@ template __global__ void dynamic_per_token_scaled_fp8_quant_kernel( c10::Float8_e4m3fn* __restrict__ out, float* __restrict__ scale, scalar_t const* __restrict__ input, const int hidden_size) { - int const tid = threadIdx.x; int const token_idx = blockIdx.x; @@ -210,7 +208,7 @@ __global__ void dynamic_per_token_scaled_fp8_quant_kernel( } // namespace vllm -void static_scaled_fp8_quant(torch::Tensor& out, // [..., d] +void static_scaled_fp8_quant(torch::Tensor& out, // [..., d] torch::Tensor const& input, // [..., d] torch::Tensor const& scale) // [1] { @@ -228,9 +226,9 @@ void static_scaled_fp8_quant(torch::Tensor& out, // [..., d] }); } -void dynamic_scaled_fp8_quant(torch::Tensor& out, // [..., d] +void dynamic_scaled_fp8_quant(torch::Tensor& out, // [..., d] torch::Tensor const& input, // [..., d] - torch::Tensor& scale) // [1] + torch::Tensor& scale) // [1] { int64_t num_tokens = input.numel() / input.size(-1); int64_t num_elems = input.numel(); @@ -248,7 +246,7 @@ void dynamic_scaled_fp8_quant(torch::Tensor& out, // [..., d] }); } -void dynamic_per_token_scaled_fp8_quant(torch::Tensor& out, // [..., d] +void dynamic_per_token_scaled_fp8_quant(torch::Tensor& out, // [..., d] torch::Tensor const& input, // [..., d] torch::Tensor& scales) { TORCH_CHECK(input.is_contiguous()); From fb111f99ee6e305268dce5fe7900308b9a27de55 Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Wed, 17 Jul 2024 22:04:58 +0000 Subject: [PATCH 14/14] ws changes to trigger ci tests --- tests/kernels/quant_utils.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/kernels/quant_utils.py b/tests/kernels/quant_utils.py index 63b95901fb884..a1513bdffe768 100644 --- a/tests/kernels/quant_utils.py +++ b/tests/kernels/quant_utils.py @@ -16,8 +16,8 @@ def ref_dynamic_per_token_quant(x: torch.tensor, qtype_max = as_float32_tensor(qtype_traits.max) # For fp8, in order to match the cuda kernel output, we have to do exactly - # the same operations as in the corresponding fp8 kernel to prevent rounding - # errors. + # the same operations as in the corresponding fp8 kernel to prevent + # rounding errors. # Compute scales x_token_max, _ = x.abs().max(dim=-1) @@ -45,8 +45,8 @@ def ref_dynamic_per_tensor_fp8_quant(x: torch.tensor) \ one = as_float32_tensor(1.0) # For fp8, in order to match the cuda kernel output, we have to do exactly - # the same operations as in the corresponding fp8 kernel to prevent rounding - # errors. + # the same operations as in the corresponding fp8 kernel to prevent + # rounding errors. x_max = as_float32_tensor(x.abs().max()) ref_scale = x_max / fp8_max