Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ Kernel ] Enable Dynamic Per Token fp8 #6547

Merged
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m nm-testing/Meta-Llama-3-8B-Instruct-W8A8-FP8-Channelwise-compressed-tensors -b auto -l 1000 -f 5 -t 1
model_name: "nm-testing/Meta-Llama-3-8B-Instruct-W8A8-FP8-Channelwise-compressed-tensors"
tasks:
- name: "gsm8k"
metrics:
- name: "exact_match,strict-match"
value: 0.753
- name: "exact_match,flexible-extract"
value: 0.756
limit: 1000
num_fewshot: 5
1 change: 1 addition & 0 deletions .buildkite/lm-eval-harness/configs/models-small.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,5 @@ Meta-Llama-3-8B-Instruct-FP8.yaml
Meta-Llama-3-8B-Instruct-FP8-compressed-tensors.yaml
Meta-Llama-3-8B-Instruct-INT8-compressed-tensors.yaml
Meta-Llama-3-8B-Instruct-nonuniform-compressed-tensors.yaml
Meta-Llama-3-8B-Instruct-Channelwise-compressed-tensors.yaml
Qwen2-1.5B-Instruct-INT8-compressed-tensors.yaml
3 changes: 2 additions & 1 deletion tests/kernels/test_fp8_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ def test_dynamic_per_token_fp8_quant(num_tokens: int, hidden_size: int,
device="cuda") + 1e-6 # avoid nans

ref_out, ref_scales = ref_dynamic_per_token_quant(x, torch.float8_e4m3fn)
ops_out, ops_scales = ops.dynamic_per_token_scaled_fp8_quant(x)
ops_out, ops_scales = ops.scaled_fp8_quant(x,
use_per_token_if_dynamic=True)

assert torch.allclose(ref_scales, ops_scales)
assert torch.allclose(ref_out.to(dtype=torch.float32),
Expand Down
26 changes: 13 additions & 13 deletions vllm/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,7 @@ def scaled_fp8_quant(
input: torch.Tensor,
scale: Optional[torch.Tensor] = None,
batch_dim_padding: Optional[int] = None,
use_per_token_if_dynamic: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Quantize input tensor to FP8 and return quantized tensor and scale.
Expand All @@ -315,6 +316,8 @@ def scaled_fp8_quant(
scale: Optional scaling factor for the FP8 quantization
batch_dim_padding: If specified, pad the first dimension
of the output to at least this value.
use_per_token_if_dynamic: Whether to do per_tensor or per_token
in the dynamic quantization case.

Returns:
Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and
Expand All @@ -328,22 +331,19 @@ def scaled_fp8_quant(
else:
output = torch.empty_like(input, dtype=torch.float8_e4m3fn)
if scale is None:
scale = torch.zeros(1, device=input.device, dtype=torch.float32)
torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale)
if use_per_token_if_dynamic:
scale = torch.empty((input.numel() // input.shape[-1], 1),
device=input.device,
dtype=torch.float32)
mgoin marked this conversation as resolved.
Show resolved Hide resolved
torch.ops._C.dynamic_per_token_scaled_fp8_quant(
output, input, scale)
else:
scale = torch.zeros(1, device=input.device, dtype=torch.float32)
torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale)
else:
torch.ops._C.static_scaled_fp8_quant(output, input, scale)
return output, scale


def dynamic_per_token_scaled_fp8_quant(
input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:

output = torch.empty_like(input, dtype=torch.float8_e4m3fn)
scales = torch.empty((input.numel() // input.shape[-1], 1),
device=input.device,
dtype=torch.float32)
torch.ops._C.dynamic_per_token_scaled_fp8_quant(output, input, scales)
return output, scales
return output, scale


# int8
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,4 +103,5 @@ def apply_weights(self,
weight_scale=layer.weight_scale,
input_scale=layer.input_scale,
bias=bias,
cutlass_fp8_supported=self.cutlass_fp8_supported)
cutlass_fp8_supported=self.cutlass_fp8_supported,
use_per_token_if_dynamic=True)
3 changes: 2 additions & 1 deletion vllm/model_executor/layers/quantization/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,8 @@ def apply(self,
weight_scale=layer.weight_scale,
input_scale=layer.input_scale,
bias=bias,
cutlass_fp8_supported=self.cutlass_fp8_supported)
cutlass_fp8_supported=self.cutlass_fp8_supported,
use_per_token_if_dynamic=False)


class Fp8MoEMethod(FusedMoEMethodBase):
Expand Down
60 changes: 37 additions & 23 deletions vllm/model_executor/layers/quantization/utils/w8a8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,41 +107,55 @@ def apply_fp8_linear(
input_scale: torch.Tensor,
bias: Optional[torch.Tensor] = None,
cutlass_fp8_supported: bool = True,
use_per_token_if_dynamic: bool = False,
) -> torch.Tensor:
# ops.scaled_fp8_quant supports both dynamic and static quant.
# If dynamic, layer.input_scale is None and x_scale computed from x.
# If static, layer.input_scale is scalar and x_scale is input_scale.

# cutlass_scaled_mm supports per tensor/channel W and per tensor/token A
if cutlass_fp8_supported:
qinput, x_scale = ops.scaled_fp8_quant(input, input_scale)
qinput, x_scale = ops.scaled_fp8_quant(
input,
input_scale,
use_per_token_if_dynamic=use_per_token_if_dynamic)

# Fused GEMM_DQ
output = ops.cutlass_scaled_mm(qinput,
weight,
out_dtype=input.dtype,
scale_a=x_scale,
scale_b=weight_scale,
bias=bias)

return ops.cutlass_scaled_mm(qinput,
weight,
out_dtype=input.dtype,
scale_a=x_scale,
scale_b=weight_scale,
bias=bias)

# torch.scaled_mm supports per tensor weights + activations only
# so fallback to naive if per channel or per token
else:
# Note: we pad the input because torch._scaled_mm is more performant
# for matrices with batch dimension > 16.
# This could change in the future.
qinput, x_scale = ops.scaled_fp8_quant(input,
input_scale,
batch_dim_padding=17)
qinput, x_scale = ops.scaled_fp8_quant(
input,
input_scale,
batch_dim_padding=17,
use_per_token_if_dynamic=use_per_token_if_dynamic)

per_tensor_weights = (weight_scale.numel() == 1)
per_tensor_activations = (x_scale.numel() == 1)

if weight_scale.numel() == 1:
if per_tensor_weights and per_tensor_activations:
# Fused GEMM_DQ
output, _ = torch._scaled_mm(qinput,
weight,
out_dtype=input.dtype,
scale_a=x_scale,
scale_b=weight_scale,
bias=bias)
return torch.narrow(output, 0, 0, input.shape[0])

else:
# Fallback for channelwise case, where the weight scales are
# applied separately.
# Fallback for channelwise case, where we use unfused DQ
# due to limitations with scaled_mm

# Symmetric quantized GEMM by definition computes the following:
# C = (s_x * X) (s_w * W) + bias
Expand All @@ -155,21 +169,21 @@ def apply_fp8_linear(
# For the scaled_mm fallback case, we break this down, since it
# does not support s_w being a vector.

# This computes C = sx * (X * W).
# GEMM
# This computes C = (X * W).
# Output in fp32 to allow subsequent ops to happen in-place
output, _ = torch._scaled_mm(qinput,
weight,
out_dtype=torch.float32,
scale_a=x_scale)
out_dtype=torch.float32)
# Unpad (undo batch_dim_padding)
output = torch.narrow(output, 0, 0, input.shape[0])

# C = sw * sx * (X * W)
output = output * weight_scale.t()
# DQ
# C = sw * sx * (X * W) + bias
output = output * x_scale * weight_scale.t()
if bias is not None:
# C = sw * sx * (X * W) + bias
output = output + bias
output = output.to(dtype=input.dtype)

return torch.narrow(output, 0, 0, input.shape[0])
return output.to(dtype=input.dtype)


def apply_int8_linear(
Expand Down
Loading