From 80fb04d1b1ecee96da7ffdc9a9567490e34fc17b Mon Sep 17 00:00:00 2001 From: Robert Shaw <114415538+robertgshaw2-neuralmagic@users.noreply.github.com> Date: Thu, 18 Jul 2024 22:39:18 -0400 Subject: [PATCH] [ Misc ] non-uniform quantization via `compressed-tensors` for `Llama` (#6515) --- ...nstruct-nonuniform-compressed-tensors.yaml | 11 ++ .../lm-eval-harness/configs/models-small.txt | 1 + vllm/model_executor/layers/fused_moe/layer.py | 1 + vllm/model_executor/layers/linear.py | 44 ++++-- .../compressed_tensors/compressed_tensors.py | 92 +++++++----- .../schemes/compressed_tensors_unquantized.py | 1 - .../quantization/compressed_tensors/utils.py | 139 +++++++++++++++--- vllm/model_executor/models/gpt2.py | 27 +++- vllm/model_executor/models/llama.py | 25 +++- vllm/model_executor/models/mixtral.py | 30 +++- vllm/model_executor/models/utils.py | 19 ++- 11 files changed, 300 insertions(+), 90 deletions(-) create mode 100644 .buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct-nonuniform-compressed-tensors.yaml diff --git a/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct-nonuniform-compressed-tensors.yaml b/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct-nonuniform-compressed-tensors.yaml new file mode 100644 index 0000000000000..3964f3be5e874 --- /dev/null +++ b/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct-nonuniform-compressed-tensors.yaml @@ -0,0 +1,11 @@ +# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m nm-testing/Meta-Llama-3-8B-Instruct-nonuniform-test -b auto -l 1000 -f 5 -t 1 +model_name: "nm-testing/Meta-Llama-3-8B-Instruct-nonuniform-test" +tasks: +- name: "gsm8k" + metrics: + - name: "exact_match,strict-match" + value: 0.758 + - name: "exact_match,flexible-extract" + value: 0.759 +limit: 1000 +num_fewshot: 5 diff --git a/.buildkite/lm-eval-harness/configs/models-small.txt b/.buildkite/lm-eval-harness/configs/models-small.txt index 3d1306f6bc4f1..869fc9cef3778 100644 --- a/.buildkite/lm-eval-harness/configs/models-small.txt +++ b/.buildkite/lm-eval-harness/configs/models-small.txt @@ -2,4 +2,5 @@ Meta-Llama-3-8B-Instruct.yaml Meta-Llama-3-8B-Instruct-FP8.yaml Meta-Llama-3-8B-Instruct-FP8-compressed-tensors.yaml Meta-Llama-3-8B-Instruct-INT8-compressed-tensors.yaml +Meta-Llama-3-8B-Instruct-nonuniform-compressed-tensors.yaml Qwen2-1.5B-Instruct-INT8-compressed-tensors.yaml diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index bb2be3f3eb56f..a6fa8ffe5111c 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -158,6 +158,7 @@ def __init__( topk_group: Optional[int] = None, quant_config: Optional[QuantizationConfig] = None, tp_size: Optional[int] = None, + prefix: str = "", ): super().__init__() diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 684e1abf7bcf7..86d15207fb6bd 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -171,6 +171,8 @@ class ReplicatedLinear(LinearBase): skip_bias_add: If true, skip adding bias but instead return it. params_dtype: Data type for the parameters. quant_config: Quantization configure. + prefix: The name of the layer in the state dict, including all parents + (e.g. model.layers.0.qkv_proj) """ def __init__(self, @@ -179,15 +181,19 @@ def __init__(self, bias: bool = True, skip_bias_add: bool = False, params_dtype: Optional[torch.dtype] = None, - quant_config: Optional[QuantizationConfig] = None): + quant_config: Optional[QuantizationConfig] = None, + prefix: Optional[str] = None): super().__init__(input_size, output_size, skip_bias_add, params_dtype, quant_config) # All the linear layer supports quant method. assert self.quant_method is not None - self.quant_method.create_weights(self, self.input_size, - [self.output_size], self.input_size, - self.output_size, self.params_dtype) + self.quant_method.create_weights(self, + self.input_size, [self.output_size], + self.input_size, + self.output_size, + self.params_dtype, + prefix=prefix) if bias: self.bias = Parameter( @@ -239,6 +245,8 @@ class ColumnParallelLinear(LinearBase): quant_config: Quantization configure. output_sizes: list of output sizes packed into one output, like for QKV the list would be size 3. + prefix: The name of the layer in the state dict, including all parents + (e.g. model.layers.0.qkv_proj) """ def __init__(self, @@ -249,7 +257,8 @@ def __init__(self, skip_bias_add: bool = False, params_dtype: Optional[torch.dtype] = None, quant_config: Optional[QuantizationConfig] = None, - output_sizes: Optional[List[int]] = None): + output_sizes: Optional[List[int]] = None, + prefix: Optional[str] = None): super().__init__(input_size, output_size, skip_bias_add, params_dtype, quant_config) @@ -276,7 +285,8 @@ def __init__(self, input_size=self.input_size, output_size=self.output_size, params_dtype=self.params_dtype, - weight_loader=self.weight_loader) + weight_loader=self.weight_loader, + prefix=prefix) if bias: self.bias = Parameter( torch.empty(self.output_size_per_partition, @@ -348,6 +358,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear): skip adding bias but instead return it. params_dtype: Data type for the parameters. quant_config: Quantization configure. + prefix: The name of the layer in the state dict, including all parents + (e.g. model.layers.0.qkv_proj) """ def __init__(self, @@ -357,7 +369,8 @@ def __init__(self, gather_output: bool = False, skip_bias_add: bool = False, params_dtype: Optional[torch.dtype] = None, - quant_config: Optional[QuantizationConfig] = None): + quant_config: Optional[QuantizationConfig] = None, + prefix: Optional[str] = None): self.output_sizes = output_sizes tp_size = get_tensor_model_parallel_world_size() assert all(output_size % tp_size == 0 for output_size in output_sizes) @@ -367,7 +380,8 @@ def __init__(self, gather_output=gather_output, skip_bias_add=skip_bias_add, params_dtype=params_dtype, - quant_config=quant_config) + quant_config=quant_config, + prefix=prefix) def weight_loader(self, param: Parameter, @@ -487,6 +501,8 @@ class QKVParallelLinear(ColumnParallelLinear): skip adding bias but instead return it. params_dtype: Data type for the parameters. quant_config: Quantization configure. + prefix: The name of the layer in the state dict, including all parents + (e.g. model.layers.0.qkv_proj) """ def __init__(self, @@ -497,7 +513,8 @@ def __init__(self, bias: bool = True, skip_bias_add: bool = False, params_dtype: Optional[torch.dtype] = None, - quant_config: Optional[QuantizationConfig] = None): + quant_config: Optional[QuantizationConfig] = None, + prefix: Optional[str] = None): self.hidden_size = hidden_size self.head_size = head_size self.total_num_heads = total_num_heads @@ -529,7 +546,8 @@ def __init__(self, gather_output=False, skip_bias_add=skip_bias_add, params_dtype=params_dtype, - quant_config=quant_config) + quant_config=quant_config, + prefix=prefix) def weight_loader(self, param: Parameter, @@ -688,7 +706,8 @@ def __init__(self, skip_bias_add: bool = False, params_dtype: Optional[torch.dtype] = None, reduce_results: bool = True, - quant_config: Optional[QuantizationConfig] = None): + quant_config: Optional[QuantizationConfig] = None, + prefix: Optional[str] = None): super().__init__(input_size, output_size, skip_bias_add, params_dtype, quant_config) @@ -706,7 +725,8 @@ def __init__(self, input_size=self.input_size, output_size=self.output_size, params_dtype=self.params_dtype, - weight_loader=self.weight_loader) + weight_loader=self.weight_loader, + prefix=prefix) if not reduce_results and (bias and not skip_bias_add): raise ValueError("When not reduce the results, adding bias to the " "results can lead to incorrect results") diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index 659f5a599dc14..28c552b3654f3 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -8,23 +8,25 @@ QuantizationConfig) from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( W4A16SPARSE24_SUPPORTED_BITS, WNA16_SUPPORTED_BITS, - CompressedTensorsScheme, CompressedTensorsW4A16Sparse24, - CompressedTensorsW8A8Fp8, CompressedTensorsW8A8Int8, - CompressedTensorsWNA16) + CompressedTensorsScheme, CompressedTensorsUnquantized, + CompressedTensorsW4A16Sparse24, CompressedTensorsW8A8Fp8, + CompressedTensorsW8A8Int8, CompressedTensorsWNA16) from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( CompressionFormat, QuantizationArgs, QuantizationStrategy, - QuantizationType, find_first_name_or_class_match, - is_activation_quantization_format) + QuantizationType, find_matched_target, is_activation_quantization_format, + should_ignore_layer) from vllm.platforms import current_platform class CompressedTensorsConfig(QuantizationConfig): - def __init__(self, layer_quant_details: Dict[str, Any], ignore: List[str], + def __init__(self, target_scheme_map: Dict[str, Any], ignore: List[str], quant_format: str): + self.ignore = ignore - self.layer_quant_details = layer_quant_details self.quant_format = quant_format + # Map from [target -> scheme] + self.target_scheme_map = target_scheme_map def get_linear_method(self) -> "CompressedTensorsLinearMethod": return CompressedTensorsLinearMethod(self) @@ -51,7 +53,7 @@ def get_quant_method( @classmethod def from_config(cls, config: Dict[str, Any]) -> "CompressedTensorsConfig": - layer_quant_details: Dict[str, Any] = dict() + target_scheme_map: Dict[str, Any] = dict() ignore: List[str] = config.get("ignore", None) quant_format: str = config.get("format", None) @@ -63,21 +65,21 @@ def from_config(cls, config: Dict[str, Any]) -> "CompressedTensorsConfig": # details follow the structure defined by the QuantizationArgs # pydantic model, which is used to verify the structure of the # quant_config and also store the details for later use. - for key, quant_config in config["config_groups"].items(): + for _, quant_config in config["config_groups"].items(): targets = quant_config.get("targets") for target in targets: - layer_quant_details[target] = {} - layer_quant_details[target][ + target_scheme_map[target] = {} + target_scheme_map[target][ "weights"] = QuantizationArgs.parse_obj( quant_config.get("weights")) try: - layer_quant_details[target][ + target_scheme_map[target][ "input_activations"] = QuantizationArgs.parse_obj( quant_config.get("input_activations")) except Exception: - layer_quant_details[target]["input_activations"] = None + target_scheme_map[target]["input_activations"] = None - return cls(layer_quant_details=layer_quant_details, + return cls(target_scheme_map=target_scheme_map, ignore=ignore, quant_format=quant_format) @@ -167,8 +169,9 @@ def _is_wNa16_group_channel(self, weight_quant: BaseModel, return (is_channel_group and input_quant_none and is_symmetric and is_static) - def _get_schema(self, weight_quant: BaseModel, - input_quant: BaseModel) -> "CompressedTensorsScheme": + def _get_scheme_from_parts( + self, weight_quant: BaseModel, + input_quant: BaseModel) -> "CompressedTensorsScheme": # Detect If Mixed Precision if self._is_wNa16_group_channel(weight_quant, input_quant): @@ -205,26 +208,47 @@ def _get_schema(self, weight_quant: BaseModel, raise NotImplementedError( "No compressed-tensors compatible scheme was found.") - def get_scheme(self, layer: torch.nn.Module) -> "CompressedTensorsScheme": + def get_scheme( + self, + layer: torch.nn.Module, + layer_name: Optional[str] = None) -> "CompressedTensorsScheme": + """ + compressed-tensors supports non uniform in the following way: + + ignore: List of layer_names or nn.Module names to be ignored. + targets of config_groups: There can be N config_groups which each + have a quantization scheme. Each config_group has a list of targets + which can be a full layer_name, a regex for a layer_name, or + an nn.Module name. - layer_type_name = find_first_name_or_class_match( - name="", - module=layer, - targets=self.layer_quant_details.keys(), - check_contains=True) + We first check whether a layer is in the ignore group and use + CompressedTensorsUnquantized (i.e. fp16/bf16) scheme for the layer - if layer_type_name is None: - raise ValueError(f"Could not matching target for layer {layer}") + We then detect whether a layer_name is found in any target and + use the quantization scheme corresponding to the matched target + to select the CompressedTensorsScheme used for infernece. + """ + + # Check if the layer is skipped for quantization. + # TODO (@robertgshaw2): support module names + if should_ignore_layer(layer_name, ignore=self.ignore): + return CompressedTensorsUnquantized() + + # Find the "target" in the compressed-tensors config + # that our layer conforms to. + # TODO (@robertgshaw): add compressed-tensors as dep + # so we do not have to re-write these functions + matched_target = find_matched_target( + layer_name=layer_name, + module=layer, + targets=self.target_scheme_map.keys()) - layer_quant_details: Dict[str, Any] = self.layer_quant_details.get( - layer_type_name, None) - if layer_quant_details is None: - raise ValueError( - f"Could not find quantization details for {layer}.") + # Find the quant_scheme + scheme = self.target_scheme_map[matched_target] - scheme = self._get_schema( - weight_quant=layer_quant_details["weights"], - input_quant=layer_quant_details["input_activations"]) + return self._get_scheme_from_parts( + weight_quant=scheme["weights"], + input_quant=scheme["input_activations"]) # Raise error if device does not support the scheme # (e.g. fp8 needs ada lovelace) @@ -250,11 +274,11 @@ def create_weights(self, layer: torch.nn.Module, Use the CompressedTensorsScheme associated with each layer to create the necessary parameters for the layer. See LinearMethodBase for param details - """ weight_loader = extra_weight_attrs.get("weight_loader") + layer_name = extra_weight_attrs.get("prefix") - scheme = self.quantization_config.get_scheme(layer=layer) + scheme = self.quantization_config.get_scheme(layer, layer_name) scheme.create_weights( layer=layer, input_size=input_size, diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_unquantized.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_unquantized.py index 4350ff4e90ae8..6203f02d25e90 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_unquantized.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_unquantized.py @@ -33,7 +33,6 @@ def create_weights(self, layer: torch.nn.Module, weight = Parameter(torch.empty(sum(output_partition_sizes), input_size_per_partition, - device="cuda", dtype=params_dtype), requires_grad=False) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/utils.py b/vllm/model_executor/layers/quantization/compressed_tensors/utils.py index 25db308753eee..b3110ce653308 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/utils.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/utils.py @@ -86,25 +86,106 @@ def is_activation_quantization_format(format: str) -> bool: return format in _ACTIVATION_QUANTIZATION_FORMATS -def find_first_name_or_class_match( - name: str, - module: Module, - targets: Iterable[str], - check_contains: bool = False) -> Optional[str]: +# fused_name: List[shard_name] +_FUSED_LAYER_NAME_MAPPING = { + "qkv_proj": ["q_proj", "k_proj", "v_proj"], + "gate_up_proj": ["gate_proj", "up_proj"] +} + + +def should_ignore_layer(layer_name: Optional[str], + ignore: Iterable[str]) -> bool: + if layer_name is None: + return False + + # layer_name = model.layers.0.self_attn.qkv_proj + # proj_name = qkv_proj + proj_name = layer_name.split(".")[-1] + + # Fused layers like gate_up_proj or qkv_proj will not be fused + # in the safetensors checkpoint. So, we convert the name + # from the fused version to unfused + check to make sure that + # each shard of the fused layer has the same scheme. + if proj_name in _FUSED_LAYER_NAME_MAPPING: + shard_proj_names = _FUSED_LAYER_NAME_MAPPING[proj_name] + + # Convert fused_name --> [shard_names] + shard_names = [ + layer_name.replace(proj_name, shard_proj_name) + for shard_proj_name in shard_proj_names + ] + + # Layer should be ignored if shards are ignored. + should_ignore_layer = None + for shard_name in shard_names: + should_ignore_shard = check_equal_or_regex_match( + layer_name=shard_name, targets=ignore) + + # If shard_idx=0, set layer ignore to match shard. + if should_ignore_layer is None: + should_ignore_layer = should_ignore_shard + + # If shard_idx=1+ confirm scheme matches prior shards. + elif should_ignore_shard != should_ignore_layer: + raise ValueError(f"Found a different quantization schemes for " + f"{shard_proj_names} in {layer_name}. vLLM " + "requires all to use the same scheme.") + + # Unfused layers like down_proj and o_proj will match + # the safetensors checkpoint already. + else: + should_ignore_layer = check_equal_or_regex_match(layer_name=layer_name, + targets=ignore) + + assert should_ignore_layer is not None + return should_ignore_layer + + +def check_equal_or_regex_match(layer_name: str, + targets: Iterable[str]) -> bool: """ - Helper function to map the quantization details listed in the config - for a given list of targets against each model layer. First uses the - layer name to try and find a match. If no name match is found, uses - the layer class name. Returns None otherwise. + Checks whether a layer_name is exactly equal or a regex match for + if target starts with 're:' to any target in list. + """ + for target in targets: + if _is_equal_or_regex_match(layer_name, target): + return True + return False + + +def find_matched_target(layer_name: Optional[str], module: Module, + targets: Iterable[str]) -> str: + """ + Helper function to look up which "target" in the compressed-tensors + config that a layer corresponds to. - :param name: layer name + Recall that a compressed-tensors configs has a concept of + config_groups, where each layer can be quantized with with a different + scheme. + + targets in each config_group will be a list of either layer names + (or regexes corresponding to layer names) or names of torch Modules. + + First, we try to match the layer_name with a target + Second, we try to match the module's name with a target + + :param layer_name: layer name :param module: torch.nn.Module :param targets: list of targets to match the layer against - :param check_contains: whether or not to do a substring match """ - return _find_first_match(name, targets) or _find_first_match( - module.__class__.__name__, targets, check_contains) + if layer_name is None: + layer_name = "" + + matched_target = (_find_first_match(layer_name, targets) + or _find_first_match(module.__class__.__name__, targets, + True)) + + if matched_target is None: + raise ValueError(f"Unable to find matching target for {module} in the " + "compressed-tensors config.") + + return matched_target def _find_first_match(value: str, @@ -121,13 +202,29 @@ def _find_first_match(value: str, """ for target in targets: - if target.startswith("re:"): - pattern = target[3:] - if re.match(pattern, value): - return target - elif check_contains: - if target.lower() in value.lower(): - return target - elif target == value: + if _is_equal_or_regex_match(value, + target, + check_contains=check_contains): return target return None + + +def _is_equal_or_regex_match(value: str, + target: str, + check_contains: bool = False) -> bool: + """ + Checks whether a value is exactly equal or a regex match for target + if target starts with 're:'. If check_contains is set to True, + additionally checks if the target string is contained within the value. + """ + + if target.startswith("re:"): + pattern = target[3:] + if re.match(pattern, value): + return True + elif check_contains: + if target.lower() in value.lower(): + return True + elif target == value: + return True + return False diff --git a/vllm/model_executor/models/gpt2.py b/vllm/model_executor/models/gpt2.py index d309a2b27f5dd..94cd67e75336a 100644 --- a/vllm/model_executor/models/gpt2.py +++ b/vllm/model_executor/models/gpt2.py @@ -51,6 +51,7 @@ def __init__( config: GPT2Config, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() self.hidden_size = config.hidden_size @@ -68,12 +69,14 @@ def __init__( total_num_heads, bias=True, quant_config=quant_config, + prefix=f"{prefix}.c_attn", ) self.c_proj = RowParallelLinear( self.hidden_size, self.hidden_size, bias=True, quant_config=quant_config, + prefix=f"{prefix}.c_proj", ) self.attn = Attention(self.num_heads, self.head_dim, @@ -101,6 +104,7 @@ def __init__( intermediate_size: int, config: GPT2Config, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() hidden_size = config.hidden_size @@ -109,12 +113,14 @@ def __init__( intermediate_size, bias=True, quant_config=quant_config, + prefix=f"{prefix}.c_fc", ) self.c_proj = RowParallelLinear( intermediate_size, hidden_size, bias=True, quant_config=quant_config, + prefix=f"{prefix}.c_proj", ) self.act = get_act_fn(config.activation_function, quant_config, intermediate_size) @@ -133,6 +139,7 @@ def __init__( config: GPT2Config, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() hidden_size = config.hidden_size @@ -140,9 +147,15 @@ def __init__( hidden_size) self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) - self.attn = GPT2Attention(config, cache_config, quant_config) + self.attn = GPT2Attention(config, + cache_config, + quant_config, + prefix=f"{prefix}.attn") self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) - self.mlp = GPT2MLP(inner_dim, config, quant_config) + self.mlp = GPT2MLP(inner_dim, + config, + quant_config, + prefix=f"{prefix}.mlp") def forward( self, @@ -175,6 +188,7 @@ def __init__( config: GPT2Config, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() self.config = config @@ -186,7 +200,9 @@ def __init__( self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim) self.start_layer, self.end_layer, self.h = make_layers( config.num_hidden_layers, - lambda: GPT2Block(config, cache_config, quant_config)) + lambda prefix: GPT2Block( + config, cache_config, quant_config, prefix=prefix), + prefix=f"{prefix}.h") self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) def forward( @@ -229,7 +245,10 @@ def __init__( super().__init__() self.config = config self.quant_config = quant_config - self.transformer = GPT2Model(config, cache_config, quant_config) + self.transformer = GPT2Model(config, + cache_config, + quant_config, + prefix="transformer") self.lm_head = self.transformer.wte self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 08f449f20305a..d052113e79892 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -62,17 +62,20 @@ def __init__( hidden_act: str, quant_config: Optional[QuantizationConfig] = None, bias: bool = False, + prefix: str = "", ) -> None: super().__init__() self.gate_up_proj = MergedColumnParallelLinear( input_size=hidden_size, output_sizes=[intermediate_size] * 2, bias=bias, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj") self.down_proj = RowParallelLinear(input_size=intermediate_size, output_size=hidden_size, bias=bias, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.down_proj") if hidden_act != "silu": raise ValueError(f"Unsupported activation: {hidden_act}. " "Only silu is supported for now.") @@ -99,6 +102,7 @@ def __init__( quant_config: Optional[QuantizationConfig] = None, bias: bool = False, cache_config: Optional[CacheConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.hidden_size = hidden_size @@ -132,12 +136,14 @@ def __init__( total_num_kv_heads=self.total_num_kv_heads, bias=bias, quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", ) self.o_proj = RowParallelLinear( input_size=self.total_num_heads * self.head_dim, output_size=hidden_size, bias=bias, quant_config=quant_config, + prefix=f"{prefix}.o_proj", ) self.rotary_emb = get_rope( @@ -176,6 +182,7 @@ def __init__( config: LlamaConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.hidden_size = config.hidden_size @@ -203,6 +210,7 @@ def __init__( quant_config=quant_config, bias=attention_bias, cache_config=cache_config, + prefix=f"{prefix}.self_attn", ) self.mlp = LlamaMLP( hidden_size=self.hidden_size, @@ -210,6 +218,7 @@ def __init__( hidden_act=config.hidden_act, quant_config=quant_config, bias=getattr(config, "mlp_bias", False), + prefix=f"{prefix}.mlp", ) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -253,6 +262,7 @@ def __init__( cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.config = config @@ -272,9 +282,11 @@ def __init__( self.embed_tokens = PPMissingLayer() self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, - lambda: LlamaDecoderLayer(config=config, - cache_config=cache_config, - quant_config=quant_config)) + lambda prefix: LlamaDecoderLayer(config=config, + cache_config=cache_config, + quant_config=quant_config, + prefix=prefix), + prefix=f"{prefix}.layers") if get_pp_group().is_last_rank: self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) else: @@ -370,7 +382,8 @@ def __init__( self.model = LlamaModel(config, cache_config, quant_config, - lora_config=lora_config) + lora_config=lora_config, + prefix="model") if get_pp_group().is_last_rank: self.unpadded_vocab_size = config.vocab_size if lora_config: diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index 28dbcb30bdf55..8fbd537a2c031 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -67,7 +67,8 @@ def __init__(self, intermediate_size: int, params_dtype: Optional[torch.dtype] = None, quant_config: Optional[QuantizationConfig] = None, - tp_size: Optional[int] = None): + tp_size: Optional[int] = None, + prefix: str = ""): super().__init__() self.hidden_size = hidden_size @@ -76,7 +77,8 @@ def __init__(self, num_experts, bias=False, params_dtype=params_dtype, - quant_config=None) + quant_config=None, + prefix=f"{prefix}.gate") self.experts = FusedMoE(num_experts=num_experts, top_k=top_k, @@ -86,7 +88,8 @@ def __init__(self, reduce_results=True, renormalize=True, quant_config=quant_config, - tp_size=tp_size) + tp_size=tp_size, + prefix=f"{prefix}.experts") def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # NOTE: hidden_states can have either 1D or 2D shape. @@ -109,6 +112,7 @@ def __init__( rope_theta: float = 10000, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.hidden_size = hidden_size @@ -139,12 +143,14 @@ def __init__( self.total_num_kv_heads, bias=False, quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", ) self.o_proj = RowParallelLinear( self.total_num_heads * self.head_dim, hidden_size, bias=False, quant_config=quant_config, + prefix=f"{prefix}.o_proj", ) self.rotary_emb = get_rope( self.head_dim, @@ -182,6 +188,7 @@ def __init__( config: MixtralConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.hidden_size = config.hidden_size @@ -194,13 +201,15 @@ def __init__( num_kv_heads=config.num_key_value_heads, rope_theta=rope_theta, cache_config=cache_config, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.self_attn") self.block_sparse_moe = MixtralMoE( num_experts=config.num_local_experts, top_k=config.num_experts_per_tok, hidden_size=config.hidden_size, intermediate_size=config.intermediate_size, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.block_sparse_moe") self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = RMSNorm(config.hidden_size, @@ -243,6 +252,7 @@ def __init__( cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.padding_idx = config.pad_token_id @@ -258,8 +268,11 @@ def __init__( ) self.start_layer, self.end_layer, self.layers = make_layers( - config.num_hidden_layers, lambda: MixtralDecoderLayer( - config, cache_config, quant_config=quant_config)) + config.num_hidden_layers, + lambda prefix: MixtralDecoderLayer( + config, cache_config, quant_config=quant_config, prefix=prefix + ), + prefix=f"{prefix}.layers") self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -331,7 +344,8 @@ def __init__( self.model = MixtralModel(config, cache_config, quant_config, - lora_config=lora_config) + lora_config=lora_config, + prefix="model") self.unpadded_vocab_size = config.vocab_size if lora_config: self.unpadded_vocab_size += lora_config.lora_extra_vocab_size diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index b505d32db5985..197d3839a766a 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -1,4 +1,4 @@ -from typing import Callable, Dict, List, Tuple +from typing import Dict, List, Protocol, Tuple import torch from torch.func import functional_call @@ -45,6 +45,15 @@ def merge_vision_embeddings(input_ids: torch.Tensor, return inputs_embeds +class LayerFn(Protocol): + + def __call__( + self, + prefix="", + ) -> torch.nn.Module: + ... + + class PPMissingLayer(torch.nn.Identity): """ A placeholder layer for missing layers in a pipeline parallel model. @@ -119,7 +128,9 @@ def forward(*args, **kwargs): def make_layers( - num_hidden_layers: int, layer_fn: Callable[[], torch.nn.Module] + num_hidden_layers: int, + layer_fn: LayerFn, + prefix: str, ) -> Tuple[int, int, torch.nn.ModuleList]: """Make a list of layers with the given layer function, taking pipeline parallelism into account. @@ -131,8 +142,8 @@ def make_layers( get_pp_group().world_size) modules = torch.nn.ModuleList( [PPMissingLayer() for _ in range(start_layer)] + [ - maybe_offload_to_cpu(layer_fn()) - for _ in range(start_layer, end_layer) + maybe_offload_to_cpu(layer_fn(prefix=f"{prefix}.{idx}")) + for idx in range(start_layer, end_layer) ] + [PPMissingLayer() for _ in range(end_layer, num_hidden_layers)]) return start_layer, end_layer, modules