From 8906253f87111a651e37797939b61d2bdfe591b8 Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Thu, 18 Jan 2024 22:11:38 -0800 Subject: [PATCH] Fixed OpenAI stream response data (#193) --- docs/reference/openapi.json | 18 ++++++--- router/src/lib.rs | 78 ++++++++++++++++++++++++------------- router/src/server.rs | 2 +- 3 files changed, 64 insertions(+), 34 deletions(-) diff --git a/docs/reference/openapi.json b/docs/reference/openapi.json index 52bd15b10..b0ec58adb 100644 --- a/docs/reference/openapi.json +++ b/docs/reference/openapi.json @@ -1008,13 +1008,19 @@ "choices": { "type": "array", "items": { - "$ref": "#/components/schemas/ChatCompletionResponseChoice" + "$ref": "#/components/schemas/ChatCompletionStreamResponseChoice" } - }, - "usage": { - "type": "object", - "nullable": true, - "$ref": "#/components/schemas/UsageInfo" + } + } + }, + "ChatCompletionStreamResponseChoice": { + "type": "object", + "properties": { + "index": { "type": "integer", "format": "int32" }, + "delta": { "type": "object", "$ref": "#/components/schemas/ChatCompletionResponseMessage" }, + "finish_reason": { + "type": "string", + "nullable": true } } }, diff --git a/router/src/lib.rs b/router/src/lib.rs index ec75aa66b..7aefd84e0 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -380,7 +380,7 @@ struct CompletionResponseChoice { index: i32, text: String, logprobs: Option, - finish_reason: Option, + finish_reason: Option, } #[derive(Serialize, ToSchema)] @@ -398,7 +398,7 @@ struct CompletionResponseStreamChoice { index: i32, text: String, logprobs: Option, - finish_reason: Option, + finish_reason: Option, } #[derive(Serialize, ToSchema)] @@ -421,7 +421,7 @@ struct ChatMessage { struct ChatCompletionResponseChoice { index: i32, message: ChatMessage, - finish_reason: Option, + finish_reason: Option, } #[derive(Serialize, ToSchema)] @@ -434,14 +434,33 @@ struct ChatCompletionResponse { usage: UsageInfo, } +#[derive(Serialize, ToSchema)] +struct ChatCompletionStreamResponseChoice { + index: i32, + delta: ChatMessage, + finish_reason: Option, +} + #[derive(Serialize, ToSchema)] struct ChatCompletionStreamResponse { id: String, object: String, created: i64, model: String, - choices: Vec, - usage: Option, + choices: Vec, +} + +#[derive(Serialize, ToSchema)] +#[serde(rename_all(serialize = "snake_case"))] +pub(crate) enum CompletionFinishReason { + #[schema(rename = "stop")] + Stop, + #[schema(rename = "length")] + Length, + #[schema(rename = "content_filter")] + ContentFilter, + #[schema(rename = "tool_calls")] + ToolCalls, } impl From for CompatGenerateRequest { @@ -529,7 +548,9 @@ impl From for CompletionResponse { index: 0, text: resp.generated_text, logprobs: None, - finish_reason: None, + finish_reason: resp + .details + .map(|x| CompletionFinishReason::from(x.finish_reason)), }], usage: UsageInfo { prompt_tokens: prompt_tokens, @@ -557,9 +578,11 @@ impl From for CompletionStreamResponse { model: "null".to_string(), choices: vec![CompletionResponseStreamChoice { index: 0, - text: resp.generated_text.unwrap_or_default(), + text: resp.token.text, logprobs: None, - finish_reason: None, + finish_reason: resp + .details + .map(|x| CompletionFinishReason::from(x.finish_reason)), }], usage: Some(UsageInfo { prompt_tokens: prompt_tokens, @@ -591,7 +614,9 @@ impl From for ChatCompletionResponse { role: "assistant".to_string(), content: resp.generated_text, }, - finish_reason: None, + finish_reason: resp + .details + .map(|x| CompletionFinishReason::from(x.finish_reason)), }], usage: UsageInfo { prompt_tokens: prompt_tokens, @@ -604,32 +629,31 @@ impl From for ChatCompletionResponse { impl From for ChatCompletionStreamResponse { fn from(resp: StreamResponse) -> Self { - let prompt_tokens = resp.details.as_ref().map(|x| x.prompt_tokens).unwrap_or(0); - let completion_tokens = resp - .details - .as_ref() - .map(|x| x.generated_tokens) - .unwrap_or(0); - let total_tokens = prompt_tokens + completion_tokens; - ChatCompletionStreamResponse { id: "null".to_string(), - object: "text_completion".to_string(), + object: "chat.completion.chunk".to_string(), created: 0, model: "null".to_string(), - choices: vec![ChatCompletionResponseChoice { + choices: vec![ChatCompletionStreamResponseChoice { index: 0, - message: ChatMessage { + delta: ChatMessage { role: "assistant".to_string(), - content: resp.generated_text.unwrap_or_default(), + content: resp.token.text, }, - finish_reason: None, + finish_reason: resp + .details + .map(|x| CompletionFinishReason::from(x.finish_reason)), }], - usage: Some(UsageInfo { - prompt_tokens: prompt_tokens, - total_tokens: total_tokens, - completion_tokens: Some(completion_tokens), - }), + } + } +} + +impl From for CompletionFinishReason { + fn from(reason: FinishReason) -> Self { + match reason { + FinishReason::Length => CompletionFinishReason::Length, + FinishReason::EndOfSequenceToken => CompletionFinishReason::Stop, + FinishReason::StopSequence => CompletionFinishReason::ContentFilter, } } } diff --git a/router/src/server.rs b/router/src/server.rs index 57e656778..f79821269 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -553,7 +553,7 @@ async fn generate_stream_with_callback( details: None, }; - yield Ok(Event::default().json_data(stream_token).unwrap()) + yield Ok(callback(stream_token)) } // Yield event for last token and compute timings InferStreamResponse::End {