diff --git a/tests/serve/test_vllm.py b/tests/serve/test_vllm.py index 8c6f1ca02..004f1f46f 100644 --- a/tests/serve/test_vllm.py +++ b/tests/serve/test_vllm.py @@ -1,5 +1,6 @@ import re +import pytest import torch from outlines.serve.vllm import RegexLogitsProcessor, _patched_apply_logits_processors @@ -41,7 +42,8 @@ def sample_from_logits(logits): return torch.multinomial(probs, 1).item() -def test_time_regexp(): +@pytest.mark.parametrize("forget_logits_processor", [True, False]) +def test_time_regexp(forget_logits_processor): pattern = r"(0?[1-9]|1[0-2]):[0-5]\d\s?(am|pm)?" llm = MockModel() logits_processor = RegexLogitsProcessor(pattern, llm) @@ -49,6 +51,12 @@ def test_time_regexp(): token_ids = [] while True: random_scores = -10 + 20 * torch.rand(len(llm.tokenizer.vocabulary)) + + # mock "forgetting" the logits processor behavior in + # vLLM tensor-parallel world size > 1 + if forget_logits_processor: + logits_processor = RegexLogitsProcessor(pattern, llm) + logits = logits_processor( input_ids=token_ids, scores=random_scores,