Skip to content

Commit

Permalink
refactor function call
Browse files Browse the repository at this point in the history
  • Loading branch information
neowu committed Jul 19, 2024
1 parent 5c2e957 commit d984b2a
Show file tree
Hide file tree
Showing 7 changed files with 86 additions and 54 deletions.
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "puppet"
version = "0.2.0"
version = "0.3.0"
edition = "2021"

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
Expand Down
48 changes: 27 additions & 21 deletions src/azure/chatgpt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ use crate::azure::chatgpt_api::ChatRequestMessage;
use crate::azure::chatgpt_api::ChatStreamResponse;
use crate::azure::chatgpt_api::Role;
use crate::azure::chatgpt_api::Tool;
use crate::llm::function::FunctionImplementations;
use crate::llm::function::FunctionObject;
use crate::llm::function::FunctionStore;
use crate::llm::ChatOption;
use crate::util::console;
Expand All @@ -32,20 +34,24 @@ pub struct ChatGPT {
api_key: String,
messages: Rc<Vec<ChatRequestMessage>>,
tools: Option<Rc<[Tool]>>,
function_store: FunctionStore,
implementations: FunctionImplementations,
pub option: Option<ChatOption>,
}

impl ChatGPT {
pub fn new(endpoint: String, model: String, api_key: String, function_store: FunctionStore) -> Self {
let FunctionStore {
declarations,
implementations,
} = function_store;

let url = format!("{endpoint}/openai/deployments/{model}/chat/completions?api-version=2024-06-01");
let tools: Option<Rc<[Tool]>> = function_store.declarations.is_empty().not().then_some(
function_store
.declarations
.iter()
.map(|f| Tool {
r#type: "function".to_string(),
function: Rc::clone(f),
let tools: Option<Rc<[Tool]>> = declarations.is_empty().not().then_some(
declarations
.into_iter()
.map(|function| Tool {
r#type: "function",
function,
})
.collect(),
);
Expand All @@ -54,7 +60,7 @@ impl ChatGPT {
api_key,
messages: Rc::new(vec![]),
tools,
function_store,
implementations,
option: None,
}
}
Expand Down Expand Up @@ -96,10 +102,6 @@ impl ChatGPT {
self.add_message(ChatRequestMessage::new_message(Role::Assistant, message));
}

fn add_message(&mut self, message: ChatRequestMessage) {
Rc::get_mut(&mut self.messages).unwrap().push(message);
}

async fn process(&mut self) -> Result<(), Exception> {
loop {
let http_response = self.call_api().await?;
Expand All @@ -114,17 +116,17 @@ impl ChatGPT {
if let Some(calls) = message.tool_calls {
let mut functions = Vec::with_capacity(calls.len());
for call in calls.iter() {
functions.push((
call.id.to_string(),
call.function.name.to_string(),
json::from_json::<serde_json::Value>(&call.function.arguments)?,
))
functions.push(FunctionObject {
id: call.id.to_string(),
name: call.function.name.to_string(),
value: json::from_json::<serde_json::Value>(&call.function.arguments)?,
})
}
self.add_message(ChatRequestMessage::new_function_call(calls));

let results = self.function_store.call_functions(functions).await?;
for (id, _, result) in results {
self.add_message(ChatRequestMessage::new_function_response(id, json::to_json(&result)?));
let results = self.implementations.call_functions(functions).await?;
for result in results {
self.add_message(ChatRequestMessage::new_function_response(result.id, json::to_json(&result.value)?));
}
} else {
self.add_message(ChatRequestMessage::new_message(Role::Assistant, message.content.unwrap()));
Expand Down Expand Up @@ -170,6 +172,10 @@ impl ChatGPT {

Ok(response)
}

fn add_message(&mut self, message: ChatRequestMessage) {
Rc::get_mut(&mut self.messages).unwrap().push(message);
}
}

async fn read_sse_response(mut http_response: Response) -> Result<ChatResponse, Exception> {
Expand Down
4 changes: 2 additions & 2 deletions src/azure/chatgpt_api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,8 @@ impl ChatRequestMessage {

#[derive(Debug, Serialize)]
pub struct Tool {
pub r#type: String,
pub function: Rc<Function>,
pub r#type: &'static str,
pub function: Function,
}

#[derive(Debug, Serialize, Deserialize)]
Expand Down
23 changes: 17 additions & 6 deletions src/gcloud/gemini.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ use crate::gcloud::gemini_api::Candidate;
use crate::gcloud::gemini_api::GenerateContentStreamResponse;
use crate::gcloud::gemini_api::Role;
use crate::gcloud::gemini_api::UsageMetadata;
use crate::llm::function::FunctionImplementations;
use crate::llm::function::FunctionObject;
use crate::llm::function::FunctionStore;
use crate::llm::ChatOption;
use crate::util::console;
Expand All @@ -33,21 +35,26 @@ pub struct Gemini {
contents: Rc<Vec<Content>>,
system_instruction: Option<Rc<Content>>,
tools: Option<Rc<[Tool]>>,
function_store: FunctionStore,
implementations: FunctionImplementations,
pub option: Option<ChatOption>,
}

impl Gemini {
pub fn new(endpoint: String, project: String, location: String, model: String, function_store: FunctionStore) -> Self {
let FunctionStore {
declarations,
implementations,
} = function_store;

let url = format!("{endpoint}/v1/projects/{project}/locations/{location}/publishers/google/models/{model}:streamGenerateContent");
Gemini {
url,
contents: Rc::new(vec![]),
system_instruction: None,
tools: function_store.declarations.is_empty().not().then_some(Rc::from(vec![Tool {
function_declarations: function_store.declarations.to_vec(),
tools: declarations.is_empty().not().then_some(Rc::from(vec![Tool {
function_declarations: declarations,
}])),
function_store,
implementations,
option: None,
}
}
Expand Down Expand Up @@ -89,14 +96,18 @@ impl Gemini {
let mut functions = vec![];
for (i, part) in candidate.content.parts.iter().enumerate() {
if let Some(ref call) = part.function_call {
functions.push((i.to_string(), call.name.to_string(), call.args.clone()));
functions.push(FunctionObject {
id: i.to_string(),
name: call.name.to_string(),
value: call.args.clone(),
});
}
}

self.add_content(candidate.content);

if !functions.is_empty() {
let function_result = self.function_store.call_functions(functions).await?;
let function_result = self.implementations.call_functions(functions).await?;
self.add_content(Content::new_function_response(function_result));
} else {
return Ok(());
Expand Down
12 changes: 8 additions & 4 deletions src/gcloud/gemini_api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use serde::Deserialize;
use serde::Serialize;

use crate::llm::function::Function;
use crate::llm::function::FunctionObject;

#[derive(Debug, Serialize)]
pub struct StreamGenerateContent {
Expand Down Expand Up @@ -54,16 +55,19 @@ impl Content {
}
}

pub fn new_function_response(results: Vec<(String, String, serde_json::Value)>) -> Self {
pub fn new_function_response(results: Vec<FunctionObject>) -> Self {
Self {
role: Role::User,
parts: results
.into_iter()
.map(|r| Part {
.map(|result| Part {
text: None,
inline_data: None,
function_call: None,
function_response: Some(FunctionResponse { name: r.1, response: r.2 }),
function_response: Some(FunctionResponse {
name: result.name,
response: result.value,
}),
})
.collect(),
}
Expand All @@ -73,7 +77,7 @@ impl Content {
#[derive(Debug, Serialize)]
pub struct Tool {
#[serde(rename = "functionDeclarations")]
pub function_declarations: Vec<Rc<Function>>,
pub function_declarations: Vec<Function>,
}

#[derive(Debug, Serialize, Deserialize)]
Expand Down
49 changes: 30 additions & 19 deletions src/llm/function.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
use std::collections::HashMap;
use std::rc::Rc;
use std::sync::Arc;

use serde::Serialize;
Expand All @@ -19,41 +18,53 @@ pub struct Function {

pub type FunctionImplementation = dyn Fn(&serde_json::Value) -> serde_json::Value + Send + Sync;

// pub struct FunctionObject {
// pub id: String,
// pub name: String,
// pub value: serde_json::Value,
// }
pub struct FunctionObject {
pub id: String,
pub name: String,
pub value: serde_json::Value,
}

pub struct FunctionStore {
pub declarations: Vec<Rc<Function>>,
pub implementations: HashMap<String, Arc<Box<FunctionImplementation>>>,
pub declarations: Vec<Function>,
pub implementations: FunctionImplementations,
}

pub struct FunctionImplementations {
implementations: HashMap<String, Arc<Box<FunctionImplementation>>>,
}

impl FunctionStore {
pub fn new() -> Self {
FunctionStore {
declarations: vec![],
implementations: HashMap::new(),
implementations: FunctionImplementations {
implementations: HashMap::new(),
},
}
}

pub fn add(&mut self, function: Function, implementation: Box<FunctionImplementation>) {
let name = function.name.to_string();
self.declarations.push(Rc::new(function));
self.implementations.insert(name, Arc::new(implementation));
self.declarations.push(function);
self.implementations.implementations.insert(name, Arc::new(implementation));
}
}

pub async fn call_functions(
&self,
functions: Vec<(String, String, serde_json::Value)>,
) -> Result<Vec<(String, String, serde_json::Value)>, Exception> {
impl FunctionImplementations {
pub async fn call_functions(&self, functions: Vec<FunctionObject>) -> Result<Vec<FunctionObject>, Exception> {
let mut handles = JoinSet::new();
for (id, name, args) in functions {
let function = self.get(&name)?;
for function_param in functions {
let function = self.get(&function_param.name)?;
handles.spawn(async move {
info!("call function, id={id}, name={name}, args={args}");
(id, name, function(&args))
info!(
"call function, id={}, name={}, args={}",
function_param.id, function_param.name, function_param.value
);
FunctionObject {
id: function_param.id,
name: function_param.name,
value: function(&function_param.value),
}
});
}
let mut results = vec![];
Expand Down

0 comments on commit d984b2a

Please sign in to comment.