Skip to content

Commit

Permalink
refactor function store
Browse files Browse the repository at this point in the history
  • Loading branch information
neowu committed Apr 17, 2024
1 parent 796e877 commit 1d56293
Show file tree
Hide file tree
Showing 5 changed files with 85 additions and 80 deletions.
46 changes: 33 additions & 13 deletions src/bot.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
use std::collections::HashMap;
use std::fs;
use std::path::Path;
use std::sync::Arc;

use serde::Serialize;
use tracing::info;
Expand Down Expand Up @@ -33,27 +35,45 @@ pub struct Function {

pub type FunctionImplementation = dyn Fn(serde_json::Value) -> serde_json::Value + Send + Sync;

pub struct FunctionStore {
pub declarations: Vec<Function>,
pub implementations: HashMap<String, Arc<Box<FunctionImplementation>>>,
}

impl FunctionStore {
pub fn new() -> Self {
FunctionStore {
declarations: vec![],
implementations: HashMap::new(),
}
}

pub fn add(&mut self, function: Function, implementation: Box<FunctionImplementation>) {
let name = function.name.to_string();
self.declarations.push(function);
self.implementations.insert(name, Arc::new(implementation));
}

pub fn get(&self, name: &str) -> Result<Arc<Box<FunctionImplementation>>, Exception> {
let function = Arc::clone(
self.implementations
.get(name)
.ok_or_else(|| Exception::new(&format!("function not found, name={name}")))?,
);
Ok(function)
}
}

pub enum Bot {
ChatGPT(ChatGPT),
Vertex(Vertex),
}

impl Bot {
pub fn register_function(&mut self, function: Function, implementation: Box<FunctionImplementation>) {
match self {
Bot::ChatGPT(chat_gpt) => {
chat_gpt.register_function(function, implementation);
}
Bot::Vertex(vertex) => {
vertex.register_function(function, implementation);
}
}
}

pub async fn chat(&mut self, message: &str, handler: &dyn ChatHandler) -> Result<(), Exception> {
match self {
Bot::ChatGPT(chat_gpt) => chat_gpt.chat(message, handler).await,
Bot::Vertex(vertex) => vertex.chat(message, handler).await,
Bot::ChatGPT(bot) => bot.chat(message, handler).await,
Bot::Vertex(bot) => bot.chat(message, handler).await,
}
}
}
Expand Down
16 changes: 12 additions & 4 deletions src/bot/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ use rand::Rng;
use serde::Deserialize;
use serde_json::json;

use super::FunctionStore;

#[derive(Deserialize, Debug)]
pub struct Config {
pub bots: HashMap<String, BotConfig>,
Expand All @@ -21,21 +23,25 @@ impl Config {
.get(name)
.ok_or_else(|| Exception::new(&format!("can not find bot, name={name}")))?;

let mut bot = match config.r#type {
let function_store = load_function_store(config);

let bot = match config.r#type {
BotType::Azure => Bot::ChatGPT(ChatGPT::new(
config.endpoint.to_string(),
config.params.get("api_key").unwrap().to_string(),
config.params.get("model").unwrap().to_string(),
Option::None,
function_store,
)),
BotType::GCloud => Bot::Vertex(Vertex::new(
config.endpoint.to_string(),
config.params.get("project").unwrap().to_string(),
config.params.get("location").unwrap().to_string(),
config.params.get("model").unwrap().to_string(),
function_store,
)),
};
register_function(config, &mut bot);

Ok(bot)
}
}
Expand All @@ -54,10 +60,11 @@ pub enum BotType {
GCloud,
}

fn register_function(config: &BotConfig, bot: &mut Bot) {
fn load_function_store(config: &BotConfig) -> FunctionStore {
let mut function_store = FunctionStore::new();
for function in &config.functions {
if let "get_random_number" = function.as_str() {
bot.register_function(
function_store.add(
Function {
name: "get_random_number".to_string(),
description: "generate random number".to_string(),
Expand All @@ -84,4 +91,5 @@ fn register_function(config: &BotConfig, bot: &mut Bot) {
);
}
}
function_store
}
2 changes: 1 addition & 1 deletion src/command/chat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ impl ChatHandler for ConsoleHandler {
impl Chat {
pub async fn execute(&self) -> Result<(), Exception> {
let config = bot::load(Path::new(&self.conf))?;

let mut bot = config.create(&self.name)?;

let handler = ConsoleHandler {};
loop {
print_flush("> ")?;
Expand Down
53 changes: 21 additions & 32 deletions src/gcloud/vertex.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,13 @@ use tokio::sync::mpsc::channel;
use tokio::sync::mpsc::Sender;

use std::borrow::Cow;
use std::collections::HashMap;
use std::env;

use std::sync::Arc;
use std::env;

use crate::bot::ChatEvent;
use crate::bot::ChatHandler;
use crate::bot::Function;
use crate::bot::FunctionImplementation;

use crate::bot::FunctionStore;
use crate::gcloud::api::GenerateContentResponse;
use crate::util::exception::Exception;
use crate::util::http_client;
Expand All @@ -26,48 +24,39 @@ use super::api::StreamGenerateContent;
use super::api::Tool;

pub struct Vertex {
pub endpoint: String,
pub project: String,
pub location: String,
pub model: String,
endpoint: String,
project: String,
location: String,
model: String,
messages: Vec<Content>,
tools: Vec<Tool>,
function_implementations: HashMap<String, Arc<Box<FunctionImplementation>>>,
function_store: FunctionStore,
}

impl Vertex {
pub fn new(endpoint: String, project: String, location: String, model: String) -> Self {
pub fn new(endpoint: String, project: String, location: String, model: String, function_store: FunctionStore) -> Self {
Vertex {
endpoint,
project,
location,
model,
messages: vec![],
tools: vec![],
function_implementations: HashMap::new(),
tools: function_store
.declarations
.iter()
.map(|f| Tool {
function_declarations: vec![f.clone()],
})
.collect(),
function_store,
}
}

pub fn register_function(&mut self, function: Function, implementation: Box<FunctionImplementation>) {
let name = function.name.to_string();
self.tools.push(Tool {
function_declarations: vec![function],
});
self.function_implementations.insert(name, Arc::new(implementation));
}

pub async fn chat(&mut self, message: &str, handler: &dyn ChatHandler) -> Result<(), Exception> {
let mut result = self
.process(Content::new_text(Role::User, message), handler)
.await
.map_err(Exception::from)?;
let mut result = self.process(Content::new_text(Role::User, message), handler).await?;

while let Some(function_call) = result {
let function = Arc::clone(
self.function_implementations
.get(&function_call.name)
.ok_or_else(|| Exception::new(&format!("function not found, name={}", function_call.name)))?,
);
let function = self.function_store.get(&function_call.name)?;

let function_response = tokio::spawn(async move { function(function_call.args) }).await?;

Expand Down Expand Up @@ -114,8 +103,8 @@ impl Vertex {
Ok(None)
}

async fn call_api(&mut self) -> Result<Response, Exception> {
let has_function = !self.function_implementations.is_empty();
async fn call_api(&self) -> Result<Response, Exception> {
let has_function = !self.tools.is_empty();

let endpoint = &self.endpoint;
let project = &self.project;
Expand Down
48 changes: 18 additions & 30 deletions src/openai/chatgpt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ use std::borrow::Cow;
use std::collections::HashMap;

use std::fmt;
use std::sync::Arc;

use futures::future::join_all;
use futures::stream::StreamExt;
Expand All @@ -16,8 +15,8 @@ use tokio::task::JoinHandle;

use crate::bot::ChatEvent;
use crate::bot::ChatHandler;
use crate::bot::Function;
use crate::bot::FunctionImplementation;

use crate::bot::FunctionStore;
use crate::openai::api::ChatRequest;
use crate::openai::api::ChatRequestMessage;
use crate::openai::api::ChatResponse;
Expand All @@ -28,12 +27,12 @@ use crate::util::http_client;
use crate::util::json;

pub struct ChatGPT {
pub endpoint: String,
pub api_key: String,
pub model: String,
endpoint: String,
api_key: String,
model: String,
messages: Vec<ChatRequestMessage>,
tools: Vec<Tool>,
function_implementations: HashMap<String, Arc<Box<FunctionImplementation>>>,
function_store: FunctionStore,
}

enum InternalEvent {
Expand All @@ -44,38 +43,36 @@ enum InternalEvent {
type FunctionCall = HashMap<i64, (String, String, String)>;

impl ChatGPT {
pub fn new(endpoint: String, api_key: String, model: String, system_message: Option<String>) -> Self {
pub fn new(endpoint: String, api_key: String, model: String, system_message: Option<String>, function_store: FunctionStore) -> Self {
let mut chatgpt = ChatGPT {
endpoint,
api_key,
model,
messages: vec![],
tools: vec![],
function_implementations: HashMap::new(),
tools: function_store
.declarations
.iter()
.map(|f| Tool {
r#type: "function".to_string(),
function: f.clone(),
})
.collect(),
function_store,
};
if let Some(message) = system_message {
chatgpt.messages.push(ChatRequestMessage::new_message(Role::System, &message));
}
chatgpt
}

pub fn register_function(&mut self, function: Function, implementation: Box<FunctionImplementation>) {
let name = function.name.to_string();
self.tools.push(Tool {
r#type: "function".to_string(),
function,
});
self.function_implementations.insert(name, Arc::new(implementation));
}

pub async fn chat(&mut self, message: &str, handler: &dyn ChatHandler) -> Result<(), Exception> {
self.messages.push(ChatRequestMessage::new_message(Role::User, message));
let result = self.process(handler).await;
if let Ok(Some(InternalEvent::FunctionCall(calls))) = result {
let handles: Result<Vec<JoinHandle<_>>, _> = calls
.into_iter()
.map(|(_, (id, name, args))| {
let function = self.get_function(&name)?;
let function = self.function_store.get(&name)?;
Ok::<JoinHandle<_>, Exception>(tokio::spawn(async move { (id, function(json::from_json(&args).unwrap())) }))
})
.collect();
Expand All @@ -91,15 +88,6 @@ impl ChatGPT {
Ok(())
}

fn get_function(&mut self, name: &str) -> Result<Arc<Box<FunctionImplementation>>, Exception> {
let function = Arc::clone(
self.function_implementations
.get(name)
.ok_or_else(|| Exception::new(&format!("function not found, name={name}")))?,
);
Ok(function)
}

async fn process(&mut self, handler: &dyn ChatHandler) -> Result<Option<InternalEvent>, Exception> {
let source = self.call_api().await?;

Expand Down Expand Up @@ -132,7 +120,7 @@ impl ChatGPT {
}

async fn call_api(&mut self) -> Result<EventSource, Exception> {
let has_function = !self.function_implementations.is_empty();
let has_function = !self.tools.is_empty();

let request = ChatRequest {
messages: Cow::from(&self.messages),
Expand Down

0 comments on commit 1d56293

Please sign in to comment.