From bb39bc411c0d5c420f78c1525cd829295347d6a7 Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Thu, 18 Jul 2024 19:52:22 -0400 Subject: [PATCH] [Kernel] Implement fallback for FP8 channelwise using torch._scaled_mm (#6552) --- .../schemes/compressed_tensors_w8a8_fp8.py | 11 ---- .../layers/quantization/utils/w8a8_utils.py | 50 +++++++++++++++---- 2 files changed, 40 insertions(+), 21 deletions(-) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py index e842475e4f34b..686aff4917d21 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py @@ -23,16 +23,6 @@ def __init__(self, strategy: str, is_static_input_scheme: bool): self.is_static_input_scheme = is_static_input_scheme self.cutlass_fp8_supported = cutlass_fp8_supported() - # On Lovelace, fail for now if channelwise. - # TODO: (@tms) fallback - if (not self.cutlass_fp8_supported - and self.strategy == QuantizationStrategy.CHANNEL): - raise ValueError( - "Channelwise fp8 quantization requires vLLM's custom " - "cutlass kernels, which are not supported on your device." - "Consider quantizing with per tensor scales or upgrading " - "to Hopper.") - def get_min_capability(self) -> int: # lovelace and up return 89 @@ -53,7 +43,6 @@ def process_weights_after_loading(self, layer) -> None: # If channelwise, scales are already lined up, so just transpose. elif self.strategy == QuantizationStrategy.CHANNEL: - assert self.cutlass_fp8_supported weight = layer.weight layer.weight = Parameter(weight.t(), requires_grad=False) diff --git a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py index f290a6830c91b..bee2dc659b0f3 100644 --- a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py @@ -124,20 +124,50 @@ def apply_fp8_linear( bias=bias) else: + # Note: we pad the input because torch._scaled_mm is more performant + # for matrices with batch dimension > 16. + # This could change in the future. qinput, x_scale = ops.scaled_fp8_quant(input, input_scale, batch_dim_padding=17) - # Fused GEMM_DQ -- note we padded the input above because - # torch._scaled_mm is more performant for matrices with - # batch dimension > 16. Note that this could change - # in the future. - output, _ = torch._scaled_mm(qinput, - weight, - out_dtype=input.dtype, - scale_a=x_scale, - scale_b=weight_scale, - bias=bias) + if weight_scale.numel() == 1: + # Fused GEMM_DQ + output, _ = torch._scaled_mm(qinput, + weight, + out_dtype=input.dtype, + scale_a=x_scale, + scale_b=weight_scale, + bias=bias) + else: + # Fallback for channelwise case, where the weight scales are + # applied separately. + + # Symmetric quantized GEMM by definition computes the following: + # C = (s_x * X) (s_w * W) + bias + # This is equivalent to dequantizing the weights and activations + # before applying a GEMM. + # + # In order to compute quantized operands, a quantized kernel + # will rewrite the above like so: + # C = s_w * s_x * (X * W) + bias + # + # For the scaled_mm fallback case, we break this down, since it + # does not support s_w being a vector. + + # This computes C = sx * (X * W). + # Output in fp32 to allow subsequent ops to happen in-place + output, _ = torch._scaled_mm(qinput, + weight, + out_dtype=torch.float32, + scale_a=x_scale) + + # C = sw * sx * (X * W) + output = output * weight_scale.t() + if bias is not None: + # C = sw * sx * (X * W) + bias + output = output + bias + output = output.to(dtype=input.dtype) return torch.narrow(output, 0, 0, input.shape[0])