Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for guided decoding for offline LLM #6878

Merged
merged 13 commits into from
Aug 4, 2024
1 change: 1 addition & 0 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ def setup(app):
"tqdm",
"tensorizer",
"pynvml",
"outlines",
]

for mock_target in autodoc_mock_imports:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,26 @@
import pytest


@pytest.fixture
def sample_prompts():
return [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]


@pytest.fixture
def sample_token_ids():
return [
[0],
[0, 1],
[0, 2, 1],
[0, 3, 1, 2],
]


@pytest.fixture
def sample_regex():
return (r"((25[0-5]|(2[0-4]|1\d|[1-9]|)\d)\.){3}"
Expand Down Expand Up @@ -66,4 +86,4 @@ def sample_sql_statements():
table: "table_1" | "table_2"
condition: column "=" number
number: "1" | "2"
""")
""")
142 changes: 142 additions & 0 deletions tests/entrypoints/llm/test_guided_generate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
import json
import re
import weakref

import jsonschema
import pytest

from vllm.entrypoints.llm import LLM
from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams

from ...conftest import cleanup

MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"


@pytest.fixture(scope="module")
def llm():
# pytest caches the fixture so we use weakref.proxy to
# enable garbage collection
llm = LLM(model=MODEL_NAME, max_model_len=1024)

with llm.deprecate_legacy_api():
yield weakref.proxy(llm)
del llm
cleanup()


@pytest.mark.skip_global_cleanup
def test_guided_regex(sample_regex, llm):
sampling_params = SamplingParams(
temperature=0.8,
top_p=0.95,
)
outputs = llm.generate(
prompts=[
f"Give an example IPv4 address with this regex: {sample_regex}"
] * 2,
sampling_params=sampling_params,
use_tqdm=True,
guided_options_request=dict(guided_regex=sample_regex))

assert outputs is not None
for output in outputs:
assert output is not None
assert isinstance(output, RequestOutput)
prompt = output.prompt
generated_text = output.outputs[0].text
print(generated_text)
assert generated_text is not None
assert re.fullmatch(sample_regex, generated_text) is not None
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")


@pytest.mark.skip_global_cleanup
def test_guided_json_completion(sample_json_schema, llm):
sampling_params = SamplingParams(
temperature=1.0,
max_tokens=1000,
)
outputs = llm.generate(
prompts=[
f"Give an example JSON for an employee profile "
f"that fits this schema: {sample_json_schema}"
] * 2,
sampling_params=sampling_params,
use_tqdm=True,
guided_options_request=dict(guided_json=sample_json_schema))

assert outputs is not None

for output in outputs:
assert output is not None
assert isinstance(output, RequestOutput)
prompt = output.prompt

generated_text = output.outputs[0].text
assert generated_text is not None
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
output_json = json.loads(generated_text)
jsonschema.validate(instance=output_json, schema=sample_json_schema)


@pytest.mark.skip_global_cleanup
def test_guided_choice_completion(sample_guided_choice, llm):
sampling_params = SamplingParams(
temperature=0.8,
top_p=0.95,
)
outputs = llm.generate(
prompts="The best language for type-safe systems programming is ",
sampling_params=sampling_params,
use_tqdm=True,
guided_options_request=dict(guided_choice=sample_guided_choice))

assert outputs is not None
for output in outputs:
assert output is not None
assert isinstance(output, RequestOutput)
prompt = output.prompt
generated_text = output.outputs[0].text
print(generated_text)
assert generated_text is not None
assert generated_text in sample_guided_choice
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")


@pytest.mark.skip_global_cleanup
def test_guided_grammar(sample_sql_statements, llm):

sampling_params = SamplingParams(
temperature=0.8,
top_p=0.95,
max_tokens=1000,
)
outputs = llm.generate(
prompts=("Generate a sql state that select col_1 from "
"table_1 where it is equals to 1"),
sampling_params=sampling_params,
use_tqdm=True,
guided_options_request=dict(guided_grammar=sample_sql_statements))

assert outputs is not None
for output in outputs:
assert output is not None
assert isinstance(output, RequestOutput)
prompt = output.prompt

generated_text = output.outputs[0].text
assert generated_text is not None
# use Lark to parse the output, and make sure it's a valid parse tree
from lark import Lark
parser = Lark(sample_sql_statements)
parser.parse(generated_text)

# remove spaces for comparison b/c we removed them in the grammar
ground_truth = "SELECT col_1 from table_1 where col_1 = 1".replace(
" ", "")

assert generated_text.strip() == ground_truth

print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
44 changes: 43 additions & 1 deletion vllm/entrypoints/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
parse_and_batch_prompt)
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor.guided_decoding import (
GuidedDecodingRequest, get_local_guided_decoding_logits_processor)
from vllm.model_executor.guided_decoding.guided_fields import LLMGuidedOptions
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
Expand Down Expand Up @@ -262,6 +265,8 @@ def generate(
use_tqdm: bool = True,
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
guided_options_request: Optional[Union[LLMGuidedOptions,
GuidedDecodingRequest]] = None
) -> List[RequestOutput]:
"""Generates the completions for the input prompts.

Expand Down Expand Up @@ -303,6 +308,14 @@ def generate(
else:
inputs = cast(Union[PromptInputs, Sequence[PromptInputs]], prompts)

if isinstance(guided_options_request, dict):
if len(guided_options_request) > 1:
raise ValueError(
"You can only use one guided decoding but multiple is "
f"specified: {guided_options_request}")
guided_options_request = GuidedDecodingRequest(
**guided_options_request)

if sampling_params is None:
# Use default sampling params.
sampling_params = SamplingParams()
Expand All @@ -311,7 +324,8 @@ def generate(
inputs=inputs,
params=sampling_params,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request)
prompt_adapter_request=prompt_adapter_request,
guided_options=guided_options_request)

outputs = self._run_engine(use_tqdm=use_tqdm)
return LLMEngine.validate_outputs(outputs, RequestOutput)
Expand Down Expand Up @@ -508,6 +522,7 @@ def _validate_and_add_requests(
Sequence[PoolingParams]],
lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]],
prompt_adapter_request: Optional[PromptAdapterRequest],
guided_options: Optional[GuidedDecodingRequest] = None,
) -> None:
if isinstance(inputs, (str, dict)):
# Convert a single prompt to a list.
Expand All @@ -523,6 +538,15 @@ def _validate_and_add_requests(
raise ValueError("The lengths of prompts and lora_request "
"must be the same.")

if isinstance(params, list):
params = [
self._add_guided_processor(param, guided_options)
if isinstance(param, SamplingParams) else param
for param in params
]
elif isinstance(params, SamplingParams):
params = self._add_guided_processor(params, guided_options)

# Add requests to the engine.
for i, request_inputs in enumerate(inputs):
self._add_request(
Expand All @@ -548,6 +572,24 @@ def _add_request(
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request)

def _add_guided_processor(
self,
params: SamplingParams,
guided_options: Optional[GuidedDecodingRequest] = None):
if guided_options:
if guided_options.guided_decoding_backend is None:
decoding_config = self.llm_engine.get_decoding_config()
guided_options.guided_decoding_backend = (
decoding_config.guided_decoding_backend)
guided_logits_processor = get_local_guided_decoding_logits_processor( #noqa
guided_options.guided_decoding_backend, guided_options,
self.get_tokenizer())
if guided_logits_processor:
if params.logits_processors is None:
params.logits_processors = []
params.logits_processors.append(guided_logits_processor)
return params

def _run_engine(
self, *, use_tqdm: bool
) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
Expand Down
26 changes: 20 additions & 6 deletions vllm/entrypoints/openai/protocol.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Adapted from
# https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py
import time
from argparse import Namespace
from typing import Any, Dict, List, Literal, Optional, Union

import torch
Expand All @@ -14,6 +15,23 @@
from vllm.sampling_params import LogitsProcessor, SamplingParams
from vllm.utils import random_uuid

# torch is mocked during docs generation,
# so we have to provide the values as literals
_MOCK_LONG_INFO = Namespace(min=-9223372036854775808, max=9223372036854775807)

try:
from sphinx.ext.autodoc.mock import _MockModule

if isinstance(torch, _MockModule):
_LONG_INFO = _MOCK_LONG_INFO
else:
_LONG_INFO = torch.iinfo(torch.long)
except ModuleNotFoundError:
_LONG_INFO = torch.iinfo(torch.long)

assert _LONG_INFO.min == _MOCK_LONG_INFO.min
assert _LONG_INFO.max == _MOCK_LONG_INFO.max


class OpenAIBaseModel(BaseModel):
# OpenAI API does not allow extra fields
Expand Down Expand Up @@ -108,9 +126,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
n: Optional[int] = 1
presence_penalty: Optional[float] = 0.0
response_format: Optional[ResponseFormat] = None
seed: Optional[int] = Field(None,
ge=torch.iinfo(torch.long).min,
le=torch.iinfo(torch.long).max)
seed: Optional[int] = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max)
stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
stream: Optional[bool] = False
stream_options: Optional[StreamOptions] = None
Expand Down Expand Up @@ -327,9 +343,7 @@ class CompletionRequest(OpenAIBaseModel):
max_tokens: Optional[int] = 16
n: int = 1
presence_penalty: Optional[float] = 0.0
seed: Optional[int] = Field(None,
ge=torch.iinfo(torch.long).min,
le=torch.iinfo(torch.long).max)
seed: Optional[int] = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max)
stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
stream: Optional[bool] = False
stream_options: Optional[StreamOptions] = None
Expand Down
26 changes: 24 additions & 2 deletions vllm/model_executor/guided_decoding/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
from vllm.entrypoints.openai.protocol import (
ChatCompletionNamedToolChoiceParam, ChatCompletionRequest,
CompletionRequest)
from vllm.model_executor.guided_decoding.lm_format_enforcer_decoding import (
get_lm_format_enforcer_guided_decoding_logits_processor)
from vllm.model_executor.guided_decoding.guided_fields import (
GuidedDecodingRequest)
from vllm.model_executor.guided_decoding.outlines_decoding import (
get_local_outlines_guided_decoding_logits_processor,
get_outlines_guided_decoding_logits_processor)
from vllm.sampling_params import LogitsProcessor

Expand All @@ -20,6 +21,8 @@ async def get_guided_decoding_logits_processor(
return await get_outlines_guided_decoding_logits_processor(
request, tokenizer)
if guided_decoding_backend == 'lm-format-enforcer':
from vllm.model_executor.guided_decoding.lm_format_enforcer_decoding import ( # noqa
get_lm_format_enforcer_guided_decoding_logits_processor)
return await get_lm_format_enforcer_guided_decoding_logits_processor(
request, tokenizer)

Expand All @@ -28,6 +31,25 @@ async def get_guided_decoding_logits_processor(
"Must be one of 'outlines, 'lm-format-enforcer'")


def get_local_guided_decoding_logits_processor(
guided_decoding_backend: str, guided_options: GuidedDecodingRequest,
tokenizer) -> Optional[LogitsProcessor]:
# request = _adapt_request_for_tool_use(request)

if guided_decoding_backend == 'outlines':
return get_local_outlines_guided_decoding_logits_processor(
guided_options, tokenizer)
if guided_decoding_backend == 'lm-format-enforcer':
from vllm.model_executor.guided_decoding.lm_format_enforcer_decoding import ( # noqa
get_local_lm_format_enforcer_guided_decoding_logits_processor)
return get_local_lm_format_enforcer_guided_decoding_logits_processor(
guided_options, tokenizer)

raise ValueError(
f"Unknown guided decoding backend '{guided_decoding_backend}'. "
"Must be one of 'outlines, 'lm-format-enforcer'")


def _adapt_request_for_tool_use(request: Union[CompletionRequest,
ChatCompletionRequest]):
# the legacy completion API does not support tool use
Expand Down
Loading
Loading