diff --git a/tests/prefix_caching/test_prefix_caching.py b/tests/prefix_caching/test_prefix_caching.py index 366b030eaa399..fd6564bbfe630 100644 --- a/tests/prefix_caching/test_prefix_caching.py +++ b/tests/prefix_caching/test_prefix_caching.py @@ -5,6 +5,7 @@ import pytest from tests.kernels.utils import override_backend_env_variable +from vllm import SamplingParams, TokensPrompt from ..models.utils import check_outputs_equal @@ -12,6 +13,14 @@ "facebook/opt-125m", ] +UNSTABLE_PROMPT_SEQUENCE = [ + ([0] * 588) + ([1] * 1332) + ([2] * 30) + ([3] * 1), + ([0] * 588) + ([1] * 1332) + ([4] * 3) + ([5] * 50), + ([0] * 588) + ([1] * 1332) + ([2] * 30) + ([6] * 95), + ([0] * 588) + ([1] * 1332) + ([4] * 3) + ([7] * 174), + ([0] * 588) + ([8] * 1539), +] + @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("backend", ["FLASH_ATTN", "FLASHINFER", "XFORMERS"]) @@ -57,3 +66,22 @@ def test_mixed_requests( name_0="hf", name_1="vllm", ) + + +@pytest.mark.parametrize("backend", ["FLASH_ATTN", "FLASHINFER", "XFORMERS"]) +def test_unstable_prompt_sequence( + vllm_runner, + backend: str, + monkeypatch, +) -> None: + override_backend_env_variable(monkeypatch, backend) + + with vllm_runner( + "Qwen/Qwen2.5-0.5B-Instruct", + enable_chunked_prefill=True, + enable_prefix_caching=True, + max_model_len=4096, + ) as vllm_model: + for prompt in UNSTABLE_PROMPT_SEQUENCE: + vllm_model.generate(TokensPrompt(prompt_token_ids=prompt), + SamplingParams(max_tokens=1)) diff --git a/vllm/attention/backends/utils.py b/vllm/attention/backends/utils.py index d1a44f3e8bfa6..32fccd0dfb496 100644 --- a/vllm/attention/backends/utils.py +++ b/vllm/attention/backends/utils.py @@ -138,7 +138,6 @@ def _add_seq_group( chunked_prefill_enabled: bool): is_prompt = inter_data.is_prompt block_tables = inter_data.block_tables - computed_block_nums = inter_data.computed_block_nums for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len, curr_sliding_window_block) in zip( @@ -164,10 +163,14 @@ def _add_seq_group( # NOTE: This only works for oooooooxxx style attention. block_table = [] if inter_data.prefix_cache_hit: - block_table = computed_block_nums + block_table = block_tables[seq_id] elif ((chunked_prefill_enabled or not is_prompt) and block_tables is not None): - block_table = block_tables[seq_id][-curr_sliding_window_block:] + if curr_sliding_window_block == 0: + block_table = block_tables[seq_id] + else: + block_table = block_tables[seq_id][ + -curr_sliding_window_block:] self.block_tables.append(block_table) # Compute slot mapping.