From 6c1208d083fbaaf89c6d812f4d3424e15182f652 Mon Sep 17 00:00:00 2001 From: Pavani Majety Date: Wed, 20 Nov 2024 19:56:47 -0800 Subject: [PATCH] [Core] Add Sliding Window Support with Flashinfer (#10462) Signed-off-by: Pavani Majety --- .../block/e2e/test_correctness_sliding_window.py | 12 ++++++++++-- vllm/attention/backends/flashinfer.py | 13 ++++++++----- 2 files changed, 18 insertions(+), 7 deletions(-) diff --git a/tests/core/block/e2e/test_correctness_sliding_window.py b/tests/core/block/e2e/test_correctness_sliding_window.py index 9320a9ef62314..415d0bd8237df 100644 --- a/tests/core/block/e2e/test_correctness_sliding_window.py +++ b/tests/core/block/e2e/test_correctness_sliding_window.py @@ -3,6 +3,7 @@ import pytest +from tests.kernels.utils import override_backend_env_variable from vllm import LLM, SamplingParams from .conftest import get_text_from_llm_generator @@ -28,8 +29,9 @@ @pytest.mark.parametrize("test_llm_kwargs", [{}]) @pytest.mark.parametrize("batch_size", [5]) @pytest.mark.parametrize("seed", [1]) +@pytest.mark.parametrize("backend", ["FLASH_ATTN", "FLASHINFER", "XFORMERS"]) def test_sliding_window_retrival(baseline_llm_generator, test_llm_generator, - batch_size, seed): + batch_size, seed, backend, monkeypatch): """ The test does a bunch of assignments "x1 = 10\nx2 = 33\n..." and then asks for value of one of them (which is outside the sliding window). @@ -38,6 +40,8 @@ def test_sliding_window_retrival(baseline_llm_generator, test_llm_generator, Additionally, we compare the results of the v1 and v2 managers. """ + override_backend_env_variable(monkeypatch, backend) + sampling_params = SamplingParams( max_tokens=1024, ignore_eos=True, @@ -84,7 +88,9 @@ def test_sliding_window_retrival(baseline_llm_generator, test_llm_generator, @pytest.mark.parametrize("test_llm_kwargs", [{"enable_chunked_prefill": True}]) @pytest.mark.parametrize("batch_size", [5]) @pytest.mark.parametrize("seed", [1]) -def test_sliding_window_chunked_prefill(test_llm_generator, batch_size, seed): +@pytest.mark.parametrize("backend", ["FLASH_ATTN", "FLASHINFER", "XFORMERS"]) +def test_sliding_window_chunked_prefill(test_llm_generator, batch_size, seed, + backend, monkeypatch): """ This is similar to test_sliding_window_retrival, however, it doesn't compare against the v1 block manager since v1 doesn't support @@ -93,6 +99,8 @@ def test_sliding_window_chunked_prefill(test_llm_generator, batch_size, seed): The results with and without chunked prefill are not the same due to numerical instabilities. """ + override_backend_env_variable(monkeypatch, backend) + sampling_params = SamplingParams( max_tokens=10, ignore_eos=True, diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index 107e3bbf79666..b61c660e3e280 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -757,9 +757,8 @@ def __init__( if alibi_slopes is not None: alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) self.alibi_slopes = alibi_slopes - if sliding_window is not None: - raise ValueError("Sliding window is not supported in FlashInfer.") - self.sliding_window = (-1, -1) + self.sliding_window = ((sliding_window - 1, + 0) if sliding_window is not None else (-1, -1)) self.kv_cache_dtype = kv_cache_dtype self.logits_soft_cap = logits_soft_cap @@ -865,6 +864,8 @@ def unified_flash_infer( assert query.shape[0] == num_prefill_tokens assert decode_query.shape[0] == num_decode_tokens + window_left = window_size[0] if window_size is not None else -1 + prefill_output: Optional[torch.Tensor] = None decode_output: Optional[torch.Tensor] = None if prefill_meta := attn_metadata.prefill_metadata: @@ -895,7 +896,8 @@ def unified_flash_infer( logits_soft_cap=logits_soft_cap, causal=True, k_scale=k_scale, - v_scale=v_scale) + v_scale=v_scale, + window_left=window_left) if decode_meta := attn_metadata.decode_metadata: assert attn_metadata.decode_metadata is not None assert attn_metadata.decode_metadata.decode_wrapper is not None @@ -905,7 +907,8 @@ def unified_flash_infer( sm_scale=softmax_scale, logits_soft_cap=logits_soft_cap, k_scale=k_scale, - v_scale=v_scale) + v_scale=v_scale, + window_left=window_left) if prefill_output is None and decode_output is not None: # Decode only batch.