Skip to content

Commit

Permalink
[Bugfix] Fixup Mamba
Browse files Browse the repository at this point in the history
Signed-off-by: Tyler Michael Smith <[email protected]>
  • Loading branch information
tlrmchlsmth committed Nov 4, 2024
1 parent 04cef2c commit 5170b48
Showing 1 changed file with 3 additions and 4 deletions.
7 changes: 3 additions & 4 deletions vllm/model_executor/models/mamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,16 +39,16 @@ def __init__(self,
super().__init__()
self.config = config
self.is_falcon_mamba = config.model_type == "falcon_mamba"
mixer_rms_rps = config.mixer_rms_rps if self.is_falcon_mamba else None
self.mamba = MambaMixer(hidden_size=config.hidden_size,
mixer_rms_eps = config.mixer_rms_eps if self.is_falcon_mamba else None
self.mixer = MambaMixer(hidden_size=config.hidden_size,
ssm_state_size=config.state_size,
conv_kernel_size=config.conv_kernel,
intermediate_size=config.intermediate_size,
time_step_rank=config.time_step_rank,
use_conv_bias=config.use_conv_bias,
use_bias=config.use_bias,
use_rms_norm=self.is_falcon_mamba,
rms_norm_eps=mixer_rms_rps,
rms_norm_eps=mixer_rms_eps,
activation=config.hidden_act)

self.norm = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
Expand Down Expand Up @@ -99,7 +99,6 @@ def __init__(
for i in range(config.num_hidden_layers):
decoder_layers.append(
MambaDecoderLayer(config,
layer_idx=i,
cache_config=cache_config,
quant_config=quant_config))
self.layers = nn.ModuleList(decoder_layers)
Expand Down

0 comments on commit 5170b48

Please sign in to comment.