Skip to content

Commit

Permalink
[Misc] Add chunked-prefill support on FlashInfer. (vllm-project#9781)
Browse files Browse the repository at this point in the history
  • Loading branch information
elfiegg authored Oct 30, 2024
1 parent 81f09cf commit 9ff4511
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 28 deletions.
12 changes: 12 additions & 0 deletions tests/basic_correctness/test_chunked_prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@

import pytest

from tests.kernels.utils import override_backend_env_variable

from ..models.utils import check_logprobs_close, check_outputs_equal
from ..utils import multi_gpu_test

Expand All @@ -28,6 +30,7 @@
# NOTE: Increasing this in this suite will fail CI because we currently cannot
# reset distributed env properly. Use a value > 1 just when you test.
@pytest.mark.parametrize("tensor_parallel_size", [1])
@pytest.mark.parametrize("attention_backend", ["FLASHINFER", "FLASH_ATTN"])
def test_models(
hf_runner,
vllm_runner,
Expand All @@ -38,11 +41,15 @@ def test_models(
chunked_prefill_token_size: int,
enforce_eager: bool,
tensor_parallel_size: int,
attention_backend: str,
monkeypatch,
) -> None:
"""
Checks exact match decode between huggingface model and vllm runner with
chunked prefill.
"""
override_backend_env_variable(monkeypatch, attention_backend)

max_num_seqs = chunked_prefill_token_size
max_num_batched_tokens = chunked_prefill_token_size

Expand Down Expand Up @@ -71,13 +78,18 @@ def test_models(
@multi_gpu_test(num_gpus=2)
@pytest.mark.parametrize("distributed_executor_backend", ["ray", "mp"])
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("attention_backend", ["FLASHINFER", "FLASH_ATTN"])
def test_models_distributed(
hf_runner,
vllm_runner,
example_prompts,
model: str,
distributed_executor_backend: str,
attention_backend: str,
monkeypatch,
) -> None:
override_backend_env_variable(monkeypatch, attention_backend)

if (model == "meta-llama/Llama-2-7b-hf"
and distributed_executor_backend == "ray"):
# test ray adag
Expand Down
88 changes: 60 additions & 28 deletions vllm/attention/backends/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,11 @@ class FlashInferMetadata(AttentionMetadata):
# Maximum sequence length among prefill batch. 0 if there are decoding
# requests only.
max_prefill_seq_len: int
# Number of query tokens for each request in the batch.
# Currently, we require that all requests have the same number of query
# tokens during the decoding phase. When speculavie decoding is enabled,
# decode_query_len might be greater than 1. In all other cases, it is 1.
decode_query_len: Optional[int] = 1

use_cuda_graph: bool = True

Expand Down Expand Up @@ -335,6 +340,7 @@ def begin_forward(self):
assert self.paged_kv_last_page_len is not None
assert self.block_table_bound is not None
assert self.seq_lens_tensor is not None
self.query_start_loc = self.query_start_loc[:self.num_prefills + 1]
batch_size = self.query_start_loc.shape[0] - 1
assert batch_size >= 0
# We will use flash attention for profiling to
Expand All @@ -349,11 +355,13 @@ def begin_forward(self):
self.paged_kv_indices = self.paged_kv_indices.to(self.device)
self.prefill_wrapper.end_forward()
self.prefill_wrapper.begin_forward(
self.query_start_loc, self.paged_kv_indptr,
self.paged_kv_indices, self.paged_kv_last_page_len,
self.query_start_loc,
self.paged_kv_indptr[:self.num_prefills + 1],
self.paged_kv_indices,
self.paged_kv_last_page_len[:self.num_prefills],
self.num_qo_heads, self.num_kv_heads, self.head_dim,
self.page_size)
else:
if self.num_decode_tokens > 0:
assert self.paged_kv_indices is not None
assert self.paged_kv_indptr is not None
assert self.paged_kv_last_page_len is not None
Expand All @@ -370,9 +378,9 @@ def begin_forward(self):
assert self.decode_wrapper is not None
self.decode_wrapper.end_forward()
self.decode_wrapper.begin_forward(
self.paged_kv_indptr,
self.paged_kv_indptr[self.num_prefills:],
self.paged_kv_indices,
self.paged_kv_last_page_len,
self.paged_kv_last_page_len[self.num_prefills:],
self.num_qo_heads,
self.num_kv_heads,
self.head_dim,
Expand All @@ -397,21 +405,14 @@ def asdict_zerocopy(self,

@property
def prefill_metadata(self) -> Optional["FlashInferMetadata"]:
# Currently chunked prefill is not supported
if self.num_decode_tokens == 0:
assert self.num_prefills > 0
return self

return None
if self.num_prefills == 0:
return None
return self

@property
def decode_metadata(self) -> Optional["FlashInferMetadata"]:
# Currently chunked prefill is not supported
if self.num_prefills > 0:
assert self.num_decode_tokens == 0, (
"Chunked prefill is not supported with flashinfer yet.")
if self.num_decode_tokens == 0:
return None

return self

def advance_step(self,
Expand Down Expand Up @@ -599,11 +600,12 @@ def build(self, seq_lens: List[int], query_lens: List[int],

max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
num_decode_tokens = self.num_decode_tokens
decode_query_len = max(query_lens[self.num_prefills:], default=1)

if use_captured_graph:
self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size)
self.block_tables.extend([] * cuda_graph_pad_size)
num_decode_tokens = batch_size
num_decode_tokens = batch_size - self.num_prefill_tokens

# The shape of graph_block_tables is
# [max batch size, max context len // block size].
Expand Down Expand Up @@ -689,6 +691,7 @@ def build(self, seq_lens: List[int], query_lens: List[int],
self.runner.kv_cache_dtype, self.runner.model_config.dtype)

return FlashInferMetadata(
decode_query_len=decode_query_len,
num_prefills=self.num_prefills,
slot_mapping=slot_mapping_tensor,
num_prefill_tokens=self.num_prefill_tokens,
Expand Down Expand Up @@ -811,12 +814,6 @@ def unified_flash_infer(
key = key.view(-1, num_kv_heads, head_size)
value = value.view(-1, num_kv_heads, head_size)

if attn_metadata.num_prefill_tokens > 0:
assert attn_metadata.num_decode_tokens == 0, (
"Chunked prefill is not supported with flashinfer yet.")
if attn_metadata.num_decode_tokens > 0:
assert attn_metadata.num_prefill_tokens == 0, (
"Chunked prefill is not supported with flashinfer yet.")
if kv_cache.numel() > 0:
# Use the same reshape and cache kernel as flash attention.
ops.reshape_and_cache_flash(
Expand All @@ -836,14 +833,33 @@ def unified_flash_infer(
kv_cache_dtype)
kv_cache = kv_cache.view(torch_dtype)

num_prefill_tokens = attn_metadata.num_prefill_tokens
num_decode_tokens = attn_metadata.num_decode_tokens
assert key.shape[0] == num_prefill_tokens + num_decode_tokens, \
f"key : {key.shape} : #prefill tokens {num_prefill_tokens} : #decode tokens {num_decode_tokens}" # noqa
assert value.shape[0] == num_prefill_tokens + num_decode_tokens, \
f"value : {value.shape} : #prefill toks {num_prefill_tokens} : #decode toks {num_decode_tokens}" # noqa
query = query.contiguous() # Flashinfer requires query to be contiguous
# Query for decode. KV is not needed because it is already cached.
# QKV for prefill.
decode_query = query[num_prefill_tokens:]
query = query[:num_prefill_tokens]

key = key[:num_prefill_tokens]
value = value[:num_prefill_tokens]

assert query.shape[0] == num_prefill_tokens
assert decode_query.shape[0] == num_decode_tokens

prefill_output: Optional[torch.Tensor] = None
decode_output: Optional[torch.Tensor] = None
if prefill_meta := attn_metadata.prefill_metadata:
# We will use flash attention for prefill
# when kv_cache is not provided.
# This happens when vllm runs the profiling to
# determine the number of blocks.
if kv_cache.numel() == 0:
output = flash_attn_varlen_func(
prefill_output = flash_attn_varlen_func(
q=query,
k=key,
v=value,
Expand All @@ -859,18 +875,34 @@ def unified_flash_infer(
else:
assert prefill_meta is not None
assert prefill_meta.prefill_wrapper is not None
output = prefill_meta.prefill_wrapper.forward(
prefill_output = prefill_meta.prefill_wrapper.forward(
query, kv_cache, logits_soft_cap=logits_soft_cap, causal=True)
else:
if decode_meta := attn_metadata.decode_metadata:
assert attn_metadata.decode_metadata is not None
assert attn_metadata.decode_metadata.decode_wrapper is not None
output = attn_metadata.decode_metadata.decode_wrapper.forward(
query,
decode_output = attn_metadata.decode_metadata.decode_wrapper.forward(
decode_query,
kv_cache,
sm_scale=softmax_scale,
logits_soft_cap=logits_soft_cap,
k_scale=k_scale,
v_scale=v_scale)

if prefill_output is None and decode_output is not None:
# Decode only batch.
output, num_tokens = decode_output, num_decode_tokens
elif decode_output is None and prefill_output is not None:
# Prefill only batch.
output, num_tokens = prefill_output, num_prefill_tokens
else:
# Chunked prefill batch does not work with speculative decoding in
# FlashInfer backend, so the query length for decode should be 1.
assert prefill_output is not None
assert decode_output is not None
assert decode_meta is not None
assert decode_meta.decode_query_len == 1
decode_output = decode_output.squeeze(1)
output = torch.cat([prefill_output, decode_output], dim=0)
return output.view(num_tokens, hidden_size)


Expand Down

0 comments on commit 9ff4511

Please sign in to comment.