Skip to content

Commit

Permalink
fix: Skip returning EOS token on finish_reason 'stop' (#289)
Browse files Browse the repository at this point in the history
  • Loading branch information
jeffreyftang authored Feb 29, 2024
1 parent 36a24ba commit 6d8d711
Showing 1 changed file with 34 additions and 14 deletions.
48 changes: 34 additions & 14 deletions router/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -526,8 +526,8 @@ struct CompletionStreamResponse {

#[derive(Serialize, ToSchema)]
struct ChatMessage {
role: String,
content: String,
role: Option<String>,
content: Option<String>,
}

#[derive(Serialize, ToSchema)]
Expand Down Expand Up @@ -563,7 +563,7 @@ struct ChatCompletionStreamResponse {
choices: Vec<ChatCompletionStreamResponseChoice>,
}

#[derive(Serialize, ToSchema)]
#[derive(Serialize, ToSchema, PartialEq)]
#[serde(rename_all(serialize = "snake_case"))]
pub(crate) enum CompletionFinishReason {
#[schema(rename = "stop")]
Expand Down Expand Up @@ -688,18 +688,28 @@ impl From<StreamResponse> for CompletionStreamResponse {
.unwrap_or(0);
let total_tokens = prompt_tokens + completion_tokens;

let finish_reason = resp
.details
.map(|x| CompletionFinishReason::from(x.finish_reason));

let is_stop = finish_reason
.as_ref()
.is_some_and(|x| x == &CompletionFinishReason::Stop);

CompletionStreamResponse {
id: "null".to_string(),
object: "text_completion".to_string(),
created: 0,
model: "null".to_string(),
choices: vec![CompletionResponseStreamChoice {
index: 0,
text: resp.token.text,
text: if is_stop {
"".to_string()
} else {
resp.token.text
},
logprobs: None,
finish_reason: resp
.details
.map(|x| CompletionFinishReason::from(x.finish_reason)),
finish_reason: finish_reason,
}],
usage: Some(UsageInfo {
prompt_tokens: prompt_tokens,
Expand Down Expand Up @@ -728,8 +738,8 @@ impl From<GenerateResponse> for ChatCompletionResponse {
choices: vec![ChatCompletionResponseChoice {
index: 0,
message: ChatMessage {
role: "assistant".to_string(),
content: resp.generated_text,
role: Some("assistant".to_string()),
content: Some(resp.generated_text),
},
finish_reason: resp
.details
Expand All @@ -746,6 +756,14 @@ impl From<GenerateResponse> for ChatCompletionResponse {

impl From<StreamResponse> for ChatCompletionStreamResponse {
fn from(resp: StreamResponse) -> Self {
let finish_reason = resp
.details
.map(|x| CompletionFinishReason::from(x.finish_reason));

let is_stop = finish_reason
.as_ref()
.is_some_and(|x| x == &CompletionFinishReason::Stop);

ChatCompletionStreamResponse {
id: "null".to_string(),
object: "chat.completion.chunk".to_string(),
Expand All @@ -754,12 +772,14 @@ impl From<StreamResponse> for ChatCompletionStreamResponse {
choices: vec![ChatCompletionStreamResponseChoice {
index: 0,
delta: ChatMessage {
role: "assistant".to_string(),
content: resp.token.text,
role: if is_stop {
None
} else {
Some("assistant".to_string())
},
content: if is_stop { None } else { Some(resp.token.text) },
},
finish_reason: resp
.details
.map(|x| CompletionFinishReason::from(x.finish_reason)),
finish_reason: finish_reason,
}],
}
}
Expand Down

0 comments on commit 6d8d711

Please sign in to comment.