diff --git a/tests/prefix_caching/test_prefix_caching.py b/tests/prefix_caching/test_prefix_caching.py index 7985001d34eb1..9821dbd066a59 100644 --- a/tests/prefix_caching/test_prefix_caching.py +++ b/tests/prefix_caching/test_prefix_caching.py @@ -6,10 +6,17 @@ import pytest +from tests.kernels.utils import override_backend_env_variable from vllm.block import PhysicalTokenBlock from vllm.core.block_manager_v1 import CachedBlockAllocator from vllm.utils import Device +from ..models.utils import check_outputs_equal + +MODELS = [ + "facebook/opt-125m", +] + @pytest.mark.parametrize("block_size", [16]) @pytest.mark.parametrize("num_blocks", [16]) @@ -76,3 +83,52 @@ def test_eviction(num_blocks: int, ): assert (realloc_block != new_block) assert (new_block.block_hash == new_block_hash) assert (new_block.block_number == 2) + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("backend", ["FLASH_ATTN", "FLASHINFER", "XFORMERS"]) +@pytest.mark.parametrize("dtype", ["half"]) +@pytest.mark.parametrize("max_tokens", [5]) +@pytest.mark.parametrize("cached_position", [0, 1]) +@pytest.mark.parametrize("use_v2_block_manager", [False, True]) +def test_mixed_requests( + hf_runner, + vllm_runner, + example_prompts, + model: str, + backend: str, + dtype: str, + max_tokens: int, + cached_position: int, + use_v2_block_manager: bool, + monkeypatch, +) -> None: + """ + Test the case when some sequences have the prefix cache hit + and the others don't. The cached position determines where + the sequence is at among the batch of prefills. + """ + override_backend_env_variable(monkeypatch, backend) + + with hf_runner(model, dtype=dtype) as hf_model: + hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) + + cached_prompt = example_prompts[cached_position] + with vllm_runner( + model, + dtype=dtype, + enable_prefix_caching=True, + use_v2_block_manager=use_v2_block_manager, + ) as vllm_model: + # Run the first prompt so the cache is populated + vllm_outputs = vllm_model.generate_greedy([cached_prompt], max_tokens) + + # Run all the promopts + vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) + + check_outputs_equal( + outputs_0_lst=hf_outputs, + outputs_1_lst=vllm_outputs, + name_0="hf", + name_1="vllm", + ) diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 7d7aff9dc3cdc..58100d6db2ae6 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -209,6 +209,7 @@ def __init__(self, input_builder: "ModelInputForGPUBuilder"): self.num_prefills = 0 self.num_prefill_tokens = 0 self.num_decode_tokens = 0 + self.has_prefix_cache_hit = False self.input_builder = input_builder self.runner = input_builder.runner @@ -219,7 +220,7 @@ def __init__(self, input_builder: "ModelInputForGPUBuilder"): def _add_seq_group( self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup", - chunked_prefill_enabled: bool): + chunked_prefill_enabled: bool, prefix_cache_hit: bool): """Add a sequence group to the metadata. Specifically update/append 1. context length. 2. block table. @@ -252,7 +253,7 @@ def _add_seq_group( # only allowing multiple of block_size chunk size. # NOTE: This only works for oooooooxxx style attention. block_table = [] - if inter_data.prefix_cache_hit: + if prefix_cache_hit: # NOTE(woosuk): For flash-attn, the block table should # include the entries for the incoming prefill tokens. block_table = block_tables[seq_id] @@ -281,9 +282,14 @@ def build(self, seq_lens: List[int], query_lens: List[int], -1 if cuda graph is not used. batch_size: The maybe padded batch size. """ + prefix_cache_hit = any([ + inter_data.prefix_cache_hit + for inter_data in self.input_builder.inter_data_list + ]) for inter_data in self.input_builder.inter_data_list: self._add_seq_group(inter_data, - self.input_builder.chunked_prefill_enabled) + self.input_builder.chunked_prefill_enabled, + prefix_cache_hit) device = self.runner.device use_captured_graph = cuda_graph_pad_size != -1