Skip to content

Commit

Permalink
Enhance Structured Output Interface (#644)
Browse files Browse the repository at this point in the history
  • Loading branch information
GirinMan authored Oct 16, 2024
1 parent 8ac729b commit 4fb4d69
Show file tree
Hide file tree
Showing 5 changed files with 213 additions and 20 deletions.
2 changes: 1 addition & 1 deletion clients/python/lorax/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ class ResponseFormat(BaseModel):
model_config = ConfigDict(use_enum_values=True)

type: ResponseFormatType
schema_spec: Union[Dict[str, Any], OrderedDict] = Field(alias="schema")
schema_spec: Optional[Union[Dict[str, Any], OrderedDict]] = Field(None, alias="schema")


class Parameters(BaseModel):
Expand Down
133 changes: 124 additions & 9 deletions docs/guides/structured_output.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ valid next tokens using this FSM and sets the likelihood of invalid tokens to `-
This example follows the [JSON-structured generation example](https://outlines-dev.github.io/outlines/quickstart/#json-structured-generation) in the Outlines quickstart.

We assume that you have already deployed LoRAX using a suitable base model and installed the [LoRAX Python Client](../reference/python_client.md).
Alternatively, see [below](structured_output.md#openai-compatible-api) for an example of structured generation using an
Alternatively, see [below](structured_output.md#example-openai-compatible-api) for an example of structured generation using an
OpenAI client.

```python
Expand All @@ -60,14 +60,36 @@ class Character(BaseModel):

client = Client("http://127.0.0.1:8080")

prompt = "Generate a new character for my awesome game: name, age (between 1 and 99), armor and strength. "
response = client.generate(prompt, response_format={
# Example 1: Using a schema
prompt_with_schema = "Generate a new character for my awesome game: name, age (between 1 and 99), armor and strength."
response_with_schema = client.generate(prompt_with_schema, response_format={
"type": "json_object",
"schema": Character.model_json_schema(),
})

my_character = json.loads(response.generated_text)
print(my_character)
my_character_with_schema = json.loads(response_with_schema.generated_text)\
print(my_character_with_schema)
# {
# "name": "Thorin",
# "age": 45,
# "armor": "plate",
# "strength": 90
# }

# Example 2: Without a schema (arbitrary JSON)
prompt_without_schema = "Generate a new character for my awesome game: name, age (between 1 and 99), armor and strength."
response_without_schema = client.generate(prompt_without_schema, response_format={
"type": "json_object", # No schema provided
})

my_character_without_schema = json.loads(response_without_schema.generated_text)
print(my_character_without_schema)
# {
# "characterName": "Aragon",
# "age": 38,
# "armorType": "chainmail",
# "power": 78
# }
```

You can also specify the JSON schema directly rather than using Pydantic:
Expand Down Expand Up @@ -99,7 +121,88 @@ Structured generation of JSON following a schema is supported via the `response_

!!! note

Currently a schema is **required**. This differs from the existing OpenAI JSON mode, in which no schema is supported.
Currently, `response_format` in OpenAI interface differs slightly from the LoRAX request interface.
When calling the OpenAI-compatible API, you should format the request exactly as specified in the official documentation.
For more details, refer to the OpenAI documentation here: https://platform.openai.com/docs/api-reference/chat/create#chat-create-response_format.

#### Type 1: `text` (default)

- This is the standard mode where the model generates plain text output.
- In this example, the model simply returns plain text output.

```python
from openai import OpenAI

client = OpenAI(
api_key="EMPTY",
base_url="http://127.0.0.1:8080/v1",
)

resp = client.chat.completions.create(
model="", # optional: specify an adapter ID here
messages=[
{
"role": "user",
"content": "Describe a medieval fantasy character.",
},
],
max_tokens=100,
response_format={
"type": "text", # Default response type, plain text output
},
)

print(resp.choices[0].message.content)

'''
Sir Alaric is a noble knight of the realm. At the age of 35, he dons a suit of shining plate armor, protecting his strong, muscular frame. His strength is unparalleled in the kingdom, allowing him to wield his massive greatsword with ease.
'''
```

#### Type 2: `json_object`

- This mode outputs arbitrary JSON objects, making it ideal for generating data in a flexible JSON format without enforcing any schema. It's similar to OpenAI’s JSON mode.
- In this example, the model returns an arbitrary JSON object without enforcing a predefined schema.

```python
from openai import OpenAI

client = OpenAI(
api_key="EMPTY",
base_url="http://127.0.0.1:8080/v1",
)

resp = client.chat.completions.create(
model="", # optional: specify an adapter ID here
messages=[
{
"role": "user",
"content": "Generate a new character for my game: name, age, armor type, and strength.",
},
],
max_tokens=100,
response_format={
"type": "json_object", # Generate arbitrary JSON without a schema
},
)

my_character = json.loads(resp.choices[0].message.content)
print(my_character)

'''
{
"name": "Eldrin",
"age": 27,
"armor": "Dragonscale Armor",
"strength": "Fire Resistance"
}
'''
```

#### Type 3: `json_schema`

- The model returns a structured JSON object that adheres to the predefined schema. This ensures that the JSON follows the format of the `Character` model provided earlier.
- In this example, the model generates structured JSON output that adheres to a predefined schema.

```python
import json
Expand Down Expand Up @@ -131,18 +234,30 @@ resp = client.chat.completions.create(
messages=[
{
"role": "user",
"content": "Generate a new character for my awesome game: name, age (between 1 and 99), armor and strength. ",
"content": "Generate a new character for my game: name, age (between 1 and 99), armor, and strength.",
},
],
max_tokens=100,
response_format={
"type": "json_object",
"schema": Character.model_json_schema(),
"type": "json_schema", # Generate structured JSON output based on a schema
"json_schema": {
"name": "Character", # Name of the schema
"schema": Character.model_json_schema(), # The JSON schema generated by Pydantic
},
},
)

my_character = json.loads(resp.choices[0].message.content)
print(my_character)

'''
{
"name": "Thorin",
"age": 45,
"armor": "plate",
"strength": 90
}
'''
```


40 changes: 38 additions & 2 deletions router/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -478,15 +478,51 @@ struct UsageInfo {

#[derive(Clone, Debug, Deserialize, ToSchema)]
enum ResponseFormatType {
#[serde(alias = "text")]
Text,
#[serde(alias = "json_object")]
JsonObject,
#[serde(alias = "json_schema")]
JsonSchema,
}

#[derive(Clone, Debug, Deserialize, ToSchema)]
struct ResponseFormat {
#[allow(dead_code)] // For now allow this field even though it is unused
r#type: ResponseFormatType,
schema: serde_json::Value, // TODO: make this optional once arbitrary JSON object is supported in Outlines

#[serde(default = "default_json_schema")]
schema: Option<serde_json::Value>,
}

// Default schema to be used when no value is provided
fn default_json_schema() -> Option<serde_json::Value> {
Some(serde_json::json!({
"additionalProperties": {
"type": ["object", "string", "integer", "number", "boolean", "null"]
},
"title": "ArbitraryJsonModel",
"type": "object"
}))
}

#[derive(Clone, Debug, Deserialize, Serialize, ToSchema)]
struct JsonSchema {
#[allow(dead_code)] // For now allow this field even though it is unused
description: Option<String>,
#[allow(dead_code)] // For now allow this field even though it is unused
name: String,
schema: Option<serde_json::Value>,
#[allow(dead_code)] // For now allow this field even though it is unused
strict: Option<bool>,
}

// TODO check if json_schema field is required if type is json_schema
#[derive(Clone, Debug, Deserialize, ToSchema)]
struct OpenAiResponseFormat {
#[serde(rename(deserialize = "type"))]
response_format_type: ResponseFormatType,
json_schema: Option<JsonSchema>,
}

#[derive(Clone, Deserialize, ToSchema, Serialize, Debug, PartialEq)]
Expand Down Expand Up @@ -582,9 +618,9 @@ struct ChatCompletionRequest {
#[allow(dead_code)] // For now allow this field even though it is unused
user: Option<String>,
seed: Option<u64>,
response_format: Option<OpenAiResponseFormat>,
// Additional parameters
// TODO(travis): add other LoRAX params here
response_format: Option<ResponseFormat>,
repetition_penalty: Option<f32>,
top_k: Option<i32>,
ignore_eos_token: Option<bool>,
Expand Down
51 changes: 45 additions & 6 deletions router/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,18 @@ use crate::config::Config;
use crate::health::Health;
use crate::infer::{InferError, InferResponse, InferStreamResponse};
use crate::validation::ValidationError;
use crate::{json, HubPreprocessorConfig, HubProcessorConfig, HubTokenizerConfig};
use crate::{
AdapterParameters, AlternativeToken, BatchClassifyRequest, BestOfSequence,
default_json_schema, AdapterParameters, AlternativeToken, BatchClassifyRequest, BestOfSequence,
ChatCompletionRequest, ChatCompletionResponse, ChatCompletionResponseChoice,
ChatCompletionStreamResponse, ChatCompletionStreamResponseChoice, ChatMessage, ClassifyRequest,
CompatGenerateRequest, CompletionFinishReason, CompletionRequest, CompletionResponse,
CompletionResponseChoice, CompletionResponseStreamChoice, CompletionStreamResponse, Details,
EmbedRequest, EmbedResponse, Entity, ErrorResponse, FinishReason, GenerateParameters,
GenerateRequest, GenerateResponse, HubModelInfo, Infer, Info, LogProbs, PrefillToken,
ResponseFormat, ResponseFormatType, SimpleToken, StreamDetails, StreamResponse, Token,
TokenizeRequest, TokenizeResponse, UsageInfo, Validation,
GenerateRequest, GenerateResponse, HubModelInfo, Infer, Info, JsonSchema, LogProbs,
OpenAiResponseFormat, PrefillToken, ResponseFormat, ResponseFormatType, SimpleToken,
StreamDetails, StreamResponse, Token, TokenizeRequest, TokenizeResponse, UsageInfo, Validation,
};
use crate::{json, HubPreprocessorConfig, HubProcessorConfig, HubTokenizerConfig};
use axum::extract::Extension;
use axum::http::{HeaderMap, Method, StatusCode};
use axum::response::sse::{Event, KeepAlive, Sse};
Expand Down Expand Up @@ -263,6 +263,43 @@ async fn chat_completions_v1(
adapter_id = None;
}

// Modify input values to ResponseFormat to be OpenAI API compatible
let response_format: Option<ResponseFormat> = match req.response_format {
None => None,
Some(openai_format) => {
let response_format_type = openai_format.response_format_type.clone();
match response_format_type {
// Ignore when type is text
ResponseFormatType::Text => None,

// For json_object, use the fixed schema
ResponseFormatType::JsonObject => Some(ResponseFormat {
r#type: response_format_type.clone(),
schema: default_json_schema(),
}),

// For json_schema, use schema_value if available, otherwise fallback to the fixed schema
ResponseFormatType::JsonSchema => openai_format
.json_schema
.and_then(|schema| schema.schema)
.map_or_else(
|| {
Some(ResponseFormat {
r#type: response_format_type.clone(),
schema: default_json_schema(),
})
},
|schema_value: serde_json::Value| {
Some(ResponseFormat {
r#type: response_format_type.clone(),
schema: Some(schema_value),
})
},
),
}
}
};

let mut gen_req = CompatGenerateRequest {
inputs: inputs.to_string(),
parameters: GenerateParameters {
Expand All @@ -288,7 +325,7 @@ async fn chat_completions_v1(
return_k_alternatives: None,
apply_chat_template: false,
seed: req.seed,
response_format: req.response_format,
response_format: response_format,
},
stream: req.stream.unwrap_or(false),
};
Expand Down Expand Up @@ -1115,6 +1152,8 @@ pub async fn run(
UsageInfo,
ResponseFormat,
ResponseFormatType,
OpenAiResponseFormat,
JsonSchema,
CompatGenerateRequest,
GenerateRequest,
GenerateParameters,
Expand Down
7 changes: 5 additions & 2 deletions router/src/validation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -331,8 +331,11 @@ impl Validation {

let mut schema: Option<String> = None;
if response_format.is_some() {
let response_format_val = response_format.unwrap();
schema = Some(response_format_val.schema.to_string())
if let Some(response_format_val) = response_format {
if let Some(schema_value) = response_format_val.schema {
schema = Some(schema_value.to_string());
}
}
}

let parameters = NextTokenChooserParameters {
Expand Down

0 comments on commit 4fb4d69

Please sign in to comment.