Skip to content

Commit

Permalink
replace sse impl
Browse files Browse the repository at this point in the history
  • Loading branch information
neowu committed Jul 8, 2024
1 parent 1bcc2f2 commit cf0f7f7
Show file tree
Hide file tree
Showing 6 changed files with 87 additions and 197 deletions.
136 changes: 0 additions & 136 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 1 addition & 3 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,7 @@ clap_complete = "4"
serde = { version = "1", features = ["derive", "rc"] }
serde_json = "1"
tokio = { version = "1", features = ["full"] }
reqwest = { version = "0", features = ["stream"] }
reqwest-eventsource = "0"
futures = "0"
reqwest = "0"
rand = "0"
base64 = "0"
uuid = { version = "1", features = ["v4"] }
Expand Down
82 changes: 50 additions & 32 deletions src/azure/chatgpt.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
use std::collections::HashMap;
use std::ops::Not;
use std::rc::Rc;
use std::str;

use futures::stream::StreamExt;
use reqwest_eventsource::CannotCloneRequestError;
use reqwest_eventsource::Event;
use reqwest_eventsource::EventSource;
use bytes::Bytes;
use reqwest::Response;
use tokio::sync::mpsc::channel;
use tokio::sync::mpsc::Receiver;
use tokio::sync::mpsc::Sender;
use tracing::info;

use crate::azure::chatgpt_api::ChatRequest;
use crate::azure::chatgpt_api::ChatRequestMessage;
Expand Down Expand Up @@ -40,7 +40,7 @@ type FunctionCall = HashMap<i64, (String, String, String)>;

impl ChatGPT {
pub fn new(endpoint: String, model: String, api_key: String, system_message: Option<String>, function_store: FunctionStore) -> Self {
let url = format!("{endpoint}/openai/deployments/{model}/chat/completions?api-version=2024-02-01");
let url = format!("{endpoint}/openai/deployments/{model}/chat/completions?api-version=2024-06-01");
let tools: Option<Rc<[Tool]>> = function_store.declarations.is_empty().not().then_some(
function_store
.declarations
Expand Down Expand Up @@ -90,11 +90,10 @@ impl ChatGPT {
}

async fn process(&mut self, handler: &impl ChatHandler) -> Result<Option<FunctionCall>, Exception> {
let source = self.call_api().await?;

let (tx, rx) = channel(64);
let handle = tokio::spawn(read_event_source(source, tx));

let response = self.call_api().await?;
let handle = tokio::spawn(read_sse(response, tx));
let function_call = self.process_event(rx, handler).await;
handle.await??;

Expand Down Expand Up @@ -122,12 +121,14 @@ impl ChatGPT {
None
}

async fn call_api(&mut self) -> Result<EventSource, Exception> {
async fn call_api(&mut self) -> Result<Response, Exception> {
let request = ChatRequest {
messages: Rc::clone(&self.messages),
temperature: 0.7,
top_p: 0.95,
stream: true,
// stream_options: Some(StreamOptions { include_usage: true }),
stream_options: None,
stop: None,
max_tokens: 800,
presence_penalty: 0.0,
Expand All @@ -137,36 +138,46 @@ impl ChatGPT {
};

let body = json::to_json(&request)?;
let body = Bytes::from(body);
let request = http_client::http_client()
.post(&self.url)
.header("Content-Type", "application/json")
.header("api-key", &self.api_key)
.body(body);

Ok(EventSource::new(request)?)
}
}
.body(body.clone());

let response = request.send().await?;
let status = response.status();
if status != 200 {
let body = str::from_utf8(&body).unwrap();
info!("body={}", body);
let response_text = response.text().await?;
return Err(Exception::ExternalError(format!(
"failed to call azure api, status={status}, response={response_text}"
)));
}

impl From<CannotCloneRequestError> for Exception {
fn from(err: CannotCloneRequestError) -> Self {
Exception::unexpected(err)
Ok(response)
}
}

async fn read_event_source(mut source: EventSource, tx: Sender<InternalEvent>) -> Result<(), Exception> {
async fn read_sse(response: Response, tx: Sender<InternalEvent>) -> Result<(), Exception> {
let mut function_calls: FunctionCall = HashMap::new();
while let Some(event) = source.next().await {
match event {
Ok(Event::Open) => {}
Ok(Event::Message(message)) => {
let data = message.data;
let mut usage = Usage::default();

let mut buffer = String::with_capacity(1024);
let mut response = response;
'outer: while let Some(chunk) = response.chunk().await? {
buffer.push_str(str::from_utf8(&chunk).unwrap());

while let Some(index) = buffer.find("\n\n") {
if buffer.starts_with("data:") {
let data = &buffer[6..index];

if data == "[DONE]" {
source.close();
break;
break 'outer;
}

let response: ChatResponse = json::from_json(&data)?;
let response: ChatResponse = json::from_json(data)?;

if let Some(choice) = response.choices.into_iter().next() {
let delta = choice.delta.unwrap();
Expand All @@ -181,18 +192,25 @@ async fn read_event_source(mut source: EventSource, tx: Sender<InternalEvent>) -
tx.send(InternalEvent::Event(ChatEvent::Delta(value))).await?;
}
}
}
Err(err) => {
source.close();
return Err(Exception::unexpected(err));

if let Some(value) = response.usage {
usage = Usage {
request_tokens: value.prompt_tokens,
response_tokens: value.completion_tokens,
};
}

buffer.replace_range(0..index + 2, "");
} else {
return Err(Exception::ValidationError(format!("unexpected sse message, buffer={}", buffer)));
}
}
}

if !function_calls.is_empty() {
tx.send(InternalEvent::FunctionCall(function_calls)).await?;
} else {
// chatgpt doesn't support token usage with stream mode
tx.send(InternalEvent::Event(ChatEvent::End(Usage::default()))).await?;
tx.send(InternalEvent::Event(ChatEvent::End(usage))).await?;
}

Ok(())
Expand Down
Loading

0 comments on commit cf0f7f7

Please sign in to comment.