From 2426e29a7abd0214a63a7975624568342a0e7715 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Thu, 18 Jul 2024 16:24:49 +0000 Subject: [PATCH 01/40] stash --- vllm/_custom_ops.py | 24 +++++++++---------- .../layers/quantization/utils/w8a8_utils.py | 3 ++- 2 files changed, 13 insertions(+), 14 deletions(-) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 07646ae582a28..fabbc517904cc 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -300,6 +300,7 @@ def scaled_fp8_quant( input: torch.Tensor, scale: Optional[torch.Tensor] = None, batch_dim_padding: Optional[int] = None, + per_token: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Quantize input tensor to FP8 and return quantized tensor and scale. @@ -315,6 +316,7 @@ def scaled_fp8_quant( scale: Optional scaling factor for the FP8 quantization batch_dim_padding: If specified, pad the first dimension of the output to at least this value. + per_token: Whether to do per_tensor or per_token quant. Returns: Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and @@ -328,24 +330,20 @@ def scaled_fp8_quant( else: output = torch.empty_like(input, dtype=torch.float8_e4m3fn) if scale is None: - scale = torch.zeros(1, device=input.device, dtype=torch.float32) - torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale) + if per_token: + scales = torch.empty((input.numel() // input.shape[-1], 1), + device=input.device, + dtype=torch.float32) + torch.ops._C.dynamic_per_token_scaled_fp8_quant(output, input, scales) + else: + scale = torch.zeros(1, device=input.device, dtype=torch.float32) + torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale) else: torch.ops._C.static_scaled_fp8_quant(output, input, scale) + return output, scale -def dynamic_per_token_scaled_fp8_quant( - input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - - output = torch.empty_like(input, dtype=torch.float8_e4m3fn) - scales = torch.empty((input.numel() // input.shape[-1], 1), - device=input.device, - dtype=torch.float32) - torch.ops._C.dynamic_per_token_scaled_fp8_quant(output, input, scales) - return output, scales - - # int8 def scaled_int8_quant( input: torch.Tensor, diff --git a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py index f290a6830c91b..e2167ffa1d6df 100644 --- a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py @@ -113,7 +113,8 @@ def apply_fp8_linear( # If static, layer.input_scale is scalar and x_scale is input_scale. if cutlass_fp8_supported: - qinput, x_scale = ops.scaled_fp8_quant(input, input_scale) + qinput, x_scale = ops.scaled_fp8_quant(input, input_scale, + per_token=True) # Fused GEMM_DQ output = ops.cutlass_scaled_mm(qinput, From 7665b7bb80a39ce168d4baa905caed6c391d519f Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Thu, 18 Jul 2024 16:48:15 +0000 Subject: [PATCH 02/40] format --- vllm/_custom_ops.py | 13 +++++++------ .../layers/quantization/utils/w8a8_utils.py | 3 ++- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index fabbc517904cc..c931003e02435 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -331,16 +331,17 @@ def scaled_fp8_quant( output = torch.empty_like(input, dtype=torch.float8_e4m3fn) if scale is None: if per_token: - scales = torch.empty((input.numel() // input.shape[-1], 1), - device=input.device, - dtype=torch.float32) - torch.ops._C.dynamic_per_token_scaled_fp8_quant(output, input, scales) - else: + scale = torch.empty((input.numel() // input.shape[-1], 1), + device=input.device, + dtype=torch.float32) + torch.ops._C.dynamic_per_token_scaled_fp8_quant( + output, input, scale) + else: scale = torch.zeros(1, device=input.device, dtype=torch.float32) torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale) else: torch.ops._C.static_scaled_fp8_quant(output, input, scale) - + return output, scale diff --git a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py index e2167ffa1d6df..57a8f572d6265 100644 --- a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py @@ -113,7 +113,8 @@ def apply_fp8_linear( # If static, layer.input_scale is scalar and x_scale is input_scale. if cutlass_fp8_supported: - qinput, x_scale = ops.scaled_fp8_quant(input, input_scale, + qinput, x_scale = ops.scaled_fp8_quant(input, + input_scale, per_token=True) # Fused GEMM_DQ From ef276134bb5568c91c0a72b80bf19783cdbbc281 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Thu, 18 Jul 2024 16:57:56 +0000 Subject: [PATCH 03/40] tweak arg name --- vllm/_custom_ops.py | 7 ++++--- .../model_executor/layers/quantization/utils/w8a8_utils.py | 2 +- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index c931003e02435..666db29568952 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -300,7 +300,7 @@ def scaled_fp8_quant( input: torch.Tensor, scale: Optional[torch.Tensor] = None, batch_dim_padding: Optional[int] = None, - per_token: bool = False, + use_per_token_if_dynamic: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Quantize input tensor to FP8 and return quantized tensor and scale. @@ -316,7 +316,8 @@ def scaled_fp8_quant( scale: Optional scaling factor for the FP8 quantization batch_dim_padding: If specified, pad the first dimension of the output to at least this value. - per_token: Whether to do per_tensor or per_token quant. + use_per_token_if_dynamic: Whether to do per_tensor or per_token + in the dynamic quantization case. Returns: Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and @@ -330,7 +331,7 @@ def scaled_fp8_quant( else: output = torch.empty_like(input, dtype=torch.float8_e4m3fn) if scale is None: - if per_token: + if use_per_token_if_dynamic: scale = torch.empty((input.numel() // input.shape[-1], 1), device=input.device, dtype=torch.float32) diff --git a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py index 57a8f572d6265..1cfb632099ae1 100644 --- a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py @@ -115,7 +115,7 @@ def apply_fp8_linear( if cutlass_fp8_supported: qinput, x_scale = ops.scaled_fp8_quant(input, input_scale, - per_token=True) + use_per_token_if_dynamic=True) # Fused GEMM_DQ output = ops.cutlass_scaled_mm(qinput, From 2f961575a61d86687ddaa9dd48200799da692314 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Thu, 18 Jul 2024 18:10:34 +0000 Subject: [PATCH 04/40] fix test --- tests/kernels/test_fp8_quant.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/kernels/test_fp8_quant.py b/tests/kernels/test_fp8_quant.py index 6b555c8e242ad..0f6ee1d054c95 100644 --- a/tests/kernels/test_fp8_quant.py +++ b/tests/kernels/test_fp8_quant.py @@ -27,7 +27,7 @@ def test_dynamic_per_token_fp8_quant(num_tokens: int, hidden_size: int, device="cuda") + 1e-6 # avoid nans ref_out, ref_scales = ref_dynamic_per_token_quant(x, torch.float8_e4m3fn) - ops_out, ops_scales = ops.dynamic_per_token_scaled_fp8_quant(x) + ops_out, ops_scales = ops.scaled_fp8_quant(x, use_per_token_if_dynamic=True) assert torch.allclose(ref_scales, ops_scales) assert torch.allclose(ref_out.to(dtype=torch.float32), From e748554604a8b100d1771fe8b2b5213dd1554107 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Thu, 18 Jul 2024 18:15:42 +0000 Subject: [PATCH 05/40] format --- tests/kernels/test_fp8_quant.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/kernels/test_fp8_quant.py b/tests/kernels/test_fp8_quant.py index 0f6ee1d054c95..9077976f44bc9 100644 --- a/tests/kernels/test_fp8_quant.py +++ b/tests/kernels/test_fp8_quant.py @@ -27,7 +27,8 @@ def test_dynamic_per_token_fp8_quant(num_tokens: int, hidden_size: int, device="cuda") + 1e-6 # avoid nans ref_out, ref_scales = ref_dynamic_per_token_quant(x, torch.float8_e4m3fn) - ops_out, ops_scales = ops.scaled_fp8_quant(x, use_per_token_if_dynamic=True) + ops_out, ops_scales = ops.scaled_fp8_quant(x, + use_per_token_if_dynamic=True) assert torch.allclose(ref_scales, ops_scales) assert torch.allclose(ref_out.to(dtype=torch.float32), From 3ef571bed6afbd2d5945c0582309091be63beb08 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Fri, 19 Jul 2024 02:43:23 +0000 Subject: [PATCH 06/40] working e2e with our cutlass kernels --- vllm/model_executor/layers/quantization/__init__.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vllm/model_executor/layers/quantization/__init__.py b/vllm/model_executor/layers/quantization/__init__.py index 40b0df75a69a6..af565b17ce188 100644 --- a/vllm/model_executor/layers/quantization/__init__.py +++ b/vllm/model_executor/layers/quantization/__init__.py @@ -11,6 +11,7 @@ from vllm.model_executor.layers.quantization.deepspeedfp import ( DeepSpeedFPConfig) from vllm.model_executor.layers.quantization.fp8 import Fp8Config +from vllm.model_executor.layers.quantization.fbgemm_fp8 import FBGEMMFp8Config from vllm.model_executor.layers.quantization.gptq import GPTQConfig from vllm.model_executor.layers.quantization.gptq_marlin import ( GPTQMarlinConfig) @@ -24,6 +25,7 @@ "awq": AWQConfig, "deepspeedfp": DeepSpeedFPConfig, "fp8": Fp8Config, + "fbgemm_fp8": FBGEMMFp8Config, # The order of gptq methods is important for config.py iteration over # override_quantization_method(..) "marlin": MarlinConfig, From ad83666203db70da53e244ac0d581400efa4f760 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Fri, 19 Jul 2024 02:46:20 +0000 Subject: [PATCH 07/40] added fp8 gemm --- vllm/model_executor/layers/fused_moe/awq.py | 78 ++++++++++++ .../layers/quantization/fbgemm_fp8.py | 112 ++++++++++++++++++ 2 files changed, 190 insertions(+) create mode 100644 vllm/model_executor/layers/fused_moe/awq.py create mode 100644 vllm/model_executor/layers/quantization/fbgemm_fp8.py diff --git a/vllm/model_executor/layers/fused_moe/awq.py b/vllm/model_executor/layers/fused_moe/awq.py new file mode 100644 index 0000000000000..2ffb1c4ce1dee --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/awq.py @@ -0,0 +1,78 @@ +"""Fused MoE utilities for AWQ.""" +import torch + +from vllm import _custom_ops as ops +from vllm.logger import init_logger + +from .fused_moe import fused_moe, fused_topk, moe_align_block_size + +logger = init_logger(__name__) + + +def fused_moe_awq( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + gating_output: torch.Tensor, + topk: int, + renormalize: bool, + w1_scale: torch.Tensor, + w2_scale: torch.Tensor, + w1_qzero: torch.Tensor, + w2_qzero: torch.Tensor, + pack_factor: int, +) -> torch.Tensor: + """ + This function computes a Mixture of Experts (MoE) layer using two sets of + weights, w1 and w2, and top-k gating mechanism. + + Parameters: + - hidden_states (torch.Tensor): The input tensor to the MoE layer. + - w1 (torch.Tensor): The first set of expert weights. + - w2 (torch.Tensor): The second set of expert weights. + - gating_output (torch.Tensor): The output of the gating operation + (before softmax). + - topk (int): The number of top-k experts to select. + - renormalize (bool): If True, renormalize the top-k weights to sum to 1. + - pack_factor (int): Weight packing factor (int4 in int32 == 8) + - w1_scale (torch.Tensor): scale to be used for w1. + - w2_scale (torch.Tensor): scale to be used for w2. + - w1_qzero (torch.Tensor): zero point to be used for w1. + - w2_qzero (torch.Tensor): zero point to be used for w2. + + Returns: + - torch.Tensor: The output tensor after applying the MoE layer. + """ + + # If large seq_len prefill, dequantize and use the fp16 MoE kernel. + do_naive_dequant = hidden_states.shape[:-1].numel() >= 1024 + if do_naive_dequant: + dequant_w1 = ops.awq_dequantize(w1, w1_scale, w1_qzero, 0, 0, + 0).permute(0, 2, 1) + dequant_w2 = ops.awq_dequantize(w2, w2_scale, w2_qzero, 0, 0, + 0).permute(0, 2, 1) + + return fused_moe(hidden_states, dequant_w1, dequant_w2, gating_output, + topk, renormalize) + + topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk, + renormalize) + (sorted_token_ids, expert_ids, + num_tokens_post_padded) = moe_align_block_size(topk_ids, 16, w1.shape[0]) + + x = hidden_states.view(hidden_states.shape[0], 1, *hidden_states.shape[1:]) + + gate_up = ops.awq_group_gemm(x, w1, w1_scale, w1_qzero, topk_weights, + sorted_token_ids, expert_ids, + num_tokens_post_padded, False, pack_factor) + + out = torch.empty((gate_up.shape[:-1] + (gate_up.shape[-1] // 2, )), + dtype=hidden_states.dtype, + device=hidden_states.device) + ops.silu_and_mul(out, gate_up) + + out = ops.awq_group_gemm(out, w2, w2_scale, w2_qzero, topk_weights, + sorted_token_ids, expert_ids, + num_tokens_post_padded, True, pack_factor) + + return torch.sum(out, dim=1) diff --git a/vllm/model_executor/layers/quantization/fbgemm_fp8.py b/vllm/model_executor/layers/quantization/fbgemm_fp8.py new file mode 100644 index 0000000000000..d11134d612d99 --- /dev/null +++ b/vllm/model_executor/layers/quantization/fbgemm_fp8.py @@ -0,0 +1,112 @@ +from typing import Any, Dict, List, Optional + +import torch +from torch.nn import Module +from torch.nn.parameter import Parameter + +from vllm.logger import init_logger +from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig, QuantizeMethodBase) +from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( + create_per_channel_scale_param, apply_fp8_linear) +from vllm.model_executor.utils import set_weight_attrs + + +logger = init_logger(__name__) + +class FBGEMMFp8Config(QuantizationConfig): + """Config class for FBGEMM Fp8.""" + + @classmethod + def get_name(cls) -> str: + return "fbgemm_fp8" + + @classmethod + def get_supported_act_dtypes(cls) -> List[torch.dtype]: + return [torch.bfloat16] + + @classmethod + def get_min_capability(cls) -> int: + return 90 + + @classmethod + def get_config_filenames(cls) -> List[str]: + return [] + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "FBGEMMFp8Config": + return cls() + + def get_quant_method( + self, layer: torch.nn.Module) -> Optional["QuantizeMethodBase"]: + if isinstance(layer, LinearBase): + return FBGEMMFp8LinearMethod(self) + return None + + def get_scaled_act_names(self) -> List[str]: + return [] + + +class FBGEMMFp8LinearMethod(LinearMethodBase): + + def __init__(self, quant_config: FBGEMMFp8Config): + self.quant_config = quant_config + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + del input_size, output_size + output_size_per_partition = sum(output_partition_sizes) + + layer.logical_widths = output_partition_sizes + + layer.input_size_per_partition = input_size_per_partition + layer.output_size_per_partition = output_size_per_partition + layer.orig_dtype = params_dtype + + # WEIGHT + weight = Parameter(torch.empty(output_size_per_partition, + input_size_per_partition, + dtype=torch.float8_e4m3fn), + requires_grad=False) + layer.register_parameter("weight", weight) + set_weight_attrs(weight, { + "input_dim": 1, "output_dim": 0, + **extra_weight_attrs, + }) + + # WEIGHT SCALE + weight_scale = create_per_channel_scale_param(output_partition_sizes, + **extra_weight_attrs) + layer.register_parameter("weight_scale", weight_scale) + + # NOT USED FOR INFERNECE + input_scale_ub = torch.nn.Parameter(torch.zeros((1), dtype=torch.float8_e4m3fn)) + layer.register_parameter("input_scale_ub", input_scale_ub) + set_weight_attrs(input_scale_ub, extra_weight_attrs) + + def process_weights_after_loading(self, layer: Module) -> None: + weight = layer.weight + layer.weight = Parameter(weight.t(), requires_grad=False) + + def apply(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + + return apply_fp8_linear( + input=x, + weight=layer.weight, + weight_scale=layer.weight_scale, + input_scale=None, + bias=bias, + cutlass_fp8_supported=True) + From eb7d48cd56690e653bfbd437c66fbea5e06813e1 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Fri, 19 Jul 2024 02:47:21 +0000 Subject: [PATCH 08/40] remove --- vllm/model_executor/layers/fused_moe/awq.py | 78 --------------------- 1 file changed, 78 deletions(-) delete mode 100644 vllm/model_executor/layers/fused_moe/awq.py diff --git a/vllm/model_executor/layers/fused_moe/awq.py b/vllm/model_executor/layers/fused_moe/awq.py deleted file mode 100644 index 2ffb1c4ce1dee..0000000000000 --- a/vllm/model_executor/layers/fused_moe/awq.py +++ /dev/null @@ -1,78 +0,0 @@ -"""Fused MoE utilities for AWQ.""" -import torch - -from vllm import _custom_ops as ops -from vllm.logger import init_logger - -from .fused_moe import fused_moe, fused_topk, moe_align_block_size - -logger = init_logger(__name__) - - -def fused_moe_awq( - hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - gating_output: torch.Tensor, - topk: int, - renormalize: bool, - w1_scale: torch.Tensor, - w2_scale: torch.Tensor, - w1_qzero: torch.Tensor, - w2_qzero: torch.Tensor, - pack_factor: int, -) -> torch.Tensor: - """ - This function computes a Mixture of Experts (MoE) layer using two sets of - weights, w1 and w2, and top-k gating mechanism. - - Parameters: - - hidden_states (torch.Tensor): The input tensor to the MoE layer. - - w1 (torch.Tensor): The first set of expert weights. - - w2 (torch.Tensor): The second set of expert weights. - - gating_output (torch.Tensor): The output of the gating operation - (before softmax). - - topk (int): The number of top-k experts to select. - - renormalize (bool): If True, renormalize the top-k weights to sum to 1. - - pack_factor (int): Weight packing factor (int4 in int32 == 8) - - w1_scale (torch.Tensor): scale to be used for w1. - - w2_scale (torch.Tensor): scale to be used for w2. - - w1_qzero (torch.Tensor): zero point to be used for w1. - - w2_qzero (torch.Tensor): zero point to be used for w2. - - Returns: - - torch.Tensor: The output tensor after applying the MoE layer. - """ - - # If large seq_len prefill, dequantize and use the fp16 MoE kernel. - do_naive_dequant = hidden_states.shape[:-1].numel() >= 1024 - if do_naive_dequant: - dequant_w1 = ops.awq_dequantize(w1, w1_scale, w1_qzero, 0, 0, - 0).permute(0, 2, 1) - dequant_w2 = ops.awq_dequantize(w2, w2_scale, w2_qzero, 0, 0, - 0).permute(0, 2, 1) - - return fused_moe(hidden_states, dequant_w1, dequant_w2, gating_output, - topk, renormalize) - - topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk, - renormalize) - (sorted_token_ids, expert_ids, - num_tokens_post_padded) = moe_align_block_size(topk_ids, 16, w1.shape[0]) - - x = hidden_states.view(hidden_states.shape[0], 1, *hidden_states.shape[1:]) - - gate_up = ops.awq_group_gemm(x, w1, w1_scale, w1_qzero, topk_weights, - sorted_token_ids, expert_ids, - num_tokens_post_padded, False, pack_factor) - - out = torch.empty((gate_up.shape[:-1] + (gate_up.shape[-1] // 2, )), - dtype=hidden_states.dtype, - device=hidden_states.device) - ops.silu_and_mul(out, gate_up) - - out = ops.awq_group_gemm(out, w2, w2_scale, w2_qzero, topk_weights, - sorted_token_ids, expert_ids, - num_tokens_post_padded, True, pack_factor) - - return torch.sum(out, dim=1) From 90bd839be799f48c9a321a54a2f53588fabf2ceb Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Fri, 19 Jul 2024 02:58:42 +0000 Subject: [PATCH 09/40] format --- .../layers/quantization/__init__.py | 2 +- .../layers/quantization/fbgemm_fp8.py | 29 ++++++++++--------- 2 files changed, 17 insertions(+), 14 deletions(-) diff --git a/vllm/model_executor/layers/quantization/__init__.py b/vllm/model_executor/layers/quantization/__init__.py index af565b17ce188..c1bb45224fcc1 100644 --- a/vllm/model_executor/layers/quantization/__init__.py +++ b/vllm/model_executor/layers/quantization/__init__.py @@ -10,8 +10,8 @@ CompressedTensorsConfig) from vllm.model_executor.layers.quantization.deepspeedfp import ( DeepSpeedFPConfig) -from vllm.model_executor.layers.quantization.fp8 import Fp8Config from vllm.model_executor.layers.quantization.fbgemm_fp8 import FBGEMMFp8Config +from vllm.model_executor.layers.quantization.fp8 import Fp8Config from vllm.model_executor.layers.quantization.gptq import GPTQConfig from vllm.model_executor.layers.quantization.gptq_marlin import ( GPTQMarlinConfig) diff --git a/vllm/model_executor/layers/quantization/fbgemm_fp8.py b/vllm/model_executor/layers/quantization/fbgemm_fp8.py index d11134d612d99..53eae9b31474f 100644 --- a/vllm/model_executor/layers/quantization/fbgemm_fp8.py +++ b/vllm/model_executor/layers/quantization/fbgemm_fp8.py @@ -9,12 +9,12 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - create_per_channel_scale_param, apply_fp8_linear) + apply_fp8_linear, create_per_channel_scale_param) from vllm.model_executor.utils import set_weight_attrs - logger = init_logger(__name__) + class FBGEMMFp8Config(QuantizationConfig): """Config class for FBGEMM Fp8.""" @@ -79,7 +79,8 @@ def create_weights( requires_grad=False) layer.register_parameter("weight", weight) set_weight_attrs(weight, { - "input_dim": 1, "output_dim": 0, + "input_dim": 1, + "output_dim": 0, **extra_weight_attrs, }) @@ -89,9 +90,13 @@ def create_weights( layer.register_parameter("weight_scale", weight_scale) # NOT USED FOR INFERNECE - input_scale_ub = torch.nn.Parameter(torch.zeros((1), dtype=torch.float8_e4m3fn)) + input_scale_ub = torch.nn.Parameter( + torch.zeros((1), dtype=torch.float8_e4m3fn)) layer.register_parameter("input_scale_ub", input_scale_ub) - set_weight_attrs(input_scale_ub, extra_weight_attrs) + set_weight_attrs(input_scale_ub, { + "ignore_warning": True, + **extra_weight_attrs + }) def process_weights_after_loading(self, layer: Module) -> None: weight = layer.weight @@ -102,11 +107,9 @@ def apply(self, x: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: - return apply_fp8_linear( - input=x, - weight=layer.weight, - weight_scale=layer.weight_scale, - input_scale=None, - bias=bias, - cutlass_fp8_supported=True) - + return apply_fp8_linear(input=x, + weight=layer.weight, + weight_scale=layer.weight_scale, + input_scale=None, + bias=bias, + cutlass_fp8_supported=True) From d064dd71a6263707ec7ba5fdd9b1052acefd2cd1 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Fri, 19 Jul 2024 16:01:41 +0000 Subject: [PATCH 10/40] stash --- .../layers/quantization/utils/w8a8_utils.py | 67 +++++++++++++++++-- 1 file changed, 61 insertions(+), 6 deletions(-) diff --git a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py index 2e9bb2cb0a637..dc2ec30febb6d 100644 --- a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py @@ -111,7 +111,9 @@ def apply_fp8_linear( # ops.scaled_fp8_quant supports both dynamic and static quant. # If dynamic, layer.input_scale is None and x_scale computed from x. # If static, layer.input_scale is scalar and x_scale is input_scale. - + + cutlass_fp8_supported = False + # cutlass_scaled_mm supports per tensor/channel W and per tensor/token A if cutlass_fp8_supported: qinput, x_scale = ops.scaled_fp8_quant(input, input_scale, @@ -125,15 +127,22 @@ def apply_fp8_linear( scale_b=weight_scale, bias=bias) + # torch.scaled_mm supports per tensor weights + activations only + # so fallback to naive if per channel or per token else: # Note: we pad the input because torch._scaled_mm is more performant # for matrices with batch dimension > 16. # This could change in the future. qinput, x_scale = ops.scaled_fp8_quant(input, input_scale, - batch_dim_padding=17) + batch_dim_padding=17, + use_per_token_if_dynamic=False) + + per_tensor_weights = (weight_scale.numel() == 1) + per_tensor_activations = (x_scale.numel() == 1) - if weight_scale.numel() == 1: + # Per tensor for both weights + if per_tensor_weights and per_tensor_activations: # Fused GEMM_DQ output, _ = torch._scaled_mm(qinput, weight, @@ -141,9 +150,41 @@ def apply_fp8_linear( scale_a=x_scale, scale_b=weight_scale, bias=bias) - else: - # Fallback for channelwise case, where the weight scales are - # applied separately. + + # Per tensor weights, per t + elif per_tensor_weights: + # Fallback for per token activations case, where the activation + # scales are applied separately. + + # Symmetric quantized GEMM by definition computes the following: + # C = (s_x * X) (s_w * W) + bias + # This is equivalent to dequantizing the weights and activations + # before applying a GEMM. + # + # In order to compute quantized operands, a quantized kernel + # will rewrite the above like so: + # C = s_w * s_x * (X * W) + bias + # + # For the scaled_mm fallback case, we break this down, since it + # does not support s_w being a vector. + + # This computes C = sw * (X * W). + # Output in fp32 to allow subsequent ops to happen in-place + output, _ = torch._scaled_mm(qinput, + weight, + out_dtype=torch.float32, + scale_b=x_scale) + + # C = sw * sx * (X * W) + output = output * weight_scale.t() + if bias is not None: + # C = sw * sx * (X * W) + bias + output = output + bias + output = output.to(dtype=input.dtype) + + elif per_tensor_activations: + # Fallback for channelwise weights case, where the weight scales + # are applied separately. # Symmetric quantized GEMM by definition computes the following: # C = (s_x * X) (s_w * W) + bias @@ -170,6 +211,20 @@ def apply_fp8_linear( # C = sw * sx * (X * W) + bias output = output + bias output = output.to(dtype=input.dtype) + + else: + # X * W + output, _ = torch._scaled_mm(qinput, + weight, + out_dtype=torch.float32) + + # C = sw * sx * (X * W) + output = output * x_scale.t() * weight_scale.t() + if bias is not None: + # C = sw * sx * (X * W) + bias + output = output + bias + output = output.to(dtype=input.dtype) + return torch.narrow(output, 0, 0, input.shape[0]) From 6aa37e54a65a58d9012e0e5cb4db2d3dc46001f9 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Fri, 19 Jul 2024 18:06:29 +0000 Subject: [PATCH 11/40] dynamic per token --- .../layers/quantization/utils/w8a8_utils.py | 90 +++++-------------- 1 file changed, 21 insertions(+), 69 deletions(-) diff --git a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py index dc2ec30febb6d..c97f444b67894 100644 --- a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py @@ -107,6 +107,7 @@ def apply_fp8_linear( input_scale: torch.Tensor, bias: Optional[torch.Tensor] = None, cutlass_fp8_supported: bool = True, + use_per_token_if_dynamic: bool = False, ) -> torch.Tensor: # ops.scaled_fp8_quant supports both dynamic and static quant. # If dynamic, layer.input_scale is None and x_scale computed from x. @@ -115,17 +116,16 @@ def apply_fp8_linear( cutlass_fp8_supported = False # cutlass_scaled_mm supports per tensor/channel W and per tensor/token A if cutlass_fp8_supported: - qinput, x_scale = ops.scaled_fp8_quant(input, - input_scale, - use_per_token_if_dynamic=True) + qinput, x_scale = ops.scaled_fp8_quant(input, input_scale, + use_per_token_if_dynamic=use_per_token_if_dynamic) # Fused GEMM_DQ - output = ops.cutlass_scaled_mm(qinput, - weight, - out_dtype=input.dtype, - scale_a=x_scale, - scale_b=weight_scale, - bias=bias) + return ops.cutlass_scaled_mm(qinput, + weight, + out_dtype=input.dtype, + scale_a=x_scale, + scale_b=weight_scale, + bias=bias) # torch.scaled_mm supports per tensor weights + activations only # so fallback to naive if per channel or per token @@ -133,15 +133,12 @@ def apply_fp8_linear( # Note: we pad the input because torch._scaled_mm is more performant # for matrices with batch dimension > 16. # This could change in the future. - qinput, x_scale = ops.scaled_fp8_quant(input, - input_scale, - batch_dim_padding=17, - use_per_token_if_dynamic=False) + qinput, x_scale = ops.scaled_fp8_quant(input, input_scale, batch_dim_padding=17, + use_per_token_if_dynamic=use_per_token_if_dynamic) per_tensor_weights = (weight_scale.numel() == 1) per_tensor_activations = (x_scale.numel() == 1) - # Per tensor for both weights if per_tensor_weights and per_tensor_activations: # Fused GEMM_DQ output, _ = torch._scaled_mm(qinput, @@ -150,41 +147,11 @@ def apply_fp8_linear( scale_a=x_scale, scale_b=weight_scale, bias=bias) + return torch.narrow(output, 0, 0, input.shape[0]) - # Per tensor weights, per t - elif per_tensor_weights: - # Fallback for per token activations case, where the activation - # scales are applied separately. - - # Symmetric quantized GEMM by definition computes the following: - # C = (s_x * X) (s_w * W) + bias - # This is equivalent to dequantizing the weights and activations - # before applying a GEMM. - # - # In order to compute quantized operands, a quantized kernel - # will rewrite the above like so: - # C = s_w * s_x * (X * W) + bias - # - # For the scaled_mm fallback case, we break this down, since it - # does not support s_w being a vector. - - # This computes C = sw * (X * W). - # Output in fp32 to allow subsequent ops to happen in-place - output, _ = torch._scaled_mm(qinput, - weight, - out_dtype=torch.float32, - scale_b=x_scale) - - # C = sw * sx * (X * W) - output = output * weight_scale.t() - if bias is not None: - # C = sw * sx * (X * W) + bias - output = output + bias - output = output.to(dtype=input.dtype) - - elif per_tensor_activations: - # Fallback for channelwise weights case, where the weight scales - # are applied separately. + else: + # Fallback for channelwise case, where we use unfused DQ + # due to limitations with scaled_mm # Symmetric quantized GEMM by definition computes the following: # C = (s_x * X) (s_w * W) + bias @@ -198,35 +165,20 @@ def apply_fp8_linear( # For the scaled_mm fallback case, we break this down, since it # does not support s_w being a vector. - # This computes C = sx * (X * W). + # This computes C = (X * W). # Output in fp32 to allow subsequent ops to happen in-place - output, _ = torch._scaled_mm(qinput, - weight, - out_dtype=torch.float32, - scale_a=x_scale) - - # C = sw * sx * (X * W) - output = output * weight_scale.t() - if bias is not None: - # C = sw * sx * (X * W) + bias - output = output + bias - output = output.to(dtype=input.dtype) - - else: - # X * W output, _ = torch._scaled_mm(qinput, weight, out_dtype=torch.float32) + # Unpad (undo batch_dim_padding) + output = torch.narrow(output, 0, 0, input.shape[0]) - # C = sw * sx * (X * W) - output = output * x_scale.t() * weight_scale.t() + # C = sw * sx * (X * W) + bias + output = output * x_scale * weight_scale.t() if bias is not None: # C = sw * sx * (X * W) + bias output = output + bias - output = output.to(dtype=input.dtype) - - - return torch.narrow(output, 0, 0, input.shape[0]) + return output.to(dtype=input.dtype) def apply_int8_linear( From c9d819acdb270af35370036de87e47fe3ce9b117 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Fri, 19 Jul 2024 18:06:43 +0000 Subject: [PATCH 12/40] format --- .../layers/quantization/utils/w8a8_utils.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py index c97f444b67894..e67405295b033 100644 --- a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py @@ -112,12 +112,14 @@ def apply_fp8_linear( # ops.scaled_fp8_quant supports both dynamic and static quant. # If dynamic, layer.input_scale is None and x_scale computed from x. # If static, layer.input_scale is scalar and x_scale is input_scale. - + cutlass_fp8_supported = False # cutlass_scaled_mm supports per tensor/channel W and per tensor/token A if cutlass_fp8_supported: - qinput, x_scale = ops.scaled_fp8_quant(input, input_scale, - use_per_token_if_dynamic=use_per_token_if_dynamic) + qinput, x_scale = ops.scaled_fp8_quant( + input, + input_scale, + use_per_token_if_dynamic=use_per_token_if_dynamic) # Fused GEMM_DQ return ops.cutlass_scaled_mm(qinput, @@ -133,8 +135,11 @@ def apply_fp8_linear( # Note: we pad the input because torch._scaled_mm is more performant # for matrices with batch dimension > 16. # This could change in the future. - qinput, x_scale = ops.scaled_fp8_quant(input, input_scale, batch_dim_padding=17, - use_per_token_if_dynamic=use_per_token_if_dynamic) + qinput, x_scale = ops.scaled_fp8_quant( + input, + input_scale, + batch_dim_padding=17, + use_per_token_if_dynamic=use_per_token_if_dynamic) per_tensor_weights = (weight_scale.numel() == 1) per_tensor_activations = (x_scale.numel() == 1) @@ -148,7 +153,7 @@ def apply_fp8_linear( scale_b=weight_scale, bias=bias) return torch.narrow(output, 0, 0, input.shape[0]) - + else: # Fallback for channelwise case, where we use unfused DQ # due to limitations with scaled_mm From 08cbaf72b1b5ce4cd86bcea6934380c30e8ea627 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Fri, 19 Jul 2024 18:11:16 +0000 Subject: [PATCH 13/40] reenable cutlass --- vllm/model_executor/layers/quantization/utils/w8a8_utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py index e67405295b033..dd8f9741f563f 100644 --- a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py @@ -113,7 +113,6 @@ def apply_fp8_linear( # If dynamic, layer.input_scale is None and x_scale computed from x. # If static, layer.input_scale is scalar and x_scale is input_scale. - cutlass_fp8_supported = False # cutlass_scaled_mm supports per tensor/channel W and per tensor/token A if cutlass_fp8_supported: qinput, x_scale = ops.scaled_fp8_quant( @@ -170,6 +169,7 @@ def apply_fp8_linear( # For the scaled_mm fallback case, we break this down, since it # does not support s_w being a vector. + # GEMM # This computes C = (X * W). # Output in fp32 to allow subsequent ops to happen in-place output, _ = torch._scaled_mm(qinput, @@ -178,6 +178,7 @@ def apply_fp8_linear( # Unpad (undo batch_dim_padding) output = torch.narrow(output, 0, 0, input.shape[0]) + # DQ # C = sw * sx * (X * W) + bias output = output * x_scale * weight_scale.t() if bias is not None: From f4cdda16f377d8287d99b3e95ff90b568dc739bc Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Fri, 19 Jul 2024 18:12:04 +0000 Subject: [PATCH 14/40] cleanup comment --- vllm/model_executor/layers/quantization/utils/w8a8_utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py index dd8f9741f563f..0729a2d7f8ddd 100644 --- a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py @@ -182,7 +182,6 @@ def apply_fp8_linear( # C = sw * sx * (X * W) + bias output = output * x_scale * weight_scale.t() if bias is not None: - # C = sw * sx * (X * W) + bias output = output + bias return output.to(dtype=input.dtype) From 2971f4d4897249ab2f537bff48dca4a757c013b5 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Fri, 19 Jul 2024 18:26:21 +0000 Subject: [PATCH 15/40] format --- .../compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py | 3 ++- vllm/model_executor/layers/quantization/fp8.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py index 686aff4917d21..51156a3bc07af 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py @@ -103,4 +103,5 @@ def apply_weights(self, weight_scale=layer.weight_scale, input_scale=layer.input_scale, bias=bias, - cutlass_fp8_supported=self.cutlass_fp8_supported) + cutlass_fp8_supported=self.cutlass_fp8_supported, + use_per_token_if_dynamic=True) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index cfef914ed6cf7..820c066aad28a 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -214,7 +214,8 @@ def apply(self, weight_scale=layer.weight_scale, input_scale=layer.input_scale, bias=bias, - cutlass_fp8_supported=self.cutlass_fp8_supported) + cutlass_fp8_supported=self.cutlass_fp8_supported, + use_per_token_if_dynamic=False) class Fp8MoEMethod(FusedMoEMethodBase): From b601033989cc84f616eb38382967a38d98f83ac5 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Fri, 19 Jul 2024 18:33:51 +0000 Subject: [PATCH 16/40] added dynamic per token test case --- ...-3-8B-Instruct-Channelwise-compressed-tensors.yaml | 11 +++++++++++ .buildkite/lm-eval-harness/configs/models-small.txt | 1 + 2 files changed, 12 insertions(+) create mode 100644 .buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct-Channelwise-compressed-tensors.yaml diff --git a/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct-Channelwise-compressed-tensors.yaml b/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct-Channelwise-compressed-tensors.yaml new file mode 100644 index 0000000000000..1ab9954a47071 --- /dev/null +++ b/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct-Channelwise-compressed-tensors.yaml @@ -0,0 +1,11 @@ +# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m nm-testing/Meta-Llama-3-8B-Instruct-W8A8-FP8-Channelwise-compressed-tensors -b auto -l 1000 -f 5 -t 1 +model_name: "nm-testing/Meta-Llama-3-8B-Instruct-W8A8-FP8-Channelwise-compressed-tensors" +tasks: +- name: "gsm8k" + metrics: + - name: "exact_match,strict-match" + value: 0.753 + - name: "exact_match,flexible-extract" + value: 0.756 +limit: 1000 +num_fewshot: 5 diff --git a/.buildkite/lm-eval-harness/configs/models-small.txt b/.buildkite/lm-eval-harness/configs/models-small.txt index 869fc9cef3778..a0ca7caa2bbd3 100644 --- a/.buildkite/lm-eval-harness/configs/models-small.txt +++ b/.buildkite/lm-eval-harness/configs/models-small.txt @@ -3,4 +3,5 @@ Meta-Llama-3-8B-Instruct-FP8.yaml Meta-Llama-3-8B-Instruct-FP8-compressed-tensors.yaml Meta-Llama-3-8B-Instruct-INT8-compressed-tensors.yaml Meta-Llama-3-8B-Instruct-nonuniform-compressed-tensors.yaml +Meta-Llama-3-8B-Instruct-Channelwise-compressed-tensors.yaml Qwen2-1.5B-Instruct-INT8-compressed-tensors.yaml From 8b5d638cdb7263cb261fa1b04ef78c3827e34841 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Fri, 19 Jul 2024 18:36:51 +0000 Subject: [PATCH 17/40] added use per token --- vllm/model_executor/layers/quantization/fbgemm_fp8.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/quantization/fbgemm_fp8.py b/vllm/model_executor/layers/quantization/fbgemm_fp8.py index 53eae9b31474f..7924446279310 100644 --- a/vllm/model_executor/layers/quantization/fbgemm_fp8.py +++ b/vllm/model_executor/layers/quantization/fbgemm_fp8.py @@ -112,4 +112,5 @@ def apply(self, weight_scale=layer.weight_scale, input_scale=None, bias=bias, - cutlass_fp8_supported=True) + cutlass_fp8_supported=True, + use_per_token_if_dynamic=True) From 006ccf0ea309e73ee3b7dc3bbb7017fa2bf7b721 Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranth Date: Fri, 19 Jul 2024 20:19:49 +0000 Subject: [PATCH 18/40] format --- csrc/ops.h | 3 ++- csrc/quantization/fp8/common.cu | 27 ++++++++++++++++++++------- csrc/torch_bindings.cpp | 2 +- tests/kernels/quant_utils.py | 17 ++++++++++++++--- tests/kernels/test_fp8_quant.py | 7 ++++--- vllm/_custom_ops.py | 4 ++-- 6 files changed, 43 insertions(+), 17 deletions(-) diff --git a/csrc/ops.h b/csrc/ops.h index c0f924c09b515..994412d97f7f9 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -136,7 +136,8 @@ void dynamic_scaled_fp8_quant(torch::Tensor& out, torch::Tensor const& input, void dynamic_per_token_scaled_fp8_quant(torch::Tensor& out, torch::Tensor const& input, - torch::Tensor& scale); + torch::Tensor& scale, + c10::optional const& scale_ub); void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, int64_t block_size, torch::Tensor sorted_token_ids, diff --git a/csrc/quantization/fp8/common.cu b/csrc/quantization/fp8/common.cu index 0938c0707679f..d8ba7c1b62a65 100644 --- a/csrc/quantization/fp8/common.cu +++ b/csrc/quantization/fp8/common.cu @@ -166,7 +166,11 @@ __global__ void scaled_fp8_quant_kernel(c10::Float8_e4m3fn* __restrict__ out, template __global__ void dynamic_per_token_scaled_fp8_quant_kernel( c10::Float8_e4m3fn* __restrict__ out, float* __restrict__ scale, - scalar_t const* __restrict__ input, const int hidden_size) { + scalar_t const* __restrict__ input, + float const* __restrict__ scale_ub, + const int hidden_size) { + float const min_scaling_factor = 1.0f / (FP8_E4M3_MAX * 512.f); + int const tid = threadIdx.x; int const token_idx = blockIdx.x; @@ -188,14 +192,20 @@ __global__ void dynamic_per_token_scaled_fp8_quant_kernel( } float const block_absmax_val_maybe = blockReduceMax(absmax_val); - __shared__ float block_absmax_val; + __shared__ float token_scale; if (tid == 0) { - block_absmax_val = block_absmax_val_maybe; - scale[token_idx] = block_absmax_val / FP8_E4M3_MAX; + if (scale_ub) { + token_scale = min(block_absmax_val_maybe, *scale_ub); + } else { + token_scale = block_absmax_val_maybe; + } + // token scale computation + token_scale = max(token_scale / FP8_E4M3_MAX, min_scaling_factor); + scale[token_idx] = token_scale; } __syncthreads(); - float const inverted_scale = FP8_E4M3_MAX / block_absmax_val; + float const inverted_scale = 1.0f / token_scale; if (can_vectorize) { scaled_fp8_conversion_vec(token_output, token_input, inverted_scale, hidden_size, tid, blockDim.x); @@ -248,7 +258,8 @@ void dynamic_scaled_fp8_quant(torch::Tensor& out, // [..., d] void dynamic_per_token_scaled_fp8_quant(torch::Tensor& out, // [..., d] torch::Tensor const& input, // [..., d] - torch::Tensor& scales) { + torch::Tensor& scales, + std::optional const& scale_ub) { TORCH_CHECK(input.is_contiguous()); TORCH_CHECK(out.is_contiguous()); @@ -264,6 +275,8 @@ void dynamic_per_token_scaled_fp8_quant(torch::Tensor& out, // [..., d] vllm::dynamic_per_token_scaled_fp8_quant_kernel <<>>( out.data_ptr(), scales.data_ptr(), - input.data_ptr(), hidden_size); + input.data_ptr(), + scale_ub.has_value() ? scale_ub->data_ptr() : nullptr, + hidden_size); }); } diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 55ccc6f53b455..d5136e45e781e 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -188,7 +188,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // Compute dynamic-per-token FP8 quantized tensor and scaling factor. ops.def( "dynamic_per_token_scaled_fp8_quant(Tensor! out, Tensor input, Tensor! " - "scale) -> " + "scale, Tensor? scale_ub) -> " "()"); ops.impl("dynamic_per_token_scaled_fp8_quant", torch::kCUDA, &dynamic_per_token_scaled_fp8_quant); diff --git a/tests/kernels/quant_utils.py b/tests/kernels/quant_utils.py index a1513bdffe768..7fa7e03a8185d 100644 --- a/tests/kernels/quant_utils.py +++ b/tests/kernels/quant_utils.py @@ -1,4 +1,4 @@ -from typing import Tuple, Union +from typing import Tuple, Union, Optional import torch @@ -7,13 +7,19 @@ def as_float32_tensor(x: Union[float, torch.tensor]) -> torch.tensor: return torch.as_tensor(x, dtype=torch.float32, device='cuda') def ref_dynamic_per_token_quant(x: torch.tensor, + scale_ub: Optional[float], quant_dtype: torch.dtype) \ -> Tuple[torch.tensor, torch.tensor]: assert quant_dtype in [torch.int8, torch.float8_e4m3fn] + if scale_ub is not None: + assert quant_dtype == torch.float8_e4m3fn + qtype_traits = torch.iinfo(quant_dtype) if quant_dtype == torch.int8 \ else torch.finfo(quant_dtype) qtype_max = as_float32_tensor(qtype_traits.max) + s_1 = as_float32_tensor(1.0) + s_512 = as_float32_tensor(512.0) # For fp8, in order to match the cuda kernel output, we have to do exactly # the same operations as in the corresponding fp8 kernel to prevent @@ -22,10 +28,16 @@ def ref_dynamic_per_token_quant(x: torch.tensor, # Compute scales x_token_max, _ = x.abs().max(dim=-1) x_token_max = as_float32_tensor(x_token_max) + if scale_ub is not None: + x_token_max = x_token_max.clamp(max=scale_ub) + scales = (x_token_max / qtype_max)[:, None] + if quant_dtype == torch.float8_e4m3fn: + min_scaling_factor = s_1 / (qtype_max * s_512) + scales = scales.clamp(min=min_scaling_factor) # Quant - iscales = (qtype_max / x_token_max)[:, None] + iscales = as_float32_tensor(s_1 / scales) torch_out = as_float32_tensor(x) * iscales torch_out = torch_out.round() if quant_dtype == torch.int8 else torch_out torch_out = torch_out.clamp(qtype_traits.min, @@ -33,7 +45,6 @@ def ref_dynamic_per_token_quant(x: torch.tensor, return torch_out, scales - # The int8 version is very similar. Incorporate the int8 version, like in # ref_dynamic_per_token_quant, when we have a dynamic_per_tensor int8 quant # kernel diff --git a/tests/kernels/test_fp8_quant.py b/tests/kernels/test_fp8_quant.py index 6b555c8e242ad..0accd7624f29d 100644 --- a/tests/kernels/test_fp8_quant.py +++ b/tests/kernels/test_fp8_quant.py @@ -26,14 +26,15 @@ def test_dynamic_per_token_fp8_quant(num_tokens: int, hidden_size: int, x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") + 1e-6 # avoid nans - ref_out, ref_scales = ref_dynamic_per_token_quant(x, torch.float8_e4m3fn) - ops_out, ops_scales = ops.dynamic_per_token_scaled_fp8_quant(x) + scale_ub = None + + ref_out, ref_scales = ref_dynamic_per_token_quant(x, scale_ub, torch.float8_e4m3fn) + ops_out, ops_scales = ops.dynamic_per_token_scaled_fp8_quant(x, scale_ub) assert torch.allclose(ref_scales, ops_scales) assert torch.allclose(ref_out.to(dtype=torch.float32), ops_out.to(dtype=torch.float32)) - @pytest.mark.parametrize("num_tokens", NUM_TOKENS) @pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) @pytest.mark.parametrize("dtype", DTYPES) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 07646ae582a28..0457498e3ab3c 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -336,13 +336,13 @@ def scaled_fp8_quant( def dynamic_per_token_scaled_fp8_quant( - input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + input: torch.Tensor, scale_ub: Optional[torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]: output = torch.empty_like(input, dtype=torch.float8_e4m3fn) scales = torch.empty((input.numel() // input.shape[-1], 1), device=input.device, dtype=torch.float32) - torch.ops._C.dynamic_per_token_scaled_fp8_quant(output, input, scales) + torch.ops._C.dynamic_per_token_scaled_fp8_quant(output, input, scales, scale_ub) return output, scales From 1884acf5b8928eb2f6c5ef24b40ce6c6f6fb025c Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranth Date: Fri, 19 Jul 2024 20:29:30 +0000 Subject: [PATCH 19/40] format --- csrc/ops.h | 7 +++---- csrc/quantization/fp8/common.cu | 13 ++++++------- tests/kernels/quant_utils.py | 5 +++-- tests/kernels/test_fp8_quant.py | 13 +++++++++---- tests/kernels/test_int8_quant.py | 2 +- vllm/_custom_ops.py | 6 ++++-- 6 files changed, 26 insertions(+), 20 deletions(-) diff --git a/csrc/ops.h b/csrc/ops.h index 994412d97f7f9..6541b4d46d7f6 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -134,10 +134,9 @@ void static_scaled_fp8_quant(torch::Tensor& out, torch::Tensor const& input, void dynamic_scaled_fp8_quant(torch::Tensor& out, torch::Tensor const& input, torch::Tensor& scale); -void dynamic_per_token_scaled_fp8_quant(torch::Tensor& out, - torch::Tensor const& input, - torch::Tensor& scale, - c10::optional const& scale_ub); +void dynamic_per_token_scaled_fp8_quant( + torch::Tensor& out, torch::Tensor const& input, torch::Tensor& scale, + c10::optional const& scale_ub); void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, int64_t block_size, torch::Tensor sorted_token_ids, diff --git a/csrc/quantization/fp8/common.cu b/csrc/quantization/fp8/common.cu index d8ba7c1b62a65..aa6727bdfc3e2 100644 --- a/csrc/quantization/fp8/common.cu +++ b/csrc/quantization/fp8/common.cu @@ -166,8 +166,7 @@ __global__ void scaled_fp8_quant_kernel(c10::Float8_e4m3fn* __restrict__ out, template __global__ void dynamic_per_token_scaled_fp8_quant_kernel( c10::Float8_e4m3fn* __restrict__ out, float* __restrict__ scale, - scalar_t const* __restrict__ input, - float const* __restrict__ scale_ub, + scalar_t const* __restrict__ input, float const* __restrict__ scale_ub, const int hidden_size) { float const min_scaling_factor = 1.0f / (FP8_E4M3_MAX * 512.f); @@ -200,7 +199,7 @@ __global__ void dynamic_per_token_scaled_fp8_quant_kernel( token_scale = block_absmax_val_maybe; } // token scale computation - token_scale = max(token_scale / FP8_E4M3_MAX, min_scaling_factor); + token_scale = max(token_scale / FP8_E4M3_MAX, min_scaling_factor); scale[token_idx] = token_scale; } __syncthreads(); @@ -256,10 +255,10 @@ void dynamic_scaled_fp8_quant(torch::Tensor& out, // [..., d] }); } -void dynamic_per_token_scaled_fp8_quant(torch::Tensor& out, // [..., d] - torch::Tensor const& input, // [..., d] - torch::Tensor& scales, - std::optional const& scale_ub) { +void dynamic_per_token_scaled_fp8_quant( + torch::Tensor& out, // [..., d] + torch::Tensor const& input, // [..., d] + torch::Tensor& scales, std::optional const& scale_ub) { TORCH_CHECK(input.is_contiguous()); TORCH_CHECK(out.is_contiguous()); diff --git a/tests/kernels/quant_utils.py b/tests/kernels/quant_utils.py index 7fa7e03a8185d..a67af34b79758 100644 --- a/tests/kernels/quant_utils.py +++ b/tests/kernels/quant_utils.py @@ -1,4 +1,4 @@ -from typing import Tuple, Union, Optional +from typing import Optional, Tuple, Union import torch @@ -28,7 +28,7 @@ def ref_dynamic_per_token_quant(x: torch.tensor, # Compute scales x_token_max, _ = x.abs().max(dim=-1) x_token_max = as_float32_tensor(x_token_max) - if scale_ub is not None: + if scale_ub is not None: x_token_max = x_token_max.clamp(max=scale_ub) scales = (x_token_max / qtype_max)[:, None] @@ -45,6 +45,7 @@ def ref_dynamic_per_token_quant(x: torch.tensor, return torch_out, scales + # The int8 version is very similar. Incorporate the int8 version, like in # ref_dynamic_per_token_quant, when we have a dynamic_per_tensor int8 quant # kernel diff --git a/tests/kernels/test_fp8_quant.py b/tests/kernels/test_fp8_quant.py index 0accd7624f29d..1aaf7008ce4c6 100644 --- a/tests/kernels/test_fp8_quant.py +++ b/tests/kernels/test_fp8_quant.py @@ -10,31 +10,36 @@ 8193] # Arbitrary values for testing HIDDEN_SIZES += list(range(1024, 1033)) # vectorized conversion edge cases NUM_TOKENS = [1, 7, 83, 4096] # Arbitrary values for testing +SCALE_UBS = [True, False] SEEDS = [0] @pytest.mark.parametrize("num_tokens", NUM_TOKENS) @pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) @pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("scale_ub", SCALE_UBS) @pytest.mark.parametrize("seed", SEEDS) @torch.inference_mode() def test_dynamic_per_token_fp8_quant(num_tokens: int, hidden_size: int, - dtype: torch.dtype, seed: int) -> None: + dtype: torch.dtype, scale_ub: bool, + seed: int) -> None: torch.random.manual_seed(seed) torch.cuda.manual_seed(seed) x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") + 1e-6 # avoid nans - scale_ub = None - - ref_out, ref_scales = ref_dynamic_per_token_quant(x, scale_ub, torch.float8_e4m3fn) + scale_ub = torch.mean(x).to(dtype=torch.float32, device='cuda') \ + if scale_ub else None + ref_out, ref_scales = ref_dynamic_per_token_quant(x, scale_ub, + torch.float8_e4m3fn) ops_out, ops_scales = ops.dynamic_per_token_scaled_fp8_quant(x, scale_ub) assert torch.allclose(ref_scales, ops_scales) assert torch.allclose(ref_out.to(dtype=torch.float32), ops_out.to(dtype=torch.float32)) + @pytest.mark.parametrize("num_tokens", NUM_TOKENS) @pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) @pytest.mark.parametrize("dtype", DTYPES) diff --git a/tests/kernels/test_int8_quant.py b/tests/kernels/test_int8_quant.py index 03acbf7968ff1..4b95435e43cdd 100644 --- a/tests/kernels/test_int8_quant.py +++ b/tests/kernels/test_int8_quant.py @@ -27,7 +27,7 @@ def test_dynamic_scaled_int8_quant(num_tokens: int, hidden_size: int, x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") * 1000 # reference - ref_out, ref_scales = ref_dynamic_per_token_quant(x, torch.int8) + ref_out, ref_scales = ref_dynamic_per_token_quant(x, None, torch.int8) # kernel ops_out, ops_scales = scaled_int8_quant(x) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 0457498e3ab3c..5e2bfbf203f3e 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -336,13 +336,15 @@ def scaled_fp8_quant( def dynamic_per_token_scaled_fp8_quant( - input: torch.Tensor, scale_ub: Optional[torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]: + input: torch.Tensor, + scale_ub: Optional[torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]: output = torch.empty_like(input, dtype=torch.float8_e4m3fn) scales = torch.empty((input.numel() // input.shape[-1], 1), device=input.device, dtype=torch.float32) - torch.ops._C.dynamic_per_token_scaled_fp8_quant(output, input, scales, scale_ub) + torch.ops._C.dynamic_per_token_scaled_fp8_quant(output, input, scales, + scale_ub) return output, scales From fe14072796295a1a39b9c06a7e394555609f279e Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranth Date: Fri, 19 Jul 2024 20:47:38 +0000 Subject: [PATCH 20/40] Make optional ubs none --- tests/kernels/quant_utils.py | 4 ++-- tests/kernels/test_fp8_quant.py | 5 +++-- tests/kernels/test_int8_quant.py | 2 +- vllm/_custom_ops.py | 2 +- 4 files changed, 7 insertions(+), 6 deletions(-) diff --git a/tests/kernels/quant_utils.py b/tests/kernels/quant_utils.py index a67af34b79758..c0a5f222fa484 100644 --- a/tests/kernels/quant_utils.py +++ b/tests/kernels/quant_utils.py @@ -7,8 +7,8 @@ def as_float32_tensor(x: Union[float, torch.tensor]) -> torch.tensor: return torch.as_tensor(x, dtype=torch.float32, device='cuda') def ref_dynamic_per_token_quant(x: torch.tensor, - scale_ub: Optional[float], - quant_dtype: torch.dtype) \ + quant_dtype: torch.dtype, + scale_ub: Optional[torch.tensor] = None) \ -> Tuple[torch.tensor, torch.tensor]: assert quant_dtype in [torch.int8, torch.float8_e4m3fn] diff --git a/tests/kernels/test_fp8_quant.py b/tests/kernels/test_fp8_quant.py index 1aaf7008ce4c6..89f89fb240532 100644 --- a/tests/kernels/test_fp8_quant.py +++ b/tests/kernels/test_fp8_quant.py @@ -31,8 +31,9 @@ def test_dynamic_per_token_fp8_quant(num_tokens: int, hidden_size: int, scale_ub = torch.mean(x).to(dtype=torch.float32, device='cuda') \ if scale_ub else None - ref_out, ref_scales = ref_dynamic_per_token_quant(x, scale_ub, - torch.float8_e4m3fn) + ref_out, ref_scales = ref_dynamic_per_token_quant(x, + torch.float8_e4m3fn, + scale_ub) ops_out, ops_scales = ops.dynamic_per_token_scaled_fp8_quant(x, scale_ub) assert torch.allclose(ref_scales, ops_scales) diff --git a/tests/kernels/test_int8_quant.py b/tests/kernels/test_int8_quant.py index 4b95435e43cdd..03acbf7968ff1 100644 --- a/tests/kernels/test_int8_quant.py +++ b/tests/kernels/test_int8_quant.py @@ -27,7 +27,7 @@ def test_dynamic_scaled_int8_quant(num_tokens: int, hidden_size: int, x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") * 1000 # reference - ref_out, ref_scales = ref_dynamic_per_token_quant(x, None, torch.int8) + ref_out, ref_scales = ref_dynamic_per_token_quant(x, torch.int8) # kernel ops_out, ops_scales = scaled_int8_quant(x) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 5e2bfbf203f3e..590d5f896631d 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -337,7 +337,7 @@ def scaled_fp8_quant( def dynamic_per_token_scaled_fp8_quant( input: torch.Tensor, - scale_ub: Optional[torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]: + scale_ub: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]: output = torch.empty_like(input, dtype=torch.float8_e4m3fn) scales = torch.empty((input.numel() // input.shape[-1], 1), From 254dcff46dfc490449df6f0a57b247d481a1597d Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranth Date: Fri, 19 Jul 2024 20:48:12 +0000 Subject: [PATCH 21/40] format --- tests/kernels/test_fp8_quant.py | 3 +-- vllm/_custom_ops.py | 5 +++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/kernels/test_fp8_quant.py b/tests/kernels/test_fp8_quant.py index 89f89fb240532..904dc07adf938 100644 --- a/tests/kernels/test_fp8_quant.py +++ b/tests/kernels/test_fp8_quant.py @@ -31,8 +31,7 @@ def test_dynamic_per_token_fp8_quant(num_tokens: int, hidden_size: int, scale_ub = torch.mean(x).to(dtype=torch.float32, device='cuda') \ if scale_ub else None - ref_out, ref_scales = ref_dynamic_per_token_quant(x, - torch.float8_e4m3fn, + ref_out, ref_scales = ref_dynamic_per_token_quant(x, torch.float8_e4m3fn, scale_ub) ops_out, ops_scales = ops.dynamic_per_token_scaled_fp8_quant(x, scale_ub) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 590d5f896631d..674736c9eb86e 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -336,8 +336,9 @@ def scaled_fp8_quant( def dynamic_per_token_scaled_fp8_quant( - input: torch.Tensor, - scale_ub: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]: + input: torch.Tensor, + scale_ub: Optional[torch.Tensor] = None +) -> Tuple[torch.Tensor, torch.Tensor]: output = torch.empty_like(input, dtype=torch.float8_e4m3fn) scales = torch.empty((input.numel() // input.shape[-1], 1), From 227a277152bae9242cf6d5916d58bfbf679c9c9b Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Fri, 19 Jul 2024 21:52:04 +0000 Subject: [PATCH 22/40] hook up end to end with varun's ub quant kernel --- vllm/_custom_ops.py | 2 +- vllm/model_executor/layers/quantization/fbgemm_fp8.py | 7 ++++--- .../model_executor/layers/quantization/utils/w8a8_utils.py | 2 ++ 3 files changed, 7 insertions(+), 4 deletions(-) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 1a8c08d5c58f5..54be5a33edc07 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -299,7 +299,7 @@ def fp8_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, def scaled_fp8_quant( input: torch.Tensor, scale: Optional[torch.Tensor] = None, - scale_ub: Optional[float] = None, + scale_ub: Optional[torch.Tensor] = None, batch_dim_padding: Optional[int] = None, use_per_token_if_dynamic: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor]: diff --git a/vllm/model_executor/layers/quantization/fbgemm_fp8.py b/vllm/model_executor/layers/quantization/fbgemm_fp8.py index 7924446279310..0ebe58d4a9490 100644 --- a/vllm/model_executor/layers/quantization/fbgemm_fp8.py +++ b/vllm/model_executor/layers/quantization/fbgemm_fp8.py @@ -89,9 +89,9 @@ def create_weights( **extra_weight_attrs) layer.register_parameter("weight_scale", weight_scale) - # NOT USED FOR INFERNECE - input_scale_ub = torch.nn.Parameter( - torch.zeros((1), dtype=torch.float8_e4m3fn)) + # INPUT SCALE UPPER BOUND + input_scale_ub = torch.nn.Parameter(torch.zeros((1), dtype=torch.float32), + requires_grad=False) layer.register_parameter("input_scale_ub", input_scale_ub) set_weight_attrs(input_scale_ub, { "ignore_warning": True, @@ -111,6 +111,7 @@ def apply(self, weight=layer.weight, weight_scale=layer.weight_scale, input_scale=None, + input_scale_ub=layer.input_scale_ub, bias=bias, cutlass_fp8_supported=True, use_per_token_if_dynamic=True) diff --git a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py index 0729a2d7f8ddd..4fbf75b2ff090 100644 --- a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py @@ -105,6 +105,7 @@ def apply_fp8_linear( weight: torch.Tensor, weight_scale: torch.Tensor, input_scale: torch.Tensor, + input_scale_ub: Optional[torch.Tensor] = None, bias: Optional[torch.Tensor] = None, cutlass_fp8_supported: bool = True, use_per_token_if_dynamic: bool = False, @@ -118,6 +119,7 @@ def apply_fp8_linear( qinput, x_scale = ops.scaled_fp8_quant( input, input_scale, + scale_ub=input_scale_ub, use_per_token_if_dynamic=use_per_token_if_dynamic) # Fused GEMM_DQ From 951834aaf68b278bfdd41e73dcf9e43427bb4069 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Fri, 19 Jul 2024 21:57:17 +0000 Subject: [PATCH 23/40] formatted --- tests/kernels/test_fp8_quant.py | 3 ++- vllm/model_executor/layers/quantization/fbgemm_fp8.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/kernels/test_fp8_quant.py b/tests/kernels/test_fp8_quant.py index ca2920a8fe086..28a186473f678 100644 --- a/tests/kernels/test_fp8_quant.py +++ b/tests/kernels/test_fp8_quant.py @@ -33,7 +33,8 @@ def test_dynamic_per_token_fp8_quant(num_tokens: int, hidden_size: int, if scale_ub else None ref_out, ref_scales = ref_dynamic_per_token_quant(x, torch.float8_e4m3fn, scale_ub) - ops_out, ops_scales = ops.dynamic_per_token_scaled_fp8_quant(x, scale_ub, use_per_token_if_dynamic=True) + ops_out, ops_scales = ops.dynamic_per_token_scaled_fp8_quant( + x, scale_ub, use_per_token_if_dynamic=True) assert torch.allclose(ref_scales, ops_scales) assert torch.allclose(ref_out.to(dtype=torch.float32), diff --git a/vllm/model_executor/layers/quantization/fbgemm_fp8.py b/vllm/model_executor/layers/quantization/fbgemm_fp8.py index 0ebe58d4a9490..81eca51914f96 100644 --- a/vllm/model_executor/layers/quantization/fbgemm_fp8.py +++ b/vllm/model_executor/layers/quantization/fbgemm_fp8.py @@ -90,7 +90,8 @@ def create_weights( layer.register_parameter("weight_scale", weight_scale) # INPUT SCALE UPPER BOUND - input_scale_ub = torch.nn.Parameter(torch.zeros((1), dtype=torch.float32), + input_scale_ub = torch.nn.Parameter(torch.zeros((1), + dtype=torch.float32), requires_grad=False) layer.register_parameter("input_scale_ub", input_scale_ub) set_weight_attrs(input_scale_ub, { From 9aa66d3565daceafc9fad7407f5339c5a27a5ce2 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Fri, 19 Jul 2024 23:23:20 +0000 Subject: [PATCH 24/40] updated for nonuniform --- vllm/attention/layer.py | 3 +- vllm/model_executor/layers/linear.py | 9 ++-- .../layers/quantization/base_config.py | 3 +- .../layers/quantization/fbgemm_fp8.py | 53 +++++++++++++++++-- .../layers/vocab_parallel_embedding.py | 11 ++-- 5 files changed, 64 insertions(+), 15 deletions(-) diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 0619bda90a2a7..643a845899c37 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -34,6 +34,7 @@ def __init__( cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, blocksparse_params: Optional[Dict[str, Any]] = None, + prefix: str = "", ) -> None: super().__init__() if cache_config is not None: @@ -56,7 +57,7 @@ def __init__( self._k_scale = 1.0 self._v_scale = 1.0 quant_method = quant_config.get_quant_method( - self) if quant_config else None + self, prefix=prefix) if quant_config else None if quant_method is not None: assert isinstance(quant_method, Fp8KVCacheMethod) # TODO (mgoin): kv cache dtype should be specified in the FP8 diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 1106d7985ecc6..44617365cc859 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -141,6 +141,7 @@ def __init__( skip_bias_add: bool = False, params_dtype: Optional[torch.dtype] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() @@ -155,7 +156,7 @@ def __init__( self.quant_method: Optional[ QuantizeMethodBase] = UnquantizedLinearMethod() else: - self.quant_method = quant_config.get_quant_method(self) + self.quant_method = quant_config.get_quant_method(self, prefix=prefix) def forward(self, x: torch.Tensor) -> torch.Tensor: raise NotImplementedError @@ -184,7 +185,7 @@ def __init__(self, quant_config: Optional[QuantizationConfig] = None, prefix: Optional[str] = None): super().__init__(input_size, output_size, skip_bias_add, params_dtype, - quant_config) + quant_config, prefix=prefix) # All the linear layer supports quant method. assert self.quant_method is not None @@ -260,7 +261,7 @@ def __init__(self, output_sizes: Optional[List[int]] = None, prefix: Optional[str] = None): super().__init__(input_size, output_size, skip_bias_add, params_dtype, - quant_config) + quant_config, prefix) self.gather_output = gather_output @@ -709,7 +710,7 @@ def __init__(self, quant_config: Optional[QuantizationConfig] = None, prefix: Optional[str] = None): super().__init__(input_size, output_size, skip_bias_add, params_dtype, - quant_config) + quant_config, prefix) self.input_is_parallel = input_is_parallel self.reduce_results = reduce_results diff --git a/vllm/model_executor/layers/quantization/base_config.py b/vllm/model_executor/layers/quantization/base_config.py index 1607470cb76f6..2a5eaf154f147 100644 --- a/vllm/model_executor/layers/quantization/base_config.py +++ b/vllm/model_executor/layers/quantization/base_config.py @@ -98,11 +98,12 @@ def get_from_keys_or(config: Dict[str, Any], keys: List[str], @abstractmethod def get_quant_method( - self, layer: torch.nn.Module) -> Optional[QuantizeMethodBase]: + self, layer: torch.nn.Module, prefix: str) -> Optional[QuantizeMethodBase]: """Get the quantize method to use for the quantized layer. Args: layer: The layer for the quant method. + prefix: The full name of the layer in the state dict Returns: The quantize method. None if the given layer doesn't support quant method. diff --git a/vllm/model_executor/layers/quantization/fbgemm_fp8.py b/vllm/model_executor/layers/quantization/fbgemm_fp8.py index 81eca51914f96..0e9075f80fec9 100644 --- a/vllm/model_executor/layers/quantization/fbgemm_fp8.py +++ b/vllm/model_executor/layers/quantization/fbgemm_fp8.py @@ -5,7 +5,7 @@ from torch.nn.parameter import Parameter from vllm.logger import init_logger -from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase +from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase, UnquantizedLinearMethod from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( @@ -15,16 +15,28 @@ logger = init_logger(__name__) +# Note: this is a hack. We should update each model to register the +# stacked params and get it from there instead in a future PR. +# fused_name: List[shard_name] +_FUSED_LAYER_NAME_MAPPING = { + "qkv_proj": ["q_proj", "k_proj", "v_proj"], + "gate_up_proj": ["gate_proj", "up_proj"] +} + + class FBGEMMFp8Config(QuantizationConfig): """Config class for FBGEMM Fp8.""" + def __init__(self, ignore_list: List[str]): + self.ignore_list = ignore_list + @classmethod def get_name(cls) -> str: return "fbgemm_fp8" @classmethod def get_supported_act_dtypes(cls) -> List[torch.dtype]: - return [torch.bfloat16] + return [torch.bfloat16, torch.float16] @classmethod def get_min_capability(cls) -> int: @@ -36,11 +48,42 @@ def get_config_filenames(cls) -> List[str]: @classmethod def from_config(cls, config: Dict[str, Any]) -> "FBGEMMFp8Config": - return cls() - + ignore_list = cls.get_from_keys(config, ["modules_to_not_convert"]) + return cls(ignore_list=ignore_list) + + def _is_layer_skipped(self, prefix: str) -> bool: + # prefix: model.layers.0.self_attn.q_proj + # proj_name: q_proj + proj_name = prefix.split(".")[-1] + if proj_name in _FUSED_LAYER_NAME_MAPPING: + shard_prefixes = [ + prefix.replace(proj_name, shard_proj_name) for + shard_proj_name in _FUSED_LAYER_NAME_MAPPING[proj_name] + ] + + is_skipped = None + for shard_prefix in shard_prefixes: + is_shard_skipped = shard_prefix in self.ignore_list + + if is_skipped is None: + is_skipped = is_shard_skipped + elif is_shard_skipped != is_skipped: + raise ValueError( + f"Detected some but not all shards of {prefix} " + "are quantized. All shards of fused layers " + "to have the same precision." + ) + else: + is_skipped = prefix in self.ignore_list + + assert is_skipped is not None + return is_skipped + def get_quant_method( - self, layer: torch.nn.Module) -> Optional["QuantizeMethodBase"]: + self, layer: torch.nn.Module, prefix: str) -> Optional["QuantizeMethodBase"]: if isinstance(layer, LinearBase): + if self._is_layer_skipped(prefix): + return UnquantizedLinearMethod() return FBGEMMFp8LinearMethod(self) return None diff --git a/vllm/model_executor/layers/vocab_parallel_embedding.py b/vllm/model_executor/layers/vocab_parallel_embedding.py index d70eb1c2704b4..37e156418c51f 100644 --- a/vllm/model_executor/layers/vocab_parallel_embedding.py +++ b/vllm/model_executor/layers/vocab_parallel_embedding.py @@ -161,6 +161,7 @@ class VocabParallelEmbedding(torch.nn.Module): org_num_embeddings: original vocabulary size (without LoRA). padding_size: padding size for the vocabulary. quant_config: quant config for the layer + prefix: full name of the layer in the state dict """ # noqa: E501 def __init__(self, @@ -169,7 +170,8 @@ def __init__(self, params_dtype: Optional[torch.dtype] = None, org_num_embeddings: Optional[int] = None, padding_size: int = DEFAULT_VOCAB_PADDING_SIZE, - quant_config: Optional[QuantizationConfig] = None): + quant_config: Optional[QuantizationConfig] = None, + prefix: str = ""): super().__init__() # Keep the input dimensions. @@ -195,7 +197,7 @@ def __init__(self, linear_method = None if quant_config is not None: - linear_method = quant_config.get_quant_method(self) + linear_method = quant_config.get_quant_method(self, prefix=prefix) if linear_method is None: linear_method = UnquantizedLinearMethod() self.linear_method: QuantizeMethodBase = linear_method @@ -382,9 +384,10 @@ def __init__(self, params_dtype: Optional[torch.dtype] = None, org_num_embeddings: Optional[int] = None, padding_size: int = DEFAULT_VOCAB_PADDING_SIZE, - quant_config: Optional[QuantizationConfig] = None): + quant_config: Optional[QuantizationConfig] = None, + prefix: str = ""): super().__init__(num_embeddings, embedding_dim, params_dtype, - org_num_embeddings, padding_size, quant_config) + org_num_embeddings, padding_size, quant_config, prefix) if bias: self.bias = Parameter( torch.empty(self.num_embeddings_per_partition, From 458a4105c41aa0ab9ef741cb4bde96aaca5eb3b4 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Fri, 19 Jul 2024 23:37:44 +0000 Subject: [PATCH 25/40] formatting after passing prefix around --- vllm/model_executor/layers/fused_moe/layer.py | 2 +- vllm/model_executor/layers/linear.py | 21 ++++++++++++------- .../layers/quantization/aqlm.py | 4 ++-- .../model_executor/layers/quantization/awq.py | 4 ++-- .../layers/quantization/base_config.py | 4 ++-- .../layers/quantization/bitsandbytes.py | 5 ++--- .../compressed_tensors/compressed_tensors.py | 6 +++++- .../layers/quantization/deepspeedfp.py | 5 ++--- .../layers/quantization/fbgemm_fp8.py | 21 +++++++++---------- .../model_executor/layers/quantization/fp8.py | 4 ++-- .../layers/quantization/gptq.py | 4 ++-- .../layers/quantization/gptq_marlin.py | 5 ++--- .../layers/quantization/gptq_marlin_24.py | 5 ++--- .../layers/quantization/marlin.py | 4 ++-- .../layers/quantization/squeezellm.py | 4 ++-- .../layers/vocab_parallel_embedding.py | 3 ++- 16 files changed, 53 insertions(+), 48 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index a6fa8ffe5111c..a0dc4c94744a8 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -182,7 +182,7 @@ def __init__( self.quant_method: Optional[QuantizeMethodBase] = ( UnquantizedFusedMoEMethod()) else: - self.quant_method = quant_config.get_quant_method(self) + self.quant_method = quant_config.get_quant_method(self, prefix) assert self.quant_method is not None self.quant_method.create_weights( diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 44617365cc859..0e0a2b72f93d4 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -156,7 +156,8 @@ def __init__( self.quant_method: Optional[ QuantizeMethodBase] = UnquantizedLinearMethod() else: - self.quant_method = quant_config.get_quant_method(self, prefix=prefix) + self.quant_method = quant_config.get_quant_method(self, + prefix=prefix) def forward(self, x: torch.Tensor) -> torch.Tensor: raise NotImplementedError @@ -183,9 +184,13 @@ def __init__(self, skip_bias_add: bool = False, params_dtype: Optional[torch.dtype] = None, quant_config: Optional[QuantizationConfig] = None, - prefix: Optional[str] = None): - super().__init__(input_size, output_size, skip_bias_add, params_dtype, - quant_config, prefix=prefix) + prefix: str = ""): + super().__init__(input_size, + output_size, + skip_bias_add, + params_dtype, + quant_config, + prefix=prefix) # All the linear layer supports quant method. assert self.quant_method is not None @@ -259,7 +264,7 @@ def __init__(self, params_dtype: Optional[torch.dtype] = None, quant_config: Optional[QuantizationConfig] = None, output_sizes: Optional[List[int]] = None, - prefix: Optional[str] = None): + prefix: str = ""): super().__init__(input_size, output_size, skip_bias_add, params_dtype, quant_config, prefix) @@ -371,7 +376,7 @@ def __init__(self, skip_bias_add: bool = False, params_dtype: Optional[torch.dtype] = None, quant_config: Optional[QuantizationConfig] = None, - prefix: Optional[str] = None): + prefix: str = ""): self.output_sizes = output_sizes tp_size = get_tensor_model_parallel_world_size() assert all(output_size % tp_size == 0 for output_size in output_sizes) @@ -515,7 +520,7 @@ def __init__(self, skip_bias_add: bool = False, params_dtype: Optional[torch.dtype] = None, quant_config: Optional[QuantizationConfig] = None, - prefix: Optional[str] = None): + prefix: str = ""): self.hidden_size = hidden_size self.head_size = head_size self.total_num_heads = total_num_heads @@ -708,7 +713,7 @@ def __init__(self, params_dtype: Optional[torch.dtype] = None, reduce_results: bool = True, quant_config: Optional[QuantizationConfig] = None, - prefix: Optional[str] = None): + prefix: str = ""): super().__init__(input_size, output_size, skip_bias_add, params_dtype, quant_config, prefix) diff --git a/vllm/model_executor/layers/quantization/aqlm.py b/vllm/model_executor/layers/quantization/aqlm.py index 730595c3d36d1..95ff05b986ab4 100644 --- a/vllm/model_executor/layers/quantization/aqlm.py +++ b/vllm/model_executor/layers/quantization/aqlm.py @@ -207,8 +207,8 @@ def from_config(cls, config: Dict[str, Any]) -> "AQLMConfig": return cls(in_group_size, nbits_per_codebook, num_code_books, out_group_size) - def get_quant_method( - self, layer: torch.nn.Module) -> Optional["AQLMLinearMethod"]: + def get_quant_method(self, layer: torch.nn.Module, + prefix: str) -> Optional["AQLMLinearMethod"]: if isinstance(layer, LinearBase): return AQLMLinearMethod(self) return None diff --git a/vllm/model_executor/layers/quantization/awq.py b/vllm/model_executor/layers/quantization/awq.py index a3854f70bb4fa..ce2fa62ef565f 100644 --- a/vllm/model_executor/layers/quantization/awq.py +++ b/vllm/model_executor/layers/quantization/awq.py @@ -63,8 +63,8 @@ def from_config(cls, config: Dict[str, Any]) -> "AWQConfig": zero_point = cls.get_from_keys(config, ["zero_point"]) return cls(weight_bits, group_size, zero_point) - def get_quant_method( - self, layer: torch.nn.Module) -> Optional["AWQLinearMethod"]: + def get_quant_method(self, layer: torch.nn.Module, + prefix: str) -> Optional["AWQLinearMethod"]: if isinstance(layer, LinearBase): return AWQLinearMethod(self) return None diff --git a/vllm/model_executor/layers/quantization/base_config.py b/vllm/model_executor/layers/quantization/base_config.py index 2a5eaf154f147..f5ff27b9f14b7 100644 --- a/vllm/model_executor/layers/quantization/base_config.py +++ b/vllm/model_executor/layers/quantization/base_config.py @@ -97,8 +97,8 @@ def get_from_keys_or(config: Dict[str, Any], keys: List[str], return default @abstractmethod - def get_quant_method( - self, layer: torch.nn.Module, prefix: str) -> Optional[QuantizeMethodBase]: + def get_quant_method(self, layer: torch.nn.Module, + prefix: str) -> Optional[QuantizeMethodBase]: """Get the quantize method to use for the quantized layer. Args: diff --git a/vllm/model_executor/layers/quantization/bitsandbytes.py b/vllm/model_executor/layers/quantization/bitsandbytes.py index e76714a7b460c..4a68da5a2323e 100644 --- a/vllm/model_executor/layers/quantization/bitsandbytes.py +++ b/vllm/model_executor/layers/quantization/bitsandbytes.py @@ -60,9 +60,8 @@ def from_config(cls, config: Dict[str, Any]) -> "BitsAndBytesConfig": target_modules = cls.get_from_keys(config, ["target_modules"]) return cls(adapter_name, target_modules) - def get_quant_method( - self, - layer: torch.nn.Module) -> Optional["BitsAndBytesLinearMethod"]: + def get_quant_method(self, layer: torch.nn.Module, + prefix: str) -> Optional["BitsAndBytesLinearMethod"]: if isinstance(layer, LinearBase): return BitsAndBytesLinearMethod(self) return None diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index 28c552b3654f3..0accc94231b9c 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -44,8 +44,12 @@ def get_min_capability(cls) -> int: def get_name(self) -> str: return "compressed_tensors" + # TODO (@robertgshaw2-neuralmagic): do layer skipping though here + # rather than though create_weights to match other methods def get_quant_method( - self, layer: torch.nn.Module + self, + layer: torch.nn.Module, + prefix: str, ) -> Optional["CompressedTensorsLinearMethod"]: if isinstance(layer, LinearBase): return CompressedTensorsLinearMethod(self) diff --git a/vllm/model_executor/layers/quantization/deepspeedfp.py b/vllm/model_executor/layers/quantization/deepspeedfp.py index 31cdffbcf0ab9..29484801dc380 100644 --- a/vllm/model_executor/layers/quantization/deepspeedfp.py +++ b/vllm/model_executor/layers/quantization/deepspeedfp.py @@ -69,9 +69,8 @@ def get_config_filenames() -> List[str]: "quantize_config.json", ] - def get_quant_method( - self, - layer: torch.nn.Module) -> Optional["DeepSpeedFPLinearMethod"]: + def get_quant_method(self, layer: torch.nn.Module, + prefix: str) -> Optional["DeepSpeedFPLinearMethod"]: if isinstance(layer, LinearBase): return DeepSpeedFPLinearMethod(self) return None diff --git a/vllm/model_executor/layers/quantization/fbgemm_fp8.py b/vllm/model_executor/layers/quantization/fbgemm_fp8.py index 0e9075f80fec9..aed7dd194795c 100644 --- a/vllm/model_executor/layers/quantization/fbgemm_fp8.py +++ b/vllm/model_executor/layers/quantization/fbgemm_fp8.py @@ -5,7 +5,8 @@ from torch.nn.parameter import Parameter from vllm.logger import init_logger -from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase, UnquantizedLinearMethod +from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, + UnquantizedLinearMethod) from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( @@ -14,8 +15,7 @@ logger = init_logger(__name__) - -# Note: this is a hack. We should update each model to register the +# Note: this is a hack. We should update each model to register the # stacked params and get it from there instead in a future PR. # fused_name: List[shard_name] _FUSED_LAYER_NAME_MAPPING = { @@ -57,30 +57,29 @@ def _is_layer_skipped(self, prefix: str) -> bool: proj_name = prefix.split(".")[-1] if proj_name in _FUSED_LAYER_NAME_MAPPING: shard_prefixes = [ - prefix.replace(proj_name, shard_proj_name) for - shard_proj_name in _FUSED_LAYER_NAME_MAPPING[proj_name] + prefix.replace(proj_name, shard_proj_name) + for shard_proj_name in _FUSED_LAYER_NAME_MAPPING[proj_name] ] is_skipped = None for shard_prefix in shard_prefixes: is_shard_skipped = shard_prefix in self.ignore_list - + if is_skipped is None: is_skipped = is_shard_skipped elif is_shard_skipped != is_skipped: raise ValueError( f"Detected some but not all shards of {prefix} " "are quantized. All shards of fused layers " - "to have the same precision." - ) + "to have the same precision.") else: is_skipped = prefix in self.ignore_list assert is_skipped is not None return is_skipped - - def get_quant_method( - self, layer: torch.nn.Module, prefix: str) -> Optional["QuantizeMethodBase"]: + + def get_quant_method(self, layer: torch.nn.Module, + prefix: str) -> Optional["QuantizeMethodBase"]: if isinstance(layer, LinearBase): if self._is_layer_skipped(prefix): return UnquantizedLinearMethod() diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 820c066aad28a..d4498f452cc06 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -66,8 +66,8 @@ def from_config(cls, config: Dict[str, Any]) -> "Fp8Config": return cls(is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized, activation_scheme=activation_scheme) - def get_quant_method( - self, layer: torch.nn.Module) -> Optional["QuantizeMethodBase"]: + def get_quant_method(self, layer: torch.nn.Module, + prefix: str) -> Optional["QuantizeMethodBase"]: from vllm.attention.layer import Attention # Avoid circular import if isinstance(layer, LinearBase): diff --git a/vllm/model_executor/layers/quantization/gptq.py b/vllm/model_executor/layers/quantization/gptq.py index 595d6ab96b1b9..510c9dd49ef03 100644 --- a/vllm/model_executor/layers/quantization/gptq.py +++ b/vllm/model_executor/layers/quantization/gptq.py @@ -69,8 +69,8 @@ def from_config(cls, config: Dict[str, Any]) -> "GPTQConfig": default=False) return cls(weight_bits, group_size, desc_act, lm_head_quantized) - def get_quant_method( - self, layer: torch.nn.Module) -> Optional["GPTQLinearMethod"]: + def get_quant_method(self, layer: torch.nn.Module, + prefix: str) -> Optional["GPTQLinearMethod"]: if (isinstance(layer, LinearBase) or (isinstance(layer, ParallelLMHead) and self.lm_head_quantized)): return GPTQLinearMethod(self) diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index 07a73d06e0596..bb9644dbc9947 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -94,9 +94,8 @@ def override_quantization_method(cls, hf_quant_cfg, " faster inference") return None - def get_quant_method( - self, - layer: torch.nn.Module) -> Optional["GPTQMarlinLinearMethod"]: + def get_quant_method(self, layer: torch.nn.Module, + prefix: str) -> Optional["GPTQMarlinLinearMethod"]: if (isinstance(layer, LinearBase) or (isinstance(layer, ParallelLMHead) and self.lm_head_quantized)): return GPTQMarlinLinearMethod(self) diff --git a/vllm/model_executor/layers/quantization/gptq_marlin_24.py b/vllm/model_executor/layers/quantization/gptq_marlin_24.py index 6bcfc405afe71..e708c4da95af3 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin_24.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin_24.py @@ -109,9 +109,8 @@ def override_quantization_method(cls, hf_quant_cfg, return None - def get_quant_method( - self, - layer: torch.nn.Module) -> Optional["GPTQMarlin24LinearMethod"]: + def get_quant_method(self, layer: torch.nn.Module, + prefix: str) -> Optional["GPTQMarlin24LinearMethod"]: if isinstance(layer, LinearBase): return GPTQMarlin24LinearMethod(self) return None diff --git a/vllm/model_executor/layers/quantization/marlin.py b/vllm/model_executor/layers/quantization/marlin.py index f0a9cf5520bdd..cdc5129a93b15 100644 --- a/vllm/model_executor/layers/quantization/marlin.py +++ b/vllm/model_executor/layers/quantization/marlin.py @@ -100,8 +100,8 @@ def override_quantization_method(cls, hf_quant_cfg, return None - def get_quant_method( - self, layer: torch.nn.Module) -> Optional["MarlinLinearMethod"]: + def get_quant_method(self, layer: torch.nn.Module, + prefix: str) -> Optional["MarlinLinearMethod"]: if (isinstance(layer, LinearBase) or (isinstance(layer, ParallelLMHead) and self.lm_head_quantized)): return MarlinLinearMethod(self) diff --git a/vllm/model_executor/layers/quantization/squeezellm.py b/vllm/model_executor/layers/quantization/squeezellm.py index 72ba55eb1740d..afb3c04976737 100644 --- a/vllm/model_executor/layers/quantization/squeezellm.py +++ b/vllm/model_executor/layers/quantization/squeezellm.py @@ -52,8 +52,8 @@ def from_config(cls, config: Dict[str, Any]) -> "SqueezeLLMConfig": weight_bits = cls.get_from_keys(config, ["wbits"]) return cls(weight_bits) - def get_quant_method( - self, layer: torch.nn.Module) -> Optional[QuantizeMethodBase]: + def get_quant_method(self, layer: torch.nn.Module, + prefix: str) -> Optional[QuantizeMethodBase]: if isinstance(layer, LinearBase): return SqueezeLLMLinearMethod(self) return None diff --git a/vllm/model_executor/layers/vocab_parallel_embedding.py b/vllm/model_executor/layers/vocab_parallel_embedding.py index 37e156418c51f..74aeb964274b0 100644 --- a/vllm/model_executor/layers/vocab_parallel_embedding.py +++ b/vllm/model_executor/layers/vocab_parallel_embedding.py @@ -387,7 +387,8 @@ def __init__(self, quant_config: Optional[QuantizationConfig] = None, prefix: str = ""): super().__init__(num_embeddings, embedding_dim, params_dtype, - org_num_embeddings, padding_size, quant_config, prefix) + org_num_embeddings, padding_size, quant_config, + prefix) if bias: self.bias = Parameter( torch.empty(self.num_embeddings_per_partition, From 3e4aaadcd63f83deba06a3949a29a3bab9a5911f Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Fri, 19 Jul 2024 23:42:36 +0000 Subject: [PATCH 26/40] fixed bad merge --- vllm/_custom_ops.py | 1 - vllm/model_executor/layers/quantization/utils/w8a8_utils.py | 1 - 2 files changed, 2 deletions(-) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 7b940319f355a..54be5a33edc07 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -302,7 +302,6 @@ def scaled_fp8_quant( scale_ub: Optional[torch.Tensor] = None, batch_dim_padding: Optional[int] = None, use_per_token_if_dynamic: bool = False, - use_per_token_if_dynamic: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Quantize input tensor to FP8 and return quantized tensor and scale. diff --git a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py index e47a5ab359368..4fbf75b2ff090 100644 --- a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py @@ -109,7 +109,6 @@ def apply_fp8_linear( bias: Optional[torch.Tensor] = None, cutlass_fp8_supported: bool = True, use_per_token_if_dynamic: bool = False, - use_per_token_if_dynamic: bool = False, ) -> torch.Tensor: # ops.scaled_fp8_quant supports both dynamic and static quant. # If dynamic, layer.input_scale is None and x_scale computed from x. From de2a764b3fcf9fafd9ca48410e9b870ef774a80c Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Sat, 20 Jul 2024 00:58:03 +0000 Subject: [PATCH 27/40] updated message --- vllm/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/config.py b/vllm/config.py index 9902a152e551a..02300d7407315 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -248,7 +248,7 @@ def _verify_quantization(self) -> None: f"supported in ROCm.") if (self.quantization not in ("fp8", "marlin", "gptq_marlin_24", "gptq_marlin", - "compressed_tensors")): + "fpgemm_fp8", "compressed_tensors")): logger.warning( "%s quantization is not fully " "optimized yet. The speed can be slower than " From c88fe3464e722612f2bb1872ec73bb32cb5d37a5 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Sat, 20 Jul 2024 01:21:13 +0000 Subject: [PATCH 28/40] merged varun's pr --- vllm/_custom_ops.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 1f3a2d8428f95..54be5a33edc07 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -301,7 +301,6 @@ def scaled_fp8_quant( scale: Optional[torch.Tensor] = None, scale_ub: Optional[torch.Tensor] = None, batch_dim_padding: Optional[int] = None, - scale_ub: Optional[torch.Tensor] = None, use_per_token_if_dynamic: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor]: """ From bb02a3f5a92337836492b12ba573a3bc2a5ea56b Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Sat, 20 Jul 2024 01:30:42 +0000 Subject: [PATCH 29/40] fixed --- tests/kernels/quant_utils.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tests/kernels/quant_utils.py b/tests/kernels/quant_utils.py index b6c0ab4148c34..cec2b05bafd21 100644 --- a/tests/kernels/quant_utils.py +++ b/tests/kernels/quant_utils.py @@ -31,9 +31,6 @@ def ref_dynamic_per_token_quant(x: torch.tensor, if scale_ub is not None: x_token_max = x_token_max.clamp(max=scale_ub) scales = (x_token_max / qtype_max)[:, None] - if quant_dtype == torch.float8_e4m3fn: - min_scaling_factor = s_1 / (qtype_max * s_512) - scales = scales.clamp(min=min_scaling_factor) # Quant if quant_dtype == torch.int8: From 1c8f71ca5d4e998a6b7b4e7fd115ed3698c2f2f6 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Sat, 20 Jul 2024 01:32:00 +0000 Subject: [PATCH 30/40] cleanup pr --- vllm/_custom_ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 54be5a33edc07..80ca357e8b293 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -299,8 +299,8 @@ def fp8_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, def scaled_fp8_quant( input: torch.Tensor, scale: Optional[torch.Tensor] = None, - scale_ub: Optional[torch.Tensor] = None, batch_dim_padding: Optional[int] = None, + scale_ub: Optional[torch.Tensor] = None, use_per_token_if_dynamic: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor]: """ From 6970e50153ed9c52c59b24e40fa0be324597b332 Mon Sep 17 00:00:00 2001 From: Robert Shaw <114415538+robertgshaw2-neuralmagic@users.noreply.github.com> Date: Fri, 19 Jul 2024 22:17:58 -0400 Subject: [PATCH 31/40] Update config.py --- vllm/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/config.py b/vllm/config.py index 02300d7407315..1bb2cadd9019f 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -248,7 +248,7 @@ def _verify_quantization(self) -> None: f"supported in ROCm.") if (self.quantization not in ("fp8", "marlin", "gptq_marlin_24", "gptq_marlin", - "fpgemm_fp8", "compressed_tensors")): + "fbgemm_fp8", "compressed_tensors")): logger.warning( "%s quantization is not fully " "optimized yet. The speed can be slower than " From 94617f0b345521dbdc0b2c3d7c4534506b1e46c8 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Sat, 20 Jul 2024 13:22:59 +0000 Subject: [PATCH 32/40] fixed config --- ...ta-Llama-3-8B-Instruct-Channelwise-compressed-tensors.yaml | 4 ++-- .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct-Channelwise-compressed-tensors.yaml b/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct-Channelwise-compressed-tensors.yaml index 39b6f20805bdc..c513159c6fa0d 100644 --- a/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct-Channelwise-compressed-tensors.yaml +++ b/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct-Channelwise-compressed-tensors.yaml @@ -4,8 +4,8 @@ tasks: - name: "gsm8k" metrics: - name: "exact_match,strict-match" - value: 0.769 + value: 0.752 - name: "exact_match,flexible-extract" - value: 0.769 + value: 0.754 limit: 1000 num_fewshot: 5 diff --git a/.buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh b/.buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh index 2f04cc1283df3..de841d959a4e4 100644 --- a/.buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh +++ b/.buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh @@ -46,6 +46,6 @@ while getopts "m:b:l:f:t:" OPT; do done lm_eval --model vllm \ - --model_args pretrained=$MODEL,tensor_parallel_size=$TP_SIZE,add_bos_token=true,distributed_executor_backend="ray",trust_remote_code=true,max_model_len=4096 \ + --model_args pretrained=$MODEL,tensor_parallel_size=$TP_SIZE,distributed_executor_backend="ray",trust_remote_code=true,max_model_len=4096 \ --tasks gsm8k --num_fewshot $FEWSHOT --limit $LIMIT \ --batch_size $BATCH_SIZE From f9d569cb9e2fbc5ce43a1285515f86707658a44f Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Sat, 20 Jul 2024 13:47:01 +0000 Subject: [PATCH 33/40] updated for new ckpt format, turned on ada lovelace, and added test case --- ...Llama-3-8B-Instruct-FBGEMM-nonuniform.yaml | 11 ++++++++++ .../layers/quantization/fbgemm_fp8.py | 20 +++++++++---------- 2 files changed, 20 insertions(+), 11 deletions(-) create mode 100644 .buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct-FBGEMM-nonuniform.yaml diff --git a/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct-FBGEMM-nonuniform.yaml b/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct-FBGEMM-nonuniform.yaml new file mode 100644 index 0000000000000..5e57fcbcf7d9b --- /dev/null +++ b/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct-FBGEMM-nonuniform.yaml @@ -0,0 +1,11 @@ +# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m nm-testing/Meta-Llama-3-8B-Instruct-FBGEMM-nonuniform -b auto -l 1000 -f 5 -t 1 +model_name: "nm-testing/Meta-Llama-3-8B-Instruct-FBGEMM-nonuniform" +tasks: +- name: "gsm8k" + metrics: + - name: "exact_match,strict-match" + value: 0.753 + - name: "exact_match,flexible-extract" + value: 0.753 +limit: 1000 +num_fewshot: 5 diff --git a/vllm/model_executor/layers/quantization/fbgemm_fp8.py b/vllm/model_executor/layers/quantization/fbgemm_fp8.py index aed7dd194795c..837898b5f0edc 100644 --- a/vllm/model_executor/layers/quantization/fbgemm_fp8.py +++ b/vllm/model_executor/layers/quantization/fbgemm_fp8.py @@ -27,8 +27,9 @@ class FBGEMMFp8Config(QuantizationConfig): """Config class for FBGEMM Fp8.""" - def __init__(self, ignore_list: List[str]): + def __init__(self, ignore_list: List[str], input_scale_ub: float): self.ignore_list = ignore_list + self.input_scale_ub = input_scale_ub @classmethod def get_name(cls) -> str: @@ -40,7 +41,7 @@ def get_supported_act_dtypes(cls) -> List[torch.dtype]: @classmethod def get_min_capability(cls) -> int: - return 90 + return 89 @classmethod def get_config_filenames(cls) -> List[str]: @@ -49,7 +50,8 @@ def get_config_filenames(cls) -> List[str]: @classmethod def from_config(cls, config: Dict[str, Any]) -> "FBGEMMFp8Config": ignore_list = cls.get_from_keys(config, ["modules_to_not_convert"]) - return cls(ignore_list=ignore_list) + input_scale_ub = cls.get_from_keys(config, ["activation_scale_ub"]) + return cls(ignore_list=ignore_list, input_scale_ub=input_scale_ub) def _is_layer_skipped(self, prefix: str) -> bool: # prefix: model.layers.0.self_attn.q_proj @@ -132,14 +134,10 @@ def create_weights( layer.register_parameter("weight_scale", weight_scale) # INPUT SCALE UPPER BOUND - input_scale_ub = torch.nn.Parameter(torch.zeros((1), - dtype=torch.float32), - requires_grad=False) - layer.register_parameter("input_scale_ub", input_scale_ub) - set_weight_attrs(input_scale_ub, { - "ignore_warning": True, - **extra_weight_attrs - }) + input_scale_ub = torch.nn.Parameter( + torch.tensor((self.quant_config.input_scale_ub), dtype=torch.float32), + requires_grad=False) + layer.input_scale_ub = input_scale_ub def process_weights_after_loading(self, layer: Module) -> None: weight = layer.weight From ae45615dc2f2ac8ce5002d6838ec4374cfdb2c3a Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Sat, 20 Jul 2024 13:48:11 +0000 Subject: [PATCH 34/40] format --- vllm/model_executor/layers/quantization/fbgemm_fp8.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/layers/quantization/fbgemm_fp8.py b/vllm/model_executor/layers/quantization/fbgemm_fp8.py index 837898b5f0edc..e6e8d28e3e16a 100644 --- a/vllm/model_executor/layers/quantization/fbgemm_fp8.py +++ b/vllm/model_executor/layers/quantization/fbgemm_fp8.py @@ -134,9 +134,9 @@ def create_weights( layer.register_parameter("weight_scale", weight_scale) # INPUT SCALE UPPER BOUND - input_scale_ub = torch.nn.Parameter( - torch.tensor((self.quant_config.input_scale_ub), dtype=torch.float32), - requires_grad=False) + input_scale_ub = torch.nn.Parameter(torch.tensor( + (self.quant_config.input_scale_ub), dtype=torch.float32), + requires_grad=False) layer.input_scale_ub = input_scale_ub def process_weights_after_loading(self, layer: Module) -> None: From e2a1eda9c20771d3333f7ed379ea1dd5cd461626 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Sat, 20 Jul 2024 14:28:42 +0000 Subject: [PATCH 35/40] add marlin support to fbgemm --- .../layers/quantization/fbgemm_fp8.py | 32 ++++++++++++++++++- .../quantization/utils/marlin_utils_fp8.py | 9 ++++-- 2 files changed, 38 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/layers/quantization/fbgemm_fp8.py b/vllm/model_executor/layers/quantization/fbgemm_fp8.py index e6e8d28e3e16a..832be99165e20 100644 --- a/vllm/model_executor/layers/quantization/fbgemm_fp8.py +++ b/vllm/model_executor/layers/quantization/fbgemm_fp8.py @@ -11,7 +11,10 @@ QuantizationConfig, QuantizeMethodBase) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( apply_fp8_linear, create_per_channel_scale_param) +from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( + apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin) from vllm.model_executor.utils import set_weight_attrs +from vllm.platforms import current_platform logger = init_logger(__name__) @@ -41,7 +44,7 @@ def get_supported_act_dtypes(cls) -> List[torch.dtype]: @classmethod def get_min_capability(cls) -> int: - return 89 + return 80 @classmethod def get_config_filenames(cls) -> List[str]: @@ -97,6 +100,14 @@ class FBGEMMFp8LinearMethod(LinearMethodBase): def __init__(self, quant_config: FBGEMMFp8Config): self.quant_config = quant_config + # For GPUs that lack FP8 hardware support, we can leverage the Marlin + # kernel for fast weight-only FP8 quantization + capability = current_platform.get_device_capability() + capability = capability[0] * 10 + capability[1] + self.use_marlin = capability < 89 + + self.use_marlin = True + def create_weights( self, layer: torch.nn.Module, @@ -139,15 +150,34 @@ def create_weights( requires_grad=False) layer.input_scale_ub = input_scale_ub + if self.use_marlin: + layer.input_size_per_partition = input_size_per_partition + layer.output_size_per_partition = output_size_per_partition + def process_weights_after_loading(self, layer: Module) -> None: weight = layer.weight layer.weight = Parameter(weight.t(), requires_grad=False) + if self.use_marlin: + prepare_fp8_layer_for_marlin(layer) + # Activations not quantized for marlin. + del layer.input_scale_ub + def apply(self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: + if self.use_marlin: + return apply_fp8_marlin_linear( + input=x, + weight=layer.weight, + weight_scale=layer.weight_scale, + workspace=layer.workspace, + size_n=layer.output_size_per_partition, + size_k=layer.input_size_per_partition, + bias=bias) + return apply_fp8_linear(input=x, weight=layer.weight, weight_scale=layer.weight_scale, diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py b/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py index e93eb747ba2eb..84063b23b3e99 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py @@ -76,8 +76,13 @@ def prepare_fp8_layer_for_marlin(layer: torch.nn.Module) -> None: # WEIGHT SCALES # Currently Marlin doesn't support per-tensor scales, so we # expand it to channelwise - scales = layer.weight_scale.repeat(1, part_size_n).to( - layer.orig_dtype).to(device) + is_channelwise = layer.weight_scale.shape[0] == part_size_n + if is_channelwise: + scales = layer.weight_scale + else: + scales = layer.weight_scale.repeat(1, part_size_n) + scales = scales.to(layer.orig_dtype).to(device) + # Permute scales marlin_scales = marlin_permute_scales(s=scales, size_k=part_size_k, From a4abc787fb62e22bbb2754d7738ae5c7c4530784 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Sat, 20 Jul 2024 14:34:12 +0000 Subject: [PATCH 36/40] fix configs --- .../layers/quantization/fbgemm_fp8.py | 22 +++++++++---------- .../quantization/utils/marlin_utils_fp8.py | 2 +- 2 files changed, 11 insertions(+), 13 deletions(-) diff --git a/vllm/model_executor/layers/quantization/fbgemm_fp8.py b/vllm/model_executor/layers/quantization/fbgemm_fp8.py index 832be99165e20..4dbd94ae4b454 100644 --- a/vllm/model_executor/layers/quantization/fbgemm_fp8.py +++ b/vllm/model_executor/layers/quantization/fbgemm_fp8.py @@ -9,10 +9,10 @@ UnquantizedLinearMethod) from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) -from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - apply_fp8_linear, create_per_channel_scale_param) from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin) +from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( + apply_fp8_linear, create_per_channel_scale_param) from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform @@ -34,6 +34,12 @@ def __init__(self, ignore_list: List[str], input_scale_ub: float): self.ignore_list = ignore_list self.input_scale_ub = input_scale_ub + # For GPUs that lack FP8 hardware support, we can leverage the Marlin + # kernel for fast weight-only FP8 quantization + capability = current_platform.get_device_capability() + capability = capability[0] * 10 + capability[1] + self.use_marlin = capability < 89 + @classmethod def get_name(cls) -> str: return "fbgemm_fp8" @@ -100,14 +106,6 @@ class FBGEMMFp8LinearMethod(LinearMethodBase): def __init__(self, quant_config: FBGEMMFp8Config): self.quant_config = quant_config - # For GPUs that lack FP8 hardware support, we can leverage the Marlin - # kernel for fast weight-only FP8 quantization - capability = current_platform.get_device_capability() - capability = capability[0] * 10 + capability[1] - self.use_marlin = capability < 89 - - self.use_marlin = True - def create_weights( self, layer: torch.nn.Module, @@ -150,7 +148,7 @@ def create_weights( requires_grad=False) layer.input_scale_ub = input_scale_ub - if self.use_marlin: + if self.quant_config.use_marlin: layer.input_size_per_partition = input_size_per_partition layer.output_size_per_partition = output_size_per_partition @@ -158,7 +156,7 @@ def process_weights_after_loading(self, layer: Module) -> None: weight = layer.weight layer.weight = Parameter(weight.t(), requires_grad=False) - if self.use_marlin: + if self.quant_config.use_marlin: prepare_fp8_layer_for_marlin(layer) # Activations not quantized for marlin. del layer.input_scale_ub diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py b/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py index 84063b23b3e99..aabd46e64536f 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py @@ -80,7 +80,7 @@ def prepare_fp8_layer_for_marlin(layer: torch.nn.Module) -> None: if is_channelwise: scales = layer.weight_scale else: - scales = layer.weight_scale.repeat(1, part_size_n) + scales = layer.weight_scale.repeat(1, part_size_n) scales = scales.to(layer.orig_dtype).to(device) # Permute scales From 5008ecb324cb529d992cf7e3430dbb3099076289 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Sat, 20 Jul 2024 14:37:04 +0000 Subject: [PATCH 37/40] fix configs --- vllm/model_executor/layers/quantization/fbgemm_fp8.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/vllm/model_executor/layers/quantization/fbgemm_fp8.py b/vllm/model_executor/layers/quantization/fbgemm_fp8.py index 4dbd94ae4b454..7261f43c3ae8e 100644 --- a/vllm/model_executor/layers/quantization/fbgemm_fp8.py +++ b/vllm/model_executor/layers/quantization/fbgemm_fp8.py @@ -40,6 +40,8 @@ def __init__(self, ignore_list: List[str], input_scale_ub: float): capability = capability[0] * 10 + capability[1] self.use_marlin = capability < 89 + self.use_marlin = True + @classmethod def get_name(cls) -> str: return "fbgemm_fp8" @@ -148,10 +150,6 @@ def create_weights( requires_grad=False) layer.input_scale_ub = input_scale_ub - if self.quant_config.use_marlin: - layer.input_size_per_partition = input_size_per_partition - layer.output_size_per_partition = output_size_per_partition - def process_weights_after_loading(self, layer: Module) -> None: weight = layer.weight layer.weight = Parameter(weight.t(), requires_grad=False) @@ -166,7 +164,7 @@ def apply(self, x: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: - if self.use_marlin: + if self.quant_config.use_marlin: return apply_fp8_marlin_linear( input=x, weight=layer.weight, From 615a2ed5a291a39d224d8dc5b0208f6bd7dfd5d6 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Sat, 20 Jul 2024 15:22:01 +0000 Subject: [PATCH 38/40] added marlin nonuniform test --- .../Meta-Llama-3-70B-Instruct-FBGEMM-nonuniform.yaml | 11 +++++++++++ .buildkite/lm-eval-harness/configs/models-large.txt | 1 + 2 files changed, 12 insertions(+) create mode 100644 .buildkite/lm-eval-harness/configs/Meta-Llama-3-70B-Instruct-FBGEMM-nonuniform.yaml diff --git a/.buildkite/lm-eval-harness/configs/Meta-Llama-3-70B-Instruct-FBGEMM-nonuniform.yaml b/.buildkite/lm-eval-harness/configs/Meta-Llama-3-70B-Instruct-FBGEMM-nonuniform.yaml new file mode 100644 index 0000000000000..8529e7ddb4bc8 --- /dev/null +++ b/.buildkite/lm-eval-harness/configs/Meta-Llama-3-70B-Instruct-FBGEMM-nonuniform.yaml @@ -0,0 +1,11 @@ +# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-hf-baseline.sh -m nm-testing/Meta-Llama-3-70B-Instruct-FBGEMM -b auto -l 1000 -f 5 +model_name: "meta-llama/Meta-Llama-3-70B-Instruct" +tasks: +- name: "gsm8k" + metrics: + - name: "exact_match,strict-match" + value: 0.905 + - name: "exact_match,flexible-extract" + value: 0.905 +limit: 1000 +num_fewshot: 5 diff --git a/.buildkite/lm-eval-harness/configs/models-large.txt b/.buildkite/lm-eval-harness/configs/models-large.txt index 94b15a87235b9..37eeac85c933b 100644 --- a/.buildkite/lm-eval-harness/configs/models-large.txt +++ b/.buildkite/lm-eval-harness/configs/models-large.txt @@ -1,3 +1,4 @@ +Meta-Llama-3-70B-Instruct-FBGEMM-nonuniform.yaml Meta-Llama-3-70B-Instruct.yaml Mixtral-8x7B-Instruct-v0.1.yaml Qwen2-57B-A14-Instruct.yaml From da3759834a4e20219c1d293838aa651e1633e03c Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Sat, 20 Jul 2024 16:00:41 +0000 Subject: [PATCH 39/40] fixed --- .../configs/Meta-Llama-3-70B-Instruct-FBGEMM-nonuniform.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.buildkite/lm-eval-harness/configs/Meta-Llama-3-70B-Instruct-FBGEMM-nonuniform.yaml b/.buildkite/lm-eval-harness/configs/Meta-Llama-3-70B-Instruct-FBGEMM-nonuniform.yaml index 8529e7ddb4bc8..4397effa82cc8 100644 --- a/.buildkite/lm-eval-harness/configs/Meta-Llama-3-70B-Instruct-FBGEMM-nonuniform.yaml +++ b/.buildkite/lm-eval-harness/configs/Meta-Llama-3-70B-Instruct-FBGEMM-nonuniform.yaml @@ -1,5 +1,5 @@ -# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-hf-baseline.sh -m nm-testing/Meta-Llama-3-70B-Instruct-FBGEMM -b auto -l 1000 -f 5 -model_name: "meta-llama/Meta-Llama-3-70B-Instruct" +# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-hf-baseline.sh -m nm-testing/Meta-Llama-3-70B-Instruct-FBGEMM-nonuniform -b auto -l 1000 -f 5 +model_name: "nm-testing/Meta-Llama-3-70B-Instruct-FBGEMM-nonuniform" tasks: - name: "gsm8k" metrics: From 183bfe77b499a317f1097a8cb0d53a9545eb7b87 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Sat, 20 Jul 2024 16:58:43 +0000 Subject: [PATCH 40/40] use marlin remove: --- vllm/model_executor/layers/quantization/fbgemm_fp8.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/vllm/model_executor/layers/quantization/fbgemm_fp8.py b/vllm/model_executor/layers/quantization/fbgemm_fp8.py index 7261f43c3ae8e..e84564714171a 100644 --- a/vllm/model_executor/layers/quantization/fbgemm_fp8.py +++ b/vllm/model_executor/layers/quantization/fbgemm_fp8.py @@ -40,8 +40,6 @@ def __init__(self, ignore_list: List[str], input_scale_ub: float): capability = capability[0] * 10 + capability[1] self.use_marlin = capability < 89 - self.use_marlin = True - @classmethod def get_name(cls) -> str: return "fbgemm_fp8"