diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index 9dac12d3b906d..d46e87a3e6afe 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -297,23 +297,26 @@ def _add_seq_group( if is_profile_run: return - # Get the number of valid blocks based on sequence length. - # If seq_len = 16, block_size = 16, - # block_table_bound is 1 with 1 valid block. - # If seq_len = 15, block_size = 16, - # block_table_bound is 0 + 1 with 1 valid block. - block_table_bound = seq_len // self.block_size + 1 \ - if seq_len % self.block_size != 0 \ - else seq_len // self.block_size block_table = block_tables[seq_id] - self.paged_kv_indices.extend(block_table[:block_table_bound]) - self.paged_kv_indptr.append(self.paged_kv_indptr[-1] + - block_table_bound) - - last_page_len = seq_len % self.block_size - if last_page_len == 0: - last_page_len = self.block_size - self.paged_kv_last_page_len.append(last_page_len) + self._update_paged_kv_tensors(block_table, seq_len) + + def _update_paged_kv_tensors(self, block_table: List[int], seq_len: int): + # Get the number of valid blocks based on sequence length. + # If seq_len = 16, block_size = 16, + # block_table_bound is 1 with 1 valid block. + # If seq_len = 15, block_size = 16, + # block_table_bound is 0 + 1 with 1 valid block. + block_table_bound = seq_len // self.block_size + 1 \ + if seq_len % self.block_size != 0 \ + else seq_len // self.block_size + self.paged_kv_indices.extend(block_table[:block_table_bound]) + self.paged_kv_indptr.append(self.paged_kv_indptr[-1] + + block_table_bound) + + last_page_len = seq_len % self.block_size + if last_page_len == 0: + last_page_len = self.block_size + self.paged_kv_last_page_len.append(last_page_len) def build(self, seq_lens: List[int], query_lens: List[int], cuda_graph_pad_size: int, batch_size: int): diff --git a/vllm/worker/embedding_model_runner.py b/vllm/worker/embedding_model_runner.py index a333e6634a41f..e919dbd18d9df 100644 --- a/vllm/worker/embedding_model_runner.py +++ b/vllm/worker/embedding_model_runner.py @@ -11,7 +11,8 @@ from vllm.pooling_params import PoolingParams from vllm.sequence import (IntermediateTensors, PoolerOutput, SequenceData, SequenceGroupMetadata) -from vllm.worker.model_runner import GPUModelRunnerBase, ModelInputForGPU +from vllm.worker.model_runner import (GPUModelRunnerBase, ModelInputForGPU, + ModelInputForGPUBuilder) logger = init_logger(__name__) @@ -28,6 +29,7 @@ class EmbeddingModelRunner( GPUModelRunnerBase[ModelInputForGPUWithPoolingMetadata]): _model_input_cls: Type[ModelInputForGPUWithPoolingMetadata] = ( ModelInputForGPUWithPoolingMetadata) + _builder_cls: Type[ModelInputForGPUBuilder] = ModelInputForGPUBuilder def __init__( self, diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index e63be184af16a..83f62a730bc8a 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -3,7 +3,7 @@ import time import warnings import weakref -from dataclasses import dataclass, field +from dataclasses import dataclass from typing import (TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Set, Tuple, Type, TypeVar, Union) @@ -171,48 +171,83 @@ def from_broadcasted_tensor_dict( class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): """Build ModelInputForGPU from SequenceGroupMetadata.""" - @dataclass + # Note: ideally we would be using a dataclass(kw_only=True) + # here, so that this can be subclassed easily, + # but kw_only is not supported in python<3.10. class InterDataForSeqGroup: """Intermediate data for the current sequence group.""" - # From sequence group metadata. - request_id: str - seq_ids: List[int] - is_prompt: bool - block_tables: Optional[Dict[int, List[int]]] - computed_block_nums: List[int] - n_seqs: int = 0 - - # Input tokens and positions. - input_tokens: List[List[int]] = field(default_factory=list) - input_positions: List[List[int]] = field(default_factory=list) - - # The sequence length (may be capped to the sliding window). - seq_lens: List[int] = field(default_factory=list) - # The original sequence length (before applying sliding window). - # This is used to compute slot mapping. - orig_seq_lens: List[int] = field(default_factory=list) - # The query length. - query_lens: List[int] = field(default_factory=list) - # The number of tokens that are already computed. - context_lens: List[int] = field(default_factory=list) - # The current sliding window block. - curr_sliding_window_blocks: List[int] = field(default_factory=list) - - # LoRA inputs. - lora_index_mapping: List[List[int]] = field(default_factory=list) - lora_prompt_mapping: List[List[int]] = field(default_factory=list) - lora_requests: Set[LoRARequest] = field(default_factory=set) - - # Prompt adapter inputs. - prompt_adapter_index_mapping: List[int] = field(default_factory=list) - prompt_adapter_prompt_mapping: List[int] = field(default_factory=list) - prompt_adapter_request: Optional[PromptAdapterRequest] = None - - # Multi-modal inputs. - multi_modal_inputs: Optional[MultiModalInputs] = None - - # Whether the prefix cache is hit (prefill only). - prefix_cache_hit: bool = False + + def __init__( + self, + *, + # From sequence group metadata. + request_id: str, + seq_ids: List[int], + is_prompt: bool, + block_tables: Optional[Dict[int, List[int]]], + computed_block_nums: List[int], + n_seqs: int = 0, + + # Input tokens and positions. + input_tokens: Optional[List[List[int]]] = None, + input_positions: Optional[List[List[int]]] = None, + + # The sequence length (may be capped to the sliding window). + seq_lens: Optional[List[int]] = None, + # The original sequence length (before applying sliding window). + # This is used to compute slot mapping. + orig_seq_lens: Optional[List[int]] = None, + # The query length. + query_lens: Optional[List[int]] = None, + # The number of tokens that are already computed. + context_lens: Optional[List[int]] = None, + # The current sliding window block. + curr_sliding_window_blocks: Optional[List[int]] = None, + + # LoRA inputs. + lora_index_mapping: Optional[List[List[int]]] = None, + lora_prompt_mapping: Optional[List[List[int]]] = None, + lora_requests: Optional[Set[LoRARequest]] = None, + + # Prompt adapter inputs. + prompt_adapter_index_mapping: Optional[List[int]] = None, + prompt_adapter_prompt_mapping: Optional[List[int]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + + # Multi-modal inputs. + multi_modal_inputs: Optional[MultiModalInputs] = None, + + # Whether the prefix cache is hit (prefill only). + prefix_cache_hit: bool = False, + ): + self.request_id = request_id + self.seq_ids = seq_ids + self.is_prompt = is_prompt + self.block_tables = block_tables + self.computed_block_nums = computed_block_nums + self.n_seqs = n_seqs + self.input_tokens = input_tokens or [] + self.input_positions = input_positions or [] + self.seq_lens = seq_lens or [] + self.orig_seq_lens = orig_seq_lens or [] + self.query_lens = query_lens or [] + self.context_lens = context_lens or [] + self.curr_sliding_window_blocks = curr_sliding_window_blocks or [] + + self.lora_index_mapping = lora_index_mapping or [] + self.lora_prompt_mapping = lora_prompt_mapping or [] + self.lora_requests = lora_requests or set() + + self.prompt_adapter_index_mapping = (prompt_adapter_index_mapping + or []) + self.prompt_adapter_prompt_mapping = (prompt_adapter_prompt_mapping + or []) + self.prompt_adapter_request = prompt_adapter_request + + self.multi_modal_inputs = multi_modal_inputs + self.prefix_cache_hit = prefix_cache_hit + + self.__post_init__() def __post_init__(self): self.n_seqs = len(self.seq_ids) @@ -457,6 +492,12 @@ def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata): for per_seq_group_fn in self.per_seq_group_compute_fns: per_seq_group_fn(inter_data, seq_group_metadata) + def _use_captured_graph(self, batch_size: int, + max_decode_seq_len: int) -> bool: + return (self.decode_only and not self.runner.model_config.enforce_eager + and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1] + and max_decode_seq_len <= self.runner.max_seq_len_to_capture) + def build(self) -> ModelInputForGPU: """Finalize the builder intermediate data and create on-device tensors. @@ -491,10 +532,8 @@ def build(self) -> ModelInputForGPU: } batch_size = len(input_tokens) - use_captured_graph = ( - self.decode_only and not self.runner.model_config.enforce_eager - and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1] - and max_decode_seq_len <= self.runner.max_seq_len_to_capture) + use_captured_graph = self._use_captured_graph(batch_size, + max_decode_seq_len) # If cuda graph can be used, pad tensors accordingly. # See `capture_model` API for more details. @@ -592,6 +631,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): Helper class for shared methods between GPU model runners. """ _model_input_cls: Type[TModelInputForGPU] + _builder_cls: Type[ModelInputForGPUBuilder] def __init__( self, @@ -794,8 +834,7 @@ def _prepare_model_input_tensors( If cuda graph is required, this API automatically pads inputs. """ - builder = ModelInputForGPUBuilder(weakref.proxy(self), - finished_requests_ids) + builder = self._builder_cls(weakref.proxy(self), finished_requests_ids) for seq_group_metadata in seq_group_metadata_list: builder.add_seq_group(seq_group_metadata) return builder.build() # type: ignore @@ -1191,6 +1230,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]): """ _model_input_cls: Type[ModelInputForGPUWithSamplingMetadata] = ( ModelInputForGPUWithSamplingMetadata) + _builder_cls: Type[ModelInputForGPUBuilder] = ModelInputForGPUBuilder def make_model_input_from_broadcasted_tensor_dict( self,