Skip to content

Commit

Permalink
Implement MetadataBuilder for XFormers
Browse files Browse the repository at this point in the history
Signed-off-by: sasha0552 <[email protected]>
  • Loading branch information
sasha0552 authored Oct 30, 2024
1 parent c787f2d commit d82d162
Showing 1 changed file with 171 additions and 6 deletions.
177 changes: 171 additions & 6 deletions vllm/attention/backends/xformers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Attention layer with xFormers and PagedAttention."""
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Type
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type

import torch
from xformers import ops as xops
Expand All @@ -10,12 +10,20 @@
LowerTriangularMaskWithTensorBias)

from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionType)
from vllm.attention.backends.utils import (CommonAttentionState,
CommonMetadataBuilder)
AttentionMetadata,
AttentionMetadataBuilder,
AttentionType)
from vllm.attention.backends.utils import (PAD_SLOT_ID, CommonAttentionState,
compute_slot_mapping,
compute_slot_mapping_start_idx,
is_block_tables_empty)
from vllm.attention.ops.paged_attn import (PagedAttention,
PagedAttentionMetadata)
from vllm.logger import init_logger
from vllm.utils import async_tensor_h2d, make_tensor_with_pad

if TYPE_CHECKING:
from vllm.worker.model_runner import ModelInputForGPUBuilder

logger = init_logger(__name__)

Expand Down Expand Up @@ -384,9 +392,166 @@ def _get_seq_len_block_table_args(
raise AttributeError(f"Invalid attention type {str(attn_type)}")


class XFormersMetadataBuilder(CommonMetadataBuilder[XFormersMetadata]):
class XFormersMetadataBuilder(AttentionMetadataBuilder[XFormersMetadata]):

def __init__(self, input_builder: "ModelInputForGPUBuilder"):
self.slot_mapping: List[int] = []
self.prefill_seq_lens: List[int] = []
self.context_lens: List[int] = []
self.block_tables: List[List[int]] = []
self.curr_seq_lens: List[int] = []
self.num_prefills = 0
self.num_prefill_tokens = 0
self.num_decode_tokens = 0
self.has_prefix_cache_hit = False

self.input_builder = input_builder
self.runner = input_builder.runner

self.sliding_window = input_builder.sliding_window
self.block_size = input_builder.block_size

def _add_seq_group(
self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup",
chunked_prefill_enabled: bool, prefix_cache_hit: bool):
is_prompt = inter_data.is_prompt
block_tables = inter_data.block_tables

for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len,
curr_sliding_window_block) in zip(
inter_data.seq_ids, [len(t) for t in inter_data.input_tokens],
inter_data.orig_seq_lens, inter_data.seq_lens,
inter_data.query_lens, inter_data.context_lens,
inter_data.curr_sliding_window_blocks):
self.context_lens.append(context_len)
if is_prompt:
self.num_prefills += 1
self.num_prefill_tokens += token_len
self.prefill_seq_lens.append(seq_len)
else:
assert query_len == 1, (
"seq_len: {}, context_len: {}, query_len: {}".format(
seq_len, context_len, query_len))
self.num_decode_tokens += query_len
self.curr_seq_lens.append(curr_seq_len)

# Compute block table.
# TODO(sang): Combine chunked prefill and prefix caching by
# only allowing multiple of block_size chunk size.
# NOTE: This only works for oooooooxxx style attention.
block_table = []
if prefix_cache_hit:
# NOTE(woosuk): For xformers, the block table should
# include the entries for the incoming prefill tokens.
block_table = block_tables[seq_id]
elif ((chunked_prefill_enabled or not is_prompt)
and block_tables is not None):
if curr_sliding_window_block == 0:
block_table = block_tables[seq_id]
else:
block_table = block_tables[seq_id][
-curr_sliding_window_block:]
self.block_tables.append(block_table)

# Compute slot mapping.
is_profile_run = is_block_tables_empty(block_tables)
start_idx = compute_slot_mapping_start_idx(is_prompt, query_len,
context_len,
self.sliding_window)
compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id,
seq_len, context_len, start_idx,
self.block_size, inter_data.block_tables)

def build(self, seq_lens: List[int], query_lens: List[int],
cuda_graph_pad_size: int, batch_size: int):
"""Build attention metadata with on-device tensors.
_metadata_cls = XFormersMetadata
Args:
seq_lens: The maybe padded sequence lengths of the input sequences.
query_lens: The query lengths of the input sequences.
cuda_graph_pad_size: The padding size for cuda graph.
-1 if cuda graph is not used.
batch_size: The maybe padded batch size.
"""
prefix_cache_hit = any([
inter_data.prefix_cache_hit
for inter_data in self.input_builder.inter_data_list
])
for inter_data in self.input_builder.inter_data_list:
self._add_seq_group(inter_data,
self.input_builder.chunked_prefill_enabled,
prefix_cache_hit)

device = self.runner.device
use_captured_graph = cuda_graph_pad_size != -1

max_query_len = max(query_lens)
max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
max_decode_seq_len = max(self.curr_seq_lens, default=0)
num_decode_tokens = self.num_decode_tokens

if use_captured_graph:
self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size)
self.block_tables.extend([] * cuda_graph_pad_size)
num_decode_tokens = batch_size

# The shape of graph_block_tables is
# [max batch size, max context len // block size].
input_block_tables = self.runner.graph_block_tables[:batch_size]
for i, block_table in enumerate(self.block_tables):
if block_table:
input_block_tables[i, :len(block_table)] = block_table
block_tables = torch.from_numpy(input_block_tables).to(
device, non_blocking=True)
else:
block_tables = make_tensor_with_pad(
self.block_tables,
pad=0,
dtype=torch.int,
device=device,
)
assert max_query_len > 0, "query_lens: {}".format(query_lens)

assert device is not None
context_lens_tensor = async_tensor_h2d(self.context_lens, torch.int,
device, self.runner.pin_memory)
seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device,
self.runner.pin_memory)
query_lens_tensor = async_tensor_h2d(query_lens, torch.long, device,
self.runner.pin_memory)
slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.long,
device, self.runner.pin_memory)
query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1,
dtype=torch.int32,
device=device)
seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1,
dtype=torch.int32,
device=device)
torch.cumsum(seq_lens_tensor,
dim=0,
dtype=seq_start_loc.dtype,
out=seq_start_loc[1:])
torch.cumsum(query_lens_tensor,
dim=0,
dtype=query_start_loc.dtype,
out=query_start_loc[1:])

return self._metadata_cls( # type: ignore
num_prefills=self.num_prefills,
slot_mapping=slot_mapping_tensor,
num_prefill_tokens=self.num_prefill_tokens,
num_decode_tokens=num_decode_tokens,
seq_lens=seq_lens,
seq_lens_tensor=seq_lens_tensor,
max_query_len=max_query_len,
max_prefill_seq_len=max_prefill_seq_len,
max_decode_seq_len=max_decode_seq_len,
query_start_loc=query_start_loc,
seq_start_loc=seq_start_loc,
context_lens_tensor=context_lens_tensor,
block_tables=block_tables,
use_cuda_graph=use_captured_graph,
)


class XFormersImpl(AttentionImpl[XFormersMetadata]):
Expand Down

0 comments on commit d82d162

Please sign in to comment.