From 654bc5ca49bde0969bc95e4b1dbe7fabbb8f631c Mon Sep 17 00:00:00 2001 From: Yihuan Bu <88394319+kevinbu233@users.noreply.github.com> Date: Sat, 3 Aug 2024 23:12:09 -0400 Subject: [PATCH] Support for guided decoding for offline LLM (#6878) Co-authored-by: Cyrus Leung --- docs/source/conf.py | 1 + tests/entrypoints/{openai => }/conftest.py | 22 ++- tests/entrypoints/llm/test_guided_generate.py | 142 ++++++++++++++++++ vllm/entrypoints/llm.py | 44 +++++- vllm/entrypoints/openai/protocol.py | 26 +++- .../guided_decoding/__init__.py | 26 +++- .../guided_decoding/guided_fields.py | 38 +++++ .../lm_format_enforcer_decoding.py | 39 +++++ .../guided_decoding/outlines_decoding.py | 26 +++- 9 files changed, 352 insertions(+), 12 deletions(-) rename tests/entrypoints/{openai => }/conftest.py (83%) create mode 100644 tests/entrypoints/llm/test_guided_generate.py create mode 100644 vllm/model_executor/guided_decoding/guided_fields.py diff --git a/docs/source/conf.py b/docs/source/conf.py index 1093b30bca11d..f1eb8524d4e9c 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -111,6 +111,7 @@ def setup(app): "tqdm", "tensorizer", "pynvml", + "outlines", ] for mock_target in autodoc_mock_imports: diff --git a/tests/entrypoints/openai/conftest.py b/tests/entrypoints/conftest.py similarity index 83% rename from tests/entrypoints/openai/conftest.py rename to tests/entrypoints/conftest.py index 0837644f26bde..e7ef5637c8ccb 100644 --- a/tests/entrypoints/openai/conftest.py +++ b/tests/entrypoints/conftest.py @@ -1,6 +1,26 @@ import pytest +@pytest.fixture +def sample_prompts(): + return [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] + + +@pytest.fixture +def sample_token_ids(): + return [ + [0], + [0, 1], + [0, 2, 1], + [0, 3, 1, 2], + ] + + @pytest.fixture def sample_regex(): return (r"((25[0-5]|(2[0-4]|1\d|[1-9]|)\d)\.){3}" @@ -66,4 +86,4 @@ def sample_sql_statements(): table: "table_1" | "table_2" condition: column "=" number number: "1" | "2" -""") \ No newline at end of file +""") diff --git a/tests/entrypoints/llm/test_guided_generate.py b/tests/entrypoints/llm/test_guided_generate.py new file mode 100644 index 0000000000000..873e115421257 --- /dev/null +++ b/tests/entrypoints/llm/test_guided_generate.py @@ -0,0 +1,142 @@ +import json +import re +import weakref + +import jsonschema +import pytest + +from vllm.entrypoints.llm import LLM +from vllm.outputs import RequestOutput +from vllm.sampling_params import SamplingParams + +from ...conftest import cleanup + +MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta" + + +@pytest.fixture(scope="module") +def llm(): + # pytest caches the fixture so we use weakref.proxy to + # enable garbage collection + llm = LLM(model=MODEL_NAME, max_model_len=1024) + + with llm.deprecate_legacy_api(): + yield weakref.proxy(llm) + del llm + cleanup() + + +@pytest.mark.skip_global_cleanup +def test_guided_regex(sample_regex, llm): + sampling_params = SamplingParams( + temperature=0.8, + top_p=0.95, + ) + outputs = llm.generate( + prompts=[ + f"Give an example IPv4 address with this regex: {sample_regex}" + ] * 2, + sampling_params=sampling_params, + use_tqdm=True, + guided_options_request=dict(guided_regex=sample_regex)) + + assert outputs is not None + for output in outputs: + assert output is not None + assert isinstance(output, RequestOutput) + prompt = output.prompt + generated_text = output.outputs[0].text + print(generated_text) + assert generated_text is not None + assert re.fullmatch(sample_regex, generated_text) is not None + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + + +@pytest.mark.skip_global_cleanup +def test_guided_json_completion(sample_json_schema, llm): + sampling_params = SamplingParams( + temperature=1.0, + max_tokens=1000, + ) + outputs = llm.generate( + prompts=[ + f"Give an example JSON for an employee profile " + f"that fits this schema: {sample_json_schema}" + ] * 2, + sampling_params=sampling_params, + use_tqdm=True, + guided_options_request=dict(guided_json=sample_json_schema)) + + assert outputs is not None + + for output in outputs: + assert output is not None + assert isinstance(output, RequestOutput) + prompt = output.prompt + + generated_text = output.outputs[0].text + assert generated_text is not None + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + output_json = json.loads(generated_text) + jsonschema.validate(instance=output_json, schema=sample_json_schema) + + +@pytest.mark.skip_global_cleanup +def test_guided_choice_completion(sample_guided_choice, llm): + sampling_params = SamplingParams( + temperature=0.8, + top_p=0.95, + ) + outputs = llm.generate( + prompts="The best language for type-safe systems programming is ", + sampling_params=sampling_params, + use_tqdm=True, + guided_options_request=dict(guided_choice=sample_guided_choice)) + + assert outputs is not None + for output in outputs: + assert output is not None + assert isinstance(output, RequestOutput) + prompt = output.prompt + generated_text = output.outputs[0].text + print(generated_text) + assert generated_text is not None + assert generated_text in sample_guided_choice + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + + +@pytest.mark.skip_global_cleanup +def test_guided_grammar(sample_sql_statements, llm): + + sampling_params = SamplingParams( + temperature=0.8, + top_p=0.95, + max_tokens=1000, + ) + outputs = llm.generate( + prompts=("Generate a sql state that select col_1 from " + "table_1 where it is equals to 1"), + sampling_params=sampling_params, + use_tqdm=True, + guided_options_request=dict(guided_grammar=sample_sql_statements)) + + assert outputs is not None + for output in outputs: + assert output is not None + assert isinstance(output, RequestOutput) + prompt = output.prompt + + generated_text = output.outputs[0].text + assert generated_text is not None + # use Lark to parse the output, and make sure it's a valid parse tree + from lark import Lark + parser = Lark(sample_sql_statements) + parser.parse(generated_text) + + # remove spaces for comparison b/c we removed them in the grammar + ground_truth = "SELECT col_1 from table_1 where col_1 = 1".replace( + " ", "") + + assert generated_text.strip() == ground_truth + + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 62309ed345b1d..262cba79e5712 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -10,6 +10,9 @@ parse_and_batch_prompt) from vllm.logger import init_logger from vllm.lora.request import LoRARequest +from vllm.model_executor.guided_decoding import ( + GuidedDecodingRequest, get_local_guided_decoding_logits_processor) +from vllm.model_executor.guided_decoding.guided_fields import LLMGuidedOptions from vllm.outputs import EmbeddingRequestOutput, RequestOutput from vllm.pooling_params import PoolingParams from vllm.prompt_adapter.request import PromptAdapterRequest @@ -262,6 +265,8 @@ def generate( use_tqdm: bool = True, lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, + guided_options_request: Optional[Union[LLMGuidedOptions, + GuidedDecodingRequest]] = None ) -> List[RequestOutput]: """Generates the completions for the input prompts. @@ -303,6 +308,14 @@ def generate( else: inputs = cast(Union[PromptInputs, Sequence[PromptInputs]], prompts) + if isinstance(guided_options_request, dict): + if len(guided_options_request) > 1: + raise ValueError( + "You can only use one guided decoding but multiple is " + f"specified: {guided_options_request}") + guided_options_request = GuidedDecodingRequest( + **guided_options_request) + if sampling_params is None: # Use default sampling params. sampling_params = SamplingParams() @@ -311,7 +324,8 @@ def generate( inputs=inputs, params=sampling_params, lora_request=lora_request, - prompt_adapter_request=prompt_adapter_request) + prompt_adapter_request=prompt_adapter_request, + guided_options=guided_options_request) outputs = self._run_engine(use_tqdm=use_tqdm) return LLMEngine.validate_outputs(outputs, RequestOutput) @@ -508,6 +522,7 @@ def _validate_and_add_requests( Sequence[PoolingParams]], lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]], prompt_adapter_request: Optional[PromptAdapterRequest], + guided_options: Optional[GuidedDecodingRequest] = None, ) -> None: if isinstance(inputs, (str, dict)): # Convert a single prompt to a list. @@ -523,6 +538,15 @@ def _validate_and_add_requests( raise ValueError("The lengths of prompts and lora_request " "must be the same.") + if isinstance(params, list): + params = [ + self._add_guided_processor(param, guided_options) + if isinstance(param, SamplingParams) else param + for param in params + ] + elif isinstance(params, SamplingParams): + params = self._add_guided_processor(params, guided_options) + # Add requests to the engine. for i, request_inputs in enumerate(inputs): self._add_request( @@ -548,6 +572,24 @@ def _add_request( lora_request=lora_request, prompt_adapter_request=prompt_adapter_request) + def _add_guided_processor( + self, + params: SamplingParams, + guided_options: Optional[GuidedDecodingRequest] = None): + if guided_options: + if guided_options.guided_decoding_backend is None: + decoding_config = self.llm_engine.get_decoding_config() + guided_options.guided_decoding_backend = ( + decoding_config.guided_decoding_backend) + guided_logits_processor = get_local_guided_decoding_logits_processor( #noqa + guided_options.guided_decoding_backend, guided_options, + self.get_tokenizer()) + if guided_logits_processor: + if params.logits_processors is None: + params.logits_processors = [] + params.logits_processors.append(guided_logits_processor) + return params + def _run_engine( self, *, use_tqdm: bool ) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 3b35ae1ebd705..76318a1271229 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -1,6 +1,7 @@ # Adapted from # https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py import time +from argparse import Namespace from typing import Any, Dict, List, Literal, Optional, Union import torch @@ -14,6 +15,23 @@ from vllm.sampling_params import LogitsProcessor, SamplingParams from vllm.utils import random_uuid +# torch is mocked during docs generation, +# so we have to provide the values as literals +_MOCK_LONG_INFO = Namespace(min=-9223372036854775808, max=9223372036854775807) + +try: + from sphinx.ext.autodoc.mock import _MockModule + + if isinstance(torch, _MockModule): + _LONG_INFO = _MOCK_LONG_INFO + else: + _LONG_INFO = torch.iinfo(torch.long) +except ModuleNotFoundError: + _LONG_INFO = torch.iinfo(torch.long) + +assert _LONG_INFO.min == _MOCK_LONG_INFO.min +assert _LONG_INFO.max == _MOCK_LONG_INFO.max + class OpenAIBaseModel(BaseModel): # OpenAI API does not allow extra fields @@ -108,9 +126,7 @@ class ChatCompletionRequest(OpenAIBaseModel): n: Optional[int] = 1 presence_penalty: Optional[float] = 0.0 response_format: Optional[ResponseFormat] = None - seed: Optional[int] = Field(None, - ge=torch.iinfo(torch.long).min, - le=torch.iinfo(torch.long).max) + seed: Optional[int] = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max) stop: Optional[Union[str, List[str]]] = Field(default_factory=list) stream: Optional[bool] = False stream_options: Optional[StreamOptions] = None @@ -327,9 +343,7 @@ class CompletionRequest(OpenAIBaseModel): max_tokens: Optional[int] = 16 n: int = 1 presence_penalty: Optional[float] = 0.0 - seed: Optional[int] = Field(None, - ge=torch.iinfo(torch.long).min, - le=torch.iinfo(torch.long).max) + seed: Optional[int] = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max) stop: Optional[Union[str, List[str]]] = Field(default_factory=list) stream: Optional[bool] = False stream_options: Optional[StreamOptions] = None diff --git a/vllm/model_executor/guided_decoding/__init__.py b/vllm/model_executor/guided_decoding/__init__.py index 50aa3ec379f4a..4a2476dd6314d 100644 --- a/vllm/model_executor/guided_decoding/__init__.py +++ b/vllm/model_executor/guided_decoding/__init__.py @@ -3,9 +3,10 @@ from vllm.entrypoints.openai.protocol import ( ChatCompletionNamedToolChoiceParam, ChatCompletionRequest, CompletionRequest) -from vllm.model_executor.guided_decoding.lm_format_enforcer_decoding import ( - get_lm_format_enforcer_guided_decoding_logits_processor) +from vllm.model_executor.guided_decoding.guided_fields import ( + GuidedDecodingRequest) from vllm.model_executor.guided_decoding.outlines_decoding import ( + get_local_outlines_guided_decoding_logits_processor, get_outlines_guided_decoding_logits_processor) from vllm.sampling_params import LogitsProcessor @@ -20,6 +21,8 @@ async def get_guided_decoding_logits_processor( return await get_outlines_guided_decoding_logits_processor( request, tokenizer) if guided_decoding_backend == 'lm-format-enforcer': + from vllm.model_executor.guided_decoding.lm_format_enforcer_decoding import ( # noqa + get_lm_format_enforcer_guided_decoding_logits_processor) return await get_lm_format_enforcer_guided_decoding_logits_processor( request, tokenizer) @@ -28,6 +31,25 @@ async def get_guided_decoding_logits_processor( "Must be one of 'outlines, 'lm-format-enforcer'") +def get_local_guided_decoding_logits_processor( + guided_decoding_backend: str, guided_options: GuidedDecodingRequest, + tokenizer) -> Optional[LogitsProcessor]: + # request = _adapt_request_for_tool_use(request) + + if guided_decoding_backend == 'outlines': + return get_local_outlines_guided_decoding_logits_processor( + guided_options, tokenizer) + if guided_decoding_backend == 'lm-format-enforcer': + from vllm.model_executor.guided_decoding.lm_format_enforcer_decoding import ( # noqa + get_local_lm_format_enforcer_guided_decoding_logits_processor) + return get_local_lm_format_enforcer_guided_decoding_logits_processor( + guided_options, tokenizer) + + raise ValueError( + f"Unknown guided decoding backend '{guided_decoding_backend}'. " + "Must be one of 'outlines, 'lm-format-enforcer'") + + def _adapt_request_for_tool_use(request: Union[CompletionRequest, ChatCompletionRequest]): # the legacy completion API does not support tool use diff --git a/vllm/model_executor/guided_decoding/guided_fields.py b/vllm/model_executor/guided_decoding/guided_fields.py new file mode 100644 index 0000000000000..3082ac1510ccc --- /dev/null +++ b/vllm/model_executor/guided_decoding/guided_fields.py @@ -0,0 +1,38 @@ +from dataclasses import dataclass +from typing import Dict, List, Optional, TypedDict, Union + +from pydantic import BaseModel + + +class LLMGuidedOptions(TypedDict, total=False): + guided_json: Union[Dict, BaseModel, str] + guided_regex: str + guided_choice: List[str] + guided_grammar: str + guided_decoding_backend: str + guided_whitespace_pattern: str + guided_json_object: bool + + +@dataclass +class GuidedDecodingRequest: + """One of the fields will be used to retrieve the logit processor.""" + guided_json: Optional[Union[Dict, BaseModel, str]] = None + guided_regex: Optional[str] = None + guided_choice: Optional[List[str]] = None + guided_grammar: Optional[str] = None + guided_decoding_backend: Optional[str] = None + guided_whitespace_pattern: Optional[str] = None + guided_json_object: Optional[bool] = None + + def __post_init__(self): + """Validate that some fields are mutually exclusive.""" + guide_count = sum([ + self.guided_json is not None, self.guided_regex is not None, + self.guided_choice is not None, self.guided_grammar is not None, + self.guided_json_object is not None + ]) + if guide_count > 1: + raise ValueError( + "You can only use one kind of guided decoding but multiple are " + f"specified: {self.__dict__}") diff --git a/vllm/model_executor/guided_decoding/lm_format_enforcer_decoding.py b/vllm/model_executor/guided_decoding/lm_format_enforcer_decoding.py index d0a5ca5592f9d..b2188c9cbc2bb 100644 --- a/vllm/model_executor/guided_decoding/lm_format_enforcer_decoding.py +++ b/vllm/model_executor/guided_decoding/lm_format_enforcer_decoding.py @@ -12,7 +12,10 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, CompletionRequest) +from vllm.model_executor.guided_decoding.guided_fields import ( + GuidedDecodingRequest) from vllm.model_executor.guided_decoding.outlines_decoding import ( + get_local_outlines_guided_decoding_logits_processor, get_outlines_guided_decoding_logits_processor) from vllm.sampling_params import LogitsProcessor @@ -54,6 +57,42 @@ async def get_lm_format_enforcer_guided_decoding_logits_processor( return logits_processor +def get_local_lm_format_enforcer_guided_decoding_logits_processor( + guided_options: GuidedDecodingRequest, + tokenizer) -> Optional[LogitsProcessor]: + """ + Given an OpenAI-compatible request, check for guided decoding parameters + and get the necessary logits processor for the given guide. + We cache logit processors by (guide, tokenizer), and on cache hit + we make a shallow copy to reuse the same underlying FSM. + """ + + tokenizer_data = _cached_build_vllm_token_enforcer_tokenizer_data( + tokenizer) + character_level_parser: CharacterLevelParser + if guided_options.guided_json: + schema = _normalize_json_schema_object(guided_options.guided_json) + character_level_parser = JsonSchemaParser(schema) + elif guided_options.guided_choice: + character_level_parser = UnionParser( + [StringParser(choice) for choice in guided_options.guided_choice]) + elif guided_options.guided_regex: + character_level_parser = RegexParser(guided_options.guided_regex) + elif guided_options.guided_grammar: + # CFG grammar not supported by LMFE, revert to outlines + return get_local_outlines_guided_decoding_logits_processor( + guided_options, tokenizer) + elif guided_options.guided_json_object: + # None means any json object + character_level_parser = JsonSchemaParser(None) + else: + return None + + logits_processor = build_vllm_logits_processor(tokenizer_data, + character_level_parser) + return logits_processor + + def _normalize_json_schema_object(schema: Union[str, dict, BaseModel]) -> dict: if isinstance(schema, str): return json_loads(schema) diff --git a/vllm/model_executor/guided_decoding/outlines_decoding.py b/vllm/model_executor/guided_decoding/outlines_decoding.py index 721f7e0530cb7..bc62224dabecf 100644 --- a/vllm/model_executor/guided_decoding/outlines_decoding.py +++ b/vllm/model_executor/guided_decoding/outlines_decoding.py @@ -10,6 +10,8 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, CompletionRequest) +from vllm.model_executor.guided_decoding.guided_fields import ( + GuidedDecodingRequest) from vllm.model_executor.guided_decoding.outlines_logits_processors import ( CFGLogitsProcessor, JSONLogitsProcessor, RegexLogitsProcessor) @@ -77,8 +79,27 @@ async def get_outlines_guided_decoding_logits_processor( mode, request.guided_whitespace_pattern) +def get_local_outlines_guided_decoding_logits_processor( + guided_options: GuidedDecodingRequest, tokenizer: PreTrainedTokenizerBase +) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor, + None]: + """ + Given an OpenAI-compatible request, check for guided decoding parameters + and get the necessary logits processor for the given guide. + We cache logit processors by (guide, tokenizer), and on cache hit + we make a shallow copy to reuse the same underlying FSM. + """ + guide, mode = _get_guide_and_mode(guided_options) + if not guide or not mode: + return None + + return _get_logits_processor(guide, tokenizer, mode, + guided_options.guided_whitespace_pattern) + + def _get_guide_and_mode( - request: Union[CompletionRequest, ChatCompletionRequest] + request: Union[CompletionRequest, ChatCompletionRequest, + GuidedDecodingRequest] ) -> Union[Tuple[str, GuidedDecodingMode], Tuple[None, None]]: if request.guided_json: @@ -102,7 +123,8 @@ def _get_guide_and_mode( return choices_regex, GuidedDecodingMode.CHOICE elif request.guided_grammar: return request.guided_grammar, GuidedDecodingMode.GRAMMAR - elif (request.response_format is not None + elif (not isinstance(request, GuidedDecodingRequest) + and request.response_format is not None and request.response_format.type == "json_object"): return JSON_GRAMMAR, GuidedDecodingMode.GRAMMAR else: