Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added Support for guided decoding in offline interface #4130

Closed
wants to merge 37 commits into from
Closed
Show file tree
Hide file tree
Changes from 27 commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
a5aeec2
first commit, added extra parameters for offline LLM, issue with work…
kevinbu233 Apr 16, 2024
dcbfc69
clean up cod code
kevinbu233 Apr 16, 2024
bd4d84c
add skip cleanup
simon-mo Apr 17, 2024
76c0924
Changed model and sampling parameters, cleaned up naming
kevinbu233 Apr 17, 2024
4709a98
cleand up code and created helper functions
kevinbu233 Apr 18, 2024
aab046d
first merge resolved
kevinbu233 Apr 18, 2024
84b1442
fix format
kevinbu233 Apr 19, 2024
1d6abe4
fix merge conflict
kevinbu233 Apr 19, 2024
daa2c8f
added docstrings for sampling params guided options
kevinbu233 Apr 23, 2024
35c73ee
fix merge conflict with main
kevinbu233 Apr 23, 2024
9e454c5
:erge remote-tracking branch 'upstream/main' into yihuan_issue3536
kevinbu233 Apr 23, 2024
ef4cf6f
fixed support for multiple sampling params for LLM
kevinbu233 Apr 23, 2024
945125d
Merge remote-tracking branch 'upstream/main' into yihuan_issue3536
kevinbu233 Apr 23, 2024
1c0769d
added noqa for extra long line
kevinbu233 Apr 24, 2024
062f0bc
Update tests/entrypoints/test_local_LLM.py
simon-mo May 1, 2024
05a2512
Merge branch 'main' of github.com:vllm-project/vllm into yihuan_issue…
simon-mo May 2, 2024
438ab37
fix typing
simon-mo May 2, 2024
06c2205
fix test and more refactoring
simon-mo May 3, 2024
4158d78
use x2
simon-mo May 3, 2024
0d9e5a5
lint
simon-mo May 3, 2024
ff9ba7f
fix isort
simon-mo May 3, 2024
bbb59bf
merge with main
kevinbu233 May 27, 2024
d779f86
fixing merge issues
kevinbu233 May 30, 2024
f923677
Merge remote-tracking branch 'upstream/main' into yihuan_issue3536
kevinbu233 May 30, 2024
a15b511
fixed merge
kevinbu233 May 30, 2024
292264a
Merge remote-tracking branch 'upstream/main' into yihuan_issue3536
kevinbu233 May 30, 2024
77d42a8
format
kevinbu233 May 30, 2024
60ab6f6
finished merge first draft
kevinbu233 Jun 14, 2024
ae23772
merged main
kevinbu233 Jun 14, 2024
3fb6258
Merge remote-tracking branch 'upstream/main' into yihuan_issue3536
kevinbu233 Jun 19, 2024
75cf9a7
fixed merge conflict and fixed suggestions
kevinbu233 Jun 19, 2024
c429ef8
Merge remote-tracking branch 'upstream/main' into yihuan_issue3536
kevinbu233 Jun 21, 2024
da89c1b
Merge remote-tracking branch 'upstream/main' into yihuan_issue3536
kevinbu233 Jun 22, 2024
6c8b82a
fix test_openai error
kevinbu233 Jun 23, 2024
4e759d9
Merge remote-tracking branch 'upstream/main' into yihuan_issue3536
kevinbu233 Jun 24, 2024
4ac8abb
fixing response_format test
kevinbu233 Jun 24, 2024
2eedeba
temporay push
kevinbu233 Jun 25, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,12 @@
from transformers import (AutoModelForCausalLM, AutoProcessor, AutoTokenizer,
LlavaConfig, LlavaForConditionalGeneration)

from vllm import LLM, SamplingParams
from vllm.config import TokenizerPoolConfig, VisionLanguageConfig
from vllm.distributed import destroy_model_parallel
from vllm.entrypoints.llm import LLM
from vllm.inputs import PromptInputs
from vllm.logger import init_logger
from vllm.sampling_params import SamplingParams
from vllm.sequence import MultiModalData

logger = init_logger(__name__)
Expand Down
72 changes: 72 additions & 0 deletions tests/entrypoints/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import pytest


@pytest.fixture
def sample_regex():
return (r"((25[0-5]|(2[0-4]|1\d|[1-9]|)\d)\.){3}"
r"(25[0-5]|(2[0-4]|1\d|[1-9]|)\d)")


@pytest.fixture
def sample_json_schema():
return {
"type": "object",
"properties": {
"name": {
"type": "string"
},
"age": {
"type": "integer"
},
"skills": {
"type": "array",
"items": {
"type": "string",
"maxLength": 10
},
"minItems": 3
},
"work history": {
"type": "array",
"items": {
"type": "object",
"properties": {
"company": {
"type": "string"
},
"duration": {
"type": "number"
},
"position": {
"type": "string"
}
},
"required": ["company", "position"]
}
}
},
"required": ["name", "age", "skills", "work history"]
}


@pytest.fixture
def sample_guided_choice():
return [
"Python", "Java", "JavaScript", "C++", "C#", "PHP", "TypeScript",
"Ruby", "Swift", "Kotlin"
]


@pytest.fixture
def sample_sql_statements():
return ("""
start: select_statement

select_statement: "SELECT" column "from" table "where" condition

column: "col_1" | "col_2"
table: "table_1" | "table_2"
condition: column "=" number

number: "1" | "2"
""")
83 changes: 19 additions & 64 deletions tests/entrypoints/test_guided_processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,68 +4,23 @@
import torch
from transformers import AutoTokenizer

from vllm.entrypoints.openai.protocol import CompletionRequest
from vllm.model_executor.guided_decoding import (
get_guided_decoding_logits_processor)
GuidedDecodingFields, get_guided_decoding_logits_processor)
from vllm.model_executor.guided_decoding.outlines_logits_processors import (
JSONLogitsProcessor, RegexLogitsProcessor)

TEST_SCHEMA = {
"type": "object",
"properties": {
"name": {
"type": "string"
},
"age": {
"type": "integer"
},
"skills": {
"type": "array",
"items": {
"type": "string",
"maxLength": 10
},
"minItems": 3
},
"work history": {
"type": "array",
"items": {
"type": "object",
"properties": {
"company": {
"type": "string"
},
"duration": {
"type": "string"
},
"position": {
"type": "string"
}
},
"required": ["company", "position"]
}
}
},
"required": ["name", "age", "skills", "work history"]
}

TEST_REGEX = (r"((25[0-5]|(2[0-4]|1\d|[1-9]|)\d)\.){3}"
r"(25[0-5]|(2[0-4]|1\d|[1-9]|)\d)")

pytestmark = pytest.mark.openai


def test_guided_logits_processors():
def test_guided_logits_processors(sample_regex, sample_json_schema):
"""Basic unit test for RegexLogitsProcessor and JSONLogitsProcessor."""
tokenizer = AutoTokenizer.from_pretrained('HuggingFaceH4/zephyr-7b-beta')
regex_LP = RegexLogitsProcessor(TEST_REGEX, tokenizer)
json_LP = JSONLogitsProcessor(TEST_SCHEMA,
regex_LP = RegexLogitsProcessor(sample_regex, tokenizer)
json_LP = JSONLogitsProcessor(sample_json_schema,
tokenizer,
whitespace_pattern=None)

regex_LP.init_state()
token_ids = tokenizer.encode(
f"Give an example IPv4 address with this regex: {TEST_REGEX}")
f"Give an example IPv4 address with this regex: {sample_regex}")
tensor = torch.rand(32000)
original_tensor = torch.clone(tensor)
regex_LP(token_ids, tensor)
Expand All @@ -74,7 +29,8 @@ def test_guided_logits_processors():

json_LP.init_state()
token_ids = tokenizer.encode(
f"Give an employee profile that fits this schema: {TEST_SCHEMA}")
f"Give an employee profile that fits this schema: {sample_json_schema}"
)
tensor = torch.rand(32000)
original_tensor = torch.clone(tensor)
json_LP(token_ids, tensor)
Expand All @@ -84,15 +40,15 @@ def test_guided_logits_processors():

@pytest.mark.asyncio
@pytest.mark.parametrize("backend", ["outlines", "lm-format-enforcer"])
async def test_guided_logits_processor_black_box(backend: str):
async def test_guided_logits_processor_black_box(sample_regex,
sample_json_schema,
backend: str):
tokenizer = AutoTokenizer.from_pretrained('HuggingFaceH4/zephyr-7b-beta')
token_ids = tokenizer.encode(
f"Give an example IPv4 address with this regex: {TEST_REGEX}")
regex_request = CompletionRequest(model='test',
prompt=token_ids,
guided_regex=TEST_REGEX)
regex_lp = await get_guided_decoding_logits_processor(
backend, regex_request, tokenizer)
f"Give an example IPv4 address with this regex: {sample_regex}")
regex_lp = get_guided_decoding_logits_processor(
GuidedDecodingFields(guided_regex=sample_regex,
guided_decoding_backend=backend), tokenizer)
assert regex_lp is not None
tensor = torch.rand(32000)
original_tensor = torch.clone(tensor)
Expand All @@ -101,12 +57,11 @@ async def test_guided_logits_processor_black_box(backend: str):
assert not torch.allclose(tensor, original_tensor)

token_ids = tokenizer.encode(
f"Give an employee profile that fits this schema: {TEST_SCHEMA}")
json_request = CompletionRequest(model='test',
prompt=token_ids,
guided_json=TEST_SCHEMA)
json_lp = await get_guided_decoding_logits_processor(
backend, json_request, tokenizer)
f"Give an employee profile that fits this schema: {sample_json_schema}"
)
json_lp = get_guided_decoding_logits_processor(
GuidedDecodingFields(guided_json=sample_json_schema,
guided_decoding_backend=backend), tokenizer)
assert json_lp is not None
tensor = torch.rand(32000)
original_tensor = torch.clone(tensor)
Expand Down
141 changes: 131 additions & 10 deletions tests/entrypoints/test_llm_generate.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,25 @@
import json
import re
import weakref
from typing import List

import jsonschema
import pytest

from vllm import LLM, RequestOutput, SamplingParams
from vllm.entrypoints.llm import LLM
from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams

from ..conftest import cleanup

MODEL_NAME = "facebook/opt-125m"
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"

PROMPTS = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]

TOKEN_IDS = [
[0],
[0, 1],
Expand All @@ -30,11 +34,7 @@
def llm():
# pytest caches the fixture so we use weakref.proxy to
# enable garbage collection
llm = LLM(model=MODEL_NAME,
max_num_batched_tokens=4096,
tensor_parallel_size=1,
gpu_memory_utilization=0.10,
enforce_eager=True)
llm = LLM(model=MODEL_NAME, max_model_len=1024)

with llm.deprecate_legacy_api():
yield weakref.proxy(llm)
Expand Down Expand Up @@ -119,6 +119,13 @@ def test_v1_v2_api_consistency_multi_prompt_tokens(llm: LLM):

@pytest.mark.skip_global_cleanup
def test_multiple_sampling_params(llm: LLM):
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]

sampling_params = [
SamplingParams(temperature=0.01, top_p=0.95),
SamplingParams(temperature=0.3, top_p=0.95),
Expand All @@ -140,5 +147,119 @@ def test_multiple_sampling_params(llm: LLM):
assert len(PROMPTS) == len(outputs)

# sampling_params is None, default params should be applied
outputs = llm.generate(PROMPTS, sampling_params=None)
assert len(PROMPTS) == len(outputs)
outputs = llm.generate(prompts, sampling_params=None)
assert len(prompts) == len(outputs)


@pytest.mark.skip_global_cleanup
def test_guided_regex(sample_regex, llm):
sampling_params = SamplingParams(
temperature=0.8,
top_p=0.95,
guided_options=dict(guided_regex=sample_regex))
outputs = llm.generate(
prompts=[
f"Give an example IPv4 address with this regex: {sample_regex}"
] * 2,
sampling_params=sampling_params,
use_tqdm=True,
)

assert outputs is not None
for output in outputs:
assert output is not None
assert isinstance(output, RequestOutput)
prompt = output.prompt
generated_text = output.outputs[0].text
assert generated_text is not None
assert re.fullmatch(sample_regex, generated_text) is not None
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")


@pytest.mark.skip_global_cleanup
def test_guided_json_completion(sample_json_schema, llm):
sampling_params = SamplingParams(
temperature=1.0,
max_tokens=1000,
guided_options=dict(guided_json=sample_json_schema))
outputs = llm.generate(
prompts=[
f"Give an example JSON for an employee profile "
f"that fits this schema: {sample_json_schema}"
] * 2,
sampling_params=sampling_params,
use_tqdm=True,
)

assert outputs is not None

for output in outputs:
assert output is not None
assert isinstance(output, RequestOutput)
prompt = output.prompt

generated_text = output.outputs[0].text
assert generated_text is not None
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
output_json = json.loads(generated_text)
jsonschema.validate(instance=output_json, schema=sample_json_schema)


@pytest.mark.skip_global_cleanup
def test_guided_choice_completion(sample_guided_choice, llm):
sampling_params = SamplingParams(
temperature=0.8,
top_p=0.95,
guided_options=dict(guided_choice=sample_guided_choice))
outputs = llm.generate(
prompts="The best language for type-safe systems programming is ",
sampling_params=sampling_params,
use_tqdm=True,
)

assert outputs is not None
for output in outputs:
assert output is not None
assert isinstance(output, RequestOutput)
prompt = output.prompt
generated_text = output.outputs[0].text
assert generated_text is not None
assert generated_text in sample_guided_choice
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")


@pytest.mark.skip_global_cleanup
def test_guided_grammar(sample_sql_statements, llm):

sampling_params = SamplingParams(
temperature=0.8,
top_p=0.95,
guided_options=dict(guided_grammar=sample_sql_statements))
outputs = llm.generate(
prompts=("Generate a sql state that select col_1 from "
"table_1 where it is equals to 1"),
sampling_params=sampling_params,
use_tqdm=True,
)

assert outputs is not None
for output in outputs:
assert output is not None
assert isinstance(output, RequestOutput)
prompt = output.prompt

generated_text = output.outputs[0].text
assert generated_text is not None

# use Lark to parse the output, and make sure it's a valid parse tree
from lark import Lark
parser = Lark(sample_sql_statements)
parser.parse(generated_text)

# remove spaces for comparison b/c we removed them in the grammar
ground_truth = "SELECT col_1 from table_1 where col_1 = 1".replace(
" ", "")

assert generated_text.strip() == ground_truth

print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
Loading
Loading