From 7f0733305e8692458305011f4c3c48a78fe7ebee Mon Sep 17 00:00:00 2001 From: neo <1100909+neowu@users.noreply.github.com> Date: Thu, 18 Apr 2024 16:01:42 +0800 Subject: [PATCH] fix multiple functions --- src/bot/config.rs | 40 ++++++++++++++++++++++++++++++++++++++-- src/bot/function.rs | 3 ++- src/gcloud/vertex.rs | 13 ++++--------- src/openai/chatgpt.rs | 1 + 4 files changed, 45 insertions(+), 12 deletions(-) diff --git a/src/bot/config.rs b/src/bot/config.rs index a29abe0..7f48dcf 100644 --- a/src/bot/config.rs +++ b/src/bot/config.rs @@ -70,7 +70,7 @@ fn load_function_store(config: &BotConfig) -> FunctionStore { Function { name: "get_random_number".to_string(), description: "generate random number".to_string(), - parameters: serde_json::json!({ + parameters: Some(serde_json::json!({ "type": "object", "properties": { "max": { @@ -79,7 +79,7 @@ fn load_function_store(config: &BotConfig) -> FunctionStore { }, }, "required": ["max"] - }), + })), }, Box::new(|request| { let max = request.get("max").unwrap().as_i64().unwrap(); @@ -92,6 +92,42 @@ fn load_function_store(config: &BotConfig) -> FunctionStore { }), ); } + if let "close_door" = function.as_str() { + function_store.add( + Function { + name: "close_door".to_string(), + description: "close door of home".to_string(), + parameters: None, + }, + Box::new(|_request| { + json!({ + "success": true + }) + }), + ); + } + if let "close_window" = function.as_str() { + function_store.add( + Function { + name: "close_window".to_string(), + description: "close window of home with id".to_string(), + parameters: Some(serde_json::json!({ + "type": "object", + "properties": { + "id": { + "type": "string", + "description": "id of window" + } + } + })), + }, + Box::new(|_request| { + json!({ + "success": true + }) + }), + ); + } } function_store } diff --git a/src/bot/function.rs b/src/bot/function.rs index 01bdaab..1b8bbe0 100644 --- a/src/bot/function.rs +++ b/src/bot/function.rs @@ -13,7 +13,8 @@ use crate::util::exception::Exception; pub struct Function { pub name: String, pub description: String, - pub parameters: serde_json::Value, + #[serde(skip_serializing_if = "Option::is_none")] + pub parameters: Option, } pub type FunctionImplementation = dyn Fn(serde_json::Value) -> serde_json::Value + Send + Sync; diff --git a/src/gcloud/vertex.rs b/src/gcloud/vertex.rs index 4ca7023..a08ddbd 100644 --- a/src/gcloud/vertex.rs +++ b/src/gcloud/vertex.rs @@ -52,15 +52,9 @@ impl Vertex { url, messages: Rc::new(vec![]), system_message: Rc::new(system_message.map(|message| Content::new_text(Role::Model, message))), - tools: Rc::new( - function_store - .declarations - .iter() - .map(|f| Tool { - function_declarations: vec![f.clone()], - }) - .collect(), - ), + tools: Rc::new(vec![Tool { + function_declarations: function_store.declarations.to_vec(), + }]), function_store, data: vec![], usage: Usage::default(), @@ -164,6 +158,7 @@ impl Vertex { async fn post(&self, request: StreamGenerateContent) -> Result { let body = json::to_json(&request)?; + info!("body={body}"); let response = http_client::http_client() .post(&self.url) .bearer_auth(token()) diff --git a/src/openai/chatgpt.rs b/src/openai/chatgpt.rs index 6b0a542..6b792cf 100644 --- a/src/openai/chatgpt.rs +++ b/src/openai/chatgpt.rs @@ -197,6 +197,7 @@ async fn process_event_source(mut source: EventSource, tx: Sender if !function_calls.is_empty() { tx.send(InternalEvent::FunctionCall(function_calls)).await.unwrap(); } else { + // chatgpt doesn't support token usage with stream mode tx.send(InternalEvent::Event(ChatEvent::End(Usage::default()))).await.unwrap(); } }