diff --git a/Cargo.lock b/Cargo.lock index 8e6a996..ee59072 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -276,6 +276,21 @@ dependencies = [ "percent-encoding", ] +[[package]] +name = "futures" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "645c6916888f6cb6350d2550b80fb63e734897a8498abe35cfb732b6487804b0" +dependencies = [ + "futures-channel", + "futures-core", + "futures-executor", + "futures-io", + "futures-sink", + "futures-task", + "futures-util", +] + [[package]] name = "futures-channel" version = "0.3.30" @@ -283,6 +298,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "eac8f7d7865dcb88bd4373ab671c8cf4508703796caa2b1985a9ca867b3fcb78" dependencies = [ "futures-core", + "futures-sink", ] [[package]] @@ -291,6 +307,34 @@ version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dfc6580bb841c5a68e9ef15c77ccc837b40a7504914d52e47b8b0e9bbda25a1d" +[[package]] +name = "futures-executor" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a576fc72ae164fca6b9db127eaa9a9dda0d61316034f33a0a0d4eda41f02b01d" +dependencies = [ + "futures-core", + "futures-task", + "futures-util", +] + +[[package]] +name = "futures-io" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a44623e20b9681a318efdd71c299b6b222ed6f231972bfe2f224ebad6311f0c1" + +[[package]] +name = "futures-macro" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "87750cf4b7a4c0625b1529e4c543c2182106e4dedc60a2a6455e00d212c489ac" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "futures-sink" version = "0.3.30" @@ -309,10 +353,16 @@ version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3d6401deb83407ab3da39eba7e33987a73c3df0c82b4bb5813ee871c19c41d48" dependencies = [ + "futures-channel", "futures-core", + "futures-io", + "futures-macro", + "futures-sink", "futures-task", + "memchr", "pin-project-lite", "pin-utils", + "slab", ] [[package]] @@ -771,6 +821,7 @@ dependencies = [ "bytes", "clap", "clap_complete", + "futures", "glob", "log", "rand", @@ -894,10 +945,12 @@ dependencies = [ "system-configuration", "tokio", "tokio-native-tls", + "tokio-util", "tower-service", "url", "wasm-bindgen", "wasm-bindgen-futures", + "wasm-streams", "web-sys", "winreg", ] @@ -1436,6 +1489,19 @@ version = "0.2.92" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "af190c94f2773fdb3729c55b007a722abb5384da03bc0986df4c289bf5567e96" +[[package]] +name = "wasm-streams" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b65dc4c90b63b118468cf747d8bf3566c1913ef60be765b5730ead9e0a3ba129" +dependencies = [ + "futures-util", + "js-sys", + "wasm-bindgen", + "wasm-bindgen-futures", + "web-sys", +] + [[package]] name = "web-sys" version = "0.3.69" diff --git a/Cargo.toml b/Cargo.toml index 4629db2..d188a47 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,10 +12,11 @@ serde = { version = "1", features = ["derive", "rc"] } serde_json = "1" log = "0" tokio = { version = "1", features = ["full"] } -reqwest = "0" +reqwest = { version = "0", features = ["stream"] } bytes = "1" rand = "0" base64 = "0" uuid = { version = "1", features = ["v4"] } regex = "1" glob = "0" +futures = "0" diff --git a/src/azure/chatgpt.rs b/src/azure/chatgpt.rs index f0f8fab..ccf53e8 100644 --- a/src/azure/chatgpt.rs +++ b/src/azure/chatgpt.rs @@ -1,3 +1,5 @@ +use std::io; +use std::io::ErrorKind; use std::ops::Not; use std::path::Path; use std::rc::Rc; @@ -6,6 +8,9 @@ use std::str; use base64::prelude::BASE64_STANDARD; use base64::Engine; use bytes::Bytes; +use futures::AsyncBufReadExt; +use futures::StreamExt; +use futures::TryStreamExt; use log::info; use reqwest::Response; use tokio::fs; @@ -178,7 +183,7 @@ impl ChatGPT { } } -async fn read_sse_response(mut http_response: Response) -> Result { +async fn read_sse_response(http_response: Response) -> Result { let mut response = ChatResponse { choices: vec![ChatCompletionChoice { index: 0, @@ -193,62 +198,60 @@ async fn read_sse_response(mut http_response: Response) -> Result, function_implementations: FunctionImplementations, ) -> Self { - let url = format!("{endpoint}/v1/projects/{project}/locations/{location}/publishers/google/models/{model}:streamGenerateContent"); + let url = format!("{endpoint}/v1/projects/{project}/locations/{location}/publishers/google/models/{model}:streamGenerateContent?alt=sse&key=AIzaSyBK8zxfxrzAkg4DBXjpOQfkvfEXDXPikuQ"); Gemini { url, contents: Rc::new(vec![]), @@ -88,7 +92,7 @@ impl Gemini { async fn process(&mut self) -> Result<(), Exception> { loop { let http_response = self.call_api().await?; - let response = read_stream_response(http_response).await?; + let response = read_sse_response(http_response).await?; info!( "usage, prompt_tokens={}, candidates_tokens={}", response.usage_metadata.prompt_token_count, response.usage_metadata.candidates_token_count @@ -138,7 +142,7 @@ impl Gemini { let body = Bytes::from(body); let response = http_client::http_client() .post(&self.url) - .bearer_auth(token()) + // .bearer_auth(token()) .header("Content-Type", "application/json") .header("Accept", "application/json") .body(body.clone()) @@ -158,7 +162,7 @@ impl Gemini { } } -async fn read_stream_response(mut http_response: Response) -> Result { +async fn read_sse_response(http_response: Response) -> Result { let mut response = GenerateContentResponse { candidates: vec![Candidate { content: Content { @@ -171,41 +175,39 @@ async fn read_stream_response(mut http_response: Response) -> Result Result bool { - let result: serde_json::Result = serde_json::from_str(content); - result.is_ok() -} - async fn inline_datas(files: &[&Path]) -> Result, Exception> { let mut data = Vec::with_capacity(files.len()); for file in files {