Skip to content

Commit

Permalink
Add test for logit processor
Browse files Browse the repository at this point in the history
  • Loading branch information
mory91 committed Jan 26, 2024
1 parent fdf7397 commit 6b96f38
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 4 deletions.
7 changes: 3 additions & 4 deletions outlines/serve/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from typing import Callable, DefaultDict, List

import torch
from vllm import LLMEngine

from outlines.fsm.fsm import CFGFSM, FSM, RegexFSM
from outlines.fsm.json_schema import build_regex_from_object
Expand Down Expand Up @@ -113,7 +112,7 @@ def __call__(


class RegexLogitsProcessor(FSMLogitsProcessor):
def __init__(self, regex_string, llm: LLMEngine):
def __init__(self, regex_string, llm):
"""Compile the FSM that drives the regex-guided generation.
Parameters
Expand All @@ -130,7 +129,7 @@ def __init__(self, regex_string, llm: LLMEngine):


class CFGLogitsProcessor(FSMLogitsProcessor):
def __init__(self, cfg_string, llm: LLMEngine):
def __init__(self, cfg_string, llm):
"""Compile the FSM that drives the cfg-guided generation.
Parameters
Expand All @@ -147,7 +146,7 @@ def __init__(self, cfg_string, llm: LLMEngine):


class JSONLogitsProcessor(RegexLogitsProcessor):
def __init__(self, schema, llm: LLMEngine):
def __init__(self, schema, llm):
"""Compile the FSM that drives the JSON-guided generation.
Parameters
Expand Down
44 changes: 44 additions & 0 deletions tests/test_vllm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import pytest
import torch
from transformers import AutoTokenizer

from outlines.serve.vllm import (
CFGLogitsProcessor,
JSONLogitsProcessor,
RegexLogitsProcessor,
)

TEST_REGEX = r"(-)?(0|[1-9][0-9]*)(.[0-9]+)?([eE][+-][0-9]+)?"
TEST_CFG = """
start: DECIMAL
DIGIT: "0".."9"
INT: DIGIT+
DECIMAL: INT "." INT? | "." INT
"""
TEST_SCHEMA = '{"type": "string", "maxLength": 5}'

LOGIT_PROCESSORS = (
(CFGLogitsProcessor, TEST_CFG),
(RegexLogitsProcessor, TEST_REGEX),
(JSONLogitsProcessor, TEST_SCHEMA),
)

TEST_MODEL = "hf-internal-testing/tiny-random-GPTJForCausalLM"


@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires cuda available")
@pytest.mark.parametrize("logit_processor, fsm_str", LOGIT_PROCESSORS)
def test_logit_processor(logit_processor, fsm_str: str):
class MockvLLMEngine:
def __init__(self, tokenizer):
self.tokenizer = tokenizer

def __call__(*_):
return torch.tensor([[0, 1, 2, 3, 4]], dtype=torch.float), None

tokenizer = AutoTokenizer.from_pretrained(TEST_MODEL)
engine = MockvLLMEngine(tokenizer)
logit_processor(fsm_str, engine)
assert isinstance(engine.tokenizer.decode([0, 1, 2, 3]), list)
logit_processor(fsm_str, engine)
assert isinstance(engine.tokenizer.decode([0, 1, 2, 3]), list)

0 comments on commit 6b96f38

Please sign in to comment.