Skip to content

Commit

Permalink
fix tests s.t. they mock forgetting the logits processor
Browse files Browse the repository at this point in the history
  • Loading branch information
Andrew Lapp committed Feb 5, 2024
1 parent 6b2035e commit e46aae7
Showing 1 changed file with 9 additions and 1 deletion.
10 changes: 9 additions & 1 deletion tests/serve/test_vllm.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import re

import pytest
import torch

from outlines.serve.vllm import RegexLogitsProcessor, _patched_apply_logits_processors
Expand Down Expand Up @@ -41,14 +42,21 @@ def sample_from_logits(logits):
return torch.multinomial(probs, 1).item()


def test_time_regexp():
@pytest.mark.parametrize("forget_logits_processor", [True, False])
def test_time_regexp(forget_logits_processor):
pattern = r"(0?[1-9]|1[0-2]):[0-5]\d\s?(am|pm)?"
llm = MockModel()
logits_processor = RegexLogitsProcessor(pattern, llm)

token_ids = []
while True:
random_scores = -10 + 20 * torch.rand(len(llm.tokenizer.vocabulary))

# mock "forgetting" the logits processor behavior in
# vLLM tensor-parallel world size > 1
if forget_logits_processor:
logits_processor = RegexLogitsProcessor(pattern, llm)

logits = logits_processor(
input_ids=token_ids,
scores=random_scores,
Expand Down

0 comments on commit e46aae7

Please sign in to comment.