Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ Kernel ] Enable Dynamic Per Token fp8 #6547

Merged
Merged
3 changes: 2 additions & 1 deletion tests/kernels/test_fp8_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ def test_dynamic_per_token_fp8_quant(num_tokens: int, hidden_size: int,
device="cuda") + 1e-6 # avoid nans

ref_out, ref_scales = ref_dynamic_per_token_quant(x, torch.float8_e4m3fn)
ops_out, ops_scales = ops.dynamic_per_token_scaled_fp8_quant(x)
ops_out, ops_scales = ops.scaled_fp8_quant(x,
use_per_token_if_dynamic=True)

assert torch.allclose(ref_scales, ops_scales)
assert torch.allclose(ref_out.to(dtype=torch.float32),
Expand Down
26 changes: 13 additions & 13 deletions vllm/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,7 @@ def scaled_fp8_quant(
input: torch.Tensor,
scale: Optional[torch.Tensor] = None,
batch_dim_padding: Optional[int] = None,
use_per_token_if_dynamic: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Quantize input tensor to FP8 and return quantized tensor and scale.
Expand All @@ -315,6 +316,8 @@ def scaled_fp8_quant(
scale: Optional scaling factor for the FP8 quantization
batch_dim_padding: If specified, pad the first dimension
of the output to at least this value.
use_per_token_if_dynamic: Whether to do per_tensor or per_token
in the dynamic quantization case.

Returns:
Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and
Expand All @@ -328,22 +331,19 @@ def scaled_fp8_quant(
else:
output = torch.empty_like(input, dtype=torch.float8_e4m3fn)
if scale is None:
scale = torch.zeros(1, device=input.device, dtype=torch.float32)
torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale)
if use_per_token_if_dynamic:
scale = torch.empty((input.numel() // input.shape[-1], 1),
device=input.device,
dtype=torch.float32)
mgoin marked this conversation as resolved.
Show resolved Hide resolved
torch.ops._C.dynamic_per_token_scaled_fp8_quant(
output, input, scale)
else:
scale = torch.zeros(1, device=input.device, dtype=torch.float32)
torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale)
else:
torch.ops._C.static_scaled_fp8_quant(output, input, scale)
return output, scale


def dynamic_per_token_scaled_fp8_quant(
input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:

output = torch.empty_like(input, dtype=torch.float8_e4m3fn)
scales = torch.empty((input.numel() // input.shape[-1], 1),
device=input.device,
dtype=torch.float32)
torch.ops._C.dynamic_per_token_scaled_fp8_quant(output, input, scales)
return output, scales
return output, scale


# int8
Expand Down
4 changes: 3 additions & 1 deletion vllm/model_executor/layers/quantization/utils/w8a8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,9 @@ def apply_fp8_linear(
# If static, layer.input_scale is scalar and x_scale is input_scale.

if cutlass_fp8_supported:
qinput, x_scale = ops.scaled_fp8_quant(input, input_scale)
qinput, x_scale = ops.scaled_fp8_quant(input,
input_scale,
use_per_token_if_dynamic=True)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't use_per_token_if_dynamic be set by a scheme or something? I don't see why it should always be true for this case

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In what case would we not want to use dynamic per token?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mgoin the hypothesis here is that dynamic per token is an overall win over dynamic per tensor when supported. Higher accuracy, but also easier to fuse RMSNorm + Quant, and fewer dependencies so better parallelization for the quantize kernels overall. Downside is more scales.

We'll need to benchmark, but IMO ok for this PR

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume there is a decent performance hit compared to produce/using per-tensor scale, is this not the case?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wouldn't assume a decent performance hit -- the overheads from the CUTLASS epilogues are quite small (like 3%), and there are advantages to doing per-token when quantizing as well. We'll measure.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed my mind -- as I am implementing the per-token/per-channel wrapper for torch._scaled_mm, I think it would be nicer to grab this from a config somewhere

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried to pass this from the scheme. However, this became very hard because I needed to check if cutlass is supported in multiple places

I think that deciding this here is the right place


# Fused GEMM_DQ
output = ops.cutlass_scaled_mm(qinput,
Expand Down
Loading