Skip to content

Commit

Permalink
To merged column tp
Browse files Browse the repository at this point in the history
  • Loading branch information
mzusman committed Oct 30, 2024
1 parent 11141de commit aa6bbcc
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions vllm/model_executor/layers/mamba/mamba_mixer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
from vllm.model_executor.layers.linear import (ColumnParallelLinear, MergedColumnParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
causal_conv1d_fn, causal_conv1d_update)
Expand Down Expand Up @@ -58,8 +58,8 @@ def __init__(self,
# doesn't allow to override it
self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1)

self.in_proj = ColumnParallelLinear(hidden_size,
intermediate_size * 2,
self.in_proj = MergedColumnParallelLinear(hidden_size,
[intermediate_size] * 2,
bias=use_bias)
# selective projection used to make dt, B and C input dependent
self.x_proj = RowParallelLinear(
Expand Down

0 comments on commit aa6bbcc

Please sign in to comment.