diff --git a/tests/multi_step/test_correctness_llm.py b/tests/multi_step/test_correctness_llm.py index cc1fd19252019..6fe5e6f76653b 100644 --- a/tests/multi_step/test_correctness_llm.py +++ b/tests/multi_step/test_correctness_llm.py @@ -5,6 +5,8 @@ import pytest +from tests.kernels.utils import override_backend_env_variable + from ..models.utils import check_logprobs_close, check_outputs_equal MODELS = [ @@ -19,10 +21,11 @@ @pytest.mark.parametrize("tp_size", [1]) @pytest.mark.parametrize("enable_chunked_prefill", [False, True]) @pytest.mark.parametrize("max_tokens", [5]) -@pytest.mark.parametrize("enforce_eager", [True]) +@pytest.mark.parametrize("enforce_eager", [True, False]) @pytest.mark.parametrize("num_scheduler_steps", NUM_SCHEDULER_STEPS) @pytest.mark.parametrize("num_prompts", NUM_PROMPTS) @pytest.mark.parametrize("num_logprobs", [None, 5]) +@pytest.mark.parametrize("attention_backend", ["FLASH_ATTN", "FLASHINFER"]) def test_multi_step_llm( hf_runner, vllm_runner, @@ -36,6 +39,8 @@ def test_multi_step_llm( num_scheduler_steps: int, num_prompts: int, num_logprobs: Optional[int], + attention_backend: str, + monkeypatch, ) -> None: """Test vLLM engine with multi-step scheduling via sync LLM Engine. @@ -63,6 +68,7 @@ def test_multi_step_llm( num_logprobs: corresponds to the `logprobs` argument to the OpenAI completions endpoint; `None` -> 1 logprob returned. """ + override_backend_env_variable(monkeypatch, attention_backend) prompts = example_prompts if len(prompts) < num_prompts: @@ -110,10 +116,11 @@ def test_multi_step_llm( @pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("tp_size", [1]) @pytest.mark.parametrize("max_tokens", [5]) -@pytest.mark.parametrize("enforce_eager", [True]) +@pytest.mark.parametrize("enforce_eager", [True, False]) @pytest.mark.parametrize("num_scheduler_steps", NUM_SCHEDULER_STEPS) @pytest.mark.parametrize("num_prompts", NUM_PROMPTS) @pytest.mark.parametrize("num_logprobs,num_prompt_logprobs", [(5, 5)]) +@pytest.mark.parametrize("attention_backend", ["FLASH_ATTN", "FLASHINFER"]) def test_multi_step_llm_w_prompt_logprobs( vllm_runner, example_prompts, @@ -126,6 +133,8 @@ def test_multi_step_llm_w_prompt_logprobs( num_prompts: int, num_logprobs: Optional[int], num_prompt_logprobs: Optional[int], + attention_backend: str, + monkeypatch, ) -> None: """Test prompt logprobs with multi-step scheduling via sync LLM Engine. @@ -155,6 +164,7 @@ def test_multi_step_llm_w_prompt_logprobs( note that this argument is not supported by the OpenAI completions endpoint. """ + override_backend_env_variable(monkeypatch, attention_backend) prompts = example_prompts if len(prompts) < num_prompts: @@ -205,6 +215,7 @@ def test_multi_step_llm_w_prompt_logprobs( @pytest.mark.parametrize("num_scheduler_steps", NUM_SCHEDULER_STEPS) @pytest.mark.parametrize("num_prompts", NUM_PROMPTS) @pytest.mark.parametrize("num_logprobs", [None, 5]) +@pytest.mark.parametrize("attention_backend", ["FLASH_ATTN"]) def test_multi_step_llm_chunked_prefill_prefix_cache( vllm_runner, example_prompts, @@ -216,6 +227,8 @@ def test_multi_step_llm_chunked_prefill_prefix_cache( num_scheduler_steps: int, num_prompts: int, num_logprobs: Optional[int], + attention_backend: str, + monkeypatch, ) -> None: """Test vLLM engine with multi-step+"single-step chunked prefill"+APC. @@ -278,6 +291,8 @@ def test_multi_step_llm_chunked_prefill_prefix_cache( # # The Incorrect scheduling behavior - if it occurs - will cause an exception # in the model runner resulting from `do_sample=False`. + override_backend_env_variable(monkeypatch, attention_backend) + assert len(example_prompts) >= 2 challenge_prompts = copy.deepcopy(example_prompts) challenge_prompts[0] = ('vLLM is a high-throughput and memory-efficient ' diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index e367468d05d26..702401b135de4 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -256,7 +256,9 @@ def prepare_graph_input_buffers(self, def begin_forward(self, model_input): assert not self._is_graph_capturing state = self - if model_input.attn_metadata.use_cuda_graph: + use_cuda_graph = model_input.attn_metadata.use_cuda_graph + is_decode = model_input.attn_metadata.num_prefills == 0 + if use_cuda_graph and is_decode: batch_size = model_input.input_tokens.shape[0] state = (self.runner.graph_runners[model_input.virtual_engine] [batch_size].attn_state) @@ -429,10 +431,24 @@ def advance_step(self, Update metadata in-place to advance one decode step. """ - assert not turn_prefills_into_decodes, \ - ("Chunked prefill is not supported with flashinfer yet." - "turn_prefills_into_decodes is a Multi-Step + Chunked-Prefill " - "specific parameter.") + if turn_prefills_into_decodes: + # When Multi-Step is enabled with Chunked-Prefill, prefills and + # decodes are scheduled together. In the first step, all the + # prefills turn into decodes. This update reflects that + # conversion. + assert self.num_decode_tokens + self.num_prefills == num_seqs + # Flashinfer doesn't support speculative decoding + chunked-prefill + # + multi-step scheduling yet. + assert self.decode_query_len == 1 + self.num_decode_tokens += self.num_prefills + self.num_prefills = 0 + self.num_prefill_tokens = 0 + self.max_prefill_seq_len = 0 + self.max_query_len = 1 + + self.slot_mapping = self.slot_mapping[:num_seqs] + else: + assert self.seq_lens_tensor is not None assert num_seqs > 0 assert num_queries > 0 diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 26fd486130ce6..7547d24bdb2c1 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -6,8 +6,19 @@ import warnings import weakref from dataclasses import dataclass -from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Set, - Tuple, Type, TypeVar, Union) +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + List, + Optional, + Set, + Tuple, + Type, + TypeVar, + Union, +) import numpy as np import torch @@ -35,26 +46,41 @@ from vllm.model_executor.model_loader.tensorizer import TensorizerConfig from vllm.model_executor.models import supports_lora, supports_multimodal from vllm.model_executor.models.utils import set_cpu_offload_max_bytes -from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs, - MultiModalKwargs, MultiModalPlaceholderMap, - MultiModalRegistry) +from vllm.multimodal import ( + MULTIMODAL_REGISTRY, + BatchedTensorInputs, + MultiModalKwargs, + MultiModalPlaceholderMap, + MultiModalRegistry, +) from vllm.platforms import current_platform from vllm.prompt_adapter.layers import PromptAdapterMapping from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.prompt_adapter.worker_manager import ( - LRUCacheWorkerPromptAdapterManager) + LRUCacheWorkerPromptAdapterManager, +) from vllm.sampling_params import SamplingParams from vllm.sequence import IntermediateTensors, SequenceGroupMetadata -from vllm.utils import (DeviceMemoryProfiler, GiB_bytes, PyObjectCache, - async_tensor_h2d, flatten_2d_lists, - is_pin_memory_available, supports_dynamo, - weak_ref_tensor) +from vllm.utils import ( + DeviceMemoryProfiler, + GiB_bytes, + PyObjectCache, + async_tensor_h2d, + flatten_2d_lists, + is_pin_memory_available, + supports_dynamo, + weak_ref_tensor, +) from vllm.worker.model_runner_base import ( - ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase, + ModelRunnerBase, + ModelRunnerInputBase, + ModelRunnerInputBuilderBase, _add_attn_metadata_broadcastable_dict, _add_sampling_metadata_broadcastable_dict, _init_attn_metadata_from_tensor_dict, - _init_sampling_metadata_from_tensor_dict, dump_input_when_exception) + _init_sampling_metadata_from_tensor_dict, + dump_input_when_exception, +) if TYPE_CHECKING: from vllm.attention.backends.abstract import AttentionBackend @@ -65,7 +91,7 @@ _NUM_WARMUP_ITERS = 2 -TModelInputForGPU = TypeVar('TModelInputForGPU', bound="ModelInputForGPU") +TModelInputForGPU = TypeVar("TModelInputForGPU", bound="ModelInputForGPU") # For now, bump up cache limits for recompilations during CUDA graph warmups. torch._dynamo.config.cache_size_limit = 128 @@ -80,6 +106,7 @@ class ModelInputForGPU(ModelRunnerInputBase): runners that run additional steps should subclass this method to add additional fields. """ + input_tokens: Optional[torch.Tensor] = None input_positions: Optional[torch.Tensor] = None token_types: Optional[torch.Tensor] = None @@ -122,7 +149,8 @@ def from_broadcasted_tensor_dict( ) -> TModelInputForGPU: if attn_backend is not None: tensor_dict = _init_attn_metadata_from_tensor_dict( - attn_backend, tensor_dict) + attn_backend, tensor_dict + ) return cls(**tensor_dict) # Exclude `async_callback` to be able to pickle this object @@ -135,7 +163,7 @@ def __getstate__(self): # How can we update this callback to properly pass it to the engine? def __setstate__(self, state): self.__dict__.update(state) - self.__dict__.update({'async_callback': None}) + self.__dict__.update({"async_callback": None}) @dataclass(frozen=True) @@ -143,6 +171,7 @@ class ModelInputForGPUWithSamplingMetadata(ModelInputForGPU): """ Used by the ModelRunner. """ + sampling_metadata: Optional["SamplingMetadata"] = None # Used for speculative decoding. We do not broadcast it because it is only # used by the driver worker. @@ -162,8 +191,9 @@ def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: "finished_requests_ids": self.finished_requests_ids, } _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) - _add_sampling_metadata_broadcastable_dict(tensor_dict, - self.sampling_metadata) + _add_sampling_metadata_broadcastable_dict( + tensor_dict, self.sampling_metadata + ) return tensor_dict @classmethod @@ -175,7 +205,8 @@ def from_broadcasted_tensor_dict( tensor_dict = _init_sampling_metadata_from_tensor_dict(tensor_dict) if attn_backend is not None: tensor_dict = _init_attn_metadata_from_tensor_dict( - attn_backend, tensor_dict) + attn_backend, tensor_dict + ) return cls(**tensor_dict) @@ -214,13 +245,11 @@ def __init__( block_tables: Optional[Dict[int, List[int]]], computed_block_nums: List[int], n_seqs: int = 0, - # Input tokens and positions. input_tokens: Optional[List[List[int]]] = None, input_positions: Optional[List[List[int]]] = None, token_types: Optional[List[List[int]]] = None, mrope_input_positions: Optional[List[List[List[int]]]] = None, - # The sequence length (may be capped to the sliding window). seq_lens: Optional[List[int]] = None, # The original sequence length (before applying sliding window). @@ -232,22 +261,19 @@ def __init__( context_lens: Optional[List[int]] = None, # The current sliding window block. curr_sliding_window_blocks: Optional[List[int]] = None, - # LoRA inputs. lora_index_mapping: Optional[List[List[int]]] = None, lora_prompt_mapping: Optional[List[List[int]]] = None, lora_requests: Optional[Set[LoRARequest]] = None, - # Prompt adapter inputs. prompt_adapter_index_mapping: Optional[List[int]] = None, prompt_adapter_prompt_mapping: Optional[List[int]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, - # Multi-modal inputs. multi_modal_kwargs: Optional[MultiModalKwargs] = None, - multi_modal_placeholder_maps: Optional[Dict[ - str, MultiModalPlaceholderMap]] = None, - + multi_modal_placeholder_maps: Optional[ + Dict[str, MultiModalPlaceholderMap] + ] = None, # Whether the prefix cache is hit (prefill only). prefix_cache_hit: bool = False, reinit: bool = False, @@ -317,8 +343,9 @@ def __init__( self.context_lens[seq_id] = 0 if curr_sliding_window_blocks: - self.curr_sliding_window_blocks = \ + self.curr_sliding_window_blocks = ( curr_sliding_window_blocks + ) else: for seq_id in range(len(self.seq_ids)): self.curr_sliding_window_blocks[seq_id] = 0 @@ -339,14 +366,16 @@ def __init__( self.lora_requests.clear() if prompt_adapter_index_mapping: - self.prompt_adapter_index_mapping = \ + self.prompt_adapter_index_mapping = ( prompt_adapter_index_mapping + ) else: self.prompt_adapter_index_mapping.clear() if prompt_adapter_prompt_mapping: - self.prompt_adapter_prompt_mapping = \ + self.prompt_adapter_prompt_mapping = ( prompt_adapter_prompt_mapping + ) else: self.prompt_adapter_prompt_mapping.clear() @@ -359,17 +388,20 @@ def __init__( self.orig_seq_lens = orig_seq_lens or [] self.query_lens = query_lens or [] self.context_lens = context_lens or [] - self.curr_sliding_window_blocks = \ + self.curr_sliding_window_blocks = ( curr_sliding_window_blocks or [] + ) self.lora_index_mapping = lora_index_mapping or [] self.lora_prompt_mapping = lora_prompt_mapping or [] self.lora_requests = lora_requests or set() self.prompt_adapter_index_mapping = ( - prompt_adapter_index_mapping or []) + prompt_adapter_index_mapping or [] + ) self.prompt_adapter_prompt_mapping = ( - prompt_adapter_prompt_mapping or []) + prompt_adapter_prompt_mapping or [] + ) self.prompt_adapter_request = prompt_adapter_request self.multi_modal_kwargs = multi_modal_kwargs @@ -403,7 +435,8 @@ def gen_inter_data_builder(self, num_seqs: int): seq_ids=[0] * num_seqs, is_prompt=True, block_tables=None, - computed_block_nums=[]) + computed_block_nums=[], + ) def init_cached_inter_data(self, *args, **kwargs): assert len(args) == 0 @@ -415,7 +448,8 @@ def init_cached_inter_data(self, *args, **kwargs): inter_data_cache = self.runner.inter_data_cache if num_seqs not in inter_data_cache: inter_data_cache[num_seqs] = PyObjectCache( - self.gen_inter_data_builder(num_seqs)) + self.gen_inter_data_builder(num_seqs) + ) obj = inter_data_cache[num_seqs].get_object() obj.__init__(*args, **kwargs) @@ -425,9 +459,11 @@ def reset_cached_inter_data(self): for cache in self.runner.inter_data_cache.values(): cache.reset() - def __init__(self, - runner: "GPUModelRunnerBase", - finished_requests_ids: Optional[List[str]] = None): + def __init__( + self, + runner: "GPUModelRunnerBase", + finished_requests_ids: Optional[List[str]] = None, + ): super().__init__() # Compute functions for each sequence in a sequence group. # WARNING: The order of the functions matters! @@ -451,8 +487,9 @@ def __init__(self, self.sliding_window = self.runner.sliding_window self.block_size = self.runner.block_size self.enable_lora = self.runner.lora_config is not None - self.enable_prompt_adapter = (self.runner.prompt_adapter_config - is not None) + self.enable_prompt_adapter = ( + self.runner.prompt_adapter_config is not None + ) self.multi_modal_input_mapper = self.runner.multi_modal_input_mapper self.finished_requests_ids = finished_requests_ids self.decode_only = True @@ -460,24 +497,33 @@ def __init__(self, # Intermediate data (data in CPU before going to GPU) for # the current sequence group. self.inter_data_list: List[ - ModelInputForGPUBuilder.InterDataForSeqGroup] = [] + ModelInputForGPUBuilder.InterDataForSeqGroup + ] = [] # Attention metadata inputs. self.attn_metadata_builder = self.attn_backend.make_metadata_builder( - weakref.proxy(self)) + weakref.proxy(self) + ) # Engine/Model configurations. self.chunked_prefill_enabled = ( self.scheduler_config is not None - and self.scheduler_config.chunked_prefill_enabled) + and self.scheduler_config.chunked_prefill_enabled + ) if self.sliding_window is not None: self.sliding_window_blocks = ( - self.sliding_window + self.block_size - 1) // self.block_size - self.block_aligned_sliding_window = \ + self.sliding_window + self.block_size - 1 + ) // self.block_size + self.block_aligned_sliding_window = ( self.sliding_window_blocks * self.block_size + ) - def _compute_lens(self, inter_data: InterDataForSeqGroup, seq_idx: int, - seq_group_metadata: SequenceGroupMetadata): + def _compute_lens( + self, + inter_data: InterDataForSeqGroup, + seq_idx: int, + seq_group_metadata: SequenceGroupMetadata, + ): """Compute context length, sequence length and tokens for the given sequence data. """ @@ -491,8 +537,10 @@ def _compute_lens(self, inter_data: InterDataForSeqGroup, seq_idx: int, if inter_data.is_prompt: context_len = seq_data.get_num_computed_tokens() seq_len = min(seq_len, context_len + token_chunk_size) - elif self.runner.scheduler_config.is_multi_step or \ - self.runner.model_config.is_encoder_decoder: + elif ( + self.runner.scheduler_config.is_multi_step + or self.runner.model_config.is_encoder_decoder + ): context_len = seq_len - 1 else: context_len = seq_data.get_num_computed_tokens() @@ -507,23 +555,28 @@ def _compute_lens(self, inter_data: InterDataForSeqGroup, seq_idx: int, inter_data.input_tokens[seq_idx].extend(tokens) inter_data.input_positions[seq_idx].extend(range(context_len, seq_len)) inter_data.token_types[seq_idx].extend( - token_types if token_types else []) + token_types if token_types else [] + ) inter_data.query_lens[seq_idx] = seq_len - context_len if seq_data.mrope_position_delta is not None: if inter_data.mrope_input_positions is None: inter_data.mrope_input_positions = [None] * inter_data.n_seqs - inter_data.mrope_input_positions[ - seq_idx] = MRotaryEmbedding.get_next_input_positions( + inter_data.mrope_input_positions[seq_idx] = ( + MRotaryEmbedding.get_next_input_positions( seq_data.mrope_position_delta, context_len, seq_len, ) + ) def _compute_for_prefix_cache_hit( - self, inter_data: InterDataForSeqGroup, seq_idx: int, - seq_group_metadata: SequenceGroupMetadata): + self, + inter_data: InterDataForSeqGroup, + seq_idx: int, + seq_group_metadata: SequenceGroupMetadata, + ): """Check if hit prefix cache (i.e., some blocks are already computed). If hit, update input tokens and positions to only compute the remaining blocks. @@ -531,10 +584,12 @@ def _compute_for_prefix_cache_hit( computed_block_nums = inter_data.computed_block_nums # Note that prefix caching does not support sliding window. - prefix_cache_hit = (computed_block_nums is not None - and len(computed_block_nums) > 0 - and self.sliding_window is None - and inter_data.is_prompt) + prefix_cache_hit = ( + computed_block_nums is not None + and len(computed_block_nums) > 0 + and self.sliding_window is None + and inter_data.is_prompt + ) inter_data.prefix_cache_hit = prefix_cache_hit if not prefix_cache_hit: @@ -545,8 +600,9 @@ def _compute_for_prefix_cache_hit( # this may be larger than the sequence length if chunked # prefill is enabled. prefix_cache_len = len(computed_block_nums) * self.block_size - seq_group_metadata.seq_data[inter_data.seq_ids[ - seq_idx]].update_num_cached_tokens(prefix_cache_len) + seq_group_metadata.seq_data[ + inter_data.seq_ids[seq_idx] + ].update_num_cached_tokens(prefix_cache_len) # The number of so far computed prompt tokens in this sequence. context_len = inter_data.context_lens[seq_idx] @@ -561,34 +617,44 @@ def _compute_for_prefix_cache_hit( elif context_len < prefix_cache_len < seq_len: # Partial hit. Compute the missing part. uncomputed_start = prefix_cache_len - context_len - inter_data.input_tokens[seq_idx] = inter_data.input_tokens[ - seq_idx][uncomputed_start:] + inter_data.input_tokens[seq_idx] = inter_data.input_tokens[seq_idx][ + uncomputed_start: + ] inter_data.input_positions[seq_idx] = inter_data.input_positions[ - seq_idx][uncomputed_start:] + seq_idx + ][uncomputed_start:] inter_data.token_types[seq_idx] = inter_data.token_types[seq_idx][ - uncomputed_start:] + uncomputed_start: + ] context_len = prefix_cache_len inter_data.context_lens[seq_idx] = context_len - inter_data.query_lens[ - seq_idx] = inter_data.seq_lens[seq_idx] - context_len + inter_data.query_lens[seq_idx] = ( + inter_data.seq_lens[seq_idx] - context_len + ) elif seq_len <= prefix_cache_len: # Full hit. Only compute the last token to avoid # erroneous behavior. FIXME: Ideally we should directly # mark all tokens as computed in the scheduler and do not # schedule this sequence, so this case should not happen. - inter_data.input_tokens[seq_idx] = inter_data.input_tokens[ - seq_idx][-1:] + inter_data.input_tokens[seq_idx] = inter_data.input_tokens[seq_idx][ + -1: + ] inter_data.input_positions[seq_idx] = inter_data.input_positions[ - seq_idx][-1:] + seq_idx + ][-1:] inter_data.token_types[seq_idx] = inter_data.token_types[seq_idx][ - -1:] + -1: + ] inter_data.query_lens[seq_idx] = 1 inter_data.context_lens[seq_idx] = inter_data.seq_lens[seq_idx] - 1 - def _compute_for_sliding_window(self, inter_data: InterDataForSeqGroup, - seq_idx: int, - seq_group_metadata: SequenceGroupMetadata): + def _compute_for_sliding_window( + self, + inter_data: InterDataForSeqGroup, + seq_idx: int, + seq_group_metadata: SequenceGroupMetadata, + ): """Update seq_len and curr_sliding_window_block for the given sequence data (only required by decoding) if sliding window is enabled. """ @@ -601,18 +667,24 @@ def _compute_for_sliding_window(self, inter_data: InterDataForSeqGroup, curr_sliding_window_block = self.sliding_window_blocks # number of elements in last block suff_len = inter_data.seq_lens[seq_idx] % self.block_size - sliding_seq_len = min(inter_data.seq_lens[seq_idx], - self.block_aligned_sliding_window + suff_len) + sliding_seq_len = min( + inter_data.seq_lens[seq_idx], + self.block_aligned_sliding_window + suff_len, + ) if suff_len > 0: curr_sliding_window_block += 1 - inter_data.curr_sliding_window_blocks[ - seq_idx] = curr_sliding_window_block + inter_data.curr_sliding_window_blocks[seq_idx] = ( + curr_sliding_window_block + ) inter_data.seq_lens[seq_idx] = sliding_seq_len - def _compute_lora_input(self, inter_data: InterDataForSeqGroup, - seq_idx: int, - seq_group_metadata: SequenceGroupMetadata): + def _compute_lora_input( + self, + inter_data: InterDataForSeqGroup, + seq_idx: int, + seq_group_metadata: SequenceGroupMetadata, + ): """If LoRA is enabled, compute LoRA index and prompt mapping.""" if not self.enable_lora: return @@ -622,6 +694,7 @@ def _compute_lora_input(self, inter_data: InterDataForSeqGroup, inter_data.lora_requests.add(seq_group_metadata.lora_request) query_len = inter_data.query_lens[seq_idx] inter_data.lora_index_mapping.append([lora_id] * query_len) +<<<<<<< HEAD sampling_params = seq_group_metadata.sampling_params if sampling_params and sampling_params.prompt_logprobs is not None: inter_data.lora_prompt_mapping.append([lora_id] * query_len) @@ -629,12 +702,25 @@ def _compute_lora_input(self, inter_data: InterDataForSeqGroup, inter_data.lora_prompt_mapping.append([lora_id]) else: inter_data.lora_prompt_mapping.append([]) +======= + inter_data.lora_prompt_mapping.append( + [lora_id] + * ( + query_len + if seq_group_metadata.sampling_params + and seq_group_metadata.sampling_params.prompt_logprobs + is not None + else 1 + ) + ) +>>>>>>> 817a0867 (Add multipstep chunked-prefill support where prefill turns into decode after the first single step.) def _compute_prompt_adapter_input( - self, inter_data: InterDataForSeqGroup, - seq_group_metadata: SequenceGroupMetadata): - """If prompt adapter is enabled, compute index and prompt mapping. - """ + self, + inter_data: InterDataForSeqGroup, + seq_group_metadata: SequenceGroupMetadata, + ): + """If prompt adapter is enabled, compute index and prompt mapping.""" # Note that when is_prompt=True, we expect only one sequence # in the group. if not self.enable_prompt_adapter: @@ -648,25 +734,33 @@ def _compute_prompt_adapter_input( assert inter_data.n_seqs == 1 query_len = inter_data.query_lens[0] inter_data.prompt_adapter_request = ( - seq_group_metadata.prompt_adapter_request) + seq_group_metadata.prompt_adapter_request + ) num_tokens = seq_group_metadata.prompt_adapter_num_virtual_tokens inter_data.prompt_adapter_index_mapping = [ prompt_adapter_id ] * num_tokens + [0] * (query_len - num_tokens) inter_data.prompt_adapter_prompt_mapping = [prompt_adapter_id] * ( - query_len if seq_group_metadata.sampling_params - and seq_group_metadata.sampling_params.prompt_logprobs else 1) + query_len + if seq_group_metadata.sampling_params + and seq_group_metadata.sampling_params.prompt_logprobs + else 1 + ) - def _compute_multi_modal_input(self, inter_data: InterDataForSeqGroup, - seq_group_metadata: SequenceGroupMetadata): + def _compute_multi_modal_input( + self, + inter_data: InterDataForSeqGroup, + seq_group_metadata: SequenceGroupMetadata, + ): """If multi-modal data is given, add it to the input.""" # NOTE: mm_data only includes the subset of multi-modal items that # intersect with the current prefill positions. positions = inter_data.input_positions[0] mm_data, placeholder_maps = MultiModalPlaceholderMap.from_seq_group( seq_group_metadata, - range(positions[0], positions[0] + len(positions))) + range(positions[0], positions[0] + len(positions)), + ) if not mm_data: return @@ -687,17 +781,19 @@ def _compute_multi_modal_input(self, inter_data: InterDataForSeqGroup, video_grid_thw = mm_kwargs.get("video_grid_thw", None) assert image_grid_thw is not None or video_grid_thw is not None, ( "mrope embedding type requires multi-modal input mapper " - "returns 'image_grid_thw' or 'video_grid_thw'.") + "returns 'image_grid_thw' or 'video_grid_thw'." + ) hf_config = self.runner.model_config.hf_config inter_data.mrope_input_positions = [None] * inter_data.n_seqs for seq_idx in range(inter_data.n_seqs): seq_data = seq_group_metadata.seq_data[ - inter_data.seq_ids[seq_idx]] + inter_data.seq_ids[seq_idx] + ] token_ids = seq_data.get_token_ids() - mrope_input_positions, mrope_position_delta = \ + mrope_input_positions, mrope_position_delta = ( MRotaryEmbedding.get_input_positions( token_ids, image_grid_thw=image_grid_thw, @@ -706,15 +802,16 @@ def _compute_multi_modal_input(self, inter_data: InterDataForSeqGroup, video_token_id=hf_config.video_token_id, vision_start_token_id=hf_config.vision_start_token_id, vision_end_token_id=hf_config.vision_end_token_id, - spatial_merge_size=hf_config.vision_config. - spatial_merge_size, + spatial_merge_size=hf_config.vision_config.spatial_merge_size, context_len=inter_data.context_lens[seq_idx], seq_len=inter_data.seq_lens[seq_idx], ) + ) seq_data.mrope_position_delta = mrope_position_delta - inter_data.mrope_input_positions[ - seq_idx] = mrope_input_positions + inter_data.mrope_input_positions[seq_idx] = ( + mrope_input_positions + ) def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata): """Add a sequence group to the builder.""" @@ -739,7 +836,8 @@ def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata): computed_block_nums=seq_group_metadata.computed_block_nums, reinit=True, reinit_use_defaults=True, - encoder_seq_len=encoder_seq_len) + encoder_seq_len=encoder_seq_len, + ) self.inter_data_list.append(inter_data) @@ -749,6 +847,7 @@ def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata): for per_seq_group_fn in self.per_seq_group_compute_fns: per_seq_group_fn(inter_data, seq_group_metadata) +<<<<<<< HEAD def _use_captured_graph(self, batch_size: int, decode_only: bool, @@ -758,11 +857,30 @@ def _use_captured_graph(self, and max_decode_seq_len <= self.runner.max_seq_len_to_capture and max_encoder_seq_len <= self.runner.max_seq_len_to_capture and batch_size <= self.runner.max_batchsize_to_capture) +======= + def _use_captured_graph( + self, + batch_size: int, + decode_only: bool, + max_decode_seq_len: int, + max_encoder_seq_len: int = 0, + ) -> bool: + return ( + decode_only + and not self.runner.model_config.enforce_eager + and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1] + and max_decode_seq_len <= self.runner.max_seq_len_to_capture + and max_encoder_seq_len <= self.runner.max_seq_len_to_capture + and batch_size <= self.runner.max_batchsize_to_capture + ) +>>>>>>> 817a0867 (Add multipstep chunked-prefill support where prefill turns into decode after the first single step.) - def _get_cuda_graph_pad_size(self, - num_seqs: int, - max_decode_seq_len: int, - max_encoder_seq_len: int = 0) -> int: + def _get_cuda_graph_pad_size( + self, + num_seqs: int, + max_decode_seq_len: int, + max_encoder_seq_len: int = 0, + ) -> int: """ Determine the number of padding sequences required for running in CUDA graph mode. Returns -1 if CUDA graphs cannot be used. @@ -786,10 +904,12 @@ def _get_cuda_graph_pad_size(self, int: Returns the determined number of padding sequences. If CUDA graphs is not viable, returns -1. """ - is_mscp: bool = self.runner.scheduler_config.is_multi_step and \ - self.runner.scheduler_config.chunked_prefill_enabled + is_mscp: bool = ( + self.runner.scheduler_config.is_multi_step + and self.runner.scheduler_config.chunked_prefill_enabled + ) decode_only = self.decode_only or is_mscp - if not decode_only: + if not decode_only or self.runner.is_profile_run: # Early exit so we can treat num_seqs as the batch_size below. return -1 @@ -797,9 +917,9 @@ def _get_cuda_graph_pad_size(self, # tokens being scheduled. This conflation of num_seqs as batch_size # is valid as this is a decode-only case. batch_size = num_seqs - if not self._use_captured_graph(batch_size, decode_only, - max_decode_seq_len, - max_encoder_seq_len): + if not self._use_captured_graph( + batch_size, decode_only, max_decode_seq_len, max_encoder_seq_len + ): return -1 graph_batch_size = VllmConfig.get_graph_batch_size(batch_size) @@ -825,8 +945,10 @@ def build(self) -> ModelInputForGPU: return self.model_input_cls() mrope_input_positions: Optional[List[List[int]]] = None - if any(inter_data.mrope_input_positions is not None - for inter_data in self.inter_data_list): + if any( + inter_data.mrope_input_positions is not None + for inter_data in self.inter_data_list + ): mrope_input_positions = [[] for _ in range(3)] for idx in range(3): for inter_data in self.inter_data_list: @@ -834,11 +956,13 @@ def build(self) -> ModelInputForGPU: if msections is None: for _seq_input_positions in inter_data.input_positions: mrope_input_positions[idx].extend( - _seq_input_positions) + _seq_input_positions + ) else: for _seq_mrope_input_positions in msections: mrope_input_positions[idx].extend( - _seq_mrope_input_positions[idx]) + _seq_mrope_input_positions[idx] + ) input_positions = None else: input_positions = [] @@ -854,23 +978,25 @@ def build(self) -> ModelInputForGPU: seq_lens.extend(inter_data.seq_lens) query_lens.extend(inter_data.query_lens) if not inter_data.is_prompt: - max_decode_seq_len = max(max_decode_seq_len, - max(inter_data.seq_lens)) + max_decode_seq_len = max( + max_decode_seq_len, max(inter_data.seq_lens) + ) if self.runner.model_config.is_encoder_decoder: - max_encoder_seq_len = max(max_encoder_seq_len, - inter_data.encoder_seq_len) + max_encoder_seq_len = max( + max_encoder_seq_len, inter_data.encoder_seq_len + ) # Mapping from request IDs to sequence IDs. Used for Jamba models # that manages the cache by itself. request_ids_to_seq_ids = { - data.request_id: data.seq_ids - for data in self.inter_data_list + data.request_id: data.seq_ids for data in self.inter_data_list } cuda_graph_pad_size = self._get_cuda_graph_pad_size( num_seqs=len(seq_lens), max_decode_seq_len=max_decode_seq_len, - max_encoder_seq_len=max_encoder_seq_len) + max_encoder_seq_len=max_encoder_seq_len, + ) batch_size = len(input_tokens) if cuda_graph_pad_size != -1: @@ -883,78 +1009,106 @@ def build(self) -> ModelInputForGPU: if cuda_graph_pad_size: input_tokens.extend(itertools.repeat(0, cuda_graph_pad_size)) assert self.runner.device is not None - input_tokens_tensor = async_tensor_h2d(input_tokens, torch.long, - self.runner.device, - self.runner.pin_memory) + input_tokens_tensor = async_tensor_h2d( + input_tokens, torch.long, self.runner.device, self.runner.pin_memory + ) - token_types_tensor = async_tensor_h2d(token_types, torch.long, - self.runner.device, - self.runner.pin_memory) \ - if token_types else None + token_types_tensor = ( + async_tensor_h2d( + token_types, + torch.long, + self.runner.device, + self.runner.pin_memory, + ) + if token_types + else None + ) if mrope_input_positions is not None: for idx in range(3): mrope_input_positions[idx].extend( - itertools.repeat(0, cuda_graph_pad_size)) - input_positions_tensor = async_tensor_h2d(mrope_input_positions, - torch.long, - self.runner.device, - self.runner.pin_memory) + itertools.repeat(0, cuda_graph_pad_size) + ) + input_positions_tensor = async_tensor_h2d( + mrope_input_positions, + torch.long, + self.runner.device, + self.runner.pin_memory, + ) else: input_positions.extend(itertools.repeat(0, cuda_graph_pad_size)) - input_positions_tensor = async_tensor_h2d(input_positions, - torch.long, - self.runner.device, - self.runner.pin_memory) + input_positions_tensor = async_tensor_h2d( + input_positions, + torch.long, + self.runner.device, + self.runner.pin_memory, + ) # Sequence and query lengths. if cuda_graph_pad_size: seq_lens.extend(itertools.repeat(1, cuda_graph_pad_size)) # Attention metadata. attn_metadata = self.attn_metadata_builder.build( - seq_lens, query_lens, cuda_graph_pad_size, batch_size) + seq_lens, query_lens, cuda_graph_pad_size, batch_size + ) # LoRA data. lora_requests = set() lora_mapping = None if self.enable_lora: - lora_requests = set(r for data in self.inter_data_list - for r in data.lora_requests) - lora_index_mapping = flatten_2d_lists([ - flatten_2d_lists(inter_data.lora_index_mapping) - for inter_data in self.inter_data_list - ]) + lora_requests = set( + r for data in self.inter_data_list for r in data.lora_requests + ) + lora_index_mapping = flatten_2d_lists( + [ + flatten_2d_lists(inter_data.lora_index_mapping) + for inter_data in self.inter_data_list + ] + ) if cuda_graph_pad_size: lora_index_mapping.extend( - itertools.repeat(0, cuda_graph_pad_size)) - lora_prompt_mapping = flatten_2d_lists([ - flatten_2d_lists(inter_data.lora_prompt_mapping) - for inter_data in self.inter_data_list - ]) + itertools.repeat(0, cuda_graph_pad_size) + ) + lora_prompt_mapping = flatten_2d_lists( + [ + flatten_2d_lists(inter_data.lora_prompt_mapping) + for inter_data in self.inter_data_list + ] + ) lora_mapping = LoRAMapping( - **dict(index_mapping=lora_index_mapping, - prompt_mapping=lora_prompt_mapping, - is_prefill=not self.decode_only)) + **dict( + index_mapping=lora_index_mapping, + prompt_mapping=lora_prompt_mapping, + is_prefill=not self.decode_only, + ) + ) # Prompt adapter data. prompt_adapter_requests: Set[PromptAdapterRequest] = set() prompt_adapter_mapping = None if self.enable_prompt_adapter: prompt_adapter_requests = set( - data.prompt_adapter_request for data in self.inter_data_list - if data.prompt_adapter_request is not None) - prompt_adapter_index_mapping = flatten_2d_lists([ - inter_data.prompt_adapter_index_mapping - for inter_data in self.inter_data_list - ]) + data.prompt_adapter_request + for data in self.inter_data_list + if data.prompt_adapter_request is not None + ) + prompt_adapter_index_mapping = flatten_2d_lists( + [ + inter_data.prompt_adapter_index_mapping + for inter_data in self.inter_data_list + ] + ) if cuda_graph_pad_size: prompt_adapter_index_mapping.extend( - itertools.repeat(0, cuda_graph_pad_size)) - prompt_adapter_prompt_mapping = flatten_2d_lists([ - inter_data.prompt_adapter_prompt_mapping - for inter_data in self.inter_data_list - ]) + itertools.repeat(0, cuda_graph_pad_size) + ) + prompt_adapter_prompt_mapping = flatten_2d_lists( + [ + inter_data.prompt_adapter_prompt_mapping + for inter_data in self.inter_data_list + ] + ) prompt_adapter_mapping = PromptAdapterMapping( prompt_adapter_index_mapping, prompt_adapter_prompt_mapping, @@ -962,7 +1116,8 @@ def build(self) -> ModelInputForGPU: # Multi-modal data. multi_modal_kwargs_list = [ - data.multi_modal_kwargs for data in self.inter_data_list + data.multi_modal_kwargs + for data in self.inter_data_list if data.multi_modal_kwargs is not None ] multi_modal_kwargs = MultiModalKwargs.batch(multi_modal_kwargs_list) @@ -980,13 +1135,15 @@ def build(self) -> ModelInputForGPU: request_ids_to_seq_ids=request_ids_to_seq_ids, finished_requests_ids=self.finished_requests_ids, prompt_adapter_mapping=prompt_adapter_mapping, - prompt_adapter_requests=prompt_adapter_requests) + prompt_adapter_requests=prompt_adapter_requests, + ) class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): """ Helper class for shared methods between GPU model runners. """ + _model_input_cls: Type[TModelInputForGPU] _builder_cls: Type[ModelInputForGPUBuilder] @@ -999,7 +1156,6 @@ def __init__( input_registry: InputRegistry = INPUT_REGISTRY, mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, ): - ModelRunnerBase.__init__(self, vllm_config) model_config = self.model_config cache_config = self.cache_config @@ -1020,11 +1176,14 @@ def __init__( self.graph_runners: List[Dict[int, CUDAGraphRunner]] = [ {} for _ in range(self.parallel_config.pipeline_parallel_size) ] - self.graph_memory_pool: Optional[Tuple[ - int, int]] = None # Set during graph capture. + self.graph_memory_pool: Optional[Tuple[int, int]] = ( + None # Set during graph capture. + ) self.has_inner_state = model_config.has_inner_state + self.is_profile_run = False + # When using CUDA graph, the input block tables must be padded to # max_seq_len_to_capture. However, creating the block table in # Python can be expensive. To optimize this, we cache the block table @@ -1033,7 +1192,8 @@ def __init__( # (max batch size to capture, max seq len to capture / block size). self.graph_block_tables = np.zeros( (self.max_batchsize_to_capture, self.get_max_block_per_batch()), - dtype=np.int32) + dtype=np.int32, + ) # Attention-free but stateful models like Mamba need a placeholder attn # backend, as the attention metadata is needed to manage internal state. @@ -1041,28 +1201,36 @@ def __init__( # used for speculative decoding to avoid a divide-by-zero in # model_config.get_head_size() num_attn_heads = self.model_config.get_num_attention_heads( - self.parallel_config) - needs_attn_backend = (num_attn_heads != 0 - or self.model_config.is_attention_free) - - self.attn_backend = get_attn_backend( - self.model_config.get_head_size(), - self.model_config.dtype, - self.kv_cache_dtype, - self.block_size, - self.model_config.is_attention_free, - ) if needs_attn_backend else None + self.parallel_config + ) + needs_attn_backend = ( + num_attn_heads != 0 or self.model_config.is_attention_free + ) + + self.attn_backend = ( + get_attn_backend( + self.model_config.get_head_size(), + self.model_config.dtype, + self.kv_cache_dtype, + self.block_size, + self.model_config.is_attention_free, + ) + if needs_attn_backend + else None + ) if self.attn_backend: self.attn_state = self.attn_backend.get_state_cls()( - weakref.proxy(self)) + weakref.proxy(self) + ) else: self.attn_state = CommonAttentionState(weakref.proxy(self)) # Multi-modal data support self.input_registry = input_registry self.mm_registry = mm_registry - self.multi_modal_input_mapper = mm_registry \ - .create_input_mapper(model_config) + self.multi_modal_input_mapper = mm_registry.create_input_mapper( + model_config + ) self.mm_registry.init_mm_limits_per_prompt(self.model_config) # Lazy initialization @@ -1072,7 +1240,8 @@ def __init__( self.prompt_adapter_manager: LRUCacheWorkerPromptAdapterManager = None set_cpu_offload_max_bytes( - int(self.cache_config.cpu_offload_gb * 1024**3)) + int(self.cache_config.cpu_offload_gb * 1024**3) + ) # Used to cache python objects self.inter_data_cache: Dict[int, PyObjectCache] = {} @@ -1083,9 +1252,11 @@ def __init__( # prepare_model_inputs() call. This clobbers the cached # SequenceGroupToSample objects, as we reset the cache during # every prepare_model_inputs() call. - self.sampling_metadata_cache: SamplingMetadataCache = \ - SamplingMetadataCache() \ - if self.parallel_config.pipeline_parallel_size == 1 else None + self.sampling_metadata_cache: SamplingMetadataCache = ( + SamplingMetadataCache() + if self.parallel_config.pipeline_parallel_size == 1 + else None + ) def load_model(self) -> None: logger.info("Starting to load model %s...", self.model_config.model) @@ -1093,8 +1264,10 @@ def load_model(self) -> None: self.model = get_model(vllm_config=self.vllm_config) self.model_memory_usage = m.consumed_memory - logger.info("Loading model weights took %.4f GB", - self.model_memory_usage / float(2**30)) + logger.info( + "Loading model weights took %.4f GB", + self.model_memory_usage / float(2**30), + ) if self.lora_config: assert supports_lora( @@ -1102,15 +1275,18 @@ def load_model(self) -> None: ), f"{self.model.__class__.__name__} does not support LoRA yet." if supports_multimodal(self.model): - logger.warning("Regarding multimodal models, vLLM currently " - "only supports adding LoRA to language model.") + logger.warning( + "Regarding multimodal models, vLLM currently " + "only supports adding LoRA to language model." + ) # It's necessary to distinguish between the max_position_embeddings # of VLMs and LLMs. if hasattr(self.model.config, "max_position_embeddings"): max_pos_embeddings = self.model.config.max_position_embeddings else: max_pos_embeddings = ( - self.model.config.text_config.max_position_embeddings) + self.model.config.text_config.max_position_embeddings + ) self.lora_manager = LRUCacheWorkerLoRAManager( self.scheduler_config.max_num_seqs, @@ -1127,11 +1303,15 @@ def load_model(self) -> None: if self.prompt_adapter_config: self.prompt_adapter_manager = LRUCacheWorkerPromptAdapterManager( self.scheduler_config.max_num_seqs, - self.scheduler_config.max_num_batched_tokens, self.device, - self.prompt_adapter_config) + self.scheduler_config.max_num_batched_tokens, + self.device, + self.prompt_adapter_config, + ) self.model = ( self.prompt_adapter_manager.create_prompt_adapter_manager( - self.model)) + self.model + ) + ) if self.kv_cache_dtype == "fp8" and current_platform.is_rocm(): # Currently only ROCm accepts kv-cache scaling factors @@ -1144,21 +1324,27 @@ def load_model(self) -> None: "deprecated and will be removed. Please include " "kv cache scaling factors in the model checkpoint.", FutureWarning, - stacklevel=2) + stacklevel=2, + ) self.model.load_kv_cache_scales( - self.model_config.quantization_param_path) - logger.info("Loaded KV cache scaling factors from %s", - self.model_config.quantization_param_path) + self.model_config.quantization_param_path + ) + logger.info( + "Loaded KV cache scaling factors from %s", + self.model_config.quantization_param_path, + ) else: raise RuntimeError( "Using FP8 KV cache and scaling factors provided but " "model %s does not support loading scaling factors.", - self.model.__class__) + self.model.__class__, + ) else: logger.warning( "Using FP8 KV cache but no scaling factors " "provided. Defaulting to scaling factors of 1.0. " - "This may lead to less accurate results!") + "This may lead to less accurate results!" + ) if self.vllm_config.compilation_config.level ==\ CompilationLevel.DYNAMO_AS_IS and supports_dynamo(): @@ -1167,7 +1353,8 @@ def load_model(self) -> None: self.model = torch.compile( self.model, fullgraph=envs.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE, - backend=backend) + backend=backend, + ) def save_sharded_state( self, @@ -1176,6 +1363,7 @@ def save_sharded_state( max_size: Optional[int] = None, ) -> None: from vllm.model_executor.model_loader.loader import ShardedStateLoader + ShardedStateLoader.save_model( self.model, path, @@ -1188,6 +1376,7 @@ def save_tensorized_model( tensorizer_config: TensorizerConfig, ) -> None: from vllm.model_executor.model_loader.loader import TensorizerLoader + TensorizerLoader.save_model( self.model, tensorizer_config=tensorizer_config, @@ -1200,7 +1389,7 @@ def get_max_block_per_batch(self) -> int: def _prepare_model_input_tensors( self, seq_group_metadata_list: List[SequenceGroupMetadata], - finished_requests_ids: Optional[List[str]] = None + finished_requests_ids: Optional[List[str]] = None, ) -> TModelInputForGPU: """Helper method to prepare the model input based on a given sequence group. Prepares metadata needed for the base model forward pass but not @@ -1227,6 +1416,7 @@ def _prepare_model_input_tensors( @torch.inference_mode() def profile_run(self) -> None: # Enable top-k sampling to reflect the accurate memory usage. + self.is_profile_run = True sampling_params = SamplingParams(top_p=0.99, top_k=self.vocab_size - 1) max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens max_num_seqs = self.scheduler_config.max_num_seqs @@ -1246,8 +1436,9 @@ def profile_run(self) -> None: lora_int_id=lora_id, lora_path="/not/a/real/path", ) - self.lora_manager.add_dummy_lora(dummy_lora_request, - rank=LORA_WARMUP_RANK) + self.lora_manager.add_dummy_lora( + dummy_lora_request, rank=LORA_WARMUP_RANK + ) dummy_lora_requests.append(dummy_lora_request) dummy_lora_requests_per_seq = [ dummy_lora_requests[idx % len(dummy_lora_requests)] @@ -1265,29 +1456,35 @@ def profile_run(self) -> None: # of images processed. max_mm_tokens = self.mm_registry.get_max_multimodal_tokens( - self.model_config) + self.model_config + ) if max_mm_tokens > 0: max_num_seqs_orig = max_num_seqs - max_num_seqs = min(max_num_seqs, - max_num_batched_tokens // max_mm_tokens) + max_num_seqs = min( + max_num_seqs, max_num_batched_tokens // max_mm_tokens + ) if max_num_seqs < 1: - expr = (f"min({max_num_seqs_orig}, " - f"{max_num_batched_tokens} // {max_mm_tokens})") + expr = ( + f"min({max_num_seqs_orig}, " + f"{max_num_batched_tokens} // {max_mm_tokens})" + ) logger.warning( "Computed max_num_seqs (%s) to be less than 1. " - "Setting it to the minimum value of 1.", expr) + "Setting it to the minimum value of 1.", + expr, + ) max_num_seqs = 1 batch_size = 0 for group_id in range(max_num_seqs): - seq_len = (max_num_batched_tokens // max_num_seqs + - (group_id < max_num_batched_tokens % max_num_seqs)) + seq_len = max_num_batched_tokens // max_num_seqs + ( + group_id < max_num_batched_tokens % max_num_seqs + ) batch_size += seq_len - dummy_data = self.input_registry \ - .dummy_data_for_profiling(self.model_config, - seq_len, - self.mm_registry) + dummy_data = self.input_registry.dummy_data_for_profiling( + self.model_config, seq_len, self.mm_registry + ) seq = SequenceGroupMetadata( request_id=str(group_id), @@ -1296,7 +1493,8 @@ def profile_run(self) -> None: sampling_params=sampling_params, block_tables=None, lora_request=dummy_lora_requests_per_seq[group_id] - if dummy_lora_requests_per_seq else None, + if dummy_lora_requests_per_seq + else None, multi_modal_data=dummy_data.multi_modal_data, multi_modal_placeholders=dummy_data.multi_modal_placeholders, ) @@ -1317,16 +1515,19 @@ def profile_run(self) -> None: ] finished_requests_ids = [seq.request_id for seq in seqs] model_input = self.prepare_model_input( - seqs, finished_requests_ids=finished_requests_ids) + seqs, finished_requests_ids=finished_requests_ids + ) intermediate_tensors = None if not get_pp_group().is_first_rank: intermediate_tensors = self.model.make_empty_intermediate_tensors( batch_size=batch_size, dtype=self.model_config.dtype, - device=self.device) + device=self.device, + ) self.execute_model(model_input, kv_caches, intermediate_tensors) torch.cuda.synchronize() + self.is_profile_run = False return def remove_all_loras(self): @@ -1334,8 +1535,9 @@ def remove_all_loras(self): raise RuntimeError("LoRA is not enabled.") self.lora_manager.remove_all_adapters() - def set_active_loras(self, lora_requests: Set[LoRARequest], - lora_mapping: LoRAMapping) -> None: + def set_active_loras( + self, lora_requests: Set[LoRARequest], lora_mapping: LoRAMapping + ) -> None: if not self.lora_manager: raise RuntimeError("LoRA is not enabled.") self.lora_manager.set_active_adapters(lora_requests, lora_mapping) @@ -1366,15 +1568,19 @@ def remove_all_prompt_adapters(self): self.prompt_adapter_manager.remove_all_adapters() def set_active_prompt_adapters( - self, prompt_adapter_requests: Set[PromptAdapterRequest], - prompt_adapter_mapping: PromptAdapterMapping) -> None: + self, + prompt_adapter_requests: Set[PromptAdapterRequest], + prompt_adapter_mapping: PromptAdapterMapping, + ) -> None: if not self.prompt_adapter_manager: raise RuntimeError("PromptAdapter is not enabled.") self.prompt_adapter_manager.set_active_adapters( - prompt_adapter_requests, prompt_adapter_mapping) + prompt_adapter_requests, prompt_adapter_mapping + ) def add_prompt_adapter( - self, prompt_adapter_request: PromptAdapterRequest) -> bool: + self, prompt_adapter_request: PromptAdapterRequest + ) -> bool: if not self.prompt_adapter_manager: raise RuntimeError("PromptAdapter is not enabled.") return self.prompt_adapter_manager.add_adapter(prompt_adapter_request) @@ -1409,14 +1615,18 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None: per sequence in the batch. """ assert not self.model_config.enforce_eager - logger.info("Capturing cudagraphs for decoding. This may lead to " - "unexpected consequences if the model is not static. To " - "run the model in eager mode, set 'enforce_eager=True' or " - "use '--enforce-eager' in the CLI.") - logger.info("If out-of-memory error occurs during cudagraph capture," - " consider decreasing `gpu_memory_utilization` or " - "switching to eager mode. You can also reduce the " - "`max_num_seqs` as needed to decrease memory usage.") + logger.info( + "Capturing cudagraphs for decoding. This may lead to " + "unexpected consequences if the model is not static. To " + "run the model in eager mode, set 'enforce_eager=True' or " + "use '--enforce-eager' in the CLI." + ) + logger.info( + "If out-of-memory error occurs during cudagraph capture," + " consider decreasing `gpu_memory_utilization` or " + "switching to eager mode. You can also reduce the " + "`max_num_seqs` as needed to decrease memory usage." + ) start_time = time.perf_counter() start_free_gpu_memory = torch.cuda.mem_get_info()[0] @@ -1429,23 +1639,27 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None: # Prepare dummy previous_hidden_states only if needed by the model. # This is used by draft models such as EAGLE. previous_hidden_states = None - if "previous_hidden_states" in inspect.signature( - self.model.forward).parameters: + if ( + "previous_hidden_states" + in inspect.signature(self.model.forward).parameters + ): previous_hidden_states = torch.empty( - [max_batch_size, - self.model_config.get_hidden_size()], + [max_batch_size, self.model_config.get_hidden_size()], dtype=self.model_config.dtype, - device=self.device) + device=self.device, + ) intermediate_inputs = None if not get_pp_group().is_first_rank: intermediate_inputs = self.model.make_empty_intermediate_tensors( batch_size=max_batch_size, dtype=self.model_config.dtype, - device=self.device) + device=self.device, + ) with self.attn_state.graph_capture( - max_batch_size), graph_capture() as graph_capture_context: + max_batch_size + ), graph_capture() as graph_capture_context: # NOTE: Capturing the largest batch size first may help reduce the # memory usage of CUDA graph. for virtual_engine in range( @@ -1460,9 +1674,12 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None: if self.lora_config: lora_mapping = LoRAMapping( - **dict(index_mapping=[0] * batch_size, - prompt_mapping=[0] * batch_size, - is_prefill=False)) + **dict( + index_mapping=[0] * batch_size, + prompt_mapping=[0] * batch_size, + is_prefill=False, + ) + ) self.set_active_loras(set(), lora_mapping) if self.prompt_adapter_config: @@ -1471,64 +1688,68 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None: [-1] * batch_size, ) self.set_active_prompt_adapters( - set(), prompt_adapter_mapping) + set(), prompt_adapter_mapping + ) graph_runner = CUDAGraphRunner( - self.model, self.attn_backend.get_name(), + self.model, + self.attn_backend.get_name(), self.attn_state.graph_clone(batch_size), - self.model_config.is_encoder_decoder) + self.model_config.is_encoder_decoder, + ) capture_inputs = { - "input_ids": - input_tokens[:batch_size], - "positions": - input_positions[..., :batch_size], - "intermediate_inputs": - intermediate_inputs[:batch_size] - if intermediate_inputs is not None else None, - "kv_caches": - kv_caches[virtual_engine], - "attn_metadata": - attn_metadata, - "memory_pool": - self.graph_memory_pool, - "stream": - graph_capture_context.stream + "input_ids": input_tokens[:batch_size], + "positions": input_positions[..., :batch_size], + "intermediate_inputs": intermediate_inputs[:batch_size] + if intermediate_inputs is not None + else None, + "kv_caches": kv_caches[virtual_engine], + "attn_metadata": attn_metadata, + "memory_pool": self.graph_memory_pool, + "stream": graph_capture_context.stream, } if previous_hidden_states is not None: - capture_inputs[ - "previous_hidden_states"] = previous_hidden_states[: - batch_size] + capture_inputs["previous_hidden_states"] = ( + previous_hidden_states[:batch_size] + ) if self.has_inner_state: # Only used by Mamba-based models CUDA graph atm (Jamba) - capture_inputs.update({ - "seqlen_agnostic_capture_inputs": - self.model.get_seqlen_agnostic_capture_inputs( - batch_size) - }) + capture_inputs.update( + { + "seqlen_agnostic_capture_inputs": self.model.get_seqlen_agnostic_capture_inputs( + batch_size + ) + } + ) if self.model_config.is_encoder_decoder: # add the additional inputs to capture for # encoder-decoder models. self._update_inputs_to_capture_for_enc_dec_model( - capture_inputs) + capture_inputs + ) with set_forward_context(attn_metadata, self.vllm_config): graph_runner.capture(**capture_inputs) self.graph_memory_pool = graph_runner.graph.pool() self.graph_runners[virtual_engine][batch_size] = ( - graph_runner) + graph_runner + ) end_time = time.perf_counter() end_free_gpu_memory = torch.cuda.mem_get_info()[0] elapsed_time = end_time - start_time cuda_graph_size = start_free_gpu_memory - end_free_gpu_memory # This usually takes < 10 seconds. - logger.info("Graph capturing finished in %.0f secs, took %.2f GiB", - elapsed_time, cuda_graph_size / GiB_bytes) + logger.info( + "Graph capturing finished in %.0f secs, took %.2f GiB", + elapsed_time, + cuda_graph_size / GiB_bytes, + ) - def _update_inputs_to_capture_for_enc_dec_model(self, - capture_inputs: Dict[str, - Any]): + def _update_inputs_to_capture_for_enc_dec_model( + self, capture_inputs: Dict[str, Any] + ): """ Updates the set of input tensors needed for CUDA graph capture in an encoder-decoder model. @@ -1540,9 +1761,11 @@ def _update_inputs_to_capture_for_enc_dec_model(self, # During the decode phase encoder_input_ids and encoder_positions are # unset. Do the same thing for graph capture. capture_inputs["encoder_input_ids"] = torch.tensor( - [], dtype=torch.long).cuda() + [], dtype=torch.long + ).cuda() capture_inputs["encoder_positions"] = torch.tensor( - [], dtype=torch.long).cuda() + [], dtype=torch.long + ).cuda() @property def vocab_size(self) -> int: @@ -1553,19 +1776,22 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]): """ GPU model runner with sampling step. """ + _model_input_cls: Type[ModelInputForGPUWithSamplingMetadata] = ( - ModelInputForGPUWithSamplingMetadata) + ModelInputForGPUWithSamplingMetadata + ) _builder_cls: Type[ModelInputForGPUBuilder] = ModelInputForGPUBuilder def make_model_input_from_broadcasted_tensor_dict( self, tensor_dict: Dict[str, Any], ) -> ModelInputForGPUWithSamplingMetadata: - model_input = \ + model_input = ( ModelInputForGPUWithSamplingMetadata.from_broadcasted_tensor_dict( tensor_dict, attn_backend=self.attn_backend, ) + ) return model_input def prepare_model_input( @@ -1588,22 +1814,33 @@ def prepare_model_input( If cuda graph is required, this API automatically pads inputs. """ model_input = self._prepare_model_input_tensors( - seq_group_metadata_list, finished_requests_ids) + seq_group_metadata_list, finished_requests_ids + ) if get_pp_group().is_last_rank: # Sampling metadata is only required for the final pp group generators = self.get_generators(finished_requests_ids) sampling_metadata = SamplingMetadata.prepare( - seq_group_metadata_list, model_input.seq_lens, - model_input.query_lens, self.device, self.pin_memory, - generators, self.sampling_metadata_cache) + seq_group_metadata_list, + model_input.seq_lens, + model_input.query_lens, + self.device, + self.pin_memory, + generators, + self.sampling_metadata_cache, + ) else: sampling_metadata = None - is_prompt = (seq_group_metadata_list[0].is_prompt - if seq_group_metadata_list else None) - return dataclasses.replace(model_input, - sampling_metadata=sampling_metadata, - is_prompt=is_prompt, - virtual_engine=virtual_engine) + is_prompt = ( + seq_group_metadata_list[0].is_prompt + if seq_group_metadata_list + else None + ) + return dataclasses.replace( + model_input, + sampling_metadata=sampling_metadata, + is_prompt=is_prompt, + virtual_engine=virtual_engine, + ) @torch.inference_mode() @dump_input_when_exception(exclude_args=[0], exclude_kwargs=["self"]) @@ -1620,15 +1857,17 @@ def execute_model( if self.lora_config: assert model_input.lora_requests is not None assert model_input.lora_mapping is not None - self.set_active_loras(model_input.lora_requests, - model_input.lora_mapping) + self.set_active_loras( + model_input.lora_requests, model_input.lora_mapping + ) if self.prompt_adapter_config: assert model_input.prompt_adapter_requests is not None assert model_input.prompt_adapter_mapping is not None self.set_active_prompt_adapters( model_input.prompt_adapter_requests, - model_input.prompt_adapter_mapping) + model_input.prompt_adapter_mapping, + ) self.attn_state.begin_forward(model_input) @@ -1643,7 +1882,8 @@ def execute_model( assert model_input.input_tokens is not None graph_batch_size = model_input.input_tokens.shape[0] model_executable = self.graph_runners[virtual_engine][ - graph_batch_size] + graph_batch_size + ] else: model_executable = self.model @@ -1666,12 +1906,18 @@ def execute_model( ) multi_modal_kwargs = model_input.multi_modal_kwargs or {} - seqlen_agnostic_kwargs = { - "finished_requests_ids": model_input.finished_requests_ids, - "request_ids_to_seq_ids": model_input.request_ids_to_seq_ids, - } if self.has_inner_state else {} - if (self.observability_config is not None - and self.observability_config.collect_model_forward_time): + seqlen_agnostic_kwargs = ( + { + "finished_requests_ids": model_input.finished_requests_ids, + "request_ids_to_seq_ids": model_input.request_ids_to_seq_ids, + } + if self.has_inner_state + else {} + ) + if ( + self.observability_config is not None + and self.observability_config.collect_model_forward_time + ): model_forward_start = torch.cuda.Event(enable_timing=True) model_forward_end = torch.cuda.Event(enable_timing=True) model_forward_start.record() @@ -1689,8 +1935,10 @@ def execute_model( device=self.device), **seqlen_agnostic_kwargs) - if (self.observability_config is not None - and self.observability_config.collect_model_forward_time): + if ( + self.observability_config is not None + and self.observability_config.collect_model_forward_time + ): model_forward_end.record() # Sending KV cache in distributed KV cache transfer setting @@ -1708,25 +1956,32 @@ def execute_model( # Compute the logits in the last pipeline stage. if not get_pp_group().is_last_rank: - if (self.is_driver_worker - and hidden_or_intermediate_states is not None - and isinstance(hidden_or_intermediate_states, - IntermediateTensors) - and self.observability_config is not None - and self.observability_config.collect_model_forward_time): + if ( + self.is_driver_worker + and hidden_or_intermediate_states is not None + and isinstance( + hidden_or_intermediate_states, IntermediateTensors + ) + and self.observability_config is not None + and self.observability_config.collect_model_forward_time + ): model_forward_end.synchronize() model_forward_time = model_forward_start.elapsed_time( - model_forward_end) + model_forward_end + ) orig_model_forward_time = 0.0 if intermediate_tensors is not None: orig_model_forward_time = intermediate_tensors.tensors.get( - "model_forward_time", torch.tensor(0.0)).item() + "model_forward_time", torch.tensor(0.0) + ).item() hidden_or_intermediate_states.tensors["model_forward_time"] = ( - torch.tensor(model_forward_time + orig_model_forward_time)) + torch.tensor(model_forward_time + orig_model_forward_time) + ) return hidden_or_intermediate_states - logits = self.model.compute_logits(hidden_or_intermediate_states, - model_input.sampling_metadata) + logits = self.model.compute_logits( + hidden_or_intermediate_states, model_input.sampling_metadata + ) if not self.is_driver_worker: return [] @@ -1739,22 +1994,27 @@ def execute_model( logits=logits, sampling_metadata=model_input.sampling_metadata, ) - if (self.observability_config is not None - and self.observability_config.collect_model_forward_time - and output is not None): + if ( + self.observability_config is not None + and self.observability_config.collect_model_forward_time + and output is not None + ): model_forward_end.synchronize() model_forward_time = model_forward_start.elapsed_time( - model_forward_end) + model_forward_end + ) orig_model_forward_time = 0.0 if intermediate_tensors is not None: orig_model_forward_time = intermediate_tensors.tensors.get( - "model_forward_time", torch.tensor(0.0)).item() + "model_forward_time", torch.tensor(0.0) + ).item() # If there are multiple workers, we are still tracking the latency # from the start time of the driver worker to the end time of the # driver worker. The model forward time will then end up covering # the communication time as well. - output.model_forward_time = (orig_model_forward_time + - model_forward_time) + output.model_forward_time = ( + orig_model_forward_time + model_forward_time + ) if self.return_hidden_states: # we only need to pass hidden states of most recent token @@ -1762,10 +2022,11 @@ def execute_model( indices = model_input.sampling_metadata.selected_token_indices if model_input.is_prompt: hidden_states = hidden_or_intermediate_states.index_select( - 0, indices) + 0, indices + ) output.prefill_hidden_states = hidden_or_intermediate_states elif decode_meta.use_cuda_graph: - hidden_states = hidden_or_intermediate_states[:len(indices)] + hidden_states = hidden_or_intermediate_states[: len(indices)] else: hidden_states = hidden_or_intermediate_states @@ -1827,9 +2088,13 @@ def need_send_kv(self, model_input, kv_caches) -> bool: # NOTE: this is nn.Module so the profiler can properly capture/group # kernels calls made within the graph class CUDAGraphRunner(nn.Module): - - def __init__(self, model: nn.Module, backend_name: str, - attn_state: AttentionState, is_encoder_decoder_model: bool): + def __init__( + self, + model: nn.Module, + backend_name: str, + attn_state: AttentionState, + is_encoder_decoder_model: bool, + ): super().__init__() self.model = model self.backend_name = backend_name @@ -1888,15 +2153,17 @@ def capture( if isinstance(output_hidden_or_intermediate_states, torch.Tensor): hidden_or_intermediate_states = weak_ref_tensor( - output_hidden_or_intermediate_states) - elif isinstance(output_hidden_or_intermediate_states, - IntermediateTensors): + output_hidden_or_intermediate_states + ) + elif isinstance( + output_hidden_or_intermediate_states, IntermediateTensors + ): hidden_or_intermediate_states = IntermediateTensors( tensors={ key: weak_ref_tensor(value) - for key, value in - output_hidden_or_intermediate_states.tensors.items() - }) + for key, value in output_hidden_or_intermediate_states.tensors.items() + } + ) del output_hidden_or_intermediate_states # make sure `output_hidden_or_intermediate_states` is deleted @@ -1906,14 +2173,12 @@ def capture( # Save the input and output buffers. self.input_buffers = { - "input_ids": - input_ids, - "positions": - positions, - "kv_caches": - kv_caches, + "input_ids": input_ids, + "positions": positions, + "kv_caches": kv_caches, **self.attn_state.get_graph_input_buffers( - attn_metadata, self._is_encoder_decoder_model), + attn_metadata, self._is_encoder_decoder_model + ), **kwargs, } if intermediate_inputs is not None: @@ -1943,29 +2208,36 @@ def forward( if self.backend_name != "NO_ATTENTION": self.input_buffers["slot_mapping"].copy_( - attn_metadata.slot_mapping, non_blocking=True) + attn_metadata.slot_mapping, non_blocking=True + ) self.attn_state.prepare_graph_input_buffers( - self.input_buffers, attn_metadata, self._is_encoder_decoder_model) + self.input_buffers, attn_metadata, self._is_encoder_decoder_model + ) if "seqlen_agnostic_capture_inputs" in self.input_buffers: - self.model.copy_inputs_before_cuda_graphs(self.input_buffers, - **kwargs) + self.model.copy_inputs_before_cuda_graphs( + self.input_buffers, **kwargs + ) if "previous_hidden_states" in self.input_buffers: self.input_buffers["previous_hidden_states"].copy_( - kwargs["previous_hidden_states"], non_blocking=True) + kwargs["previous_hidden_states"], non_blocking=True + ) if intermediate_tensors is not None: for key in intermediate_tensors.tensors: if key != "model_execute_time" and key != "model_forward_time": - self.input_buffers[key].copy_(intermediate_tensors[key], - non_blocking=True) + self.input_buffers[key].copy_( + intermediate_tensors[key], non_blocking=True + ) if self._is_encoder_decoder_model: self.input_buffers["encoder_input_ids"].copy_( - kwargs['encoder_input_ids'], non_blocking=True) + kwargs["encoder_input_ids"], non_blocking=True + ) self.input_buffers["encoder_positions"].copy_( - kwargs['encoder_positions'], non_blocking=True) + kwargs["encoder_positions"], non_blocking=True + ) # Run the graph. self.graph.replay() diff --git a/vllm/worker/multi_step_model_runner.py b/vllm/worker/multi_step_model_runner.py index e08a61e31fe42..91c5ca789c501 100644 --- a/vllm/worker/multi_step_model_runner.py +++ b/vllm/worker/multi_step_model_runner.py @@ -1,26 +1,47 @@ import dataclasses import functools from dataclasses import dataclass, field -from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, - Union) +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + List, + Optional, + Tuple, + Union, +) import torch from vllm.distributed import get_pp_group from vllm.logger import init_logger -from vllm.model_executor.layers.sampler import (PromptLogprobs, SampleLogprobs, - SamplerOutput, - SamplingMetadata, get_logprobs, - get_pythonized_sample_results) -from vllm.sequence import (CompletionSequenceGroupOutput, IntermediateTensors, - Logprob, SequenceGroupMetadata, SequenceOutput) +from vllm.model_executor.layers.sampler import ( + PromptLogprobs, + SampleLogprobs, + SamplerOutput, + SamplingMetadata, + get_logprobs, + get_pythonized_sample_results, +) +from vllm.sequence import ( + CompletionSequenceGroupOutput, + IntermediateTensors, + Logprob, + SequenceGroupMetadata, + SequenceOutput, +) from vllm.utils import PyObjectCache, async_tensor_h2d -from vllm.worker.model_runner import (GPUModelRunnerBase, - ModelInputForGPUWithSamplingMetadata) +from vllm.worker.model_runner import ( + GPUModelRunnerBase, + ModelInputForGPUWithSamplingMetadata, +) from vllm.worker.model_runner_base import ( - BroadcastableModelInput, _init_attn_metadata_from_tensor_dict, + BroadcastableModelInput, + _init_attn_metadata_from_tensor_dict, _init_frozen_model_input_from_tensor_dict, - _init_sampling_metadata_from_tensor_dict) + _init_sampling_metadata_from_tensor_dict, +) from ..model_executor.model_loader.tensorizer import TensorizerConfig @@ -30,12 +51,17 @@ logger = init_logger(__name__) MULTI_STEP_ATTENTION_BACKENDS = [ - "FLASH_ATTN", "ROCM_FLASH", "FLASHINFER", "NO_ATTENTION" + "FLASH_ATTN", + "ROCM_FLASH", + "FLASHINFER", + "NO_ATTENTION", ] -MULTI_STEP_CHUNKED_PREFILL_ATTENTION_BACKENDS = ["FLASH_ATTN"] +MULTI_STEP_CHUNKED_PREFILL_ATTENTION_BACKENDS = ["FLASH_ATTN", "FLASHINFER"] -def _get_supported_attention_backends(chunked_prefill_enabled: bool) \ - -> List[str]: + +def _get_supported_attention_backends( + chunked_prefill_enabled: bool, +) -> List[str]: if chunked_prefill_enabled: return MULTI_STEP_CHUNKED_PREFILL_ATTENTION_BACKENDS else: @@ -44,8 +70,8 @@ def _get_supported_attention_backends(chunked_prefill_enabled: bool) \ def seq_output_builder(): return SequenceOutput( - 0, 0, - {0: Logprob(logprob=float('inf'), rank=None, decoded_token=None)}) + 0, 0, {0: Logprob(logprob=float("inf"), rank=None, decoded_token=None)} + ) def completion_seq_group_output_builder(): @@ -54,11 +80,11 @@ def completion_seq_group_output_builder(): # Used by pythonization to reduce python object allocations class PythonizationCache: - def __init__(self): self.cached_seq_output = PyObjectCache(seq_output_builder) self.cached_completion_seq_group_output = PyObjectCache( - completion_seq_group_output_builder) + completion_seq_group_output_builder + ) def reset(self): self.cached_seq_output.reset() @@ -81,6 +107,7 @@ class ModelOutput: 2. The output tensors are not ready and we need to wait for the event to be ready. """ + sampler_output: SamplerOutput sampler_output_ready_event: torch.cuda.Event sampled_token_ids: Optional[torch.Tensor] = None @@ -89,28 +116,38 @@ class ModelOutput: logprobs: Optional["torch.Tensor"] = None pythonization_cache: Optional[PythonizationCache] = None - def pythonize(self, input_metadata: "StatefulModelInput", - copy_stream: torch.cuda.Stream, - pinned_sampled_token_buffer: torch.Tensor) -> None: + def pythonize( + self, + input_metadata: "StatefulModelInput", + copy_stream: torch.cuda.Stream, + pinned_sampled_token_buffer: torch.Tensor, + ) -> None: """Pythonize the output. Blocking.""" if not self.pythonized: - self._pythonize_sampler_output(input_metadata, copy_stream, - pinned_sampled_token_buffer, True) + self._pythonize_sampler_output( + input_metadata, copy_stream, pinned_sampled_token_buffer, True + ) self.pythonized = True - def maybe_pythonize(self, input_metadata: "StatefulModelInput", - copy_stream: torch.cuda.Stream, - pinned_sampled_token_buffer: torch.Tensor) -> None: + def maybe_pythonize( + self, + input_metadata: "StatefulModelInput", + copy_stream: torch.cuda.Stream, + pinned_sampled_token_buffer: torch.Tensor, + ) -> None: """Pythonize the output if ready, else return None. Non-blocking.""" if not self.pythonized: self.pythonized = self._pythonize_sampler_output( - input_metadata, copy_stream, pinned_sampled_token_buffer, - False) + input_metadata, copy_stream, pinned_sampled_token_buffer, False + ) - def _pythonize_sampler_output(self, input_metadata: "StatefulModelInput", - copy_stream: torch.cuda.Stream, - pinned_sampled_token_buffer: torch.Tensor, - blocking: bool) -> bool: + def _pythonize_sampler_output( + self, + input_metadata: "StatefulModelInput", + copy_stream: torch.cuda.Stream, + pinned_sampled_token_buffer: torch.Tensor, + blocking: bool, + ) -> bool: """ If blocking is set, will block until the forward pass for the output is ready and pythonize the output. Upon completing Pythonization, erases @@ -124,10 +161,14 @@ def _pythonize_sampler_output(self, input_metadata: "StatefulModelInput", if blocking: self.sampler_output_ready_event.synchronize() with torch.cuda.stream(copy_stream): - _pythonize_sampler_output(input_metadata, self.sampler_output, - pinned_sampled_token_buffer, - self.sampled_token_ids, self.logprobs, - self.pythonization_cache) + _pythonize_sampler_output( + input_metadata, + self.sampler_output, + pinned_sampled_token_buffer, + self.sampled_token_ids, + self.logprobs, + self.pythonization_cache, + ) # Erase the logprobs GPU-side tensor. # Note that although _pythonize_sampler_output() runs in its @@ -157,7 +198,8 @@ class StatefulModelInput(BroadcastableModelInput): base_output_proc_callback: Optional[Callable] = None # ping-pong data structures for multi-step to wait on the previous step step_cuda_events: List[torch.cuda.Event] = field( - default_factory=lambda: [torch.cuda.Event(blocking=True)] * 2) + default_factory=lambda: [torch.cuda.Event(blocking=True)] * 2 + ) num_seqs: int = -1 num_queries: int = -1 num_single_step_prefills: int = 0 @@ -166,14 +208,14 @@ def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: assert self.frozen_model_input is not None tensor_dict = self.frozen_model_input.as_broadcastable_tensor_dict() new_tensor_dict = { - 'last_sampled_token_ids': self.last_sampled_token_ids, - 'current_step': self.current_step, - 'is_multi_step': self.is_multi_step, - 'is_last_step': self.is_last_step, - 'is_first_multi_step': self.is_first_multi_step, - 'num_seqs': self.num_seqs, - 'num_queries': self.num_queries, - 'num_single_step_prefills': self.num_single_step_prefills, + "last_sampled_token_ids": self.last_sampled_token_ids, + "current_step": self.current_step, + "is_multi_step": self.is_multi_step, + "is_last_step": self.is_last_step, + "is_first_multi_step": self.is_first_multi_step, + "num_seqs": self.num_seqs, + "num_queries": self.num_queries, + "num_single_step_prefills": self.num_single_step_prefills, } tensor_dict.update(new_tensor_dict) return tensor_dict @@ -187,9 +229,11 @@ def from_broadcasted_tensor_dict( tensor_dict = _init_sampling_metadata_from_tensor_dict(tensor_dict) if attn_backend is not None: tensor_dict = _init_attn_metadata_from_tensor_dict( - attn_backend, tensor_dict) + attn_backend, tensor_dict + ) tensor_dict = _init_frozen_model_input_from_tensor_dict( - ModelInputForGPUWithSamplingMetadata, tensor_dict) + ModelInputForGPUWithSamplingMetadata, tensor_dict + ) return cls(**tensor_dict) @@ -198,8 +242,9 @@ def record_step_event(self, current_stream: torch.cuda.Stream): # on it. We modulo by 2 to keep the events in a circular buffer and # support any attn backends that may be supported in the future. ie # Flashinfer would want two DecodeWrappers to overlap the CPU and GPU. - self.step_cuda_events[self.current_step & 1] = \ - torch.cuda.Event(blocking=True) + self.step_cuda_events[self.current_step & 1] = torch.cuda.Event( + blocking=True + ) self.step_cuda_events[self.current_step & 1].record(current_stream) def wait_previous_step(self): @@ -213,14 +258,19 @@ def wait_previous_step(self): # backend) self.step_cuda_events[(self.current_step + 1) & 1].wait() - def add_sampler_output(self, - sampler_output: SamplerOutput, - sampled_token_ids: Optional[torch.Tensor] = None): + def add_sampler_output( + self, + sampler_output: SamplerOutput, + sampled_token_ids: Optional[torch.Tensor] = None, + ): self.cached_outputs.append( - ModelOutput(sampler_output=sampler_output, - sampler_output_ready_event=None, - sampled_token_ids=sampled_token_ids, - pythonized=False)) + ModelOutput( + sampler_output=sampler_output, + sampler_output_ready_event=None, + sampled_token_ids=sampled_token_ids, + pythonized=False, + ) + ) def maybe_advance_sampling_metadata(self, device: str, pin_memory: bool): """ @@ -249,11 +299,14 @@ def maybe_advance_sampling_metadata(self, device: str, pin_memory: bool): assert self.frozen_model_input is not None assert self.frozen_model_input.sampling_metadata is not None - self.frozen_model_input.sampling_metadata.selected_token_indices = \ - async_tensor_h2d(list(range(self.num_queries)), - dtype=torch.long, - target_device=device, - pin_memory=pin_memory) + self.frozen_model_input.sampling_metadata.selected_token_indices = ( + async_tensor_h2d( + list(range(self.num_queries)), + dtype=torch.long, + target_device=device, + pin_memory=pin_memory, + ) + ) def maybe_advance_frozen_model_input(self, device: str, pin_memory: bool): """ @@ -271,13 +324,14 @@ def maybe_advance_frozen_model_input(self, device: str, pin_memory: bool): # Truncate input_tokens assert fmi.input_tokens is not None assert fmi.input_tokens.shape[0] >= self.num_seqs - fmi_new_input_tokens: torch.Tensor = fmi.input_tokens[:self.num_seqs] + fmi_new_input_tokens: torch.Tensor = fmi.input_tokens[: self.num_seqs] # Update frozen_model_input::input_positons. assert fmi.input_positions is not None assert fmi.input_positions.shape[0] >= self.num_seqs - fmi_new_input_positions: torch.Tensor = fmi.input_positions[:self. - num_seqs] + fmi_new_input_positions: torch.Tensor = fmi.input_positions[ + : self.num_seqs + ] # Assert unsupported assert fmi.lora_mapping is None @@ -293,7 +347,8 @@ def maybe_advance_frozen_model_input(self, device: str, pin_memory: bool): self.frozen_model_input = dataclasses.replace( self.frozen_model_input, input_tokens=fmi_new_input_tokens, - input_positions=fmi_new_input_positions) + input_positions=fmi_new_input_positions, + ) self.maybe_advance_sampling_metadata(device, pin_memory) @@ -306,21 +361,25 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]): # mypy: enable-error-code=type-var def __init__(self, base_model_runner: GPUModelRunnerBase, *args, **kwargs): - super().__init__(*args, **kwargs) # Check attention backend support. - supported_attention_backends: List[str] = \ + supported_attention_backends: List[str] = ( _get_supported_attention_backends( - self.scheduler_config.chunked_prefill_enabled) + self.scheduler_config.chunked_prefill_enabled + ) + ) if self.attn_backend.get_name() not in supported_attention_backends: - ms_config_str: str = "Multi-Step + Chunked-Prefill" \ - if self.scheduler_config.chunked_prefill_enabled \ - else "Multi-Step" + ms_config_str: str = ( + "Multi-Step + Chunked-Prefill" + if self.scheduler_config.chunked_prefill_enabled + else "Multi-Step" + ) raise ValueError( f"{ms_config_str} not supported for attention backend: " f"{self.attn_backend.get_name()}. Set VLLM_ATTENTION_BACKEND " - f"to a value from {supported_attention_backends}.") + f"to a value from {supported_attention_backends}." + ) # uses the base model runner to execute the model and wraps it with # multi-step logic @@ -335,8 +394,11 @@ def __init__(self, base_model_runner: GPUModelRunnerBase, *args, **kwargs): # execution, there may be other on-going single-step/multi-step # executions. The current caching implementation does not check # for this. - self.pythonization_cache = PythonizationCache() \ - if self.parallel_config.pipeline_parallel_size == 1 else None + self.pythonization_cache = ( + PythonizationCache() + if self.parallel_config.pipeline_parallel_size == 1 + else None + ) @functools.cached_property def _copy_stream(self): @@ -344,24 +406,25 @@ def _copy_stream(self): return torch.cuda.Stream() def make_model_input_from_broadcasted_tensor_dict( - self, tensor_dict: Dict[str, Any]) -> StatefulModelInput: - model_input = (StatefulModelInput.from_broadcasted_tensor_dict( + self, tensor_dict: Dict[str, Any] + ) -> StatefulModelInput: + model_input = StatefulModelInput.from_broadcasted_tensor_dict( tensor_dict, attn_backend=self.attn_backend, - )) + ) return model_input def prepare_model_input( self, seq_group_metadata_list: List[SequenceGroupMetadata], virtual_engine: int = 0, - finished_requests_ids: Optional[List[str]] = None + finished_requests_ids: Optional[List[str]] = None, ) -> StatefulModelInput: - frozen_model_input: ModelInputForGPUWithSamplingMetadata = \ - self._base_model_runner.prepare_model_input( - seq_group_metadata_list, - virtual_engine, - finished_requests_ids) + frozen_model_input: ModelInputForGPUWithSamplingMetadata = ( + self._base_model_runner.prepare_model_input( + seq_group_metadata_list, virtual_engine, finished_requests_ids + ) + ) assert frozen_model_input.query_lens is not None assert frozen_model_input.seq_lens is not None @@ -374,12 +437,14 @@ def prepare_model_input( frozen_model_input=frozen_model_input, num_seqs=num_seqs, num_queries=num_queries, - num_single_step_prefills=num_single_step_prefills) + num_single_step_prefills=num_single_step_prefills, + ) return model_input - def _async_process_outputs(self, model_input: StatefulModelInput, - output_proc_callback: Callable): + def _async_process_outputs( + self, model_input: StatefulModelInput, output_proc_callback: Callable + ): # Proceed with pythonization and output_proc in order. # Stop on the first one that fails to pythonize output_proc_callback() @@ -387,8 +452,11 @@ def _async_process_outputs(self, model_input: StatefulModelInput, cont = True for step_num, model_output in enumerate(model_input.cached_outputs): if not model_output.pythonized: - model_output.maybe_pythonize(model_input, self._copy_stream, - self.pinned_sampled_token_ids) + model_output.maybe_pythonize( + model_input, + self._copy_stream, + self.pinned_sampled_token_ids, + ) if model_output.pythonized: ctx = output_proc_callback.keywords["ctx"] ctx.append_output( @@ -397,7 +465,8 @@ def _async_process_outputs(self, model_input: StatefulModelInput, scheduler_outputs=ctx.scheduler_outputs, is_async=False, is_last_step=False, - is_first_step_output=step_num == 0) + is_first_step_output=step_num == 0, + ) output_proc_callback() else: @@ -406,8 +475,11 @@ def _async_process_outputs(self, model_input: StatefulModelInput, if not cont: break - def _final_process_outputs(self, model_input: StatefulModelInput, - output_proc_callback: Optional[Callable]): + def _final_process_outputs( + self, + model_input: StatefulModelInput, + output_proc_callback: Optional[Callable], + ): assert model_input.frozen_model_input is not None has_async_callback = output_proc_callback is not None @@ -429,27 +501,34 @@ def _final_process_outputs(self, model_input: StatefulModelInput, # Pythonize if not output.pythonized: - output.pythonize(model_input, self._copy_stream, - self.pinned_sampled_token_ids) + output.pythonize( + model_input, + self._copy_stream, + self.pinned_sampled_token_ids, + ) # For non last step, add to callback queue to chain # callbacks=>pythonize pairs (for GPU overlap) if not is_last_step: ctx = output_proc_callback.keywords[ # type: ignore - "ctx"] # type: ignore + "ctx" + ] # type: ignore ctx.append_output( outputs=[output.sampler_output], - seq_group_metadata_list=ctx. - seq_group_metadata_list, + seq_group_metadata_list=ctx.seq_group_metadata_list, scheduler_outputs=ctx.scheduler_outputs, is_async=False, is_last_step=False, - is_first_step_output=step_num == 0) + is_first_step_output=step_num == 0, + ) else: outputs.append(output.sampler_output) else: - output.pythonize(model_input, self._copy_stream, - self.pinned_sampled_token_ids) + output.pythonize( + model_input, + self._copy_stream, + self.pinned_sampled_token_ids, + ) outputs.append(output.sampler_output) return outputs @@ -462,7 +541,7 @@ def execute_model( intermediate_tensors: Optional[IntermediateTensors] = None, num_steps: int = 1, ) -> Optional[Union[List[SamplerOutput], IntermediateTensors]]: - """ + """ Execute the model for a single step and update multi-step metadata """ @@ -473,7 +552,8 @@ def execute_model( # path for warm up runs if not model_input.is_multi_step: return self._base_model_runner.execute_model( - frozen_model_input, kv_caches, intermediate_tensors, num_steps) + frozen_model_input, kv_caches, intermediate_tensors, num_steps + ) # make sure we skip the sampler on the lask rank and only pythonize # if CPU is ahead. @@ -483,13 +563,16 @@ def execute_model( (self.scheduler_config.max_num_seqs, 1), dtype=torch.long, device="cpu", - pin_memory=True) + pin_memory=True, + ) self._base_model_runner.model.sampler.include_gpu_probs_tensor = ( - True) + True + ) if frozen_model_input.sampling_metadata: frozen_model_input.sampling_metadata.skip_sampler_cpu_output = ( - True) + True + ) # some pre-execute model logic for multi-step: # - if it's the first step, we need to reset the sampling tensors @@ -508,7 +591,8 @@ def execute_model( # far ahead if needed) model_input.wait_previous_step() model_input = self._advance_step( - model_input, model_input.cached_outputs[-1].sampler_output) + model_input, model_input.cached_outputs[-1].sampler_output + ) # frozen_model_input may have been updated frozen_model_input = model_input.frozen_model_input @@ -516,36 +600,37 @@ def execute_model( if model_input.base_output_proc_callback is None: assert frozen_model_input is not None - model_input.base_output_proc_callback = \ - frozen_model_input.async_callback + model_input.base_output_proc_callback = ( + frozen_model_input.async_callback + ) if frozen_model_input.async_callback is not None: assert model_input.base_output_proc_callback is not None async_callback = functools.partial( self._async_process_outputs, model_input=model_input, - output_proc_callback=model_input.base_output_proc_callback) + output_proc_callback=model_input.base_output_proc_callback, + ) model_input.frozen_model_input = dataclasses.replace( # type: ignore - model_input.frozen_model_input, - async_callback=async_callback) + model_input.frozen_model_input, async_callback=async_callback + ) # Update the local instance frozen_model_input = model_input.frozen_model_input assert frozen_model_input is not None # Execute the model - output = self._base_model_runner.execute_model(frozen_model_input, - kv_caches, - intermediate_tensors, - num_steps=1) + output = self._base_model_runner.execute_model( + frozen_model_input, kv_caches, intermediate_tensors, num_steps=1 + ) # record the event for the current step so that the next step can sync model_input.record_step_event(current_stream) if get_pp_group().is_last_rank and self.is_driver_worker: - assert len( - output - ) == 1, "MultiStepModelRunner requires single-step base_models" + assert ( + len(output) == 1 + ), "MultiStepModelRunner requires single-step base_models" # event for the pythonization so that we only pythonize if the # tensors are ready. May be able to be combined with the step event @@ -553,11 +638,18 @@ def execute_model( output_ready_event.record(current_stream) if self.parallel_config.pipeline_parallel_size > 1: output[0].sampled_token_ids_cpu = output[ - 0].sampled_token_ids.cpu() + 0 + ].sampled_token_ids.cpu() model_input.cached_outputs.append( - ModelOutput(output[0], output_ready_event, - output[0].sampled_token_ids, False, - output[0].logprobs, self.pythonization_cache)) + ModelOutput( + output[0], + output_ready_event, + output[0].sampled_token_ids, + False, + output[0].logprobs, + self.pythonization_cache, + ) + ) # These GPU tensors are not required by multi-step; # erase them to ensure they are not pythonized or @@ -570,9 +662,11 @@ def execute_model( # ready. if frozen_model_input.async_callback is None: for model_output in model_input.cached_outputs: - model_output.maybe_pythonize(model_input, - self._copy_stream, - self.pinned_sampled_token_ids) + model_output.maybe_pythonize( + model_input, + self._copy_stream, + self.pinned_sampled_token_ids, + ) model_input.current_step += 1 @@ -586,7 +680,8 @@ def execute_model( # Pythonize the output and block if needed since it is the last step if model_input.is_last_step: outputs = self._final_process_outputs( - model_input, model_input.base_output_proc_callback) + model_input, model_input.base_output_proc_callback + ) if self.pythonization_cache: self.pythonization_cache.reset() return outputs @@ -594,13 +689,12 @@ def execute_model( # should be [SamplerOutput] return output - def _update_sampling_metadata(self, sampling_metadata, num_seqs, - num_queries): - + def _update_sampling_metadata( + self, sampling_metadata, num_seqs, num_queries + ): assert sampling_metadata.num_prompts == 0 assert len(sampling_metadata.seq_groups) == num_queries - assert sampling_metadata.selected_token_indices.shape == ( - num_queries, ) + assert sampling_metadata.selected_token_indices.shape == (num_queries,) # assert sampling_metadata.categorized_sample_indices == TODO: Add if needed # noqa: E501 # Verify that all sequences are decodes @@ -613,11 +707,12 @@ def _update_sampling_metadata(self, sampling_metadata, num_seqs, assert seq_group.seq_len is None # Decode assert seq_group.query_len is None # Decode - def _advance_step(self, model_input: StatefulModelInput, - out: SamplerOutput) -> StatefulModelInput: - - model_input.maybe_advance_frozen_model_input(self.device, - self.pin_memory) + def _advance_step( + self, model_input: StatefulModelInput, out: SamplerOutput + ) -> StatefulModelInput: + model_input.maybe_advance_frozen_model_input( + self.device, self.pin_memory + ) frozen_model_input = model_input.frozen_model_input assert frozen_model_input is not None assert frozen_model_input.input_tokens is not None @@ -632,15 +727,18 @@ def _advance_step(self, model_input: StatefulModelInput, attn_metadata = frozen_model_input.attn_metadata assert attn_metadata is not None - turn_prefills_into_decodes: bool = model_input.current_step == 1 and \ - model_input.num_single_step_prefills != 0 + turn_prefills_into_decodes: bool = ( + model_input.current_step == 1 + and model_input.num_single_step_prefills != 0 + ) attn_metadata.advance_step( frozen_model_input, sampled_token_ids, self.block_size, num_seqs, num_queries, - turn_prefills_into_decodes=turn_prefills_into_decodes) + turn_prefills_into_decodes=turn_prefills_into_decodes, + ) return model_input @@ -654,10 +752,12 @@ def save_sharded_state( max_size: Optional[int] = None, ) -> None: return self._base_model_runner.save_sharded_state( - path, pattern, max_size) + path, pattern, max_size + ) - def save_tensorized_model(self, - tensorizer_config: TensorizerConfig) -> None: + def save_tensorized_model( + self, tensorizer_config: TensorizerConfig + ) -> None: return self._base_model_runner.save_tensorized_model(tensorizer_config) def profile_run(self) -> None: @@ -674,8 +774,9 @@ def vocab_size(self) -> int: return self._base_model_runner.vocab_size -DeferredLogprobsReturnType = Tuple[Optional[List[Optional[PromptLogprobs]]], - Optional[List[SampleLogprobs]]] +DeferredLogprobsReturnType = Tuple[ + Optional[List[Optional[PromptLogprobs]]], Optional[List[SampleLogprobs]] +] def deferred_pythonize_logprobs( @@ -688,21 +789,22 @@ def deferred_pythonize_logprobs( 1. Pythonize GPU-side sampler result tensors into CPU-side sampler result. 2. Pythonize GPU-side logprobs tensor into CPU-side logprobs lists, utilizing the Pythonized sampler result computed in step 1. - + These deferred computations are not required for single-step scheduling or the `profile_run()` phase of multi-step scheduling. Args: output: sampler output (under deferred Pythonization) sampling_metadata - + Returns: prompt_logprobs (CPU), sample_logprobs (CPU) """ # - Deferred pythonization of sample result sampler_result = get_pythonized_sample_results( - output.deferred_sample_results_args) + output.deferred_sample_results_args + ) # - Erase the GPU-side deferred sample_result # computation args to ensure it is never @@ -728,10 +830,10 @@ def _pythonize_sampler_output( logprobs_tensor: Optional[torch.Tensor], cache: Optional[PythonizationCache], ) -> None: - """ This function is only called when the output tensors are ready. - See :class:`ModelOutput`. - - Modifies `output.outputs` and `pinned_sampled_token_buffer` in-place, + """This function is only called when the output tensors are ready. + See :class:`ModelOutput`. + + Modifies `output.outputs` and `pinned_sampled_token_buffer` in-place, adding a Pythonized output data structure (:class:`CompletionSequenceGroupOutput`) for each :class:`SequenceGroup`. @@ -742,7 +844,7 @@ def _pythonize_sampler_output( (receives copy of GPU-side token buffer.) sampled_token_ids: GPU-side token buffer - logprobs_tensor: GPU-side tensor containing + logprobs_tensor: GPU-side tensor containing logprobs computed during sampling """ @@ -754,7 +856,7 @@ def _pythonize_sampler_output( # samples generation should have been skipped assert not output.outputs - pinned_buffer = pinned_sampled_token_buffer[:model_input.num_queries] + pinned_buffer = pinned_sampled_token_buffer[: model_input.num_queries] # We guarantee output tensors are ready, so it is safe to # pythonize the sampler output & obtain CPU-side logprobs. @@ -776,55 +878,67 @@ def _pythonize_sampler_output( # process to accommodate logprobs. seq_groups = sampling_metadata.seq_groups - prompt_logprobs_are_requested_for_prefill = any([ - sg.sampling_params.prompt_logprobs is not None and sg.is_prompt - for sg in seq_groups - ]) + prompt_logprobs_are_requested_for_prefill = any( + [ + sg.sampling_params.prompt_logprobs is not None and sg.is_prompt + for sg in seq_groups + ] + ) any_logprobs_are_requested = ( prompt_logprobs_are_requested_for_prefill - or any([sg.sampling_params.logprobs is not None for sg in seq_groups])) + or any([sg.sampling_params.logprobs is not None for sg in seq_groups]) + ) if prompt_logprobs_are_requested_for_prefill: # CPU GPU sync, after gathering *only* sampled tokens (since # requesting prompt logprobs leads `sampled_token_ids` to # include prompt token ids in addition to sampled token ids.) sample_idx_tensor = torch.tensor( - [sdx for sg in seq_groups for sdx in sg.sample_indices]) + [sdx for sg in seq_groups for sdx in sg.sample_indices] + ) pinned_buffer = pinned_buffer.copy_( - sampled_token_ids[sample_idx_tensor, :], non_blocking=False) + sampled_token_ids[sample_idx_tensor, :], non_blocking=False + ) else: # CPU GPU sync - pinned_buffer = pinned_buffer.copy_(sampled_token_ids, - non_blocking=False) + pinned_buffer = pinned_buffer.copy_( + sampled_token_ids, non_blocking=False + ) # this will not block as the tensors are already on CPU samples_list = pinned_buffer.tolist() skip_sampler_cpu_output = ( - frozen_model_input.sampling_metadata.skip_sampler_cpu_output) + frozen_model_input.sampling_metadata.skip_sampler_cpu_output + ) # *Don't* skip logprobs pythonization *if*: # * Any requests require logprobs to be returned in this # iteration AND # * These requests are being scheduled in a fashion which # defers pythonization (i.e. multi-step scheduling.) - do_pythonize_logprobs = (skip_sampler_cpu_output - and any_logprobs_are_requested) + do_pythonize_logprobs = ( + skip_sampler_cpu_output and any_logprobs_are_requested + ) ( prompt_logprobs, sample_logprobs, - ) = (deferred_pythonize_logprobs(output, sampling_metadata, - logprobs_tensor) - if do_pythonize_logprobs else (None, None)) - - for sgdx, (seq_group, - sample_result) in enumerate(zip(seq_groups, samples_list)): + ) = ( + deferred_pythonize_logprobs(output, sampling_metadata, logprobs_tensor) + if do_pythonize_logprobs + else (None, None) + ) + + for sgdx, (seq_group, sample_result) in enumerate( + zip(seq_groups, samples_list) + ): # Reminder: Please update docs/source/usage/compatibility_matrix.rst # If the feature combo become valid # (Check for Guided Decoding) if seq_group.sampling_params.logits_processors: - assert len(seq_group.sampling_params.logits_processors) == 0, ( - "Logits Processors are not supported in multi-step decoding") + assert ( + len(seq_group.sampling_params.logits_processors) == 0 + ), "Logits Processors are not supported in multi-step decoding" if do_pythonize_logprobs: assert prompt_logprobs is not None @@ -844,25 +958,30 @@ def _pythonize_sampler_output( ) = ( # profile_run: use already-computed logprobs output.outputs[sgdx].prompt_logprobs, - [sample.logprobs for sample in output.outputs[sgdx].samples]) + [sample.logprobs for sample in output.outputs[sgdx].samples], + ) seq_ids = seq_group.seq_ids next_token_ids = sample_result parent_ids = [0] if cache is not None: - completion_seq_group_output: CompletionSequenceGroupOutput = \ + completion_seq_group_output: CompletionSequenceGroupOutput = ( cache.cached_completion_seq_group_output.get_object() + ) completion_seq_group_output.samples.clear() - seq_outputs: List[ - SequenceOutput] = completion_seq_group_output.samples + seq_outputs: List[SequenceOutput] = ( + completion_seq_group_output.samples + ) else: seq_outputs = [] - for tdx, (parent_id, - next_token_id) in enumerate(zip(parent_ids, next_token_ids)): + for tdx, (parent_id, next_token_id) in enumerate( + zip(parent_ids, next_token_ids) + ): if cache is not None: - seq_output: SequenceOutput = cache.cached_seq_output.get_object( + seq_output: SequenceOutput = ( + cache.cached_seq_output.get_object() ) seq_output.parent_seq_id = seq_ids[parent_id] seq_output.output_token = next_token_id @@ -873,7 +992,7 @@ def _pythonize_sampler_output( logprobs = next(iter(seq_output.logprobs.values())) seq_output.logprobs.clear() - logprobs.logprob = float('inf') + logprobs.logprob = float("inf") logprobs.rank = None logprobs.decoded_token = None @@ -883,22 +1002,37 @@ def _pythonize_sampler_output( else: seq_outputs.append( - SequenceOutput(seq_ids[parent_id], next_token_id, - (group_sample_logprobs[tdx] - if any_logprobs_are_requested else { - next_token_id: - Logprob(logprob=float('inf'), - rank=None, - decoded_token=None) - }))) + SequenceOutput( + seq_ids[parent_id], + next_token_id, + ( + group_sample_logprobs[tdx] + if any_logprobs_are_requested + else { + next_token_id: Logprob( + logprob=float("inf"), + rank=None, + decoded_token=None, + ) + } + ), + ) + ) if cache is not None: - completion_seq_group_output.prompt_logprobs = \ + completion_seq_group_output.prompt_logprobs = ( group_prompt_logprobs if any_logprobs_are_requested else None + ) output.outputs.append(completion_seq_group_output) else: output.outputs.append( CompletionSequenceGroupOutput( - seq_outputs, (group_prompt_logprobs - if any_logprobs_are_requested else None))) + seq_outputs, + ( + group_prompt_logprobs + if any_logprobs_are_requested + else None + ), + ) + ) assert len(output.outputs) > 0