|
5 | 5 | import pytest
|
6 | 6 |
|
7 | 7 | from tests.kernels.utils import override_backend_env_variable
|
| 8 | +from vllm import SamplingParams, TokensPrompt |
8 | 9 |
|
9 | 10 | from ..models.utils import check_outputs_equal
|
10 | 11 |
|
11 | 12 | MODELS = [
|
12 | 13 | "facebook/opt-125m",
|
13 | 14 | ]
|
14 | 15 |
|
| 16 | +UNSTABLE_PROMPT_SEQUENCE = [ |
| 17 | + ([0] * 588) + ([1] * 1332) + ([2] * 30) + ([3] * 1), |
| 18 | + ([0] * 588) + ([1] * 1332) + ([4] * 3) + ([5] * 50), |
| 19 | + ([0] * 588) + ([1] * 1332) + ([2] * 30) + ([6] * 95), |
| 20 | + ([0] * 588) + ([1] * 1332) + ([4] * 3) + ([7] * 174), |
| 21 | + ([0] * 588) + ([8] * 1539), |
| 22 | +] |
| 23 | + |
15 | 24 |
|
16 | 25 | @pytest.mark.parametrize("model", MODELS)
|
17 | 26 | @pytest.mark.parametrize("backend", ["FLASH_ATTN", "FLASHINFER", "XFORMERS"])
|
@@ -57,3 +66,22 @@ def test_mixed_requests(
|
57 | 66 | name_0="hf",
|
58 | 67 | name_1="vllm",
|
59 | 68 | )
|
| 69 | + |
| 70 | + |
| 71 | +@pytest.mark.parametrize("backend", ["FLASH_ATTN", "FLASHINFER", "XFORMERS"]) |
| 72 | +def test_unstable_prompt_sequence( |
| 73 | + vllm_runner, |
| 74 | + backend: str, |
| 75 | + monkeypatch, |
| 76 | +) -> None: |
| 77 | + override_backend_env_variable(monkeypatch, backend) |
| 78 | + |
| 79 | + with vllm_runner( |
| 80 | + "Qwen/Qwen2.5-0.5B-Instruct", |
| 81 | + enable_chunked_prefill=True, |
| 82 | + enable_prefix_caching=True, |
| 83 | + max_model_len=4096, |
| 84 | + ) as vllm_model: |
| 85 | + for prompt in UNSTABLE_PROMPT_SEQUENCE: |
| 86 | + vllm_model.generate(TokensPrompt(prompt_token_ids=prompt), |
| 87 | + SamplingParams(max_tokens=1)) |
0 commit comments