Skip to content

Commit

Permalink
Add integrations tests for the vLLM integration
Browse files Browse the repository at this point in the history
  • Loading branch information
rlouf committed Apr 10, 2024
1 parent aacc633 commit ae9ae50
Show file tree
Hide file tree
Showing 3 changed files with 126 additions and 5 deletions.
2 changes: 2 additions & 0 deletions outlines/models/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ def generate(
if max_tokens is not None:
sampling_params.max_tokens = max_tokens
if stop_at is not None:
if isinstance(stop_at, str):
stop_at = [stop_at]
sampling_params.stop = stop_at
if seed is not None:
sampling_params.seed = seed
Expand Down
2 changes: 1 addition & 1 deletion tests/generate/test_integration_llamacpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def test_llamacpp_date(model):
prompt = (
"<|im_start|>user\nWhat day is it today?<|im_end|>\n<|im_start|>assistant\n"
)
sequence = generate.format(model, datetime.date)(prompt, max_tokens=10)
sequence = generate.format(model, datetime.date)(prompt, max_tokens=20, seed=10)
assert isinstance(sequence, datetime.date)


Expand Down
127 changes: 123 additions & 4 deletions tests/generate/test_integration_vllm.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
import datetime
import re

import pytest
import torch
from pydantic import BaseModel, constr
from vllm.sampling_params import SamplingParams

import outlines.generate as generate
import outlines.grammars as grammars
import outlines.models as models
import outlines.samplers as samplers

Expand Down Expand Up @@ -114,7 +119,121 @@ def test_vllm_beam_search(model):
assert res[0] != res[1]


@pytest.mark.xfail(reason="CFG logits processor not available for vLLM")
def test_cfg_simple(model):
generator = generate.cfg(model)
_ = generator("test")
def test_vllm_text_stop(model):
prompt = "Write a short sentence containing 'You': "
sequence = generate.text(model)(prompt, max_tokens=100, seed=10)
assert sequence.find("news") != -1

sequence = generate.text(model)(prompt, stop_at="news", max_tokens=100, seed=10)
assert isinstance(sequence, str)
assert sequence.find("news") == -1


def test_vllm_regex(model):
prompt = "Write an email address: "
regex_str = r"([a-z]{10})@([a-z]{5})\.([a-z]{3})"
generator = generate.regex(model, regex_str)

# One prompt
sequence = generator(prompts=prompt)
assert isinstance(sequence, str)
assert re.fullmatch(pattern=regex_str, string=sequence) is not None


def test_vllm_integer(model):
prompt = "Give me an integer: "
sequence = generate.format(model, int)(prompt, max_tokens=10)
assert isinstance(sequence, int)
assert sequence != ""
int(sequence)


def test_vllm_float(model):
prompt = "Give me a floating-point number: "
sequence = generate.format(model, float)(prompt, max_tokens=10)
assert isinstance(sequence, float)

assert sequence != ""
float(sequence)


def test_vllm_bool(model):
prompt = "Is this True or False? "
sequence = generate.format(model, bool)(prompt, max_tokens=10)
assert isinstance(sequence, bool)

assert sequence != ""
bool(sequence)


def test_vllm_date(model):
prompt = "What day is it today? "
sequence = generate.format(model, datetime.date)(prompt, max_tokens=10)
assert isinstance(sequence, datetime.date)


def test_vllm_time(model):
prompt = "What time is it? "
sequence = generate.format(model, datetime.time)(prompt, max_tokens=10)
assert isinstance(sequence, datetime.time)


def test_vllm_datetime(model):
prompt = "What time is it? "
sequence = generate.format(model, datetime.datetime)(prompt, max_tokens=20)
assert isinstance(sequence, datetime.datetime)


def test_vllm_choice(model):
prompt = "Which one between 'test' and 'choice'? "
sequence = generate.choice(model, ["test", "choice"])(prompt)
assert sequence == "test" or sequence == "choice"


def test_vllm_json_basic(model):
prompt = "Output some JSON. "

class Spam(BaseModel):
spam: constr(max_length=10)
fuzz: bool

sampling_params = SamplingParams(temperature=0)
result = generate.json(model, Spam, whitespace_pattern="")(
prompt, max_tokens=100, seed=1, sampling_params=sampling_params
)
assert isinstance(result, BaseModel)
assert isinstance(result.spam, str)
assert isinstance(result.fuzz, bool)
assert len(result.spam) <= 10


def test_vllm_json_schema(model):
prompt = "Output some JSON. "

schema = """{
"title": "spam",
"type": "object",
"properties": {
"foo" : {"type": "boolean"},
"bar": {"type": "string", "maxLength": 4}
},
"required": ["foo", "bar"]
}
"""

sampling_params = SamplingParams(temperature=0)
result = generate.json(model, schema, whitespace_pattern="")(
prompt, max_tokens=100, seed=10, sampling_params=sampling_params
)
assert isinstance(result, dict)
assert isinstance(result["foo"], bool)
assert isinstance(result["bar"], str)


@pytest.mark.xfail(
reason="The CFG logits processor for vLLM has not been implemented yet."
)
def test_vllm_cfg(model):
prompt = "<|im_start|>user\nOutput a short and valid JSON object with two keys.<|im_end|>\n><|im_start|>assistant\n"
result = generate.cfg(model, grammars.arithmetic)(prompt, seed=11)
assert isinstance(result, str)

0 comments on commit ae9ae50

Please sign in to comment.