diff --git a/router/src/lib.rs b/router/src/lib.rs index e049db318..d99c17951 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -526,8 +526,8 @@ struct CompletionStreamResponse { #[derive(Serialize, ToSchema)] struct ChatMessage { - role: String, - content: String, + role: Option, + content: Option, } #[derive(Serialize, ToSchema)] @@ -563,7 +563,7 @@ struct ChatCompletionStreamResponse { choices: Vec, } -#[derive(Serialize, ToSchema)] +#[derive(Serialize, ToSchema, PartialEq)] #[serde(rename_all(serialize = "snake_case"))] pub(crate) enum CompletionFinishReason { #[schema(rename = "stop")] @@ -688,6 +688,14 @@ impl From for CompletionStreamResponse { .unwrap_or(0); let total_tokens = prompt_tokens + completion_tokens; + let finish_reason = resp + .details + .map(|x| CompletionFinishReason::from(x.finish_reason)); + + let is_stop = finish_reason + .as_ref() + .is_some_and(|x| x == &CompletionFinishReason::Stop); + CompletionStreamResponse { id: "null".to_string(), object: "text_completion".to_string(), @@ -695,11 +703,13 @@ impl From for CompletionStreamResponse { model: "null".to_string(), choices: vec![CompletionResponseStreamChoice { index: 0, - text: resp.token.text, + text: if is_stop { + "".to_string() + } else { + resp.token.text + }, logprobs: None, - finish_reason: resp - .details - .map(|x| CompletionFinishReason::from(x.finish_reason)), + finish_reason: finish_reason, }], usage: Some(UsageInfo { prompt_tokens: prompt_tokens, @@ -728,8 +738,8 @@ impl From for ChatCompletionResponse { choices: vec![ChatCompletionResponseChoice { index: 0, message: ChatMessage { - role: "assistant".to_string(), - content: resp.generated_text, + role: Some("assistant".to_string()), + content: Some(resp.generated_text), }, finish_reason: resp .details @@ -746,6 +756,14 @@ impl From for ChatCompletionResponse { impl From for ChatCompletionStreamResponse { fn from(resp: StreamResponse) -> Self { + let finish_reason = resp + .details + .map(|x| CompletionFinishReason::from(x.finish_reason)); + + let is_stop = finish_reason + .as_ref() + .is_some_and(|x| x == &CompletionFinishReason::Stop); + ChatCompletionStreamResponse { id: "null".to_string(), object: "chat.completion.chunk".to_string(), @@ -754,12 +772,14 @@ impl From for ChatCompletionStreamResponse { choices: vec![ChatCompletionStreamResponseChoice { index: 0, delta: ChatMessage { - role: "assistant".to_string(), - content: resp.token.text, + role: if is_stop { + None + } else { + Some("assistant".to_string()) + }, + content: if is_stop { None } else { Some(resp.token.text) }, }, - finish_reason: resp - .details - .map(|x| CompletionFinishReason::from(x.finish_reason)), + finish_reason: finish_reason, }], } }