Skip to content

Commit c05b894

Browse files
sasha0552whyiug
authored andcommitted
[Bugfix] Fix illegal memory access error with chunked prefill, prefix caching, block manager v2 and xformers enabled together (vllm-project#9532)
Signed-off-by: sasha0552 <[email protected]>
1 parent ea25a5f commit c05b894

File tree

2 files changed

+34
-4
lines changed

2 files changed

+34
-4
lines changed

tests/prefix_caching/test_prefix_caching.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,21 @@
1111
from vllm.block import PhysicalTokenBlock
1212
from vllm.core.block_manager_v1 import CachedBlockAllocator
1313
from vllm.utils import Device
14-
14+
from vllm import SamplingParams, TokensPrompt
1515
from ..models.utils import check_outputs_equal
1616

1717
MODELS = [
1818
"facebook/opt-125m",
1919
]
2020

21+
UNSTABLE_PROMPT_SEQUENCE = [
22+
([0] * 588) + ([1] * 1332) + ([2] * 30) + ([3] * 1),
23+
([0] * 588) + ([1] * 1332) + ([4] * 3) + ([5] * 50),
24+
([0] * 588) + ([1] * 1332) + ([2] * 30) + ([6] * 95),
25+
([0] * 588) + ([1] * 1332) + ([4] * 3) + ([7] * 174),
26+
([0] * 588) + ([8] * 1539),
27+
]
28+
2129

2230
@pytest.fixture(scope="module", autouse=True)
2331
def check_deprecated_block_manager():
@@ -146,3 +154,22 @@ def test_mixed_requests(
146154
name_0="hf",
147155
name_1="vllm",
148156
)
157+
158+
159+
@pytest.mark.parametrize("backend", ["FLASH_ATTN", "FLASHINFER", "XFORMERS"])
160+
def test_unstable_prompt_sequence(
161+
vllm_runner,
162+
backend: str,
163+
monkeypatch,
164+
) -> None:
165+
override_backend_env_variable(monkeypatch, backend)
166+
167+
with vllm_runner(
168+
"Qwen/Qwen2.5-0.5B-Instruct",
169+
enable_chunked_prefill=True,
170+
enable_prefix_caching=True,
171+
max_model_len=4096,
172+
) as vllm_model:
173+
for prompt in UNSTABLE_PROMPT_SEQUENCE:
174+
vllm_model.generate(TokensPrompt(prompt_token_ids=prompt),
175+
SamplingParams(max_tokens=1))

vllm/attention/backends/utils.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,6 @@ def _add_seq_group(
146146
chunked_prefill_enabled: bool):
147147
is_prompt = inter_data.is_prompt
148148
block_tables = inter_data.block_tables
149-
computed_block_nums = inter_data.computed_block_nums
150149

151150
for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len,
152151
curr_sliding_window_block) in zip(
@@ -172,10 +171,14 @@ def _add_seq_group(
172171
# NOTE: This only works for oooooooxxx style attention.
173172
block_table = []
174173
if inter_data.prefix_cache_hit:
175-
block_table = computed_block_nums
174+
block_table = block_tables[seq_id]
176175
elif ((chunked_prefill_enabled or not is_prompt)
177176
and block_tables is not None):
178-
block_table = block_tables[seq_id][-curr_sliding_window_block:]
177+
if curr_sliding_window_block == 0:
178+
block_table = block_tables[seq_id]
179+
else:
180+
block_table = block_tables[seq_id][
181+
-curr_sliding_window_block:]
179182
self.block_tables.append(block_table)
180183

181184
# Compute slot mapping.

0 commit comments

Comments
 (0)