From 90d0ff3fe7581e2c0266d4062a2f3c2dfc67b469 Mon Sep 17 00:00:00 2001 From: mzusman Date: Sun, 8 Dec 2024 16:17:36 +0200 Subject: [PATCH 01/29] Add PP support to Jamba Signed-off-by: mzusman --- vllm/model_executor/models/jamba.py | 91 ++++++++++++++++++++--------- 1 file changed, 63 insertions(+), 28 deletions(-) diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index 5d5e8ae1ee532..098b31cb7cef6 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -9,6 +9,7 @@ from vllm.attention.layer import Attention from vllm.config import _BATCH_SIZES_TO_CAPTURE, CacheConfig, VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.distributed.parallel_state import get_pp_group from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (QKVParallelLinear, @@ -26,8 +27,10 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors -from .interfaces import HasInnerState, SupportsLoRA -from .utils import maybe_prefix +from .interfaces import HasInnerState, SupportsLoRA, SupportsPP +from .utils import (is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) KVCache = Tuple[torch.Tensor, torch.Tensor] @@ -281,16 +284,24 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): org_num_embeddings=config.vocab_size, ) - decoder_layers = [] - for i in range(config.num_hidden_layers): - layer_class = ALL_DECODER_LAYER_TYPES[config.layers_block_type[i]] - decoder_layers.append( - layer_class(config, - layer_idx=i, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.layers.{i}")) - self.layers = nn.ModuleList(decoder_layers) + def get_layer(prefix: str): + layer_idx = int(prefix.split(".")[-1]) + layer_class = ALL_DECODER_LAYER_TYPES[ + config.layers_block_type[layer_idx]] + return layer_class( + config, + layer_idx, + cache_config, + quant_config=quant_config, + prefix=prefix, + ) + + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, get_layer, prefix=f"{prefix}.layers") + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) + self.final_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -304,26 +315,34 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, mamba_cache_params: MambaCacheParams, + intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: - if inputs_embeds is not None: - hidden_states = inputs_embeds + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None else: - hidden_states = self.get_input_embeddings(input_ids) - residual = None - for i in range(len(self.layers)): + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + + kv_cache_index = 0 + mamba_cache_index = 0 + for i in range(self.start_layer, self.end_layer): layer = self.layers[i] kv_cache = None layer_mamba_cache_params = None if isinstance(layer, JambaAttentionDecoderLayer): - kv_cache = kv_caches[(i - self.config.attn_layer_offset) // - self.config.attn_layer_period] + kv_cache = kv_caches[kv_cache_index] + kv_cache_index += 1 if isinstance(layer, JambaMambaDecoderLayer): - current_state_layer = i - (1 + - (i - self.config.attn_layer_offset) - // self.config.attn_layer_period) + current_state_layer = mamba_cache_index layer_mamba_cache_params = mamba_cache_params.at_layer_idx( current_state_layer) + mamba_cache_index += 1 hidden_states, residual = layer( positions=positions, @@ -332,11 +351,16 @@ def forward( attn_metadata=attn_metadata, residual=residual, mamba_cache_params=layer_mamba_cache_params) + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) hidden_states, _ = self.final_layernorm(hidden_states, residual) return hidden_states -class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA): +class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP): packed_modules_mapping = { "qkv_proj": [ "q_proj", @@ -368,6 +392,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() self.config = config + self.vllm_config = vllm_config + self.model_config = vllm_config.model_config self.scheduler_config = scheduler_config self.model = JambaModel(vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")) @@ -390,6 +416,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config.vocab_size) self.sampler = get_sampler() + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) @@ -406,10 +435,8 @@ def forward(self, self.scheduler_config.max_num_seqs) if self.scheduler_config else max(_BATCH_SIZES_TO_CAPTURE) + 2) - layers_type = self.config.layers_block_type - num_mamba_layers = sum( - [layer_type == "mamba" for layer_type in layers_type]) - + num_mamba_layers = self.model_config.get_num_layers_by_block_type( + self.vllm_config.parallel_config, "mamba") self.mamba_cache = MambaCacheManager( self.lm_head.weight.dtype, num_mamba_layers, max_batch_size, *self._get_mamba_cache_shape()) @@ -423,7 +450,7 @@ def forward(self, state_indices_tensor) hidden_states = self.model(input_ids, positions, kv_caches, attn_metadata, mamba_cache_params, - inputs_embeds) + intermediate_tensors, inputs_embeds) return hidden_states def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): @@ -504,8 +531,12 @@ def load_weights(self, weights: Iterable[Tuple[str, continue name = name.replace(weight_name, param_name) # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: continue + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) @@ -520,6 +551,8 @@ def load_weights(self, weights: Iterable[Tuple[str, if weight_name not in name: continue + if is_pp_missing_parameter(name, self): + continue name = name.replace(weight_name, param_name) param = params_dict[name] weight_loader = param.weight_loader @@ -533,6 +566,8 @@ def load_weights(self, weights: Iterable[Tuple[str, # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue + if is_pp_missing_parameter(name, self): + continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", From d30f741b01a30fdfb585944c5f09df1d65af9ab3 Mon Sep 17 00:00:00 2001 From: mzusman Date: Sun, 8 Dec 2024 16:17:48 +0200 Subject: [PATCH 02/29] GPU Runner adaptations Signed-off-by: mzusman --- vllm/v1/worker/gpu_model_runner.py | 2 +- vllm/v1/worker/gpu_worker.py | 2 +- vllm/worker/cache_engine.py | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index e8d964a722f60..8b407f183de3f 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -69,7 +69,7 @@ def __init__( self.max_num_tokens = scheduler_config.max_num_batched_tokens # Model-related. - self.num_attn_layers = model_config.get_num_attention_layers( + self.num_attn_layers = model_config.get_num_layers_by_block_type( parallel_config) self.num_kv_heads = model_config.get_num_kv_heads(parallel_config) self.head_size = model_config.get_head_size() diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index d33b55a8a9f9a..77bfa50274514 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -253,7 +253,7 @@ def _get_cache_block_size( ) -> int: head_size = model_config.get_head_size() num_heads = model_config.get_num_kv_heads(parallel_config) - num_attention_layers = model_config.get_num_attention_layers( + num_attention_layers = model_config.get_num_layers_by_block_type( parallel_config) key_cache_block = cache_config.block_size * num_heads * head_size diff --git a/vllm/worker/cache_engine.py b/vllm/worker/cache_engine.py index ac3270d1c9909..6c58ade83c8af 100644 --- a/vllm/worker/cache_engine.py +++ b/vllm/worker/cache_engine.py @@ -34,7 +34,7 @@ def __init__( self.head_size = model_config.get_head_size() # Models like Jamba, have mixed typed layers, E.g Mamba - self.num_attention_layers = model_config.get_num_attention_layers( + self.num_attention_layers = model_config.get_num_layers_by_block_type( parallel_config) self.num_kv_heads = model_config.get_num_kv_heads(parallel_config) @@ -105,7 +105,7 @@ def get_cache_block_size( ) -> int: head_size = model_config.get_head_size() num_heads = model_config.get_num_kv_heads(parallel_config) - num_attention_layers = model_config.get_num_attention_layers( + num_attention_layers = model_config.get_num_layers_by_block_type( parallel_config) key_cache_block = cache_config.block_size * num_heads * head_size From 47e961b3756de07c11ca4dab77493d3891f2435a Mon Sep 17 00:00:00 2001 From: mzusman Date: Sun, 8 Dec 2024 16:18:14 +0200 Subject: [PATCH 03/29] Add tests and adaptations Signed-off-by: mzusman --- tests/distributed/test_pipeline_parallel.py | 1 + vllm/config.py | 24 ++++++++++++--------- 2 files changed, 15 insertions(+), 10 deletions(-) diff --git a/tests/distributed/test_pipeline_parallel.py b/tests/distributed/test_pipeline_parallel.py index b818ca921fcb0..ca7d0a106afdd 100644 --- a/tests/distributed/test_pipeline_parallel.py +++ b/tests/distributed/test_pipeline_parallel.py @@ -181,6 +181,7 @@ def iter_params(self, model_name: str): "stabilityai/stablelm-3b-4e1t": PPTestSettings.fast(), "bigcode/starcoder2-3b": PPTestSettings.fast(), "upstage/solar-pro-preview-instruct": PPTestSettings.fast(tp_base=2), + "ai21labs/Jamba-tiny-dev": PPTestSettings.fast() # FIXME: Cannot load tokenizer in latest transformers version. # Need to use tokenizer from `meta-llama/Llama-2-7b-chat-hf` # "xverse/XVERSE-7B-Chat": PPTestSettings.fast(trust_remote_code=True), diff --git a/vllm/config.py b/vllm/config.py index 164622b5af34e..4adefab7d30d9 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -680,26 +680,30 @@ def get_num_attention_heads(self, num_heads = getattr(self.hf_text_config, "num_attention_heads", 0) return num_heads // parallel_config.tensor_parallel_size - def get_num_layers(self, parallel_config: "ParallelConfig") -> int: + def get_layers_start_end_indices( + self, parallel_config: "ParallelConfig") -> Tuple[int, int]: from vllm.distributed.utils import get_pp_indices total_num_hidden_layers = getattr(self.hf_text_config, "num_hidden_layers", 0) pp_rank = parallel_config.rank // parallel_config.tensor_parallel_size pp_size = parallel_config.pipeline_parallel_size start, end = get_pp_indices(total_num_hidden_layers, pp_rank, pp_size) - return end - start + return start, end - def get_num_attention_layers(self, - parallel_config: "ParallelConfig") -> int: - if self.is_attention_free: - return 0 + def get_num_layers(self, parallel_config: "ParallelConfig") -> int: + start, end = self.get_layers_start_end_indices(parallel_config) + return end - start - num_layers = self.get_num_layers(parallel_config) + def get_num_layers_by_block_type(self, + parallel_config: "ParallelConfig", + block_type="attention") -> int: + start, end = self.get_layers_start_end_indices(parallel_config) # Transformers supports layers_block_type @property - layers = getattr(self.hf_config, "layers_block_type", - ["attention"] * num_layers) - return len([t for t in layers if t == "attention"]) + layers_block_type = getattr(self.hf_config, "layers_block_type", + ["attention"] * (end - start)) + return len( + [t for t in layers_block_type[start:end] if t == block_type]) def get_multimodal_config(self) -> "MultiModalConfig": """ From e7e3ea484cef7654e875dd646686c73f074ee78e Mon Sep 17 00:00:00 2001 From: mzusman Date: Sun, 8 Dec 2024 17:17:03 +0200 Subject: [PATCH 04/29] Add PP to mamba Signed-off-by: mzusman --- tests/distributed/test_pipeline_parallel.py | 3 +- vllm/model_executor/models/mamba.py | 55 ++++++++++++++------- 2 files changed, 39 insertions(+), 19 deletions(-) diff --git a/tests/distributed/test_pipeline_parallel.py b/tests/distributed/test_pipeline_parallel.py index ca7d0a106afdd..1317a6f6ae41e 100644 --- a/tests/distributed/test_pipeline_parallel.py +++ b/tests/distributed/test_pipeline_parallel.py @@ -181,7 +181,8 @@ def iter_params(self, model_name: str): "stabilityai/stablelm-3b-4e1t": PPTestSettings.fast(), "bigcode/starcoder2-3b": PPTestSettings.fast(), "upstage/solar-pro-preview-instruct": PPTestSettings.fast(tp_base=2), - "ai21labs/Jamba-tiny-dev": PPTestSettings.fast() + "ai21labs/Jamba-tiny-dev": PPTestSettings.fast(), + "state-spaces/mamba-130m-hf": PPTestSettings.fast() # FIXME: Cannot load tokenizer in latest transformers version. # Need to use tokenizer from `meta-llama/Llama-2-7b-chat-hf` # "xverse/XVERSE-7B-Chat": PPTestSettings.fast(trust_remote_code=True), diff --git a/vllm/model_executor/models/mamba.py b/vllm/model_executor/models/mamba.py index b32032e411b0a..d5bda01638421 100644 --- a/vllm/model_executor/models/mamba.py +++ b/vllm/model_executor/models/mamba.py @@ -8,6 +8,7 @@ from vllm.attention.backends.abstract import AttentionMetadata from vllm.config import _BATCH_SIZES_TO_CAPTURE, CacheConfig, VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.distributed.parallel_state import get_pp_group from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.mamba.mamba_mixer import MambaMixer @@ -18,13 +19,14 @@ DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.interfaces import (HasInnerState, - IsAttentionFree) + IsAttentionFree, SupportsPP) from vllm.model_executor.models.mamba_cache import (MambaCacheManager, MambaCacheParams) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors -from .utils import maybe_prefix +from .utils import (make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) KVCache = Tuple[torch.Tensor, torch.Tensor] @@ -94,15 +96,17 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): org_num_embeddings=config.vocab_size, ) - decoder_layers = [] - for i in range(config.num_hidden_layers): - decoder_layers.append( - MambaDecoderLayer(config, - cache_config=cache_config, - quant_config=quant_config)) - self.layers = nn.ModuleList(decoder_layers) + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: MambaDecoderLayer( + config, cache_config=cache_config, quant_config=quant_config), + prefix=f"{prefix}.layers") + self.norm_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embeddings(input_ids) @@ -113,29 +117,40 @@ def forward( positions: torch.Tensor, attn_metadata: AttentionMetadata, mamba_cache_params: MambaCacheParams, + intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: - - if inputs_embeds is not None: - hidden_states = inputs_embeds + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None else: - hidden_states = self.get_input_embeddings(input_ids) - residual = None + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] - for i in range(len(self.layers)): + for i in range(self.start_layer, self.end_layer): layer = self.layers[i] hidden_states, residual = layer( positions=positions, hidden_states=hidden_states, attn_metadata=attn_metadata, residual=residual, - mamba_cache_params=mamba_cache_params.at_layer_idx(i)) + mamba_cache_params=mamba_cache_params.at_layer_idx( + i - self.start_layer)) + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) hidden_states, _ = self.norm_f(hidden_states, residual) return hidden_states -class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree): +class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsPP): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_config @@ -173,6 +188,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config.vocab_size) self.sampler = get_sampler() + self.make_empty_intermediate_tensors = ( + self.backbone.make_empty_intermediate_tensors) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.backbone.get_input_embeddings(input_ids) @@ -203,7 +221,8 @@ def forward(self, state_indices_tensor) hidden_states = self.backbone(input_ids, positions, attn_metadata, - mamba_cache_params, inputs_embeds) + mamba_cache_params, intermediate_tensors, + inputs_embeds) return hidden_states From ebbf7e339b0b77dc700abec0211b1a915aeb2f5f Mon Sep 17 00:00:00 2001 From: mzusman Date: Sun, 8 Dec 2024 17:31:12 +0200 Subject: [PATCH 05/29] Introduce LayerBlockType Signed-off-by: mzusman --- vllm/config.py | 23 +++++++++++++++++------ vllm/utils.py | 3 +++ vllm/v1/worker/gpu_model_runner.py | 4 ++-- vllm/v1/worker/gpu_worker.py | 4 ++-- vllm/worker/cache_engine.py | 6 +++--- 5 files changed, 27 insertions(+), 13 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 4adefab7d30d9..4f05c54f1eb3b 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -27,9 +27,14 @@ ConfigFormat, get_config, get_hf_image_processor_config, get_hf_text_config, get_pooling_config, get_sentence_transformer_tokenizer_config, is_encoder_decoder, uses_mrope) +<<<<<<< HEAD from vllm.utils import (GiB_bytes, cuda_device_count_stateless, get_cpu_memory, print_warning_once, random_uuid, resolve_obj_by_qualname) +======= +from vllm.utils import (GiB_bytes, LayerBlockType, cuda_device_count_stateless, get_cpu_memory, + print_warning_once, resolve_obj_by_qualname) +>>>>>>> 121764ea9 (Introduce LayerBlockType) if TYPE_CHECKING: from ray.util.placement_group import PlacementGroup @@ -694,16 +699,22 @@ def get_num_layers(self, parallel_config: "ParallelConfig") -> int: start, end = self.get_layers_start_end_indices(parallel_config) return end - start - def get_num_layers_by_block_type(self, - parallel_config: "ParallelConfig", - block_type="attention") -> int: + def get_num_layers_by_block_type( + self, + parallel_config: "ParallelConfig", + block_type: LayerBlockType = LayerBlockType.attention, + ) -> int: start, end = self.get_layers_start_end_indices(parallel_config) # Transformers supports layers_block_type @property - layers_block_type = getattr(self.hf_config, "layers_block_type", - ["attention"] * (end - start)) + layers_block_type_value = getattr(self.hf_config, "layers_block_type", + [block_type.value] * (end - start)) return len( - [t for t in layers_block_type[start:end] if t == block_type]) + [ + t + for t in layers_block_type_value[start:end] + if t == block_type.value + ]) def get_multimodal_config(self) -> "MultiModalConfig": """ diff --git a/vllm/utils.py b/vllm/utils.py index 1f19d9eacd16d..822056ae92945 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -168,6 +168,9 @@ class Device(enum.Enum): GPU = enum.auto() CPU = enum.auto() +class LayerBlockType(enum.Enum): + attention = "attention" + mamba = "mamba" class Counter: diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 8b407f183de3f..2fec505050832 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -16,7 +16,7 @@ from vllm.model_executor.model_loader import get_model from vllm.multimodal import MultiModalKwargs from vllm.sampling_params import SamplingParams, SamplingType -from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, cdiv, +from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, LayerBlockType, cdiv, is_pin_memory_available) from vllm.v1.attention.backends.flash_attn import (FlashAttentionBackend, FlashAttentionMetadata) @@ -70,7 +70,7 @@ def __init__( # Model-related. self.num_attn_layers = model_config.get_num_layers_by_block_type( - parallel_config) + parallel_config, LayerBlockType.attention) self.num_kv_heads = model_config.get_num_kv_heads(parallel_config) self.head_size = model_config.get_head_size() self.hidden_size = model_config.get_hidden_size() diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 77bfa50274514..c3362841f47ce 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -14,7 +14,7 @@ from vllm.logger import init_logger from vllm.model_executor import set_random_seed from vllm.platforms import current_platform -from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, get_dtype_size +from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, LayerBlockType, get_dtype_size from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.worker.gpu_model_runner import GPUModelRunner @@ -254,7 +254,7 @@ def _get_cache_block_size( head_size = model_config.get_head_size() num_heads = model_config.get_num_kv_heads(parallel_config) num_attention_layers = model_config.get_num_layers_by_block_type( - parallel_config) + parallel_config, LayerBlockType.attention) key_cache_block = cache_config.block_size * num_heads * head_size value_cache_block = key_cache_block diff --git a/vllm/worker/cache_engine.py b/vllm/worker/cache_engine.py index 6c58ade83c8af..17863d1c178eb 100644 --- a/vllm/worker/cache_engine.py +++ b/vllm/worker/cache_engine.py @@ -6,7 +6,7 @@ from vllm.attention import get_attn_backend from vllm.config import CacheConfig, DeviceConfig, ModelConfig, ParallelConfig from vllm.logger import init_logger -from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, get_dtype_size, +from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, LayerBlockType, get_dtype_size, is_pin_memory_available) logger = init_logger(__name__) @@ -35,7 +35,7 @@ def __init__( self.head_size = model_config.get_head_size() # Models like Jamba, have mixed typed layers, E.g Mamba self.num_attention_layers = model_config.get_num_layers_by_block_type( - parallel_config) + parallel_config, LayerBlockType.attention) self.num_kv_heads = model_config.get_num_kv_heads(parallel_config) self.block_size = cache_config.block_size @@ -106,7 +106,7 @@ def get_cache_block_size( head_size = model_config.get_head_size() num_heads = model_config.get_num_kv_heads(parallel_config) num_attention_layers = model_config.get_num_layers_by_block_type( - parallel_config) + parallel_config, LayerBlockType.attention) key_cache_block = cache_config.block_size * num_heads * head_size value_cache_block = key_cache_block From 5514aa6f085f71f07fb0eb3c05e39d208249de47 Mon Sep 17 00:00:00 2001 From: mzusman Date: Sun, 8 Dec 2024 17:33:03 +0200 Subject: [PATCH 06/29] Fix jamba Signed-off-by: mzusman --- vllm/config.py | 5 ----- vllm/model_executor/models/jamba.py | 2 +- 2 files changed, 1 insertion(+), 6 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 4f05c54f1eb3b..5c7cc256a4932 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -27,14 +27,9 @@ ConfigFormat, get_config, get_hf_image_processor_config, get_hf_text_config, get_pooling_config, get_sentence_transformer_tokenizer_config, is_encoder_decoder, uses_mrope) -<<<<<<< HEAD from vllm.utils import (GiB_bytes, cuda_device_count_stateless, get_cpu_memory, print_warning_once, random_uuid, resolve_obj_by_qualname) -======= -from vllm.utils import (GiB_bytes, LayerBlockType, cuda_device_count_stateless, get_cpu_memory, - print_warning_once, resolve_obj_by_qualname) ->>>>>>> 121764ea9 (Introduce LayerBlockType) if TYPE_CHECKING: from ray.util.placement_group import PlacementGroup diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index 098b31cb7cef6..0f7739351eb64 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -436,7 +436,7 @@ def forward(self, else max(_BATCH_SIZES_TO_CAPTURE) + 2) num_mamba_layers = self.model_config.get_num_layers_by_block_type( - self.vllm_config.parallel_config, "mamba") + self.vllm_config.parallel_config, LayerBlockType.mamba) self.mamba_cache = MambaCacheManager( self.lm_head.weight.dtype, num_mamba_layers, max_batch_size, *self._get_mamba_cache_shape()) From f9fcff41873f0ef5a17c10ff26fee6d1eb660ce3 Mon Sep 17 00:00:00 2001 From: mzusman Date: Sun, 8 Dec 2024 17:45:14 +0200 Subject: [PATCH 07/29] Fix MS for Mamba Signed-off-by: mzusman --- vllm/attention/backends/placeholder_attn.py | 4 ++++ vllm/worker/multi_step_model_runner.py | 4 +++- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/vllm/attention/backends/placeholder_attn.py b/vllm/attention/backends/placeholder_attn.py index 888adbffb8578..b444349d152a4 100644 --- a/vllm/attention/backends/placeholder_attn.py +++ b/vllm/attention/backends/placeholder_attn.py @@ -114,6 +114,10 @@ class PlaceholderAttentionMetadata(AttentionMetadata): _cached_prefill_metadata: Optional["PlaceholderAttentionMetadata"] = None _cached_decode_metadata: Optional["PlaceholderAttentionMetadata"] = None + def advance_step(self, *args, **kwargs): + # No need to do anything here + pass + @property def prefill_metadata(self) -> Optional["PlaceholderAttentionMetadata"]: if self.num_prefills == 0: diff --git a/vllm/worker/multi_step_model_runner.py b/vllm/worker/multi_step_model_runner.py index 3ca0d88a42183..e08a61e31fe42 100644 --- a/vllm/worker/multi_step_model_runner.py +++ b/vllm/worker/multi_step_model_runner.py @@ -29,7 +29,9 @@ logger = init_logger(__name__) -MULTI_STEP_ATTENTION_BACKENDS = ["FLASH_ATTN", "ROCM_FLASH", "FLASHINFER"] +MULTI_STEP_ATTENTION_BACKENDS = [ + "FLASH_ATTN", "ROCM_FLASH", "FLASHINFER", "NO_ATTENTION" +] MULTI_STEP_CHUNKED_PREFILL_ATTENTION_BACKENDS = ["FLASH_ATTN"] def _get_supported_attention_backends(chunked_prefill_enabled: bool) \ From 2b5b52510ffe35b842a198da0eac07536ba0fbba Mon Sep 17 00:00:00 2001 From: mzusman Date: Sun, 8 Dec 2024 17:45:43 +0200 Subject: [PATCH 08/29] Format Signed-off-by: mzusman --- vllm/config.py | 12 +++++------- vllm/utils.py | 2 ++ vllm/v1/worker/gpu_model_runner.py | 4 ++-- vllm/worker/cache_engine.py | 4 ++-- 4 files changed, 11 insertions(+), 11 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 5c7cc256a4932..5f2c8c3e997fe 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -703,13 +703,11 @@ def get_num_layers_by_block_type( # Transformers supports layers_block_type @property layers_block_type_value = getattr(self.hf_config, "layers_block_type", - [block_type.value] * (end - start)) - return len( - [ - t - for t in layers_block_type_value[start:end] - if t == block_type.value - ]) + [block_type.value] * (end - start)) + return len([ + t for t in layers_block_type_value[start:end] + if t == block_type.value + ]) def get_multimodal_config(self) -> "MultiModalConfig": """ diff --git a/vllm/utils.py b/vllm/utils.py index 822056ae92945..025a99455433e 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -168,10 +168,12 @@ class Device(enum.Enum): GPU = enum.auto() CPU = enum.auto() + class LayerBlockType(enum.Enum): attention = "attention" mamba = "mamba" + class Counter: def __init__(self, start: int = 0) -> None: diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 2fec505050832..9f1b37d3b8e8c 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -16,8 +16,8 @@ from vllm.model_executor.model_loader import get_model from vllm.multimodal import MultiModalKwargs from vllm.sampling_params import SamplingParams, SamplingType -from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, LayerBlockType, cdiv, - is_pin_memory_available) +from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, + LayerBlockType, cdiv, is_pin_memory_available) from vllm.v1.attention.backends.flash_attn import (FlashAttentionBackend, FlashAttentionMetadata) from vllm.v1.outputs import ModelRunnerOutput diff --git a/vllm/worker/cache_engine.py b/vllm/worker/cache_engine.py index 17863d1c178eb..7ccd4571b19df 100644 --- a/vllm/worker/cache_engine.py +++ b/vllm/worker/cache_engine.py @@ -6,8 +6,8 @@ from vllm.attention import get_attn_backend from vllm.config import CacheConfig, DeviceConfig, ModelConfig, ParallelConfig from vllm.logger import init_logger -from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, LayerBlockType, get_dtype_size, - is_pin_memory_available) +from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, LayerBlockType, + get_dtype_size, is_pin_memory_available) logger = init_logger(__name__) From 01c7d9318d20891302988ee117602e8d5a3a059f Mon Sep 17 00:00:00 2001 From: mzusman Date: Sun, 8 Dec 2024 17:58:40 +0200 Subject: [PATCH 09/29] Format Signed-off-by: mzusman --- vllm/model_executor/models/jamba.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index 0f7739351eb64..dc1dabe8e478e 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -26,6 +26,7 @@ MambaCacheParams) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors +from vllm.utils import LayerBlockType from .interfaces import HasInnerState, SupportsLoRA, SupportsPP from .utils import (is_pp_missing_parameter, From 50f2a8c01660a7b2749149643acd17b81bcb0b3d Mon Sep 17 00:00:00 2001 From: mzusman Date: Sun, 8 Dec 2024 18:03:05 +0200 Subject: [PATCH 10/29] Format Signed-off-by: mzusman --- vllm/config.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 5f2c8c3e997fe..8b4d15a7055b9 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -27,8 +27,8 @@ ConfigFormat, get_config, get_hf_image_processor_config, get_hf_text_config, get_pooling_config, get_sentence_transformer_tokenizer_config, is_encoder_decoder, uses_mrope) -from vllm.utils import (GiB_bytes, cuda_device_count_stateless, get_cpu_memory, - print_warning_once, random_uuid, +from vllm.utils import (GiB_bytes, LayerBlockType, cuda_device_count_stateless, + get_cpu_memory, print_warning_once, random_uuid, resolve_obj_by_qualname) if TYPE_CHECKING: From 314c9affe19e1a8be1ba4e613f7dc718c6856096 Mon Sep 17 00:00:00 2001 From: mzusman Date: Sun, 8 Dec 2024 18:06:46 +0200 Subject: [PATCH 11/29] Remove comment Signed-off-by: mzusman --- tests/distributed/test_pipeline_parallel.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/distributed/test_pipeline_parallel.py b/tests/distributed/test_pipeline_parallel.py index 1317a6f6ae41e..e54b8c909f359 100644 --- a/tests/distributed/test_pipeline_parallel.py +++ b/tests/distributed/test_pipeline_parallel.py @@ -156,8 +156,6 @@ def iter_params(self, model_name: str): # "internlm/internlm-chat-7b": PPTestSettings.fast(), "internlm/internlm2-chat-7b": PPTestSettings.fast(trust_remote_code=True), "inceptionai/jais-13b-chat": PPTestSettings.fast(), - # TODO: Implement PP - # "ai21labs/AI21-Jamba-1.5-Mini": PPTestSettings.fast(), "meta-llama/Meta-Llama-3-8B": PPTestSettings.detailed(), "openbmb/MiniCPM-2B-sft-bf16": PPTestSettings.fast(trust_remote_code=True), "openbmb/MiniCPM3-4B": PPTestSettings.fast(trust_remote_code=True), From 3347f3f3d0f7faca26c9d942dd99e92b9224ea73 Mon Sep 17 00:00:00 2001 From: mzusman Date: Sun, 8 Dec 2024 18:11:30 +0200 Subject: [PATCH 12/29] Add `is_pp_missing_parameter` to mamba weight loading Signed-off-by: mzusman --- vllm/model_executor/models/mamba.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/models/mamba.py b/vllm/model_executor/models/mamba.py index d5bda01638421..fb0d3fe281abf 100644 --- a/vllm/model_executor/models/mamba.py +++ b/vllm/model_executor/models/mamba.py @@ -25,7 +25,8 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors -from .utils import (make_empty_intermediate_tensors_factory, make_layers, +from .utils import (is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) KVCache = Tuple[torch.Tensor, torch.Tensor] @@ -268,6 +269,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue + if is_pp_missing_parameter(name, self): + continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", From f21ff9ae58176369c075050f676701d3dbf66609 Mon Sep 17 00:00:00 2001 From: mzusman Date: Sun, 8 Dec 2024 18:30:23 +0200 Subject: [PATCH 13/29] Revert Mamba MS fix, will fix it in future PR Signed-off-by: mzusman --- vllm/attention/backends/placeholder_attn.py | 4 ---- vllm/worker/multi_step_model_runner.py | 4 +--- 2 files changed, 1 insertion(+), 7 deletions(-) diff --git a/vllm/attention/backends/placeholder_attn.py b/vllm/attention/backends/placeholder_attn.py index b444349d152a4..888adbffb8578 100644 --- a/vllm/attention/backends/placeholder_attn.py +++ b/vllm/attention/backends/placeholder_attn.py @@ -114,10 +114,6 @@ class PlaceholderAttentionMetadata(AttentionMetadata): _cached_prefill_metadata: Optional["PlaceholderAttentionMetadata"] = None _cached_decode_metadata: Optional["PlaceholderAttentionMetadata"] = None - def advance_step(self, *args, **kwargs): - # No need to do anything here - pass - @property def prefill_metadata(self) -> Optional["PlaceholderAttentionMetadata"]: if self.num_prefills == 0: diff --git a/vllm/worker/multi_step_model_runner.py b/vllm/worker/multi_step_model_runner.py index e08a61e31fe42..3ca0d88a42183 100644 --- a/vllm/worker/multi_step_model_runner.py +++ b/vllm/worker/multi_step_model_runner.py @@ -29,9 +29,7 @@ logger = init_logger(__name__) -MULTI_STEP_ATTENTION_BACKENDS = [ - "FLASH_ATTN", "ROCM_FLASH", "FLASHINFER", "NO_ATTENTION" -] +MULTI_STEP_ATTENTION_BACKENDS = ["FLASH_ATTN", "ROCM_FLASH", "FLASHINFER"] MULTI_STEP_CHUNKED_PREFILL_ATTENTION_BACKENDS = ["FLASH_ATTN"] def _get_supported_attention_backends(chunked_prefill_enabled: bool) \ From ff6e13e94fbbb2b99ab7c5732750fb43e24abd02 Mon Sep 17 00:00:00 2001 From: mzusman Date: Mon, 9 Dec 2024 08:13:09 +0200 Subject: [PATCH 14/29] Addressing review's comments Signed-off-by: mzusman --- tests/distributed/test_pipeline_parallel.py | 8 +++++--- vllm/config.py | 6 ++---- vllm/model_executor/models/jamba.py | 2 +- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/tests/distributed/test_pipeline_parallel.py b/tests/distributed/test_pipeline_parallel.py index e54b8c909f359..83f621ee0f9af 100644 --- a/tests/distributed/test_pipeline_parallel.py +++ b/tests/distributed/test_pipeline_parallel.py @@ -156,11 +156,13 @@ def iter_params(self, model_name: str): # "internlm/internlm-chat-7b": PPTestSettings.fast(), "internlm/internlm2-chat-7b": PPTestSettings.fast(trust_remote_code=True), "inceptionai/jais-13b-chat": PPTestSettings.fast(), + "ai21labs/Jamba-tiny-dev": PPTestSettings.fast(), "meta-llama/Meta-Llama-3-8B": PPTestSettings.detailed(), "openbmb/MiniCPM-2B-sft-bf16": PPTestSettings.fast(trust_remote_code=True), "openbmb/MiniCPM3-4B": PPTestSettings.fast(trust_remote_code=True), # Uses Llama # "mistralai/Mistral-7B-Instruct-v0.1": PPTestSettings.fast(), + "state-spaces/mamba-130m-hf": PPTestSettings.fast(), "mistralai/Mixtral-8x7B-Instruct-v0.1": PPTestSettings.fast(tp_base=4), "mosaicml/mpt-7b": PPTestSettings.fast(), "nvidia/Minitron-8B-Base": PPTestSettings.fast(), @@ -178,9 +180,7 @@ def iter_params(self, model_name: str): "Qwen/Qwen1.5-MoE-A2.7B-Chat": PPTestSettings.fast(), "stabilityai/stablelm-3b-4e1t": PPTestSettings.fast(), "bigcode/starcoder2-3b": PPTestSettings.fast(), - "upstage/solar-pro-preview-instruct": PPTestSettings.fast(tp_base=2), - "ai21labs/Jamba-tiny-dev": PPTestSettings.fast(), - "state-spaces/mamba-130m-hf": PPTestSettings.fast() + "upstage/solar-pro-preview-instruct": PPTestSettings.fast(tp_base=2) # FIXME: Cannot load tokenizer in latest transformers version. # Need to use tokenizer from `meta-llama/Llama-2-7b-chat-hf` # "xverse/XVERSE-7B-Chat": PPTestSettings.fast(trust_remote_code=True), @@ -234,6 +234,8 @@ def iter_params(self, model_name: str): "OpenGVLab/InternVL2-1B", "microsoft/Phi-3-vision-128k-instruct", "fixie-ai/ultravox-v0_3", + # [LANGUAGE GENERATION - HYBRID ARCH] + "ai21labs/Jamba-tiny-dev", ] diff --git a/vllm/config.py b/vllm/config.py index 8b4d15a7055b9..5e419be4ec8aa 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -704,10 +704,8 @@ def get_num_layers_by_block_type( # Transformers supports layers_block_type @property layers_block_type_value = getattr(self.hf_config, "layers_block_type", [block_type.value] * (end - start)) - return len([ - t for t in layers_block_type_value[start:end] - if t == block_type.value - ]) + return sum(t == block_type.value + for t in layers_block_type_value[start:end]) def get_multimodal_config(self) -> "MultiModalConfig": """ diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index dc1dabe8e478e..5aebbed9f0ba4 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -286,7 +286,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): ) def get_layer(prefix: str): - layer_idx = int(prefix.split(".")[-1]) + layer_idx = int(prefix.rsplit(".", 1)[1]) layer_class = ALL_DECODER_LAYER_TYPES[ config.layers_block_type[layer_idx]] return layer_class( From 06ab64d0c15f6d1b530c7cefe1d3cb4a7a15d6ce Mon Sep 17 00:00:00 2001 From: mzusman Date: Mon, 9 Dec 2024 08:20:28 +0200 Subject: [PATCH 15/29] Add mamba/jamba PP support to supported_models page Signed-off-by: mzusman --- docs/source/models/supported_models.rst | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index c9b3fa8485ff1..dd265b22bff54 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -128,7 +128,7 @@ Text Generation - FalconMamba - :code:`tiiuae/falcon-mamba-7b`, :code:`tiiuae/falcon-mamba-7b-instruct`, etc. - ✅︎ - - + - ✅︎ * - :code:`GemmaForCausalLM` - Gemma - :code:`google/gemma-2b`, :code:`google/gemma-7b`, etc. @@ -193,7 +193,7 @@ Text Generation - Jamba - :code:`ai21labs/AI21-Jamba-1.5-Large`, :code:`ai21labs/AI21-Jamba-1.5-Mini`, :code:`ai21labs/Jamba-v0.1`, etc. - ✅︎ - - + - ✅︎ * - :code:`LlamaForCausalLM` - Llama 3.1, Llama 3, Llama 2, LLaMA, Yi - :code:`meta-llama/Meta-Llama-3.1-405B-Instruct`, :code:`meta-llama/Meta-Llama-3.1-70B`, :code:`meta-llama/Meta-Llama-3-70B-Instruct`, :code:`meta-llama/Llama-2-70b-hf`, :code:`01-ai/Yi-34B`, etc. @@ -203,7 +203,7 @@ Text Generation - Mamba - :code:`state-spaces/mamba-130m-hf`, :code:`state-spaces/mamba-790m-hf`, :code:`state-spaces/mamba-2.8b-hf`, etc. - - - + - ✅︎ * - :code:`MiniCPMForCausalLM` - MiniCPM - :code:`openbmb/MiniCPM-2B-sft-bf16`, :code:`openbmb/MiniCPM-2B-dpo-bf16`, :code:`openbmb/MiniCPM-S-1B-sft`, etc. From 46e3c25534ea4df2939d2c9a30f0da1706fa3033 Mon Sep 17 00:00:00 2001 From: mzusman Date: Mon, 9 Dec 2024 17:44:16 +0200 Subject: [PATCH 16/29] Use only needed number layers per PP Signed-off-by: mzusman --- tests/distributed/test_pipeline_parallel.py | 2 +- vllm/model_executor/models/mamba.py | 8 +++++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/tests/distributed/test_pipeline_parallel.py b/tests/distributed/test_pipeline_parallel.py index 83f621ee0f9af..85d408efafe96 100644 --- a/tests/distributed/test_pipeline_parallel.py +++ b/tests/distributed/test_pipeline_parallel.py @@ -180,7 +180,7 @@ def iter_params(self, model_name: str): "Qwen/Qwen1.5-MoE-A2.7B-Chat": PPTestSettings.fast(), "stabilityai/stablelm-3b-4e1t": PPTestSettings.fast(), "bigcode/starcoder2-3b": PPTestSettings.fast(), - "upstage/solar-pro-preview-instruct": PPTestSettings.fast(tp_base=2) + "upstage/solar-pro-preview-instruct": PPTestSettings.fast(tp_base=2), # FIXME: Cannot load tokenizer in latest transformers version. # Need to use tokenizer from `meta-llama/Llama-2-7b-chat-hf` # "xverse/XVERSE-7B-Chat": PPTestSettings.fast(trust_remote_code=True), diff --git a/vllm/model_executor/models/mamba.py b/vllm/model_executor/models/mamba.py index fb0d3fe281abf..c7bf5e146cf37 100644 --- a/vllm/model_executor/models/mamba.py +++ b/vllm/model_executor/models/mamba.py @@ -24,6 +24,7 @@ MambaCacheParams) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors +from vllm.utils import LayerBlockType from .utils import (is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, @@ -163,7 +164,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() self.config = config + self.vllm_config = vllm_config self.scheduler_config = scheduler_config + self.model_config = vllm_config.model_config self.backbone = MambaModel(vllm_config=vllm_config, prefix=maybe_prefix(prefix, "backbone")) self.unpadded_vocab_size = config.vocab_size @@ -207,8 +210,11 @@ def forward(self, max_batch_size = (VllmConfig.get_graph_batch_size( self.scheduler_config.max_num_seqs) if self.scheduler_config else max(_BATCH_SIZES_TO_CAPTURE) + 2) + + num_mamba_layers = self.model_config.get_num_layers_by_block_type( + self.vllm_config.parallel_config, LayerBlockType.mamba) self.mamba_cache = MambaCacheManager( - self.lm_head.weight.dtype, self.config.num_hidden_layers, + self.lm_head.weight.dtype, num_mamba_layers, max_batch_size, *self._get_mamba_cache_shape()) ( From f3544689d069c3771f79f29503381d15a5673dd3 Mon Sep 17 00:00:00 2001 From: mzusman Date: Mon, 9 Dec 2024 18:00:28 +0200 Subject: [PATCH 17/29] Fix `get_num_layers_by_block_type` to always return num of layers regardless of given block_type Signed-off-by: mzusman --- vllm/config.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vllm/config.py b/vllm/config.py index 5e419be4ec8aa..f513363bb7ee6 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -699,6 +699,8 @@ def get_num_layers_by_block_type( parallel_config: "ParallelConfig", block_type: LayerBlockType = LayerBlockType.attention, ) -> int: + if self.is_attention_free and block_type == LayerBlockType.attention: + return 0 start, end = self.get_layers_start_end_indices(parallel_config) # Transformers supports layers_block_type @property From 529ade68ec802e93d5720a5d9b9a743d166343ee Mon Sep 17 00:00:00 2001 From: mzusman Date: Mon, 9 Dec 2024 18:02:28 +0200 Subject: [PATCH 18/29] Format Signed-off-by: mzusman --- vllm/model_executor/models/mamba.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/models/mamba.py b/vllm/model_executor/models/mamba.py index c7bf5e146cf37..c9e83334d58fc 100644 --- a/vllm/model_executor/models/mamba.py +++ b/vllm/model_executor/models/mamba.py @@ -210,12 +210,12 @@ def forward(self, max_batch_size = (VllmConfig.get_graph_batch_size( self.scheduler_config.max_num_seqs) if self.scheduler_config else max(_BATCH_SIZES_TO_CAPTURE) + 2) - + num_mamba_layers = self.model_config.get_num_layers_by_block_type( self.vllm_config.parallel_config, LayerBlockType.mamba) self.mamba_cache = MambaCacheManager( - self.lm_head.weight.dtype, num_mamba_layers, - max_batch_size, *self._get_mamba_cache_shape()) + self.lm_head.weight.dtype, num_mamba_layers, max_batch_size, + *self._get_mamba_cache_shape()) ( mamba_cache_tensors, From e7f63387245bdf798e2d8f11a7a651e89d5b61d4 Mon Sep 17 00:00:00 2001 From: mzusman Date: Mon, 9 Dec 2024 18:22:03 +0200 Subject: [PATCH 19/29] Add comment for workaround Signed-off-by: mzusman --- vllm/config.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/vllm/config.py b/vllm/config.py index f513363bb7ee6..e86affc77b934 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -699,6 +699,9 @@ def get_num_layers_by_block_type( parallel_config: "ParallelConfig", block_type: LayerBlockType = LayerBlockType.attention, ) -> int: + # This function relies on 'layers_block_type' in hf_config, + # for hybrid/attention free models w/o this attribute, + # we will need to have workarounds like so if self.is_attention_free and block_type == LayerBlockType.attention: return 0 start, end = self.get_layers_start_end_indices(parallel_config) From 6d863ae2605fb0dd1312fe52ff5d5393fa02bade Mon Sep 17 00:00:00 2001 From: mzusman Date: Mon, 9 Dec 2024 18:25:16 +0200 Subject: [PATCH 20/29] Fix comment Signed-off-by: mzusman --- vllm/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/config.py b/vllm/config.py index e86affc77b934..4039c78be288b 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -700,7 +700,7 @@ def get_num_layers_by_block_type( block_type: LayerBlockType = LayerBlockType.attention, ) -> int: # This function relies on 'layers_block_type' in hf_config, - # for hybrid/attention free models w/o this attribute, + # for hybrid/attention-free models w/o this attribute, # we will need to have workarounds like so if self.is_attention_free and block_type == LayerBlockType.attention: return 0 From d9ef20d6a6a20b9f643f23f4560d7b96a1abbef8 Mon Sep 17 00:00:00 2001 From: mzusman Date: Mon, 9 Dec 2024 19:27:06 +0200 Subject: [PATCH 21/29] Add IsHybrid interface to return 0 layers on non-hybrid models asking about non-attention layers Signed-off-by: mzusman --- vllm/config.py | 13 ++++++++- vllm/model_executor/models/interfaces.py | 37 ++++++++++++++++++++++++ vllm/model_executor/models/jamba.py | 5 ++-- vllm/model_executor/models/registry.py | 11 ++++++- 4 files changed, 62 insertions(+), 4 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 4039c78be288b..c8d1cc9952515 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -284,6 +284,7 @@ def __init__( self._verify_tokenizer_mode() self.is_attention_free = self._init_attention_free() + self.is_hybrid = self._init_is_hybrid() self.has_inner_state = self._init_has_inner_state() if current_platform.is_neuron(): @@ -340,6 +341,10 @@ def _init_attention_free(self) -> bool: architectures = getattr(self.hf_config, "architectures", []) return ModelRegistry.is_attention_free_model(architectures) + def _init_is_hybrid(self) -> bool: + architectures = getattr(self.hf_config, "architectures", []) + return ModelRegistry.is_hybrid_model(architectures) + def _init_has_inner_state(self) -> bool: architectures = getattr(self.hf_config, "architectures", []) return ModelRegistry.model_has_inner_state(architectures) @@ -704,11 +709,17 @@ def get_num_layers_by_block_type( # we will need to have workarounds like so if self.is_attention_free and block_type == LayerBlockType.attention: return 0 + if not self.is_hybrid and block_type != LayerBlockType.attention: + return 0 + start, end = self.get_layers_start_end_indices(parallel_config) # Transformers supports layers_block_type @property layers_block_type_value = getattr(self.hf_config, "layers_block_type", - [block_type.value] * (end - start)) + None) + if layers_block_type_value is None: + return end - start + return sum(t == block_type.value for t in layers_block_type_value[start:end]) diff --git a/vllm/model_executor/models/interfaces.py b/vllm/model_executor/models/interfaces.py index c3979eab905db..70b78fe64f2d8 100644 --- a/vllm/model_executor/models/interfaces.py +++ b/vllm/model_executor/models/interfaces.py @@ -363,6 +363,43 @@ def is_attention_free( return isinstance(model, IsAttentionFree) +@runtime_checkable +class IsHybrid(Protocol): + """The interface required for all models like Jamba that have both + attention and mamba blocks, indicates that + hf_config has 'layers_block_type'""" + + is_hybrid: ClassVar[Literal[True]] = True + """ + A flag that indicates this model has both mamba and attention blocks + , also indicates that the model's hf_config has + 'layers_block_type' """ + + +@runtime_checkable +class _IsHybridType(Protocol): + is_hybrid: ClassVar[Literal[True]] + + +@overload +def is_hybrid(model: object) -> TypeIs[IsHybrid]: + ... + + +@overload +def is_hybrid(model: Type[object]) -> TypeIs[Type[IsHybrid]]: + ... + + +def is_hybrid( + model: Union[Type[object], object] +) -> Union[TypeIs[Type[IsHybrid]], TypeIs[IsHybrid]]: + if isinstance(model, type): + return isinstance(model, _IsHybridType) + + return isinstance(model, IsHybrid) + + @runtime_checkable class SupportsCrossEncoding(Protocol): """The interface required for all models that support cross encoding.""" diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index 5aebbed9f0ba4..7b4c67cedab7b 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -28,7 +28,7 @@ from vllm.sequence import IntermediateTensors from vllm.utils import LayerBlockType -from .interfaces import HasInnerState, SupportsLoRA, SupportsPP +from .interfaces import HasInnerState, IsHybridModel, SupportsLoRA, SupportsPP from .utils import (is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) @@ -361,7 +361,8 @@ def forward( return hidden_states -class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP): +class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, + IsHybridModel): packed_modules_mapping = { "qkv_proj": [ "q_proj", diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index e69596aa915b5..4beea4641f5ab 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -21,7 +21,7 @@ from vllm.platforms import current_platform from .adapters import as_embedding_model -from .interfaces import (has_inner_state, is_attention_free, +from .interfaces import (has_inner_state, is_attention_free, is_hybrid, supports_cross_encoding, supports_multimodal, supports_pp) from .interfaces_base import is_pooling_model, is_text_generation_model @@ -218,6 +218,7 @@ class _ModelInfo: supports_pp: bool has_inner_state: bool is_attention_free: bool + is_hybrid: bool @staticmethod def from_model_cls(model: Type[nn.Module]) -> "_ModelInfo": @@ -239,6 +240,7 @@ def from_model_cls(model: Type[nn.Module]) -> "_ModelInfo": supports_pp=supports_pp(model), has_inner_state=has_inner_state(model), is_attention_free=is_attention_free(model), + is_hybrid=is_hybrid(model), ) @@ -484,6 +486,13 @@ def is_attention_free_model( model_cls, _ = self.inspect_model_cls(architectures) return model_cls.is_attention_free + def is_hybrid_model( + self, + architectures: Union[str, List[str]], + ) -> bool: + model_cls, _ = self.inspect_model_cls(architectures) + return model_cls.is_hybrid + ModelRegistry = _ModelRegistry({ model_arch: _LazyRegisteredModel( From 029d7104e414a532440c39992bbe8ee889edc339 Mon Sep 17 00:00:00 2001 From: mzusman Date: Mon, 9 Dec 2024 19:29:07 +0200 Subject: [PATCH 22/29] Typo Signed-off-by: mzusman --- vllm/model_executor/models/jamba.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index 7b4c67cedab7b..6bb4c13ab35df 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -28,7 +28,7 @@ from vllm.sequence import IntermediateTensors from vllm.utils import LayerBlockType -from .interfaces import HasInnerState, IsHybridModel, SupportsLoRA, SupportsPP +from .interfaces import HasInnerState, IsHybrid, SupportsLoRA, SupportsPP from .utils import (is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) @@ -362,7 +362,7 @@ def forward( class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, - IsHybridModel): + IsHybrid): packed_modules_mapping = { "qkv_proj": [ "q_proj", From 3bc382322934f0b1abbe27ec738cdf6bc7d736b1 Mon Sep 17 00:00:00 2001 From: mzusman Date: Mon, 9 Dec 2024 19:37:40 +0200 Subject: [PATCH 23/29] Format Signed-off-by: mzusman --- vllm/v1/worker/gpu_model_runner.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 5658c37d39dbd..ea1ff42c105ff 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -15,8 +15,8 @@ from vllm.model_executor.model_loader import get_model from vllm.multimodal import MultiModalKwargs from vllm.sampling_params import SamplingType -from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, LayerBlockType, cdiv, - is_pin_memory_available) +from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, + LayerBlockType, cdiv, is_pin_memory_available) from vllm.v1.attention.backends.flash_attn import (FlashAttentionBackend, FlashAttentionMetadata) from vllm.v1.outputs import ModelRunnerOutput From 0f701311845fd3a20cf7b0bd0c57c78ae993c63e Mon Sep 17 00:00:00 2001 From: mzusman Date: Mon, 9 Dec 2024 23:58:40 +0200 Subject: [PATCH 24/29] Fix mamba logic Signed-off-by: mzusman --- vllm/config.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index b57e78723e86d..995819cdb09ba 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -696,9 +696,12 @@ def get_num_layers_by_block_type( # This function relies on 'layers_block_type' in hf_config, # for hybrid/attention-free models w/o this attribute, # we will need to have workarounds like so - if self.is_attention_free and block_type == LayerBlockType.attention: + attn_block_type = block_type == LayerBlockType.attention + is_full_attn_model = not self.is_hybrid and not self.is_attention_free + + if self.is_attention_free and attn_block_type: return 0 - if not self.is_hybrid and block_type != LayerBlockType.attention: + if is_full_attn_model and attn_block_type: return 0 start, end = self.get_layers_start_end_indices(parallel_config) From fcce742db4435a145bec8d95e77ec650361e3f1c Mon Sep 17 00:00:00 2001 From: mzusman Date: Tue, 10 Dec 2024 14:22:34 +0200 Subject: [PATCH 25/29] Small fix to logic Signed-off-by: mzusman --- vllm/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/config.py b/vllm/config.py index 60a32fb182dbb..e8475efcb56e1 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -701,7 +701,7 @@ def get_num_layers_by_block_type( if self.is_attention_free and attn_block_type: return 0 - if is_full_attn_model and attn_block_type: + if is_full_attn_model and not attn_block_type: return 0 start, end = self.get_layers_start_end_indices(parallel_config) From e3d8343357f192a007f13108fc3f163bf5f72a84 Mon Sep 17 00:00:00 2001 From: mzusman Date: Tue, 10 Dec 2024 19:02:58 +0200 Subject: [PATCH 26/29] FIx logic Signed-off-by: mzusman --- vllm/config.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/vllm/config.py b/vllm/config.py index e8475efcb56e1..13dc8cb431137 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -710,7 +710,12 @@ def get_num_layers_by_block_type( layers_block_type_value = getattr(self.hf_config, "layers_block_type", None) if layers_block_type_value is None: - return end - start + if not self.is_hybrid: + return end - start + raise ValueError("The model is an hybrid without a" + "layers_block_type in the hf_config," + "cannot determine the num of " + f"{block_type.value} layers") return sum(t == block_type.value for t in layers_block_type_value[start:end]) From 6d6b3d4e9cd4f420d7101035ee951b73233a1879 Mon Sep 17 00:00:00 2001 From: mzusman Date: Tue, 10 Dec 2024 19:47:01 +0200 Subject: [PATCH 27/29] Fixing according to comment Signed-off-by: mzusman --- vllm/config.py | 41 +++++++++++++++++++++-------------------- 1 file changed, 21 insertions(+), 20 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 13dc8cb431137..ee8f72f469aae 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -697,28 +697,29 @@ def get_num_layers_by_block_type( # for hybrid/attention-free models w/o this attribute, # we will need to have workarounds like so attn_block_type = block_type == LayerBlockType.attention - is_full_attn_model = not self.is_hybrid and not self.is_attention_free - - if self.is_attention_free and attn_block_type: - return 0 - if is_full_attn_model and not attn_block_type: - return 0 - + is_transformer = not self.is_hybrid and not self.is_attention_free start, end = self.get_layers_start_end_indices(parallel_config) - # Transformers supports layers_block_type @property - layers_block_type_value = getattr(self.hf_config, "layers_block_type", - None) - if layers_block_type_value is None: - if not self.is_hybrid: - return end - start - raise ValueError("The model is an hybrid without a" - "layers_block_type in the hf_config," - "cannot determine the num of " - f"{block_type.value} layers") - - return sum(t == block_type.value - for t in layers_block_type_value[start:end]) + if is_transformer: + # Handle the basic case first + return end - start if attn_block_type else 0 + elif self.is_attention_free: + # Attention free + # Note that this code assumes there + # is only one type of attention-free block type. + return 0 if attn_block_type else end - start + else: + # Hybrid model + layers_block_type_value = getattr(self.hf_config, "layers_block_type", + None) + if layers_block_type_value is None: + raise ValueError("The model is an hybrid without a" + "layers_block_type in the hf_config," + "cannot determine the num of " + f"{block_type.value} layers") + + return sum(t == block_type.value + for t in layers_block_type_value[start:end]) def get_multimodal_config(self) -> "MultiModalConfig": """ From c034ffee11693c7f51f13fc2436e5b00e6ac6706 Mon Sep 17 00:00:00 2001 From: mzusman Date: Tue, 10 Dec 2024 19:48:19 +0200 Subject: [PATCH 28/29] Format Signed-off-by: mzusman --- vllm/config.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index ee8f72f469aae..6425125954114 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -705,13 +705,13 @@ def get_num_layers_by_block_type( return end - start if attn_block_type else 0 elif self.is_attention_free: # Attention free - # Note that this code assumes there + # Note that this code assumes there # is only one type of attention-free block type. return 0 if attn_block_type else end - start else: # Hybrid model - layers_block_type_value = getattr(self.hf_config, "layers_block_type", - None) + layers_block_type_value = getattr(self.hf_config, + "layers_block_type", None) if layers_block_type_value is None: raise ValueError("The model is an hybrid without a" "layers_block_type in the hf_config," From 75ff1ef22f094ed22fd6c286a0c5a0fe900c703f Mon Sep 17 00:00:00 2001 From: mzusman Date: Tue, 10 Dec 2024 20:03:48 +0200 Subject: [PATCH 29/29] Fix comment Signed-off-by: mzusman --- vllm/config.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 6425125954114..e2d75fc569540 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -694,8 +694,7 @@ def get_num_layers_by_block_type( block_type: LayerBlockType = LayerBlockType.attention, ) -> int: # This function relies on 'layers_block_type' in hf_config, - # for hybrid/attention-free models w/o this attribute, - # we will need to have workarounds like so + # for w/o this attribute, we will need to have workarounds like so attn_block_type = block_type == LayerBlockType.attention is_transformer = not self.is_hybrid and not self.is_attention_free start, end = self.get_layers_start_end_indices(parallel_config)