Skip to content

Commit

Permalink
[Core] Modulize prepare input and attention metadata builder (#6596)
Browse files Browse the repository at this point in the history
  • Loading branch information
comaniac authored Jul 23, 2024
1 parent bdf5fd1 commit e0c1575
Show file tree
Hide file tree
Showing 6 changed files with 409 additions and 298 deletions.
20 changes: 3 additions & 17 deletions vllm/attention/backends/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import torch

if TYPE_CHECKING:
from vllm.sequence import SequenceGroupMetadata
from vllm.worker.model_runner_base import ModelRunnerInputBuilderBase


Expand Down Expand Up @@ -128,25 +127,12 @@ class AttentionMetadataBuilder(ABC, Generic[T]):
"""Abstract class for attention metadata builders."""

@abstractmethod
def __init__(self, input_builder) -> None:
def __init__(self, input_builder: "ModelRunnerInputBuilderBase") -> None:
raise NotImplementedError

@abstractmethod
def add_seq_group(self, seq_group_metadata: "SequenceGroupMetadata",
token_lens: List[int], seq_lens: List[int],
curr_seq_lens: List[int], query_lens: List[int],
context_lens: List[int],
curr_sliding_window_blocks: List[int],
prefix_cache_hit: bool, chunked_prefill_enabled: bool):
"""Add a sequence group to the metadata and update
corresponding fields (in Python objects).
"""
raise NotImplementedError

@abstractmethod
def build(self, runner: "ModelRunnerInputBuilderBase", seq_lens: List[int],
query_lens: List[int], cuda_graph_pad_size: int,
batch_size: int) -> T:
def build(self, seq_lens: List[int], query_lens: List[int],
cuda_graph_pad_size: int, batch_size: int) -> T:
"""Build attention metadata with on-device tensors."""
raise NotImplementedError

Expand Down
43 changes: 22 additions & 21 deletions vllm/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,10 @@
from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping,
compute_slot_mapping_start_idx,
is_block_tables_empty)
from vllm.sequence import SequenceGroupMetadata
from vllm.utils import make_tensor_with_pad

if TYPE_CHECKING:
from vllm.worker.model_runner import (GPUModelRunnerBase,
ModelInputForGPUBuilder)
from vllm.worker.model_runner import ModelInputForGPUBuilder


class FlashAttentionBackend(AttentionBackend):
Expand Down Expand Up @@ -212,30 +210,30 @@ def __init__(self, input_builder: "ModelInputForGPUBuilder"):
self.num_prefill_tokens = 0
self.num_decode_tokens = 0

self.input_builder = input_builder
self.runner = input_builder.runner
self.sliding_window = input_builder.sliding_window
self.block_size = input_builder.block_size
self.use_v2_block_manager = (
input_builder.scheduler_config.use_v2_block_manager)

def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata,
token_lens: List[int], seq_lens: List[int],
curr_seq_lens: List[int], query_lens: List[int],
context_lens: List[int],
curr_sliding_window_blocks: List[int],
prefix_cache_hit: bool, chunked_prefill_enabled: bool):
def _add_seq_group(
self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup",
chunked_prefill_enabled: bool):
"""Add a sequence group to the metadata. Specifically update/append
1. context length.
2. block table.
3. slot mapping.
"""
is_prompt = seq_group_metadata.is_prompt
block_tables = seq_group_metadata.block_tables
is_prompt = inter_data.is_prompt
block_tables = inter_data.block_tables

for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len,
curr_sliding_window_block) in zip(
seq_group_metadata.seq_data.keys(), token_lens, seq_lens,
curr_seq_lens, query_lens, context_lens,
curr_sliding_window_blocks):
inter_data.seq_ids, [len(t) for t in inter_data.input_tokens],
inter_data.orig_seq_lens, inter_data.seq_lens,
inter_data.query_lens, inter_data.context_lens,
inter_data.curr_sliding_window_blocks):
self.context_lens.append(context_len)

if is_prompt:
Expand All @@ -254,7 +252,7 @@ def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata,
# only allowing multiple of block_size chunk size.
# NOTE: This only works for oooooooxxx style attention.
block_table = []
if prefix_cache_hit:
if inter_data.prefix_cache_hit:
# NOTE(woosuk): For flash-attn, the block table should
# include the entries for the incoming prefill tokens.
block_table = block_tables[seq_id]
Expand All @@ -270,16 +268,19 @@ def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata,
self.use_v2_block_manager)
compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id,
seq_len, context_len, start_idx,
self.block_size,
seq_group_metadata.block_tables)
self.block_size, inter_data.block_tables)

def build(self, runner: "GPUModelRunnerBase", seq_lens, query_lens,
def build(self, seq_lens: List[int], query_lens: List[int],
cuda_graph_pad_size: int, batch_size: int):
"""Build attention metadata with on-device tensors."""
device = runner.device
for inter_data in self.input_builder.inter_data_list:
self._add_seq_group(inter_data,
self.input_builder.chunked_prefill_enabled)

device = self.runner.device
use_captured_graph = cuda_graph_pad_size != -1

logits_soft_cap = getattr(runner.model_config.hf_config,
logits_soft_cap = getattr(self.runner.model_config.hf_config,
"attn_logit_softcapping", None)
if logits_soft_cap is not None:
raise ValueError(
Expand All @@ -300,7 +301,7 @@ def build(self, runner: "GPUModelRunnerBase", seq_lens, query_lens,

# The shape of graph_block_tables is
# [max batch size, max context len // block size].
input_block_tables = runner.graph_block_tables[:batch_size]
input_block_tables = self.runner.graph_block_tables[:batch_size]
for i, block_table in enumerate(self.block_tables):
if block_table:
input_block_tables[i, :len(block_table)] = block_table
Expand Down
60 changes: 31 additions & 29 deletions vllm/attention/backends/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,10 @@
compute_slot_mapping_start_idx,
is_block_tables_empty)
from vllm.attention.ops.paged_attn import PagedAttention
from vllm.sequence import SequenceGroupMetadata
from vllm.utils import get_kv_cache_torch_dtype, make_tensor_with_pad

if TYPE_CHECKING:
from vllm.worker.model_runner import (GPUModelRunnerBase,
ModelInputForGPUBuilder)
from vllm.worker.model_runner import ModelInputForGPUBuilder


class FlashInferBackend(AttentionBackend):
Expand Down Expand Up @@ -216,6 +214,9 @@ def __init__(self, input_builder: "ModelInputForGPUBuilder"):
self.num_prefill_tokens = 0
self.num_decode_tokens = 0

self.input_builder = input_builder
self.runner = input_builder.runner

self.sliding_window = input_builder.sliding_window
self.block_size = input_builder.block_size
self.use_v2_block_manager = (
Expand All @@ -238,26 +239,24 @@ def __init__(self, input_builder: "ModelInputForGPUBuilder"):
# paged_kv_last_page_len is the length of the last page of each request
self.paged_kv_last_page_len: List[int] = []

def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata,
token_lens: List[int], seq_lens: List[int],
curr_seq_lens: List[int], query_lens: List[int],
context_lens: List[int],
curr_sliding_window_blocks: List[int],
prefix_cache_hit: bool, chunked_prefill_enabled: bool):
def _add_seq_group(
self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup",
chunked_prefill_enabled: bool):
"""Add a sequence group to the metadata. Specifically update/append
1. context length.
2. block table.
3. slot mapping.
"""
is_prompt = seq_group_metadata.is_prompt
block_tables = seq_group_metadata.block_tables
computed_block_nums = seq_group_metadata.computed_block_nums
is_prompt = inter_data.is_prompt
block_tables = inter_data.block_tables
computed_block_nums = inter_data.computed_block_nums

for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len,
curr_sliding_window_block) in zip(
seq_group_metadata.seq_data.keys(), token_lens, seq_lens,
curr_seq_lens, query_lens, context_lens,
curr_sliding_window_blocks):
inter_data.seq_ids, [len(t) for t in inter_data.input_tokens],
inter_data.orig_seq_lens, inter_data.seq_lens,
inter_data.query_lens, inter_data.context_lens,
inter_data.curr_sliding_window_blocks):
self.context_lens.append(context_len)
if is_prompt:
self.num_prefills += 1
Expand All @@ -275,7 +274,7 @@ def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata,
# only allowing multiple of block_size chunk size.
# NOTE: This only works for oooooooxxx style attention.
block_table = []
if prefix_cache_hit:
if inter_data.prefix_cache_hit:
block_table = computed_block_nums
elif ((chunked_prefill_enabled or not is_prompt)
and block_tables is not None):
Expand All @@ -290,8 +289,7 @@ def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata,
self.use_v2_block_manager)
compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id,
seq_len, context_len, start_idx,
self.block_size,
seq_group_metadata.block_tables)
self.block_size, inter_data.block_tables)

# It is not necessary to add paged_kv_indices, paged_kv_indptr,
# and paged_kv_last_page_len for profile run because we will
Expand All @@ -317,9 +315,13 @@ def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata,
last_page_len = self.block_size
self.paged_kv_last_page_len.append(last_page_len)

def build(self, runner: "GPUModelRunnerBase", seq_lens, query_lens,
def build(self, seq_lens: List[int], query_lens: List[int],
cuda_graph_pad_size: int, batch_size: int):
device = runner.device
for inter_data in self.input_builder.inter_data_list:
self._add_seq_group(inter_data,
self.input_builder.chunked_prefill_enabled)

device = self.runner.device
use_captured_graph = cuda_graph_pad_size != -1

max_query_len = max(query_lens)
Expand All @@ -333,7 +335,7 @@ def build(self, runner: "GPUModelRunnerBase", seq_lens, query_lens,

# The shape of graph_block_tables is
# [max batch size, max context len // block size].
input_block_tables = runner.graph_block_tables[:batch_size]
input_block_tables = self.runner.graph_block_tables[:batch_size]
for i, block_table in enumerate(self.block_tables):
if block_table:
input_block_tables[i, :len(block_table)] = block_table
Expand Down Expand Up @@ -377,7 +379,7 @@ def build(self, runner: "GPUModelRunnerBase", seq_lens, query_lens,
dtype=torch.long,
device=device)

logits_soft_cap = getattr(runner.model_config.hf_config,
logits_soft_cap = getattr(self.runner.model_config.hf_config,
"attn_logit_softcapping", None)

if len(self.paged_kv_indptr) > 0:
Expand All @@ -394,8 +396,8 @@ def build(self, runner: "GPUModelRunnerBase", seq_lens, query_lens,
paged_kv_indptr_tensor = None
paged_kv_last_page_len_tensor = None

kv_cache_dtype = get_kv_cache_torch_dtype(runner.kv_cache_dtype,
runner.model_config.dtype)
kv_cache_dtype = get_kv_cache_torch_dtype(
self.runner.kv_cache_dtype, self.runner.model_config.dtype)
return FlashInferMetadata(
num_prefills=self.num_prefills,
slot_mapping=slot_mapping_tensor,
Expand All @@ -406,11 +408,11 @@ def build(self, runner: "GPUModelRunnerBase", seq_lens, query_lens,
paged_kv_indptr=paged_kv_indptr_tensor,
paged_kv_indices=paged_kv_indices_tensor,
paged_kv_last_page_len=paged_kv_last_page_len_tensor,
num_qo_heads=runner.model_config.get_num_attention_heads(
runner.parallel_config),
num_kv_heads=runner.model_config.get_num_kv_heads(
runner.parallel_config),
head_dim=runner.model_config.get_head_size(),
num_qo_heads=self.runner.model_config.get_num_attention_heads(
self.runner.parallel_config),
num_kv_heads=self.runner.model_config.get_num_kv_heads(
self.runner.parallel_config),
head_dim=self.runner.model_config.get_head_size(),
page_size=self.block_size,
seq_start_loc=seq_start_loc,
query_start_loc=query_start_loc,
Expand Down
49 changes: 25 additions & 24 deletions vllm/attention/backends/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import torch

from vllm.attention import AttentionMetadata, AttentionMetadataBuilder
from vllm.sequence import SequenceGroupMetadata
from vllm.utils import make_tensor_with_pad

# Error string(s) for encoder/decoder
Expand All @@ -15,8 +14,7 @@
PAD_SLOT_ID = -1

if TYPE_CHECKING:
from vllm.worker.model_runner import (GPUModelRunnerBase,
ModelInputForGPUBuilder)
from vllm.worker.model_runner import ModelInputForGPUBuilder


def is_block_tables_empty(block_tables: Union[None, Dict]):
Expand Down Expand Up @@ -95,26 +93,27 @@ def __init__(self, input_builder: "ModelInputForGPUBuilder"):
self.num_prefill_tokens = 0
self.num_decode_tokens = 0

self.input_builder = input_builder
self.runner = input_builder.runner

self.sliding_window = input_builder.sliding_window
self.block_size = input_builder.block_size
self.use_v2_block_manager = (
input_builder.scheduler_config.use_v2_block_manager)

def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata,
token_lens: List[int], seq_lens: List[int],
curr_seq_lens: List[int], query_lens: List[int],
context_lens: List[int],
curr_sliding_window_blocks: List[int], prefix_cache_hit,
chunked_prefill_enabled):
is_prompt = seq_group_metadata.is_prompt
block_tables = seq_group_metadata.block_tables
computed_block_nums = seq_group_metadata.computed_block_nums
def _add_seq_group(
self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup",
chunked_prefill_enabled: bool):
is_prompt = inter_data.is_prompt
block_tables = inter_data.block_tables
computed_block_nums = inter_data.computed_block_nums

for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len,
curr_sliding_window_block) in zip(
seq_group_metadata.seq_data.keys(), token_lens, seq_lens,
curr_seq_lens, query_lens, context_lens,
curr_sliding_window_blocks):
inter_data.seq_ids, [len(t) for t in inter_data.input_tokens],
inter_data.orig_seq_lens, inter_data.seq_lens,
inter_data.query_lens, inter_data.context_lens,
inter_data.curr_sliding_window_blocks):
self.context_lens.append(context_len)
if is_prompt:
self.num_prefills += 1
Expand All @@ -132,7 +131,7 @@ def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata,
# only allowing multiple of block_size chunk size.
# NOTE: This only works for oooooooxxx style attention.
block_table = []
if prefix_cache_hit:
if inter_data.prefix_cache_hit:
block_table = computed_block_nums
elif ((chunked_prefill_enabled or not is_prompt)
and block_tables is not None):
Expand All @@ -146,16 +145,18 @@ def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata,
self.use_v2_block_manager)
compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id,
seq_len, context_len, start_idx,
self.block_size,
seq_group_metadata.block_tables)
self.block_size, inter_data.block_tables)

def build(self, seq_lens: List[int], query_lens: List[int],
cuda_graph_pad_size: int, batch_size: int):
for inter_data in self.input_builder.inter_data_list:
self._add_seq_group(inter_data,
self.input_builder.chunked_prefill_enabled)

def build(self, runner: "GPUModelRunnerBase", seq_lens: List[int],
query_lens: List[int], cuda_graph_pad_size: int,
batch_size: int):
device = runner.device
device = self.runner.device
use_captured_graph = cuda_graph_pad_size != -1

logits_soft_cap = getattr(runner.model_config.hf_config,
logits_soft_cap = getattr(self.runner.model_config.hf_config,
"attn_logit_softcapping", None)
if logits_soft_cap is not None:
raise ValueError(
Expand All @@ -176,7 +177,7 @@ def build(self, runner: "GPUModelRunnerBase", seq_lens: List[int],

# The shape of graph_block_tables is
# [max batch size, max context len // block size].
input_block_tables = runner.graph_block_tables[:batch_size]
input_block_tables = self.runner.graph_block_tables[:batch_size]
for i, block_table in enumerate(self.block_tables):
if block_table:
input_block_tables[i, :len(block_table)] = block_table
Expand Down
5 changes: 5 additions & 0 deletions vllm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -719,6 +719,11 @@ def merge_dicts(dict1: Dict[K, List[T]],
return dict(merged_dict)


def flatten_2d_lists(lists: List[List[T]]) -> List[T]:
"""Flatten a list of lists to a single list."""
return [item for sublist in lists for item in sublist]


def init_cached_hf_modules() -> None:
"""
Lazy initialization of the Hugging Face modules.
Expand Down
Loading

0 comments on commit e0c1575

Please sign in to comment.