diff --git a/src/llm.rs b/src/llm.rs index 23d83c0..4b96651 100644 --- a/src/llm.rs +++ b/src/llm.rs @@ -15,12 +15,6 @@ use crate::util::json; pub mod config; pub mod function; -#[derive(Default, Clone)] -pub struct TokenUsage { - pub prompt_tokens: i32, - pub completion_tokens: i32, -} - #[derive(Debug)] pub struct ChatOption { pub temperature: f32, diff --git a/src/llm/config.rs b/src/llm/config.rs index d9f5f25..ab99cf0 100644 --- a/src/llm/config.rs +++ b/src/llm/config.rs @@ -1,4 +1,5 @@ use std::collections::HashMap; +use std::sync::Arc; use anyhow::anyhow; use anyhow::Context; @@ -9,8 +10,9 @@ use serde::Deserialize; use serde_json::json; use super::function::FUNCTION_STORE; -use crate::llm::function::Function; use crate::openai::chat::Chat; +use crate::openai::chat_api::Function; +use crate::openai::chat_api::Tool; #[derive(Deserialize, Debug)] pub struct Config { @@ -32,9 +34,10 @@ impl Config { info!("create model, name={name}"); - let functions = load_functions(config)?; + let tools = load_functions(config)?; + let tools = if tools.is_empty() { None } else { Some(Arc::from(tools)) }; - let mut model = Chat::new(config.url.to_string(), config.api_key.to_string(), config.model.to_string(), functions); + let mut model = Chat::new(config.url.to_string(), config.api_key.to_string(), config.model.to_string(), tools); if let Some(message) = config.system_message.as_ref() { model.system_message(message.to_string()); @@ -44,26 +47,29 @@ impl Config { } } -fn load_functions(config: &ModelConfig) -> Result> { - let mut declarations: Vec = vec![]; +fn load_functions(config: &ModelConfig) -> Result> { + let mut declarations: Vec = vec![]; let mut function_store = FUNCTION_STORE.lock().unwrap(); for function in &config.functions { info!("load function, name={function}"); match function.as_str() { "get_random_number" => { - declarations.push(Function { - name: "get_random_number", - description: "generate random number", - parameters: Some(serde_json::json!({ - "type": "object", - "properties": { - "max": { - "type": "number", - "description": "max of value" - }, - }, - "required": ["max"] - })), + declarations.push(Tool { + r#type: "function", + function: Function { + name: "get_random_number", + description: "generate random number", + parameters: Some(serde_json::json!({ + "type": "object", + "properties": { + "max": { + "type": "number", + "description": "max of value" + }, + }, + "required": ["max"] + })), + }, }); function_store.add( "get_random_number", @@ -79,10 +85,13 @@ fn load_functions(config: &ModelConfig) -> Result> { ) } "close_door" => { - declarations.push(Function { - name: "close_door", - description: "close door of home", - parameters: None, + declarations.push(Tool { + r#type: "function", + function: Function { + name: "close_door", + description: "close door of home", + parameters: None, + }, }); function_store.add( "close_door", diff --git a/src/llm/function.rs b/src/llm/function.rs index 5b4163d..442d37f 100644 --- a/src/llm/function.rs +++ b/src/llm/function.rs @@ -6,16 +6,6 @@ use anyhow::anyhow; use anyhow::Context; use anyhow::Result; use log::info; -use serde::Serialize; - -// both openai and gemini shares same openai schema -#[derive(Debug, Serialize)] -pub struct Function { - pub name: &'static str, - pub description: &'static str, - #[serde(skip_serializing_if = "Option::is_none")] - pub parameters: Option, -} pub type FunctionImplementation = dyn Fn(&serde_json::Value) -> serde_json::Value + Send + Sync; diff --git a/src/openai.rs b/src/openai.rs index 35738b1..896892a 100644 --- a/src/openai.rs +++ b/src/openai.rs @@ -1,2 +1,2 @@ pub mod chat; -mod chat_api; +pub mod chat_api; diff --git a/src/openai/chat.rs b/src/openai/chat.rs index e872257..6ad72af 100644 --- a/src/openai/chat.rs +++ b/src/openai/chat.rs @@ -1,5 +1,4 @@ use std::fs; -use std::ops::Not; use std::path::Path; use std::str; use std::sync::Arc; @@ -27,12 +26,10 @@ use super::chat_api::StreamOptions; use super::chat_api::Tool; use super::chat_api::ToolCall; use super::chat_api::Usage; -use crate::llm::function::Function; use crate::llm::function::FunctionPayload; use crate::llm::function::FUNCTION_STORE; use crate::llm::ChatOption; use crate::llm::TextStream; -use crate::llm::TokenUsage; use crate::util::http_client::ResponseExt; use crate::util::http_client::HTTP_CLIENT; use crate::util::json; @@ -49,20 +46,11 @@ struct Context { messages: Arc>, tools: Option>, option: Option, - usage: TokenUsage, + usage: Arc, } impl Chat { - pub fn new(url: String, api_key: String, model: String, functions: Vec) -> Self { - let tools: Option> = functions.is_empty().not().then_some( - functions - .into_iter() - .map(|function| Tool { - r#type: "function", - function, - }) - .collect(), - ); + pub fn new(url: String, api_key: String, model: String, tools: Option>) -> Self { Chat { context: Arc::from(Mutex::new(Context { url, @@ -71,7 +59,7 @@ impl Chat { messages: Arc::new(vec![]), tools, option: None, - usage: TokenUsage::default(), + usage: Arc::new(Usage::default()), })), } } @@ -114,7 +102,7 @@ impl Chat { self.context.lock().unwrap().option = Some(option); } - pub fn usage(&self) -> TokenUsage { + pub fn usage(&self) -> Arc { self.context.lock().unwrap().usage.clone() } } @@ -131,8 +119,7 @@ async fn process(context: Arc>, tx: mpsc::Sender) -> Resu let response = read_sse_response(http_response, &tx).await?; let mut context = context.lock().unwrap(); - context.usage.prompt_tokens += response.usage.prompt_tokens; - context.usage.completion_tokens += response.usage.completion_tokens; + context.usage = Arc::new(response.usage); let message = response.choices.into_iter().next().unwrap().message; diff --git a/src/openai/chat_api.rs b/src/openai/chat_api.rs index 50de3e1..d175043 100644 --- a/src/openai/chat_api.rs +++ b/src/openai/chat_api.rs @@ -3,8 +3,6 @@ use std::sync::Arc; use serde::Deserialize; use serde::Serialize; -use crate::llm::function::Function; - #[derive(Debug, Serialize)] pub struct ChatRequest { pub model: String, @@ -120,6 +118,14 @@ pub struct Tool { pub function: Function, } +#[derive(Debug, Serialize)] +pub struct Function { + pub name: &'static str, + pub description: &'static str, + #[serde(skip_serializing_if = "Option::is_none")] + pub parameters: Option, +} + #[derive(Debug, Serialize, Deserialize)] pub enum Role { #[serde(rename = "user")]