Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bugfix] Fix block table for seqs that have prefix cache hits #7018

Merged
merged 5 commits into from
Aug 3, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 53 additions & 0 deletions tests/prefix_caching/test_prefix_caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,13 @@
from vllm.core.block_manager_v1 import CachedBlockAllocator
from vllm.utils import Device

from tests.kernels.utils import override_backend_env_variable
from ..models.utils import check_outputs_equal

MODELS = [
"facebook/opt-125m",
]


@pytest.mark.parametrize("block_size", [16])
@pytest.mark.parametrize("num_blocks", [16])
Expand Down Expand Up @@ -76,3 +83,49 @@ def test_eviction(num_blocks: int, ):
assert (realloc_block != new_block)
assert (new_block.block_hash == new_block_hash)
assert (new_block.block_number == 2)


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("backend", ["FLASH_ATTN", "FLASHINFER", "XFORMERS"])
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [5])
@pytest.mark.parametrize("use_v2_block_manager", [False, True])
def test_mixed_requests(
hf_runner,
vllm_runner,
example_prompts,
model: str,
backend: str,
dtype: str,
max_tokens: int,
use_v2_block_manager: bool,
monkeypatch,
) -> None:
"""
Test the case when some sequences have the prefix cache hit
and the others don't.
"""
override_backend_env_variable(monkeypatch, backend)

with hf_runner(model, dtype=dtype) as hf_model:
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)

cached_prompt = example_prompts[0]
with vllm_runner(
model,
dtype=dtype,
enable_prefix_caching=True,
use_v2_block_manager=use_v2_block_manager,
) as vllm_model:
# Run the first prompt so the cache is populated
vllm_outputs = vllm_model.generate_greedy([cached_prompt], max_tokens)

# Run all the promopts
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)

check_outputs_equal(
outputs_0_lst=hf_outputs,
outputs_1_lst=vllm_outputs,
name_0="hf",
name_1="vllm",
)
4 changes: 3 additions & 1 deletion vllm/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,7 @@ def __init__(self, input_builder: "ModelInputForGPUBuilder"):
self.num_prefills = 0
self.num_prefill_tokens = 0
self.num_decode_tokens = 0
self.has_prefix_cache_hit = False

self.input_builder = input_builder
self.runner = input_builder.runner
Expand Down Expand Up @@ -252,7 +253,8 @@ def _add_seq_group(
# only allowing multiple of block_size chunk size.
# NOTE: This only works for oooooooxxx style attention.
block_table = []
if inter_data.prefix_cache_hit:
self.has_prefix_cache_hit |= inter_data.prefix_cache_hit
if self.has_prefix_cache_hit:
# NOTE(woosuk): For flash-attn, the block table should
# include the entries for the incoming prefill tokens.
block_table = block_tables[seq_id]
Expand Down
Loading