Skip to content

Commit

Permalink
refactor file embedding
Browse files Browse the repository at this point in the history
  • Loading branch information
neowu committed Apr 18, 2024
1 parent 8372f78 commit 8f855da
Show file tree
Hide file tree
Showing 10 changed files with 177 additions and 120 deletions.
1 change: 1 addition & 0 deletions rustfmt.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@ max_width = 150

unstable_features = true
imports_granularity = "Item"
group_imports = "StdExternalCrate"
60 changes: 14 additions & 46 deletions src/bot.rs
Original file line number Diff line number Diff line change
@@ -1,19 +1,17 @@
use std::collections::HashMap;
use std::fs;
use std::path::Path;
use std::sync::Arc;

use serde::Serialize;
use tracing::info;
use tracing::warn;

use self::config::Config;
use crate::gcloud::vertex::Vertex;
use crate::openai::chatgpt::ChatGPT;
use crate::util::exception::Exception;
use crate::util::json;

use self::config::Config;

pub mod config;
pub mod function;

pub trait ChatHandler {
fn on_event(&self, event: ChatEvent);
Expand All @@ -22,46 +20,13 @@ pub trait ChatHandler {
pub enum ChatEvent {
Delta(String),
Error(String),
End,
}

// both openai and gemini shares same openai schema
#[derive(Debug, Serialize, Clone)]
pub struct Function {
pub name: String,
pub description: String,
pub parameters: serde_json::Value,
}

pub type FunctionImplementation = dyn Fn(serde_json::Value) -> serde_json::Value + Send + Sync;

pub struct FunctionStore {
pub declarations: Vec<Function>,
pub implementations: HashMap<String, Arc<Box<FunctionImplementation>>>,
End(Usage),
}

impl FunctionStore {
pub fn new() -> Self {
FunctionStore {
declarations: vec![],
implementations: HashMap::new(),
}
}

pub fn add(&mut self, function: Function, implementation: Box<FunctionImplementation>) {
let name = function.name.to_string();
self.declarations.push(function);
self.implementations.insert(name, Arc::new(implementation));
}

pub fn get(&self, name: &str) -> Result<Arc<Box<FunctionImplementation>>, Exception> {
let function = Arc::clone(
self.implementations
.get(name)
.ok_or_else(|| Exception::new(format!("function not found, name={name}")))?,
);
Ok(function)
}
#[derive(Default)]
pub struct Usage {
pub request_tokens: i32,
pub response_tokens: i32,
}

pub enum Bot {
Expand All @@ -77,10 +42,13 @@ impl Bot {
}
}

pub async fn data(&mut self, path: &Path, message: String, handler: &dyn ChatHandler) -> Result<(), Exception> {
pub fn file(&mut self, path: &Path) -> Result<(), Exception> {
match self {
Bot::ChatGPT(_bot) => todo!("not impl"),
Bot::Vertex(bot) => bot.data(path, message, handler).await,
Bot::ChatGPT(_bot) => {
warn!("ChatGPT does not support uploading file");
Ok(())
}
Bot::Vertex(bot) => bot.file(path),
}
}
}
Expand Down
14 changes: 6 additions & 8 deletions src/bot/config.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
use std::collections::HashMap;

use crate::bot::Bot;
use crate::bot::Function;
use crate::gcloud::vertex::Vertex;
use crate::openai::chatgpt::ChatGPT;
use crate::util::exception::Exception;
use rand::Rng;
use serde::Deserialize;
use serde_json::json;
use tracing::info;

use super::FunctionStore;
use crate::bot::function::Function;
use crate::bot::function::FunctionStore;
use crate::bot::Bot;
use crate::gcloud::vertex::Vertex;
use crate::openai::chatgpt::ChatGPT;
use crate::util::exception::Exception;

#[derive(Deserialize, Debug)]
pub struct Config {
Expand Down Expand Up @@ -83,7 +82,6 @@ fn load_function_store(config: &BotConfig) -> FunctionStore {
}),
},
Box::new(|request| {
info!("call get_random_number, request={request}");
let max = request.get("max").unwrap().as_i64().unwrap();
let mut rng = rand::thread_rng();
let result = rng.gen_range(0..max);
Expand Down
69 changes: 69 additions & 0 deletions src/bot/function.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
use std::collections::HashMap;
use std::sync::Arc;

use futures::future::join_all;
use serde::Serialize;
use tokio::task::JoinHandle;
use tracing::info;

use crate::util::exception::Exception;

// both openai and gemini shares same openai schema
#[derive(Debug, Serialize, Clone)]
pub struct Function {
pub name: String,
pub description: String,
pub parameters: serde_json::Value,
}

pub type FunctionImplementation = dyn Fn(serde_json::Value) -> serde_json::Value + Send + Sync;

pub struct FunctionStore {
pub declarations: Vec<Function>,
pub implementations: HashMap<String, Arc<Box<FunctionImplementation>>>,
}

impl FunctionStore {
pub fn new() -> Self {
FunctionStore {
declarations: vec![],
implementations: HashMap::new(),
}
}

pub fn add(&mut self, function: Function, implementation: Box<FunctionImplementation>) {
let name = function.name.to_string();
self.declarations.push(function);
self.implementations.insert(name, Arc::new(implementation));
}

pub async fn call_function(&self, name: &str, args: serde_json::Value) -> Result<serde_json::Value, Exception> {
info!("call function, name={name}, args={args}");
let function = self.get(name)?;
let response = tokio::spawn(async move { function(args) }).await?;
Ok(response)
}

pub async fn call_functions(&self, functions: Vec<(String, String, serde_json::Value)>) -> Result<Vec<(String, serde_json::Value)>, Exception> {
let handles: Result<Vec<JoinHandle<_>>, _> = functions
.into_iter()
.map(|(id, name, args)| {
info!("call function, id={id}, name={name}, args={args}");
let function = self.get(&name)?;
Ok::<JoinHandle<_>, Exception>(tokio::spawn(async move { (id, function(args)) }))
})
.collect();

let results = join_all(handles?).await.into_iter().collect::<Result<Vec<_>, _>>()?;
Ok(results)
}

fn get(&self, name: &str) -> Result<Arc<Box<FunctionImplementation>>, Exception> {
let function = Arc::clone(
self.implementations
.get(name)
.ok_or_else(|| Exception::new(format!("function not found, name={name}")))?,
);
Ok(function)
}
}
18 changes: 10 additions & 8 deletions src/command/chat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@ use std::io::Write;
use std::path::Path;

use clap::Args;
use tracing::info;

use crate::bot;
use crate::bot::ChatEvent;
use crate::bot::ChatHandler;

use crate::util::exception::Exception;

#[derive(Args)]
Expand All @@ -26,14 +26,18 @@ struct ConsoleHandler;
impl ChatHandler for ConsoleHandler {
fn on_event(&self, event: ChatEvent) {
match event {
ChatEvent::Delta(ref data) => {
print_flush(data).unwrap();
ChatEvent::Delta(data) => {
print_flush(&data).unwrap();
}
ChatEvent::Error(error) => {
println!("Error: {}", error);
}
ChatEvent::End => {
ChatEvent::End(usage) => {
println!();
info!(
"usage, request_tokens={}, response_tokens={}",
usage.request_tokens, usage.response_tokens
);
}
}
}
Expand All @@ -52,10 +56,8 @@ impl Chat {
if line == "/quit" {
break;
}
if line.starts_with("/data ") {
let index = line.find(',').unwrap();
bot.data(Path::new(line[6..index].trim()), line[(index + 1)..].to_string(), &handler)
.await?;
if line.starts_with("/file ") {
bot.file(Path::new(line.strip_prefix("/file ").unwrap()))?;
} else {
bot.chat(line, &handler).await?;
}
Expand Down
44 changes: 28 additions & 16 deletions src/gcloud/api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::rc::Rc;
use serde::Deserialize;
use serde::Serialize;

use crate::bot::Function;
use crate::bot::function::Function;

#[derive(Debug, Serialize)]
pub struct StreamGenerateContent {
Expand Down Expand Up @@ -59,24 +59,26 @@ impl Content {
}
}

pub fn new_inline_data(mime_type: String, data: String, message: String) -> Self {
Self {
role: Role::User,
parts: vec![
Part {
pub fn new_text_with_inline_data(message: String, data: Vec<InlineData>) -> Self {
let mut parts: Vec<Part> = vec![];
parts.append(
&mut data
.into_iter()
.map(|d| Part {
text: None,
inline_data: Some(InlineData { mime_type, data }),
function_call: None,
function_response: None,
},
Part {
text: Some(message),
inline_data: None,
inline_data: Some(d),
function_call: None,
function_response: None,
},
],
}
})
.collect(),
);
parts.push(Part {
text: Some(message),
inline_data: None,
function_call: None,
function_response: None,
});
Self { role: Role::User, parts }
}
}

Expand Down Expand Up @@ -128,6 +130,8 @@ pub struct InlineData {
#[derive(Debug, Deserialize)]
pub struct GenerateContentResponse {
pub candidates: Vec<Candidate>,
#[serde(rename = "usageMetadata")]
pub usage_metadata: Option<UsageMetadata>,
}

#[derive(Debug, Deserialize)]
Expand All @@ -146,3 +150,11 @@ pub struct FunctionResponse {
pub name: String,
pub response: serde_json::Value,
}

#[derive(Debug, Deserialize)]
pub struct UsageMetadata {
#[serde(rename = "promptTokenCount")]
pub prompt_token_count: i32,
#[serde(rename = "candidatesTokenCount")]
pub candidates_token_count: i32,
}
Loading

0 comments on commit 8f855da

Please sign in to comment.