diff --git a/Cargo.toml b/Cargo.toml index cafc024..271fe9a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,7 +8,7 @@ edition = "2021" [dependencies] clap = { version = "4.4.16", features = ["derive"] } clap_complete = "4.4.6" -serde = { version = "1.0", features = ["derive"] } +serde = { version = "1.0", features = ["derive", "rc"] } serde_json = "1.0" hyper = { version = "1.2.0", features = ["full"] } tokio = { version = "1.36.0", features = ["full"] } diff --git a/src/bot.rs b/src/bot.rs index 2ebb18b..b1add63 100644 --- a/src/bot.rs +++ b/src/bot.rs @@ -16,7 +16,7 @@ use self::config::Config; pub mod config; pub trait ChatHandler { - fn on_event(&self, event: &ChatEvent); + fn on_event(&self, event: ChatEvent); } pub enum ChatEvent { @@ -70,7 +70,7 @@ pub enum Bot { } impl Bot { - pub async fn chat(&mut self, message: &str, handler: &dyn ChatHandler) -> Result<(), Exception> { + pub async fn chat(&mut self, message: String, handler: &dyn ChatHandler) -> Result<(), Exception> { match self { Bot::ChatGPT(bot) => bot.chat(message, handler).await, Bot::Vertex(bot) => bot.chat(message, handler).await, diff --git a/src/command/chat.rs b/src/command/chat.rs index d2246ef..c0eb281 100644 --- a/src/command/chat.rs +++ b/src/command/chat.rs @@ -24,9 +24,9 @@ pub struct Chat { struct ConsoleHandler; impl ChatHandler for ConsoleHandler { - fn on_event(&self, event: &ChatEvent) { + fn on_event(&self, event: ChatEvent) { match event { - ChatEvent::Delta(data) => { + ChatEvent::Delta(ref data) => { print_flush(data).unwrap(); } ChatEvent::Error(error) => { @@ -53,7 +53,7 @@ impl Chat { break; } - bot.chat(&line, &handler).await?; + bot.chat(line, &handler).await?; } Ok(()) } diff --git a/src/gcloud/api.rs b/src/gcloud/api.rs index d3c1419..d3e9439 100644 --- a/src/gcloud/api.rs +++ b/src/gcloud/api.rs @@ -1,4 +1,4 @@ -use std::borrow::Cow; +use std::rc::Rc; use serde::Deserialize; use serde::Serialize; @@ -6,23 +6,22 @@ use serde::Serialize; use crate::bot::Function; #[derive(Debug, Serialize)] -pub struct StreamGenerateContent<'a> { - #[serde(borrow)] - pub contents: Cow<'a, [Content]>, +pub struct StreamGenerateContent { + pub contents: Rc>, #[serde(rename = "generationConfig")] pub generation_config: GenerationConfig, #[serde(skip_serializing_if = "Option::is_none")] - pub tools: Option>, + pub tools: Option>>, } -#[derive(Debug, Serialize, Deserialize, Clone)] +#[derive(Debug, Serialize, Deserialize)] pub struct Content { pub role: Role, pub parts: Vec, } impl Content { - pub fn new_text(role: Role, message: &str) -> Self { + pub fn new_text(role: Role, message: String) -> Self { Self { role, parts: vec![Part { @@ -33,16 +32,13 @@ impl Content { } } - pub fn new_function_response(name: &str, response: serde_json::Value) -> Self { + pub fn new_function_response(name: String, response: serde_json::Value) -> Self { Self { role: Role::User, parts: vec![Part { text: None, function_call: None, - function_response: Some(FunctionResponse { - name: name.to_string(), - response, - }), + function_response: Some(FunctionResponse { name, response }), }], } } @@ -59,13 +55,13 @@ impl Content { } } -#[derive(Debug, Serialize, Clone)] +#[derive(Debug, Serialize)] pub struct Tool { #[serde(rename = "functionDeclarations")] pub function_declarations: Vec, } -#[derive(Debug, Serialize, Deserialize, Clone)] +#[derive(Debug, Serialize, Deserialize)] pub enum Role { #[serde(rename = "user")] User, @@ -73,7 +69,7 @@ pub enum Role { Model, } -#[derive(Debug, Serialize, Deserialize, Clone)] +#[derive(Debug, Serialize, Deserialize)] pub struct Part { #[serde(skip_serializing_if = "Option::is_none")] pub text: Option, @@ -110,7 +106,7 @@ pub struct FunctionCall { pub args: serde_json::Value, } -#[derive(Debug, Serialize, Deserialize, Clone)] +#[derive(Debug, Serialize, Deserialize)] pub struct FunctionResponse { pub name: String, pub response: serde_json::Value, diff --git a/src/gcloud/vertex.rs b/src/gcloud/vertex.rs index fceb7ea..1722678 100644 --- a/src/gcloud/vertex.rs +++ b/src/gcloud/vertex.rs @@ -3,9 +3,8 @@ use reqwest::Response; use tokio::sync::mpsc::channel; use tokio::sync::mpsc::Sender; -use std::borrow::Cow; - use std::env; +use std::rc::Rc; use crate::bot::ChatEvent; use crate::bot::ChatHandler; @@ -28,8 +27,8 @@ pub struct Vertex { project: String, location: String, model: String, - messages: Vec, - tools: Vec, + messages: Rc>, + tools: Rc>, function_store: FunctionStore, } @@ -40,19 +39,21 @@ impl Vertex { project, location, model, - messages: vec![], - tools: function_store - .declarations - .iter() - .map(|f| Tool { - function_declarations: vec![f.clone()], - }) - .collect(), + messages: Rc::new(vec![]), + tools: Rc::new( + function_store + .declarations + .iter() + .map(|f| Tool { + function_declarations: vec![f.clone()], + }) + .collect(), + ), function_store, } } - pub async fn chat(&mut self, message: &str, handler: &dyn ChatHandler) -> Result<(), Exception> { + pub async fn chat(&mut self, message: String, handler: &dyn ChatHandler) -> Result<(), Exception> { let mut result = self.process(Content::new_text(Role::User, message), handler).await?; while let Some(function_call) = result { @@ -60,14 +61,14 @@ impl Vertex { let function_response = tokio::spawn(async move { function(function_call.args) }).await?; - let content = Content::new_function_response(&function_call.name, function_response); + let content = Content::new_function_response(function_call.name, function_response); result = self.process(content, handler).await?; } Ok(()) } async fn process(&mut self, content: Content, handler: &dyn ChatHandler) -> Result, Exception> { - self.messages.push(content); + self.add_message(content); let response = self.call_api().await?; @@ -81,14 +82,14 @@ impl Vertex { while let Some(response) = rx.recv().await { match response { Ok(response) => { - let part = response.candidates.first().unwrap().content.parts.first().unwrap(); - - if let Some(function) = part.function_call.as_ref() { - self.messages.push(Content::new_function_call(function.clone())); - return Ok(Some(function.clone())); - } else if let Some(text) = part.text.as_ref() { - handler.on_event(&ChatEvent::Delta(text.to_string())); - model_message.push_str(text); + let part = response.candidates.into_iter().next().unwrap().content.parts.into_iter().next().unwrap(); + + if let Some(function) = part.function_call { + self.add_message(Content::new_function_call(function.clone())); + return Ok(Some(function)); + } else if let Some(text) = part.text { + handler.on_event(ChatEvent::Delta(text.clone())); + model_message.push_str(&text); } } Err(err) => { @@ -97,12 +98,16 @@ impl Vertex { } } if !model_message.is_empty() { - self.messages.push(Content::new_text(Role::Model, &model_message)); + self.add_message(Content::new_text(Role::Model, model_message)); } - handler.on_event(&ChatEvent::End); + handler.on_event(ChatEvent::End); Ok(None) } + fn add_message(&mut self, content: Content) { + Rc::get_mut(&mut self.messages).unwrap().push(content); + } + async fn call_api(&self) -> Result { let has_function = !self.tools.is_empty(); @@ -113,19 +118,19 @@ impl Vertex { let url = format!("{endpoint}/v1/projects/{project}/locations/{location}/publishers/google/models/{model}:streamGenerateContent"); let request = StreamGenerateContent { - contents: Cow::from(&self.messages), + contents: Rc::clone(&self.messages), generation_config: GenerationConfig { temperature: 1.0, top_p: 0.95, max_output_tokens: 2048, }, - tools: has_function.then(|| Cow::from(&self.tools)), + tools: has_function.then(|| Rc::clone(&self.tools)), }; let response = self.post(&url, &request).await?; Ok(response) } - async fn post(&self, url: &str, request: &StreamGenerateContent<'_>) -> Result { + async fn post(&self, url: &str, request: &StreamGenerateContent) -> Result { let body = json::to_json(request)?; let response = http_client::http_client() .post(url) diff --git a/src/openai/api.rs b/src/openai/api.rs index bb35c25..7fd4b79 100644 --- a/src/openai/api.rs +++ b/src/openai/api.rs @@ -1,5 +1,5 @@ -use std::borrow::Cow; use std::collections::HashMap; +use std::rc::Rc; use serde::Deserialize; use serde::Serialize; @@ -7,9 +7,8 @@ use serde::Serialize; use crate::bot::Function; #[derive(Debug, Serialize)] -pub struct ChatRequest<'a> { - #[serde(borrow)] - pub messages: Cow<'a, [ChatRequestMessage]>, +pub struct ChatRequest { + pub messages: Rc>, pub temperature: f32, pub top_p: f32, pub stream: bool, @@ -21,10 +20,10 @@ pub struct ChatRequest<'a> { #[serde(skip_serializing_if = "Option::is_none")] pub tool_choice: Option, #[serde(skip_serializing_if = "Option::is_none")] - pub tools: Option>, + pub tools: Option>>, } -#[derive(Debug, Serialize, Clone)] +#[derive(Debug, Serialize)] pub struct ChatRequestMessage { pub role: Role, pub content: Option, @@ -35,10 +34,10 @@ pub struct ChatRequestMessage { } impl ChatRequestMessage { - pub fn new_message(role: Role, message: &str) -> Self { + pub fn new_message(role: Role, message: String) -> Self { ChatRequestMessage { role, - content: Some(message.to_string()), + content: Some(message), tool_call_id: None, tool_calls: None, } @@ -76,13 +75,13 @@ impl ChatRequestMessage { } } -#[derive(Debug, Serialize, Clone)] +#[derive(Debug, Serialize)] pub struct Tool { pub r#type: String, pub function: Function, } -#[derive(Debug, Serialize, Deserialize, Clone)] +#[derive(Debug, Serialize, Deserialize)] pub enum Role { #[serde(rename = "user")] User, @@ -117,7 +116,7 @@ pub struct ChatResponseMessage { pub tool_calls: Option>, } -#[derive(Debug, Serialize, Deserialize, Clone)] +#[derive(Debug, Serialize, Deserialize)] pub struct ToolCall { pub index: i64, pub id: Option, @@ -125,7 +124,7 @@ pub struct ToolCall { pub function: FunctionCall, } -#[derive(Debug, Serialize, Deserialize, Clone)] +#[derive(Debug, Serialize, Deserialize)] pub struct FunctionCall { pub name: Option, pub arguments: String, diff --git a/src/openai/chatgpt.rs b/src/openai/chatgpt.rs index 994d8e6..9dc37a4 100644 --- a/src/openai/chatgpt.rs +++ b/src/openai/chatgpt.rs @@ -1,7 +1,7 @@ -use std::borrow::Cow; use std::collections::HashMap; use std::fmt; +use std::rc::Rc; use futures::future::join_all; use futures::stream::StreamExt; @@ -30,8 +30,8 @@ pub struct ChatGPT { endpoint: String, api_key: String, model: String, - messages: Vec, - tools: Vec, + messages: Rc>, + tools: Rc>, function_store: FunctionStore, } @@ -48,25 +48,27 @@ impl ChatGPT { endpoint, api_key, model, - messages: vec![], - tools: function_store - .declarations - .iter() - .map(|f| Tool { - r#type: "function".to_string(), - function: f.clone(), - }) - .collect(), + messages: Rc::new(vec![]), + tools: Rc::new( + function_store + .declarations + .iter() + .map(|f| Tool { + r#type: "function".to_string(), + function: f.clone(), + }) + .collect(), + ), function_store, }; if let Some(message) = system_message { - chatgpt.messages.push(ChatRequestMessage::new_message(Role::System, &message)); + chatgpt.add_message(ChatRequestMessage::new_message(Role::System, message)); } chatgpt } - pub async fn chat(&mut self, message: &str, handler: &dyn ChatHandler) -> Result<(), Exception> { - self.messages.push(ChatRequestMessage::new_message(Role::User, message)); + pub async fn chat(&mut self, message: String, handler: &dyn ChatHandler) -> Result<(), Exception> { + self.add_message(ChatRequestMessage::new_message(Role::User, message)); let result = self.process(handler).await; if let Ok(Some(InternalEvent::FunctionCall(calls))) = result { let handles: Result>, _> = calls @@ -79,8 +81,8 @@ impl ChatGPT { let results: Result, _> = join_all(handles?).await.into_iter().collect(); for result in results? { - let function_message = ChatRequestMessage::new_function_response(result.0.to_string(), json::to_json(&result.1)?); - self.messages.push(function_message); + let function_message = ChatRequestMessage::new_function_response(result.0, json::to_json(&result.1)?); + self.add_message(function_message); } self.process(handler).await?; @@ -88,6 +90,10 @@ impl ChatGPT { Ok(()) } + fn add_message(&mut self, message: ChatRequestMessage) { + Rc::get_mut(&mut self.messages).unwrap().push(message); + } + async fn process(&mut self, handler: &dyn ChatHandler) -> Result, Exception> { let source = self.call_api().await?; @@ -100,20 +106,20 @@ impl ChatGPT { while let Some(event) = rx.recv().await { match event { InternalEvent::Event(event) => { - handler.on_event(&event); - if let ChatEvent::Delta(data) = event { - assistant_message.push_str(&data); + if let ChatEvent::Delta(ref data) = event { + assistant_message.push_str(data); } + handler.on_event(event); } InternalEvent::FunctionCall(calls) => { - self.messages.push(ChatRequestMessage::new_function_call(&calls)); + self.add_message(ChatRequestMessage::new_function_call(&calls)); return Ok(Some(InternalEvent::FunctionCall(calls))); } } } if !assistant_message.is_empty() { - self.messages.push(ChatRequestMessage::new_message(Role::Assistant, &assistant_message)); + self.add_message(ChatRequestMessage::new_message(Role::Assistant, assistant_message)); } Ok(None) @@ -123,7 +129,7 @@ impl ChatGPT { let has_function = !self.tools.is_empty(); let request = ChatRequest { - messages: Cow::from(&self.messages), + messages: Rc::clone(&self.messages), temperature: 0.8, top_p: 0.8, stream: true, @@ -132,7 +138,7 @@ impl ChatGPT { presence_penalty: 0.0, frequency_penalty: 0.0, tool_choice: has_function.then(|| "auto".to_string()), - tools: has_function.then(|| Cow::from(&self.tools)), + tools: has_function.then(|| Rc::clone(&self.tools)), }; let source = self.post_sse(&request).await?; Ok(source)