Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Model] Support quantization of PixtralHFTransformer for PixtralHF #9921

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions vllm/model_executor/layers/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,3 +299,33 @@ def get_act_fn(
return ScaledActivation(act_fn, intermediate_size, input_is_parallel,
params_dtype)
return act_fn


_ACTIVATION_AND_MUL_REGISTRY = LazyDict({
"gelu": lambda: GeluAndMul(),
"silu": lambda: SiluAndMul(),
})


def get_act_and_mul_fn(
act_fn_name: str,
quant_config: Optional[QuantizationConfig] = None,
intermediate_size: Optional[int] = None,
input_is_parallel: bool = True,
params_dtype: Optional[torch.dtype] = None,
) -> nn.Module:
"""Get an activation-and-mul (i.e. SiluAndMul) function by name."""
act_fn_name = act_fn_name.lower()
if act_fn_name not in _ACTIVATION_AND_MUL_REGISTRY:
raise ValueError(
f"Activation function {act_fn_name!r} is not supported.")

act_fn = _ACTIVATION_AND_MUL_REGISTRY[act_fn_name]
if (quant_config is not None
and act_fn_name in quant_config.get_scaled_act_names()):
if intermediate_size is None:
raise ValueError("intermediate_size must be specified for scaled "
"activation functions.")
return ScaledActivation(act_fn, intermediate_size, input_is_parallel,
params_dtype)
return act_fn
100 changes: 60 additions & 40 deletions vllm/model_executor/models/pixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,11 @@
from vllm.config import CacheConfig, ModelConfig, MultiModalConfig
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
InputContext, token_inputs)
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.activation import get_act_and_mul_fn
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
Expand Down Expand Up @@ -798,20 +801,24 @@ def __init__(
super().__init__()

assert config.intermediate_size is not None
# TODO: Use quant_config and prefix after optimizing this
self.gate_proj = nn.Linear(config.hidden_size,
config.intermediate_size,
bias=False)
self.up_proj = nn.Linear(config.hidden_size,
config.intermediate_size,
bias=False)
self.down_proj = nn.Linear(config.intermediate_size,
config.hidden_size,
bias=False)
self.act = get_act_fn(config.hidden_act)
self.gate_up_proj = MergedColumnParallelLinear(
input_size=config.hidden_size,
output_sizes=[config.intermediate_size] * 2,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj")
self.down_proj = RowParallelLinear(input_size=config.intermediate_size,
output_size=config.hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.down_proj")
self.act_and_mul = get_act_and_mul_fn(config.hidden_act)

def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.down_proj(self.act(self.gate_proj(x)) * self.up_proj(x))
gate_up, _ = self.gate_up_proj(x)
x = self.act_and_mul(gate_up)
x, _ = self.down_proj(x)
return x
mgoin marked this conversation as resolved.
Show resolved Hide resolved


class PixtralHFAttention(nn.Module):
Expand All @@ -830,21 +837,21 @@ def __init__(
self.n_heads = config.num_attention_heads
self.head_dim = config.hidden_size // config.num_attention_heads

self.scale = self.head_dim**-0.5

# TODO: Use quant_config and prefix after optimizing this
self.q_proj = nn.Linear(config.hidden_size,
config.hidden_size,
bias=False)
self.k_proj = nn.Linear(config.hidden_size,
config.hidden_size,
bias=False)
self.v_proj = nn.Linear(config.hidden_size,
config.hidden_size,
bias=False)
self.o_proj = nn.Linear(config.hidden_size,
config.hidden_size,
bias=False)
self.qkv_proj = QKVParallelLinear(
hidden_size=config.hidden_size,
head_size=self.head_dim,
total_num_heads=self.n_heads,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
self.o_proj = RowParallelLinear(
input_size=config.hidden_size,
output_size=config.hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
)

def forward(
self,
Expand All @@ -854,36 +861,35 @@ def forward(
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
batch, patches, _ = hidden_states.size()

q = self.q_proj(hidden_states)
k = self.k_proj(hidden_states)
v = self.v_proj(hidden_states)
qkv_states, _ = self.qkv_proj(hidden_states)
q, k, v = qkv_states.chunk(3, dim=-1)

# Transpose q and k to apply HF's Rotary Position Embedding
q = q.view(batch, patches, self.n_heads, self.head_dim).transpose(1, 2)
k = k.view(batch, patches, self.n_heads, self.head_dim).transpose(1, 2)
v = v.view(batch, patches, self.n_heads, self.head_dim)
cos, sin = position_embeddings
q, k = apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=0)

if USE_XFORMERS_OPS:
# Transpose q and k back for attention
q = q.transpose(1, 2).contiguous()
k = k.transpose(1, 2).contiguous()
v = v.reshape(batch, patches, self.n_heads, self.head_dim)

out = xops.memory_efficient_attention(q,
k,
v,
attn_bias=attention_mask)
else:
v = v.reshape(batch, patches, self.n_heads,
self.head_dim).transpose(1, 2)
v = v.transpose(1, 2)
out = nn.functional.scaled_dot_product_attention(
q, k, v, attn_mask=attention_mask)
out = out.transpose(1, 2)

out = out.reshape(batch, patches, self.n_heads * self.head_dim)
out = out.view(batch, patches, self.n_heads * self.head_dim)
attn_output, _ = self.o_proj(out)

return self.o_proj(out)
return attn_output, None


class PixtralHFTransformerBlock(nn.Module):
Expand Down Expand Up @@ -912,9 +918,9 @@ def forward(
attention_mask: torch.Tensor,
position_embeddings: torch.Tensor,
) -> torch.Tensor:
r = self.attention.forward(self.attention_norm(hidden_states),
attention_mask=attention_mask,
position_embeddings=position_embeddings)
r, _ = self.attention.forward(self.attention_norm(hidden_states),
attention_mask=attention_mask,
position_embeddings=position_embeddings)
h = hidden_states + r
r = self.feed_forward.forward(self.ffn_norm(h))
out = h + r
Expand Down Expand Up @@ -1053,10 +1059,24 @@ def forward(
# (TODO) Add prefix argument for filtering out weights to be loaded
# ref: https://github.com/vllm-project/vllm/pull/7186#discussion_r1734163986
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = []
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".qkv_proj", ".q_proj", "q"),
(".qkv_proj", ".k_proj", "k"),
(".qkv_proj", ".v_proj", "v"),
(".gate_up_proj", ".gate_proj", 0),
(".gate_up_proj", ".up_proj", 1),
]
params_dict = dict(self.named_parameters())
layer_count = len(self.transformer.layers)

for name, loaded_weight in weights:
# omit layers when num_hidden_layers_override is set
if name.startswith("transformer.layers"):
layer_idx = int(name.split(".")[2])
if layer_idx >= layer_count:
continue

for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
Expand Down