forked from vllm-project/vllm
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[ Kernel ] FP8 Dynamic-Per-Token Quant Kernel (vllm-project#6511)
Co-authored-by: Varun Sundar Rabindranath <[email protected]>
- Loading branch information
1 parent
81614e7
commit fc7b66f
Showing
7 changed files
with
271 additions
and
40 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
from typing import Tuple, Union | ||
|
||
import torch | ||
|
||
|
||
def as_float32_tensor(x: Union[float, torch.tensor]) -> torch.tensor: | ||
return torch.as_tensor(x, dtype=torch.float32, device='cuda') | ||
|
||
def ref_dynamic_per_token_quant(x: torch.tensor, | ||
quant_dtype: torch.dtype) \ | ||
-> Tuple[torch.tensor, torch.tensor]: | ||
|
||
assert quant_dtype in [torch.int8, torch.float8_e4m3fn] | ||
qtype_traits = torch.iinfo(quant_dtype) if quant_dtype == torch.int8 \ | ||
else torch.finfo(quant_dtype) | ||
qtype_max = as_float32_tensor(qtype_traits.max) | ||
|
||
# For fp8, in order to match the cuda kernel output, we have to do exactly | ||
# the same operations as in the corresponding fp8 kernel to prevent | ||
# rounding errors. | ||
|
||
# Compute scales | ||
x_token_max, _ = x.abs().max(dim=-1) | ||
x_token_max = as_float32_tensor(x_token_max) | ||
scales = (x_token_max / qtype_max)[:, None] | ||
|
||
# Quant | ||
iscales = (qtype_max / x_token_max)[:, None] | ||
torch_out = as_float32_tensor(x) * iscales | ||
torch_out = torch_out.round() if quant_dtype == torch.int8 else torch_out | ||
torch_out = torch_out.clamp(qtype_traits.min, | ||
qtype_traits.max).to(quant_dtype) | ||
|
||
return torch_out, scales | ||
|
||
|
||
# The int8 version is very similar. Incorporate the int8 version, like in | ||
# ref_dynamic_per_token_quant, when we have a dynamic_per_tensor int8 quant | ||
# kernel | ||
def ref_dynamic_per_tensor_fp8_quant(x: torch.tensor) \ | ||
-> Tuple[torch.tensor, torch.tensor]: | ||
|
||
fp8_traits = torch.finfo(torch.float8_e4m3fn) | ||
fp8_max = as_float32_tensor(fp8_traits.max) | ||
one = as_float32_tensor(1.0) | ||
|
||
# For fp8, in order to match the cuda kernel output, we have to do exactly | ||
# the same operations as in the corresponding fp8 kernel to prevent | ||
# rounding errors. | ||
|
||
x_max = as_float32_tensor(x.abs().max()) | ||
ref_scale = x_max / fp8_max | ||
ref_iscale = one / ref_scale | ||
ref_out = (as_float32_tensor(x) * ref_iscale).clamp( | ||
fp8_traits.min, fp8_traits.max).to(dtype=torch.float8_e4m3fn) | ||
return ref_out, ref_scale |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
import pytest | ||
import torch | ||
|
||
import vllm._custom_ops as ops | ||
from tests.kernels.quant_utils import (ref_dynamic_per_tensor_fp8_quant, | ||
ref_dynamic_per_token_quant) | ||
|
||
DTYPES = [torch.half, torch.bfloat16, torch.float] | ||
HIDDEN_SIZES = [1, 2, 3, 4, 16, 67, 768, 2048, 5120, 5137, 8192, | ||
8193] # Arbitrary values for testing | ||
HIDDEN_SIZES += list(range(1024, 1033)) # vectorized conversion edge cases | ||
NUM_TOKENS = [1, 7, 83, 4096] # Arbitrary values for testing | ||
SEEDS = [0] | ||
|
||
|
||
@pytest.mark.parametrize("num_tokens", NUM_TOKENS) | ||
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) | ||
@pytest.mark.parametrize("dtype", DTYPES) | ||
@pytest.mark.parametrize("seed", SEEDS) | ||
@torch.inference_mode() | ||
def test_dynamic_per_token_fp8_quant(num_tokens: int, hidden_size: int, | ||
dtype: torch.dtype, seed: int) -> None: | ||
torch.random.manual_seed(seed) | ||
torch.cuda.manual_seed(seed) | ||
|
||
x = torch.rand(num_tokens, hidden_size, dtype=dtype, | ||
device="cuda") + 1e-6 # avoid nans | ||
|
||
ref_out, ref_scales = ref_dynamic_per_token_quant(x, torch.float8_e4m3fn) | ||
ops_out, ops_scales = ops.dynamic_per_token_scaled_fp8_quant(x) | ||
|
||
assert torch.allclose(ref_scales, ops_scales) | ||
assert torch.allclose(ref_out.to(dtype=torch.float32), | ||
ops_out.to(dtype=torch.float32)) | ||
|
||
|
||
@pytest.mark.parametrize("num_tokens", NUM_TOKENS) | ||
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) | ||
@pytest.mark.parametrize("dtype", DTYPES) | ||
@pytest.mark.parametrize("seed", SEEDS) | ||
@torch.inference_mode() | ||
def test_dynamic_per_tensor_fp8_quant(num_tokens: int, hidden_size: int, | ||
dtype: torch.dtype, seed: int) -> None: | ||
torch.random.manual_seed(seed) | ||
torch.cuda.manual_seed(seed) | ||
|
||
x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") | ||
|
||
ref_out, ref_scale = ref_dynamic_per_tensor_fp8_quant(x) | ||
ops_out, ops_scale = ops.scaled_fp8_quant(x) | ||
|
||
assert torch.allclose(ref_scale, ops_scale) | ||
assert torch.allclose(ref_out.to(dtype=torch.float32), | ||
ops_out.to(dtype=torch.float32)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.