Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core] Tweaks to model runner/input builder developer APIs #6712

Merged
merged 4 commits into from
Jul 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 19 additions & 16 deletions vllm/attention/backends/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,23 +297,26 @@ def _add_seq_group(
if is_profile_run:
return

# Get the number of valid blocks based on sequence length.
# If seq_len = 16, block_size = 16,
# block_table_bound is 1 with 1 valid block.
# If seq_len = 15, block_size = 16,
# block_table_bound is 0 + 1 with 1 valid block.
block_table_bound = seq_len // self.block_size + 1 \
if seq_len % self.block_size != 0 \
else seq_len // self.block_size
block_table = block_tables[seq_id]
self.paged_kv_indices.extend(block_table[:block_table_bound])
self.paged_kv_indptr.append(self.paged_kv_indptr[-1] +
block_table_bound)

last_page_len = seq_len % self.block_size
if last_page_len == 0:
last_page_len = self.block_size
self.paged_kv_last_page_len.append(last_page_len)
self._update_paged_kv_tensors(block_table, seq_len)

def _update_paged_kv_tensors(self, block_table: List[int], seq_len: int):
# Get the number of valid blocks based on sequence length.
# If seq_len = 16, block_size = 16,
# block_table_bound is 1 with 1 valid block.
# If seq_len = 15, block_size = 16,
# block_table_bound is 0 + 1 with 1 valid block.
block_table_bound = seq_len // self.block_size + 1 \
if seq_len % self.block_size != 0 \
else seq_len // self.block_size
self.paged_kv_indices.extend(block_table[:block_table_bound])
self.paged_kv_indptr.append(self.paged_kv_indptr[-1] +
block_table_bound)

last_page_len = seq_len % self.block_size
if last_page_len == 0:
last_page_len = self.block_size
self.paged_kv_last_page_len.append(last_page_len)

def build(self, seq_lens: List[int], query_lens: List[int],
cuda_graph_pad_size: int, batch_size: int):
Expand Down
4 changes: 3 additions & 1 deletion vllm/worker/embedding_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
from vllm.pooling_params import PoolingParams
from vllm.sequence import (IntermediateTensors, PoolerOutput, SequenceData,
SequenceGroupMetadata)
from vllm.worker.model_runner import GPUModelRunnerBase, ModelInputForGPU
from vllm.worker.model_runner import (GPUModelRunnerBase, ModelInputForGPU,
ModelInputForGPUBuilder)

logger = init_logger(__name__)

Expand All @@ -28,6 +29,7 @@ class EmbeddingModelRunner(
GPUModelRunnerBase[ModelInputForGPUWithPoolingMetadata]):
_model_input_cls: Type[ModelInputForGPUWithPoolingMetadata] = (
ModelInputForGPUWithPoolingMetadata)
_builder_cls: Type[ModelInputForGPUBuilder] = ModelInputForGPUBuilder

def __init__(
self,
Expand Down
134 changes: 87 additions & 47 deletions vllm/worker/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import time
import warnings
import weakref
from dataclasses import dataclass, field
from dataclasses import dataclass
from typing import (TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Set,
Tuple, Type, TypeVar, Union)

Expand Down Expand Up @@ -171,48 +171,83 @@ def from_broadcasted_tensor_dict(
class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
"""Build ModelInputForGPU from SequenceGroupMetadata."""

@dataclass
# Note: ideally we would be using a dataclass(kw_only=True)
# here, so that this can be subclassed easily,
# but kw_only is not supported in python<3.10.
class InterDataForSeqGroup:
"""Intermediate data for the current sequence group."""
# From sequence group metadata.
request_id: str
seq_ids: List[int]
is_prompt: bool
block_tables: Optional[Dict[int, List[int]]]
computed_block_nums: List[int]
n_seqs: int = 0

# Input tokens and positions.
input_tokens: List[List[int]] = field(default_factory=list)
input_positions: List[List[int]] = field(default_factory=list)

# The sequence length (may be capped to the sliding window).
seq_lens: List[int] = field(default_factory=list)
# The original sequence length (before applying sliding window).
# This is used to compute slot mapping.
orig_seq_lens: List[int] = field(default_factory=list)
# The query length.
query_lens: List[int] = field(default_factory=list)
# The number of tokens that are already computed.
context_lens: List[int] = field(default_factory=list)
# The current sliding window block.
curr_sliding_window_blocks: List[int] = field(default_factory=list)

# LoRA inputs.
lora_index_mapping: List[List[int]] = field(default_factory=list)
lora_prompt_mapping: List[List[int]] = field(default_factory=list)
lora_requests: Set[LoRARequest] = field(default_factory=set)

# Prompt adapter inputs.
prompt_adapter_index_mapping: List[int] = field(default_factory=list)
prompt_adapter_prompt_mapping: List[int] = field(default_factory=list)
prompt_adapter_request: Optional[PromptAdapterRequest] = None

# Multi-modal inputs.
multi_modal_inputs: Optional[MultiModalInputs] = None

# Whether the prefix cache is hit (prefill only).
prefix_cache_hit: bool = False

def __init__(
self,
*,
# From sequence group metadata.
request_id: str,
seq_ids: List[int],
is_prompt: bool,
block_tables: Optional[Dict[int, List[int]]],
computed_block_nums: List[int],
n_seqs: int = 0,

# Input tokens and positions.
input_tokens: Optional[List[List[int]]] = None,
input_positions: Optional[List[List[int]]] = None,

# The sequence length (may be capped to the sliding window).
seq_lens: Optional[List[int]] = None,
# The original sequence length (before applying sliding window).
# This is used to compute slot mapping.
orig_seq_lens: Optional[List[int]] = None,
# The query length.
query_lens: Optional[List[int]] = None,
# The number of tokens that are already computed.
context_lens: Optional[List[int]] = None,
# The current sliding window block.
curr_sliding_window_blocks: Optional[List[int]] = None,

# LoRA inputs.
lora_index_mapping: Optional[List[List[int]]] = None,
lora_prompt_mapping: Optional[List[List[int]]] = None,
lora_requests: Optional[Set[LoRARequest]] = None,

# Prompt adapter inputs.
prompt_adapter_index_mapping: Optional[List[int]] = None,
prompt_adapter_prompt_mapping: Optional[List[int]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,

# Multi-modal inputs.
multi_modal_inputs: Optional[MultiModalInputs] = None,

# Whether the prefix cache is hit (prefill only).
prefix_cache_hit: bool = False,
):
self.request_id = request_id
self.seq_ids = seq_ids
self.is_prompt = is_prompt
self.block_tables = block_tables
self.computed_block_nums = computed_block_nums
self.n_seqs = n_seqs
self.input_tokens = input_tokens or []
self.input_positions = input_positions or []
self.seq_lens = seq_lens or []
self.orig_seq_lens = orig_seq_lens or []
self.query_lens = query_lens or []
self.context_lens = context_lens or []
self.curr_sliding_window_blocks = curr_sliding_window_blocks or []

self.lora_index_mapping = lora_index_mapping or []
self.lora_prompt_mapping = lora_prompt_mapping or []
self.lora_requests = lora_requests or set()

self.prompt_adapter_index_mapping = (prompt_adapter_index_mapping
or [])
self.prompt_adapter_prompt_mapping = (prompt_adapter_prompt_mapping
or [])
self.prompt_adapter_request = prompt_adapter_request

self.multi_modal_inputs = multi_modal_inputs
self.prefix_cache_hit = prefix_cache_hit

self.__post_init__()

def __post_init__(self):
self.n_seqs = len(self.seq_ids)
Expand Down Expand Up @@ -457,6 +492,12 @@ def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata):
for per_seq_group_fn in self.per_seq_group_compute_fns:
per_seq_group_fn(inter_data, seq_group_metadata)

def _use_captured_graph(self, batch_size: int,
max_decode_seq_len: int) -> bool:
return (self.decode_only and not self.runner.model_config.enforce_eager
and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1]
and max_decode_seq_len <= self.runner.max_seq_len_to_capture)

def build(self) -> ModelInputForGPU:
"""Finalize the builder intermediate data and
create on-device tensors.
Expand Down Expand Up @@ -491,10 +532,8 @@ def build(self) -> ModelInputForGPU:
}

batch_size = len(input_tokens)
use_captured_graph = (
self.decode_only and not self.runner.model_config.enforce_eager
and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1]
and max_decode_seq_len <= self.runner.max_seq_len_to_capture)
use_captured_graph = self._use_captured_graph(batch_size,
max_decode_seq_len)

# If cuda graph can be used, pad tensors accordingly.
# See `capture_model` API for more details.
Expand Down Expand Up @@ -592,6 +631,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
Helper class for shared methods between GPU model runners.
"""
_model_input_cls: Type[TModelInputForGPU]
_builder_cls: Type[ModelInputForGPUBuilder]

def __init__(
self,
Expand Down Expand Up @@ -794,8 +834,7 @@ def _prepare_model_input_tensors(

If cuda graph is required, this API automatically pads inputs.
"""
builder = ModelInputForGPUBuilder(weakref.proxy(self),
finished_requests_ids)
builder = self._builder_cls(weakref.proxy(self), finished_requests_ids)
for seq_group_metadata in seq_group_metadata_list:
builder.add_seq_group(seq_group_metadata)
return builder.build() # type: ignore
Expand Down Expand Up @@ -1191,6 +1230,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
"""
_model_input_cls: Type[ModelInputForGPUWithSamplingMetadata] = (
ModelInputForGPUWithSamplingMetadata)
_builder_cls: Type[ModelInputForGPUBuilder] = ModelInputForGPUBuilder

def make_model_input_from_broadcasted_tensor_dict(
self,
Expand Down
Loading