From cd7a46869362630faa6505079e208b88d668806c Mon Sep 17 00:00:00 2001 From: Michael Goin Date: Tue, 5 Nov 2024 13:42:20 -0500 Subject: [PATCH] [Model] Support quantization of PixtralHFTransformer for PixtralHF (#9921) Signed-off-by: mgoin Signed-off-by: Maxime Fournioux <55544262+mfournioux@users.noreply.github.com> --- vllm/model_executor/layers/activation.py | 30 +++++++ vllm/model_executor/models/pixtral.py | 100 ++++++++++++++--------- 2 files changed, 90 insertions(+), 40 deletions(-) diff --git a/vllm/model_executor/layers/activation.py b/vllm/model_executor/layers/activation.py index 658a3700f33d6..e347ca80ff765 100644 --- a/vllm/model_executor/layers/activation.py +++ b/vllm/model_executor/layers/activation.py @@ -299,3 +299,33 @@ def get_act_fn( return ScaledActivation(act_fn, intermediate_size, input_is_parallel, params_dtype) return act_fn + + +_ACTIVATION_AND_MUL_REGISTRY = LazyDict({ + "gelu": lambda: GeluAndMul(), + "silu": lambda: SiluAndMul(), +}) + + +def get_act_and_mul_fn( + act_fn_name: str, + quant_config: Optional[QuantizationConfig] = None, + intermediate_size: Optional[int] = None, + input_is_parallel: bool = True, + params_dtype: Optional[torch.dtype] = None, +) -> nn.Module: + """Get an activation-and-mul (i.e. SiluAndMul) function by name.""" + act_fn_name = act_fn_name.lower() + if act_fn_name not in _ACTIVATION_AND_MUL_REGISTRY: + raise ValueError( + f"Activation function {act_fn_name!r} is not supported.") + + act_fn = _ACTIVATION_AND_MUL_REGISTRY[act_fn_name] + if (quant_config is not None + and act_fn_name in quant_config.get_scaled_act_names()): + if intermediate_size is None: + raise ValueError("intermediate_size must be specified for scaled " + "activation functions.") + return ScaledActivation(act_fn, intermediate_size, input_is_parallel, + params_dtype) + return act_fn diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py index 051454c49bff8..ee9f150b17cfc 100644 --- a/vllm/model_executor/models/pixtral.py +++ b/vllm/model_executor/models/pixtral.py @@ -19,8 +19,11 @@ from vllm.config import CacheConfig, ModelConfig, MultiModalConfig from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, InputContext, token_inputs) -from vllm.model_executor.layers.activation import get_act_fn +from vllm.model_executor.layers.activation import get_act_and_mul_fn from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -798,20 +801,24 @@ def __init__( super().__init__() assert config.intermediate_size is not None - # TODO: Use quant_config and prefix after optimizing this - self.gate_proj = nn.Linear(config.hidden_size, - config.intermediate_size, - bias=False) - self.up_proj = nn.Linear(config.hidden_size, - config.intermediate_size, - bias=False) - self.down_proj = nn.Linear(config.intermediate_size, - config.hidden_size, - bias=False) - self.act = get_act_fn(config.hidden_act) + self.gate_up_proj = MergedColumnParallelLinear( + input_size=config.hidden_size, + output_sizes=[config.intermediate_size] * 2, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj") + self.down_proj = RowParallelLinear(input_size=config.intermediate_size, + output_size=config.hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.down_proj") + self.act_and_mul = get_act_and_mul_fn(config.hidden_act) def forward(self, x: torch.Tensor) -> torch.Tensor: - return self.down_proj(self.act(self.gate_proj(x)) * self.up_proj(x)) + gate_up, _ = self.gate_up_proj(x) + x = self.act_and_mul(gate_up) + x, _ = self.down_proj(x) + return x class PixtralHFAttention(nn.Module): @@ -830,21 +837,21 @@ def __init__( self.n_heads = config.num_attention_heads self.head_dim = config.hidden_size // config.num_attention_heads - self.scale = self.head_dim**-0.5 - - # TODO: Use quant_config and prefix after optimizing this - self.q_proj = nn.Linear(config.hidden_size, - config.hidden_size, - bias=False) - self.k_proj = nn.Linear(config.hidden_size, - config.hidden_size, - bias=False) - self.v_proj = nn.Linear(config.hidden_size, - config.hidden_size, - bias=False) - self.o_proj = nn.Linear(config.hidden_size, - config.hidden_size, - bias=False) + self.qkv_proj = QKVParallelLinear( + hidden_size=config.hidden_size, + head_size=self.head_dim, + total_num_heads=self.n_heads, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + self.o_proj = RowParallelLinear( + input_size=config.hidden_size, + output_size=config.hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + ) def forward( self, @@ -854,13 +861,13 @@ def forward( ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: batch, patches, _ = hidden_states.size() - q = self.q_proj(hidden_states) - k = self.k_proj(hidden_states) - v = self.v_proj(hidden_states) + qkv_states, _ = self.qkv_proj(hidden_states) + q, k, v = qkv_states.chunk(3, dim=-1) # Transpose q and k to apply HF's Rotary Position Embedding q = q.view(batch, patches, self.n_heads, self.head_dim).transpose(1, 2) k = k.view(batch, patches, self.n_heads, self.head_dim).transpose(1, 2) + v = v.view(batch, patches, self.n_heads, self.head_dim) cos, sin = position_embeddings q, k = apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=0) @@ -868,22 +875,21 @@ def forward( # Transpose q and k back for attention q = q.transpose(1, 2).contiguous() k = k.transpose(1, 2).contiguous() - v = v.reshape(batch, patches, self.n_heads, self.head_dim) out = xops.memory_efficient_attention(q, k, v, attn_bias=attention_mask) else: - v = v.reshape(batch, patches, self.n_heads, - self.head_dim).transpose(1, 2) + v = v.transpose(1, 2) out = nn.functional.scaled_dot_product_attention( q, k, v, attn_mask=attention_mask) out = out.transpose(1, 2) - out = out.reshape(batch, patches, self.n_heads * self.head_dim) + out = out.view(batch, patches, self.n_heads * self.head_dim) + attn_output, _ = self.o_proj(out) - return self.o_proj(out) + return attn_output, None class PixtralHFTransformerBlock(nn.Module): @@ -912,9 +918,9 @@ def forward( attention_mask: torch.Tensor, position_embeddings: torch.Tensor, ) -> torch.Tensor: - r = self.attention.forward(self.attention_norm(hidden_states), - attention_mask=attention_mask, - position_embeddings=position_embeddings) + r, _ = self.attention.forward(self.attention_norm(hidden_states), + attention_mask=attention_mask, + position_embeddings=position_embeddings) h = hidden_states + r r = self.feed_forward.forward(self.ffn_norm(h)) out = h + r @@ -1053,10 +1059,24 @@ def forward( # (TODO) Add prefix argument for filtering out weights to be loaded # ref: https://github.com/vllm-project/vllm/pull/7186#discussion_r1734163986 def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): - stacked_params_mapping = [] + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + (".qkv_proj", ".q_proj", "q"), + (".qkv_proj", ".k_proj", "k"), + (".qkv_proj", ".v_proj", "v"), + (".gate_up_proj", ".gate_proj", 0), + (".gate_up_proj", ".up_proj", 1), + ] params_dict = dict(self.named_parameters()) + layer_count = len(self.transformer.layers) for name, loaded_weight in weights: + # omit layers when num_hidden_layers_override is set + if name.startswith("transformer.layers"): + layer_idx = int(name.split(".")[2]) + if layer_idx >= layer_count: + continue + for (param_name, weight_name, shard_id) in stacked_params_mapping: if weight_name not in name: continue