From decffe4eb14988bb54abadf0d06dbba6102a54d3 Mon Sep 17 00:00:00 2001 From: kevinbu233 Date: Sun, 28 Jul 2024 22:41:20 +0000 Subject: [PATCH 01/12] local guided decoding --- tests/entrypoints/llm/conftest.py | 69 ++++++++ tests/entrypoints/llm/test_guided_generate.py | 154 ++++++++++++++++++ vllm/entrypoints/llm.py | 47 +++++- .../guided_decoding/__init__.py | 26 ++- .../guided_decoding/guided_fields.py | 29 ++++ .../lm_format_enforcer_decoding.py | 39 +++++ .../guided_decoding/outlines_decoding.py | 26 ++- 7 files changed, 384 insertions(+), 6 deletions(-) create mode 100644 tests/entrypoints/llm/conftest.py create mode 100644 tests/entrypoints/llm/test_guided_generate.py create mode 100644 vllm/model_executor/guided_decoding/guided_fields.py diff --git a/tests/entrypoints/llm/conftest.py b/tests/entrypoints/llm/conftest.py new file mode 100644 index 0000000000000..0837644f26bde --- /dev/null +++ b/tests/entrypoints/llm/conftest.py @@ -0,0 +1,69 @@ +import pytest + + +@pytest.fixture +def sample_regex(): + return (r"((25[0-5]|(2[0-4]|1\d|[1-9]|)\d)\.){3}" + r"(25[0-5]|(2[0-4]|1\d|[1-9]|)\d)") + + +@pytest.fixture +def sample_json_schema(): + return { + "type": "object", + "properties": { + "name": { + "type": "string" + }, + "age": { + "type": "integer" + }, + "skills": { + "type": "array", + "items": { + "type": "string", + "maxLength": 10 + }, + "minItems": 3 + }, + "work_history": { + "type": "array", + "items": { + "type": "object", + "properties": { + "company": { + "type": "string" + }, + "duration": { + "type": "number" + }, + "position": { + "type": "string" + } + }, + "required": ["company", "position"] + } + } + }, + "required": ["name", "age", "skills", "work_history"] + } + + +@pytest.fixture +def sample_guided_choice(): + return [ + "Python", "Java", "JavaScript", "C++", "C#", "PHP", "TypeScript", + "Ruby", "Swift", "Kotlin" + ] + + +@pytest.fixture +def sample_sql_statements(): + return (""" +start: select_statement +select_statement: "SELECT" column "from" table "where" condition +column: "col_1" | "col_2" +table: "table_1" | "table_2" +condition: column "=" number +number: "1" | "2" +""") \ No newline at end of file diff --git a/tests/entrypoints/llm/test_guided_generate.py b/tests/entrypoints/llm/test_guided_generate.py new file mode 100644 index 0000000000000..5fe69d9c88aa8 --- /dev/null +++ b/tests/entrypoints/llm/test_guided_generate.py @@ -0,0 +1,154 @@ +import json +import re +import weakref +from typing import List + +import jsonschema +import pytest + +from vllm.entrypoints.llm import LLM +from vllm.outputs import RequestOutput +from vllm.sampling_params import SamplingParams + +from ...conftest import cleanup + + +MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta" + +PROMPTS = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", +] + +TOKEN_IDS = [ + [0], + [0, 1], + [0, 2, 1], + [0, 3, 1, 2], +] + +@pytest.fixture(scope="module") +def llm(): + # pytest caches the fixture so we use weakref.proxy to + # enable garbage collection + llm = LLM(model=MODEL_NAME, max_model_len=1024) + + with llm.deprecate_legacy_api(): + yield weakref.proxy(llm) + del llm + cleanup() + +@pytest.mark.skip_global_cleanup +def test_guided_regex(sample_regex, llm): + sampling_params = SamplingParams( + temperature=0.8, + top_p=0.95, + ) + outputs = llm.generate(prompts=[ + f"Give an example IPv4 address with this regex: {sample_regex}" + ] * 2, + sampling_params=sampling_params, + use_tqdm=True, + guided_options=dict(guided_regex=sample_regex)) + + assert outputs is not None + for output in outputs: + assert output is not None + assert isinstance(output, RequestOutput) + prompt = output.prompt + generated_text = output.outputs[0].text + print(generated_text) + assert generated_text is not None + assert re.fullmatch(sample_regex, generated_text) is not None + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + + +@pytest.mark.skip_global_cleanup +def test_guided_json_completion(sample_json_schema, llm): + sampling_params = SamplingParams( + temperature=1.0, + max_tokens=1000, + ) + outputs = llm.generate(prompts=[ + f"Give an example JSON for an employee profile " + f"that fits this schema: {sample_json_schema}" + ] * 2, + sampling_params=sampling_params, + use_tqdm=True, + guided_options=dict(guided_json=sample_json_schema)) + + assert outputs is not None + + for output in outputs: + assert output is not None + assert isinstance(output, RequestOutput) + prompt = output.prompt + + generated_text = output.outputs[0].text + assert generated_text is not None + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + output_json = json.loads(generated_text) + jsonschema.validate(instance=output_json, schema=sample_json_schema) + + +@pytest.mark.skip_global_cleanup +def test_guided_choice_completion(sample_guided_choice, llm): + sampling_params = SamplingParams( + temperature=0.8, + top_p=0.95, + ) + outputs = llm.generate( + prompts="The best language for type-safe systems programming is ", + sampling_params=sampling_params, + use_tqdm=True, + guided_options=dict(guided_choice=sample_guided_choice)) + + assert outputs is not None + for output in outputs: + assert output is not None + assert isinstance(output, RequestOutput) + prompt = output.prompt + generated_text = output.outputs[0].text + print(generated_text) + assert generated_text is not None + assert generated_text in sample_guided_choice + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + + +@pytest.mark.skip_global_cleanup +def test_guided_grammar(sample_sql_statements, llm): + + sampling_params = SamplingParams( + temperature=0.8, + top_p=0.95, + ) + outputs = llm.generate( + prompts=("Generate a sql state that select col_1 from " + "table_1 where it is equals to 1"), + sampling_params=sampling_params, + use_tqdm=True, + guided_options=dict(guided_grammar=sample_sql_statements)) + + assert outputs is not None + for output in outputs: + assert output is not None + assert isinstance(output, RequestOutput) + prompt = output.prompt + + generated_text = output.outputs[0].text + assert generated_text is not None + + # use Lark to parse the output, and make sure it's a valid parse tree + from lark import Lark + parser = Lark(sample_sql_statements) + parser.parse(generated_text) + + # remove spaces for comparison b/c we removed them in the grammar + ground_truth = "SELECT col_1 from table_1 where col_1 = 1".replace( + " ", "") + + assert generated_text.strip() == ground_truth + + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") \ No newline at end of file diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 62309ed345b1d..8a90b2596555f 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -1,5 +1,6 @@ from contextlib import contextmanager -from typing import ClassVar, List, Optional, Sequence, Union, cast, overload +from typing import (ClassVar, Dict, List, Optional, Sequence, Union, cast, + overload) from tqdm import tqdm from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast @@ -10,6 +11,8 @@ parse_and_batch_prompt) from vllm.logger import init_logger from vllm.lora.request import LoRARequest +from vllm.model_executor.guided_decoding import ( + GuidedDecodingFields, get_local_guided_decoding_logits_processor) from vllm.outputs import EmbeddingRequestOutput, RequestOutput from vllm.pooling_params import PoolingParams from vllm.prompt_adapter.request import PromptAdapterRequest @@ -262,6 +265,7 @@ def generate( use_tqdm: bool = True, lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, + guided_options: Optional[Union[Dict, GuidedDecodingFields]] = None ) -> List[RequestOutput]: """Generates the completions for the input prompts. @@ -303,6 +307,13 @@ def generate( else: inputs = cast(Union[PromptInputs, Sequence[PromptInputs]], prompts) + if isinstance(guided_options, Dict): + if len(guided_options) > 1: + raise ValueError( + "You can only use one guided decoding but multiple is " + f"specified: {self.__dict__}") + guided_options = GuidedDecodingFields(**guided_options) + if sampling_params is None: # Use default sampling params. sampling_params = SamplingParams() @@ -311,7 +322,8 @@ def generate( inputs=inputs, params=sampling_params, lora_request=lora_request, - prompt_adapter_request=prompt_adapter_request) + prompt_adapter_request=prompt_adapter_request, + guided_options=guided_options) outputs = self._run_engine(use_tqdm=use_tqdm) return LLMEngine.validate_outputs(outputs, RequestOutput) @@ -508,6 +520,7 @@ def _validate_and_add_requests( Sequence[PoolingParams]], lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]], prompt_adapter_request: Optional[PromptAdapterRequest], + guided_options: Optional[GuidedDecodingFields] = None, ) -> None: if isinstance(inputs, (str, dict)): # Convert a single prompt to a list. @@ -523,6 +536,18 @@ def _validate_and_add_requests( raise ValueError("The lengths of prompts and lora_request " "must be the same.") + if isinstance(params, list): + if len(params) != num_requests: + raise ValueError("The lengths of prompts and params " + "must be the same.") + + params = [ + self._add_guided_processor(param, guided_options) + for param in params if isinstance(param, SamplingParams) + ] + elif isinstance(params, SamplingParams): + params = self._add_guided_processor(params, guided_options) + # Add requests to the engine. for i, request_inputs in enumerate(inputs): self._add_request( @@ -548,6 +573,24 @@ def _add_request( lora_request=lora_request, prompt_adapter_request=prompt_adapter_request) + def _add_guided_processor( + self, + params: SamplingParams, + guided_options: Optional[GuidedDecodingFields] = None): + if guided_options: + if guided_options.guided_decoding_backend is None: + decoding_config = self.llm_engine.get_decoding_config() + guided_options.guided_decoding_backend = ( + decoding_config.guided_decoding_backend) + guided_logits_processor = get_local_guided_decoding_logits_processor( #noqa + guided_options.guided_decoding_backend, guided_options, + self.get_tokenizer()) + if guided_logits_processor: + if params.logits_processors is None: + params.logits_processors = [] + params.logits_processors.append(guided_logits_processor) + return params + def _run_engine( self, *, use_tqdm: bool ) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: diff --git a/vllm/model_executor/guided_decoding/__init__.py b/vllm/model_executor/guided_decoding/__init__.py index 50aa3ec379f4a..4c846bcfddc7c 100644 --- a/vllm/model_executor/guided_decoding/__init__.py +++ b/vllm/model_executor/guided_decoding/__init__.py @@ -3,9 +3,10 @@ from vllm.entrypoints.openai.protocol import ( ChatCompletionNamedToolChoiceParam, ChatCompletionRequest, CompletionRequest) -from vllm.model_executor.guided_decoding.lm_format_enforcer_decoding import ( - get_lm_format_enforcer_guided_decoding_logits_processor) +from vllm.model_executor.guided_decoding.guided_fields import ( + GuidedDecodingFields) from vllm.model_executor.guided_decoding.outlines_decoding import ( + get_local_outlines_guided_decoding_logits_processor, get_outlines_guided_decoding_logits_processor) from vllm.sampling_params import LogitsProcessor @@ -20,6 +21,8 @@ async def get_guided_decoding_logits_processor( return await get_outlines_guided_decoding_logits_processor( request, tokenizer) if guided_decoding_backend == 'lm-format-enforcer': + from vllm.model_executor.guided_decoding.lm_format_enforcer_decoding import ( # noqa + get_lm_format_enforcer_guided_decoding_logits_processor) return await get_lm_format_enforcer_guided_decoding_logits_processor( request, tokenizer) @@ -28,6 +31,25 @@ async def get_guided_decoding_logits_processor( "Must be one of 'outlines, 'lm-format-enforcer'") +def get_local_guided_decoding_logits_processor( + guided_decoding_backend: str, guided_options: GuidedDecodingFields, + tokenizer) -> Optional[LogitsProcessor]: + # request = _adapt_request_for_tool_use(request) + + if guided_decoding_backend == 'outlines': + return get_local_outlines_guided_decoding_logits_processor( + guided_options, tokenizer) + if guided_decoding_backend == 'lm-format-enforcer': + from vllm.model_executor.guided_decoding.lm_format_enforcer_decoding import ( # noqa + get_local_lm_format_enforcer_guided_decoding_logits_processor) + return get_local_lm_format_enforcer_guided_decoding_logits_processor( + guided_options, tokenizer) + + raise ValueError( + f"Unknown guided decoding backend '{guided_decoding_backend}'. " + "Must be one of 'outlines, 'lm-format-enforcer'") + + def _adapt_request_for_tool_use(request: Union[CompletionRequest, ChatCompletionRequest]): # the legacy completion API does not support tool use diff --git a/vllm/model_executor/guided_decoding/guided_fields.py b/vllm/model_executor/guided_decoding/guided_fields.py new file mode 100644 index 0000000000000..63d699ca21add --- /dev/null +++ b/vllm/model_executor/guided_decoding/guided_fields.py @@ -0,0 +1,29 @@ +from dataclasses import dataclass +from typing import Dict, List, Optional, Union + +from pydantic import BaseModel + + +@dataclass +class GuidedDecodingFields: + """One of the fields will be used to retrieve the logit processor.""" + guided_json: Optional[Union[Dict, BaseModel, str]] = None + guided_regex: Optional[str] = None + guided_choice: Optional[List[str]] = None + guided_grammar: Optional[str] = None + guided_decoding_backend: Optional[str] = None + guided_whitespace_pattern: Optional[str] = None + guided_json_object: Optional[bool] = None + + def __post_init__(self): + """Validate that some fields are mutually exclusive.""" + guide_count = sum([ + self.guided_json is not None, + self.guided_regex is not None, + self.guided_choice is not None, + self.guided_grammar is not None, + self.guided_json_object is not None, ]) + if guide_count > 1: + raise ValueError( + "You can only use one kind of guided decoding but multiple is " + f"specified: {self.__dict__}") diff --git a/vllm/model_executor/guided_decoding/lm_format_enforcer_decoding.py b/vllm/model_executor/guided_decoding/lm_format_enforcer_decoding.py index d0a5ca5592f9d..40906354f62c2 100644 --- a/vllm/model_executor/guided_decoding/lm_format_enforcer_decoding.py +++ b/vllm/model_executor/guided_decoding/lm_format_enforcer_decoding.py @@ -12,7 +12,10 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, CompletionRequest) +from vllm.model_executor.guided_decoding.guided_fields import ( + GuidedDecodingFields) from vllm.model_executor.guided_decoding.outlines_decoding import ( + get_local_outlines_guided_decoding_logits_processor, get_outlines_guided_decoding_logits_processor) from vllm.sampling_params import LogitsProcessor @@ -54,6 +57,42 @@ async def get_lm_format_enforcer_guided_decoding_logits_processor( return logits_processor +def get_local_lm_format_enforcer_guided_decoding_logits_processor( + guided_options: GuidedDecodingFields, + tokenizer) -> Optional[LogitsProcessor]: + """ + Given an OpenAI-compatible request, check for guided decoding parameters + and get the necessary logits processor for the given guide. + We cache logit processors by (guide, tokenizer), and on cache hit + we make a shallow copy to reuse the same underlying FSM. + """ + + tokenizer_data = _cached_build_vllm_token_enforcer_tokenizer_data( + tokenizer) + character_level_parser: CharacterLevelParser + if guided_options.guided_json: + schema = _normalize_json_schema_object(guided_options.guided_json) + character_level_parser = JsonSchemaParser(schema) + elif guided_options.guided_choice: + character_level_parser = UnionParser( + [StringParser(choice) for choice in guided_options.guided_choice]) + elif guided_options.guided_regex: + character_level_parser = RegexParser(guided_options.guided_regex) + elif guided_options.guided_grammar: + # CFG grammar not supported by LMFE, revert to outlines + return get_local_outlines_guided_decoding_logits_processor( + guided_options, tokenizer) + elif guided_options.guided_json_object: + # None means any json object + character_level_parser = JsonSchemaParser(None) + else: + return None + + logits_processor = build_vllm_logits_processor(tokenizer_data, + character_level_parser) + return logits_processor + + def _normalize_json_schema_object(schema: Union[str, dict, BaseModel]) -> dict: if isinstance(schema, str): return json_loads(schema) diff --git a/vllm/model_executor/guided_decoding/outlines_decoding.py b/vllm/model_executor/guided_decoding/outlines_decoding.py index 721f7e0530cb7..487f205dae67f 100644 --- a/vllm/model_executor/guided_decoding/outlines_decoding.py +++ b/vllm/model_executor/guided_decoding/outlines_decoding.py @@ -10,6 +10,8 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, CompletionRequest) +from vllm.model_executor.guided_decoding.guided_fields import ( + GuidedDecodingFields) from vllm.model_executor.guided_decoding.outlines_logits_processors import ( CFGLogitsProcessor, JSONLogitsProcessor, RegexLogitsProcessor) @@ -77,8 +79,27 @@ async def get_outlines_guided_decoding_logits_processor( mode, request.guided_whitespace_pattern) +def get_local_outlines_guided_decoding_logits_processor( + guided_options: GuidedDecodingFields, tokenizer: PreTrainedTokenizerBase +) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor, + None]: + """ + Given an OpenAI-compatible request, check for guided decoding parameters + and get the necessary logits processor for the given guide. + We cache logit processors by (guide, tokenizer), and on cache hit + we make a shallow copy to reuse the same underlying FSM. + """ + guide, mode = _get_guide_and_mode(guided_options) + if not guide or not mode: + return None + + return _get_logits_processor(guide, tokenizer, mode, + guided_options.guided_whitespace_pattern) + + def _get_guide_and_mode( - request: Union[CompletionRequest, ChatCompletionRequest] + request: Union[CompletionRequest, ChatCompletionRequest, + GuidedDecodingFields] ) -> Union[Tuple[str, GuidedDecodingMode], Tuple[None, None]]: if request.guided_json: @@ -102,7 +123,8 @@ def _get_guide_and_mode( return choices_regex, GuidedDecodingMode.CHOICE elif request.guided_grammar: return request.guided_grammar, GuidedDecodingMode.GRAMMAR - elif (request.response_format is not None + elif (not isinstance(request, GuidedDecodingFields) + and request.response_format is not None and request.response_format.type == "json_object"): return JSON_GRAMMAR, GuidedDecodingMode.GRAMMAR else: From c3b01bb288f4fc7e402437b7425230791065011d Mon Sep 17 00:00:00 2001 From: kevinbu233 Date: Tue, 30 Jul 2024 06:46:32 +0000 Subject: [PATCH 02/12] fixes --- tests/entrypoints/llm/test_guided_generate.py | 1 - vllm/entrypoints/llm.py | 8 ++------ vllm/model_executor/guided_decoding/guided_fields.py | 4 ++-- 3 files changed, 4 insertions(+), 9 deletions(-) diff --git a/tests/entrypoints/llm/test_guided_generate.py b/tests/entrypoints/llm/test_guided_generate.py index 5fe69d9c88aa8..79761d577fd2a 100644 --- a/tests/entrypoints/llm/test_guided_generate.py +++ b/tests/entrypoints/llm/test_guided_generate.py @@ -1,7 +1,6 @@ import json import re import weakref -from typing import List import jsonschema import pytest diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 8a90b2596555f..c42dc039927b6 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -307,11 +307,11 @@ def generate( else: inputs = cast(Union[PromptInputs, Sequence[PromptInputs]], prompts) - if isinstance(guided_options, Dict): + if isinstance(guided_options, dict): if len(guided_options) > 1: raise ValueError( "You can only use one guided decoding but multiple is " - f"specified: {self.__dict__}") + f"specified: {guided_options}") guided_options = GuidedDecodingFields(**guided_options) if sampling_params is None: @@ -537,10 +537,6 @@ def _validate_and_add_requests( "must be the same.") if isinstance(params, list): - if len(params) != num_requests: - raise ValueError("The lengths of prompts and params " - "must be the same.") - params = [ self._add_guided_processor(param, guided_options) for param in params if isinstance(param, SamplingParams) diff --git a/vllm/model_executor/guided_decoding/guided_fields.py b/vllm/model_executor/guided_decoding/guided_fields.py index 63d699ca21add..b48ef24e76b3c 100644 --- a/vllm/model_executor/guided_decoding/guided_fields.py +++ b/vllm/model_executor/guided_decoding/guided_fields.py @@ -22,8 +22,8 @@ def __post_init__(self): self.guided_regex is not None, self.guided_choice is not None, self.guided_grammar is not None, - self.guided_json_object is not None, ]) + self.guided_json_object is not None]) if guide_count > 1: raise ValueError( - "You can only use one kind of guided decoding but multiple is " + "You can only use one kind of guided decoding but multiple are " f"specified: {self.__dict__}") From e3b07e0ee09304944eec5ca2e1f7ead3c5586f99 Mon Sep 17 00:00:00 2001 From: kevinbu233 Date: Wed, 31 Jul 2024 06:54:37 +0000 Subject: [PATCH 03/12] fix naming and typos --- tests/entrypoints/{llm => }/conftest.py | 18 +++++ tests/entrypoints/llm/test_guided_generate.py | 49 +++++-------- tests/entrypoints/openai/conftest.py | 69 ------------------- vllm/entrypoints/llm.py | 20 +++--- .../guided_decoding/__init__.py | 4 +- .../guided_decoding/guided_fields.py | 11 ++- .../lm_format_enforcer_decoding.py | 4 +- .../guided_decoding/outlines_decoding.py | 8 +-- 8 files changed, 61 insertions(+), 122 deletions(-) rename tests/entrypoints/{llm => }/conftest.py (84%) delete mode 100644 tests/entrypoints/openai/conftest.py diff --git a/tests/entrypoints/llm/conftest.py b/tests/entrypoints/conftest.py similarity index 84% rename from tests/entrypoints/llm/conftest.py rename to tests/entrypoints/conftest.py index 0837644f26bde..738ebc056deda 100644 --- a/tests/entrypoints/llm/conftest.py +++ b/tests/entrypoints/conftest.py @@ -1,5 +1,23 @@ import pytest +@pytest.fixture +def sample_prompts(): + return [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] + +@pytest.fixture +def sample_token_ids(): + return [ + [0], + [0, 1], + [0, 2, 1], + [0, 3, 1, 2], + ] + @pytest.fixture def sample_regex(): diff --git a/tests/entrypoints/llm/test_guided_generate.py b/tests/entrypoints/llm/test_guided_generate.py index 79761d577fd2a..977aca0441bb6 100644 --- a/tests/entrypoints/llm/test_guided_generate.py +++ b/tests/entrypoints/llm/test_guided_generate.py @@ -11,22 +11,8 @@ from ...conftest import cleanup - MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta" -PROMPTS = [ - "Hello, my name is", - "The president of the United States is", - "The capital of France is", - "The future of AI is", -] - -TOKEN_IDS = [ - [0], - [0, 1], - [0, 2, 1], - [0, 3, 1, 2], -] @pytest.fixture(scope="module") def llm(): @@ -39,18 +25,20 @@ def llm(): del llm cleanup() + @pytest.mark.skip_global_cleanup def test_guided_regex(sample_regex, llm): sampling_params = SamplingParams( temperature=0.8, top_p=0.95, ) - outputs = llm.generate(prompts=[ - f"Give an example IPv4 address with this regex: {sample_regex}" - ] * 2, - sampling_params=sampling_params, - use_tqdm=True, - guided_options=dict(guided_regex=sample_regex)) + outputs = llm.generate( + prompts=[ + f"Give an example IPv4 address with this regex: {sample_regex}" + ] * 2, + sampling_params=sampling_params, + use_tqdm=True, + guided_options_request=dict(guided_regex=sample_regex)) assert outputs is not None for output in outputs: @@ -70,13 +58,14 @@ def test_guided_json_completion(sample_json_schema, llm): temperature=1.0, max_tokens=1000, ) - outputs = llm.generate(prompts=[ - f"Give an example JSON for an employee profile " - f"that fits this schema: {sample_json_schema}" - ] * 2, - sampling_params=sampling_params, - use_tqdm=True, - guided_options=dict(guided_json=sample_json_schema)) + outputs = llm.generate( + prompts=[ + f"Give an example JSON for an employee profile " + f"that fits this schema: {sample_json_schema}" + ] * 2, + sampling_params=sampling_params, + use_tqdm=True, + guided_options_request=dict(guided_json=sample_json_schema)) assert outputs is not None @@ -102,7 +91,7 @@ def test_guided_choice_completion(sample_guided_choice, llm): prompts="The best language for type-safe systems programming is ", sampling_params=sampling_params, use_tqdm=True, - guided_options=dict(guided_choice=sample_guided_choice)) + guided_options_request=dict(guided_choice=sample_guided_choice)) assert outputs is not None for output in outputs: @@ -128,7 +117,7 @@ def test_guided_grammar(sample_sql_statements, llm): "table_1 where it is equals to 1"), sampling_params=sampling_params, use_tqdm=True, - guided_options=dict(guided_grammar=sample_sql_statements)) + guided_options_request=dict(guided_grammar=sample_sql_statements)) assert outputs is not None for output in outputs: @@ -150,4 +139,4 @@ def test_guided_grammar(sample_sql_statements, llm): assert generated_text.strip() == ground_truth - print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") \ No newline at end of file + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") diff --git a/tests/entrypoints/openai/conftest.py b/tests/entrypoints/openai/conftest.py deleted file mode 100644 index 0837644f26bde..0000000000000 --- a/tests/entrypoints/openai/conftest.py +++ /dev/null @@ -1,69 +0,0 @@ -import pytest - - -@pytest.fixture -def sample_regex(): - return (r"((25[0-5]|(2[0-4]|1\d|[1-9]|)\d)\.){3}" - r"(25[0-5]|(2[0-4]|1\d|[1-9]|)\d)") - - -@pytest.fixture -def sample_json_schema(): - return { - "type": "object", - "properties": { - "name": { - "type": "string" - }, - "age": { - "type": "integer" - }, - "skills": { - "type": "array", - "items": { - "type": "string", - "maxLength": 10 - }, - "minItems": 3 - }, - "work_history": { - "type": "array", - "items": { - "type": "object", - "properties": { - "company": { - "type": "string" - }, - "duration": { - "type": "number" - }, - "position": { - "type": "string" - } - }, - "required": ["company", "position"] - } - } - }, - "required": ["name", "age", "skills", "work_history"] - } - - -@pytest.fixture -def sample_guided_choice(): - return [ - "Python", "Java", "JavaScript", "C++", "C#", "PHP", "TypeScript", - "Ruby", "Swift", "Kotlin" - ] - - -@pytest.fixture -def sample_sql_statements(): - return (""" -start: select_statement -select_statement: "SELECT" column "from" table "where" condition -column: "col_1" | "col_2" -table: "table_1" | "table_2" -condition: column "=" number -number: "1" | "2" -""") \ No newline at end of file diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index c42dc039927b6..5930e81b94c8b 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -12,7 +12,7 @@ from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.model_executor.guided_decoding import ( - GuidedDecodingFields, get_local_guided_decoding_logits_processor) + GuidedDecodingRequest, get_local_guided_decoding_logits_processor) from vllm.outputs import EmbeddingRequestOutput, RequestOutput from vllm.pooling_params import PoolingParams from vllm.prompt_adapter.request import PromptAdapterRequest @@ -265,7 +265,8 @@ def generate( use_tqdm: bool = True, lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, - guided_options: Optional[Union[Dict, GuidedDecodingFields]] = None + guided_options_request: Optional[Union[Dict, + GuidedDecodingRequest]] = None ) -> List[RequestOutput]: """Generates the completions for the input prompts. @@ -307,12 +308,13 @@ def generate( else: inputs = cast(Union[PromptInputs, Sequence[PromptInputs]], prompts) - if isinstance(guided_options, dict): - if len(guided_options) > 1: + if isinstance(guided_options_request, dict): + if len(guided_options_request) > 1: raise ValueError( "You can only use one guided decoding but multiple is " - f"specified: {guided_options}") - guided_options = GuidedDecodingFields(**guided_options) + f"specified: {guided_options_request}") + guided_options_request = GuidedDecodingRequest( + **guided_options_request) if sampling_params is None: # Use default sampling params. @@ -323,7 +325,7 @@ def generate( params=sampling_params, lora_request=lora_request, prompt_adapter_request=prompt_adapter_request, - guided_options=guided_options) + guided_options=guided_options_request) outputs = self._run_engine(use_tqdm=use_tqdm) return LLMEngine.validate_outputs(outputs, RequestOutput) @@ -520,7 +522,7 @@ def _validate_and_add_requests( Sequence[PoolingParams]], lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]], prompt_adapter_request: Optional[PromptAdapterRequest], - guided_options: Optional[GuidedDecodingFields] = None, + guided_options: Optional[GuidedDecodingRequest] = None, ) -> None: if isinstance(inputs, (str, dict)): # Convert a single prompt to a list. @@ -572,7 +574,7 @@ def _add_request( def _add_guided_processor( self, params: SamplingParams, - guided_options: Optional[GuidedDecodingFields] = None): + guided_options: Optional[GuidedDecodingRequest] = None): if guided_options: if guided_options.guided_decoding_backend is None: decoding_config = self.llm_engine.get_decoding_config() diff --git a/vllm/model_executor/guided_decoding/__init__.py b/vllm/model_executor/guided_decoding/__init__.py index 4c846bcfddc7c..4a2476dd6314d 100644 --- a/vllm/model_executor/guided_decoding/__init__.py +++ b/vllm/model_executor/guided_decoding/__init__.py @@ -4,7 +4,7 @@ ChatCompletionNamedToolChoiceParam, ChatCompletionRequest, CompletionRequest) from vllm.model_executor.guided_decoding.guided_fields import ( - GuidedDecodingFields) + GuidedDecodingRequest) from vllm.model_executor.guided_decoding.outlines_decoding import ( get_local_outlines_guided_decoding_logits_processor, get_outlines_guided_decoding_logits_processor) @@ -32,7 +32,7 @@ async def get_guided_decoding_logits_processor( def get_local_guided_decoding_logits_processor( - guided_decoding_backend: str, guided_options: GuidedDecodingFields, + guided_decoding_backend: str, guided_options: GuidedDecodingRequest, tokenizer) -> Optional[LogitsProcessor]: # request = _adapt_request_for_tool_use(request) diff --git a/vllm/model_executor/guided_decoding/guided_fields.py b/vllm/model_executor/guided_decoding/guided_fields.py index b48ef24e76b3c..a484955023804 100644 --- a/vllm/model_executor/guided_decoding/guided_fields.py +++ b/vllm/model_executor/guided_decoding/guided_fields.py @@ -5,7 +5,7 @@ @dataclass -class GuidedDecodingFields: +class GuidedDecodingRequest: """One of the fields will be used to retrieve the logit processor.""" guided_json: Optional[Union[Dict, BaseModel, str]] = None guided_regex: Optional[str] = None @@ -18,11 +18,10 @@ class GuidedDecodingFields: def __post_init__(self): """Validate that some fields are mutually exclusive.""" guide_count = sum([ - self.guided_json is not None, - self.guided_regex is not None, - self.guided_choice is not None, - self.guided_grammar is not None, - self.guided_json_object is not None]) + self.guided_json is not None, self.guided_regex is not None, + self.guided_choice is not None, self.guided_grammar is not None, + self.guided_json_object is not None + ]) if guide_count > 1: raise ValueError( "You can only use one kind of guided decoding but multiple are " diff --git a/vllm/model_executor/guided_decoding/lm_format_enforcer_decoding.py b/vllm/model_executor/guided_decoding/lm_format_enforcer_decoding.py index 40906354f62c2..b2188c9cbc2bb 100644 --- a/vllm/model_executor/guided_decoding/lm_format_enforcer_decoding.py +++ b/vllm/model_executor/guided_decoding/lm_format_enforcer_decoding.py @@ -13,7 +13,7 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, CompletionRequest) from vllm.model_executor.guided_decoding.guided_fields import ( - GuidedDecodingFields) + GuidedDecodingRequest) from vllm.model_executor.guided_decoding.outlines_decoding import ( get_local_outlines_guided_decoding_logits_processor, get_outlines_guided_decoding_logits_processor) @@ -58,7 +58,7 @@ async def get_lm_format_enforcer_guided_decoding_logits_processor( def get_local_lm_format_enforcer_guided_decoding_logits_processor( - guided_options: GuidedDecodingFields, + guided_options: GuidedDecodingRequest, tokenizer) -> Optional[LogitsProcessor]: """ Given an OpenAI-compatible request, check for guided decoding parameters diff --git a/vllm/model_executor/guided_decoding/outlines_decoding.py b/vllm/model_executor/guided_decoding/outlines_decoding.py index 487f205dae67f..bc62224dabecf 100644 --- a/vllm/model_executor/guided_decoding/outlines_decoding.py +++ b/vllm/model_executor/guided_decoding/outlines_decoding.py @@ -11,7 +11,7 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, CompletionRequest) from vllm.model_executor.guided_decoding.guided_fields import ( - GuidedDecodingFields) + GuidedDecodingRequest) from vllm.model_executor.guided_decoding.outlines_logits_processors import ( CFGLogitsProcessor, JSONLogitsProcessor, RegexLogitsProcessor) @@ -80,7 +80,7 @@ async def get_outlines_guided_decoding_logits_processor( def get_local_outlines_guided_decoding_logits_processor( - guided_options: GuidedDecodingFields, tokenizer: PreTrainedTokenizerBase + guided_options: GuidedDecodingRequest, tokenizer: PreTrainedTokenizerBase ) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor, None]: """ @@ -99,7 +99,7 @@ def get_local_outlines_guided_decoding_logits_processor( def _get_guide_and_mode( request: Union[CompletionRequest, ChatCompletionRequest, - GuidedDecodingFields] + GuidedDecodingRequest] ) -> Union[Tuple[str, GuidedDecodingMode], Tuple[None, None]]: if request.guided_json: @@ -123,7 +123,7 @@ def _get_guide_and_mode( return choices_regex, GuidedDecodingMode.CHOICE elif request.guided_grammar: return request.guided_grammar, GuidedDecodingMode.GRAMMAR - elif (not isinstance(request, GuidedDecodingFields) + elif (not isinstance(request, GuidedDecodingRequest) and request.response_format is not None and request.response_format.type == "json_object"): return JSON_GRAMMAR, GuidedDecodingMode.GRAMMAR From dad6ffb33c7becd398402b2016e1424fe10b7880 Mon Sep 17 00:00:00 2001 From: kevinbu233 Date: Wed, 31 Jul 2024 16:43:50 +0000 Subject: [PATCH 04/12] added TypedDict --- vllm/entrypoints/llm.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 5930e81b94c8b..0560687732f98 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -1,7 +1,8 @@ from contextlib import contextmanager -from typing import (ClassVar, Dict, List, Optional, Sequence, Union, cast, - overload) +from typing import (ClassVar, Dict, List, Optional, Sequence, TypedDict, Union, + cast, overload) +from pydantic import BaseModel from tqdm import tqdm from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast @@ -24,6 +25,16 @@ logger = init_logger(__name__) +class LLMGuidedOptions(TypedDict, total=False): + guided_json: Union[Dict, BaseModel, str] + guided_regex: str + guided_choice: List[str] + guided_grammar: str + guided_decoding_backend: str + guided_whitespace_pattern: str + guided_json_object: bool + + class LLM: """An LLM for generating texts from given prompts and sampling parameters. @@ -265,7 +276,7 @@ def generate( use_tqdm: bool = True, lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, - guided_options_request: Optional[Union[Dict, + guided_options_request: Optional[Union[LLMGuidedOptions, GuidedDecodingRequest]] = None ) -> List[RequestOutput]: """Generates the completions for the input prompts. From 14d3a4107d26aa27f1fe54be8367c55332bbb910 Mon Sep 17 00:00:00 2001 From: kevinbu233 Date: Thu, 1 Aug 2024 04:08:46 +0000 Subject: [PATCH 05/12] fixed typed dict location --- vllm/entrypoints/llm.py | 15 ++------------- .../guided_decoding/guided_fields.py | 12 +++++++++++- 2 files changed, 13 insertions(+), 14 deletions(-) diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 0560687732f98..25a555401c78f 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -1,8 +1,6 @@ from contextlib import contextmanager -from typing import (ClassVar, Dict, List, Optional, Sequence, TypedDict, Union, - cast, overload) +from typing import ClassVar, List, Optional, Sequence, Union, cast, overload -from pydantic import BaseModel from tqdm import tqdm from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast @@ -14,6 +12,7 @@ from vllm.lora.request import LoRARequest from vllm.model_executor.guided_decoding import ( GuidedDecodingRequest, get_local_guided_decoding_logits_processor) +from vllm.model_executor.guided_decoding.guided_fields import LLMGuidedOptions from vllm.outputs import EmbeddingRequestOutput, RequestOutput from vllm.pooling_params import PoolingParams from vllm.prompt_adapter.request import PromptAdapterRequest @@ -25,16 +24,6 @@ logger = init_logger(__name__) -class LLMGuidedOptions(TypedDict, total=False): - guided_json: Union[Dict, BaseModel, str] - guided_regex: str - guided_choice: List[str] - guided_grammar: str - guided_decoding_backend: str - guided_whitespace_pattern: str - guided_json_object: bool - - class LLM: """An LLM for generating texts from given prompts and sampling parameters. diff --git a/vllm/model_executor/guided_decoding/guided_fields.py b/vllm/model_executor/guided_decoding/guided_fields.py index a484955023804..3082ac1510ccc 100644 --- a/vllm/model_executor/guided_decoding/guided_fields.py +++ b/vllm/model_executor/guided_decoding/guided_fields.py @@ -1,9 +1,19 @@ from dataclasses import dataclass -from typing import Dict, List, Optional, Union +from typing import Dict, List, Optional, TypedDict, Union from pydantic import BaseModel +class LLMGuidedOptions(TypedDict, total=False): + guided_json: Union[Dict, BaseModel, str] + guided_regex: str + guided_choice: List[str] + guided_grammar: str + guided_decoding_backend: str + guided_whitespace_pattern: str + guided_json_object: bool + + @dataclass class GuidedDecodingRequest: """One of the fields will be used to retrieve the logit processor.""" From f5f357b311e4eeb265868e5f8ceed546cb27af76 Mon Sep 17 00:00:00 2001 From: kevinbu233 Date: Thu, 1 Aug 2024 22:05:59 +0000 Subject: [PATCH 06/12] fix styles --- tests/entrypoints/conftest.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/entrypoints/conftest.py b/tests/entrypoints/conftest.py index 738ebc056deda..dbabbd3ccd550 100644 --- a/tests/entrypoints/conftest.py +++ b/tests/entrypoints/conftest.py @@ -1,5 +1,6 @@ import pytest + @pytest.fixture def sample_prompts(): return [ From 7e0a8463bcc586a11ca10af89a5a3ecbdbe2b167 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Sat, 3 Aug 2024 13:56:46 +0800 Subject: [PATCH 07/12] yapf --- tests/entrypoints/conftest.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/entrypoints/conftest.py b/tests/entrypoints/conftest.py index dbabbd3ccd550..61fa5328f2b81 100644 --- a/tests/entrypoints/conftest.py +++ b/tests/entrypoints/conftest.py @@ -13,10 +13,10 @@ def sample_prompts(): @pytest.fixture def sample_token_ids(): return [ - [0], - [0, 1], - [0, 2, 1], - [0, 3, 1, 2], + [0], + [0, 1], + [0, 2, 1], + [0, 3, 1, 2], ] @@ -85,4 +85,4 @@ def sample_sql_statements(): table: "table_1" | "table_2" condition: column "=" number number: "1" | "2" -""") \ No newline at end of file +""") From 5f6256843305d7823d5041d3d34aba7d5f0c2b1b Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Sat, 3 Aug 2024 13:57:11 +0800 Subject: [PATCH 08/12] yapf --- tests/entrypoints/conftest.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/entrypoints/conftest.py b/tests/entrypoints/conftest.py index 61fa5328f2b81..84c41bc46da9e 100644 --- a/tests/entrypoints/conftest.py +++ b/tests/entrypoints/conftest.py @@ -13,10 +13,10 @@ def sample_prompts(): @pytest.fixture def sample_token_ids(): return [ - [0], - [0, 1], - [0, 2, 1], - [0, 3, 1, 2], + [0], + [0, 1], + [0, 2, 1], + [0, 3, 1, 2], ] From 6e61c5fae0da44fef45e1173f02440589565f1fc Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Sat, 3 Aug 2024 14:04:22 +0800 Subject: [PATCH 09/12] yapf --- tests/entrypoints/conftest.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/entrypoints/conftest.py b/tests/entrypoints/conftest.py index 84c41bc46da9e..e7ef5637c8ccb 100644 --- a/tests/entrypoints/conftest.py +++ b/tests/entrypoints/conftest.py @@ -10,6 +10,7 @@ def sample_prompts(): "The future of AI is", ] + @pytest.fixture def sample_token_ids(): return [ From fdde57253ecfdb2294883d85a727707c7438f87f Mon Sep 17 00:00:00 2001 From: kevinbu233 Date: Sat, 3 Aug 2024 18:41:10 +0000 Subject: [PATCH 10/12] fix param bug --- vllm/entrypoints/llm.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 25a555401c78f..262cba79e5712 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -541,7 +541,8 @@ def _validate_and_add_requests( if isinstance(params, list): params = [ self._add_guided_processor(param, guided_options) - for param in params if isinstance(param, SamplingParams) + if isinstance(param, SamplingParams) else param + for param in params ] elif isinstance(params, SamplingParams): params = self._add_guided_processor(params, guided_options) From 1bdeb0eae01cd3be38c7ea17b96523c38a0a4838 Mon Sep 17 00:00:00 2001 From: kevinbu233 Date: Sat, 3 Aug 2024 22:06:59 +0000 Subject: [PATCH 11/12] fixed guided test --- tests/entrypoints/llm/test_guided_generate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/entrypoints/llm/test_guided_generate.py b/tests/entrypoints/llm/test_guided_generate.py index 977aca0441bb6..873e115421257 100644 --- a/tests/entrypoints/llm/test_guided_generate.py +++ b/tests/entrypoints/llm/test_guided_generate.py @@ -111,6 +111,7 @@ def test_guided_grammar(sample_sql_statements, llm): sampling_params = SamplingParams( temperature=0.8, top_p=0.95, + max_tokens=1000, ) outputs = llm.generate( prompts=("Generate a sql state that select col_1 from " @@ -127,7 +128,6 @@ def test_guided_grammar(sample_sql_statements, llm): generated_text = output.outputs[0].text assert generated_text is not None - # use Lark to parse the output, and make sure it's a valid parse tree from lark import Lark parser = Lark(sample_sql_statements) From ea5b10ad31898b100ec13e42d0bdefcb6ba41114 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Sun, 4 Aug 2024 01:20:12 +0000 Subject: [PATCH 12/12] Fix docs error --- docs/source/conf.py | 1 + vllm/entrypoints/openai/protocol.py | 26 ++++++++++++++++++++------ 2 files changed, 21 insertions(+), 6 deletions(-) diff --git a/docs/source/conf.py b/docs/source/conf.py index 1093b30bca11d..f1eb8524d4e9c 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -111,6 +111,7 @@ def setup(app): "tqdm", "tensorizer", "pynvml", + "outlines", ] for mock_target in autodoc_mock_imports: diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 3b35ae1ebd705..76318a1271229 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -1,6 +1,7 @@ # Adapted from # https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py import time +from argparse import Namespace from typing import Any, Dict, List, Literal, Optional, Union import torch @@ -14,6 +15,23 @@ from vllm.sampling_params import LogitsProcessor, SamplingParams from vllm.utils import random_uuid +# torch is mocked during docs generation, +# so we have to provide the values as literals +_MOCK_LONG_INFO = Namespace(min=-9223372036854775808, max=9223372036854775807) + +try: + from sphinx.ext.autodoc.mock import _MockModule + + if isinstance(torch, _MockModule): + _LONG_INFO = _MOCK_LONG_INFO + else: + _LONG_INFO = torch.iinfo(torch.long) +except ModuleNotFoundError: + _LONG_INFO = torch.iinfo(torch.long) + +assert _LONG_INFO.min == _MOCK_LONG_INFO.min +assert _LONG_INFO.max == _MOCK_LONG_INFO.max + class OpenAIBaseModel(BaseModel): # OpenAI API does not allow extra fields @@ -108,9 +126,7 @@ class ChatCompletionRequest(OpenAIBaseModel): n: Optional[int] = 1 presence_penalty: Optional[float] = 0.0 response_format: Optional[ResponseFormat] = None - seed: Optional[int] = Field(None, - ge=torch.iinfo(torch.long).min, - le=torch.iinfo(torch.long).max) + seed: Optional[int] = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max) stop: Optional[Union[str, List[str]]] = Field(default_factory=list) stream: Optional[bool] = False stream_options: Optional[StreamOptions] = None @@ -327,9 +343,7 @@ class CompletionRequest(OpenAIBaseModel): max_tokens: Optional[int] = 16 n: int = 1 presence_penalty: Optional[float] = 0.0 - seed: Optional[int] = Field(None, - ge=torch.iinfo(torch.long).min, - le=torch.iinfo(torch.long).max) + seed: Optional[int] = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max) stop: Optional[Union[str, List[str]]] = Field(default_factory=list) stream: Optional[bool] = False stream_options: Optional[StreamOptions] = None