From 4fb4d698ea3955db28f60efb2ba00a9d856ab255 Mon Sep 17 00:00:00 2001 From: Seongjin Lee Date: Thu, 17 Oct 2024 02:39:42 +0900 Subject: [PATCH] Enhance Structured Output Interface (#644) --- clients/python/lorax/types.py | 2 +- docs/guides/structured_output.md | 133 ++++++++++++++++++++++++++++--- router/src/lib.rs | 40 +++++++++- router/src/server.rs | 51 ++++++++++-- router/src/validation.rs | 7 +- 5 files changed, 213 insertions(+), 20 deletions(-) diff --git a/clients/python/lorax/types.py b/clients/python/lorax/types.py index 8a5be4fe3..f70984e1c 100644 --- a/clients/python/lorax/types.py +++ b/clients/python/lorax/types.py @@ -64,7 +64,7 @@ class ResponseFormat(BaseModel): model_config = ConfigDict(use_enum_values=True) type: ResponseFormatType - schema_spec: Union[Dict[str, Any], OrderedDict] = Field(alias="schema") + schema_spec: Optional[Union[Dict[str, Any], OrderedDict]] = Field(None, alias="schema") class Parameters(BaseModel): diff --git a/docs/guides/structured_output.md b/docs/guides/structured_output.md index 26ae5dac0..1071a1283 100644 --- a/docs/guides/structured_output.md +++ b/docs/guides/structured_output.md @@ -35,7 +35,7 @@ valid next tokens using this FSM and sets the likelihood of invalid tokens to `- This example follows the [JSON-structured generation example](https://outlines-dev.github.io/outlines/quickstart/#json-structured-generation) in the Outlines quickstart. We assume that you have already deployed LoRAX using a suitable base model and installed the [LoRAX Python Client](../reference/python_client.md). -Alternatively, see [below](structured_output.md#openai-compatible-api) for an example of structured generation using an +Alternatively, see [below](structured_output.md#example-openai-compatible-api) for an example of structured generation using an OpenAI client. ```python @@ -60,14 +60,36 @@ class Character(BaseModel): client = Client("http://127.0.0.1:8080") -prompt = "Generate a new character for my awesome game: name, age (between 1 and 99), armor and strength. " -response = client.generate(prompt, response_format={ +# Example 1: Using a schema +prompt_with_schema = "Generate a new character for my awesome game: name, age (between 1 and 99), armor and strength." +response_with_schema = client.generate(prompt_with_schema, response_format={ "type": "json_object", "schema": Character.model_json_schema(), }) -my_character = json.loads(response.generated_text) -print(my_character) +my_character_with_schema = json.loads(response_with_schema.generated_text)\ +print(my_character_with_schema) +# { +# "name": "Thorin", +# "age": 45, +# "armor": "plate", +# "strength": 90 +# } + +# Example 2: Without a schema (arbitrary JSON) +prompt_without_schema = "Generate a new character for my awesome game: name, age (between 1 and 99), armor and strength." +response_without_schema = client.generate(prompt_without_schema, response_format={ + "type": "json_object", # No schema provided +}) + +my_character_without_schema = json.loads(response_without_schema.generated_text) +print(my_character_without_schema) +# { +# "characterName": "Aragon", +# "age": 38, +# "armorType": "chainmail", +# "power": 78 +# } ``` You can also specify the JSON schema directly rather than using Pydantic: @@ -99,7 +121,88 @@ Structured generation of JSON following a schema is supported via the `response_ !!! note - Currently a schema is **required**. This differs from the existing OpenAI JSON mode, in which no schema is supported. + Currently, `response_format` in OpenAI interface differs slightly from the LoRAX request interface. + When calling the OpenAI-compatible API, you should format the request exactly as specified in the official documentation. + For more details, refer to the OpenAI documentation here: https://platform.openai.com/docs/api-reference/chat/create#chat-create-response_format. + +#### Type 1: `text` (default) + +- This is the standard mode where the model generates plain text output. +- In this example, the model simply returns plain text output. + +```python +from openai import OpenAI + +client = OpenAI( + api_key="EMPTY", + base_url="http://127.0.0.1:8080/v1", +) + +resp = client.chat.completions.create( + model="", # optional: specify an adapter ID here + messages=[ + { + "role": "user", + "content": "Describe a medieval fantasy character.", + }, + ], + max_tokens=100, + response_format={ + "type": "text", # Default response type, plain text output + }, +) + +print(resp.choices[0].message.content) + +''' +Sir Alaric is a noble knight of the realm. At the age of 35, he dons a suit of shining plate armor, protecting his strong, muscular frame. His strength is unparalleled in the kingdom, allowing him to wield his massive greatsword with ease. +''' +``` + +#### Type 2: `json_object` + +- This mode outputs arbitrary JSON objects, making it ideal for generating data in a flexible JSON format without enforcing any schema. It's similar to OpenAI’s JSON mode. +- In this example, the model returns an arbitrary JSON object without enforcing a predefined schema. + +```python +from openai import OpenAI + +client = OpenAI( + api_key="EMPTY", + base_url="http://127.0.0.1:8080/v1", +) + +resp = client.chat.completions.create( + model="", # optional: specify an adapter ID here + messages=[ + { + "role": "user", + "content": "Generate a new character for my game: name, age, armor type, and strength.", + }, + ], + max_tokens=100, + response_format={ + "type": "json_object", # Generate arbitrary JSON without a schema + }, +) + +my_character = json.loads(resp.choices[0].message.content) +print(my_character) + +''' +{ + "name": "Eldrin", + "age": 27, + "armor": "Dragonscale Armor", + "strength": "Fire Resistance" +} +''' +``` + +#### Type 3: `json_schema` + +- The model returns a structured JSON object that adheres to the predefined schema. This ensures that the JSON follows the format of the `Character` model provided earlier. +- In this example, the model generates structured JSON output that adheres to a predefined schema. ```python import json @@ -131,18 +234,30 @@ resp = client.chat.completions.create( messages=[ { "role": "user", - "content": "Generate a new character for my awesome game: name, age (between 1 and 99), armor and strength. ", + "content": "Generate a new character for my game: name, age (between 1 and 99), armor, and strength.", }, ], max_tokens=100, response_format={ - "type": "json_object", - "schema": Character.model_json_schema(), + "type": "json_schema", # Generate structured JSON output based on a schema + "json_schema": { + "name": "Character", # Name of the schema + "schema": Character.model_json_schema(), # The JSON schema generated by Pydantic + }, }, ) my_character = json.loads(resp.choices[0].message.content) print(my_character) + +''' +{ + "name": "Thorin", + "age": 45, + "armor": "plate", + "strength": 90 +} +''' ``` diff --git a/router/src/lib.rs b/router/src/lib.rs index b15612237..f8bdb8f3f 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -478,15 +478,51 @@ struct UsageInfo { #[derive(Clone, Debug, Deserialize, ToSchema)] enum ResponseFormatType { + #[serde(alias = "text")] + Text, #[serde(alias = "json_object")] JsonObject, + #[serde(alias = "json_schema")] + JsonSchema, } #[derive(Clone, Debug, Deserialize, ToSchema)] struct ResponseFormat { #[allow(dead_code)] // For now allow this field even though it is unused r#type: ResponseFormatType, - schema: serde_json::Value, // TODO: make this optional once arbitrary JSON object is supported in Outlines + + #[serde(default = "default_json_schema")] + schema: Option, +} + +// Default schema to be used when no value is provided +fn default_json_schema() -> Option { + Some(serde_json::json!({ + "additionalProperties": { + "type": ["object", "string", "integer", "number", "boolean", "null"] + }, + "title": "ArbitraryJsonModel", + "type": "object" + })) +} + +#[derive(Clone, Debug, Deserialize, Serialize, ToSchema)] +struct JsonSchema { + #[allow(dead_code)] // For now allow this field even though it is unused + description: Option, + #[allow(dead_code)] // For now allow this field even though it is unused + name: String, + schema: Option, + #[allow(dead_code)] // For now allow this field even though it is unused + strict: Option, +} + +// TODO check if json_schema field is required if type is json_schema +#[derive(Clone, Debug, Deserialize, ToSchema)] +struct OpenAiResponseFormat { + #[serde(rename(deserialize = "type"))] + response_format_type: ResponseFormatType, + json_schema: Option, } #[derive(Clone, Deserialize, ToSchema, Serialize, Debug, PartialEq)] @@ -582,9 +618,9 @@ struct ChatCompletionRequest { #[allow(dead_code)] // For now allow this field even though it is unused user: Option, seed: Option, + response_format: Option, // Additional parameters // TODO(travis): add other LoRAX params here - response_format: Option, repetition_penalty: Option, top_k: Option, ignore_eos_token: Option, diff --git a/router/src/server.rs b/router/src/server.rs index bb2b7a2d2..7a2480c3d 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -4,18 +4,18 @@ use crate::config::Config; use crate::health::Health; use crate::infer::{InferError, InferResponse, InferStreamResponse}; use crate::validation::ValidationError; -use crate::{json, HubPreprocessorConfig, HubProcessorConfig, HubTokenizerConfig}; use crate::{ - AdapterParameters, AlternativeToken, BatchClassifyRequest, BestOfSequence, + default_json_schema, AdapterParameters, AlternativeToken, BatchClassifyRequest, BestOfSequence, ChatCompletionRequest, ChatCompletionResponse, ChatCompletionResponseChoice, ChatCompletionStreamResponse, ChatCompletionStreamResponseChoice, ChatMessage, ClassifyRequest, CompatGenerateRequest, CompletionFinishReason, CompletionRequest, CompletionResponse, CompletionResponseChoice, CompletionResponseStreamChoice, CompletionStreamResponse, Details, EmbedRequest, EmbedResponse, Entity, ErrorResponse, FinishReason, GenerateParameters, - GenerateRequest, GenerateResponse, HubModelInfo, Infer, Info, LogProbs, PrefillToken, - ResponseFormat, ResponseFormatType, SimpleToken, StreamDetails, StreamResponse, Token, - TokenizeRequest, TokenizeResponse, UsageInfo, Validation, + GenerateRequest, GenerateResponse, HubModelInfo, Infer, Info, JsonSchema, LogProbs, + OpenAiResponseFormat, PrefillToken, ResponseFormat, ResponseFormatType, SimpleToken, + StreamDetails, StreamResponse, Token, TokenizeRequest, TokenizeResponse, UsageInfo, Validation, }; +use crate::{json, HubPreprocessorConfig, HubProcessorConfig, HubTokenizerConfig}; use axum::extract::Extension; use axum::http::{HeaderMap, Method, StatusCode}; use axum::response::sse::{Event, KeepAlive, Sse}; @@ -263,6 +263,43 @@ async fn chat_completions_v1( adapter_id = None; } + // Modify input values to ResponseFormat to be OpenAI API compatible + let response_format: Option = match req.response_format { + None => None, + Some(openai_format) => { + let response_format_type = openai_format.response_format_type.clone(); + match response_format_type { + // Ignore when type is text + ResponseFormatType::Text => None, + + // For json_object, use the fixed schema + ResponseFormatType::JsonObject => Some(ResponseFormat { + r#type: response_format_type.clone(), + schema: default_json_schema(), + }), + + // For json_schema, use schema_value if available, otherwise fallback to the fixed schema + ResponseFormatType::JsonSchema => openai_format + .json_schema + .and_then(|schema| schema.schema) + .map_or_else( + || { + Some(ResponseFormat { + r#type: response_format_type.clone(), + schema: default_json_schema(), + }) + }, + |schema_value: serde_json::Value| { + Some(ResponseFormat { + r#type: response_format_type.clone(), + schema: Some(schema_value), + }) + }, + ), + } + } + }; + let mut gen_req = CompatGenerateRequest { inputs: inputs.to_string(), parameters: GenerateParameters { @@ -288,7 +325,7 @@ async fn chat_completions_v1( return_k_alternatives: None, apply_chat_template: false, seed: req.seed, - response_format: req.response_format, + response_format: response_format, }, stream: req.stream.unwrap_or(false), }; @@ -1115,6 +1152,8 @@ pub async fn run( UsageInfo, ResponseFormat, ResponseFormatType, + OpenAiResponseFormat, + JsonSchema, CompatGenerateRequest, GenerateRequest, GenerateParameters, diff --git a/router/src/validation.rs b/router/src/validation.rs index be4ed553e..5867cc014 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -331,8 +331,11 @@ impl Validation { let mut schema: Option = None; if response_format.is_some() { - let response_format_val = response_format.unwrap(); - schema = Some(response_format_val.schema.to_string()) + if let Some(response_format_val) = response_format { + if let Some(schema_value) = response_format_val.schema { + schema = Some(schema_value.to_string()); + } + } } let parameters = NextTokenChooserParameters {