Skip to content

Commit

Permalink
[Core] Add Sliding Window Support with Flashinfer (vllm-project#10462)
Browse files Browse the repository at this point in the history
Signed-off-by: Pavani Majety <[email protected]>
  • Loading branch information
pavanimajety authored Nov 21, 2024
1 parent 388ee3d commit 6c1208d
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 7 deletions.
12 changes: 10 additions & 2 deletions tests/core/block/e2e/test_correctness_sliding_window.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import pytest

from tests.kernels.utils import override_backend_env_variable
from vllm import LLM, SamplingParams

from .conftest import get_text_from_llm_generator
Expand All @@ -28,8 +29,9 @@
@pytest.mark.parametrize("test_llm_kwargs", [{}])
@pytest.mark.parametrize("batch_size", [5])
@pytest.mark.parametrize("seed", [1])
@pytest.mark.parametrize("backend", ["FLASH_ATTN", "FLASHINFER", "XFORMERS"])
def test_sliding_window_retrival(baseline_llm_generator, test_llm_generator,
batch_size, seed):
batch_size, seed, backend, monkeypatch):
"""
The test does a bunch of assignments "x1 = 10\nx2 = 33\n..." and then
asks for value of one of them (which is outside the sliding window).
Expand All @@ -38,6 +40,8 @@ def test_sliding_window_retrival(baseline_llm_generator, test_llm_generator,
Additionally, we compare the results of the v1 and v2 managers.
"""
override_backend_env_variable(monkeypatch, backend)

sampling_params = SamplingParams(
max_tokens=1024,
ignore_eos=True,
Expand Down Expand Up @@ -84,7 +88,9 @@ def test_sliding_window_retrival(baseline_llm_generator, test_llm_generator,
@pytest.mark.parametrize("test_llm_kwargs", [{"enable_chunked_prefill": True}])
@pytest.mark.parametrize("batch_size", [5])
@pytest.mark.parametrize("seed", [1])
def test_sliding_window_chunked_prefill(test_llm_generator, batch_size, seed):
@pytest.mark.parametrize("backend", ["FLASH_ATTN", "FLASHINFER", "XFORMERS"])
def test_sliding_window_chunked_prefill(test_llm_generator, batch_size, seed,
backend, monkeypatch):
"""
This is similar to test_sliding_window_retrival, however, it doesn't
compare against the v1 block manager since v1 doesn't support
Expand All @@ -93,6 +99,8 @@ def test_sliding_window_chunked_prefill(test_llm_generator, batch_size, seed):
The results with and without chunked prefill are not the same due to
numerical instabilities.
"""
override_backend_env_variable(monkeypatch, backend)

sampling_params = SamplingParams(
max_tokens=10,
ignore_eos=True,
Expand Down
13 changes: 8 additions & 5 deletions vllm/attention/backends/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -757,9 +757,8 @@ def __init__(
if alibi_slopes is not None:
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
self.alibi_slopes = alibi_slopes
if sliding_window is not None:
raise ValueError("Sliding window is not supported in FlashInfer.")
self.sliding_window = (-1, -1)
self.sliding_window = ((sliding_window - 1,
0) if sliding_window is not None else (-1, -1))
self.kv_cache_dtype = kv_cache_dtype
self.logits_soft_cap = logits_soft_cap

Expand Down Expand Up @@ -865,6 +864,8 @@ def unified_flash_infer(
assert query.shape[0] == num_prefill_tokens
assert decode_query.shape[0] == num_decode_tokens

window_left = window_size[0] if window_size is not None else -1

prefill_output: Optional[torch.Tensor] = None
decode_output: Optional[torch.Tensor] = None
if prefill_meta := attn_metadata.prefill_metadata:
Expand Down Expand Up @@ -895,7 +896,8 @@ def unified_flash_infer(
logits_soft_cap=logits_soft_cap,
causal=True,
k_scale=k_scale,
v_scale=v_scale)
v_scale=v_scale,
window_left=window_left)
if decode_meta := attn_metadata.decode_metadata:
assert attn_metadata.decode_metadata is not None
assert attn_metadata.decode_metadata.decode_wrapper is not None
Expand All @@ -905,7 +907,8 @@ def unified_flash_infer(
sm_scale=softmax_scale,
logits_soft_cap=logits_soft_cap,
k_scale=k_scale,
v_scale=v_scale)
v_scale=v_scale,
window_left=window_left)

if prefill_output is None and decode_output is not None:
# Decode only batch.
Expand Down

0 comments on commit 6c1208d

Please sign in to comment.