diff --git a/src/bot.rs b/src/bot.rs index 5916042..76fb20b 100644 --- a/src/bot.rs +++ b/src/bot.rs @@ -35,7 +35,7 @@ pub enum Bot { } impl Bot { - pub async fn chat(&mut self, message: String, handler: &dyn ChatHandler) -> Result<(), Exception> { + pub async fn chat(&mut self, message: String, handler: &impl ChatHandler) -> Result<(), Exception> { match self { Bot::ChatGPT(bot) => bot.chat(message, handler).await, Bot::Vertex(bot) => bot.chat(message, handler).await, diff --git a/src/gcloud/api.rs b/src/gcloud/api.rs index a397dfd..4ec23a6 100644 --- a/src/gcloud/api.rs +++ b/src/gcloud/api.rs @@ -9,11 +9,11 @@ use crate::bot::function::Function; pub struct StreamGenerateContent { pub contents: Rc>, #[serde(rename = "systemInstruction", skip_serializing_if = "Option::is_none")] - pub system_instruction: Rc>, + pub system_instruction: Option>, #[serde(rename = "generationConfig")] pub generation_config: GenerationConfig, #[serde(skip_serializing_if = "Option::is_none")] - pub tools: Option>>, + pub tools: Option>, } #[derive(Debug, Serialize, Deserialize)] diff --git a/src/gcloud/vertex.rs b/src/gcloud/vertex.rs index 9932e97..fdb1102 100644 --- a/src/gcloud/vertex.rs +++ b/src/gcloud/vertex.rs @@ -1,6 +1,7 @@ use std::env; use std::fs; use std::mem; +use std::ops::Not; use std::path::Path; use std::rc::Rc; @@ -31,8 +32,8 @@ use crate::util::json; pub struct Vertex { url: String, messages: Rc>, - system_message: Rc>, - tools: Rc>, + system_message: Option>, + tools: Option>, function_store: FunctionStore, data: Vec, usage: Usage, @@ -51,17 +52,17 @@ impl Vertex { Vertex { url, messages: Rc::new(vec![]), - system_message: Rc::new(system_message.map(|message| Content::new_text(Role::Model, message))), - tools: Rc::new(vec![Tool { + system_message: system_message.map(|message| Rc::new(Content::new_text(Role::Model, message))), + tools: function_store.declarations.is_empty().not().then_some(Rc::from(vec![Tool { function_declarations: function_store.declarations.to_vec(), - }]), + }])), function_store, data: vec![], usage: Usage::default(), } } - pub async fn chat(&mut self, message: String, handler: &dyn ChatHandler) -> Result<(), Exception> { + pub async fn chat(&mut self, message: String, handler: &impl ChatHandler) -> Result<(), Exception> { let data = mem::take(&mut self.data); let mut result = self.process(Content::new_text_with_inline_data(message, data), handler).await?; @@ -97,7 +98,7 @@ impl Vertex { Ok(()) } - async fn process(&mut self, content: Content, handler: &dyn ChatHandler) -> Result, Exception> { + async fn process(&mut self, content: Content, handler: &impl ChatHandler) -> Result, Exception> { self.add_message(content); let response = self.call_api().await?; @@ -138,23 +139,17 @@ impl Vertex { } async fn call_api(&self) -> Result { - let has_function = !self.tools.is_empty(); - let request = StreamGenerateContent { contents: Rc::clone(&self.messages), - system_instruction: Rc::clone(&self.system_message), + system_instruction: self.system_message.clone(), generation_config: GenerationConfig { temperature: 1.0, top_p: 0.95, max_output_tokens: 2048, }, - tools: has_function.then(|| Rc::clone(&self.tools)), + tools: self.tools.clone(), }; - let response = self.post(request).await?; - Ok(response) - } - async fn post(&self, request: StreamGenerateContent) -> Result { let body = json::to_json(&request)?; // info!("body={body}"); let response = http_client::http_client() @@ -174,6 +169,7 @@ impl Vertex { response.text().await? ))); } + Ok(response) } } diff --git a/src/openai/api.rs b/src/openai/api.rs index 30df0cd..4074928 100644 --- a/src/openai/api.rs +++ b/src/openai/api.rs @@ -20,7 +20,7 @@ pub struct ChatRequest { #[serde(skip_serializing_if = "Option::is_none")] pub tool_choice: Option, #[serde(skip_serializing_if = "Option::is_none")] - pub tools: Option>>, + pub tools: Option>, } #[derive(Debug, Serialize)] diff --git a/src/openai/chatgpt.rs b/src/openai/chatgpt.rs index 6b792cf..f2a4d0b 100644 --- a/src/openai/chatgpt.rs +++ b/src/openai/chatgpt.rs @@ -1,12 +1,11 @@ use std::collections::HashMap; -use std::fmt; +use std::ops::Not; use std::rc::Rc; use futures::stream::StreamExt; use reqwest_eventsource::CannotCloneRequestError; use reqwest_eventsource::Event; use reqwest_eventsource::EventSource; -use serde::Serialize; use tokio::sync::mpsc::channel; use tokio::sync::mpsc::Sender; @@ -27,7 +26,7 @@ pub struct ChatGPT { url: String, api_key: String, messages: Rc>, - tools: Rc>, + tools: Option>, function_store: FunctionStore, } @@ -41,20 +40,21 @@ type FunctionCall = HashMap; impl ChatGPT { pub fn new(endpoint: String, model: String, api_key: String, system_message: Option, function_store: FunctionStore) -> Self { let url = format!("{endpoint}/openai/deployments/{model}/chat/completions?api-version=2024-02-01"); + let tools: Option> = function_store.declarations.is_empty().not().then_some( + function_store + .declarations + .iter() + .map(|f| Tool { + r#type: "function".to_string(), + function: f.clone(), + }) + .collect(), + ); let mut chatgpt = ChatGPT { url, api_key, messages: Rc::new(vec![]), - tools: Rc::new( - function_store - .declarations - .iter() - .map(|f| Tool { - r#type: "function".to_string(), - function: f.clone(), - }) - .collect(), - ), + tools, function_store, }; if let Some(message) = system_message { @@ -63,14 +63,14 @@ impl ChatGPT { chatgpt } - pub async fn chat(&mut self, message: String, handler: &dyn ChatHandler) -> Result<(), Exception> { + pub async fn chat(&mut self, message: String, handler: &impl ChatHandler) -> Result<(), Exception> { self.add_message(ChatRequestMessage::new_message(Role::User, message)); let result = self.process(handler).await; if let Ok(Some(InternalEvent::FunctionCall(calls))) = result { - let functions = calls - .into_iter() - .map(|(_, (id, name, args))| (id, name, json::from_json(&args).unwrap())) - .collect(); + let mut functions = Vec::with_capacity(calls.len()); + for (_, (id, name, args)) in calls { + functions.push((id, name, json::from_json::(&args)?)) + } let results = self.function_store.call_functions(functions).await?; for result in results { let function_response = ChatRequestMessage::new_function_response(result.0, json::to_json(&result.1)?); @@ -85,7 +85,7 @@ impl ChatGPT { Rc::get_mut(&mut self.messages).unwrap().push(message); } - async fn process(&mut self, handler: &dyn ChatHandler) -> Result, Exception> { + async fn process(&mut self, handler: &impl ChatHandler) -> Result, Exception> { let source = self.call_api().await?; let (tx, mut rx) = channel(64); @@ -117,8 +117,6 @@ impl ChatGPT { } async fn call_api(&mut self) -> Result { - let has_function = !self.tools.is_empty(); - let request = ChatRequest { messages: Rc::clone(&self.messages), temperature: 0.8, @@ -128,19 +126,11 @@ impl ChatGPT { max_tokens: 800, presence_penalty: 0.0, frequency_penalty: 0.0, - tool_choice: has_function.then(|| "auto".to_string()), - tools: has_function.then(|| Rc::clone(&self.tools)), + tool_choice: self.tools.is_some().then_some("auto".to_string()), + tools: self.tools.clone(), }; - let source = self.post_sse(&request).await?; - Ok(source) - } - async fn post_sse(&self, request: &Request) -> Result - where - Request: Serialize + fmt::Debug, - { let body = json::to_json(&request)?; - let request = http_client::http_client() .post(&self.url) .header("Content-Type", "application/json")