Skip to content

Commit

Permalink
use stream to read sse
Browse files Browse the repository at this point in the history
  • Loading branch information
neowu committed Jul 23, 2024
1 parent 6197307 commit a97c426
Show file tree
Hide file tree
Showing 4 changed files with 153 additions and 86 deletions.
66 changes: 66 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,11 @@ serde = { version = "1", features = ["derive", "rc"] }
serde_json = "1"
log = "0"
tokio = { version = "1", features = ["full"] }
reqwest = "0"
reqwest = { version = "0", features = ["stream"] }
bytes = "1"
rand = "0"
base64 = "0"
uuid = { version = "1", features = ["v4"] }
regex = "1"
glob = "0"
futures = "0"
93 changes: 48 additions & 45 deletions src/azure/chatgpt.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::io;
use std::io::ErrorKind;
use std::ops::Not;
use std::path::Path;
use std::rc::Rc;
Expand All @@ -6,6 +8,9 @@ use std::str;
use base64::prelude::BASE64_STANDARD;
use base64::Engine;
use bytes::Bytes;
use futures::AsyncBufReadExt;
use futures::StreamExt;
use futures::TryStreamExt;
use log::info;
use reqwest::Response;
use tokio::fs;
Expand Down Expand Up @@ -178,7 +183,7 @@ impl ChatGPT {
}
}

async fn read_sse_response(mut http_response: Response) -> Result<ChatResponse, Exception> {
async fn read_sse_response(http_response: Response) -> Result<ChatResponse, Exception> {
let mut response = ChatResponse {
choices: vec![ChatCompletionChoice {
index: 0,
Expand All @@ -193,62 +198,60 @@ async fn read_sse_response(mut http_response: Response) -> Result<ChatResponse,
// only support one choice, n=1
let choice = response.choices.first_mut().unwrap();

let mut buffer = String::with_capacity(1024);
while let Some(chunk) = http_response.chunk().await? {
buffer.push_str(str::from_utf8(&chunk)?);
let reader = http_response
.bytes_stream()
.map_err(|e| io::Error::new(ErrorKind::Other, e))
.into_async_read();

while let Some(index) = buffer.find("\n\n") {
if buffer.starts_with("data:") {
let data = &buffer[6..index];
let mut lines = reader.lines();
while let Some(line) = lines.next().await {
let line = line?;

if data == "[DONE]" {
break;
}

let stream_response: ChatStreamResponse = json::from_json(data)?;
// chatgpt always adds space after data:
if let Some(data) = line.strip_prefix("data: ") {
if data == "[DONE]" {
break;
}

if let Some(stream_choice) = stream_response.choices.into_iter().next() {
choice.index = stream_choice.index;
let stream_response: ChatStreamResponse = json::from_json(data)?;

if let Some(stream_calls) = stream_choice.delta.tool_calls {
if choice.message.tool_calls.is_none() {
choice.message.tool_calls = Some(vec![]);
}
if let Some(stream_choice) = stream_response.choices.into_iter().next() {
choice.index = stream_choice.index;

// stream tool call only return single element
let stream_call = stream_calls.into_iter().next().unwrap();
if let Some(name) = stream_call.function.name {
choice.message.tool_calls.as_mut().unwrap().push(ToolCall {
id: stream_call.id.unwrap(),
r#type: "function".to_string(),
function: super::chatgpt_api::FunctionCall {
name,
arguments: String::new(),
},
});
}
let tool_call = choice.message.tool_calls.as_mut().unwrap().get_mut(stream_call.index as usize).unwrap();
tool_call.function.arguments.push_str(&stream_call.function.arguments);
} else if let Some(content) = stream_choice.delta.content {
choice.append_content(&content);
console::print(&content).await?;
if let Some(stream_calls) = stream_choice.delta.tool_calls {
if choice.message.tool_calls.is_none() {
choice.message.tool_calls = Some(vec![]);
}

if let Some(finish_reason) = stream_choice.finish_reason {
choice.finish_reason = finish_reason;
if choice.finish_reason == "stop" {
console::print("\n").await?;
}
// stream tool call only return single element
let stream_call = stream_calls.into_iter().next().unwrap();
if let Some(name) = stream_call.function.name {
choice.message.tool_calls.as_mut().unwrap().push(ToolCall {
id: stream_call.id.unwrap(),
r#type: "function".to_string(),
function: super::chatgpt_api::FunctionCall {
name,
arguments: String::new(),
},
});
}
let tool_call = choice.message.tool_calls.as_mut().unwrap().get_mut(stream_call.index as usize).unwrap();
tool_call.function.arguments.push_str(&stream_call.function.arguments);
} else if let Some(content) = stream_choice.delta.content {
choice.append_content(&content);
console::print(&content).await?;
}

if let Some(usage) = stream_response.usage {
response.usage = usage;
if let Some(finish_reason) = stream_choice.finish_reason {
choice.finish_reason = finish_reason;
if choice.finish_reason == "stop" {
console::print("\n").await?;
}
}
}

buffer.replace_range(0..index + 2, "");
} else {
return Err(Exception::unexpected(format!("unexpected sse message, buffer={}", buffer)));
if let Some(usage) = stream_response.usage {
response.usage = usage;
}
}
}
Expand Down
Loading

0 comments on commit a97c426

Please sign in to comment.