diff --git a/vllm/model_executor/layers/mamba/mamba_mixer.py b/vllm/model_executor/layers/mamba/mamba_mixer.py index 57e9772b0b600..b200d3ebf9c29 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer.py @@ -7,7 +7,7 @@ get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (ColumnParallelLinear, +from vllm.model_executor.layers.linear import (ColumnParallelLinear, MergedColumnParallelLinear, RowParallelLinear) from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( causal_conv1d_fn, causal_conv1d_update) @@ -58,8 +58,8 @@ def __init__(self, # doesn't allow to override it self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1) - self.in_proj = ColumnParallelLinear(hidden_size, - intermediate_size * 2, + self.in_proj = MergedColumnParallelLinear(hidden_size, + [intermediate_size] * 2, bias=use_bias) # selective projection used to make dt, B and C input dependent self.x_proj = RowParallelLinear(