Skip to content

Commit

Permalink
[Core] Tweaks to model runner/input builder developer APIs (vllm-proj…
Browse files Browse the repository at this point in the history
  • Loading branch information
Yard1 authored and phil committed Aug 6, 2024
1 parent 3f1a6f2 commit 0f8250d
Show file tree
Hide file tree
Showing 3 changed files with 109 additions and 64 deletions.
35 changes: 19 additions & 16 deletions vllm/attention/backends/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,23 +297,26 @@ def _add_seq_group(
if is_profile_run:
return

# Get the number of valid blocks based on sequence length.
# If seq_len = 16, block_size = 16,
# block_table_bound is 1 with 1 valid block.
# If seq_len = 15, block_size = 16,
# block_table_bound is 0 + 1 with 1 valid block.
block_table_bound = seq_len // self.block_size + 1 \
if seq_len % self.block_size != 0 \
else seq_len // self.block_size
block_table = block_tables[seq_id]
self.paged_kv_indices.extend(block_table[:block_table_bound])
self.paged_kv_indptr.append(self.paged_kv_indptr[-1] +
block_table_bound)

last_page_len = seq_len % self.block_size
if last_page_len == 0:
last_page_len = self.block_size
self.paged_kv_last_page_len.append(last_page_len)
self._update_paged_kv_tensors(block_table, seq_len)

def _update_paged_kv_tensors(self, block_table: List[int], seq_len: int):
# Get the number of valid blocks based on sequence length.
# If seq_len = 16, block_size = 16,
# block_table_bound is 1 with 1 valid block.
# If seq_len = 15, block_size = 16,
# block_table_bound is 0 + 1 with 1 valid block.
block_table_bound = seq_len // self.block_size + 1 \
if seq_len % self.block_size != 0 \
else seq_len // self.block_size
self.paged_kv_indices.extend(block_table[:block_table_bound])
self.paged_kv_indptr.append(self.paged_kv_indptr[-1] +
block_table_bound)

last_page_len = seq_len % self.block_size
if last_page_len == 0:
last_page_len = self.block_size
self.paged_kv_last_page_len.append(last_page_len)

def build(self, seq_lens: List[int], query_lens: List[int],
cuda_graph_pad_size: int, batch_size: int):
Expand Down
4 changes: 3 additions & 1 deletion vllm/worker/embedding_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
from vllm.pooling_params import PoolingParams
from vllm.sequence import (IntermediateTensors, PoolerOutput, SequenceData,
SequenceGroupMetadata)
from vllm.worker.model_runner import GPUModelRunnerBase, ModelInputForGPU
from vllm.worker.model_runner import (GPUModelRunnerBase, ModelInputForGPU,
ModelInputForGPUBuilder)

logger = init_logger(__name__)

Expand All @@ -28,6 +29,7 @@ class EmbeddingModelRunner(
GPUModelRunnerBase[ModelInputForGPUWithPoolingMetadata]):
_model_input_cls: Type[ModelInputForGPUWithPoolingMetadata] = (
ModelInputForGPUWithPoolingMetadata)
_builder_cls: Type[ModelInputForGPUBuilder] = ModelInputForGPUBuilder

def __init__(
self,
Expand Down
134 changes: 87 additions & 47 deletions vllm/worker/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import time
import warnings
import weakref
from dataclasses import dataclass, field
from dataclasses import dataclass
from typing import (TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Set,
Tuple, Type, TypeVar, Union)

Expand Down Expand Up @@ -171,48 +171,83 @@ def from_broadcasted_tensor_dict(
class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
"""Build ModelInputForGPU from SequenceGroupMetadata."""

@dataclass
# Note: ideally we would be using a dataclass(kw_only=True)
# here, so that this can be subclassed easily,
# but kw_only is not supported in python<3.10.
class InterDataForSeqGroup:
"""Intermediate data for the current sequence group."""
# From sequence group metadata.
request_id: str
seq_ids: List[int]
is_prompt: bool
block_tables: Optional[Dict[int, List[int]]]
computed_block_nums: List[int]
n_seqs: int = 0

# Input tokens and positions.
input_tokens: List[List[int]] = field(default_factory=list)
input_positions: List[List[int]] = field(default_factory=list)

# The sequence length (may be capped to the sliding window).
seq_lens: List[int] = field(default_factory=list)
# The original sequence length (before applying sliding window).
# This is used to compute slot mapping.
orig_seq_lens: List[int] = field(default_factory=list)
# The query length.
query_lens: List[int] = field(default_factory=list)
# The number of tokens that are already computed.
context_lens: List[int] = field(default_factory=list)
# The current sliding window block.
curr_sliding_window_blocks: List[int] = field(default_factory=list)

# LoRA inputs.
lora_index_mapping: List[List[int]] = field(default_factory=list)
lora_prompt_mapping: List[List[int]] = field(default_factory=list)
lora_requests: Set[LoRARequest] = field(default_factory=set)

# Prompt adapter inputs.
prompt_adapter_index_mapping: List[int] = field(default_factory=list)
prompt_adapter_prompt_mapping: List[int] = field(default_factory=list)
prompt_adapter_request: Optional[PromptAdapterRequest] = None

# Multi-modal inputs.
multi_modal_inputs: Optional[MultiModalInputs] = None

# Whether the prefix cache is hit (prefill only).
prefix_cache_hit: bool = False

def __init__(
self,
*,
# From sequence group metadata.
request_id: str,
seq_ids: List[int],
is_prompt: bool,
block_tables: Optional[Dict[int, List[int]]],
computed_block_nums: List[int],
n_seqs: int = 0,

# Input tokens and positions.
input_tokens: Optional[List[List[int]]] = None,
input_positions: Optional[List[List[int]]] = None,

# The sequence length (may be capped to the sliding window).
seq_lens: Optional[List[int]] = None,
# The original sequence length (before applying sliding window).
# This is used to compute slot mapping.
orig_seq_lens: Optional[List[int]] = None,
# The query length.
query_lens: Optional[List[int]] = None,
# The number of tokens that are already computed.
context_lens: Optional[List[int]] = None,
# The current sliding window block.
curr_sliding_window_blocks: Optional[List[int]] = None,

# LoRA inputs.
lora_index_mapping: Optional[List[List[int]]] = None,
lora_prompt_mapping: Optional[List[List[int]]] = None,
lora_requests: Optional[Set[LoRARequest]] = None,

# Prompt adapter inputs.
prompt_adapter_index_mapping: Optional[List[int]] = None,
prompt_adapter_prompt_mapping: Optional[List[int]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,

# Multi-modal inputs.
multi_modal_inputs: Optional[MultiModalInputs] = None,

# Whether the prefix cache is hit (prefill only).
prefix_cache_hit: bool = False,
):
self.request_id = request_id
self.seq_ids = seq_ids
self.is_prompt = is_prompt
self.block_tables = block_tables
self.computed_block_nums = computed_block_nums
self.n_seqs = n_seqs
self.input_tokens = input_tokens or []
self.input_positions = input_positions or []
self.seq_lens = seq_lens or []
self.orig_seq_lens = orig_seq_lens or []
self.query_lens = query_lens or []
self.context_lens = context_lens or []
self.curr_sliding_window_blocks = curr_sliding_window_blocks or []

self.lora_index_mapping = lora_index_mapping or []
self.lora_prompt_mapping = lora_prompt_mapping or []
self.lora_requests = lora_requests or set()

self.prompt_adapter_index_mapping = (prompt_adapter_index_mapping
or [])
self.prompt_adapter_prompt_mapping = (prompt_adapter_prompt_mapping
or [])
self.prompt_adapter_request = prompt_adapter_request

self.multi_modal_inputs = multi_modal_inputs
self.prefix_cache_hit = prefix_cache_hit

self.__post_init__()

def __post_init__(self):
self.n_seqs = len(self.seq_ids)
Expand Down Expand Up @@ -457,6 +492,12 @@ def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata):
for per_seq_group_fn in self.per_seq_group_compute_fns:
per_seq_group_fn(inter_data, seq_group_metadata)

def _use_captured_graph(self, batch_size: int,
max_decode_seq_len: int) -> bool:
return (self.decode_only and not self.runner.model_config.enforce_eager
and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1]
and max_decode_seq_len <= self.runner.max_seq_len_to_capture)

def build(self) -> ModelInputForGPU:
"""Finalize the builder intermediate data and
create on-device tensors.
Expand Down Expand Up @@ -491,10 +532,8 @@ def build(self) -> ModelInputForGPU:
}

batch_size = len(input_tokens)
use_captured_graph = (
self.decode_only and not self.runner.model_config.enforce_eager
and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1]
and max_decode_seq_len <= self.runner.max_seq_len_to_capture)
use_captured_graph = self._use_captured_graph(batch_size,
max_decode_seq_len)

# If cuda graph can be used, pad tensors accordingly.
# See `capture_model` API for more details.
Expand Down Expand Up @@ -592,6 +631,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
Helper class for shared methods between GPU model runners.
"""
_model_input_cls: Type[TModelInputForGPU]
_builder_cls: Type[ModelInputForGPUBuilder]

def __init__(
self,
Expand Down Expand Up @@ -794,8 +834,7 @@ def _prepare_model_input_tensors(
If cuda graph is required, this API automatically pads inputs.
"""
builder = ModelInputForGPUBuilder(weakref.proxy(self),
finished_requests_ids)
builder = self._builder_cls(weakref.proxy(self), finished_requests_ids)
for seq_group_metadata in seq_group_metadata_list:
builder.add_seq_group(seq_group_metadata)
return builder.build() # type: ignore
Expand Down Expand Up @@ -1191,6 +1230,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
"""
_model_input_cls: Type[ModelInputForGPUWithSamplingMetadata] = (
ModelInputForGPUWithSamplingMetadata)
_builder_cls: Type[ModelInputForGPUBuilder] = ModelInputForGPUBuilder

def make_model_input_from_broadcasted_tensor_dict(
self,
Expand Down

0 comments on commit 0f8250d

Please sign in to comment.