From b5f3e3233fadbed93cb0183249eed57bc0156d19 Mon Sep 17 00:00:00 2001 From: mzusman Date: Sun, 3 Nov 2024 19:34:59 +0200 Subject: [PATCH] Add mixer to jamba Signed-off-by: mzusman --- vllm/model_executor/models/jamba.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index 715ccc74620c5..6f7949c880e61 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -16,10 +16,7 @@ ReplicatedLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( - causal_conv1d_fn, causal_conv1d_update) -from vllm.model_executor.layers.mamba.ops.mamba_ssm import ( - selective_scan_fn, selective_state_update) +from vllm.model_executor.layers.mamba.mamba_mixer import MambaMixer from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -28,7 +25,6 @@ from vllm.model_executor.models.mamba_cache import (MambaCacheManager, MambaCacheParams) from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.model_executor.utils import set_weight_attrs from vllm.sequence import IntermediateTensors from vllm.worker.model_runner import (_BATCH_SIZES_TO_CAPTURE, _get_graph_batch_size) @@ -108,9 +104,18 @@ def __init__(self, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None) -> None: super().__init__() - self.layer_idx = layer_idx self.config = config - self.mamba = JambaMambaMixer(config) + self.mamba = MambaMixer(hidden_size= config.hidden_size, + ssm_state_size = config.mamba_d_state, + conv_kernel_size = config.mamba_d_conv, + intermediate_size = config.mamba_expand *\ + config.hidden_size, + time_step_rank = config.mamba_dt_rank, + use_conv_bias = config.mamba_conv_bias, + use_bias = config.mamba_proj_bias, + use_rms_norm=True, + rms_norm_eps=config.rms_norm_eps, + activation=config.hidden_act) num_experts = config.layers_num_experts[layer_idx] ffn_layer_class = JambaMoE if num_experts > 1 else JambaMLP