Skip to content

Commit

Permalink
Fix cases where non-cached are after cached
Browse files Browse the repository at this point in the history
  • Loading branch information
zachzzc committed Aug 2, 2024
1 parent fa0d7ee commit 10b2e04
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 6 deletions.
7 changes: 5 additions & 2 deletions tests/prefix_caching/test_prefix_caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ def test_eviction(num_blocks: int, ):
@pytest.mark.parametrize("backend", ["FLASH_ATTN", "FLASHINFER", "XFORMERS"])
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [5])
@pytest.mark.parametrize("cached_position", [0, 1])
@pytest.mark.parametrize("use_v2_block_manager", [False, True])
def test_mixed_requests(
hf_runner,
Expand All @@ -98,19 +99,21 @@ def test_mixed_requests(
backend: str,
dtype: str,
max_tokens: int,
cached_position: int,
use_v2_block_manager: bool,
monkeypatch,
) -> None:
"""
Test the case when some sequences have the prefix cache hit
and the others don't.
and the others don't. The cached position determines where
the sequence is at among the batch of prefills.
"""
override_backend_env_variable(monkeypatch, backend)

with hf_runner(model, dtype=dtype) as hf_model:
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)

cached_prompt = example_prompts[0]
cached_prompt = example_prompts[cached_position]
with vllm_runner(
model,
dtype=dtype,
Expand Down
12 changes: 8 additions & 4 deletions vllm/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ def __init__(self, input_builder: "ModelInputForGPUBuilder"):

def _add_seq_group(
self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup",
chunked_prefill_enabled: bool):
chunked_prefill_enabled: bool, prefix_cache_hit: bool):
"""Add a sequence group to the metadata. Specifically update/append
1. context length.
2. block table.
Expand Down Expand Up @@ -253,8 +253,7 @@ def _add_seq_group(
# only allowing multiple of block_size chunk size.
# NOTE: This only works for oooooooxxx style attention.
block_table = []
self.has_prefix_cache_hit |= inter_data.prefix_cache_hit
if self.has_prefix_cache_hit:
if prefix_cache_hit:
# NOTE(woosuk): For flash-attn, the block table should
# include the entries for the incoming prefill tokens.
block_table = block_tables[seq_id]
Expand Down Expand Up @@ -283,9 +282,14 @@ def build(self, seq_lens: List[int], query_lens: List[int],
-1 if cuda graph is not used.
batch_size: The maybe padded batch size.
"""
prefix_cache_hit = any([
inter_data.prefix_cache_hit
for inter_data in self.input_builder.inter_data_list
])
for inter_data in self.input_builder.inter_data_list:
self._add_seq_group(inter_data,
self.input_builder.chunked_prefill_enabled)
self.input_builder.chunked_prefill_enabled,
prefix_cache_hit)

device = self.runner.device
use_captured_graph = cuda_graph_pad_size != -1
Expand Down

0 comments on commit 10b2e04

Please sign in to comment.