diff --git a/tests/models/test_jamba.py b/tests/models/test_jamba.py index 0a5fe19f80ec6..774f2d9d9cdbc 100644 --- a/tests/models/test_jamba.py +++ b/tests/models/test_jamba.py @@ -1,5 +1,6 @@ import pytest +from tests.models.utils import check_outputs_equal from vllm.worker.model_runner import _get_graph_batch_size MODELS = ["ai21labs/Jamba-tiny-random"] @@ -34,6 +35,34 @@ def test_models( f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}") +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["float"]) +@pytest.mark.parametrize("max_tokens", [20]) +def test_batching( + vllm_runner, + example_prompts, + model: str, + dtype: str, + max_tokens: int, +) -> None: + # To pass the small model tests, we need full precision. + for_loop_outputs = [] + with vllm_runner(model, dtype=dtype) as vllm_model: + for prompt in example_prompts: + for_loop_outputs.append( + vllm_model.generate_greedy([prompt], max_tokens)[0]) + + batched_outputs = vllm_model.generate_greedy(example_prompts, + max_tokens) + + check_outputs_equal( + outputs_0_lst=for_loop_outputs, + outputs_1_lst=batched_outputs, + name_0="for_loop_vllm", + name_1="batched_vllm", + ) + + @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", ["bfloat16"]) @pytest.mark.parametrize("max_tokens", [20]) @@ -60,6 +89,60 @@ def test_mamba_cache_cg_padding( "Could be related to mamba cache not padded correctly") +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["float"]) +@pytest.mark.parametrize("max_tokens", [20]) +def test_models_preemption_recompute( + hf_runner, + vllm_runner, + example_prompts, + model: str, + dtype: str, + max_tokens: int, +) -> None: + # Tests that outputs are identical with and w/o preemtions (recompute) + assert dtype == "float" + + with vllm_runner(model, dtype=dtype) as vllm_model: + vllm_model.model.llm_engine.scheduler[ + 0].ENABLE_ARTIFICIAL_PREEMPT = True + preempt_vllm_outputs = vllm_model.generate_greedy( + example_prompts, max_tokens) + + vllm_model.model.llm_engine.scheduler[ + 0].ENABLE_ARTIFICIAL_PREEMPT = False + vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) + + check_outputs_equal( + outputs_0_lst=preempt_vllm_outputs, + outputs_1_lst=vllm_outputs, + name_0="vllm_preepmtions", + name_1="vllm", + ) + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["float"]) +def test_fail_upon_inc_requests_and_finished_requests_lt_available_blocks( + vllm_runner, + model: str, + dtype: str, + example_prompts, +) -> None: + # This test is for verifying that the Jamba inner state management doesn't + # collapse in case where the number of incoming requests and + # finished_requests_ids is larger than the maximum mamba block capacity. + # This could generally happen due to the fact that Jamba does support + # statelessness mechanism where it can cleanup new incoming requests in + # a single step. + try: + with vllm_runner(model, dtype=dtype, max_num_seqs=10) as vllm_model: + vllm_model.generate_greedy([example_prompts[0]] * 100, 10) + except ValueError: + pytest.fail("Jamba inner state wasn't cleaned up properly between" + "steps finished requests registered unnecessarily ") + + @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", ["float"]) def test_state_cleanup( diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 6bda18cd4f061..6e59c5e0f74f3 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -374,6 +374,7 @@ def abort_seq_group(self, request_id: Union[str, Iterable[str]]) -> None: for aborted_group in aborted_groups: # Remove the sequence group from the state queue. state_queue.remove(aborted_group) + self._finished_requests_ids.append(aborted_group.request_id) for seq in aborted_group.get_seqs(): if seq.is_finished(): continue diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index 60547965063fa..e169a98bb509c 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -32,7 +32,8 @@ filter_duplicate_safetensors_files, filter_files_not_needed_for_inference, get_quant_config, initialize_dummy_weights, np_cache_weights_iterator, pt_weights_iterator, safetensors_weights_iterator) -from vllm.model_executor.models.interfaces import (supports_lora, +from vllm.model_executor.models.interfaces import (has_inner_state, + supports_lora, supports_vision) from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform @@ -66,10 +67,10 @@ def _get_quantization_config( def _get_model_initialization_kwargs( - model_class: Type[nn.Module], - lora_config: Optional[LoRAConfig], - multimodal_config: Optional[MultiModalConfig], -) -> Dict[str, Any]: + model_class: Type[nn.Module], + lora_config: Optional[LoRAConfig], + multimodal_config: Optional[MultiModalConfig], + scheduler_config: Optional[SchedulerConfig] = None) -> Dict[str, Any]: """Get extra kwargs for model initialization.""" extra_kwargs: Dict[str, Any] = {} @@ -90,13 +91,19 @@ def _get_model_initialization_kwargs( extra_kwargs["multimodal_config"] = multimodal_config + if has_inner_state(model_class) and scheduler_config: + extra_kwargs["scheduler_config"] = scheduler_config + return extra_kwargs -def _initialize_model(model_config: ModelConfig, load_config: LoadConfig, - lora_config: Optional[LoRAConfig], - multimodal_config: Optional[MultiModalConfig], - cache_config: CacheConfig) -> nn.Module: +def _initialize_model( + model_config: ModelConfig, + load_config: LoadConfig, + lora_config: Optional[LoRAConfig], + multimodal_config: Optional[MultiModalConfig], + cache_config: CacheConfig, + scheduler_config: Optional[SchedulerConfig] = None) -> nn.Module: """Initialize a model with the given configurations.""" model_class = get_model_architecture(model_config)[0] quant_config = _get_quantization_config(model_config, load_config) @@ -105,7 +112,8 @@ def _initialize_model(model_config: ModelConfig, load_config: LoadConfig, cache_config=cache_config, quant_config=quant_config, **_get_model_initialization_kwargs( - model_class, lora_config, multimodal_config)) + model_class, lora_config, multimodal_config, + scheduler_config)) class BaseModelLoader(ABC): @@ -266,7 +274,7 @@ def load_model(self, *, model_config: ModelConfig, with torch.device(device_config.device): model = _initialize_model(model_config, self.load_config, lora_config, multimodal_config, - cache_config) + cache_config, scheduler_config) model.load_weights( self._get_weights_iterator(model_config.model, model_config.revision, @@ -306,7 +314,7 @@ def load_model(self, *, model_config: ModelConfig, with torch.device(device_config.device): model = _initialize_model(model_config, self.load_config, lora_config, multimodal_config, - cache_config) + cache_config, scheduler_config) # NOTE(woosuk): For accurate performance evaluation, we assign # random values to the weights. initialize_dummy_weights(model) diff --git a/vllm/model_executor/models/interfaces.py b/vllm/model_executor/models/interfaces.py index 2697a6996f4ca..6fdacd4469788 100644 --- a/vllm/model_executor/models/interfaces.py +++ b/vllm/model_executor/models/interfaces.py @@ -3,7 +3,7 @@ from typing_extensions import TypeGuard -from vllm.config import LoRAConfig, MultiModalConfig +from vllm.config import LoRAConfig, MultiModalConfig, SchedulerConfig from vllm.logger import init_logger logger = init_logger(__name__) @@ -142,3 +142,49 @@ def _supports_lora( return isinstance(model, _SupportsLoRAType) return isinstance(model, SupportsLoRA) + + +@runtime_checkable +class HasInnerState(Protocol): + """The interface required for all models that has inner state.""" + + has_inner_state: ClassVar[Literal[True]] = True + """ + A flag that indicates this model has inner state. + Models that has inner state usually need access to the scheduler_config + for max_num_seqs ,etc... (Currently only used by Jamba) + """ + + def __init__(self, + *, + scheduler_config: Optional[SchedulerConfig] = None) -> None: + ... + + +@runtime_checkable +class _HasInnerStateType(Protocol): + has_inner_state: ClassVar[Literal[True]] + + def __init__(self, + *, + scheduler_config: Optional[SchedulerConfig] = None) -> None: + ... + + +@overload +def has_inner_state(model: object) -> TypeGuard[HasInnerState]: + ... + + +@overload +def has_inner_state(model: Type[object]) -> TypeGuard[Type[HasInnerState]]: + ... + + +def has_inner_state( + model: Union[Type[object], object] +) -> Union[TypeGuard[Type[HasInnerState]], TypeGuard[HasInnerState]]: + if isinstance(model, type): + return isinstance(model, _HasInnerStateType) + + return isinstance(model, HasInnerState) diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index 4524d8df86b9d..d4e4f0055aa2b 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -13,7 +13,7 @@ from vllm.attention.backends.abstract import AttentionMetadata from vllm.attention.layer import Attention -from vllm.config import CacheConfig, LoRAConfig +from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig from vllm.distributed import (get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce) @@ -32,10 +32,12 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.interfaces import HasInnerState from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_weight_attrs from vllm.sequence import IntermediateTensors, SamplerOutput -from vllm.worker.model_runner import _BATCH_SIZES_TO_CAPTURE +from vllm.worker.model_runner import (_BATCH_SIZES_TO_CAPTURE, + _get_graph_batch_size) KVCache = Tuple[torch.Tensor, torch.Tensor] @@ -612,7 +614,7 @@ def forward( return hidden_states -class JambaForCausalLM(nn.Module): +class JambaForCausalLM(nn.Module, HasInnerState): packed_modules_mapping = { "qkv_proj": [ "q_proj", @@ -640,9 +642,11 @@ def __init__( cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, + scheduler_config: Optional[SchedulerConfig] = None, ) -> None: super().__init__() self.config = config + self.scheduler_config = scheduler_config self.model = JambaModel(config, cache_config=cache_config, quant_config=quant_config, @@ -689,6 +693,8 @@ def forward(self, for key in ["request_ids_to_seq_ids", "finished_requests_ids"]) request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"] + finished_requests_ids = kwargs["finished_requests_ids"] + self._release_mamba_cache(finished_requests_ids) batch_size = input_ids.shape[0] if attn_metadata.prefill_metadata: batch_size = len(request_ids_to_seq_ids) @@ -696,9 +702,8 @@ def forward(self, current_seqlen_agnostic_cache, indices, ) = self._prepare_current_run_mamba_cache(request_ids_to_seq_ids, - batch_size) - finished_requests_ids = kwargs["finished_requests_ids"] - self._release_mamba_cache(finished_requests_ids) + batch_size, + finished_requests_ids) else: # CUDA graph capturing runs current_seqlen_agnostic_cache, indices = ( @@ -760,10 +765,15 @@ def _assign_seq_id_to_mamba_cache(self, cur_rid: str, return indices_for_current_run def _prepare_current_run_mamba_cache( - self, request_ids_to_seq_ids: Dict[str, list[int]], batch_size: int + self, request_ids_to_seq_ids: Dict[str, list[int]], batch_size: int, + finished_requests_ids: List[str] ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], List[int]]: indices_for_current_run = [] for request_id, seqs_id in request_ids_to_seq_ids.items(): + if request_id in finished_requests_ids: + # Do not allocate cache for requests that run + # and finish right after + continue indices_for_current_run += self._assign_seq_id_to_mamba_cache( request_id, seqs_id) ## Pad the batch in case of running batch that was not captured via CG @@ -787,16 +797,17 @@ def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): assert all( key in kwargs for key in ["request_ids_to_seq_ids", "finished_requests_ids"]) + finished_requests_ids = kwargs["finished_requests_ids"] + self._release_mamba_cache(finished_requests_ids) request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"] cg_batch_size = input_buffers['input_ids'].shape[0] ( current_mamba_cache, indices, ) = self._prepare_current_run_mamba_cache(request_ids_to_seq_ids, - cg_batch_size) + cg_batch_size, + finished_requests_ids) self.current_indices = indices - finished_requests_ids = kwargs["finished_requests_ids"] - self._release_mamba_cache(finished_requests_ids) for input_buffer, current_cache_buffer in zip( input_buffers["seqlen_agnostic_capture_inputs"], @@ -860,9 +871,12 @@ def _prepare_mamba_cache(self): layers_type = self.config.layers_block_type mamba_layers = sum( [layer_type == "mamba" for layer_type in layers_type]) - max_batch_size = _BATCH_SIZES_TO_CAPTURE[-1] + 10 + max_batch_size = (_get_graph_batch_size( + self.scheduler_config.max_num_seqs) if self.scheduler_config else + max(_BATCH_SIZES_TO_CAPTURE)) + 10 conv_state_shape, temporal_state_shape = self._get_mamba_cache_shape() assert conv_state_shape is not None and temporal_state_shape is not None + for buffername in ["mamba_cache", "mamba_gc_cache_buffer"]: buffer = (torch.empty(size=(mamba_layers, max_batch_size) + conv_state_shape,