Skip to content

Commit

Permalink
refactor rust syntax
Browse files Browse the repository at this point in the history
  • Loading branch information
neowu committed Apr 17, 2024
1 parent 1d56293 commit 102bd6e
Show file tree
Hide file tree
Showing 7 changed files with 92 additions and 86 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ edition = "2021"
[dependencies]
clap = { version = "4.4.16", features = ["derive"] }
clap_complete = "4.4.6"
serde = { version = "1.0", features = ["derive"] }
serde = { version = "1.0", features = ["derive", "rc"] }
serde_json = "1.0"
hyper = { version = "1.2.0", features = ["full"] }
tokio = { version = "1.36.0", features = ["full"] }
Expand Down
4 changes: 2 additions & 2 deletions src/bot.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ use self::config::Config;
pub mod config;

pub trait ChatHandler {
fn on_event(&self, event: &ChatEvent);
fn on_event(&self, event: ChatEvent);
}

pub enum ChatEvent {
Expand Down Expand Up @@ -70,7 +70,7 @@ pub enum Bot {
}

impl Bot {
pub async fn chat(&mut self, message: &str, handler: &dyn ChatHandler) -> Result<(), Exception> {
pub async fn chat(&mut self, message: String, handler: &dyn ChatHandler) -> Result<(), Exception> {
match self {
Bot::ChatGPT(bot) => bot.chat(message, handler).await,
Bot::Vertex(bot) => bot.chat(message, handler).await,
Expand Down
6 changes: 3 additions & 3 deletions src/command/chat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@ pub struct Chat {
struct ConsoleHandler;

impl ChatHandler for ConsoleHandler {
fn on_event(&self, event: &ChatEvent) {
fn on_event(&self, event: ChatEvent) {
match event {
ChatEvent::Delta(data) => {
ChatEvent::Delta(ref data) => {
print_flush(data).unwrap();
}
ChatEvent::Error(error) => {
Expand All @@ -53,7 +53,7 @@ impl Chat {
break;
}

bot.chat(&line, &handler).await?;
bot.chat(line, &handler).await?;
}
Ok(())
}
Expand Down
28 changes: 12 additions & 16 deletions src/gcloud/api.rs
Original file line number Diff line number Diff line change
@@ -1,28 +1,27 @@
use std::borrow::Cow;
use std::rc::Rc;

use serde::Deserialize;
use serde::Serialize;

use crate::bot::Function;

#[derive(Debug, Serialize)]
pub struct StreamGenerateContent<'a> {
#[serde(borrow)]
pub contents: Cow<'a, [Content]>,
pub struct StreamGenerateContent {
pub contents: Rc<Vec<Content>>,
#[serde(rename = "generationConfig")]
pub generation_config: GenerationConfig,
#[serde(skip_serializing_if = "Option::is_none")]
pub tools: Option<Cow<'a, [Tool]>>,
pub tools: Option<Rc<Vec<Tool>>>,
}

#[derive(Debug, Serialize, Deserialize, Clone)]
#[derive(Debug, Serialize, Deserialize)]
pub struct Content {
pub role: Role,
pub parts: Vec<Part>,
}

impl Content {
pub fn new_text(role: Role, message: &str) -> Self {
pub fn new_text(role: Role, message: String) -> Self {
Self {
role,
parts: vec![Part {
Expand All @@ -33,16 +32,13 @@ impl Content {
}
}

pub fn new_function_response(name: &str, response: serde_json::Value) -> Self {
pub fn new_function_response(name: String, response: serde_json::Value) -> Self {
Self {
role: Role::User,
parts: vec![Part {
text: None,
function_call: None,
function_response: Some(FunctionResponse {
name: name.to_string(),
response,
}),
function_response: Some(FunctionResponse { name, response }),
}],
}
}
Expand All @@ -59,21 +55,21 @@ impl Content {
}
}

#[derive(Debug, Serialize, Clone)]
#[derive(Debug, Serialize)]
pub struct Tool {
#[serde(rename = "functionDeclarations")]
pub function_declarations: Vec<Function>,
}

#[derive(Debug, Serialize, Deserialize, Clone)]
#[derive(Debug, Serialize, Deserialize)]
pub enum Role {
#[serde(rename = "user")]
User,
#[serde(rename = "model")]
Model,
}

#[derive(Debug, Serialize, Deserialize, Clone)]
#[derive(Debug, Serialize, Deserialize)]
pub struct Part {
#[serde(skip_serializing_if = "Option::is_none")]
pub text: Option<String>,
Expand Down Expand Up @@ -110,7 +106,7 @@ pub struct FunctionCall {
pub args: serde_json::Value,
}

#[derive(Debug, Serialize, Deserialize, Clone)]
#[derive(Debug, Serialize, Deserialize)]
pub struct FunctionResponse {
pub name: String,
pub response: serde_json::Value,
Expand Down
61 changes: 33 additions & 28 deletions src/gcloud/vertex.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,8 @@ use reqwest::Response;
use tokio::sync::mpsc::channel;
use tokio::sync::mpsc::Sender;

use std::borrow::Cow;

use std::env;
use std::rc::Rc;

use crate::bot::ChatEvent;
use crate::bot::ChatHandler;
Expand All @@ -28,8 +27,8 @@ pub struct Vertex {
project: String,
location: String,
model: String,
messages: Vec<Content>,
tools: Vec<Tool>,
messages: Rc<Vec<Content>>,
tools: Rc<Vec<Tool>>,
function_store: FunctionStore,
}

Expand All @@ -40,34 +39,36 @@ impl Vertex {
project,
location,
model,
messages: vec![],
tools: function_store
.declarations
.iter()
.map(|f| Tool {
function_declarations: vec![f.clone()],
})
.collect(),
messages: Rc::new(vec![]),
tools: Rc::new(
function_store
.declarations
.iter()
.map(|f| Tool {
function_declarations: vec![f.clone()],
})
.collect(),
),
function_store,
}
}

pub async fn chat(&mut self, message: &str, handler: &dyn ChatHandler) -> Result<(), Exception> {
pub async fn chat(&mut self, message: String, handler: &dyn ChatHandler) -> Result<(), Exception> {
let mut result = self.process(Content::new_text(Role::User, message), handler).await?;

while let Some(function_call) = result {
let function = self.function_store.get(&function_call.name)?;

let function_response = tokio::spawn(async move { function(function_call.args) }).await?;

let content = Content::new_function_response(&function_call.name, function_response);
let content = Content::new_function_response(function_call.name, function_response);
result = self.process(content, handler).await?;
}
Ok(())
}

async fn process(&mut self, content: Content, handler: &dyn ChatHandler) -> Result<Option<FunctionCall>, Exception> {
self.messages.push(content);
self.add_message(content);

let response = self.call_api().await?;

Expand All @@ -81,14 +82,14 @@ impl Vertex {
while let Some(response) = rx.recv().await {
match response {
Ok(response) => {
let part = response.candidates.first().unwrap().content.parts.first().unwrap();

if let Some(function) = part.function_call.as_ref() {
self.messages.push(Content::new_function_call(function.clone()));
return Ok(Some(function.clone()));
} else if let Some(text) = part.text.as_ref() {
handler.on_event(&ChatEvent::Delta(text.to_string()));
model_message.push_str(text);
let part = response.candidates.into_iter().next().unwrap().content.parts.into_iter().next().unwrap();

if let Some(function) = part.function_call {
self.add_message(Content::new_function_call(function.clone()));
return Ok(Some(function));
} else if let Some(text) = part.text {
handler.on_event(ChatEvent::Delta(text.clone()));
model_message.push_str(&text);
}
}
Err(err) => {
Expand All @@ -97,12 +98,16 @@ impl Vertex {
}
}
if !model_message.is_empty() {
self.messages.push(Content::new_text(Role::Model, &model_message));
self.add_message(Content::new_text(Role::Model, model_message));
}
handler.on_event(&ChatEvent::End);
handler.on_event(ChatEvent::End);
Ok(None)
}

fn add_message(&mut self, content: Content) {
Rc::get_mut(&mut self.messages).unwrap().push(content);
}

async fn call_api(&self) -> Result<Response, Exception> {
let has_function = !self.tools.is_empty();

Expand All @@ -113,19 +118,19 @@ impl Vertex {
let url = format!("{endpoint}/v1/projects/{project}/locations/{location}/publishers/google/models/{model}:streamGenerateContent");

let request = StreamGenerateContent {
contents: Cow::from(&self.messages),
contents: Rc::clone(&self.messages),
generation_config: GenerationConfig {
temperature: 1.0,
top_p: 0.95,
max_output_tokens: 2048,
},
tools: has_function.then(|| Cow::from(&self.tools)),
tools: has_function.then(|| Rc::clone(&self.tools)),
};
let response = self.post(&url, &request).await?;
Ok(response)
}

async fn post(&self, url: &str, request: &StreamGenerateContent<'_>) -> Result<Response, Exception> {
async fn post(&self, url: &str, request: &StreamGenerateContent) -> Result<Response, Exception> {
let body = json::to_json(request)?;
let response = http_client::http_client()
.post(url)
Expand Down
23 changes: 11 additions & 12 deletions src/openai/api.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
use std::borrow::Cow;
use std::collections::HashMap;
use std::rc::Rc;

use serde::Deserialize;
use serde::Serialize;

use crate::bot::Function;

#[derive(Debug, Serialize)]
pub struct ChatRequest<'a> {
#[serde(borrow)]
pub messages: Cow<'a, [ChatRequestMessage]>,
pub struct ChatRequest {
pub messages: Rc<Vec<ChatRequestMessage>>,
pub temperature: f32,
pub top_p: f32,
pub stream: bool,
Expand All @@ -21,10 +20,10 @@ pub struct ChatRequest<'a> {
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_choice: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tools: Option<Cow<'a, [Tool]>>,
pub tools: Option<Rc<Vec<Tool>>>,
}

#[derive(Debug, Serialize, Clone)]
#[derive(Debug, Serialize)]
pub struct ChatRequestMessage {
pub role: Role,
pub content: Option<String>,
Expand All @@ -35,10 +34,10 @@ pub struct ChatRequestMessage {
}

impl ChatRequestMessage {
pub fn new_message(role: Role, message: &str) -> Self {
pub fn new_message(role: Role, message: String) -> Self {
ChatRequestMessage {
role,
content: Some(message.to_string()),
content: Some(message),
tool_call_id: None,
tool_calls: None,
}
Expand Down Expand Up @@ -76,13 +75,13 @@ impl ChatRequestMessage {
}
}

#[derive(Debug, Serialize, Clone)]
#[derive(Debug, Serialize)]
pub struct Tool {
pub r#type: String,
pub function: Function,
}

#[derive(Debug, Serialize, Deserialize, Clone)]
#[derive(Debug, Serialize, Deserialize)]
pub enum Role {
#[serde(rename = "user")]
User,
Expand Down Expand Up @@ -117,15 +116,15 @@ pub struct ChatResponseMessage {
pub tool_calls: Option<Vec<ToolCall>>,
}

#[derive(Debug, Serialize, Deserialize, Clone)]
#[derive(Debug, Serialize, Deserialize)]
pub struct ToolCall {
pub index: i64,
pub id: Option<String>,
pub r#type: Option<String>,
pub function: FunctionCall,
}

#[derive(Debug, Serialize, Deserialize, Clone)]
#[derive(Debug, Serialize, Deserialize)]
pub struct FunctionCall {
pub name: Option<String>,
pub arguments: String,
Expand Down
Loading

0 comments on commit 102bd6e

Please sign in to comment.