From e976e031690c29132416046f357be8d27adc3241 Mon Sep 17 00:00:00 2001 From: Robert Shaw <114415538+robertgshaw2-neuralmagic@users.noreply.github.com> Date: Sat, 20 Jul 2024 12:36:57 -0400 Subject: [PATCH] [ Misc ] `fbgemm` checkpoints (#6559) --- ...struct-Channelwise-compressed-tensors.yaml | 4 +- ...Llama-3-8B-Instruct-FBGEMM-nonuniform.yaml | 11 ++ .../run-lm-eval-gsm-vllm-baseline.sh | 2 +- vllm/_custom_ops.py | 2 + vllm/attention/layer.py | 3 +- vllm/config.py | 2 +- vllm/model_executor/layers/fused_moe/layer.py | 2 +- vllm/model_executor/layers/linear.py | 26 +-- .../layers/quantization/__init__.py | 2 + .../layers/quantization/aqlm.py | 4 +- .../model_executor/layers/quantization/awq.py | 4 +- .../layers/quantization/base_config.py | 5 +- .../layers/quantization/bitsandbytes.py | 5 +- .../compressed_tensors/compressed_tensors.py | 6 +- .../layers/quantization/deepspeedfp.py | 5 +- .../layers/quantization/fbgemm_fp8.py | 158 ++++++++++++++++++ .../model_executor/layers/quantization/fp8.py | 4 +- .../layers/quantization/gptq.py | 4 +- .../layers/quantization/gptq_marlin.py | 5 +- .../layers/quantization/gptq_marlin_24.py | 5 +- .../layers/quantization/marlin.py | 4 +- .../layers/quantization/squeezellm.py | 4 +- .../layers/quantization/utils/w8a8_utils.py | 2 + .../layers/vocab_parallel_embedding.py | 12 +- 24 files changed, 234 insertions(+), 47 deletions(-) create mode 100644 .buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct-FBGEMM-nonuniform.yaml create mode 100644 vllm/model_executor/layers/quantization/fbgemm_fp8.py diff --git a/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct-Channelwise-compressed-tensors.yaml b/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct-Channelwise-compressed-tensors.yaml index 39b6f20805bdc..c513159c6fa0d 100644 --- a/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct-Channelwise-compressed-tensors.yaml +++ b/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct-Channelwise-compressed-tensors.yaml @@ -4,8 +4,8 @@ tasks: - name: "gsm8k" metrics: - name: "exact_match,strict-match" - value: 0.769 + value: 0.752 - name: "exact_match,flexible-extract" - value: 0.769 + value: 0.754 limit: 1000 num_fewshot: 5 diff --git a/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct-FBGEMM-nonuniform.yaml b/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct-FBGEMM-nonuniform.yaml new file mode 100644 index 0000000000000..5e57fcbcf7d9b --- /dev/null +++ b/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct-FBGEMM-nonuniform.yaml @@ -0,0 +1,11 @@ +# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m nm-testing/Meta-Llama-3-8B-Instruct-FBGEMM-nonuniform -b auto -l 1000 -f 5 -t 1 +model_name: "nm-testing/Meta-Llama-3-8B-Instruct-FBGEMM-nonuniform" +tasks: +- name: "gsm8k" + metrics: + - name: "exact_match,strict-match" + value: 0.753 + - name: "exact_match,flexible-extract" + value: 0.753 +limit: 1000 +num_fewshot: 5 diff --git a/.buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh b/.buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh index 2f04cc1283df3..de841d959a4e4 100644 --- a/.buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh +++ b/.buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh @@ -46,6 +46,6 @@ while getopts "m:b:l:f:t:" OPT; do done lm_eval --model vllm \ - --model_args pretrained=$MODEL,tensor_parallel_size=$TP_SIZE,add_bos_token=true,distributed_executor_backend="ray",trust_remote_code=true,max_model_len=4096 \ + --model_args pretrained=$MODEL,tensor_parallel_size=$TP_SIZE,distributed_executor_backend="ray",trust_remote_code=true,max_model_len=4096 \ --tasks gsm8k --num_fewshot $FEWSHOT --limit $LIMIT \ --batch_size $BATCH_SIZE diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 873c6786a85a0..80ca357e8b293 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -315,6 +315,8 @@ def scaled_fp8_quant( Args: input: The input tensor to be quantized to FP8 scale: Optional scaling factor for the FP8 quantization + scale_ub: Optional upper bound for scaling factor in dynamic + per token case batch_dim_padding: If specified, pad the first dimension of the output to at least this value. use_per_token_if_dynamic: Whether to do per_tensor or per_token diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 0619bda90a2a7..643a845899c37 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -34,6 +34,7 @@ def __init__( cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, blocksparse_params: Optional[Dict[str, Any]] = None, + prefix: str = "", ) -> None: super().__init__() if cache_config is not None: @@ -56,7 +57,7 @@ def __init__( self._k_scale = 1.0 self._v_scale = 1.0 quant_method = quant_config.get_quant_method( - self) if quant_config else None + self, prefix=prefix) if quant_config else None if quant_method is not None: assert isinstance(quant_method, Fp8KVCacheMethod) # TODO (mgoin): kv cache dtype should be specified in the FP8 diff --git a/vllm/config.py b/vllm/config.py index 8dde171576973..81ef9526c8b9b 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -251,7 +251,7 @@ def _verify_quantization(self) -> None: f"supported in ROCm.") if (self.quantization not in ("fp8", "marlin", "gptq_marlin_24", "gptq_marlin", - "compressed_tensors")): + "fbgemm_fp8", "compressed_tensors")): logger.warning( "%s quantization is not fully " "optimized yet. The speed can be slower than " diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index a6fa8ffe5111c..a0dc4c94744a8 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -182,7 +182,7 @@ def __init__( self.quant_method: Optional[QuantizeMethodBase] = ( UnquantizedFusedMoEMethod()) else: - self.quant_method = quant_config.get_quant_method(self) + self.quant_method = quant_config.get_quant_method(self, prefix) assert self.quant_method is not None self.quant_method.create_weights( diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 1106d7985ecc6..0e0a2b72f93d4 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -141,6 +141,7 @@ def __init__( skip_bias_add: bool = False, params_dtype: Optional[torch.dtype] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() @@ -155,7 +156,8 @@ def __init__( self.quant_method: Optional[ QuantizeMethodBase] = UnquantizedLinearMethod() else: - self.quant_method = quant_config.get_quant_method(self) + self.quant_method = quant_config.get_quant_method(self, + prefix=prefix) def forward(self, x: torch.Tensor) -> torch.Tensor: raise NotImplementedError @@ -182,9 +184,13 @@ def __init__(self, skip_bias_add: bool = False, params_dtype: Optional[torch.dtype] = None, quant_config: Optional[QuantizationConfig] = None, - prefix: Optional[str] = None): - super().__init__(input_size, output_size, skip_bias_add, params_dtype, - quant_config) + prefix: str = ""): + super().__init__(input_size, + output_size, + skip_bias_add, + params_dtype, + quant_config, + prefix=prefix) # All the linear layer supports quant method. assert self.quant_method is not None @@ -258,9 +264,9 @@ def __init__(self, params_dtype: Optional[torch.dtype] = None, quant_config: Optional[QuantizationConfig] = None, output_sizes: Optional[List[int]] = None, - prefix: Optional[str] = None): + prefix: str = ""): super().__init__(input_size, output_size, skip_bias_add, params_dtype, - quant_config) + quant_config, prefix) self.gather_output = gather_output @@ -370,7 +376,7 @@ def __init__(self, skip_bias_add: bool = False, params_dtype: Optional[torch.dtype] = None, quant_config: Optional[QuantizationConfig] = None, - prefix: Optional[str] = None): + prefix: str = ""): self.output_sizes = output_sizes tp_size = get_tensor_model_parallel_world_size() assert all(output_size % tp_size == 0 for output_size in output_sizes) @@ -514,7 +520,7 @@ def __init__(self, skip_bias_add: bool = False, params_dtype: Optional[torch.dtype] = None, quant_config: Optional[QuantizationConfig] = None, - prefix: Optional[str] = None): + prefix: str = ""): self.hidden_size = hidden_size self.head_size = head_size self.total_num_heads = total_num_heads @@ -707,9 +713,9 @@ def __init__(self, params_dtype: Optional[torch.dtype] = None, reduce_results: bool = True, quant_config: Optional[QuantizationConfig] = None, - prefix: Optional[str] = None): + prefix: str = ""): super().__init__(input_size, output_size, skip_bias_add, params_dtype, - quant_config) + quant_config, prefix) self.input_is_parallel = input_is_parallel self.reduce_results = reduce_results diff --git a/vllm/model_executor/layers/quantization/__init__.py b/vllm/model_executor/layers/quantization/__init__.py index 40b0df75a69a6..c1bb45224fcc1 100644 --- a/vllm/model_executor/layers/quantization/__init__.py +++ b/vllm/model_executor/layers/quantization/__init__.py @@ -10,6 +10,7 @@ CompressedTensorsConfig) from vllm.model_executor.layers.quantization.deepspeedfp import ( DeepSpeedFPConfig) +from vllm.model_executor.layers.quantization.fbgemm_fp8 import FBGEMMFp8Config from vllm.model_executor.layers.quantization.fp8 import Fp8Config from vllm.model_executor.layers.quantization.gptq import GPTQConfig from vllm.model_executor.layers.quantization.gptq_marlin import ( @@ -24,6 +25,7 @@ "awq": AWQConfig, "deepspeedfp": DeepSpeedFPConfig, "fp8": Fp8Config, + "fbgemm_fp8": FBGEMMFp8Config, # The order of gptq methods is important for config.py iteration over # override_quantization_method(..) "marlin": MarlinConfig, diff --git a/vllm/model_executor/layers/quantization/aqlm.py b/vllm/model_executor/layers/quantization/aqlm.py index 730595c3d36d1..95ff05b986ab4 100644 --- a/vllm/model_executor/layers/quantization/aqlm.py +++ b/vllm/model_executor/layers/quantization/aqlm.py @@ -207,8 +207,8 @@ def from_config(cls, config: Dict[str, Any]) -> "AQLMConfig": return cls(in_group_size, nbits_per_codebook, num_code_books, out_group_size) - def get_quant_method( - self, layer: torch.nn.Module) -> Optional["AQLMLinearMethod"]: + def get_quant_method(self, layer: torch.nn.Module, + prefix: str) -> Optional["AQLMLinearMethod"]: if isinstance(layer, LinearBase): return AQLMLinearMethod(self) return None diff --git a/vllm/model_executor/layers/quantization/awq.py b/vllm/model_executor/layers/quantization/awq.py index a3854f70bb4fa..ce2fa62ef565f 100644 --- a/vllm/model_executor/layers/quantization/awq.py +++ b/vllm/model_executor/layers/quantization/awq.py @@ -63,8 +63,8 @@ def from_config(cls, config: Dict[str, Any]) -> "AWQConfig": zero_point = cls.get_from_keys(config, ["zero_point"]) return cls(weight_bits, group_size, zero_point) - def get_quant_method( - self, layer: torch.nn.Module) -> Optional["AWQLinearMethod"]: + def get_quant_method(self, layer: torch.nn.Module, + prefix: str) -> Optional["AWQLinearMethod"]: if isinstance(layer, LinearBase): return AWQLinearMethod(self) return None diff --git a/vllm/model_executor/layers/quantization/base_config.py b/vllm/model_executor/layers/quantization/base_config.py index 1607470cb76f6..f5ff27b9f14b7 100644 --- a/vllm/model_executor/layers/quantization/base_config.py +++ b/vllm/model_executor/layers/quantization/base_config.py @@ -97,12 +97,13 @@ def get_from_keys_or(config: Dict[str, Any], keys: List[str], return default @abstractmethod - def get_quant_method( - self, layer: torch.nn.Module) -> Optional[QuantizeMethodBase]: + def get_quant_method(self, layer: torch.nn.Module, + prefix: str) -> Optional[QuantizeMethodBase]: """Get the quantize method to use for the quantized layer. Args: layer: The layer for the quant method. + prefix: The full name of the layer in the state dict Returns: The quantize method. None if the given layer doesn't support quant method. diff --git a/vllm/model_executor/layers/quantization/bitsandbytes.py b/vllm/model_executor/layers/quantization/bitsandbytes.py index e76714a7b460c..4a68da5a2323e 100644 --- a/vllm/model_executor/layers/quantization/bitsandbytes.py +++ b/vllm/model_executor/layers/quantization/bitsandbytes.py @@ -60,9 +60,8 @@ def from_config(cls, config: Dict[str, Any]) -> "BitsAndBytesConfig": target_modules = cls.get_from_keys(config, ["target_modules"]) return cls(adapter_name, target_modules) - def get_quant_method( - self, - layer: torch.nn.Module) -> Optional["BitsAndBytesLinearMethod"]: + def get_quant_method(self, layer: torch.nn.Module, + prefix: str) -> Optional["BitsAndBytesLinearMethod"]: if isinstance(layer, LinearBase): return BitsAndBytesLinearMethod(self) return None diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index 28c552b3654f3..0accc94231b9c 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -44,8 +44,12 @@ def get_min_capability(cls) -> int: def get_name(self) -> str: return "compressed_tensors" + # TODO (@robertgshaw2-neuralmagic): do layer skipping though here + # rather than though create_weights to match other methods def get_quant_method( - self, layer: torch.nn.Module + self, + layer: torch.nn.Module, + prefix: str, ) -> Optional["CompressedTensorsLinearMethod"]: if isinstance(layer, LinearBase): return CompressedTensorsLinearMethod(self) diff --git a/vllm/model_executor/layers/quantization/deepspeedfp.py b/vllm/model_executor/layers/quantization/deepspeedfp.py index 31cdffbcf0ab9..29484801dc380 100644 --- a/vllm/model_executor/layers/quantization/deepspeedfp.py +++ b/vllm/model_executor/layers/quantization/deepspeedfp.py @@ -69,9 +69,8 @@ def get_config_filenames() -> List[str]: "quantize_config.json", ] - def get_quant_method( - self, - layer: torch.nn.Module) -> Optional["DeepSpeedFPLinearMethod"]: + def get_quant_method(self, layer: torch.nn.Module, + prefix: str) -> Optional["DeepSpeedFPLinearMethod"]: if isinstance(layer, LinearBase): return DeepSpeedFPLinearMethod(self) return None diff --git a/vllm/model_executor/layers/quantization/fbgemm_fp8.py b/vllm/model_executor/layers/quantization/fbgemm_fp8.py new file mode 100644 index 0000000000000..e6e8d28e3e16a --- /dev/null +++ b/vllm/model_executor/layers/quantization/fbgemm_fp8.py @@ -0,0 +1,158 @@ +from typing import Any, Dict, List, Optional + +import torch +from torch.nn import Module +from torch.nn.parameter import Parameter + +from vllm.logger import init_logger +from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, + UnquantizedLinearMethod) +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig, QuantizeMethodBase) +from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( + apply_fp8_linear, create_per_channel_scale_param) +from vllm.model_executor.utils import set_weight_attrs + +logger = init_logger(__name__) + +# Note: this is a hack. We should update each model to register the +# stacked params and get it from there instead in a future PR. +# fused_name: List[shard_name] +_FUSED_LAYER_NAME_MAPPING = { + "qkv_proj": ["q_proj", "k_proj", "v_proj"], + "gate_up_proj": ["gate_proj", "up_proj"] +} + + +class FBGEMMFp8Config(QuantizationConfig): + """Config class for FBGEMM Fp8.""" + + def __init__(self, ignore_list: List[str], input_scale_ub: float): + self.ignore_list = ignore_list + self.input_scale_ub = input_scale_ub + + @classmethod + def get_name(cls) -> str: + return "fbgemm_fp8" + + @classmethod + def get_supported_act_dtypes(cls) -> List[torch.dtype]: + return [torch.bfloat16, torch.float16] + + @classmethod + def get_min_capability(cls) -> int: + return 89 + + @classmethod + def get_config_filenames(cls) -> List[str]: + return [] + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "FBGEMMFp8Config": + ignore_list = cls.get_from_keys(config, ["modules_to_not_convert"]) + input_scale_ub = cls.get_from_keys(config, ["activation_scale_ub"]) + return cls(ignore_list=ignore_list, input_scale_ub=input_scale_ub) + + def _is_layer_skipped(self, prefix: str) -> bool: + # prefix: model.layers.0.self_attn.q_proj + # proj_name: q_proj + proj_name = prefix.split(".")[-1] + if proj_name in _FUSED_LAYER_NAME_MAPPING: + shard_prefixes = [ + prefix.replace(proj_name, shard_proj_name) + for shard_proj_name in _FUSED_LAYER_NAME_MAPPING[proj_name] + ] + + is_skipped = None + for shard_prefix in shard_prefixes: + is_shard_skipped = shard_prefix in self.ignore_list + + if is_skipped is None: + is_skipped = is_shard_skipped + elif is_shard_skipped != is_skipped: + raise ValueError( + f"Detected some but not all shards of {prefix} " + "are quantized. All shards of fused layers " + "to have the same precision.") + else: + is_skipped = prefix in self.ignore_list + + assert is_skipped is not None + return is_skipped + + def get_quant_method(self, layer: torch.nn.Module, + prefix: str) -> Optional["QuantizeMethodBase"]: + if isinstance(layer, LinearBase): + if self._is_layer_skipped(prefix): + return UnquantizedLinearMethod() + return FBGEMMFp8LinearMethod(self) + return None + + def get_scaled_act_names(self) -> List[str]: + return [] + + +class FBGEMMFp8LinearMethod(LinearMethodBase): + + def __init__(self, quant_config: FBGEMMFp8Config): + self.quant_config = quant_config + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + del input_size, output_size + output_size_per_partition = sum(output_partition_sizes) + + layer.logical_widths = output_partition_sizes + + layer.input_size_per_partition = input_size_per_partition + layer.output_size_per_partition = output_size_per_partition + layer.orig_dtype = params_dtype + + # WEIGHT + weight = Parameter(torch.empty(output_size_per_partition, + input_size_per_partition, + dtype=torch.float8_e4m3fn), + requires_grad=False) + layer.register_parameter("weight", weight) + set_weight_attrs(weight, { + "input_dim": 1, + "output_dim": 0, + **extra_weight_attrs, + }) + + # WEIGHT SCALE + weight_scale = create_per_channel_scale_param(output_partition_sizes, + **extra_weight_attrs) + layer.register_parameter("weight_scale", weight_scale) + + # INPUT SCALE UPPER BOUND + input_scale_ub = torch.nn.Parameter(torch.tensor( + (self.quant_config.input_scale_ub), dtype=torch.float32), + requires_grad=False) + layer.input_scale_ub = input_scale_ub + + def process_weights_after_loading(self, layer: Module) -> None: + weight = layer.weight + layer.weight = Parameter(weight.t(), requires_grad=False) + + def apply(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + + return apply_fp8_linear(input=x, + weight=layer.weight, + weight_scale=layer.weight_scale, + input_scale=None, + input_scale_ub=layer.input_scale_ub, + bias=bias, + cutlass_fp8_supported=True, + use_per_token_if_dynamic=True) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 820c066aad28a..d4498f452cc06 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -66,8 +66,8 @@ def from_config(cls, config: Dict[str, Any]) -> "Fp8Config": return cls(is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized, activation_scheme=activation_scheme) - def get_quant_method( - self, layer: torch.nn.Module) -> Optional["QuantizeMethodBase"]: + def get_quant_method(self, layer: torch.nn.Module, + prefix: str) -> Optional["QuantizeMethodBase"]: from vllm.attention.layer import Attention # Avoid circular import if isinstance(layer, LinearBase): diff --git a/vllm/model_executor/layers/quantization/gptq.py b/vllm/model_executor/layers/quantization/gptq.py index 595d6ab96b1b9..510c9dd49ef03 100644 --- a/vllm/model_executor/layers/quantization/gptq.py +++ b/vllm/model_executor/layers/quantization/gptq.py @@ -69,8 +69,8 @@ def from_config(cls, config: Dict[str, Any]) -> "GPTQConfig": default=False) return cls(weight_bits, group_size, desc_act, lm_head_quantized) - def get_quant_method( - self, layer: torch.nn.Module) -> Optional["GPTQLinearMethod"]: + def get_quant_method(self, layer: torch.nn.Module, + prefix: str) -> Optional["GPTQLinearMethod"]: if (isinstance(layer, LinearBase) or (isinstance(layer, ParallelLMHead) and self.lm_head_quantized)): return GPTQLinearMethod(self) diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index 07a73d06e0596..bb9644dbc9947 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -94,9 +94,8 @@ def override_quantization_method(cls, hf_quant_cfg, " faster inference") return None - def get_quant_method( - self, - layer: torch.nn.Module) -> Optional["GPTQMarlinLinearMethod"]: + def get_quant_method(self, layer: torch.nn.Module, + prefix: str) -> Optional["GPTQMarlinLinearMethod"]: if (isinstance(layer, LinearBase) or (isinstance(layer, ParallelLMHead) and self.lm_head_quantized)): return GPTQMarlinLinearMethod(self) diff --git a/vllm/model_executor/layers/quantization/gptq_marlin_24.py b/vllm/model_executor/layers/quantization/gptq_marlin_24.py index 6bcfc405afe71..e708c4da95af3 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin_24.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin_24.py @@ -109,9 +109,8 @@ def override_quantization_method(cls, hf_quant_cfg, return None - def get_quant_method( - self, - layer: torch.nn.Module) -> Optional["GPTQMarlin24LinearMethod"]: + def get_quant_method(self, layer: torch.nn.Module, + prefix: str) -> Optional["GPTQMarlin24LinearMethod"]: if isinstance(layer, LinearBase): return GPTQMarlin24LinearMethod(self) return None diff --git a/vllm/model_executor/layers/quantization/marlin.py b/vllm/model_executor/layers/quantization/marlin.py index f0a9cf5520bdd..cdc5129a93b15 100644 --- a/vllm/model_executor/layers/quantization/marlin.py +++ b/vllm/model_executor/layers/quantization/marlin.py @@ -100,8 +100,8 @@ def override_quantization_method(cls, hf_quant_cfg, return None - def get_quant_method( - self, layer: torch.nn.Module) -> Optional["MarlinLinearMethod"]: + def get_quant_method(self, layer: torch.nn.Module, + prefix: str) -> Optional["MarlinLinearMethod"]: if (isinstance(layer, LinearBase) or (isinstance(layer, ParallelLMHead) and self.lm_head_quantized)): return MarlinLinearMethod(self) diff --git a/vllm/model_executor/layers/quantization/squeezellm.py b/vllm/model_executor/layers/quantization/squeezellm.py index 72ba55eb1740d..afb3c04976737 100644 --- a/vllm/model_executor/layers/quantization/squeezellm.py +++ b/vllm/model_executor/layers/quantization/squeezellm.py @@ -52,8 +52,8 @@ def from_config(cls, config: Dict[str, Any]) -> "SqueezeLLMConfig": weight_bits = cls.get_from_keys(config, ["wbits"]) return cls(weight_bits) - def get_quant_method( - self, layer: torch.nn.Module) -> Optional[QuantizeMethodBase]: + def get_quant_method(self, layer: torch.nn.Module, + prefix: str) -> Optional[QuantizeMethodBase]: if isinstance(layer, LinearBase): return SqueezeLLMLinearMethod(self) return None diff --git a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py index 0729a2d7f8ddd..4fbf75b2ff090 100644 --- a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py @@ -105,6 +105,7 @@ def apply_fp8_linear( weight: torch.Tensor, weight_scale: torch.Tensor, input_scale: torch.Tensor, + input_scale_ub: Optional[torch.Tensor] = None, bias: Optional[torch.Tensor] = None, cutlass_fp8_supported: bool = True, use_per_token_if_dynamic: bool = False, @@ -118,6 +119,7 @@ def apply_fp8_linear( qinput, x_scale = ops.scaled_fp8_quant( input, input_scale, + scale_ub=input_scale_ub, use_per_token_if_dynamic=use_per_token_if_dynamic) # Fused GEMM_DQ diff --git a/vllm/model_executor/layers/vocab_parallel_embedding.py b/vllm/model_executor/layers/vocab_parallel_embedding.py index d70eb1c2704b4..74aeb964274b0 100644 --- a/vllm/model_executor/layers/vocab_parallel_embedding.py +++ b/vllm/model_executor/layers/vocab_parallel_embedding.py @@ -161,6 +161,7 @@ class VocabParallelEmbedding(torch.nn.Module): org_num_embeddings: original vocabulary size (without LoRA). padding_size: padding size for the vocabulary. quant_config: quant config for the layer + prefix: full name of the layer in the state dict """ # noqa: E501 def __init__(self, @@ -169,7 +170,8 @@ def __init__(self, params_dtype: Optional[torch.dtype] = None, org_num_embeddings: Optional[int] = None, padding_size: int = DEFAULT_VOCAB_PADDING_SIZE, - quant_config: Optional[QuantizationConfig] = None): + quant_config: Optional[QuantizationConfig] = None, + prefix: str = ""): super().__init__() # Keep the input dimensions. @@ -195,7 +197,7 @@ def __init__(self, linear_method = None if quant_config is not None: - linear_method = quant_config.get_quant_method(self) + linear_method = quant_config.get_quant_method(self, prefix=prefix) if linear_method is None: linear_method = UnquantizedLinearMethod() self.linear_method: QuantizeMethodBase = linear_method @@ -382,9 +384,11 @@ def __init__(self, params_dtype: Optional[torch.dtype] = None, org_num_embeddings: Optional[int] = None, padding_size: int = DEFAULT_VOCAB_PADDING_SIZE, - quant_config: Optional[QuantizationConfig] = None): + quant_config: Optional[QuantizationConfig] = None, + prefix: str = ""): super().__init__(num_embeddings, embedding_dim, params_dtype, - org_num_embeddings, padding_size, quant_config) + org_num_embeddings, padding_size, quant_config, + prefix) if bias: self.bias = Parameter( torch.empty(self.num_embeddings_per_partition,