From 4baae800758b27da975dcc5190646a0d3a499386 Mon Sep 17 00:00:00 2001 From: Mor Zusman Date: Sun, 14 Jul 2024 16:31:06 +0300 Subject: [PATCH 01/20] Add Jamba tests, preemption, the turning door bug test and cleanup --- tests/models/test_jamba.py | 76 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 76 insertions(+) diff --git a/tests/models/test_jamba.py b/tests/models/test_jamba.py index 0a5fe19f80ec6..d3b915227388a 100644 --- a/tests/models/test_jamba.py +++ b/tests/models/test_jamba.py @@ -60,6 +60,82 @@ def test_mamba_cache_cg_padding( "Could be related to mamba cache not padded correctly") +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["float"]) +@pytest.mark.parametrize("max_tokens", [96]) +def test_models_preemption_recompute( + hf_runner, + vllm_runner, + example_prompts, + model: str, + dtype: str, + max_tokens: int, +) -> None: + # Tests that outputs are identical with and w/o preemtions (recompute) + assert dtype == "float" + + with vllm_runner(model, dtype=dtype) as vllm_model: + vllm_model.model.llm_engine.scheduler[ + 0].ENABLE_ARTIFICIAL_PREEMPT = True + preempt_vllm_outputs = vllm_model.generate_greedy( + example_prompts, max_tokens) + + vllm_model.model.llm_engine.scheduler[ + 0].ENABLE_ARTIFICIAL_PREEMPT = False + vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) + + check_outputs_equal( + outputs_0_lst=preempt_vllm_outputs, + outputs_1_lst=vllm_outputs, + name_0="vllm_preepmtions", + name_1="vllm", + ) + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["float"]) +def test_fail_upon_inc_requests_and_finished_requests_lt_available_blocks( + vllm_runner, + model: str, + dtype: str, + example_prompts, +) -> None: + # This test is for verifying that the Jamba inner state management doesn't + # collapse in case where the number of incoming requests and + # finished_requests_ids is larger than the maximum mamba block capacity. + # This could generally happen due to the fact that Jamba does support + # statelessness mechanism where it can cleanup new incoming requests in + # a single step. + try: + with vllm_runner(model, dtype=dtype, max_num_seqs=10) as vllm_model: + vllm_model.generate_greedy([example_prompts[0]] * 100, 10) + except ValueError: + pytest.fail("Jamba inner state wasn't cleaned up properly between" + "steps finished requests registered unnecessarily ") + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["float"]) +def test_cleanup_upon_aborted_requests( + vllm_runner, + model: str, + dtype: str, + example_prompts, +) -> None: + # This test is for verifying that the Jamba inner state management doesn't + # collapse in case where the number of incoming requests and + # finished_requests_ids is larger than the maximum mamba block capacity. + # This could generally happen due to the fact that Jamba does support + # statelessness mechanism where it can cleanup new incoming requests in + # a single step. + try: + with vllm_runner(model, dtype=dtype, max_num_seqs=10) as vllm_model: + vllm_model.generate_greedy([example_prompts[0]] * 100, 10) + except ValueError: + pytest.fail("Jamba inner state wasn't cleaned up properly between" + "steps finished requests registered unnecessarily ") + + @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", ["float"]) def test_state_cleanup( From f2ebf514586baac9d35441ccff3d0ef87917c94e Mon Sep 17 00:00:00 2001 From: Mor Zusman Date: Sun, 14 Jul 2024 16:32:31 +0300 Subject: [PATCH 02/20] Add aborted requests ids to the finished_requests_ids --- vllm/core/scheduler.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 6bda18cd4f061..6e59c5e0f74f3 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -374,6 +374,7 @@ def abort_seq_group(self, request_id: Union[str, Iterable[str]]) -> None: for aborted_group in aborted_groups: # Remove the sequence group from the state queue. state_queue.remove(aborted_group) + self._finished_requests_ids.append(aborted_group.request_id) for seq in aborted_group.get_seqs(): if seq.is_finished(): continue From f6603d2139ddac0215e846bbac401e7e13aec49d Mon Sep 17 00:00:00 2001 From: Mor Zusman Date: Sun, 14 Jul 2024 16:33:01 +0300 Subject: [PATCH 03/20] Add interface for HasInnerState models --- vllm/model_executor/models/interfaces.py | 35 +++++++++++++++++++++++- 1 file changed, 34 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/models/interfaces.py b/vllm/model_executor/models/interfaces.py index 2697a6996f4ca..4eb7f8167e581 100644 --- a/vllm/model_executor/models/interfaces.py +++ b/vllm/model_executor/models/interfaces.py @@ -3,7 +3,7 @@ from typing_extensions import TypeGuard -from vllm.config import LoRAConfig, MultiModalConfig +from vllm.config import LoRAConfig, MultiModalConfig, SchedulerConfig from vllm.logger import init_logger logger = init_logger(__name__) @@ -142,3 +142,36 @@ def _supports_lora( return isinstance(model, _SupportsLoRAType) return isinstance(model, SupportsLoRA) + + +@runtime_checkable +class HasInnerState(Protocol): + """The interface required for all models that has inner state.""" + + has_inner_state: ClassVar[Literal[True]] = True + """ + A flag that indicates this model has inner state. + Models that has inner state usually need access to the scheduler_config + for max_num_seqs ,etc... (Currently only used by Jamba) + """ + + def __init__(self, + *, + scheduler_config: Optional[SchedulerConfig] = None) -> None: + ... + + +@overload +def has_inner_state(model: object) -> TypeGuard[HasInnerState]: + ... + + +@overload +def has_inner_state(model: Type[object]) -> TypeGuard[Type[HasInnerState]]: + ... + + +def has_inner_state( + model: Union[Type[object], object] +) -> Union[TypeGuard[Type[HasInnerState]], TypeGuard[HasInnerState]]: + return isinstance(model, HasInnerState) From 607dc751b15655df33c7f3738179b91229711ee5 Mon Sep 17 00:00:00 2001 From: Mor Zusman Date: Sun, 14 Jul 2024 16:34:45 +0300 Subject: [PATCH 04/20] Fix the "turning door" bug --- vllm/model_executor/models/jamba.py | 21 ++++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index 4524d8df86b9d..a95906fc4939f 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -689,6 +689,8 @@ def forward(self, for key in ["request_ids_to_seq_ids", "finished_requests_ids"]) request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"] + finished_requests_ids = kwargs["finished_requests_ids"] + self._release_mamba_cache(finished_requests_ids) batch_size = input_ids.shape[0] if attn_metadata.prefill_metadata: batch_size = len(request_ids_to_seq_ids) @@ -696,9 +698,8 @@ def forward(self, current_seqlen_agnostic_cache, indices, ) = self._prepare_current_run_mamba_cache(request_ids_to_seq_ids, - batch_size) - finished_requests_ids = kwargs["finished_requests_ids"] - self._release_mamba_cache(finished_requests_ids) + batch_size, + finished_requests_ids) else: # CUDA graph capturing runs current_seqlen_agnostic_cache, indices = ( @@ -760,10 +761,15 @@ def _assign_seq_id_to_mamba_cache(self, cur_rid: str, return indices_for_current_run def _prepare_current_run_mamba_cache( - self, request_ids_to_seq_ids: Dict[str, list[int]], batch_size: int + self, request_ids_to_seq_ids: Dict[str, list[int]], batch_size: int, + finished_requests_ids: List[str] ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], List[int]]: indices_for_current_run = [] for request_id, seqs_id in request_ids_to_seq_ids.items(): + if request_id in finished_requests_ids: + # Do not allocate cache for requests that run + # and finish right after + continue indices_for_current_run += self._assign_seq_id_to_mamba_cache( request_id, seqs_id) ## Pad the batch in case of running batch that was not captured via CG @@ -787,16 +793,17 @@ def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): assert all( key in kwargs for key in ["request_ids_to_seq_ids", "finished_requests_ids"]) + finished_requests_ids = kwargs["finished_requests_ids"] + self._release_mamba_cache(finished_requests_ids) request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"] cg_batch_size = input_buffers['input_ids'].shape[0] ( current_mamba_cache, indices, ) = self._prepare_current_run_mamba_cache(request_ids_to_seq_ids, - cg_batch_size) + cg_batch_size, + finished_requests_ids) self.current_indices = indices - finished_requests_ids = kwargs["finished_requests_ids"] - self._release_mamba_cache(finished_requests_ids) for input_buffer, current_cache_buffer in zip( input_buffers["seqlen_agnostic_capture_inputs"], From 4e003822a9b560cba6c3298457525d7ceca7ccac Mon Sep 17 00:00:00 2001 From: Mor Zusman Date: Sun, 14 Jul 2024 16:35:02 +0300 Subject: [PATCH 05/20] Add scheduler config to has inner state models --- vllm/model_executor/model_loader/loader.py | 32 ++++++++++++++-------- 1 file changed, 20 insertions(+), 12 deletions(-) diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index 60547965063fa..e169a98bb509c 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -32,7 +32,8 @@ filter_duplicate_safetensors_files, filter_files_not_needed_for_inference, get_quant_config, initialize_dummy_weights, np_cache_weights_iterator, pt_weights_iterator, safetensors_weights_iterator) -from vllm.model_executor.models.interfaces import (supports_lora, +from vllm.model_executor.models.interfaces import (has_inner_state, + supports_lora, supports_vision) from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform @@ -66,10 +67,10 @@ def _get_quantization_config( def _get_model_initialization_kwargs( - model_class: Type[nn.Module], - lora_config: Optional[LoRAConfig], - multimodal_config: Optional[MultiModalConfig], -) -> Dict[str, Any]: + model_class: Type[nn.Module], + lora_config: Optional[LoRAConfig], + multimodal_config: Optional[MultiModalConfig], + scheduler_config: Optional[SchedulerConfig] = None) -> Dict[str, Any]: """Get extra kwargs for model initialization.""" extra_kwargs: Dict[str, Any] = {} @@ -90,13 +91,19 @@ def _get_model_initialization_kwargs( extra_kwargs["multimodal_config"] = multimodal_config + if has_inner_state(model_class) and scheduler_config: + extra_kwargs["scheduler_config"] = scheduler_config + return extra_kwargs -def _initialize_model(model_config: ModelConfig, load_config: LoadConfig, - lora_config: Optional[LoRAConfig], - multimodal_config: Optional[MultiModalConfig], - cache_config: CacheConfig) -> nn.Module: +def _initialize_model( + model_config: ModelConfig, + load_config: LoadConfig, + lora_config: Optional[LoRAConfig], + multimodal_config: Optional[MultiModalConfig], + cache_config: CacheConfig, + scheduler_config: Optional[SchedulerConfig] = None) -> nn.Module: """Initialize a model with the given configurations.""" model_class = get_model_architecture(model_config)[0] quant_config = _get_quantization_config(model_config, load_config) @@ -105,7 +112,8 @@ def _initialize_model(model_config: ModelConfig, load_config: LoadConfig, cache_config=cache_config, quant_config=quant_config, **_get_model_initialization_kwargs( - model_class, lora_config, multimodal_config)) + model_class, lora_config, multimodal_config, + scheduler_config)) class BaseModelLoader(ABC): @@ -266,7 +274,7 @@ def load_model(self, *, model_config: ModelConfig, with torch.device(device_config.device): model = _initialize_model(model_config, self.load_config, lora_config, multimodal_config, - cache_config) + cache_config, scheduler_config) model.load_weights( self._get_weights_iterator(model_config.model, model_config.revision, @@ -306,7 +314,7 @@ def load_model(self, *, model_config: ModelConfig, with torch.device(device_config.device): model = _initialize_model(model_config, self.load_config, lora_config, multimodal_config, - cache_config) + cache_config, scheduler_config) # NOTE(woosuk): For accurate performance evaluation, we assign # random values to the weights. initialize_dummy_weights(model) From 5279333bef31dbea56206e8c74d16dc99af1abc6 Mon Sep 17 00:00:00 2001 From: Mor Zusman Date: Sun, 14 Jul 2024 16:35:17 +0300 Subject: [PATCH 06/20] Add has inner config to jamba and create mamba cache max capacity according to max_num_seqs --- vllm/model_executor/models/jamba.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index a95906fc4939f..e602aaa92a160 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -13,7 +13,7 @@ from vllm.attention.backends.abstract import AttentionMetadata from vllm.attention.layer import Attention -from vllm.config import CacheConfig, LoRAConfig +from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig, SchedulerConfig from vllm.distributed import (get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce) @@ -612,7 +612,7 @@ def forward( return hidden_states -class JambaForCausalLM(nn.Module): +class JambaForCausalLM(nn.Module, HasInnerState): packed_modules_mapping = { "qkv_proj": [ "q_proj", @@ -640,9 +640,11 @@ def __init__( cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, + scheduler_config: Optional[SchedulerConfig] = None, ) -> None: super().__init__() self.config = config + self.scheduler_config = scheduler_config self.model = JambaModel(config, cache_config=cache_config, quant_config=quant_config, @@ -867,9 +869,12 @@ def _prepare_mamba_cache(self): layers_type = self.config.layers_block_type mamba_layers = sum( [layer_type == "mamba" for layer_type in layers_type]) - max_batch_size = _BATCH_SIZES_TO_CAPTURE[-1] + 10 + max_batch_size = (_get_graph_batch_size( + self.scheduler_config.max_num_seqs) if self.scheduler_config else + max(_BATCH_SIZES_TO_CAPTURE)) + 10 conv_state_shape, temporal_state_shape = self._get_mamba_cache_shape() assert conv_state_shape is not None and temporal_state_shape is not None + for buffername in ["mamba_cache", "mamba_gc_cache_buffer"]: buffer = (torch.empty(size=(mamba_layers, max_batch_size) + conv_state_shape, From 17efa066b90330493b54155519aa6824dc0261ad Mon Sep 17 00:00:00 2001 From: Mor Zusman Date: Sun, 14 Jul 2024 16:42:13 +0300 Subject: [PATCH 07/20] Add import --- vllm/model_executor/models/jamba.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index e602aaa92a160..ab0e5b558b075 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -35,7 +35,10 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_weight_attrs from vllm.sequence import IntermediateTensors, SamplerOutput -from vllm.worker.model_runner import _BATCH_SIZES_TO_CAPTURE +from vllm.worker.model_runner import ( + _BATCH_SIZES_TO_CAPTURE, + _get_graph_batch_size, +) KVCache = Tuple[torch.Tensor, torch.Tensor] From b23af539fb5d21ee0d6ef97ac894bb6a217e586e Mon Sep 17 00:00:00 2001 From: Mor Zusman Date: Sun, 14 Jul 2024 17:30:56 +0300 Subject: [PATCH 08/20] Remove redundant test --- tests/models/test_jamba.py | 22 ---------------------- 1 file changed, 22 deletions(-) diff --git a/tests/models/test_jamba.py b/tests/models/test_jamba.py index d3b915227388a..1e8c95480b000 100644 --- a/tests/models/test_jamba.py +++ b/tests/models/test_jamba.py @@ -114,28 +114,6 @@ def test_fail_upon_inc_requests_and_finished_requests_lt_available_blocks( "steps finished requests registered unnecessarily ") -@pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("dtype", ["float"]) -def test_cleanup_upon_aborted_requests( - vllm_runner, - model: str, - dtype: str, - example_prompts, -) -> None: - # This test is for verifying that the Jamba inner state management doesn't - # collapse in case where the number of incoming requests and - # finished_requests_ids is larger than the maximum mamba block capacity. - # This could generally happen due to the fact that Jamba does support - # statelessness mechanism where it can cleanup new incoming requests in - # a single step. - try: - with vllm_runner(model, dtype=dtype, max_num_seqs=10) as vllm_model: - vllm_model.generate_greedy([example_prompts[0]] * 100, 10) - except ValueError: - pytest.fail("Jamba inner state wasn't cleaned up properly between" - "steps finished requests registered unnecessarily ") - - @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", ["float"]) def test_state_cleanup( From 481d62b02026d4afe5737771b8b42b893e06e291 Mon Sep 17 00:00:00 2001 From: Mor Zusman Date: Sun, 14 Jul 2024 17:34:31 +0300 Subject: [PATCH 09/20] Add test --- tests/models/test_jamba.py | 148 +++++++++++++++++++++++++++++++++++-- 1 file changed, 142 insertions(+), 6 deletions(-) diff --git a/tests/models/test_jamba.py b/tests/models/test_jamba.py index 1e8c95480b000..954390d6a69d3 100644 --- a/tests/models/test_jamba.py +++ b/tests/models/test_jamba.py @@ -1,5 +1,6 @@ import pytest +from tests.models.utils import check_outputs_equal from vllm.worker.model_runner import _get_graph_batch_size MODELS = ["ai21labs/Jamba-tiny-random"] @@ -21,17 +22,152 @@ def test_models( with hf_runner(model, dtype=dtype) as hf_model: hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) + hf_logprobs_outputs = hf_model.generate_greedy_logprobs_limit( + example_prompts, max_tokens, num_logprobs=2) with vllm_runner(model, dtype=dtype) as vllm_model: vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) + vllm_logprobs_outputs = vllm_model.generate_greedy_logprobs( + example_prompts, max_tokens, num_logprobs=2) for i in range(len(example_prompts)): - hf_output_ids, hf_output_str = hf_outputs[i] - vllm_output_ids, vllm_output_str = vllm_outputs[i] - assert hf_output_str == vllm_output_str, ( - f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}") - assert hf_output_ids == vllm_output_ids, ( - f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}") + _, hf_output_str = hf_outputs[i] + hf_output_ids, _, hf_output_logprobs = hf_logprobs_outputs[i] + + _, vllm_output_str = vllm_outputs[i] + vllm_output_ids, _, vllm_output_logprobs = vllm_logprobs_outputs[i] + + if hf_output_str != vllm_output_str: + first_diff_index = [ + hf_id == vllm_id + for hf_id, vllm_id in zip(hf_output_ids, vllm_output_ids) + ].index(False) + hf_disagreement_logprobs = hf_output_logprobs[first_diff_index] + vllm_disagreement_logprobs = { + k: v.logprob + for k, v in vllm_output_logprobs[first_diff_index].items() + } + + assert (hf_output_ids[first_diff_index] + in vllm_disagreement_logprobs), ( + f"Test{i}:different outputs\n" + f"HF: {hf_output_str!r}\n" + f"vLLM: {vllm_output_str!r}\n", + f"Disagreement in {first_diff_index}th token. " + f"HF id: {hf_output_ids[first_diff_index]}, " + f"vLLM id: {vllm_output_ids[first_diff_index]})\n", + "HF top token not in vLLM top 2 tokens") + + vllm_disagreement_logprobs_values = list( + vllm_disagreement_logprobs.values()) + vllm_logprobs_diff = abs(vllm_disagreement_logprobs_values[0] - + vllm_disagreement_logprobs_values[1]) + vllm_hf_diff = abs( + hf_disagreement_logprobs[hf_output_ids[first_diff_index]] - + vllm_disagreement_logprobs[hf_output_ids[first_diff_index]]) + + assert (vllm_logprobs_diff < vllm_hf_diff + or vllm_logprobs_diff < 1e-4), ( + f"Test{i}:different outputs\n" + f"HF: {hf_output_str!r}\n" + f"vLLM: {vllm_output_str!r}\n", + f"Disagreement in {first_diff_index}th token. " + f"HF id: {hf_output_ids[first_diff_index]}, " + f"vLLM id: {vllm_output_ids[first_diff_index]})\n", + f"HF top token in vLLM top 2 tokens, " + f"but logprobs diff is too large. " + f"vLLM top 2 logprob diff: {vllm_logprobs_diff}\n", + f"HF to vLLM diff of top HF token: {vllm_hf_diff}") + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["half"]) +@pytest.mark.parametrize("max_tokens", [15]) +def test_batching( + vllm_runner, + example_prompts, + model: str, + dtype: str, + max_tokens: int, +) -> None: + # To pass the small model tests, we need full precision. + # assert dtype == "float" + + with vllm_runner(model, dtype=dtype, enforce_eager=True) as vllm_model: + for_loop_outputs = [] + for_loop_logprobs_outputs = [] + for prompt in example_prompts: + for_loop_outputs.append( + vllm_model.generate_greedy([prompt], max_tokens)[0]) + for_loop_logprobs_outputs.append( + vllm_model.generate_greedy_logprobs([prompt], + max_tokens, + num_logprobs=2)[0]) + + batched_outputs = vllm_model.generate_greedy(example_prompts, + max_tokens) + batched_logprobs_outputs = vllm_model.generate_greedy_logprobs( + example_prompts, max_tokens, num_logprobs=2) + + for i in range(len(example_prompts)): + _, for_loop_output_str = for_loop_outputs[i] + (for_loop_output_ids, _, + for_loop_output_logprobs) = for_loop_logprobs_outputs[i] + + _, batched_output_str = batched_outputs[i] + (batched_output_ids, _, + batched_output_logprobs) = batched_logprobs_outputs[i] + + if for_loop_output_str != batched_output_str: + first_diff_index = [ + for_loop_id == batched_id for for_loop_id, batched_id in zip( + for_loop_output_ids, batched_output_ids) + ].index(False) + for_loop_disagreement_logprobs = { + k: v.logprob + for k, v in for_loop_output_logprobs[first_diff_index].items() + } + batched_disagreement_logprobs = { + k: v.logprob + for k, v in batched_output_logprobs[first_diff_index].items() + } + + assert ( + for_loop_output_ids[first_diff_index] + in batched_disagreement_logprobs), ( + f"Test{i}:different outputs\n" + f"For-loop: {for_loop_output_str!r}\n", + f"Batched: {batched_output_str!r}\n", + f"Disagreement in {first_diff_index}th token. " + f"For-loop id: {for_loop_output_ids[first_diff_index]}, " + f"Batched id: {batched_output_ids[first_diff_index]})\n", + "For-loop top token not in batched top 2 tokens") + + batched_disagreement_logprobs_values = list( + batched_disagreement_logprobs.values()) + batched_logprobs_diff = abs( + batched_disagreement_logprobs_values[0] - + batched_disagreement_logprobs_values[1]) + batched_for_loop_diff = abs( + for_loop_disagreement_logprobs[ + for_loop_output_ids[first_diff_index]] - + batched_disagreement_logprobs[ + for_loop_output_ids[first_diff_index]]) + + assert ( + batched_logprobs_diff < batched_for_loop_diff + or batched_logprobs_diff < 1e-4), ( + f"Test{i}:different outputs\n" + f"For-loop: {for_loop_output_str!r}\n" + f"Batched: {batched_output_str!r}\n", + f"Disagreement in {first_diff_index}th token. " + f"For-loop id: {for_loop_output_ids[first_diff_index]}, " + f"Batched id: {batched_output_ids[first_diff_index]})\n", + f"For-loop top token in batched top 2 tokens, " + f"but logprobs diff is too large. " + f"Batched top 2 logprob diff: {batched_logprobs_diff}\n", + f"For-loop to batched diff of top for-loop token: " + f"{batched_logprobs_diff}") @pytest.mark.parametrize("model", MODELS) From 097eb8b3d2dd322d13f13a99cab4dab1c439eabf Mon Sep 17 00:00:00 2001 From: Mor Zusman Date: Sun, 14 Jul 2024 17:37:49 +0300 Subject: [PATCH 10/20] Add imports --- vllm/model_executor/models/jamba.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index ab0e5b558b075..9332124e4cf16 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -13,7 +13,7 @@ from vllm.attention.backends.abstract import AttentionMetadata from vllm.attention.layer import Attention -from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig, SchedulerConfig +from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig from vllm.distributed import (get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce) @@ -32,6 +32,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.interfaces import HasInnerState from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_weight_attrs from vllm.sequence import IntermediateTensors, SamplerOutput From 57559b3d64f6645f32bdcaf4f79e7c2cd0f64ffe Mon Sep 17 00:00:00 2001 From: Mor Zusman Date: Sun, 14 Jul 2024 17:38:27 +0300 Subject: [PATCH 11/20] Format --- vllm/model_executor/models/jamba.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index 9332124e4cf16..d4e4f0055aa2b 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -36,10 +36,8 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_weight_attrs from vllm.sequence import IntermediateTensors, SamplerOutput -from vllm.worker.model_runner import ( - _BATCH_SIZES_TO_CAPTURE, - _get_graph_batch_size, -) +from vllm.worker.model_runner import (_BATCH_SIZES_TO_CAPTURE, + _get_graph_batch_size) KVCache = Tuple[torch.Tensor, torch.Tensor] From 8c6e780882b620c70fd5d40f84a9b34a4e2ec8b5 Mon Sep 17 00:00:00 2001 From: Mor Zusman Date: Mon, 15 Jul 2024 12:11:31 +0300 Subject: [PATCH 12/20] Switch jamba test (models/batching) to regular tokens comparison --- tests/models/test_jamba.py | 146 +++++-------------------------------- 1 file changed, 20 insertions(+), 126 deletions(-) diff --git a/tests/models/test_jamba.py b/tests/models/test_jamba.py index 954390d6a69d3..c689c85677f66 100644 --- a/tests/models/test_jamba.py +++ b/tests/models/test_jamba.py @@ -1,6 +1,6 @@ import pytest -from tests.models.utils import check_outputs_equal +from tests.models.utils import check_logprobs_close, check_outputs_equal from vllm.worker.model_runner import _get_graph_batch_size MODELS = ["ai21labs/Jamba-tiny-random"] @@ -22,67 +22,21 @@ def test_models( with hf_runner(model, dtype=dtype) as hf_model: hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) - hf_logprobs_outputs = hf_model.generate_greedy_logprobs_limit( - example_prompts, max_tokens, num_logprobs=2) with vllm_runner(model, dtype=dtype) as vllm_model: vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) - vllm_logprobs_outputs = vllm_model.generate_greedy_logprobs( - example_prompts, max_tokens, num_logprobs=2) - for i in range(len(example_prompts)): - _, hf_output_str = hf_outputs[i] - hf_output_ids, _, hf_output_logprobs = hf_logprobs_outputs[i] - - _, vllm_output_str = vllm_outputs[i] - vllm_output_ids, _, vllm_output_logprobs = vllm_logprobs_outputs[i] - - if hf_output_str != vllm_output_str: - first_diff_index = [ - hf_id == vllm_id - for hf_id, vllm_id in zip(hf_output_ids, vllm_output_ids) - ].index(False) - hf_disagreement_logprobs = hf_output_logprobs[first_diff_index] - vllm_disagreement_logprobs = { - k: v.logprob - for k, v in vllm_output_logprobs[first_diff_index].items() - } - - assert (hf_output_ids[first_diff_index] - in vllm_disagreement_logprobs), ( - f"Test{i}:different outputs\n" - f"HF: {hf_output_str!r}\n" - f"vLLM: {vllm_output_str!r}\n", - f"Disagreement in {first_diff_index}th token. " - f"HF id: {hf_output_ids[first_diff_index]}, " - f"vLLM id: {vllm_output_ids[first_diff_index]})\n", - "HF top token not in vLLM top 2 tokens") - - vllm_disagreement_logprobs_values = list( - vllm_disagreement_logprobs.values()) - vllm_logprobs_diff = abs(vllm_disagreement_logprobs_values[0] - - vllm_disagreement_logprobs_values[1]) - vllm_hf_diff = abs( - hf_disagreement_logprobs[hf_output_ids[first_diff_index]] - - vllm_disagreement_logprobs[hf_output_ids[first_diff_index]]) - - assert (vllm_logprobs_diff < vllm_hf_diff - or vllm_logprobs_diff < 1e-4), ( - f"Test{i}:different outputs\n" - f"HF: {hf_output_str!r}\n" - f"vLLM: {vllm_output_str!r}\n", - f"Disagreement in {first_diff_index}th token. " - f"HF id: {hf_output_ids[first_diff_index]}, " - f"vLLM id: {vllm_output_ids[first_diff_index]})\n", - f"HF top token in vLLM top 2 tokens, " - f"but logprobs diff is too large. " - f"vLLM top 2 logprob diff: {vllm_logprobs_diff}\n", - f"HF to vLLM diff of top HF token: {vllm_hf_diff}") + check_outputs_equal( + outputs_0_lst=hf_outputs, + outputs_1_lst=vllm_outputs, + name_0="hf", + name_1="vllm", + ) @pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("dtype", ["half"]) -@pytest.mark.parametrize("max_tokens", [15]) +@pytest.mark.parametrize("dtype", ["float"]) +@pytest.mark.parametrize("max_tokens", [96]) def test_batching( vllm_runner, example_prompts, @@ -92,82 +46,22 @@ def test_batching( ) -> None: # To pass the small model tests, we need full precision. # assert dtype == "float" - - with vllm_runner(model, dtype=dtype, enforce_eager=True) as vllm_model: - for_loop_outputs = [] - for_loop_logprobs_outputs = [] + for_loop_outputs = [] + with vllm_runner(model, dtype=dtype) as vllm_model: for prompt in example_prompts: for_loop_outputs.append( - vllm_model.generate_greedy([prompt], max_tokens)[0]) - for_loop_logprobs_outputs.append( - vllm_model.generate_greedy_logprobs([prompt], - max_tokens, - num_logprobs=2)[0]) - - batched_outputs = vllm_model.generate_greedy(example_prompts, - max_tokens) - batched_logprobs_outputs = vllm_model.generate_greedy_logprobs( - example_prompts, max_tokens, num_logprobs=2) - - for i in range(len(example_prompts)): - _, for_loop_output_str = for_loop_outputs[i] - (for_loop_output_ids, _, - for_loop_output_logprobs) = for_loop_logprobs_outputs[i] + vllm_model.generate_greedy([prompt], max_tokens)[0] + ) - _, batched_output_str = batched_outputs[i] - (batched_output_ids, _, - batched_output_logprobs) = batched_logprobs_outputs[i] + batched_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) - if for_loop_output_str != batched_output_str: - first_diff_index = [ - for_loop_id == batched_id for for_loop_id, batched_id in zip( - for_loop_output_ids, batched_output_ids) - ].index(False) - for_loop_disagreement_logprobs = { - k: v.logprob - for k, v in for_loop_output_logprobs[first_diff_index].items() - } - batched_disagreement_logprobs = { - k: v.logprob - for k, v in batched_output_logprobs[first_diff_index].items() - } - - assert ( - for_loop_output_ids[first_diff_index] - in batched_disagreement_logprobs), ( - f"Test{i}:different outputs\n" - f"For-loop: {for_loop_output_str!r}\n", - f"Batched: {batched_output_str!r}\n", - f"Disagreement in {first_diff_index}th token. " - f"For-loop id: {for_loop_output_ids[first_diff_index]}, " - f"Batched id: {batched_output_ids[first_diff_index]})\n", - "For-loop top token not in batched top 2 tokens") - - batched_disagreement_logprobs_values = list( - batched_disagreement_logprobs.values()) - batched_logprobs_diff = abs( - batched_disagreement_logprobs_values[0] - - batched_disagreement_logprobs_values[1]) - batched_for_loop_diff = abs( - for_loop_disagreement_logprobs[ - for_loop_output_ids[first_diff_index]] - - batched_disagreement_logprobs[ - for_loop_output_ids[first_diff_index]]) + check_outputs_equal( + outputs_0_lst=for_loop_outputs, + outputs_1_lst=batched_outputs, + name_0="for_loop_vllm", + name_1="batched_vllm", + ) - assert ( - batched_logprobs_diff < batched_for_loop_diff - or batched_logprobs_diff < 1e-4), ( - f"Test{i}:different outputs\n" - f"For-loop: {for_loop_output_str!r}\n" - f"Batched: {batched_output_str!r}\n", - f"Disagreement in {first_diff_index}th token. " - f"For-loop id: {for_loop_output_ids[first_diff_index]}, " - f"Batched id: {batched_output_ids[first_diff_index]})\n", - f"For-loop top token in batched top 2 tokens, " - f"but logprobs diff is too large. " - f"Batched top 2 logprob diff: {batched_logprobs_diff}\n", - f"For-loop to batched diff of top for-loop token: " - f"{batched_logprobs_diff}") @pytest.mark.parametrize("model", MODELS) From 4e029de22fa0e7b642cd13a3bf4111df03195249 Mon Sep 17 00:00:00 2001 From: Mor Zusman Date: Mon, 15 Jul 2024 12:20:59 +0300 Subject: [PATCH 13/20] Add _HasInnerStateType to support isinstance with a type and not just the model instance --- vllm/model_executor/models/interfaces.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/vllm/model_executor/models/interfaces.py b/vllm/model_executor/models/interfaces.py index 4eb7f8167e581..cfc970265897a 100644 --- a/vllm/model_executor/models/interfaces.py +++ b/vllm/model_executor/models/interfaces.py @@ -160,6 +160,15 @@ def __init__(self, scheduler_config: Optional[SchedulerConfig] = None) -> None: ... +@runtime_checkable +class _HasInnerStateType(Protocol): + has_inner_state: ClassVar[Literal[True]] + + + def __init__(self, + *, + scheduler_config: Optional[SchedulerConfig] = None) -> None: + ... @overload def has_inner_state(model: object) -> TypeGuard[HasInnerState]: @@ -174,4 +183,8 @@ def has_inner_state(model: Type[object]) -> TypeGuard[Type[HasInnerState]]: def has_inner_state( model: Union[Type[object], object] ) -> Union[TypeGuard[Type[HasInnerState]], TypeGuard[HasInnerState]]: + if isinstance(model, type): + return isinstance(model, _HasInnerStateType) + return isinstance(model, HasInnerState) + From ed1b84ed8e60694d749df4a69073699398b50cee Mon Sep 17 00:00:00 2001 From: Mor Zusman Date: Mon, 15 Jul 2024 12:23:23 +0300 Subject: [PATCH 14/20] Format --- tests/models/test_jamba.py | 9 ++++----- vllm/model_executor/models/interfaces.py | 4 ++-- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/tests/models/test_jamba.py b/tests/models/test_jamba.py index c689c85677f66..88d9296708a72 100644 --- a/tests/models/test_jamba.py +++ b/tests/models/test_jamba.py @@ -1,6 +1,6 @@ import pytest -from tests.models.utils import check_logprobs_close, check_outputs_equal +from tests.models.utils import check_outputs_equal from vllm.worker.model_runner import _get_graph_batch_size MODELS = ["ai21labs/Jamba-tiny-random"] @@ -50,10 +50,10 @@ def test_batching( with vllm_runner(model, dtype=dtype) as vllm_model: for prompt in example_prompts: for_loop_outputs.append( - vllm_model.generate_greedy([prompt], max_tokens)[0] - ) + vllm_model.generate_greedy([prompt], max_tokens)[0]) - batched_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) + batched_outputs = vllm_model.generate_greedy(example_prompts, + max_tokens) check_outputs_equal( outputs_0_lst=for_loop_outputs, @@ -63,7 +63,6 @@ def test_batching( ) - @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", ["bfloat16"]) @pytest.mark.parametrize("max_tokens", [20]) diff --git a/vllm/model_executor/models/interfaces.py b/vllm/model_executor/models/interfaces.py index cfc970265897a..6fdacd4469788 100644 --- a/vllm/model_executor/models/interfaces.py +++ b/vllm/model_executor/models/interfaces.py @@ -160,16 +160,17 @@ def __init__(self, scheduler_config: Optional[SchedulerConfig] = None) -> None: ... + @runtime_checkable class _HasInnerStateType(Protocol): has_inner_state: ClassVar[Literal[True]] - def __init__(self, *, scheduler_config: Optional[SchedulerConfig] = None) -> None: ... + @overload def has_inner_state(model: object) -> TypeGuard[HasInnerState]: ... @@ -187,4 +188,3 @@ def has_inner_state( return isinstance(model, _HasInnerStateType) return isinstance(model, HasInnerState) - From c66aa3dc5d11f35668bff001a855c573cd999959 Mon Sep 17 00:00:00 2001 From: Mor Zusman Date: Mon, 15 Jul 2024 12:25:25 +0300 Subject: [PATCH 15/20] Revert test to reduce diff --- tests/models/test_jamba.py | 42 +++++++------------------------------- 1 file changed, 7 insertions(+), 35 deletions(-) diff --git a/tests/models/test_jamba.py b/tests/models/test_jamba.py index 88d9296708a72..988f9d9f21c13 100644 --- a/tests/models/test_jamba.py +++ b/tests/models/test_jamba.py @@ -26,41 +26,13 @@ def test_models( with vllm_runner(model, dtype=dtype) as vllm_model: vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) - check_outputs_equal( - outputs_0_lst=hf_outputs, - outputs_1_lst=vllm_outputs, - name_0="hf", - name_1="vllm", - ) - - -@pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("dtype", ["float"]) -@pytest.mark.parametrize("max_tokens", [96]) -def test_batching( - vllm_runner, - example_prompts, - model: str, - dtype: str, - max_tokens: int, -) -> None: - # To pass the small model tests, we need full precision. - # assert dtype == "float" - for_loop_outputs = [] - with vllm_runner(model, dtype=dtype) as vllm_model: - for prompt in example_prompts: - for_loop_outputs.append( - vllm_model.generate_greedy([prompt], max_tokens)[0]) - - batched_outputs = vllm_model.generate_greedy(example_prompts, - max_tokens) - - check_outputs_equal( - outputs_0_lst=for_loop_outputs, - outputs_1_lst=batched_outputs, - name_0="for_loop_vllm", - name_1="batched_vllm", - ) + for i in range(len(example_prompts)): + hf_output_ids, hf_output_str = hf_outputs[i] + vllm_output_ids, vllm_output_str = vllm_outputs[i] + assert hf_output_str == vllm_output_str, ( + f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}") + assert hf_output_ids == vllm_output_ids, ( + f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}") @pytest.mark.parametrize("model", MODELS) From ba9691a3cf579ea0e4a2c911800dcd593eec98cf Mon Sep 17 00:00:00 2001 From: Mor Zusman Date: Mon, 15 Jul 2024 12:33:07 +0300 Subject: [PATCH 16/20] Put back the batching test --- tests/models/test_jamba.py | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/tests/models/test_jamba.py b/tests/models/test_jamba.py index 988f9d9f21c13..2e84f005f610a 100644 --- a/tests/models/test_jamba.py +++ b/tests/models/test_jamba.py @@ -34,6 +34,34 @@ def test_models( assert hf_output_ids == vllm_output_ids, ( f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}") +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["float"]) +@pytest.mark.parametrize("max_tokens", [96]) +def test_batching( + vllm_runner, + example_prompts, + model: str, + dtype: str, + max_tokens: int, +) -> None: + # To pass the small model tests, we need full precision. + # assert dtype == "float" + for_loop_outputs = [] + with vllm_runner(model, dtype=dtype) as vllm_model: + for prompt in example_prompts: + for_loop_outputs.append( + vllm_model.generate_greedy([prompt], max_tokens)[0]) + + batched_outputs = vllm_model.generate_greedy(example_prompts, + max_tokens) + + check_outputs_equal( + outputs_0_lst=for_loop_outputs, + outputs_1_lst=batched_outputs, + name_0="for_loop_vllm", + name_1="batched_vllm", + ) + @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", ["bfloat16"]) From 757245056f7b36e49fd28041858d5a956caec65b Mon Sep 17 00:00:00 2001 From: Mor Zusman Date: Mon, 15 Jul 2024 12:36:14 +0300 Subject: [PATCH 17/20] Remove comment --- tests/models/test_jamba.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/models/test_jamba.py b/tests/models/test_jamba.py index 2e84f005f610a..86a9ab6db4f50 100644 --- a/tests/models/test_jamba.py +++ b/tests/models/test_jamba.py @@ -45,7 +45,6 @@ def test_batching( max_tokens: int, ) -> None: # To pass the small model tests, we need full precision. - # assert dtype == "float" for_loop_outputs = [] with vllm_runner(model, dtype=dtype) as vllm_model: for prompt in example_prompts: From d25c59036ff8691e7e719ccd8ab75e7182eb3a4d Mon Sep 17 00:00:00 2001 From: Mor Zusman Date: Mon, 15 Jul 2024 12:37:05 +0300 Subject: [PATCH 18/20] Format --- tests/models/test_jamba.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/models/test_jamba.py b/tests/models/test_jamba.py index 86a9ab6db4f50..f83dff683d43b 100644 --- a/tests/models/test_jamba.py +++ b/tests/models/test_jamba.py @@ -34,6 +34,7 @@ def test_models( assert hf_output_ids == vllm_output_ids, ( f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}") + @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", ["float"]) @pytest.mark.parametrize("max_tokens", [96]) From ca28c521ac0b20ad068efae82b5ae5106ac85f11 Mon Sep 17 00:00:00 2001 From: Mor Zusman Date: Mon, 15 Jul 2024 15:49:33 +0300 Subject: [PATCH 19/20] Reduce the max_tokens to 20 --- tests/models/test_jamba.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/test_jamba.py b/tests/models/test_jamba.py index f83dff683d43b..0891a76856404 100644 --- a/tests/models/test_jamba.py +++ b/tests/models/test_jamba.py @@ -91,7 +91,7 @@ def test_mamba_cache_cg_padding( @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", ["float"]) -@pytest.mark.parametrize("max_tokens", [96]) +@pytest.mark.parametrize("max_tokens", [20]) def test_models_preemption_recompute( hf_runner, vllm_runner, From acfa4c7fa5c34ac1326c22839547b8cb30a43ec2 Mon Sep 17 00:00:00 2001 From: Mor Zusman Date: Mon, 15 Jul 2024 17:32:08 +0300 Subject: [PATCH 20/20] set batching test to 20 --- tests/models/test_jamba.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/test_jamba.py b/tests/models/test_jamba.py index 0891a76856404..774f2d9d9cdbc 100644 --- a/tests/models/test_jamba.py +++ b/tests/models/test_jamba.py @@ -37,7 +37,7 @@ def test_models( @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", ["float"]) -@pytest.mark.parametrize("max_tokens", [96]) +@pytest.mark.parametrize("max_tokens", [20]) def test_batching( vllm_runner, example_prompts,