Skip to content

Commit

Permalink
simplify openai api
Browse files Browse the repository at this point in the history
  • Loading branch information
neowu committed Nov 14, 2024
1 parent aad81b9 commit fa7da7a
Show file tree
Hide file tree
Showing 6 changed files with 45 additions and 59 deletions.
6 changes: 0 additions & 6 deletions src/llm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,6 @@ use crate::util::json;
pub mod config;
pub mod function;

#[derive(Default, Clone)]
pub struct TokenUsage {
pub prompt_tokens: i32,
pub completion_tokens: i32,
}

#[derive(Debug)]
pub struct ChatOption {
pub temperature: f32,
Expand Down
53 changes: 31 additions & 22 deletions src/llm/config.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use std::collections::HashMap;
use std::sync::Arc;

use anyhow::anyhow;
use anyhow::Context;
Expand All @@ -9,8 +10,9 @@ use serde::Deserialize;
use serde_json::json;

use super::function::FUNCTION_STORE;
use crate::llm::function::Function;
use crate::openai::chat::Chat;
use crate::openai::chat_api::Function;
use crate::openai::chat_api::Tool;

#[derive(Deserialize, Debug)]
pub struct Config {
Expand All @@ -32,9 +34,10 @@ impl Config {

info!("create model, name={name}");

let functions = load_functions(config)?;
let tools = load_functions(config)?;
let tools = if tools.is_empty() { None } else { Some(Arc::from(tools)) };

let mut model = Chat::new(config.url.to_string(), config.api_key.to_string(), config.model.to_string(), functions);
let mut model = Chat::new(config.url.to_string(), config.api_key.to_string(), config.model.to_string(), tools);

if let Some(message) = config.system_message.as_ref() {
model.system_message(message.to_string());
Expand All @@ -44,26 +47,29 @@ impl Config {
}
}

fn load_functions(config: &ModelConfig) -> Result<Vec<Function>> {
let mut declarations: Vec<Function> = vec![];
fn load_functions(config: &ModelConfig) -> Result<Vec<Tool>> {
let mut declarations: Vec<Tool> = vec![];
let mut function_store = FUNCTION_STORE.lock().unwrap();
for function in &config.functions {
info!("load function, name={function}");
match function.as_str() {
"get_random_number" => {
declarations.push(Function {
name: "get_random_number",
description: "generate random number",
parameters: Some(serde_json::json!({
"type": "object",
"properties": {
"max": {
"type": "number",
"description": "max of value"
},
},
"required": ["max"]
})),
declarations.push(Tool {
r#type: "function",
function: Function {
name: "get_random_number",
description: "generate random number",
parameters: Some(serde_json::json!({
"type": "object",
"properties": {
"max": {
"type": "number",
"description": "max of value"
},
},
"required": ["max"]
})),
},
});
function_store.add(
"get_random_number",
Expand All @@ -79,10 +85,13 @@ fn load_functions(config: &ModelConfig) -> Result<Vec<Function>> {
)
}
"close_door" => {
declarations.push(Function {
name: "close_door",
description: "close door of home",
parameters: None,
declarations.push(Tool {
r#type: "function",
function: Function {
name: "close_door",
description: "close door of home",
parameters: None,
},
});
function_store.add(
"close_door",
Expand Down
10 changes: 0 additions & 10 deletions src/llm/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,6 @@ use anyhow::anyhow;
use anyhow::Context;
use anyhow::Result;
use log::info;
use serde::Serialize;

// both openai and gemini shares same openai schema
#[derive(Debug, Serialize)]
pub struct Function {
pub name: &'static str,
pub description: &'static str,
#[serde(skip_serializing_if = "Option::is_none")]
pub parameters: Option<serde_json::Value>,
}

pub type FunctionImplementation = dyn Fn(&serde_json::Value) -> serde_json::Value + Send + Sync;

Expand Down
2 changes: 1 addition & 1 deletion src/openai.rs
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
pub mod chat;
mod chat_api;
pub mod chat_api;
23 changes: 5 additions & 18 deletions src/openai/chat.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
use std::fs;
use std::ops::Not;
use std::path::Path;
use std::str;
use std::sync::Arc;
Expand Down Expand Up @@ -27,12 +26,10 @@ use super::chat_api::StreamOptions;
use super::chat_api::Tool;
use super::chat_api::ToolCall;
use super::chat_api::Usage;
use crate::llm::function::Function;
use crate::llm::function::FunctionPayload;
use crate::llm::function::FUNCTION_STORE;
use crate::llm::ChatOption;
use crate::llm::TextStream;
use crate::llm::TokenUsage;
use crate::util::http_client::ResponseExt;
use crate::util::http_client::HTTP_CLIENT;
use crate::util::json;
Expand All @@ -49,20 +46,11 @@ struct Context {
messages: Arc<Vec<ChatRequestMessage>>,
tools: Option<Arc<[Tool]>>,
option: Option<ChatOption>,
usage: TokenUsage,
usage: Arc<Usage>,
}

impl Chat {
pub fn new(url: String, api_key: String, model: String, functions: Vec<Function>) -> Self {
let tools: Option<Arc<[Tool]>> = functions.is_empty().not().then_some(
functions
.into_iter()
.map(|function| Tool {
r#type: "function",
function,
})
.collect(),
);
pub fn new(url: String, api_key: String, model: String, tools: Option<Arc<[Tool]>>) -> Self {
Chat {
context: Arc::from(Mutex::new(Context {
url,
Expand All @@ -71,7 +59,7 @@ impl Chat {
messages: Arc::new(vec![]),
tools,
option: None,
usage: TokenUsage::default(),
usage: Arc::new(Usage::default()),
})),
}
}
Expand Down Expand Up @@ -114,7 +102,7 @@ impl Chat {
self.context.lock().unwrap().option = Some(option);
}

pub fn usage(&self) -> TokenUsage {
pub fn usage(&self) -> Arc<Usage> {
self.context.lock().unwrap().usage.clone()
}
}
Expand All @@ -131,8 +119,7 @@ async fn process(context: Arc<Mutex<Context>>, tx: mpsc::Sender<String>) -> Resu
let response = read_sse_response(http_response, &tx).await?;

let mut context = context.lock().unwrap();
context.usage.prompt_tokens += response.usage.prompt_tokens;
context.usage.completion_tokens += response.usage.completion_tokens;
context.usage = Arc::new(response.usage);

let message = response.choices.into_iter().next().unwrap().message;

Expand Down
10 changes: 8 additions & 2 deletions src/openai/chat_api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@ use std::sync::Arc;
use serde::Deserialize;
use serde::Serialize;

use crate::llm::function::Function;

#[derive(Debug, Serialize)]
pub struct ChatRequest {
pub model: String,
Expand Down Expand Up @@ -120,6 +118,14 @@ pub struct Tool {
pub function: Function,
}

#[derive(Debug, Serialize)]
pub struct Function {
pub name: &'static str,
pub description: &'static str,
#[serde(skip_serializing_if = "Option::is_none")]
pub parameters: Option<serde_json::Value>,
}

#[derive(Debug, Serialize, Deserialize)]
pub enum Role {
#[serde(rename = "user")]
Expand Down

0 comments on commit fa7da7a

Please sign in to comment.