Skip to content

Commit

Permalink
refactor syntax
Browse files Browse the repository at this point in the history
  • Loading branch information
neowu committed Apr 19, 2024
1 parent 1eccfea commit f29726c
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 50 deletions.
2 changes: 1 addition & 1 deletion src/bot.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ pub enum Bot {
}

impl Bot {
pub async fn chat(&mut self, message: String, handler: &dyn ChatHandler) -> Result<(), Exception> {
pub async fn chat(&mut self, message: String, handler: &impl ChatHandler) -> Result<(), Exception> {
match self {
Bot::ChatGPT(bot) => bot.chat(message, handler).await,
Bot::Vertex(bot) => bot.chat(message, handler).await,
Expand Down
4 changes: 2 additions & 2 deletions src/gcloud/api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@ use crate::bot::function::Function;
pub struct StreamGenerateContent {
pub contents: Rc<Vec<Content>>,
#[serde(rename = "systemInstruction", skip_serializing_if = "Option::is_none")]
pub system_instruction: Rc<Option<Content>>,
pub system_instruction: Option<Rc<Content>>,
#[serde(rename = "generationConfig")]
pub generation_config: GenerationConfig,
#[serde(skip_serializing_if = "Option::is_none")]
pub tools: Option<Rc<Vec<Tool>>>,
pub tools: Option<Rc<[Tool]>>,
}

#[derive(Debug, Serialize, Deserialize)]
Expand Down
26 changes: 11 additions & 15 deletions src/gcloud/vertex.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::env;
use std::fs;
use std::mem;
use std::ops::Not;
use std::path::Path;
use std::rc::Rc;

Expand Down Expand Up @@ -31,8 +32,8 @@ use crate::util::json;
pub struct Vertex {
url: String,
messages: Rc<Vec<Content>>,
system_message: Rc<Option<Content>>,
tools: Rc<Vec<Tool>>,
system_message: Option<Rc<Content>>,
tools: Option<Rc<[Tool]>>,
function_store: FunctionStore,
data: Vec<InlineData>,
usage: Usage,
Expand All @@ -51,17 +52,17 @@ impl Vertex {
Vertex {
url,
messages: Rc::new(vec![]),
system_message: Rc::new(system_message.map(|message| Content::new_text(Role::Model, message))),
tools: Rc::new(vec![Tool {
system_message: system_message.map(|message| Rc::new(Content::new_text(Role::Model, message))),
tools: function_store.declarations.is_empty().not().then_some(Rc::from(vec![Tool {
function_declarations: function_store.declarations.to_vec(),
}]),
}])),
function_store,
data: vec![],
usage: Usage::default(),
}
}

pub async fn chat(&mut self, message: String, handler: &dyn ChatHandler) -> Result<(), Exception> {
pub async fn chat(&mut self, message: String, handler: &impl ChatHandler) -> Result<(), Exception> {
let data = mem::take(&mut self.data);
let mut result = self.process(Content::new_text_with_inline_data(message, data), handler).await?;

Expand Down Expand Up @@ -97,7 +98,7 @@ impl Vertex {
Ok(())
}

async fn process(&mut self, content: Content, handler: &dyn ChatHandler) -> Result<Option<FunctionCall>, Exception> {
async fn process(&mut self, content: Content, handler: &impl ChatHandler) -> Result<Option<FunctionCall>, Exception> {
self.add_message(content);

let response = self.call_api().await?;
Expand Down Expand Up @@ -138,23 +139,17 @@ impl Vertex {
}

async fn call_api(&self) -> Result<Response, Exception> {
let has_function = !self.tools.is_empty();

let request = StreamGenerateContent {
contents: Rc::clone(&self.messages),
system_instruction: Rc::clone(&self.system_message),
system_instruction: self.system_message.clone(),
generation_config: GenerationConfig {
temperature: 1.0,
top_p: 0.95,
max_output_tokens: 2048,
},
tools: has_function.then(|| Rc::clone(&self.tools)),
tools: self.tools.clone(),
};
let response = self.post(request).await?;
Ok(response)
}

async fn post(&self, request: StreamGenerateContent) -> Result<Response, Exception> {
let body = json::to_json(&request)?;
// info!("body={body}");
let response = http_client::http_client()
Expand All @@ -174,6 +169,7 @@ impl Vertex {
response.text().await?
)));
}

Ok(response)
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/openai/api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ pub struct ChatRequest {
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_choice: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tools: Option<Rc<Vec<Tool>>>,
pub tools: Option<Rc<[Tool]>>,
}

#[derive(Debug, Serialize)]
Expand Down
52 changes: 21 additions & 31 deletions src/openai/chatgpt.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
use std::collections::HashMap;
use std::fmt;
use std::ops::Not;
use std::rc::Rc;

use futures::stream::StreamExt;
use reqwest_eventsource::CannotCloneRequestError;
use reqwest_eventsource::Event;
use reqwest_eventsource::EventSource;
use serde::Serialize;
use tokio::sync::mpsc::channel;
use tokio::sync::mpsc::Sender;

Expand All @@ -27,7 +26,7 @@ pub struct ChatGPT {
url: String,
api_key: String,
messages: Rc<Vec<ChatRequestMessage>>,
tools: Rc<Vec<Tool>>,
tools: Option<Rc<[Tool]>>,
function_store: FunctionStore,
}

Expand All @@ -41,20 +40,21 @@ type FunctionCall = HashMap<i64, (String, String, String)>;
impl ChatGPT {
pub fn new(endpoint: String, model: String, api_key: String, system_message: Option<String>, function_store: FunctionStore) -> Self {
let url = format!("{endpoint}/openai/deployments/{model}/chat/completions?api-version=2024-02-01");
let tools: Option<Rc<[Tool]>> = function_store.declarations.is_empty().not().then_some(
function_store
.declarations
.iter()
.map(|f| Tool {
r#type: "function".to_string(),
function: f.clone(),
})
.collect(),
);
let mut chatgpt = ChatGPT {
url,
api_key,
messages: Rc::new(vec![]),
tools: Rc::new(
function_store
.declarations
.iter()
.map(|f| Tool {
r#type: "function".to_string(),
function: f.clone(),
})
.collect(),
),
tools,
function_store,
};
if let Some(message) = system_message {
Expand All @@ -63,14 +63,14 @@ impl ChatGPT {
chatgpt
}

pub async fn chat(&mut self, message: String, handler: &dyn ChatHandler) -> Result<(), Exception> {
pub async fn chat(&mut self, message: String, handler: &impl ChatHandler) -> Result<(), Exception> {
self.add_message(ChatRequestMessage::new_message(Role::User, message));
let result = self.process(handler).await;
if let Ok(Some(InternalEvent::FunctionCall(calls))) = result {
let functions = calls
.into_iter()
.map(|(_, (id, name, args))| (id, name, json::from_json(&args).unwrap()))
.collect();
let mut functions = Vec::with_capacity(calls.len());
for (_, (id, name, args)) in calls {
functions.push((id, name, json::from_json::<serde_json::Value>(&args)?))
}
let results = self.function_store.call_functions(functions).await?;
for result in results {
let function_response = ChatRequestMessage::new_function_response(result.0, json::to_json(&result.1)?);
Expand All @@ -85,7 +85,7 @@ impl ChatGPT {
Rc::get_mut(&mut self.messages).unwrap().push(message);
}

async fn process(&mut self, handler: &dyn ChatHandler) -> Result<Option<InternalEvent>, Exception> {
async fn process(&mut self, handler: &impl ChatHandler) -> Result<Option<InternalEvent>, Exception> {
let source = self.call_api().await?;

let (tx, mut rx) = channel(64);
Expand Down Expand Up @@ -117,8 +117,6 @@ impl ChatGPT {
}

async fn call_api(&mut self) -> Result<EventSource, Exception> {
let has_function = !self.tools.is_empty();

let request = ChatRequest {
messages: Rc::clone(&self.messages),
temperature: 0.8,
Expand All @@ -128,19 +126,11 @@ impl ChatGPT {
max_tokens: 800,
presence_penalty: 0.0,
frequency_penalty: 0.0,
tool_choice: has_function.then(|| "auto".to_string()),
tools: has_function.then(|| Rc::clone(&self.tools)),
tool_choice: self.tools.is_some().then_some("auto".to_string()),
tools: self.tools.clone(),
};
let source = self.post_sse(&request).await?;
Ok(source)
}

async fn post_sse<Request>(&self, request: &Request) -> Result<EventSource, Exception>
where
Request: Serialize + fmt::Debug,
{
let body = json::to_json(&request)?;

let request = http_client::http_client()
.post(&self.url)
.header("Content-Type", "application/json")
Expand Down

0 comments on commit f29726c

Please sign in to comment.