diff --git a/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct-Channelwise-compressed-tensors.yaml b/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct-Channelwise-compressed-tensors.yaml new file mode 100644 index 0000000000000..39b6f20805bdc --- /dev/null +++ b/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct-Channelwise-compressed-tensors.yaml @@ -0,0 +1,11 @@ +# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m nm-testing/Meta-Llama-3-8B-Instruct-W8A8-FP8-Channelwise-compressed-tensors -b auto -l 1000 -f 5 -t 1 +model_name: "nm-testing/Meta-Llama-3-8B-Instruct-W8A8-FP8-Channelwise-compressed-tensors" +tasks: +- name: "gsm8k" + metrics: + - name: "exact_match,strict-match" + value: 0.769 + - name: "exact_match,flexible-extract" + value: 0.769 +limit: 1000 +num_fewshot: 5 diff --git a/.buildkite/lm-eval-harness/configs/models-small.txt b/.buildkite/lm-eval-harness/configs/models-small.txt index 869fc9cef3778..1d1b0ed38671d 100644 --- a/.buildkite/lm-eval-harness/configs/models-small.txt +++ b/.buildkite/lm-eval-harness/configs/models-small.txt @@ -3,4 +3,5 @@ Meta-Llama-3-8B-Instruct-FP8.yaml Meta-Llama-3-8B-Instruct-FP8-compressed-tensors.yaml Meta-Llama-3-8B-Instruct-INT8-compressed-tensors.yaml Meta-Llama-3-8B-Instruct-nonuniform-compressed-tensors.yaml +Meta-Llama-3-8B-Instruct-Channelwise-compressed-tensors.yaml Qwen2-1.5B-Instruct-INT8-compressed-tensors.yaml diff --git a/tests/kernels/test_fp8_quant.py b/tests/kernels/test_fp8_quant.py index 6b555c8e242ad..9077976f44bc9 100644 --- a/tests/kernels/test_fp8_quant.py +++ b/tests/kernels/test_fp8_quant.py @@ -27,7 +27,8 @@ def test_dynamic_per_token_fp8_quant(num_tokens: int, hidden_size: int, device="cuda") + 1e-6 # avoid nans ref_out, ref_scales = ref_dynamic_per_token_quant(x, torch.float8_e4m3fn) - ops_out, ops_scales = ops.dynamic_per_token_scaled_fp8_quant(x) + ops_out, ops_scales = ops.scaled_fp8_quant(x, + use_per_token_if_dynamic=True) assert torch.allclose(ref_scales, ops_scales) assert torch.allclose(ref_out.to(dtype=torch.float32), diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 07646ae582a28..666db29568952 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -300,6 +300,7 @@ def scaled_fp8_quant( input: torch.Tensor, scale: Optional[torch.Tensor] = None, batch_dim_padding: Optional[int] = None, + use_per_token_if_dynamic: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Quantize input tensor to FP8 and return quantized tensor and scale. @@ -315,6 +316,8 @@ def scaled_fp8_quant( scale: Optional scaling factor for the FP8 quantization batch_dim_padding: If specified, pad the first dimension of the output to at least this value. + use_per_token_if_dynamic: Whether to do per_tensor or per_token + in the dynamic quantization case. Returns: Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and @@ -328,22 +331,19 @@ def scaled_fp8_quant( else: output = torch.empty_like(input, dtype=torch.float8_e4m3fn) if scale is None: - scale = torch.zeros(1, device=input.device, dtype=torch.float32) - torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale) + if use_per_token_if_dynamic: + scale = torch.empty((input.numel() // input.shape[-1], 1), + device=input.device, + dtype=torch.float32) + torch.ops._C.dynamic_per_token_scaled_fp8_quant( + output, input, scale) + else: + scale = torch.zeros(1, device=input.device, dtype=torch.float32) + torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale) else: torch.ops._C.static_scaled_fp8_quant(output, input, scale) - return output, scale - -def dynamic_per_token_scaled_fp8_quant( - input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - - output = torch.empty_like(input, dtype=torch.float8_e4m3fn) - scales = torch.empty((input.numel() // input.shape[-1], 1), - device=input.device, - dtype=torch.float32) - torch.ops._C.dynamic_per_token_scaled_fp8_quant(output, input, scales) - return output, scales + return output, scale # int8 diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py index 686aff4917d21..51156a3bc07af 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py @@ -103,4 +103,5 @@ def apply_weights(self, weight_scale=layer.weight_scale, input_scale=layer.input_scale, bias=bias, - cutlass_fp8_supported=self.cutlass_fp8_supported) + cutlass_fp8_supported=self.cutlass_fp8_supported, + use_per_token_if_dynamic=True) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index cfef914ed6cf7..820c066aad28a 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -214,7 +214,8 @@ def apply(self, weight_scale=layer.weight_scale, input_scale=layer.input_scale, bias=bias, - cutlass_fp8_supported=self.cutlass_fp8_supported) + cutlass_fp8_supported=self.cutlass_fp8_supported, + use_per_token_if_dynamic=False) class Fp8MoEMethod(FusedMoEMethodBase): diff --git a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py index bee2dc659b0f3..0729a2d7f8ddd 100644 --- a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py @@ -107,31 +107,43 @@ def apply_fp8_linear( input_scale: torch.Tensor, bias: Optional[torch.Tensor] = None, cutlass_fp8_supported: bool = True, + use_per_token_if_dynamic: bool = False, ) -> torch.Tensor: # ops.scaled_fp8_quant supports both dynamic and static quant. # If dynamic, layer.input_scale is None and x_scale computed from x. # If static, layer.input_scale is scalar and x_scale is input_scale. + # cutlass_scaled_mm supports per tensor/channel W and per tensor/token A if cutlass_fp8_supported: - qinput, x_scale = ops.scaled_fp8_quant(input, input_scale) + qinput, x_scale = ops.scaled_fp8_quant( + input, + input_scale, + use_per_token_if_dynamic=use_per_token_if_dynamic) # Fused GEMM_DQ - output = ops.cutlass_scaled_mm(qinput, - weight, - out_dtype=input.dtype, - scale_a=x_scale, - scale_b=weight_scale, - bias=bias) - + return ops.cutlass_scaled_mm(qinput, + weight, + out_dtype=input.dtype, + scale_a=x_scale, + scale_b=weight_scale, + bias=bias) + + # torch.scaled_mm supports per tensor weights + activations only + # so fallback to naive if per channel or per token else: # Note: we pad the input because torch._scaled_mm is more performant # for matrices with batch dimension > 16. # This could change in the future. - qinput, x_scale = ops.scaled_fp8_quant(input, - input_scale, - batch_dim_padding=17) + qinput, x_scale = ops.scaled_fp8_quant( + input, + input_scale, + batch_dim_padding=17, + use_per_token_if_dynamic=use_per_token_if_dynamic) + + per_tensor_weights = (weight_scale.numel() == 1) + per_tensor_activations = (x_scale.numel() == 1) - if weight_scale.numel() == 1: + if per_tensor_weights and per_tensor_activations: # Fused GEMM_DQ output, _ = torch._scaled_mm(qinput, weight, @@ -139,9 +151,11 @@ def apply_fp8_linear( scale_a=x_scale, scale_b=weight_scale, bias=bias) + return torch.narrow(output, 0, 0, input.shape[0]) + else: - # Fallback for channelwise case, where the weight scales are - # applied separately. + # Fallback for channelwise case, where we use unfused DQ + # due to limitations with scaled_mm # Symmetric quantized GEMM by definition computes the following: # C = (s_x * X) (s_w * W) + bias @@ -155,21 +169,21 @@ def apply_fp8_linear( # For the scaled_mm fallback case, we break this down, since it # does not support s_w being a vector. - # This computes C = sx * (X * W). + # GEMM + # This computes C = (X * W). # Output in fp32 to allow subsequent ops to happen in-place output, _ = torch._scaled_mm(qinput, weight, - out_dtype=torch.float32, - scale_a=x_scale) + out_dtype=torch.float32) + # Unpad (undo batch_dim_padding) + output = torch.narrow(output, 0, 0, input.shape[0]) - # C = sw * sx * (X * W) - output = output * weight_scale.t() + # DQ + # C = sw * sx * (X * W) + bias + output = output * x_scale * weight_scale.t() if bias is not None: - # C = sw * sx * (X * W) + bias output = output + bias - output = output.to(dtype=input.dtype) - - return torch.narrow(output, 0, 0, input.shape[0]) + return output.to(dtype=input.dtype) def apply_int8_linear(