From 1d5629368d94b861dab409a95a9123ad3a0d596a Mon Sep 17 00:00:00 2001 From: neo <1100909+neowu@users.noreply.github.com> Date: Wed, 17 Apr 2024 11:42:18 +0800 Subject: [PATCH] refactor function store --- src/bot.rs | 46 ++++++++++++++++++++++++++----------- src/bot/config.rs | 16 +++++++++---- src/command/chat.rs | 2 +- src/gcloud/vertex.rs | 53 +++++++++++++++++-------------------------- src/openai/chatgpt.rs | 48 +++++++++++++++------------------------ 5 files changed, 85 insertions(+), 80 deletions(-) diff --git a/src/bot.rs b/src/bot.rs index b760ed3..2ebb18b 100644 --- a/src/bot.rs +++ b/src/bot.rs @@ -1,5 +1,7 @@ +use std::collections::HashMap; use std::fs; use std::path::Path; +use std::sync::Arc; use serde::Serialize; use tracing::info; @@ -33,27 +35,45 @@ pub struct Function { pub type FunctionImplementation = dyn Fn(serde_json::Value) -> serde_json::Value + Send + Sync; +pub struct FunctionStore { + pub declarations: Vec, + pub implementations: HashMap>>, +} + +impl FunctionStore { + pub fn new() -> Self { + FunctionStore { + declarations: vec![], + implementations: HashMap::new(), + } + } + + pub fn add(&mut self, function: Function, implementation: Box) { + let name = function.name.to_string(); + self.declarations.push(function); + self.implementations.insert(name, Arc::new(implementation)); + } + + pub fn get(&self, name: &str) -> Result>, Exception> { + let function = Arc::clone( + self.implementations + .get(name) + .ok_or_else(|| Exception::new(&format!("function not found, name={name}")))?, + ); + Ok(function) + } +} + pub enum Bot { ChatGPT(ChatGPT), Vertex(Vertex), } impl Bot { - pub fn register_function(&mut self, function: Function, implementation: Box) { - match self { - Bot::ChatGPT(chat_gpt) => { - chat_gpt.register_function(function, implementation); - } - Bot::Vertex(vertex) => { - vertex.register_function(function, implementation); - } - } - } - pub async fn chat(&mut self, message: &str, handler: &dyn ChatHandler) -> Result<(), Exception> { match self { - Bot::ChatGPT(chat_gpt) => chat_gpt.chat(message, handler).await, - Bot::Vertex(vertex) => vertex.chat(message, handler).await, + Bot::ChatGPT(bot) => bot.chat(message, handler).await, + Bot::Vertex(bot) => bot.chat(message, handler).await, } } } diff --git a/src/bot/config.rs b/src/bot/config.rs index dde8936..2609357 100644 --- a/src/bot/config.rs +++ b/src/bot/config.rs @@ -9,6 +9,8 @@ use rand::Rng; use serde::Deserialize; use serde_json::json; +use super::FunctionStore; + #[derive(Deserialize, Debug)] pub struct Config { pub bots: HashMap, @@ -21,21 +23,25 @@ impl Config { .get(name) .ok_or_else(|| Exception::new(&format!("can not find bot, name={name}")))?; - let mut bot = match config.r#type { + let function_store = load_function_store(config); + + let bot = match config.r#type { BotType::Azure => Bot::ChatGPT(ChatGPT::new( config.endpoint.to_string(), config.params.get("api_key").unwrap().to_string(), config.params.get("model").unwrap().to_string(), Option::None, + function_store, )), BotType::GCloud => Bot::Vertex(Vertex::new( config.endpoint.to_string(), config.params.get("project").unwrap().to_string(), config.params.get("location").unwrap().to_string(), config.params.get("model").unwrap().to_string(), + function_store, )), }; - register_function(config, &mut bot); + Ok(bot) } } @@ -54,10 +60,11 @@ pub enum BotType { GCloud, } -fn register_function(config: &BotConfig, bot: &mut Bot) { +fn load_function_store(config: &BotConfig) -> FunctionStore { + let mut function_store = FunctionStore::new(); for function in &config.functions { if let "get_random_number" = function.as_str() { - bot.register_function( + function_store.add( Function { name: "get_random_number".to_string(), description: "generate random number".to_string(), @@ -84,4 +91,5 @@ fn register_function(config: &BotConfig, bot: &mut Bot) { ); } } + function_store } diff --git a/src/command/chat.rs b/src/command/chat.rs index 69efedd..d2246ef 100644 --- a/src/command/chat.rs +++ b/src/command/chat.rs @@ -42,8 +42,8 @@ impl ChatHandler for ConsoleHandler { impl Chat { pub async fn execute(&self) -> Result<(), Exception> { let config = bot::load(Path::new(&self.conf))?; - let mut bot = config.create(&self.name)?; + let handler = ConsoleHandler {}; loop { print_flush("> ")?; diff --git a/src/gcloud/vertex.rs b/src/gcloud/vertex.rs index 998cdda..fceb7ea 100644 --- a/src/gcloud/vertex.rs +++ b/src/gcloud/vertex.rs @@ -4,15 +4,13 @@ use tokio::sync::mpsc::channel; use tokio::sync::mpsc::Sender; use std::borrow::Cow; -use std::collections::HashMap; -use std::env; -use std::sync::Arc; +use std::env; use crate::bot::ChatEvent; use crate::bot::ChatHandler; -use crate::bot::Function; -use crate::bot::FunctionImplementation; + +use crate::bot::FunctionStore; use crate::gcloud::api::GenerateContentResponse; use crate::util::exception::Exception; use crate::util::http_client; @@ -26,48 +24,39 @@ use super::api::StreamGenerateContent; use super::api::Tool; pub struct Vertex { - pub endpoint: String, - pub project: String, - pub location: String, - pub model: String, + endpoint: String, + project: String, + location: String, + model: String, messages: Vec, tools: Vec, - function_implementations: HashMap>>, + function_store: FunctionStore, } impl Vertex { - pub fn new(endpoint: String, project: String, location: String, model: String) -> Self { + pub fn new(endpoint: String, project: String, location: String, model: String, function_store: FunctionStore) -> Self { Vertex { endpoint, project, location, model, messages: vec![], - tools: vec![], - function_implementations: HashMap::new(), + tools: function_store + .declarations + .iter() + .map(|f| Tool { + function_declarations: vec![f.clone()], + }) + .collect(), + function_store, } } - pub fn register_function(&mut self, function: Function, implementation: Box) { - let name = function.name.to_string(); - self.tools.push(Tool { - function_declarations: vec![function], - }); - self.function_implementations.insert(name, Arc::new(implementation)); - } - pub async fn chat(&mut self, message: &str, handler: &dyn ChatHandler) -> Result<(), Exception> { - let mut result = self - .process(Content::new_text(Role::User, message), handler) - .await - .map_err(Exception::from)?; + let mut result = self.process(Content::new_text(Role::User, message), handler).await?; while let Some(function_call) = result { - let function = Arc::clone( - self.function_implementations - .get(&function_call.name) - .ok_or_else(|| Exception::new(&format!("function not found, name={}", function_call.name)))?, - ); + let function = self.function_store.get(&function_call.name)?; let function_response = tokio::spawn(async move { function(function_call.args) }).await?; @@ -114,8 +103,8 @@ impl Vertex { Ok(None) } - async fn call_api(&mut self) -> Result { - let has_function = !self.function_implementations.is_empty(); + async fn call_api(&self) -> Result { + let has_function = !self.tools.is_empty(); let endpoint = &self.endpoint; let project = &self.project; diff --git a/src/openai/chatgpt.rs b/src/openai/chatgpt.rs index 96b775f..994d8e6 100644 --- a/src/openai/chatgpt.rs +++ b/src/openai/chatgpt.rs @@ -2,7 +2,6 @@ use std::borrow::Cow; use std::collections::HashMap; use std::fmt; -use std::sync::Arc; use futures::future::join_all; use futures::stream::StreamExt; @@ -16,8 +15,8 @@ use tokio::task::JoinHandle; use crate::bot::ChatEvent; use crate::bot::ChatHandler; -use crate::bot::Function; -use crate::bot::FunctionImplementation; + +use crate::bot::FunctionStore; use crate::openai::api::ChatRequest; use crate::openai::api::ChatRequestMessage; use crate::openai::api::ChatResponse; @@ -28,12 +27,12 @@ use crate::util::http_client; use crate::util::json; pub struct ChatGPT { - pub endpoint: String, - pub api_key: String, - pub model: String, + endpoint: String, + api_key: String, + model: String, messages: Vec, tools: Vec, - function_implementations: HashMap>>, + function_store: FunctionStore, } enum InternalEvent { @@ -44,14 +43,21 @@ enum InternalEvent { type FunctionCall = HashMap; impl ChatGPT { - pub fn new(endpoint: String, api_key: String, model: String, system_message: Option) -> Self { + pub fn new(endpoint: String, api_key: String, model: String, system_message: Option, function_store: FunctionStore) -> Self { let mut chatgpt = ChatGPT { endpoint, api_key, model, messages: vec![], - tools: vec![], - function_implementations: HashMap::new(), + tools: function_store + .declarations + .iter() + .map(|f| Tool { + r#type: "function".to_string(), + function: f.clone(), + }) + .collect(), + function_store, }; if let Some(message) = system_message { chatgpt.messages.push(ChatRequestMessage::new_message(Role::System, &message)); @@ -59,15 +65,6 @@ impl ChatGPT { chatgpt } - pub fn register_function(&mut self, function: Function, implementation: Box) { - let name = function.name.to_string(); - self.tools.push(Tool { - r#type: "function".to_string(), - function, - }); - self.function_implementations.insert(name, Arc::new(implementation)); - } - pub async fn chat(&mut self, message: &str, handler: &dyn ChatHandler) -> Result<(), Exception> { self.messages.push(ChatRequestMessage::new_message(Role::User, message)); let result = self.process(handler).await; @@ -75,7 +72,7 @@ impl ChatGPT { let handles: Result>, _> = calls .into_iter() .map(|(_, (id, name, args))| { - let function = self.get_function(&name)?; + let function = self.function_store.get(&name)?; Ok::, Exception>(tokio::spawn(async move { (id, function(json::from_json(&args).unwrap())) })) }) .collect(); @@ -91,15 +88,6 @@ impl ChatGPT { Ok(()) } - fn get_function(&mut self, name: &str) -> Result>, Exception> { - let function = Arc::clone( - self.function_implementations - .get(name) - .ok_or_else(|| Exception::new(&format!("function not found, name={name}")))?, - ); - Ok(function) - } - async fn process(&mut self, handler: &dyn ChatHandler) -> Result, Exception> { let source = self.call_api().await?; @@ -132,7 +120,7 @@ impl ChatGPT { } async fn call_api(&mut self) -> Result { - let has_function = !self.function_implementations.is_empty(); + let has_function = !self.tools.is_empty(); let request = ChatRequest { messages: Cow::from(&self.messages),