Skip to content

Commit

Permalink
Fixed OpenAI stream response data (#193)
Browse files Browse the repository at this point in the history
  • Loading branch information
tgaddair authored Jan 19, 2024
1 parent ccf4477 commit 8906253
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 34 deletions.
18 changes: 12 additions & 6 deletions docs/reference/openapi.json
Original file line number Diff line number Diff line change
Expand Up @@ -1008,13 +1008,19 @@
"choices": {
"type": "array",
"items": {
"$ref": "#/components/schemas/ChatCompletionResponseChoice"
"$ref": "#/components/schemas/ChatCompletionStreamResponseChoice"
}
},
"usage": {
"type": "object",
"nullable": true,
"$ref": "#/components/schemas/UsageInfo"
}
}
},
"ChatCompletionStreamResponseChoice": {
"type": "object",
"properties": {
"index": { "type": "integer", "format": "int32" },
"delta": { "type": "object", "$ref": "#/components/schemas/ChatCompletionResponseMessage" },
"finish_reason": {
"type": "string",
"nullable": true
}
}
},
Expand Down
78 changes: 51 additions & 27 deletions router/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,7 @@ struct CompletionResponseChoice {
index: i32,
text: String,
logprobs: Option<LogProbs>,
finish_reason: Option<String>,
finish_reason: Option<CompletionFinishReason>,
}

#[derive(Serialize, ToSchema)]
Expand All @@ -398,7 +398,7 @@ struct CompletionResponseStreamChoice {
index: i32,
text: String,
logprobs: Option<LogProbs>,
finish_reason: Option<String>,
finish_reason: Option<CompletionFinishReason>,
}

#[derive(Serialize, ToSchema)]
Expand All @@ -421,7 +421,7 @@ struct ChatMessage {
struct ChatCompletionResponseChoice {
index: i32,
message: ChatMessage,
finish_reason: Option<String>,
finish_reason: Option<CompletionFinishReason>,
}

#[derive(Serialize, ToSchema)]
Expand All @@ -434,14 +434,33 @@ struct ChatCompletionResponse {
usage: UsageInfo,
}

#[derive(Serialize, ToSchema)]
struct ChatCompletionStreamResponseChoice {
index: i32,
delta: ChatMessage,
finish_reason: Option<CompletionFinishReason>,
}

#[derive(Serialize, ToSchema)]
struct ChatCompletionStreamResponse {
id: String,
object: String,
created: i64,
model: String,
choices: Vec<ChatCompletionResponseChoice>,
usage: Option<UsageInfo>,
choices: Vec<ChatCompletionStreamResponseChoice>,
}

#[derive(Serialize, ToSchema)]
#[serde(rename_all(serialize = "snake_case"))]
pub(crate) enum CompletionFinishReason {
#[schema(rename = "stop")]
Stop,
#[schema(rename = "length")]
Length,
#[schema(rename = "content_filter")]
ContentFilter,
#[schema(rename = "tool_calls")]
ToolCalls,
}

impl From<CompletionRequest> for CompatGenerateRequest {
Expand Down Expand Up @@ -529,7 +548,9 @@ impl From<GenerateResponse> for CompletionResponse {
index: 0,
text: resp.generated_text,
logprobs: None,
finish_reason: None,
finish_reason: resp
.details
.map(|x| CompletionFinishReason::from(x.finish_reason)),
}],
usage: UsageInfo {
prompt_tokens: prompt_tokens,
Expand Down Expand Up @@ -557,9 +578,11 @@ impl From<StreamResponse> for CompletionStreamResponse {
model: "null".to_string(),
choices: vec![CompletionResponseStreamChoice {
index: 0,
text: resp.generated_text.unwrap_or_default(),
text: resp.token.text,
logprobs: None,
finish_reason: None,
finish_reason: resp
.details
.map(|x| CompletionFinishReason::from(x.finish_reason)),
}],
usage: Some(UsageInfo {
prompt_tokens: prompt_tokens,
Expand Down Expand Up @@ -591,7 +614,9 @@ impl From<GenerateResponse> for ChatCompletionResponse {
role: "assistant".to_string(),
content: resp.generated_text,
},
finish_reason: None,
finish_reason: resp
.details
.map(|x| CompletionFinishReason::from(x.finish_reason)),
}],
usage: UsageInfo {
prompt_tokens: prompt_tokens,
Expand All @@ -604,32 +629,31 @@ impl From<GenerateResponse> for ChatCompletionResponse {

impl From<StreamResponse> for ChatCompletionStreamResponse {
fn from(resp: StreamResponse) -> Self {
let prompt_tokens = resp.details.as_ref().map(|x| x.prompt_tokens).unwrap_or(0);
let completion_tokens = resp
.details
.as_ref()
.map(|x| x.generated_tokens)
.unwrap_or(0);
let total_tokens = prompt_tokens + completion_tokens;

ChatCompletionStreamResponse {
id: "null".to_string(),
object: "text_completion".to_string(),
object: "chat.completion.chunk".to_string(),
created: 0,
model: "null".to_string(),
choices: vec![ChatCompletionResponseChoice {
choices: vec![ChatCompletionStreamResponseChoice {
index: 0,
message: ChatMessage {
delta: ChatMessage {
role: "assistant".to_string(),
content: resp.generated_text.unwrap_or_default(),
content: resp.token.text,
},
finish_reason: None,
finish_reason: resp
.details
.map(|x| CompletionFinishReason::from(x.finish_reason)),
}],
usage: Some(UsageInfo {
prompt_tokens: prompt_tokens,
total_tokens: total_tokens,
completion_tokens: Some(completion_tokens),
}),
}
}
}

impl From<FinishReason> for CompletionFinishReason {
fn from(reason: FinishReason) -> Self {
match reason {
FinishReason::Length => CompletionFinishReason::Length,
FinishReason::EndOfSequenceToken => CompletionFinishReason::Stop,
FinishReason::StopSequence => CompletionFinishReason::ContentFilter,
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion router/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -553,7 +553,7 @@ async fn generate_stream_with_callback(
details: None,
};

yield Ok(Event::default().json_data(stream_token).unwrap())
yield Ok(callback(stream_token))
}
// Yield event for last token and compute timings
InferStreamResponse::End {
Expand Down

0 comments on commit 8906253

Please sign in to comment.