Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bugfix] Clean up some cruft in mamba.py #9343

Merged
merged 4 commits into from
Oct 15, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/models/supported_models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ Text Generation
* - :code:`MambaForCausalLM`
- Mamba
- :code:`state-spaces/mamba-130m-hf`, :code:`state-spaces/mamba-790m-hf`, :code:`state-spaces/mamba-2.8b-hf`, etc.
- ✅︎
-
-
* - :code:`MiniCPMForCausalLM`
- MiniCPM
Expand Down
129 changes: 16 additions & 113 deletions vllm/model_executor/models/mamba.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# coding=utf-8
"""PyTorch MAMBA model."""
from dataclasses import dataclass
from typing import Iterable, List, Optional, Tuple

import torch
Expand All @@ -10,7 +9,6 @@
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear,
Expand Down Expand Up @@ -39,13 +37,6 @@
KVCache = Tuple[torch.Tensor, torch.Tensor]


@dataclass
class MambaCacheParams:
is_prompt: bool = False
conv_state: torch.Tensor = torch.Tensor()
ssm_state: torch.Tensor = torch.Tensor()


# Adapted from transformers.models.mamba.modeling_mamba.MambaMixer
class MambaMixer(nn.Module):
"""
Expand Down Expand Up @@ -209,37 +200,6 @@ def forward(self, hidden_states: torch.Tensor,
return contextualized_states


class MambaMLP(nn.Module):

def __init__(
self,
config: MambaConfig,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
hidden_size = config.hidden_size
intermediate_size = config.intermediate_size
hidden_act = config.hidden_act
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size, [intermediate_size] * 2,
bias=False,
quant_config=quant_config)
self.down_proj = RowParallelLinear(intermediate_size,
hidden_size,
bias=False,
quant_config=quant_config)
if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.")
self.act_fn = SiluAndMul()

def forward(self, x):
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
return x


class MambaDecoderLayer(nn.Module):

def __init__(self,
Expand All @@ -252,7 +212,6 @@ def __init__(self,
self.config = config
self.mixer = MambaMixer(config, layer_idx)

self.feed_forward = MambaMLP(config, quant_config=quant_config)
self.norm = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
self.pre_ff_layernorm = RMSNorm(config.hidden_size,
eps=config.layer_norm_epsilon)
Expand All @@ -261,24 +220,16 @@ def forward(
self,
hidden_states: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor],
conv_state: torch.Tensor,
ssm_state: torch.Tensor,
**kwargs,
):
if residual is None:
residual = hidden_states
hidden_states = self.norm(hidden_states)
else:
hidden_states, residual = self.norm(hidden_states, residual)

residual = hidden_states
hidden_states = self.norm(hidden_states)
hidden_states = self.mixer(hidden_states, attn_metadata, conv_state,
ssm_state)
# Fully Connected
hidden_states, residual = self.pre_ff_layernorm(
hidden_states, residual)
hidden_states = self.feed_forward(hidden_states)
return hidden_states, residual
hidden_states = hidden_states + residual
return hidden_states


class MambaModel(nn.Module):
Expand Down Expand Up @@ -319,7 +270,6 @@ def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
conv_state: torch.Tensor,
ssm_state: torch.Tensor,
Expand All @@ -332,40 +282,20 @@ def forward(
current_ssm_state = ssm_state[i]
current_conv_state = conv_state[i]

hidden_states, residual = layer(
hidden_states = layer(
positions=positions,
hidden_states=hidden_states,
attn_metadata=attn_metadata,
residual=residual,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Residual is still passed in here. I personally think it is fine to keep the previous structure of passing residual into rmsnorm, but up to you

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I saw modeling_mamba.py do it this way and like that it's super clean, but you're right -- going to revert to the previous way so that we fuse. Sidenote: this is the kind of optimization we should be relying on torch.compile to do for us IMO

conv_state=current_conv_state,
ssm_state=current_ssm_state,
)
hidden_states, _ = self.norm_f(hidden_states, residual)
hidden_states = self.norm_f(hidden_states)

return hidden_states


class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
}

# LoRA specific attributes
supported_lora_modules = [
"qkv_proj",
"o_proj",
"embed_tokens",
"lm_head",
]
embedding_modules = {
"embeddings": "input_embeddings",
"lm_head": "output_embeddings",
}
embedding_padding_modules = ["lm_head"]

def __init__(
self,
Expand Down Expand Up @@ -416,8 +346,8 @@ def forward(self,
mamba_cache_tensors = self.mamba_cache.current_run_tensors(
input_ids, attn_metadata, **kwargs)

hidden_states = self.backbone(input_ids, positions, kv_caches,
attn_metadata, mamba_cache_tensors[0],
hidden_states = self.backbone(input_ids, positions, attn_metadata,
mamba_cache_tensors[0],
mamba_cache_tensors[1])

return hidden_states
Expand Down Expand Up @@ -457,43 +387,16 @@ def sample(
return next_tokens

def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]

params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue

if "A_log" in name:
name = name.replace("A_log", "A")

if ".self_attn." in name:
name = name.replace(".self_attn", "")

for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue

param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue

param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)