Skip to content

Commit

Permalink
[Bugfix] Clean up some cruft in mamba.py (#9343)
Browse files Browse the repository at this point in the history
  • Loading branch information
tlrmchlsmth authored Oct 15, 2024
1 parent f0fe4fe commit 169b530
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 104 deletions.
2 changes: 1 addition & 1 deletion docs/source/models/supported_models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ Text Generation
* - :code:`MambaForCausalLM`
- Mamba
- :code:`state-spaces/mamba-130m-hf`, :code:`state-spaces/mamba-790m-hf`, :code:`state-spaces/mamba-2.8b-hf`, etc.
- ✅︎
-
-
* - :code:`MiniCPMForCausalLM`
- MiniCPM
Expand Down
113 changes: 10 additions & 103 deletions vllm/model_executor/models/mamba.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# coding=utf-8
"""PyTorch MAMBA model."""
from dataclasses import dataclass
from typing import Iterable, List, Optional, Tuple

import torch
Expand All @@ -10,7 +9,6 @@
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear,
Expand Down Expand Up @@ -39,13 +37,6 @@
KVCache = Tuple[torch.Tensor, torch.Tensor]


@dataclass
class MambaCacheParams:
is_prompt: bool = False
conv_state: torch.Tensor = torch.Tensor()
ssm_state: torch.Tensor = torch.Tensor()


# Adapted from transformers.models.mamba.modeling_mamba.MambaMixer
class MambaMixer(nn.Module):
"""
Expand Down Expand Up @@ -209,37 +200,6 @@ def forward(self, hidden_states: torch.Tensor,
return contextualized_states


class MambaMLP(nn.Module):

def __init__(
self,
config: MambaConfig,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
hidden_size = config.hidden_size
intermediate_size = config.intermediate_size
hidden_act = config.hidden_act
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size, [intermediate_size] * 2,
bias=False,
quant_config=quant_config)
self.down_proj = RowParallelLinear(intermediate_size,
hidden_size,
bias=False,
quant_config=quant_config)
if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.")
self.act_fn = SiluAndMul()

def forward(self, x):
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
return x


class MambaDecoderLayer(nn.Module):

def __init__(self,
Expand All @@ -252,7 +212,6 @@ def __init__(self,
self.config = config
self.mixer = MambaMixer(config, layer_idx)

self.feed_forward = MambaMLP(config, quant_config=quant_config)
self.norm = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
self.pre_ff_layernorm = RMSNorm(config.hidden_size,
eps=config.layer_norm_epsilon)
Expand All @@ -274,10 +233,6 @@ def forward(

hidden_states = self.mixer(hidden_states, attn_metadata, conv_state,
ssm_state)
# Fully Connected
hidden_states, residual = self.pre_ff_layernorm(
hidden_states, residual)
hidden_states = self.feed_forward(hidden_states)
return hidden_states, residual


Expand Down Expand Up @@ -319,7 +274,6 @@ def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
conv_state: torch.Tensor,
ssm_state: torch.Tensor,
Expand All @@ -346,26 +300,6 @@ def forward(


class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
}

# LoRA specific attributes
supported_lora_modules = [
"qkv_proj",
"o_proj",
"embed_tokens",
"lm_head",
]
embedding_modules = {
"embeddings": "input_embeddings",
"lm_head": "output_embeddings",
}
embedding_padding_modules = ["lm_head"]

def __init__(
self,
Expand Down Expand Up @@ -416,8 +350,8 @@ def forward(self,
mamba_cache_tensors = self.mamba_cache.current_run_tensors(
input_ids, attn_metadata, **kwargs)

hidden_states = self.backbone(input_ids, positions, kv_caches,
attn_metadata, mamba_cache_tensors[0],
hidden_states = self.backbone(input_ids, positions, attn_metadata,
mamba_cache_tensors[0],
mamba_cache_tensors[1])

return hidden_states
Expand Down Expand Up @@ -457,43 +391,16 @@ def sample(
return next_tokens

def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]

params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue

if "A_log" in name:
name = name.replace("A_log", "A")

if ".self_attn." in name:
name = name.replace(".self_attn", "")

for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue

param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue

param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)

0 comments on commit 169b530

Please sign in to comment.