From 8f855dad7f7801533fdd3a65128e099cbc840575 Mon Sep 17 00:00:00 2001 From: neo <1100909+neowu@users.noreply.github.com> Date: Thu, 18 Apr 2024 15:21:04 +0800 Subject: [PATCH] refactor file embedding --- rustfmt.toml | 1 + src/bot.rs | 60 +++++++++---------------------------- src/bot/config.rs | 14 ++++----- src/bot/function.rs | 69 +++++++++++++++++++++++++++++++++++++++++++ src/command/chat.rs | 18 ++++++----- src/gcloud/api.rs | 44 +++++++++++++++++---------- src/gcloud/vertex.rs | 62 +++++++++++++++++++++++--------------- src/main.rs | 1 - src/openai/api.rs | 2 +- src/openai/chatgpt.rs | 26 ++++++---------- 10 files changed, 177 insertions(+), 120 deletions(-) create mode 100644 src/bot/function.rs diff --git a/rustfmt.toml b/rustfmt.toml index c1c6243..ec60976 100644 --- a/rustfmt.toml +++ b/rustfmt.toml @@ -2,3 +2,4 @@ max_width = 150 unstable_features = true imports_granularity = "Item" +group_imports = "StdExternalCrate" diff --git a/src/bot.rs b/src/bot.rs index c21153f..5916042 100644 --- a/src/bot.rs +++ b/src/bot.rs @@ -1,19 +1,17 @@ -use std::collections::HashMap; use std::fs; use std::path::Path; -use std::sync::Arc; -use serde::Serialize; use tracing::info; +use tracing::warn; +use self::config::Config; use crate::gcloud::vertex::Vertex; use crate::openai::chatgpt::ChatGPT; use crate::util::exception::Exception; use crate::util::json; -use self::config::Config; - pub mod config; +pub mod function; pub trait ChatHandler { fn on_event(&self, event: ChatEvent); @@ -22,46 +20,13 @@ pub trait ChatHandler { pub enum ChatEvent { Delta(String), Error(String), - End, -} - -// both openai and gemini shares same openai schema -#[derive(Debug, Serialize, Clone)] -pub struct Function { - pub name: String, - pub description: String, - pub parameters: serde_json::Value, -} - -pub type FunctionImplementation = dyn Fn(serde_json::Value) -> serde_json::Value + Send + Sync; - -pub struct FunctionStore { - pub declarations: Vec, - pub implementations: HashMap>>, + End(Usage), } -impl FunctionStore { - pub fn new() -> Self { - FunctionStore { - declarations: vec![], - implementations: HashMap::new(), - } - } - - pub fn add(&mut self, function: Function, implementation: Box) { - let name = function.name.to_string(); - self.declarations.push(function); - self.implementations.insert(name, Arc::new(implementation)); - } - - pub fn get(&self, name: &str) -> Result>, Exception> { - let function = Arc::clone( - self.implementations - .get(name) - .ok_or_else(|| Exception::new(format!("function not found, name={name}")))?, - ); - Ok(function) - } +#[derive(Default)] +pub struct Usage { + pub request_tokens: i32, + pub response_tokens: i32, } pub enum Bot { @@ -77,10 +42,13 @@ impl Bot { } } - pub async fn data(&mut self, path: &Path, message: String, handler: &dyn ChatHandler) -> Result<(), Exception> { + pub fn file(&mut self, path: &Path) -> Result<(), Exception> { match self { - Bot::ChatGPT(_bot) => todo!("not impl"), - Bot::Vertex(bot) => bot.data(path, message, handler).await, + Bot::ChatGPT(_bot) => { + warn!("ChatGPT does not support uploading file"); + Ok(()) + } + Bot::Vertex(bot) => bot.file(path), } } } diff --git a/src/bot/config.rs b/src/bot/config.rs index 680c880..a29abe0 100644 --- a/src/bot/config.rs +++ b/src/bot/config.rs @@ -1,16 +1,15 @@ use std::collections::HashMap; -use crate::bot::Bot; -use crate::bot::Function; -use crate::gcloud::vertex::Vertex; -use crate::openai::chatgpt::ChatGPT; -use crate::util::exception::Exception; use rand::Rng; use serde::Deserialize; use serde_json::json; -use tracing::info; -use super::FunctionStore; +use crate::bot::function::Function; +use crate::bot::function::FunctionStore; +use crate::bot::Bot; +use crate::gcloud::vertex::Vertex; +use crate::openai::chatgpt::ChatGPT; +use crate::util::exception::Exception; #[derive(Deserialize, Debug)] pub struct Config { @@ -83,7 +82,6 @@ fn load_function_store(config: &BotConfig) -> FunctionStore { }), }, Box::new(|request| { - info!("call get_random_number, request={request}"); let max = request.get("max").unwrap().as_i64().unwrap(); let mut rng = rand::thread_rng(); let result = rng.gen_range(0..max); diff --git a/src/bot/function.rs b/src/bot/function.rs new file mode 100644 index 0000000..01bdaab --- /dev/null +++ b/src/bot/function.rs @@ -0,0 +1,69 @@ +use std::collections::HashMap; +use std::sync::Arc; + +use futures::future::join_all; +use serde::Serialize; +use tokio::task::JoinHandle; +use tracing::info; + +use crate::util::exception::Exception; + +// both openai and gemini shares same openai schema +#[derive(Debug, Serialize, Clone)] +pub struct Function { + pub name: String, + pub description: String, + pub parameters: serde_json::Value, +} + +pub type FunctionImplementation = dyn Fn(serde_json::Value) -> serde_json::Value + Send + Sync; + +pub struct FunctionStore { + pub declarations: Vec, + pub implementations: HashMap>>, +} + +impl FunctionStore { + pub fn new() -> Self { + FunctionStore { + declarations: vec![], + implementations: HashMap::new(), + } + } + + pub fn add(&mut self, function: Function, implementation: Box) { + let name = function.name.to_string(); + self.declarations.push(function); + self.implementations.insert(name, Arc::new(implementation)); + } + + pub async fn call_function(&self, name: &str, args: serde_json::Value) -> Result { + info!("call function, name={name}, args={args}"); + let function = self.get(name)?; + let response = tokio::spawn(async move { function(args) }).await?; + Ok(response) + } + + pub async fn call_functions(&self, functions: Vec<(String, String, serde_json::Value)>) -> Result, Exception> { + let handles: Result>, _> = functions + .into_iter() + .map(|(id, name, args)| { + info!("call function, id={id}, name={name}, args={args}"); + let function = self.get(&name)?; + Ok::, Exception>(tokio::spawn(async move { (id, function(args)) })) + }) + .collect(); + + let results = join_all(handles?).await.into_iter().collect::, _>>()?; + Ok(results) + } + + fn get(&self, name: &str) -> Result>, Exception> { + let function = Arc::clone( + self.implementations + .get(name) + .ok_or_else(|| Exception::new(format!("function not found, name={name}")))?, + ); + Ok(function) + } +} diff --git a/src/command/chat.rs b/src/command/chat.rs index 22fc2fe..9d760f6 100644 --- a/src/command/chat.rs +++ b/src/command/chat.rs @@ -4,11 +4,11 @@ use std::io::Write; use std::path::Path; use clap::Args; +use tracing::info; use crate::bot; use crate::bot::ChatEvent; use crate::bot::ChatHandler; - use crate::util::exception::Exception; #[derive(Args)] @@ -26,14 +26,18 @@ struct ConsoleHandler; impl ChatHandler for ConsoleHandler { fn on_event(&self, event: ChatEvent) { match event { - ChatEvent::Delta(ref data) => { - print_flush(data).unwrap(); + ChatEvent::Delta(data) => { + print_flush(&data).unwrap(); } ChatEvent::Error(error) => { println!("Error: {}", error); } - ChatEvent::End => { + ChatEvent::End(usage) => { println!(); + info!( + "usage, request_tokens={}, response_tokens={}", + usage.request_tokens, usage.response_tokens + ); } } } @@ -52,10 +56,8 @@ impl Chat { if line == "/quit" { break; } - if line.starts_with("/data ") { - let index = line.find(',').unwrap(); - bot.data(Path::new(line[6..index].trim()), line[(index + 1)..].to_string(), &handler) - .await?; + if line.starts_with("/file ") { + bot.file(Path::new(line.strip_prefix("/file ").unwrap()))?; } else { bot.chat(line, &handler).await?; } diff --git a/src/gcloud/api.rs b/src/gcloud/api.rs index 6b8b907..a397dfd 100644 --- a/src/gcloud/api.rs +++ b/src/gcloud/api.rs @@ -3,7 +3,7 @@ use std::rc::Rc; use serde::Deserialize; use serde::Serialize; -use crate::bot::Function; +use crate::bot::function::Function; #[derive(Debug, Serialize)] pub struct StreamGenerateContent { @@ -59,24 +59,26 @@ impl Content { } } - pub fn new_inline_data(mime_type: String, data: String, message: String) -> Self { - Self { - role: Role::User, - parts: vec![ - Part { + pub fn new_text_with_inline_data(message: String, data: Vec) -> Self { + let mut parts: Vec = vec![]; + parts.append( + &mut data + .into_iter() + .map(|d| Part { text: None, - inline_data: Some(InlineData { mime_type, data }), - function_call: None, - function_response: None, - }, - Part { - text: Some(message), - inline_data: None, + inline_data: Some(d), function_call: None, function_response: None, - }, - ], - } + }) + .collect(), + ); + parts.push(Part { + text: Some(message), + inline_data: None, + function_call: None, + function_response: None, + }); + Self { role: Role::User, parts } } } @@ -128,6 +130,8 @@ pub struct InlineData { #[derive(Debug, Deserialize)] pub struct GenerateContentResponse { pub candidates: Vec, + #[serde(rename = "usageMetadata")] + pub usage_metadata: Option, } #[derive(Debug, Deserialize)] @@ -146,3 +150,11 @@ pub struct FunctionResponse { pub name: String, pub response: serde_json::Value, } + +#[derive(Debug, Deserialize)] +pub struct UsageMetadata { + #[serde(rename = "promptTokenCount")] + pub prompt_token_count: i32, + #[serde(rename = "candidatesTokenCount")] + pub candidates_token_count: i32, +} diff --git a/src/gcloud/vertex.rs b/src/gcloud/vertex.rs index dc2ec66..4ca7023 100644 --- a/src/gcloud/vertex.rs +++ b/src/gcloud/vertex.rs @@ -1,30 +1,32 @@ +use std::env; +use std::fs; +use std::mem; +use std::path::Path; +use std::rc::Rc; + use base64::prelude::BASE64_STANDARD; use base64::Engine; use futures::StreamExt; use reqwest::Response; use tokio::sync::mpsc::channel; use tokio::sync::mpsc::Sender; - -use std::env; -use std::fs; -use std::path::Path; -use std::rc::Rc; - -use crate::bot::ChatEvent; -use crate::bot::ChatHandler; - -use crate::bot::FunctionStore; -use crate::gcloud::api::GenerateContentResponse; -use crate::util::exception::Exception; -use crate::util::http_client; -use crate::util::json; +use tracing::info; use super::api::Content; use super::api::FunctionCall; use super::api::GenerationConfig; +use super::api::InlineData; use super::api::Role; use super::api::StreamGenerateContent; use super::api::Tool; +use crate::bot::function::FunctionStore; +use crate::bot::ChatEvent; +use crate::bot::ChatHandler; +use crate::bot::Usage; +use crate::gcloud::api::GenerateContentResponse; +use crate::util::exception::Exception; +use crate::util::http_client; +use crate::util::json; pub struct Vertex { url: String, @@ -32,6 +34,8 @@ pub struct Vertex { system_message: Rc>, tools: Rc>, function_store: FunctionStore, + data: Vec, + usage: Usage, } impl Vertex { @@ -58,24 +62,24 @@ impl Vertex { .collect(), ), function_store, + data: vec![], + usage: Usage::default(), } } pub async fn chat(&mut self, message: String, handler: &dyn ChatHandler) -> Result<(), Exception> { - let mut result = self.process(Content::new_text(Role::User, message), handler).await?; + let data = mem::take(&mut self.data); + let mut result = self.process(Content::new_text_with_inline_data(message, data), handler).await?; while let Some(function_call) = result { - let function = self.function_store.get(&function_call.name)?; - - let function_response = tokio::spawn(async move { function(function_call.args) }).await?; - + let function_response = self.function_store.call_function(&function_call.name, function_call.args).await?; let content = Content::new_function_response(function_call.name, function_response); result = self.process(content, handler).await?; } Ok(()) } - pub async fn data(&mut self, path: &Path, message: String, handler: &dyn ChatHandler) -> Result<(), Exception> { + pub fn file(&mut self, path: &Path) -> Result<(), Exception> { let extension = path .extension() .ok_or_else(|| Exception::new(format!("file must have extension, path={}", path.to_string_lossy())))? @@ -88,8 +92,14 @@ impl Vertex { "pdf" => Ok("application/pdf".to_string()), _ => Err(Exception::new(format!("not supported extension, path={}", path.to_string_lossy()))), }?; - self.process(Content::new_inline_data(mime_type, BASE64_STANDARD.encode(content), message), handler) - .await?; + info!( + "file added, will submit with next message, mime_type={mime_type}, path={}", + path.to_string_lossy() + ); + self.data.push(InlineData { + mime_type, + data: BASE64_STANDARD.encode(content), + }); Ok(()) } @@ -108,6 +118,11 @@ impl Vertex { while let Some(response) = rx.recv().await { let response = response?; + if let Some(usage) = response.usage_metadata { + self.usage.request_tokens += usage.prompt_token_count; + self.usage.response_tokens += usage.candidates_token_count; + } + let part = response.candidates.into_iter().next().unwrap().content.parts.into_iter().next().unwrap(); if let Some(function) = part.function_call { @@ -121,7 +136,8 @@ impl Vertex { if !model_message.is_empty() { self.add_message(Content::new_text(Role::Model, model_message)); } - handler.on_event(ChatEvent::End); + let usage = mem::take(&mut self.usage); + handler.on_event(ChatEvent::End(usage)); Ok(None) } diff --git a/src/main.rs b/src/main.rs index 14e2f03..581618c 100644 --- a/src/main.rs +++ b/src/main.rs @@ -3,7 +3,6 @@ use clap::Subcommand; use command::chat::Chat; use command::generate_zsh_completion::GenerateZshCompletion; use command::server::Server; - use util::exception::Exception; mod bot; diff --git a/src/openai/api.rs b/src/openai/api.rs index 7fd4b79..30df0cd 100644 --- a/src/openai/api.rs +++ b/src/openai/api.rs @@ -4,7 +4,7 @@ use std::rc::Rc; use serde::Deserialize; use serde::Serialize; -use crate::bot::Function; +use crate::bot::function::Function; #[derive(Debug, Serialize)] pub struct ChatRequest { diff --git a/src/openai/chatgpt.rs b/src/openai/chatgpt.rs index 68fdd96..6b0a542 100644 --- a/src/openai/chatgpt.rs +++ b/src/openai/chatgpt.rs @@ -1,9 +1,7 @@ use std::collections::HashMap; - use std::fmt; use std::rc::Rc; -use futures::future::join_all; use futures::stream::StreamExt; use reqwest_eventsource::CannotCloneRequestError; use reqwest_eventsource::Event; @@ -11,12 +9,11 @@ use reqwest_eventsource::EventSource; use serde::Serialize; use tokio::sync::mpsc::channel; use tokio::sync::mpsc::Sender; -use tokio::task::JoinHandle; +use crate::bot::function::FunctionStore; use crate::bot::ChatEvent; use crate::bot::ChatHandler; - -use crate::bot::FunctionStore; +use crate::bot::Usage; use crate::openai::api::ChatRequest; use crate::openai::api::ChatRequestMessage; use crate::openai::api::ChatResponse; @@ -70,20 +67,15 @@ impl ChatGPT { self.add_message(ChatRequestMessage::new_message(Role::User, message)); let result = self.process(handler).await; if let Ok(Some(InternalEvent::FunctionCall(calls))) = result { - let handles: Result>, _> = calls + let functions = calls .into_iter() - .map(|(_, (id, name, args))| { - let function = self.function_store.get(&name)?; - Ok::, Exception>(tokio::spawn(async move { (id, function(json::from_json(&args).unwrap())) })) - }) + .map(|(_, (id, name, args))| (id, name, json::from_json(&args).unwrap())) .collect(); - - let results: Result, _> = join_all(handles?).await.into_iter().collect(); - for result in results? { - let function_message = ChatRequestMessage::new_function_response(result.0, json::to_json(&result.1)?); - self.add_message(function_message); + let results = self.function_store.call_functions(functions).await?; + for result in results { + let function_response = ChatRequestMessage::new_function_response(result.0, json::to_json(&result.1)?); + self.add_message(function_response); } - self.process(handler).await?; } Ok(()) @@ -205,6 +197,6 @@ async fn process_event_source(mut source: EventSource, tx: Sender if !function_calls.is_empty() { tx.send(InternalEvent::FunctionCall(function_calls)).await.unwrap(); } else { - tx.send(InternalEvent::Event(ChatEvent::End)).await.unwrap(); + tx.send(InternalEvent::Event(ChatEvent::End(Usage::default()))).await.unwrap(); } }