From cf0f7f76901cfd8fd42b548efba3a0ef20018841 Mon Sep 17 00:00:00 2001 From: neo <1100909+neowu@users.noreply.github.com> Date: Mon, 8 Jul 2024 16:31:26 +0800 Subject: [PATCH] replace sse impl --- Cargo.lock | 136 --------------------------------------- Cargo.toml | 4 +- src/azure/chatgpt.rs | 82 ++++++++++++++--------- src/azure/chatgpt_api.rs | 16 +++++ src/gcloud/gemini.rs | 33 ++++------ src/llm/function.rs | 13 ++-- 6 files changed, 87 insertions(+), 197 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index e576595..d68aa62 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -242,17 +242,6 @@ dependencies = [ "windows-sys 0.52.0", ] -[[package]] -name = "eventsource-stream" -version = "0.2.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "74fef4569247a5f429d9156b9d0a2599914385dd189c539334c625d8099d90ab" -dependencies = [ - "futures-core", - "nom", - "pin-project-lite", -] - [[package]] name = "fastrand" version = "2.1.0" @@ -289,21 +278,6 @@ dependencies = [ "percent-encoding", ] -[[package]] -name = "futures" -version = "0.3.30" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "645c6916888f6cb6350d2550b80fb63e734897a8498abe35cfb732b6487804b0" -dependencies = [ - "futures-channel", - "futures-core", - "futures-executor", - "futures-io", - "futures-sink", - "futures-task", - "futures-util", -] - [[package]] name = "futures-channel" version = "0.3.30" @@ -311,7 +285,6 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "eac8f7d7865dcb88bd4373ab671c8cf4508703796caa2b1985a9ca867b3fcb78" dependencies = [ "futures-core", - "futures-sink", ] [[package]] @@ -320,34 +293,6 @@ version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dfc6580bb841c5a68e9ef15c77ccc837b40a7504914d52e47b8b0e9bbda25a1d" -[[package]] -name = "futures-executor" -version = "0.3.30" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a576fc72ae164fca6b9db127eaa9a9dda0d61316034f33a0a0d4eda41f02b01d" -dependencies = [ - "futures-core", - "futures-task", - "futures-util", -] - -[[package]] -name = "futures-io" -version = "0.3.30" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a44623e20b9681a318efdd71c299b6b222ed6f231972bfe2f224ebad6311f0c1" - -[[package]] -name = "futures-macro" -version = "0.3.30" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "87750cf4b7a4c0625b1529e4c543c2182106e4dedc60a2a6455e00d212c489ac" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] - [[package]] name = "futures-sink" version = "0.3.30" @@ -360,28 +305,16 @@ version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "38d84fa142264698cdce1a9f9172cf383a0c82de1bddcf3092901442c4097004" -[[package]] -name = "futures-timer" -version = "3.0.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f288b0a4f20f9a56b5d1da57e2227c661b7b16168e2f72365f57b63326e29b24" - [[package]] name = "futures-util" version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3d6401deb83407ab3da39eba7e33987a73c3df0c82b4bb5813ee871c19c41d48" dependencies = [ - "futures-channel", "futures-core", - "futures-io", - "futures-macro", - "futures-sink", "futures-task", - "memchr", "pin-project-lite", "pin-utils", - "slab", ] [[package]] @@ -753,12 +686,6 @@ version = "0.3.17" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a" -[[package]] -name = "minimal-lexical" -version = "0.2.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a" - [[package]] name = "miniz_oxide" version = "0.7.3" @@ -796,16 +723,6 @@ dependencies = [ "tempfile", ] -[[package]] -name = "nom" -version = "7.1.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d273983c5a657a70a3e8f2a01329822f3b8c8172b73826411a55751e404a0a4a" -dependencies = [ - "memchr", - "minimal-lexical", -] - [[package]] name = "nu-ansi-term" version = "0.46.0" @@ -981,10 +898,8 @@ dependencies = [ "bytes", "clap", "clap_complete", - "futures", "rand", "reqwest", - "reqwest-eventsource", "serde", "serde_json", "tokio", @@ -1075,32 +990,14 @@ dependencies = [ "system-configuration", "tokio", "tokio-native-tls", - "tokio-util", "tower-service", "url", "wasm-bindgen", "wasm-bindgen-futures", - "wasm-streams", "web-sys", "winreg", ] -[[package]] -name = "reqwest-eventsource" -version = "0.6.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "632c55746dbb44275691640e7b40c907c16a2dc1a5842aa98aaec90da6ec6bde" -dependencies = [ - "eventsource-stream", - "futures-core", - "futures-timer", - "mime", - "nom", - "pin-project-lite", - "reqwest", - "thiserror", -] - [[package]] name = "rustc-demangle" version = "0.1.24" @@ -1339,26 +1236,6 @@ dependencies = [ "windows-sys 0.52.0", ] -[[package]] -name = "thiserror" -version = "1.0.61" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c546c80d6be4bc6a00c0f01730c08df82eaa7a7a61f11d656526506112cc1709" -dependencies = [ - "thiserror-impl", -] - -[[package]] -name = "thiserror-impl" -version = "1.0.61" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "46c3384250002a6d5af4d114f2845d37b57521033f30d5c3f46c4d70e1197533" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] - [[package]] name = "thread_local" version = "1.1.8" @@ -1659,19 +1536,6 @@ version = "0.2.92" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "af190c94f2773fdb3729c55b007a722abb5384da03bc0986df4c289bf5567e96" -[[package]] -name = "wasm-streams" -version = "0.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b65dc4c90b63b118468cf747d8bf3566c1913ef60be765b5730ead9e0a3ba129" -dependencies = [ - "futures-util", - "js-sys", - "wasm-bindgen", - "wasm-bindgen-futures", - "web-sys", -] - [[package]] name = "web-sys" version = "0.3.69" diff --git a/Cargo.toml b/Cargo.toml index 9c41844..8a4c9a7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,9 +13,7 @@ clap_complete = "4" serde = { version = "1", features = ["derive", "rc"] } serde_json = "1" tokio = { version = "1", features = ["full"] } -reqwest = { version = "0", features = ["stream"] } -reqwest-eventsource = "0" -futures = "0" +reqwest = "0" rand = "0" base64 = "0" uuid = { version = "1", features = ["v4"] } diff --git a/src/azure/chatgpt.rs b/src/azure/chatgpt.rs index a739b02..6100bf5 100644 --- a/src/azure/chatgpt.rs +++ b/src/azure/chatgpt.rs @@ -1,14 +1,14 @@ use std::collections::HashMap; use std::ops::Not; use std::rc::Rc; +use std::str; -use futures::stream::StreamExt; -use reqwest_eventsource::CannotCloneRequestError; -use reqwest_eventsource::Event; -use reqwest_eventsource::EventSource; +use bytes::Bytes; +use reqwest::Response; use tokio::sync::mpsc::channel; use tokio::sync::mpsc::Receiver; use tokio::sync::mpsc::Sender; +use tracing::info; use crate::azure::chatgpt_api::ChatRequest; use crate::azure::chatgpt_api::ChatRequestMessage; @@ -40,7 +40,7 @@ type FunctionCall = HashMap; impl ChatGPT { pub fn new(endpoint: String, model: String, api_key: String, system_message: Option, function_store: FunctionStore) -> Self { - let url = format!("{endpoint}/openai/deployments/{model}/chat/completions?api-version=2024-02-01"); + let url = format!("{endpoint}/openai/deployments/{model}/chat/completions?api-version=2024-06-01"); let tools: Option> = function_store.declarations.is_empty().not().then_some( function_store .declarations @@ -90,11 +90,10 @@ impl ChatGPT { } async fn process(&mut self, handler: &impl ChatHandler) -> Result, Exception> { - let source = self.call_api().await?; - let (tx, rx) = channel(64); - let handle = tokio::spawn(read_event_source(source, tx)); + let response = self.call_api().await?; + let handle = tokio::spawn(read_sse(response, tx)); let function_call = self.process_event(rx, handler).await; handle.await??; @@ -122,12 +121,14 @@ impl ChatGPT { None } - async fn call_api(&mut self) -> Result { + async fn call_api(&mut self) -> Result { let request = ChatRequest { messages: Rc::clone(&self.messages), temperature: 0.7, top_p: 0.95, stream: true, + // stream_options: Some(StreamOptions { include_usage: true }), + stream_options: None, stop: None, max_tokens: 800, presence_penalty: 0.0, @@ -137,36 +138,46 @@ impl ChatGPT { }; let body = json::to_json(&request)?; + let body = Bytes::from(body); let request = http_client::http_client() .post(&self.url) .header("Content-Type", "application/json") .header("api-key", &self.api_key) - .body(body); - - Ok(EventSource::new(request)?) - } -} + .body(body.clone()); + + let response = request.send().await?; + let status = response.status(); + if status != 200 { + let body = str::from_utf8(&body).unwrap(); + info!("body={}", body); + let response_text = response.text().await?; + return Err(Exception::ExternalError(format!( + "failed to call azure api, status={status}, response={response_text}" + ))); + } -impl From for Exception { - fn from(err: CannotCloneRequestError) -> Self { - Exception::unexpected(err) + Ok(response) } } -async fn read_event_source(mut source: EventSource, tx: Sender) -> Result<(), Exception> { +async fn read_sse(response: Response, tx: Sender) -> Result<(), Exception> { let mut function_calls: FunctionCall = HashMap::new(); - while let Some(event) = source.next().await { - match event { - Ok(Event::Open) => {} - Ok(Event::Message(message)) => { - let data = message.data; + let mut usage = Usage::default(); + + let mut buffer = String::with_capacity(1024); + let mut response = response; + 'outer: while let Some(chunk) = response.chunk().await? { + buffer.push_str(str::from_utf8(&chunk).unwrap()); + + while let Some(index) = buffer.find("\n\n") { + if buffer.starts_with("data:") { + let data = &buffer[6..index]; if data == "[DONE]" { - source.close(); - break; + break 'outer; } - let response: ChatResponse = json::from_json(&data)?; + let response: ChatResponse = json::from_json(data)?; if let Some(choice) = response.choices.into_iter().next() { let delta = choice.delta.unwrap(); @@ -181,18 +192,25 @@ async fn read_event_source(mut source: EventSource, tx: Sender) - tx.send(InternalEvent::Event(ChatEvent::Delta(value))).await?; } } - } - Err(err) => { - source.close(); - return Err(Exception::unexpected(err)); + + if let Some(value) = response.usage { + usage = Usage { + request_tokens: value.prompt_tokens, + response_tokens: value.completion_tokens, + }; + } + + buffer.replace_range(0..index + 2, ""); + } else { + return Err(Exception::ValidationError(format!("unexpected sse message, buffer={}", buffer))); } } } + if !function_calls.is_empty() { tx.send(InternalEvent::FunctionCall(function_calls)).await?; } else { - // chatgpt doesn't support token usage with stream mode - tx.send(InternalEvent::Event(ChatEvent::End(Usage::default()))).await?; + tx.send(InternalEvent::Event(ChatEvent::End(usage))).await?; } Ok(()) diff --git a/src/azure/chatgpt_api.rs b/src/azure/chatgpt_api.rs index 80fa388..5146af9 100644 --- a/src/azure/chatgpt_api.rs +++ b/src/azure/chatgpt_api.rs @@ -13,6 +13,8 @@ pub struct ChatRequest { pub top_p: f32, pub stream: bool, #[serde(skip_serializing_if = "Option::is_none")] + pub stream_options: Option, + #[serde(skip_serializing_if = "Option::is_none")] pub stop: Option>, pub max_tokens: i32, pub presence_penalty: f32, @@ -33,6 +35,11 @@ pub struct ChatRequestMessage { pub tool_calls: Option>, } +#[derive(Debug, Serialize)] +pub struct StreamOptions { + pub include_usage: bool, +} + impl ChatRequestMessage { pub fn new_message(role: Role, message: String) -> Self { ChatRequestMessage { @@ -101,6 +108,7 @@ pub struct ChatResponse { pub created: i64, pub model: String, pub choices: Vec, + pub usage: Option, } #[allow(dead_code)] @@ -131,3 +139,11 @@ pub struct FunctionCall { pub name: Option, pub arguments: String, } + +#[allow(dead_code)] +#[derive(Debug, Deserialize)] +pub struct Usage { + pub completion_tokens: i32, + pub prompt_tokens: i32, + pub total_tokens: i32, +} diff --git a/src/gcloud/gemini.rs b/src/gcloud/gemini.rs index fa9aa7d..6868f89 100644 --- a/src/gcloud/gemini.rs +++ b/src/gcloud/gemini.rs @@ -8,7 +8,6 @@ use std::str; use base64::prelude::BASE64_STANDARD; use base64::Engine; use bytes::Bytes; -use futures::StreamExt; use reqwest::Response; use tokio::sync::mpsc::channel; use tokio::sync::mpsc::Receiver; @@ -209,27 +208,19 @@ impl Gemini { } async fn read_response_stream(response: Response, tx: Sender) -> Result<(), Exception> { - let stream = &mut response.bytes_stream(); - - let mut buffer = String::new(); - while let Some(result) = stream.next().await { - match result { - Ok(chunk) => { - buffer.push_str(str::from_utf8(&chunk).unwrap()); - - // first char is '[' or ',' - if !is_valid_json(&buffer[1..]) { - continue; - } - - let content: GenerateContentResponse = json::from_json(&buffer[1..])?; - tx.send(content).await?; - buffer.clear(); - } - Err(err) => { - return Err(Exception::unexpected(err)); - } + let mut response = response; + let mut buffer = String::with_capacity(1024); + while let Some(chunk) = response.chunk().await? { + buffer.push_str(str::from_utf8(&chunk).unwrap()); + + // first char is '[' or ',' + if !is_valid_json(&buffer[1..]) { + continue; } + + let content: GenerateContentResponse = json::from_json(&buffer[1..])?; + tx.send(content).await?; + buffer.clear(); } Ok(()) } diff --git a/src/llm/function.rs b/src/llm/function.rs index f0f7292..9ab605e 100644 --- a/src/llm/function.rs +++ b/src/llm/function.rs @@ -1,8 +1,8 @@ use std::collections::HashMap; use std::sync::Arc; -use futures::future::try_join_all; use serde::Serialize; +use tokio::task::JoinSet; use tracing::info; use crate::util::exception::Exception; @@ -48,15 +48,18 @@ impl FunctionStore { } pub async fn call_functions(&self, functions: Vec<(String, String, serde_json::Value)>) -> Result, Exception> { - let mut handles = Vec::with_capacity(functions.len()); + let mut handles = JoinSet::new(); for (id, name, args) in functions { let function = self.get(&name)?; - handles.push(tokio::spawn(async move { + handles.spawn(async move { info!("call function, id={id}, name={name}, args={args}"); (id, function(args)) - })); + }); + } + let mut results = vec![]; + while let Some(result) = handles.join_next().await { + results.push(result?) } - let results = try_join_all(handles).await?; Ok(results) }