From 76fcc5d2f6beca7cbff6a1693b15f6ad799a1bb9 Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Tue, 16 Jan 2024 08:48:09 -0800 Subject: [PATCH] fix: OpenAI response format (#184) --- README.md | 2 +- docs/guides/openai_api.md | 2 +- docs/index.md | 2 +- docs/reference/openapi.json | 40 +++++++++++++++++++- router/src/lib.rs | 74 +++++++++++++++++++++++++++++++++++++ router/src/server.rs | 20 ++++++---- 6 files changed, 127 insertions(+), 13 deletions(-) diff --git a/README.md b/README.md index 0fe7da07c..5f4e8eee6 100644 --- a/README.md +++ b/README.md @@ -158,7 +158,7 @@ resp = client.chat.completions.create( ], max_tokens=100, ) -print("Response:", resp[0].choices[0].text) +print("Response:", resp[0].choices[0].message.content) ``` See [OpenAI Compatible API](https://predibase.github.io/lorax/guides/openai_api) for details. diff --git a/docs/guides/openai_api.md b/docs/guides/openai_api.md index 1cb9da054..efc20d320 100644 --- a/docs/guides/openai_api.md +++ b/docs/guides/openai_api.md @@ -26,7 +26,7 @@ resp = client.chat.completions.create( ], max_tokens=100, ) -print("Response:", resp[0].choices[0].text) +print("Response:", resp[0].choices[0].message.content) ``` ### Streaming diff --git a/docs/index.md b/docs/index.md index 2f504ae7a..7212f3823 100644 --- a/docs/index.md +++ b/docs/index.md @@ -142,7 +142,7 @@ resp = client.chat.completions.create( ], max_tokens=100, ) -print("Response:", resp[0].choices[0].text) +print("Response:", resp[0].choices[0].message.content) ``` See [OpenAI Compatible API](./guides/openai_api.md) for details. diff --git a/docs/reference/openapi.json b/docs/reference/openapi.json index 10ba44b05..52bd15b10 100644 --- a/docs/reference/openapi.json +++ b/docs/reference/openapi.json @@ -299,7 +299,7 @@ }, "text/event-stream": { "schema": { - "$ref": "#/components/schemas/CompletionStreamResponse" + "$ref": "#/components/schemas/ChatCompletionStreamResponse" } } } @@ -989,17 +989,53 @@ } } }, + "ChatCompletionStreamResponse": { + "type": "object", + "properties": { + "id": { + "type": "string" + }, + "object": { + "type": "string" + }, + "created": { + "type": "integer", + "format": "int64" + }, + "model": { + "type": "string" + }, + "choices": { + "type": "array", + "items": { + "$ref": "#/components/schemas/ChatCompletionResponseChoice" + } + }, + "usage": { + "type": "object", + "nullable": true, + "$ref": "#/components/schemas/UsageInfo" + } + } + }, "ChatCompletionResponseChoice": { "type": "object", "properties": { "index": { "type": "integer", "format": "int32" }, - "message": { "type": "string" }, + "message": { "type": "object", "$ref": "#/components/schemas/ChatCompletionResponseMessage" }, "finish_reason": { "type": "string", "nullable": true } } }, + "ChatCompletionResponseMessage": { + "type": "object", + "properties": { + "role": { "type": "string" }, + "content": { "type": "string" } + } + }, "CompletionResponse": { "type": "object", "properties": { diff --git a/router/src/lib.rs b/router/src/lib.rs index e374406a5..b66583e18 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -430,6 +430,16 @@ struct ChatCompletionResponse { usage: UsageInfo, } +#[derive(Serialize, ToSchema)] +struct ChatCompletionStreamResponse { + id: String, + object: String, + created: i64, + model: String, + choices: Vec, + usage: Option, +} + impl From for CompatGenerateRequest { fn from(req: CompletionRequest) -> Self { CompatGenerateRequest { @@ -556,6 +566,70 @@ impl From for CompletionStreamResponse { } } +impl From for ChatCompletionResponse { + fn from(resp: GenerateResponse) -> Self { + let prompt_tokens = resp.details.as_ref().map(|x| x.prompt_tokens).unwrap_or(0); + let completion_tokens = resp + .details + .as_ref() + .map(|x| x.generated_tokens) + .unwrap_or(0); + let total_tokens = prompt_tokens + completion_tokens; + + ChatCompletionResponse { + id: "null".to_string(), + object: "text_completion".to_string(), + created: 0, + model: "null".to_string(), + choices: vec![ChatCompletionResponseChoice { + index: 0, + message: ChatMessage { + role: "assistant".to_string(), + content: resp.generated_text, + }, + finish_reason: None, + }], + usage: UsageInfo { + prompt_tokens: prompt_tokens, + total_tokens: total_tokens, + completion_tokens: Some(completion_tokens), + }, + } + } +} + +impl From for ChatCompletionStreamResponse { + fn from(resp: StreamResponse) -> Self { + let prompt_tokens = resp.details.as_ref().map(|x| x.prompt_tokens).unwrap_or(0); + let completion_tokens = resp + .details + .as_ref() + .map(|x| x.generated_tokens) + .unwrap_or(0); + let total_tokens = prompt_tokens + completion_tokens; + + ChatCompletionStreamResponse { + id: "null".to_string(), + object: "text_completion".to_string(), + created: 0, + model: "null".to_string(), + choices: vec![ChatCompletionResponseChoice { + index: 0, + message: ChatMessage { + role: "assistant".to_string(), + content: resp.generated_text.unwrap_or_default(), + }, + finish_reason: None, + }], + usage: Some(UsageInfo { + prompt_tokens: prompt_tokens, + total_tokens: total_tokens, + completion_tokens: Some(completion_tokens), + }), + } + } +} + #[cfg(test)] mod tests { use std::io::Write; diff --git a/router/src/server.rs b/router/src/server.rs index fd153e899..57e656778 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -3,10 +3,10 @@ use crate::health::Health; use crate::infer::{InferError, InferResponse, InferStreamResponse}; use crate::validation::ValidationError; use crate::{ - BestOfSequence, ChatCompletionRequest, CompatGenerateRequest, CompletionRequest, - CompletionResponse, CompletionStreamResponse, Details, ErrorResponse, FinishReason, - GenerateParameters, GenerateRequest, GenerateResponse, HubModelInfo, Infer, Info, PrefillToken, - StreamDetails, StreamResponse, Token, Validation, + BestOfSequence, ChatCompletionRequest, ChatCompletionResponse, ChatCompletionStreamResponse, + CompatGenerateRequest, CompletionRequest, CompletionResponse, CompletionStreamResponse, + Details, ErrorResponse, FinishReason, GenerateParameters, GenerateRequest, GenerateResponse, + HubModelInfo, Infer, Info, PrefillToken, StreamDetails, StreamResponse, Token, Validation, }; use axum::extract::Extension; use axum::http::{HeaderMap, Method, StatusCode}; @@ -148,7 +148,7 @@ responses( (status = 200, description = "Generated Text", content( ("application/json" = ChatCompletionResponse), -("text/event-stream" = CompletionStreamResponse), +("text/event-stream" = ChatCompletionStreamResponse), )), (status = 424, description = "Generation Error", body = ErrorResponse, example = json ! ({"error": "Request failed during generation"})), @@ -178,10 +178,10 @@ async fn chat_completions_v1( if gen_req.stream { let callback = move |resp: StreamResponse| { Event::default() - .json_data(CompletionStreamResponse::from(resp)) + .json_data(ChatCompletionStreamResponse::from(resp)) .map_or_else( |err| { - tracing::error!("Failed to serialize CompletionStreamResponse: {err}"); + tracing::error!("Failed to serialize ChatCompletionStreamResponse: {err}"); Event::default() }, |data| data, @@ -194,7 +194,11 @@ async fn chat_completions_v1( } else { let (headers, generation) = generate(infer, Json(gen_req.into())).await?; // wrap generation inside a Vec to match api-inference - Ok((headers, Json(vec![CompletionResponse::from(generation.0)])).into_response()) + Ok(( + headers, + Json(vec![ChatCompletionResponse::from(generation.0)]), + ) + .into_response()) } }