From e0c15758b85827dcac78379e60ea975ebc0ec795 Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Mon, 22 Jul 2024 17:45:24 -0700 Subject: [PATCH] [Core] Modulize prepare input and attention metadata builder (#6596) --- vllm/attention/backends/abstract.py | 20 +- vllm/attention/backends/flash_attn.py | 43 ++- vllm/attention/backends/flashinfer.py | 60 +-- vllm/attention/backends/utils.py | 49 +-- vllm/utils.py | 5 + vllm/worker/model_runner.py | 530 ++++++++++++++++---------- 6 files changed, 409 insertions(+), 298 deletions(-) diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index 191c6ff000c85..106b00cc1014c 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -7,7 +7,6 @@ import torch if TYPE_CHECKING: - from vllm.sequence import SequenceGroupMetadata from vllm.worker.model_runner_base import ModelRunnerInputBuilderBase @@ -128,25 +127,12 @@ class AttentionMetadataBuilder(ABC, Generic[T]): """Abstract class for attention metadata builders.""" @abstractmethod - def __init__(self, input_builder) -> None: + def __init__(self, input_builder: "ModelRunnerInputBuilderBase") -> None: raise NotImplementedError @abstractmethod - def add_seq_group(self, seq_group_metadata: "SequenceGroupMetadata", - token_lens: List[int], seq_lens: List[int], - curr_seq_lens: List[int], query_lens: List[int], - context_lens: List[int], - curr_sliding_window_blocks: List[int], - prefix_cache_hit: bool, chunked_prefill_enabled: bool): - """Add a sequence group to the metadata and update - corresponding fields (in Python objects). - """ - raise NotImplementedError - - @abstractmethod - def build(self, runner: "ModelRunnerInputBuilderBase", seq_lens: List[int], - query_lens: List[int], cuda_graph_pad_size: int, - batch_size: int) -> T: + def build(self, seq_lens: List[int], query_lens: List[int], + cuda_graph_pad_size: int, batch_size: int) -> T: """Build attention metadata with on-device tensors.""" raise NotImplementedError diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index cad3181d3edb7..b16a204c8f44e 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -13,12 +13,10 @@ from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping, compute_slot_mapping_start_idx, is_block_tables_empty) -from vllm.sequence import SequenceGroupMetadata from vllm.utils import make_tensor_with_pad if TYPE_CHECKING: - from vllm.worker.model_runner import (GPUModelRunnerBase, - ModelInputForGPUBuilder) + from vllm.worker.model_runner import ModelInputForGPUBuilder class FlashAttentionBackend(AttentionBackend): @@ -212,30 +210,30 @@ def __init__(self, input_builder: "ModelInputForGPUBuilder"): self.num_prefill_tokens = 0 self.num_decode_tokens = 0 + self.input_builder = input_builder + self.runner = input_builder.runner self.sliding_window = input_builder.sliding_window self.block_size = input_builder.block_size self.use_v2_block_manager = ( input_builder.scheduler_config.use_v2_block_manager) - def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata, - token_lens: List[int], seq_lens: List[int], - curr_seq_lens: List[int], query_lens: List[int], - context_lens: List[int], - curr_sliding_window_blocks: List[int], - prefix_cache_hit: bool, chunked_prefill_enabled: bool): + def _add_seq_group( + self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup", + chunked_prefill_enabled: bool): """Add a sequence group to the metadata. Specifically update/append 1. context length. 2. block table. 3. slot mapping. """ - is_prompt = seq_group_metadata.is_prompt - block_tables = seq_group_metadata.block_tables + is_prompt = inter_data.is_prompt + block_tables = inter_data.block_tables for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len, curr_sliding_window_block) in zip( - seq_group_metadata.seq_data.keys(), token_lens, seq_lens, - curr_seq_lens, query_lens, context_lens, - curr_sliding_window_blocks): + inter_data.seq_ids, [len(t) for t in inter_data.input_tokens], + inter_data.orig_seq_lens, inter_data.seq_lens, + inter_data.query_lens, inter_data.context_lens, + inter_data.curr_sliding_window_blocks): self.context_lens.append(context_len) if is_prompt: @@ -254,7 +252,7 @@ def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata, # only allowing multiple of block_size chunk size. # NOTE: This only works for oooooooxxx style attention. block_table = [] - if prefix_cache_hit: + if inter_data.prefix_cache_hit: # NOTE(woosuk): For flash-attn, the block table should # include the entries for the incoming prefill tokens. block_table = block_tables[seq_id] @@ -270,16 +268,19 @@ def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata, self.use_v2_block_manager) compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id, seq_len, context_len, start_idx, - self.block_size, - seq_group_metadata.block_tables) + self.block_size, inter_data.block_tables) - def build(self, runner: "GPUModelRunnerBase", seq_lens, query_lens, + def build(self, seq_lens: List[int], query_lens: List[int], cuda_graph_pad_size: int, batch_size: int): """Build attention metadata with on-device tensors.""" - device = runner.device + for inter_data in self.input_builder.inter_data_list: + self._add_seq_group(inter_data, + self.input_builder.chunked_prefill_enabled) + + device = self.runner.device use_captured_graph = cuda_graph_pad_size != -1 - logits_soft_cap = getattr(runner.model_config.hf_config, + logits_soft_cap = getattr(self.runner.model_config.hf_config, "attn_logit_softcapping", None) if logits_soft_cap is not None: raise ValueError( @@ -300,7 +301,7 @@ def build(self, runner: "GPUModelRunnerBase", seq_lens, query_lens, # The shape of graph_block_tables is # [max batch size, max context len // block size]. - input_block_tables = runner.graph_block_tables[:batch_size] + input_block_tables = self.runner.graph_block_tables[:batch_size] for i, block_table in enumerate(self.block_tables): if block_table: input_block_tables[i, :len(block_table)] = block_table diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index eb8b1f0fcfb39..9dac12d3b906d 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -21,12 +21,10 @@ compute_slot_mapping_start_idx, is_block_tables_empty) from vllm.attention.ops.paged_attn import PagedAttention -from vllm.sequence import SequenceGroupMetadata from vllm.utils import get_kv_cache_torch_dtype, make_tensor_with_pad if TYPE_CHECKING: - from vllm.worker.model_runner import (GPUModelRunnerBase, - ModelInputForGPUBuilder) + from vllm.worker.model_runner import ModelInputForGPUBuilder class FlashInferBackend(AttentionBackend): @@ -216,6 +214,9 @@ def __init__(self, input_builder: "ModelInputForGPUBuilder"): self.num_prefill_tokens = 0 self.num_decode_tokens = 0 + self.input_builder = input_builder + self.runner = input_builder.runner + self.sliding_window = input_builder.sliding_window self.block_size = input_builder.block_size self.use_v2_block_manager = ( @@ -238,26 +239,24 @@ def __init__(self, input_builder: "ModelInputForGPUBuilder"): # paged_kv_last_page_len is the length of the last page of each request self.paged_kv_last_page_len: List[int] = [] - def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata, - token_lens: List[int], seq_lens: List[int], - curr_seq_lens: List[int], query_lens: List[int], - context_lens: List[int], - curr_sliding_window_blocks: List[int], - prefix_cache_hit: bool, chunked_prefill_enabled: bool): + def _add_seq_group( + self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup", + chunked_prefill_enabled: bool): """Add a sequence group to the metadata. Specifically update/append 1. context length. 2. block table. 3. slot mapping. """ - is_prompt = seq_group_metadata.is_prompt - block_tables = seq_group_metadata.block_tables - computed_block_nums = seq_group_metadata.computed_block_nums + is_prompt = inter_data.is_prompt + block_tables = inter_data.block_tables + computed_block_nums = inter_data.computed_block_nums for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len, curr_sliding_window_block) in zip( - seq_group_metadata.seq_data.keys(), token_lens, seq_lens, - curr_seq_lens, query_lens, context_lens, - curr_sliding_window_blocks): + inter_data.seq_ids, [len(t) for t in inter_data.input_tokens], + inter_data.orig_seq_lens, inter_data.seq_lens, + inter_data.query_lens, inter_data.context_lens, + inter_data.curr_sliding_window_blocks): self.context_lens.append(context_len) if is_prompt: self.num_prefills += 1 @@ -275,7 +274,7 @@ def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata, # only allowing multiple of block_size chunk size. # NOTE: This only works for oooooooxxx style attention. block_table = [] - if prefix_cache_hit: + if inter_data.prefix_cache_hit: block_table = computed_block_nums elif ((chunked_prefill_enabled or not is_prompt) and block_tables is not None): @@ -290,8 +289,7 @@ def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata, self.use_v2_block_manager) compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id, seq_len, context_len, start_idx, - self.block_size, - seq_group_metadata.block_tables) + self.block_size, inter_data.block_tables) # It is not necessary to add paged_kv_indices, paged_kv_indptr, # and paged_kv_last_page_len for profile run because we will @@ -317,9 +315,13 @@ def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata, last_page_len = self.block_size self.paged_kv_last_page_len.append(last_page_len) - def build(self, runner: "GPUModelRunnerBase", seq_lens, query_lens, + def build(self, seq_lens: List[int], query_lens: List[int], cuda_graph_pad_size: int, batch_size: int): - device = runner.device + for inter_data in self.input_builder.inter_data_list: + self._add_seq_group(inter_data, + self.input_builder.chunked_prefill_enabled) + + device = self.runner.device use_captured_graph = cuda_graph_pad_size != -1 max_query_len = max(query_lens) @@ -333,7 +335,7 @@ def build(self, runner: "GPUModelRunnerBase", seq_lens, query_lens, # The shape of graph_block_tables is # [max batch size, max context len // block size]. - input_block_tables = runner.graph_block_tables[:batch_size] + input_block_tables = self.runner.graph_block_tables[:batch_size] for i, block_table in enumerate(self.block_tables): if block_table: input_block_tables[i, :len(block_table)] = block_table @@ -377,7 +379,7 @@ def build(self, runner: "GPUModelRunnerBase", seq_lens, query_lens, dtype=torch.long, device=device) - logits_soft_cap = getattr(runner.model_config.hf_config, + logits_soft_cap = getattr(self.runner.model_config.hf_config, "attn_logit_softcapping", None) if len(self.paged_kv_indptr) > 0: @@ -394,8 +396,8 @@ def build(self, runner: "GPUModelRunnerBase", seq_lens, query_lens, paged_kv_indptr_tensor = None paged_kv_last_page_len_tensor = None - kv_cache_dtype = get_kv_cache_torch_dtype(runner.kv_cache_dtype, - runner.model_config.dtype) + kv_cache_dtype = get_kv_cache_torch_dtype( + self.runner.kv_cache_dtype, self.runner.model_config.dtype) return FlashInferMetadata( num_prefills=self.num_prefills, slot_mapping=slot_mapping_tensor, @@ -406,11 +408,11 @@ def build(self, runner: "GPUModelRunnerBase", seq_lens, query_lens, paged_kv_indptr=paged_kv_indptr_tensor, paged_kv_indices=paged_kv_indices_tensor, paged_kv_last_page_len=paged_kv_last_page_len_tensor, - num_qo_heads=runner.model_config.get_num_attention_heads( - runner.parallel_config), - num_kv_heads=runner.model_config.get_num_kv_heads( - runner.parallel_config), - head_dim=runner.model_config.get_head_size(), + num_qo_heads=self.runner.model_config.get_num_attention_heads( + self.runner.parallel_config), + num_kv_heads=self.runner.model_config.get_num_kv_heads( + self.runner.parallel_config), + head_dim=self.runner.model_config.get_head_size(), page_size=self.block_size, seq_start_loc=seq_start_loc, query_start_loc=query_start_loc, diff --git a/vllm/attention/backends/utils.py b/vllm/attention/backends/utils.py index 0706e2d3a48b7..5877712b9b7d3 100644 --- a/vllm/attention/backends/utils.py +++ b/vllm/attention/backends/utils.py @@ -4,7 +4,6 @@ import torch from vllm.attention import AttentionMetadata, AttentionMetadataBuilder -from vllm.sequence import SequenceGroupMetadata from vllm.utils import make_tensor_with_pad # Error string(s) for encoder/decoder @@ -15,8 +14,7 @@ PAD_SLOT_ID = -1 if TYPE_CHECKING: - from vllm.worker.model_runner import (GPUModelRunnerBase, - ModelInputForGPUBuilder) + from vllm.worker.model_runner import ModelInputForGPUBuilder def is_block_tables_empty(block_tables: Union[None, Dict]): @@ -95,26 +93,27 @@ def __init__(self, input_builder: "ModelInputForGPUBuilder"): self.num_prefill_tokens = 0 self.num_decode_tokens = 0 + self.input_builder = input_builder + self.runner = input_builder.runner + self.sliding_window = input_builder.sliding_window self.block_size = input_builder.block_size self.use_v2_block_manager = ( input_builder.scheduler_config.use_v2_block_manager) - def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata, - token_lens: List[int], seq_lens: List[int], - curr_seq_lens: List[int], query_lens: List[int], - context_lens: List[int], - curr_sliding_window_blocks: List[int], prefix_cache_hit, - chunked_prefill_enabled): - is_prompt = seq_group_metadata.is_prompt - block_tables = seq_group_metadata.block_tables - computed_block_nums = seq_group_metadata.computed_block_nums + def _add_seq_group( + self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup", + chunked_prefill_enabled: bool): + is_prompt = inter_data.is_prompt + block_tables = inter_data.block_tables + computed_block_nums = inter_data.computed_block_nums for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len, curr_sliding_window_block) in zip( - seq_group_metadata.seq_data.keys(), token_lens, seq_lens, - curr_seq_lens, query_lens, context_lens, - curr_sliding_window_blocks): + inter_data.seq_ids, [len(t) for t in inter_data.input_tokens], + inter_data.orig_seq_lens, inter_data.seq_lens, + inter_data.query_lens, inter_data.context_lens, + inter_data.curr_sliding_window_blocks): self.context_lens.append(context_len) if is_prompt: self.num_prefills += 1 @@ -132,7 +131,7 @@ def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata, # only allowing multiple of block_size chunk size. # NOTE: This only works for oooooooxxx style attention. block_table = [] - if prefix_cache_hit: + if inter_data.prefix_cache_hit: block_table = computed_block_nums elif ((chunked_prefill_enabled or not is_prompt) and block_tables is not None): @@ -146,16 +145,18 @@ def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata, self.use_v2_block_manager) compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id, seq_len, context_len, start_idx, - self.block_size, - seq_group_metadata.block_tables) + self.block_size, inter_data.block_tables) + + def build(self, seq_lens: List[int], query_lens: List[int], + cuda_graph_pad_size: int, batch_size: int): + for inter_data in self.input_builder.inter_data_list: + self._add_seq_group(inter_data, + self.input_builder.chunked_prefill_enabled) - def build(self, runner: "GPUModelRunnerBase", seq_lens: List[int], - query_lens: List[int], cuda_graph_pad_size: int, - batch_size: int): - device = runner.device + device = self.runner.device use_captured_graph = cuda_graph_pad_size != -1 - logits_soft_cap = getattr(runner.model_config.hf_config, + logits_soft_cap = getattr(self.runner.model_config.hf_config, "attn_logit_softcapping", None) if logits_soft_cap is not None: raise ValueError( @@ -176,7 +177,7 @@ def build(self, runner: "GPUModelRunnerBase", seq_lens: List[int], # The shape of graph_block_tables is # [max batch size, max context len // block size]. - input_block_tables = runner.graph_block_tables[:batch_size] + input_block_tables = self.runner.graph_block_tables[:batch_size] for i, block_table in enumerate(self.block_tables): if block_table: input_block_tables[i, :len(block_table)] = block_table diff --git a/vllm/utils.py b/vllm/utils.py index 9e222772eb5b9..83605631b5bd6 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -719,6 +719,11 @@ def merge_dicts(dict1: Dict[K, List[T]], return dict(merged_dict) +def flatten_2d_lists(lists: List[List[T]]) -> List[T]: + """Flatten a list of lists to a single list.""" + return [item for sublist in lists for item in sublist] + + def init_cached_hf_modules() -> None: """ Lazy initialization of the Hugging Face modules. diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 31e9fc1eed548..12650f0b22780 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -3,7 +3,7 @@ import time import warnings import weakref -from collections import defaultdict +from dataclasses import dataclass, field from typing import (TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Set, Tuple, Type, TypeVar, Union) @@ -49,7 +49,8 @@ from vllm.sampling_params import SamplingParams from vllm.sequence import (IntermediateTensors, SamplerOutput, SequenceGroupMetadata) -from vllm.utils import (CudaMemoryProfiler, get_kv_cache_torch_dtype, is_hip, +from vllm.utils import (CudaMemoryProfiler, flatten_2d_lists, + get_kv_cache_torch_dtype, is_hip, is_pin_memory_available) from vllm.worker.model_runner_base import ( ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase, @@ -76,7 +77,7 @@ TModelInputForGPU = TypeVar('TModelInputForGPU', bound="ModelInputForGPU") -@dataclasses.dataclass(frozen=True) +@dataclass(frozen=True) class ModelInputForGPU(ModelRunnerInputBase): """ This base class contains metadata needed for the base model forward pass @@ -126,7 +127,7 @@ def from_broadcasted_tensor_dict( return cls(**tensor_dict) -@dataclasses.dataclass(frozen=True) +@dataclass(frozen=True) class ModelInputForGPUWithSamplingMetadata(ModelInputForGPU): """ Used by the ModelRunner. @@ -168,12 +169,84 @@ def from_broadcasted_tensor_dict( class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): - """TBA""" + """Build ModelInputForGPU from SequenceGroupMetadata.""" + + @dataclass + class InterDataForSeqGroup: + """Intermediate data for the current sequence group.""" + # From sequence group metadata. + request_id: str + seq_ids: List[int] + is_prompt: bool + block_tables: Optional[Dict[int, List[int]]] + computed_block_nums: List[int] + n_seqs: int = 0 + + # Input tokens and positions. + input_tokens: List[List[int]] = field(default_factory=list) + input_positions: List[List[int]] = field(default_factory=list) + + # The sequence length (may be capped to the sliding window). + seq_lens: List[int] = field(default_factory=list) + # The original sequence length (before applying sliding window). + # This is used to compute slot mapping. + orig_seq_lens: List[int] = field(default_factory=list) + # The query length. + query_lens: List[int] = field(default_factory=list) + # The number of tokens that are already computed. + context_lens: List[int] = field(default_factory=list) + # The current sliding window block. + curr_sliding_window_blocks: List[int] = field(default_factory=list) + + # LoRA inputs. + lora_index_mapping: List[List[int]] = field(default_factory=list) + lora_prompt_mapping: List[List[int]] = field(default_factory=list) + lora_requests: Set[LoRARequest] = field(default_factory=set) + + # Prompt adapter inputs. + prompt_adapter_index_mapping: List[int] = field(default_factory=list) + prompt_adapter_prompt_mapping: List[int] = field(default_factory=list) + prompt_adapter_request: Optional[PromptAdapterRequest] = None + + # Multi-modal inputs. + multi_modal_inputs: Optional[MultiModalInputs] = None + + # Whether the prefix cache is hit (prefill only). + prefix_cache_hit: bool = False + + def __post_init__(self): + self.n_seqs = len(self.seq_ids) + + self.input_tokens = [[] for _ in range(self.n_seqs)] + self.input_positions = [[] for _ in range(self.n_seqs)] + self.seq_lens = [0] * self.n_seqs + self.orig_seq_lens = [0] * self.n_seqs + self.query_lens = [0] * self.n_seqs + self.context_lens = [0] * self.n_seqs + self.curr_sliding_window_blocks = [0] * self.n_seqs + + self.lora_index_mapping = [[] for _ in range(self.n_seqs)] + self.lora_prompt_mapping = [[] for _ in range(self.n_seqs)] def __init__(self, runner: "GPUModelRunnerBase", finished_requests_ids: Optional[List[str]] = None): super().__init__() + # Compute functions for each sequence in a sequence group. + # WARNING: The order of the functions matters! + self.per_seq_compute_fns = [ + self._compute_lens, + self._compute_for_prefix_cache_hit, + self._compute_for_sliding_window, + self._compute_lora_input, + ] + # Compute functions for each sequence group. + # WARNING: The order of the functions matters! + self.per_seq_group_compute_fns = [ + self._compute_prompt_adapter_input, + self._compute_multi_modal_input, + ] + self.runner = runner self.model_input_cls = self.runner._model_input_cls self.attn_backend = self.runner.attn_backend @@ -187,30 +260,14 @@ def __init__(self, self.finished_requests_ids = finished_requests_ids self.decode_only = True - # Common inputs. - self.input_tokens: List[int] = [] - self.input_positions: List[int] = [] - self.seq_lens: List[int] = [] - self.query_lens: List[int] = [] - self.max_decode_seq_len: int = 0 - self.request_ids_to_seq_ids: Dict[str, List[int]] = defaultdict(list) - - # LoRA inputs. - self.lora_index_mapping: List[int] = [] - self.lora_prompt_mapping: List[int] = [] - self.lora_requests: Set[LoRARequest] = set() - - # Prompt adapter inputs. - self.prompt_adapter_index_mapping: List[int] = [] - self.prompt_adapter_prompt_mapping: List[int] = [] - self.prompt_adapter_requests: Set[PromptAdapterRequest] = set() - - # Multi-modal inputs. - self.multi_modal_inputs_list: List[MultiModalInputs] = [] + # Intermediate data (data in CPU before going to GPU) for + # the current sequence group. + self.inter_data_list: List[ + ModelInputForGPUBuilder.InterDataForSeqGroup] = [] # Attention metadata inputs. self.attn_metadata_builder = self.attn_backend.make_metadata_builder( - self) + weakref.proxy(self)) # Engine/Model configurations. self.chunked_prefill_enabled = ( @@ -222,175 +279,222 @@ def __init__(self, self.block_aligned_sliding_window = \ self.sliding_window_blocks * self.block_size - def _compute_len_for_sliding_window(self, seq_len: int): - curr_sliding_window_blocks = 0 - sliding_seq_len = seq_len + def _compute_lens(self, inter_data: InterDataForSeqGroup, seq_idx: int, + seq_group_metadata: SequenceGroupMetadata): + """Compute context length, sequence length and tokens + for the given sequence data. + """ + seq_data = seq_group_metadata.seq_data[inter_data.seq_ids[seq_idx]] + token_chunk_size = seq_group_metadata.token_chunk_size - # TODO(sang): This is a hack to make sliding window work with - # paged attn. We can remove it if we make paged attn kernel - # to properly handle slinding window attn. - if self.sliding_window is not None: - curr_sliding_window_blocks = self.sliding_window_blocks + # Compute context length (the number of tokens that are + # already computed) and sequence length (total number of tokens). + seq_len = seq_data.get_len() + if inter_data.is_prompt: + context_len = seq_data.get_num_computed_tokens() + else: + # get_num_computed_tokens is incorrect for spec decoding. + # So, we should have a special logic here. + # TODO(sang): Fix it. + context_len = seq_len - 1 + seq_len = min(seq_len, context_len + token_chunk_size) + + # Compute tokens. + if inter_data.is_prompt: + tokens = seq_data.get_token_ids()[context_len:seq_len] + else: + # Optimization. get_token_ids requires the entire copy of + # tokens. + tokens = [seq_data.get_last_token_id()] + + inter_data.seq_lens[seq_idx] = seq_len + inter_data.orig_seq_lens[seq_idx] = seq_len + inter_data.context_lens[seq_idx] = context_len + inter_data.input_tokens[seq_idx] = tokens + inter_data.input_positions[seq_idx] = list(range(context_len, seq_len)) + inter_data.query_lens[ + seq_idx] = seq_len - context_len if inter_data.is_prompt else 1 + + def _compute_for_prefix_cache_hit( + self, inter_data: InterDataForSeqGroup, seq_idx: int, + seq_group_metadata: SequenceGroupMetadata): + """Check if hit prefix cache (i.e., some blocks are already computed). + If hit, update input tokens and positions to only compute the + remaining blocks. + """ + computed_block_nums = inter_data.computed_block_nums + + # Note that prefix caching does not support sliding window. + prefix_cache_hit = (computed_block_nums is not None + and len(computed_block_nums) > 0 + and self.sliding_window is None + and inter_data.is_prompt) + inter_data.prefix_cache_hit = prefix_cache_hit + if self.chunked_prefill_enabled and prefix_cache_hit: + raise RuntimeError( + "chunked prefill cannot be used with prefix caching now.") + + # If prefix cache is hit, advance context length to bypass + # hit blocks. Accordingly, input tokens, position and query length + # have to be updated. + if prefix_cache_hit: + assert computed_block_nums is not None + context_len = len(computed_block_nums) * self.block_size + inter_data.input_tokens[seq_idx] = inter_data.input_tokens[ + seq_idx][context_len:] + inter_data.input_positions[seq_idx] = inter_data.input_positions[ + seq_idx][context_len:] + inter_data.context_lens[seq_idx] = context_len + inter_data.query_lens[ + seq_idx] = inter_data.seq_lens[seq_idx] - context_len + + def _compute_for_sliding_window(self, inter_data: InterDataForSeqGroup, + seq_idx: int, + seq_group_metadata: SequenceGroupMetadata): + """Update seq_len and curr_sliding_window_block for the given + sequence data (only required by decoding) if sliding window is enabled. + """ + curr_sliding_window_block = 0 + sliding_seq_len = inter_data.seq_lens[seq_idx] + if not inter_data.is_prompt and self.sliding_window is not None: + # TODO(sang): This is a hack to make sliding window work with + # paged attn. We can remove it if we make paged attn kernel + # to properly handle slinding window attn. + curr_sliding_window_block = self.sliding_window_blocks if self.scheduler_config.use_v2_block_manager: # number of elements in last block - suff_len = seq_len % self.block_size + suff_len = inter_data.seq_lens[seq_idx] % self.block_size sliding_seq_len = min( - seq_len, self.block_aligned_sliding_window + suff_len) + inter_data.seq_lens[seq_idx], + self.block_aligned_sliding_window + suff_len) if suff_len > 0: - curr_sliding_window_blocks += 1 + curr_sliding_window_block += 1 else: - sliding_seq_len = min(seq_len, self.sliding_window) - return curr_sliding_window_blocks, sliding_seq_len + sliding_seq_len = min(inter_data.seq_lens[seq_idx], + self.sliding_window) + + inter_data.curr_sliding_window_blocks[ + seq_idx] = curr_sliding_window_block + inter_data.seq_lens[seq_idx] = sliding_seq_len + + def _compute_lora_input(self, inter_data: InterDataForSeqGroup, + seq_idx: int, + seq_group_metadata: SequenceGroupMetadata): + """If LoRA is enabled, compute LoRA index and prompt mapping.""" + if not self.enable_lora: + return + + lora_id = seq_group_metadata.lora_int_id + if lora_id > 0: + inter_data.lora_requests.add(seq_group_metadata.lora_request) + query_len = inter_data.query_lens[seq_idx] + inter_data.lora_index_mapping.append([lora_id] * query_len) + inter_data.lora_prompt_mapping.append( + [lora_id] * + (query_len if seq_group_metadata.sampling_params + and seq_group_metadata.sampling_params.prompt_logprobs is not None + else 1)) + + def _compute_prompt_adapter_input( + self, inter_data: InterDataForSeqGroup, + seq_group_metadata: SequenceGroupMetadata): + """If prompt adapter is enabled, compute index and prompt mapping. + """ + # Note that when is_prompt=True, we expect only one sequence + # in the group. + if not self.enable_prompt_adapter: + return + + prompt_adapter_id = seq_group_metadata.prompt_adapter_id + if prompt_adapter_id <= 0 or not inter_data.is_prompt: + return + + # We expect only one sequence in the group when is_prompt=True. + assert inter_data.n_seqs == 1 + query_len = inter_data.query_lens[0] + inter_data.prompt_adapter_request = ( + seq_group_metadata.prompt_adapter_request) + + num_tokens = seq_group_metadata.prompt_adapter_num_virtual_tokens + inter_data.prompt_adapter_index_mapping = [ + prompt_adapter_id + ] * num_tokens + [0] * (query_len - num_tokens) + inter_data.prompt_adapter_prompt_mapping = [prompt_adapter_id] * ( + query_len if seq_group_metadata.sampling_params + and seq_group_metadata.sampling_params.prompt_logprobs else 1) + + def _compute_multi_modal_input(self, inter_data: InterDataForSeqGroup, + seq_group_metadata: SequenceGroupMetadata): + """If multi-modal data is given, add it to the input.""" + mm_data = seq_group_metadata.multi_modal_data + if not mm_data: + return + + mm_kwargs = self.multi_modal_input_mapper(mm_data) + inter_data.multi_modal_inputs = mm_kwargs def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata): + """Add a sequence group to the builder.""" seq_ids = list(seq_group_metadata.seq_data.keys()) n_seqs = len(seq_ids) is_prompt = seq_group_metadata.is_prompt - token_chunk_size = seq_group_metadata.token_chunk_size if is_prompt: assert n_seqs == 1 self.decode_only = False - # Mapping from request IDs to sequence IDs. Used for Jamba models - # that manages the cache by itself. - self.request_ids_to_seq_ids[seq_group_metadata.request_id] = [] - # The number of input tokens in each sequence. - token_lens: List[int] = [] - # The number of tokens that are already computed. - context_lens: List[int] = [] - # The current sliding window block for each sequence. - curr_sliding_window_blocks: List[int] = [] - # The original sequence length (before applying sliding window) - # for each sequence. - orig_seq_lens: List[int] = [] - # The sequence length (may be capped to the sliding window). - curr_seq_lens: List[int] = [] - for seq_id in seq_ids: - seq_data = seq_group_metadata.seq_data[seq_id] - self.request_ids_to_seq_ids[seq_group_metadata.request_id].append( - seq_id) - computed_block_nums = seq_group_metadata.computed_block_nums - - # Check if hit prefix cache (i.e., some blocks are already computed) - # Note that prefix caching does not support sliding window. - prefix_cache_hit = (computed_block_nums is not None - and len(computed_block_nums) > 0 - and self.sliding_window is None and is_prompt) - if self.chunked_prefill_enabled and prefix_cache_hit: - raise RuntimeError( - "chunked prefill cannot be used with prefix caching now.") - - # Compute context length (the number of tokens that are - # already computed) and sequence length (total number of tokens). - seq_len = seq_data.get_len() - if is_prompt: - context_len = seq_data.get_num_computed_tokens() - else: - # get_num_computed_tokens is incorrect for spec decoding. - # So, we should have a special logic here. - # TODO(sang): Fix it. - context_len = seq_len - 1 - seq_len = min(seq_len, context_len + token_chunk_size) - - # Compute tokens. - if is_prompt: - tokens = seq_data.get_token_ids()[context_len:seq_len] - else: - # Optimization. get_token_ids requires the entire copy of - # tokens. - tokens = [seq_data.get_last_token_id()] - if prefix_cache_hit: - assert computed_block_nums is not None - context_len = len(computed_block_nums) * self.block_size - tokens = tokens[context_len:] - - # These are seq_len/context_len capped to the sliding window. - # They are passed to decode kernel. - # We still need original seq_len/context_len to compute slot - # mapping (and input position) below. - if is_prompt: - curr_sliding_window_block = 0 - sliding_seq_len = seq_len - query_len = seq_len - context_len - else: - curr_sliding_window_block, sliding_seq_len = ( - self._compute_len_for_sliding_window(seq_len)) - query_len = 1 - - self.seq_lens.append(sliding_seq_len) - if not is_prompt: - self.max_decode_seq_len = max(self.max_decode_seq_len, - sliding_seq_len) - self.query_lens.append(query_len) - self.input_tokens.extend(tokens) - self.input_positions.extend(list(range(context_len, seq_len))) - - # Intermediate data of the current sequence group for - # the attention metadata. - token_lens.append(len(tokens)) - context_lens.append(context_len) - curr_seq_lens.append(sliding_seq_len) - curr_sliding_window_blocks.append(curr_sliding_window_block) - orig_seq_lens.append(seq_len) - - # Update attention metadata. Note that input builder attributes - # (self.xxx) include all added sequences, so we need to slice - # the last n_seqs sequences. - self.attn_metadata_builder.add_seq_group( - seq_group_metadata, token_lens, orig_seq_lens, curr_seq_lens, - self.query_lens[-n_seqs:], context_lens, - curr_sliding_window_blocks, prefix_cache_hit, - self.chunked_prefill_enabled) - - # LoRA data. - if self.enable_lora: - lora_id = seq_group_metadata.lora_int_id - for query_len in self.query_lens[-n_seqs:]: - if lora_id > 0: - self.lora_requests.add(seq_group_metadata.lora_request) - self.lora_index_mapping += [lora_id] * query_len - self.lora_prompt_mapping.extend( - [lora_id] * - (query_len if seq_group_metadata.sampling_params - and seq_group_metadata.sampling_params.prompt_logprobs - is not None else 1)) - - # Prompt adapter data. Note that when is_prompt=True, - # we expect only one sequence in the group. - if self.enable_prompt_adapter: - prompt_adapter_id = seq_group_metadata.prompt_adapter_id - if prompt_adapter_id > 0 and is_prompt: - query_len = self.query_lens[-1] - self.prompt_adapter_requests.add( - seq_group_metadata.prompt_adapter_request) - - num_tokens = seq_group_metadata.\ - prompt_adapter_num_virtual_tokens - pm = [prompt_adapter_id - ] * num_tokens + [0] * (query_len - num_tokens) - self.prompt_adapter_index_mapping += pm - self.prompt_adapter_prompt_mapping.extend( - [prompt_adapter_id] * - (query_len if seq_group_metadata.sampling_params - and seq_group_metadata.sampling_params.prompt_logprobs - else 1)) + inter_data = self.InterDataForSeqGroup( + request_id=seq_group_metadata.request_id, + seq_ids=seq_ids, + is_prompt=is_prompt, + block_tables=seq_group_metadata.block_tables, + computed_block_nums=seq_group_metadata.computed_block_nums) + self.inter_data_list.append(inter_data) - # Multi-modal data. - mm_data = seq_group_metadata.multi_modal_data - if mm_data: - mm_kwargs = self.multi_modal_input_mapper(mm_data) - self.multi_modal_inputs_list.append(mm_kwargs) + for seq_idx in range(n_seqs): + for per_seq_fn in self.per_seq_compute_fns: + per_seq_fn(inter_data, seq_idx, seq_group_metadata) + for per_seq_group_fn in self.per_seq_group_compute_fns: + per_seq_group_fn(inter_data, seq_group_metadata) def build(self) -> ModelInputForGPU: - if not self.input_tokens: + """Finalize the builder intermediate data and + create on-device tensors. + """ + # Combine and flatten intermediate data. + input_tokens = flatten_2d_lists([ + flatten_2d_lists(inter_data.input_tokens) + for inter_data in self.inter_data_list + ]) + if not input_tokens: + # This may happen when all prefill requests hit + # prefix caching and there is no decode request. return self.model_input_cls() + input_positions = flatten_2d_lists([ + flatten_2d_lists(inter_data.input_positions) + for inter_data in self.inter_data_list + ]) + seq_lens = [] + max_decode_seq_len = 0 + for inter_data in self.inter_data_list: + seq_lens.extend(inter_data.seq_lens) + if not inter_data.is_prompt: + max_decode_seq_len = max(max_decode_seq_len, + max(inter_data.seq_lens)) + query_lens = flatten_2d_lists( + [inter_data.query_lens for inter_data in self.inter_data_list]) + # Mapping from request IDs to sequence IDs. Used for Jamba models + # that manages the cache by itself. + request_ids_to_seq_ids = { + data.request_id: data.seq_ids + for data in self.inter_data_list + } - batch_size = len(self.input_tokens) + batch_size = len(input_tokens) use_captured_graph = ( self.decode_only and not self.runner.model_config.enforce_eager and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1] - and self.max_decode_seq_len <= self.runner.max_seq_len_to_capture) + and max_decode_seq_len <= self.runner.max_seq_len_to_capture) # If cuda graph can be used, pad tensors accordingly. # See `capture_model` API for more details. @@ -403,60 +507,84 @@ def build(self) -> ModelInputForGPU: batch_size = graph_batch_size # Tokens and positions. - self.input_tokens.extend([0] * cuda_graph_pad_size) - self.input_positions.extend([0] * cuda_graph_pad_size) - input_tokens_tensor = torch.tensor(self.input_tokens, + input_tokens.extend([0] * cuda_graph_pad_size) + input_positions.extend([0] * cuda_graph_pad_size) + input_tokens_tensor = torch.tensor(input_tokens, dtype=torch.long, device=self.runner.device) - input_positions_tensor = torch.tensor(self.input_positions, + input_positions_tensor = torch.tensor(input_positions, dtype=torch.long, device=self.runner.device) # Sequence and query lengths. - self.seq_lens.extend([1] * cuda_graph_pad_size) + seq_lens.extend([1] * cuda_graph_pad_size) # Attention metadata. attn_metadata = self.attn_metadata_builder.build( - self.runner, self.seq_lens, self.query_lens, cuda_graph_pad_size, - batch_size) + seq_lens, query_lens, cuda_graph_pad_size, batch_size) # LoRA data. + lora_requests = set() + lora_mapping = None if self.enable_lora: - self.lora_index_mapping.extend([0] * cuda_graph_pad_size) + lora_requests = set(r for data in self.inter_data_list + for r in data.lora_requests) + lora_index_mapping = flatten_2d_lists([ + flatten_2d_lists(inter_data.lora_index_mapping) + for inter_data in self.inter_data_list + ]) + lora_index_mapping.extend([0] * cuda_graph_pad_size) + lora_prompt_mapping = flatten_2d_lists([ + flatten_2d_lists(inter_data.lora_prompt_mapping) + for inter_data in self.inter_data_list + ]) lora_mapping = LoRAMapping( - self.lora_index_mapping, - self.lora_prompt_mapping, + lora_index_mapping, + lora_prompt_mapping, ) - else: - lora_mapping = None # Prompt adapter data. + prompt_adapter_requests: Set[PromptAdapterRequest] = set() + prompt_adapter_mapping = None if self.enable_prompt_adapter: - self.prompt_adapter_index_mapping.extend([0] * cuda_graph_pad_size) + prompt_adapter_requests = set( + data.prompt_adapter_request for data in self.inter_data_list + if data.prompt_adapter_request is not None) + prompt_adapter_index_mapping = flatten_2d_lists([ + inter_data.prompt_adapter_index_mapping + for inter_data in self.inter_data_list + ]) + prompt_adapter_index_mapping.extend([0] * cuda_graph_pad_size) + prompt_adapter_prompt_mapping = flatten_2d_lists([ + inter_data.prompt_adapter_prompt_mapping + for inter_data in self.inter_data_list + ]) prompt_adapter_mapping = PromptAdapterMapping( - self.prompt_adapter_index_mapping, - self.prompt_adapter_prompt_mapping, + prompt_adapter_index_mapping, + prompt_adapter_prompt_mapping, ) - else: - prompt_adapter_mapping = None # Multi-modal data. - multi_modal_kwargs = MultiModalInputs.batch( - self.multi_modal_inputs_list, device=self.runner.device) + multi_modal_inputs_list = [ + data.multi_modal_inputs for data in self.inter_data_list + if data.multi_modal_inputs is not None + ] + multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list, + device=self.runner.device) return self.model_input_cls( input_tokens=input_tokens_tensor, input_positions=input_positions_tensor, attn_metadata=attn_metadata, - seq_lens=self.seq_lens, - query_lens=self.query_lens, + seq_lens=seq_lens, + query_lens=query_lens, lora_mapping=lora_mapping, - lora_requests=self.lora_requests, + lora_requests=lora_requests, multi_modal_kwargs=multi_modal_kwargs, - request_ids_to_seq_ids=self.request_ids_to_seq_ids, + request_ids_to_seq_ids=request_ids_to_seq_ids, finished_requests_ids=self.finished_requests_ids, prompt_adapter_mapping=prompt_adapter_mapping, - prompt_adapter_requests=self.prompt_adapter_requests) + prompt_adapter_requests=prompt_adapter_requests) class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): @@ -1393,15 +1521,3 @@ def _get_graph_batch_size(batch_size: int) -> int: else: return ((batch_size + _BATCH_SIZE_ALIGNMENT - 1) // _BATCH_SIZE_ALIGNMENT * _BATCH_SIZE_ALIGNMENT) - - -def _is_block_tables_empty(block_tables: Union[None, Dict]): - """ - Check if block_tables is None or a dictionary with all None values. - """ - if block_tables is None: - return True - if isinstance(block_tables, dict) and all( - value is None for value in block_tables.values()): - return True - return False