Skip to content

Commit

Permalink
Add multipstep chunked-prefill support where prefill turns into decod…
Browse files Browse the repository at this point in the history
…e after the first single step.
  • Loading branch information
elfiegg committed Nov 27, 2024
1 parent b98c62b commit ad48534
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 8 deletions.
19 changes: 17 additions & 2 deletions tests/multi_step/test_correctness_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

import pytest

from tests.kernels.utils import override_backend_env_variable

from ..models.utils import check_logprobs_close, check_outputs_equal

MODELS = [
Expand All @@ -19,10 +21,11 @@
@pytest.mark.parametrize("tp_size", [1])
@pytest.mark.parametrize("enable_chunked_prefill", [False, True])
@pytest.mark.parametrize("max_tokens", [5])
@pytest.mark.parametrize("enforce_eager", [True])
@pytest.mark.parametrize("enforce_eager", [True, False])
@pytest.mark.parametrize("num_scheduler_steps", NUM_SCHEDULER_STEPS)
@pytest.mark.parametrize("num_prompts", NUM_PROMPTS)
@pytest.mark.parametrize("num_logprobs", [None, 5])
@pytest.mark.parametrize("attention_backend", ["FLASH_ATTN", "FLASHINFER"])
def test_multi_step_llm(
hf_runner,
vllm_runner,
Expand All @@ -36,6 +39,8 @@ def test_multi_step_llm(
num_scheduler_steps: int,
num_prompts: int,
num_logprobs: Optional[int],
attention_backend: str,
monkeypatch,
) -> None:
"""Test vLLM engine with multi-step scheduling via sync LLM Engine.
Expand Down Expand Up @@ -63,6 +68,7 @@ def test_multi_step_llm(
num_logprobs: corresponds to the `logprobs` argument to the OpenAI
completions endpoint; `None` -> 1 logprob returned.
"""
override_backend_env_variable(monkeypatch, attention_backend)

prompts = example_prompts
if len(prompts) < num_prompts:
Expand Down Expand Up @@ -110,10 +116,11 @@ def test_multi_step_llm(
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("tp_size", [1])
@pytest.mark.parametrize("max_tokens", [5])
@pytest.mark.parametrize("enforce_eager", [True])
@pytest.mark.parametrize("enforce_eager", [True, False])
@pytest.mark.parametrize("num_scheduler_steps", NUM_SCHEDULER_STEPS)
@pytest.mark.parametrize("num_prompts", NUM_PROMPTS)
@pytest.mark.parametrize("num_logprobs,num_prompt_logprobs", [(5, 5)])
@pytest.mark.parametrize("attention_backend", ["FLASH_ATTN", "FLASHINFER"])
def test_multi_step_llm_w_prompt_logprobs(
vllm_runner,
example_prompts,
Expand All @@ -126,6 +133,8 @@ def test_multi_step_llm_w_prompt_logprobs(
num_prompts: int,
num_logprobs: Optional[int],
num_prompt_logprobs: Optional[int],
attention_backend: str,
monkeypatch,
) -> None:
"""Test prompt logprobs with multi-step scheduling via sync LLM Engine.
Expand Down Expand Up @@ -155,6 +164,7 @@ def test_multi_step_llm_w_prompt_logprobs(
note that this argument is not supported by the
OpenAI completions endpoint.
"""
override_backend_env_variable(monkeypatch, attention_backend)

prompts = example_prompts
if len(prompts) < num_prompts:
Expand Down Expand Up @@ -205,6 +215,7 @@ def test_multi_step_llm_w_prompt_logprobs(
@pytest.mark.parametrize("num_scheduler_steps", NUM_SCHEDULER_STEPS)
@pytest.mark.parametrize("num_prompts", NUM_PROMPTS)
@pytest.mark.parametrize("num_logprobs", [None, 5])
@pytest.mark.parametrize("attention_backend", ["FLASH_ATTN"])
def test_multi_step_llm_chunked_prefill_prefix_cache(
vllm_runner,
example_prompts,
Expand All @@ -216,6 +227,8 @@ def test_multi_step_llm_chunked_prefill_prefix_cache(
num_scheduler_steps: int,
num_prompts: int,
num_logprobs: Optional[int],
attention_backend: str,
monkeypatch,
) -> None:
"""Test vLLM engine with multi-step+"single-step chunked prefill"+APC.
Expand Down Expand Up @@ -278,6 +291,8 @@ def test_multi_step_llm_chunked_prefill_prefix_cache(
#
# The Incorrect scheduling behavior - if it occurs - will cause an exception
# in the model runner resulting from `do_sample=False`.
override_backend_env_variable(monkeypatch, attention_backend)

assert len(example_prompts) >= 2
challenge_prompts = copy.deepcopy(example_prompts)
challenge_prompts[0] = ('vLLM is a high-throughput and memory-efficient '
Expand Down
51 changes: 46 additions & 5 deletions vllm/attention/backends/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ def prepare_graph_input_buffers(self,
def begin_forward(self, model_input):
assert not self._is_graph_capturing
state = self
if model_input.attn_metadata.use_cuda_graph:
if model_input.attn_metadata.use_cuda_graph and model_input.attn_metadata.num_prefills == 0:

Check failure on line 259 in vllm/attention/backends/flashinfer.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (E501)

vllm/attention/backends/flashinfer.py:259:81: E501 Line too long (100 > 80)
batch_size = model_input.input_tokens.shape[0]
state = (self.runner.graph_runners[model_input.virtual_engine]
[batch_size].attn_state)
Expand Down Expand Up @@ -429,10 +429,24 @@ def advance_step(self,
Update metadata in-place to advance one decode step.
"""

assert not turn_prefills_into_decodes, \
("Chunked prefill is not supported with flashinfer yet."
"turn_prefills_into_decodes is a Multi-Step + Chunked-Prefill "
"specific parameter.")
if turn_prefills_into_decodes:
# When Multi-Step is enabled with Chunked-Prefill, prefills and
# decodes are scheduled together. In the first step, all the
# prefills turn into decodes. This update reflects that
# conversion.
assert self.num_decode_tokens + self.num_prefills == num_seqs
# Flashinfer doesn't support speculative decoding + chunked-prefill
# + multi-step scheduling yet.
assert self.decode_query_len == 1
self.num_decode_tokens += self.num_prefills
self.num_prefills = 0
self.num_prefill_tokens = 0
self.max_prefill_seq_len = 0
self.max_query_len = 1

self.slot_mapping = self.slot_mapping[:num_seqs]
else:
assert self.seq_lens_tensor is not None

assert num_seqs > 0
assert num_queries > 0
Expand Down Expand Up @@ -895,3 +909,30 @@ def forward(
decode_output = decode_output.squeeze(1)
output = torch.cat([prefill_output, decode_output], dim=0)
return output.view(num_tokens, hidden_size)


def unified_flash_infer_fake(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
num_heads: int,
head_size: int,
num_kv_heads: int,
kv_cache: torch.Tensor,
kv_cache_dtype: str,
k_scale: float,
v_scale: float,
softmax_scale: float,
window_size: Optional[List[int]] = None,
alibi_slopes: Optional[torch.Tensor] = None,
logits_soft_cap: Optional[float] = None,
) -> torch.Tensor:
return torch.empty_like(query).contiguous()


direct_register_custom_op(

Check failure on line 933 in vllm/attention/backends/flashinfer.py

View workflow job for this annotation

GitHub Actions / mypy (3.9)

Name "direct_register_custom_op" is not defined [name-defined]

Check failure on line 933 in vllm/attention/backends/flashinfer.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (F821)

vllm/attention/backends/flashinfer.py:933:1: F821 Undefined name `direct_register_custom_op`

Check failure on line 933 in vllm/attention/backends/flashinfer.py

View workflow job for this annotation

GitHub Actions / mypy (3.10)

Name "direct_register_custom_op" is not defined [name-defined]

Check failure on line 933 in vllm/attention/backends/flashinfer.py

View workflow job for this annotation

GitHub Actions / mypy (3.11)

Name "direct_register_custom_op" is not defined [name-defined]

Check failure on line 933 in vllm/attention/backends/flashinfer.py

View workflow job for this annotation

GitHub Actions / mypy (3.12)

Name "direct_register_custom_op" is not defined [name-defined]
op_name="unified_flash_infer",
op_func=unified_flash_infer,

Check failure on line 935 in vllm/attention/backends/flashinfer.py

View workflow job for this annotation

GitHub Actions / mypy (3.9)

Name "unified_flash_infer" is not defined [name-defined]

Check failure on line 935 in vllm/attention/backends/flashinfer.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (F821)

vllm/attention/backends/flashinfer.py:935:13: F821 Undefined name `unified_flash_infer`

Check failure on line 935 in vllm/attention/backends/flashinfer.py

View workflow job for this annotation

GitHub Actions / mypy (3.10)

Name "unified_flash_infer" is not defined [name-defined]

Check failure on line 935 in vllm/attention/backends/flashinfer.py

View workflow job for this annotation

GitHub Actions / mypy (3.11)

Name "unified_flash_infer" is not defined [name-defined]

Check failure on line 935 in vllm/attention/backends/flashinfer.py

View workflow job for this annotation

GitHub Actions / mypy (3.12)

Name "unified_flash_infer" is not defined [name-defined]
mutates_args=["kv_cache"],
fake_impl=unified_flash_infer_fake,
)
2 changes: 1 addition & 1 deletion vllm/worker/multi_step_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
logger = init_logger(__name__)

MULTI_STEP_ATTENTION_BACKENDS = ["FLASH_ATTN", "ROCM_FLASH", "FLASHINFER"]
MULTI_STEP_CHUNKED_PREFILL_ATTENTION_BACKENDS = ["FLASH_ATTN"]
MULTI_STEP_CHUNKED_PREFILL_ATTENTION_BACKENDS = ["FLASH_ATTN", "FLASHINFER"]

def _get_supported_attention_backends(chunked_prefill_enabled: bool) \
-> List[str]:
Expand Down

0 comments on commit ad48534

Please sign in to comment.