Skip to content

Commit

Permalink
[ Kernel ] Enable Dynamic Per Token fp8 (vllm-project#6547)
Browse files Browse the repository at this point in the history
Signed-off-by: Alvant <[email protected]>
  • Loading branch information
robertgshaw2-redhat authored and Alvant committed Oct 26, 2024
1 parent 750f97c commit 93dd982
Show file tree
Hide file tree
Showing 7 changed files with 68 additions and 39 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m nm-testing/Meta-Llama-3-8B-Instruct-W8A8-FP8-Channelwise-compressed-tensors -b auto -l 1000 -f 5 -t 1
model_name: "nm-testing/Meta-Llama-3-8B-Instruct-W8A8-FP8-Channelwise-compressed-tensors"
tasks:
- name: "gsm8k"
metrics:
- name: "exact_match,strict-match"
value: 0.769
- name: "exact_match,flexible-extract"
value: 0.769
limit: 1000
num_fewshot: 5
1 change: 1 addition & 0 deletions .buildkite/lm-eval-harness/configs/models-small.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,5 @@ Meta-Llama-3-8B-Instruct-FP8.yaml
Meta-Llama-3-8B-Instruct-FP8-compressed-tensors.yaml
Meta-Llama-3-8B-Instruct-INT8-compressed-tensors.yaml
Meta-Llama-3-8B-Instruct-nonuniform-compressed-tensors.yaml
Meta-Llama-3-8B-Instruct-Channelwise-compressed-tensors.yaml
Qwen2-1.5B-Instruct-INT8-compressed-tensors.yaml
3 changes: 2 additions & 1 deletion tests/kernels/test_fp8_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ def test_dynamic_per_token_fp8_quant(num_tokens: int, hidden_size: int,
device="cuda") + 1e-6 # avoid nans

ref_out, ref_scales = ref_dynamic_per_token_quant(x, torch.float8_e4m3fn)
ops_out, ops_scales = ops.dynamic_per_token_scaled_fp8_quant(x)
ops_out, ops_scales = ops.scaled_fp8_quant(x,
use_per_token_if_dynamic=True)

assert torch.allclose(ref_scales, ops_scales)
assert torch.allclose(ref_out.to(dtype=torch.float32),
Expand Down
26 changes: 13 additions & 13 deletions vllm/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,7 @@ def scaled_fp8_quant(
input: torch.Tensor,
scale: Optional[torch.Tensor] = None,
batch_dim_padding: Optional[int] = None,
use_per_token_if_dynamic: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Quantize input tensor to FP8 and return quantized tensor and scale.
Expand All @@ -315,6 +316,8 @@ def scaled_fp8_quant(
scale: Optional scaling factor for the FP8 quantization
batch_dim_padding: If specified, pad the first dimension
of the output to at least this value.
use_per_token_if_dynamic: Whether to do per_tensor or per_token
in the dynamic quantization case.
Returns:
Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and
Expand All @@ -328,22 +331,19 @@ def scaled_fp8_quant(
else:
output = torch.empty_like(input, dtype=torch.float8_e4m3fn)
if scale is None:
scale = torch.zeros(1, device=input.device, dtype=torch.float32)
torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale)
if use_per_token_if_dynamic:
scale = torch.empty((input.numel() // input.shape[-1], 1),
device=input.device,
dtype=torch.float32)
torch.ops._C.dynamic_per_token_scaled_fp8_quant(
output, input, scale)
else:
scale = torch.zeros(1, device=input.device, dtype=torch.float32)
torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale)
else:
torch.ops._C.static_scaled_fp8_quant(output, input, scale)
return output, scale


def dynamic_per_token_scaled_fp8_quant(
input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:

output = torch.empty_like(input, dtype=torch.float8_e4m3fn)
scales = torch.empty((input.numel() // input.shape[-1], 1),
device=input.device,
dtype=torch.float32)
torch.ops._C.dynamic_per_token_scaled_fp8_quant(output, input, scales)
return output, scales
return output, scale


# int8
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,4 +103,5 @@ def apply_weights(self,
weight_scale=layer.weight_scale,
input_scale=layer.input_scale,
bias=bias,
cutlass_fp8_supported=self.cutlass_fp8_supported)
cutlass_fp8_supported=self.cutlass_fp8_supported,
use_per_token_if_dynamic=True)
3 changes: 2 additions & 1 deletion vllm/model_executor/layers/quantization/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,8 @@ def apply(self,
weight_scale=layer.weight_scale,
input_scale=layer.input_scale,
bias=bias,
cutlass_fp8_supported=self.cutlass_fp8_supported)
cutlass_fp8_supported=self.cutlass_fp8_supported,
use_per_token_if_dynamic=False)


class Fp8MoEMethod(FusedMoEMethodBase):
Expand Down
60 changes: 37 additions & 23 deletions vllm/model_executor/layers/quantization/utils/w8a8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,41 +107,55 @@ def apply_fp8_linear(
input_scale: torch.Tensor,
bias: Optional[torch.Tensor] = None,
cutlass_fp8_supported: bool = True,
use_per_token_if_dynamic: bool = False,
) -> torch.Tensor:
# ops.scaled_fp8_quant supports both dynamic and static quant.
# If dynamic, layer.input_scale is None and x_scale computed from x.
# If static, layer.input_scale is scalar and x_scale is input_scale.

# cutlass_scaled_mm supports per tensor/channel W and per tensor/token A
if cutlass_fp8_supported:
qinput, x_scale = ops.scaled_fp8_quant(input, input_scale)
qinput, x_scale = ops.scaled_fp8_quant(
input,
input_scale,
use_per_token_if_dynamic=use_per_token_if_dynamic)

# Fused GEMM_DQ
output = ops.cutlass_scaled_mm(qinput,
weight,
out_dtype=input.dtype,
scale_a=x_scale,
scale_b=weight_scale,
bias=bias)

return ops.cutlass_scaled_mm(qinput,
weight,
out_dtype=input.dtype,
scale_a=x_scale,
scale_b=weight_scale,
bias=bias)

# torch.scaled_mm supports per tensor weights + activations only
# so fallback to naive if per channel or per token
else:
# Note: we pad the input because torch._scaled_mm is more performant
# for matrices with batch dimension > 16.
# This could change in the future.
qinput, x_scale = ops.scaled_fp8_quant(input,
input_scale,
batch_dim_padding=17)
qinput, x_scale = ops.scaled_fp8_quant(
input,
input_scale,
batch_dim_padding=17,
use_per_token_if_dynamic=use_per_token_if_dynamic)

per_tensor_weights = (weight_scale.numel() == 1)
per_tensor_activations = (x_scale.numel() == 1)

if weight_scale.numel() == 1:
if per_tensor_weights and per_tensor_activations:
# Fused GEMM_DQ
output, _ = torch._scaled_mm(qinput,
weight,
out_dtype=input.dtype,
scale_a=x_scale,
scale_b=weight_scale,
bias=bias)
return torch.narrow(output, 0, 0, input.shape[0])

else:
# Fallback for channelwise case, where the weight scales are
# applied separately.
# Fallback for channelwise case, where we use unfused DQ
# due to limitations with scaled_mm

# Symmetric quantized GEMM by definition computes the following:
# C = (s_x * X) (s_w * W) + bias
Expand All @@ -155,21 +169,21 @@ def apply_fp8_linear(
# For the scaled_mm fallback case, we break this down, since it
# does not support s_w being a vector.

# This computes C = sx * (X * W).
# GEMM
# This computes C = (X * W).
# Output in fp32 to allow subsequent ops to happen in-place
output, _ = torch._scaled_mm(qinput,
weight,
out_dtype=torch.float32,
scale_a=x_scale)
out_dtype=torch.float32)
# Unpad (undo batch_dim_padding)
output = torch.narrow(output, 0, 0, input.shape[0])

# C = sw * sx * (X * W)
output = output * weight_scale.t()
# DQ
# C = sw * sx * (X * W) + bias
output = output * x_scale * weight_scale.t()
if bias is not None:
# C = sw * sx * (X * W) + bias
output = output + bias
output = output.to(dtype=input.dtype)

return torch.narrow(output, 0, 0, input.shape[0])
return output.to(dtype=input.dtype)


def apply_int8_linear(
Expand Down

0 comments on commit 93dd982

Please sign in to comment.