Skip to content

Commit

Permalink
address comment
Browse files Browse the repository at this point in the history
  • Loading branch information
comaniac committed Jul 16, 2024
1 parent cdf10db commit 743c0f5
Show file tree
Hide file tree
Showing 5 changed files with 88 additions and 45 deletions.
24 changes: 19 additions & 5 deletions vllm/attention/backends/abstract.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass, fields
from enum import Enum, auto
from typing import (Any, Dict, Generic, List, Optional, Set, Tuple, Type,
TypeVar)
from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Set,
Tuple, Type, TypeVar)

import torch

if TYPE_CHECKING:
from vllm.sequence import SequenceGroupMetadata
from vllm.worker.model_runner_base import ModelRunnerInputBuilderBase


class AttentionType(Enum):
DECODER = auto() # Decoder attention between previous layer Q/K/V
Expand Down Expand Up @@ -128,12 +132,22 @@ def __init__(self, input_builder) -> None:
raise NotImplementedError

@abstractmethod
def add_seq_group(self, *args, **kwargs) -> None:
def add_seq_group(self, seq_group_metadata: "SequenceGroupMetadata",
token_lens: List[int], seq_lens: List[int],
curr_seq_lens: List[int], query_lens: List[int],
context_lens: List[int],
curr_sliding_window_blocks: List[int],
prefix_cache_hit: bool, chunked_prefill_enabled: bool):
"""Add a sequence group to the metadata and update
corresponding fields (in Python objects).
"""
raise NotImplementedError

@abstractmethod
def build(self, runner, seq_lens, query_lens, use_captured_graph: bool,
cuda_graph_pad_size: int, batch_size: int) -> T:
def build(self, runner: "ModelRunnerInputBuilderBase", seq_lens: List[int],
query_lens: List[int], cuda_graph_pad_size: int,
batch_size: int) -> T:
"""Build attention metadata with on-device tensors."""
raise NotImplementedError


Expand Down
28 changes: 17 additions & 11 deletions vllm/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ def __init__(self, input_builder: "ModelInputForGPUBuilder"):
self.prefill_seq_lens: List[int] = []
self.context_lens: List[int] = []
self.block_tables: List[List[int]] = []
self.decode_seq_lens: List[int] = []
self.curr_seq_lens: List[int] = []
self.num_prefills = 0
self.num_prefill_tokens = 0
self.num_decode_tokens = 0
Expand All @@ -219,17 +219,22 @@ def __init__(self, input_builder: "ModelInputForGPUBuilder"):

def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata,
token_lens: List[int], seq_lens: List[int],
decode_seq_lens: List[int], query_lens: List[int],
curr_seq_lens: List[int], query_lens: List[int],
context_lens: List[int],
curr_sliding_window_blocks: List[int], prefix_cache_hit,
chunked_prefill_enabled):
curr_sliding_window_blocks: List[int],
prefix_cache_hit: bool, chunked_prefill_enabled: bool):
"""Add a sequence group to the metadata. Specifically update/append
1. context length.
2. block table.
3. slot mapping.
"""
is_prompt = seq_group_metadata.is_prompt
block_tables = seq_group_metadata.block_tables

for (seq_id, token_len, seq_len, decode_seq_len, query_len,
context_len, curr_sliding_window_block) in zip(
for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len,
curr_sliding_window_block) in zip(
seq_group_metadata.seq_data.keys(), token_lens, seq_lens,
decode_seq_lens, query_lens, context_lens,
curr_seq_lens, query_lens, context_lens,
curr_sliding_window_blocks):
self.context_lens.append(context_len)

Expand All @@ -242,7 +247,7 @@ def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata,
"seq_len: {}, context_len: {}, query_len: {}".format(
seq_len, context_len, query_len))
self.num_decode_tokens += query_len
self.decode_seq_lens.append(decode_seq_len)
self.curr_seq_lens.append(curr_seq_len)

# Compute block table.
# TODO(sang): Combine chunked prefill and prefix caching by
Expand All @@ -269,9 +274,10 @@ def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata,
seq_group_metadata.block_tables)

def build(self, runner: "GPUModelRunnerBase", seq_lens, query_lens,
use_captured_graph: bool, cuda_graph_pad_size: int,
batch_size: int):
cuda_graph_pad_size: int, batch_size: int):
"""Build attention metadata with on-device tensors."""
device = runner.device
use_captured_graph = cuda_graph_pad_size > 0

logits_soft_cap = getattr(runner.model_config.hf_config,
"attn_logit_softcapping", None)
Expand All @@ -284,7 +290,7 @@ def build(self, runner: "GPUModelRunnerBase", seq_lens, query_lens,

max_query_len = max(query_lens)
max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
max_decode_seq_len = max(self.decode_seq_lens, default=0)
max_decode_seq_len = max(self.curr_seq_lens, default=0)
num_decode_tokens = self.num_decode_tokens

if use_captured_graph:
Expand Down
29 changes: 19 additions & 10 deletions vllm/attention/backends/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def __init__(self, input_builder: "ModelInputForGPUBuilder"):
self.prefill_seq_lens: List[int] = []
self.context_lens: List[int] = []
self.block_tables: List[List[int]] = []
self.decode_seq_lens: List[int] = []
self.curr_seq_lens: List[int] = []
self.num_prefills = 0
self.num_prefill_tokens = 0
self.num_decode_tokens = 0
Expand Down Expand Up @@ -239,18 +239,23 @@ def __init__(self, input_builder: "ModelInputForGPUBuilder"):

def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata,
token_lens: List[int], seq_lens: List[int],
decode_seq_lens: List[int], query_lens: List[int],
curr_seq_lens: List[int], query_lens: List[int],
context_lens: List[int],
curr_sliding_window_blocks: List[int], prefix_cache_hit,
chunked_prefill_enabled):
curr_sliding_window_blocks: List[int],
prefix_cache_hit: bool, chunked_prefill_enabled: bool):
"""Add a sequence group to the metadata. Specifically update/append
1. context length.
2. block table.
3. slot mapping.
"""
is_prompt = seq_group_metadata.is_prompt
block_tables = seq_group_metadata.block_tables
computed_block_nums = seq_group_metadata.computed_block_nums

for (seq_id, token_len, seq_len, decode_seq_len, query_len,
context_len, curr_sliding_window_block) in zip(
for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len,
curr_sliding_window_block) in zip(
seq_group_metadata.seq_data.keys(), token_lens, seq_lens,
decode_seq_lens, query_lens, context_lens,
curr_seq_lens, query_lens, context_lens,
curr_sliding_window_blocks):
self.context_lens.append(context_len)
if is_prompt:
Expand All @@ -262,7 +267,7 @@ def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata,
"seq_len: {}, context_len: {}, query_len: {}".format(
seq_len, context_len, query_len))
self.num_decode_tokens += query_len
self.decode_seq_lens.append(decode_seq_len)
self.curr_seq_lens.append(curr_seq_len)

# Compute block table.
# TODO(sang): Combine chunked prefill and prefix caching by
Expand All @@ -286,6 +291,10 @@ def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata,
seq_len, context_len, start_idx,
self.block_size,
seq_group_metadata.block_tables)

# It is not necessary to add paged_kv_indices, paged_kv_indptr,
# and paged_kv_last_page_len for profile run because we will
# create dummy inputs.
if is_profile_run:
return

Expand All @@ -308,9 +317,9 @@ def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata,
self.paged_kv_last_page_len.append(last_page_len)

def build(self, runner: "GPUModelRunnerBase", seq_lens, query_lens,
use_captured_graph: bool, cuda_graph_pad_size: int,
batch_size: int):
cuda_graph_pad_size: int, batch_size: int):
device = runner.device
use_captured_graph = cuda_graph_pad_size > 0

max_query_len = max(query_lens)
max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
Expand Down
19 changes: 10 additions & 9 deletions vllm/attention/backends/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def __init__(self, input_builder: "ModelInputForGPUBuilder"):
self.prefill_seq_lens: List[int] = []
self.context_lens: List[int] = []
self.block_tables: List[List[int]] = []
self.decode_seq_lens: List[int] = []
self.curr_seq_lens: List[int] = []
self.num_prefills = 0
self.num_prefill_tokens = 0
self.num_decode_tokens = 0
Expand All @@ -102,18 +102,18 @@ def __init__(self, input_builder: "ModelInputForGPUBuilder"):

def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata,
token_lens: List[int], seq_lens: List[int],
decode_seq_lens: List[int], query_lens: List[int],
curr_seq_lens: List[int], query_lens: List[int],
context_lens: List[int],
curr_sliding_window_blocks: List[int], prefix_cache_hit,
chunked_prefill_enabled):
is_prompt = seq_group_metadata.is_prompt
block_tables = seq_group_metadata.block_tables
computed_block_nums = seq_group_metadata.computed_block_nums

for (seq_id, token_len, seq_len, decode_seq_len, query_len,
context_len, curr_sliding_window_block) in zip(
for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len,
curr_sliding_window_block) in zip(
seq_group_metadata.seq_data.keys(), token_lens, seq_lens,
decode_seq_lens, query_lens, context_lens,
curr_seq_lens, query_lens, context_lens,
curr_sliding_window_blocks):
self.context_lens.append(context_len)
if is_prompt:
Expand All @@ -125,7 +125,7 @@ def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata,
"seq_len: {}, context_len: {}, query_len: {}".format(
seq_len, context_len, query_len))
self.num_decode_tokens += query_len
self.decode_seq_lens.append(decode_seq_len)
self.curr_seq_lens.append(curr_seq_len)

# Compute block table.
# TODO(sang): Combine chunked prefill and prefix caching by
Expand All @@ -150,9 +150,10 @@ def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata,
seq_group_metadata.block_tables)

def build(self, runner: "GPUModelRunnerBase", seq_lens: List[int],
query_lens: List[int], use_captured_graph: bool,
cuda_graph_pad_size: int, batch_size: int):
query_lens: List[int], cuda_graph_pad_size: int,
batch_size: int):
device = runner.device
use_captured_graph = cuda_graph_pad_size > 0

logits_soft_cap = getattr(runner.model_config.hf_config,
"attn_logit_softcapping", None)
Expand All @@ -165,7 +166,7 @@ def build(self, runner: "GPUModelRunnerBase", seq_lens: List[int],

max_query_len = max(query_lens)
max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
max_decode_seq_len = max(self.decode_seq_lens, default=0)
max_decode_seq_len = max(self.curr_seq_lens, default=0)
num_decode_tokens = self.num_decode_tokens

if use_captured_graph:
Expand Down
33 changes: 23 additions & 10 deletions vllm/worker/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import gc
import time
import warnings
import weakref
from collections import defaultdict
from typing import (TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Set,
Tuple, Type, TypeVar, Union)
Expand Down Expand Up @@ -250,12 +251,20 @@ def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata):
assert n_seqs == 1
self.decode_only = False

# Mapping from request IDs to sequence IDs. Used for Jamba models
# that manages the cache by itself.
self.request_ids_to_seq_ids[seq_group_metadata.request_id] = []
token_lens = []
decode_seq_lens = []
context_lens = []
curr_sliding_window_blocks = []
orig_seq_lens = []
# The number of input tokens in each sequence.
token_lens: List[int] = []
# The number of tokens that are already computed.
context_lens: List[int] = []
# The current sliding window block for each sequence.
curr_sliding_window_blocks: List[int] = []
# The original sequence length (before applying sliding window)
# for each sequence.
orig_seq_lens: List[int] = []
# The sequence length (may be capped to the sliding window).
curr_seq_lens: List[int] = []
for seq_id in seq_ids:
seq_data = seq_group_metadata.seq_data[seq_id]
self.request_ids_to_seq_ids[seq_group_metadata.request_id].append(
Expand Down Expand Up @@ -320,12 +329,15 @@ def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata):
# the attention metadata.
token_lens.append(len(tokens))
context_lens.append(context_len)
decode_seq_lens.append(sliding_seq_len)
curr_seq_lens.append(sliding_seq_len)
curr_sliding_window_blocks.append(curr_sliding_window_block)
orig_seq_lens.append(seq_len)

# Update attention metadata. Note that input builder attributes
# (self.xxx) include all added sequences, so we need to slice
# the last n_seqs sequences.
self.attn_metadata_builder.add_seq_group(
seq_group_metadata, token_lens, orig_seq_lens, decode_seq_lens,
seq_group_metadata, token_lens, orig_seq_lens, curr_seq_lens,
self.query_lens[-n_seqs:], context_lens,
curr_sliding_window_blocks, prefix_cache_hit,
self.chunked_prefill_enabled)
Expand Down Expand Up @@ -404,8 +416,8 @@ def build(self) -> ModelInputForGPU:

# Attention metadata.
attn_metadata = self.attn_metadata_builder.build(
self.runner, self.seq_lens, self.query_lens, use_captured_graph,
cuda_graph_pad_size, batch_size)
self.runner, self.seq_lens, self.query_lens, cuda_graph_pad_size,
batch_size)

# LoRA data.
if self.enable_lora:
Expand Down Expand Up @@ -649,7 +661,8 @@ def _prepare_model_input_tensors(
If cuda graph is required, this API automatically pads inputs.
"""
builder = ModelInputForGPUBuilder(self, finished_requests_ids)
builder = ModelInputForGPUBuilder(weakref.proxy(self),
finished_requests_ids)
for seq_group_metadata in seq_group_metadata_list:
builder.add_seq_group(seq_group_metadata)
return builder.build() # type: ignore
Expand Down

0 comments on commit 743c0f5

Please sign in to comment.