Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Kernel] Implement fallback for FP8 channelwise using torch._scaled_mm #6552

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,6 @@ def __init__(self, strategy: str, is_static_input_scheme: bool):
self.is_static_input_scheme = is_static_input_scheme
self.cutlass_fp8_supported = cutlass_fp8_supported()

# On Lovelace, fail for now if channelwise.
# TODO: (@tms) fallback
if (not self.cutlass_fp8_supported
and self.strategy == QuantizationStrategy.CHANNEL):
raise ValueError(
"Channelwise fp8 quantization requires vLLM's custom "
"cutlass kernels, which are not supported on your device."
"Consider quantizing with per tensor scales or upgrading "
"to Hopper.")

def get_min_capability(self) -> int:
# lovelace and up
return 89
Expand All @@ -53,7 +43,6 @@ def process_weights_after_loading(self, layer) -> None:

# If channelwise, scales are already lined up, so just transpose.
elif self.strategy == QuantizationStrategy.CHANNEL:
assert self.cutlass_fp8_supported
weight = layer.weight
layer.weight = Parameter(weight.t(), requires_grad=False)

Expand Down
50 changes: 40 additions & 10 deletions vllm/model_executor/layers/quantization/utils/w8a8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,20 +124,50 @@ def apply_fp8_linear(
bias=bias)

else:
# Note: we pad the input because torch._scaled_mm is more performant
# for matrices with batch dimension > 16.
# This could change in the future.
qinput, x_scale = ops.scaled_fp8_quant(input,
input_scale,
batch_dim_padding=17)

# Fused GEMM_DQ -- note we padded the input above because
# torch._scaled_mm is more performant for matrices with
# batch dimension > 16. Note that this could change
# in the future.
output, _ = torch._scaled_mm(qinput,
weight,
out_dtype=input.dtype,
scale_a=x_scale,
scale_b=weight_scale,
bias=bias)
if weight_scale.numel() == 1:
# Fused GEMM_DQ
output, _ = torch._scaled_mm(qinput,
weight,
out_dtype=input.dtype,
scale_a=x_scale,
scale_b=weight_scale,
bias=bias)
else:
# Fallback for channelwise case, where the weight scales are
# applied separately.

# Symmetric quantized GEMM by definition computes the following:
# C = (s_x * X) (s_w * W) + bias
# This is equivalent to dequantizing the weights and activations
# before applying a GEMM.
#
# In order to compute quantized operands, a quantized kernel
# will rewrite the above like so:
# C = s_w * s_x * (X * W) + bias
#
# For the scaled_mm fallback case, we break this down, since it
# does not support s_w being a vector.

# This computes C = sx * (X * W).
# Output in fp32 to allow subsequent ops to happen in-place
output, _ = torch._scaled_mm(qinput,
weight,
out_dtype=torch.float32,
scale_a=x_scale)

# C = sw * sx * (X * W)
output = output * weight_scale.t()
if bias is not None:
# C = sw * sx * (X * W) + bias
output = output + bias
output = output.to(dtype=input.dtype)

return torch.narrow(output, 0, 0, input.shape[0])

Expand Down
Loading