Skip to content

Commit

Permalink
fix: OpenAI response format (#184)
Browse files Browse the repository at this point in the history
  • Loading branch information
tgaddair authored Jan 16, 2024
1 parent 5008c1b commit 76fcc5d
Show file tree
Hide file tree
Showing 6 changed files with 127 additions and 13 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ resp = client.chat.completions.create(
],
max_tokens=100,
)
print("Response:", resp[0].choices[0].text)
print("Response:", resp[0].choices[0].message.content)
```

See [OpenAI Compatible API](https://predibase.github.io/lorax/guides/openai_api) for details.
Expand Down
2 changes: 1 addition & 1 deletion docs/guides/openai_api.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ resp = client.chat.completions.create(
],
max_tokens=100,
)
print("Response:", resp[0].choices[0].text)
print("Response:", resp[0].choices[0].message.content)
```

### Streaming
Expand Down
2 changes: 1 addition & 1 deletion docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ resp = client.chat.completions.create(
],
max_tokens=100,
)
print("Response:", resp[0].choices[0].text)
print("Response:", resp[0].choices[0].message.content)
```

See [OpenAI Compatible API](./guides/openai_api.md) for details.
Expand Down
40 changes: 38 additions & 2 deletions docs/reference/openapi.json
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@
},
"text/event-stream": {
"schema": {
"$ref": "#/components/schemas/CompletionStreamResponse"
"$ref": "#/components/schemas/ChatCompletionStreamResponse"
}
}
}
Expand Down Expand Up @@ -989,17 +989,53 @@
}
}
},
"ChatCompletionStreamResponse": {
"type": "object",
"properties": {
"id": {
"type": "string"
},
"object": {
"type": "string"
},
"created": {
"type": "integer",
"format": "int64"
},
"model": {
"type": "string"
},
"choices": {
"type": "array",
"items": {
"$ref": "#/components/schemas/ChatCompletionResponseChoice"
}
},
"usage": {
"type": "object",
"nullable": true,
"$ref": "#/components/schemas/UsageInfo"
}
}
},
"ChatCompletionResponseChoice": {
"type": "object",
"properties": {
"index": { "type": "integer", "format": "int32" },
"message": { "type": "string" },
"message": { "type": "object", "$ref": "#/components/schemas/ChatCompletionResponseMessage" },
"finish_reason": {
"type": "string",
"nullable": true
}
}
},
"ChatCompletionResponseMessage": {
"type": "object",
"properties": {
"role": { "type": "string" },
"content": { "type": "string" }
}
},
"CompletionResponse": {
"type": "object",
"properties": {
Expand Down
74 changes: 74 additions & 0 deletions router/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -430,6 +430,16 @@ struct ChatCompletionResponse {
usage: UsageInfo,
}

#[derive(Serialize, ToSchema)]
struct ChatCompletionStreamResponse {
id: String,
object: String,
created: i64,
model: String,
choices: Vec<ChatCompletionResponseChoice>,
usage: Option<UsageInfo>,
}

impl From<CompletionRequest> for CompatGenerateRequest {
fn from(req: CompletionRequest) -> Self {
CompatGenerateRequest {
Expand Down Expand Up @@ -556,6 +566,70 @@ impl From<StreamResponse> for CompletionStreamResponse {
}
}

impl From<GenerateResponse> for ChatCompletionResponse {
fn from(resp: GenerateResponse) -> Self {
let prompt_tokens = resp.details.as_ref().map(|x| x.prompt_tokens).unwrap_or(0);
let completion_tokens = resp
.details
.as_ref()
.map(|x| x.generated_tokens)
.unwrap_or(0);
let total_tokens = prompt_tokens + completion_tokens;

ChatCompletionResponse {
id: "null".to_string(),
object: "text_completion".to_string(),
created: 0,
model: "null".to_string(),
choices: vec![ChatCompletionResponseChoice {
index: 0,
message: ChatMessage {
role: "assistant".to_string(),
content: resp.generated_text,
},
finish_reason: None,
}],
usage: UsageInfo {
prompt_tokens: prompt_tokens,
total_tokens: total_tokens,
completion_tokens: Some(completion_tokens),
},
}
}
}

impl From<StreamResponse> for ChatCompletionStreamResponse {
fn from(resp: StreamResponse) -> Self {
let prompt_tokens = resp.details.as_ref().map(|x| x.prompt_tokens).unwrap_or(0);
let completion_tokens = resp
.details
.as_ref()
.map(|x| x.generated_tokens)
.unwrap_or(0);
let total_tokens = prompt_tokens + completion_tokens;

ChatCompletionStreamResponse {
id: "null".to_string(),
object: "text_completion".to_string(),
created: 0,
model: "null".to_string(),
choices: vec![ChatCompletionResponseChoice {
index: 0,
message: ChatMessage {
role: "assistant".to_string(),
content: resp.generated_text.unwrap_or_default(),
},
finish_reason: None,
}],
usage: Some(UsageInfo {
prompt_tokens: prompt_tokens,
total_tokens: total_tokens,
completion_tokens: Some(completion_tokens),
}),
}
}
}

#[cfg(test)]
mod tests {
use std::io::Write;
Expand Down
20 changes: 12 additions & 8 deletions router/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@ use crate::health::Health;
use crate::infer::{InferError, InferResponse, InferStreamResponse};
use crate::validation::ValidationError;
use crate::{
BestOfSequence, ChatCompletionRequest, CompatGenerateRequest, CompletionRequest,
CompletionResponse, CompletionStreamResponse, Details, ErrorResponse, FinishReason,
GenerateParameters, GenerateRequest, GenerateResponse, HubModelInfo, Infer, Info, PrefillToken,
StreamDetails, StreamResponse, Token, Validation,
BestOfSequence, ChatCompletionRequest, ChatCompletionResponse, ChatCompletionStreamResponse,
CompatGenerateRequest, CompletionRequest, CompletionResponse, CompletionStreamResponse,
Details, ErrorResponse, FinishReason, GenerateParameters, GenerateRequest, GenerateResponse,
HubModelInfo, Infer, Info, PrefillToken, StreamDetails, StreamResponse, Token, Validation,
};
use axum::extract::Extension;
use axum::http::{HeaderMap, Method, StatusCode};
Expand Down Expand Up @@ -148,7 +148,7 @@ responses(
(status = 200, description = "Generated Text",
content(
("application/json" = ChatCompletionResponse),
("text/event-stream" = CompletionStreamResponse),
("text/event-stream" = ChatCompletionStreamResponse),
)),
(status = 424, description = "Generation Error", body = ErrorResponse,
example = json ! ({"error": "Request failed during generation"})),
Expand Down Expand Up @@ -178,10 +178,10 @@ async fn chat_completions_v1(
if gen_req.stream {
let callback = move |resp: StreamResponse| {
Event::default()
.json_data(CompletionStreamResponse::from(resp))
.json_data(ChatCompletionStreamResponse::from(resp))
.map_or_else(
|err| {
tracing::error!("Failed to serialize CompletionStreamResponse: {err}");
tracing::error!("Failed to serialize ChatCompletionStreamResponse: {err}");
Event::default()
},
|data| data,
Expand All @@ -194,7 +194,11 @@ async fn chat_completions_v1(
} else {
let (headers, generation) = generate(infer, Json(gen_req.into())).await?;
// wrap generation inside a Vec to match api-inference
Ok((headers, Json(vec![CompletionResponse::from(generation.0)])).into_response())
Ok((
headers,
Json(vec![ChatCompletionResponse::from(generation.0)]),
)
.into_response())
}
}

Expand Down

0 comments on commit 76fcc5d

Please sign in to comment.