Skip to content

Commit

Permalink
multi model prototype
Browse files Browse the repository at this point in the history
  • Loading branch information
neowu committed Apr 17, 2024
1 parent cc0e43b commit cf98860
Show file tree
Hide file tree
Showing 11 changed files with 129 additions and 62 deletions.
15 changes: 11 additions & 4 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 5 additions & 3 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
tracing = "0.1.40"
tracing-subscriber = "0.3.18"
clap = { version = "4.4.16", features = ["derive"] }
clap_complete = "4.4.6"
serde = { version = "1.0", features = ["derive", "rc"] }
Expand All @@ -14,9 +16,9 @@ hyper = { version = "1.2.0", features = ["full"] }
tokio = { version = "1.36.0", features = ["full"] }
reqwest = { version = "0.11.24", features = ["stream"] }
reqwest-eventsource = "0.5.0"
futures = "0.3.30"
rand = "0.8.5"
base64 = { version = "0.22.0" }

futures = "0.3.30"
axum = { version = "0.7.4", features = ["ws"] }
tracing = "0.1.40"
tracing-subscriber = "0.3.18"
axum-extra = { version = "0.9.2", features = ["typed-header"] }
Binary file added cat.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
7 changes: 7 additions & 0 deletions src/bot.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,13 @@ impl Bot {
Bot::Vertex(bot) => bot.chat(message, handler).await,
}
}

pub async fn upload(&mut self, path: &Path, message: String, handler: &dyn ChatHandler) -> Result<(), Exception> {
match self {
Bot::ChatGPT(_bot) => todo!("not impl"),
Bot::Vertex(bot) => bot.upload(path, message, handler).await,
}
}
}

pub fn load(path: &Path) -> Result<Config, Exception> {
Expand Down
6 changes: 4 additions & 2 deletions src/bot/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,16 +29,17 @@ impl Config {
let bot = match config.r#type {
BotType::Azure => Bot::ChatGPT(ChatGPT::new(
config.endpoint.to_string(),
config.params.get("api_key").unwrap().to_string(),
config.params.get("model").unwrap().to_string(),
Option::None,
config.params.get("api_key").unwrap().to_string(),
config.system_message.clone(),
function_store,
)),
BotType::GCloud => Bot::Vertex(Vertex::new(
config.endpoint.to_string(),
config.params.get("project").unwrap().to_string(),
config.params.get("location").unwrap().to_string(),
config.params.get("model").unwrap().to_string(),
config.system_message.clone(),
function_store,
)),
};
Expand All @@ -51,6 +52,7 @@ impl Config {
pub struct BotConfig {
pub endpoint: String,
pub r#type: BotType,
pub system_message: Option<String>,
pub params: HashMap<String, String>,
pub functions: Vec<String>,
}
Expand Down
8 changes: 6 additions & 2 deletions src/command/chat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,12 @@ impl Chat {
if line == "/quit" {
break;
}

bot.chat(line, &handler).await?;
if line.starts_with("/upload ") {
let index = line.find(',').unwrap();
bot.upload(Path::new(line[8..index].trim()), line[index..].to_string(), &handler).await?;
} else {
bot.chat(line, &handler).await?;
}
}
Ok(())
}
Expand Down
37 changes: 36 additions & 1 deletion src/gcloud/api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ use crate::bot::Function;
#[derive(Debug, Serialize)]
pub struct StreamGenerateContent {
pub contents: Rc<Vec<Content>>,
#[serde(rename = "systemInstruction", skip_serializing_if = "Option::is_none")]
pub system_instruction: Rc<Option<Content>>,
#[serde(rename = "generationConfig")]
pub generation_config: GenerationConfig,
#[serde(skip_serializing_if = "Option::is_none")]
Expand All @@ -25,7 +27,8 @@ impl Content {
Self {
role,
parts: vec![Part {
text: Some(message.to_string()),
text: Some(message),
inline_data: None,
function_call: None,
function_response: None,
}],
Expand All @@ -37,6 +40,7 @@ impl Content {
role: Role::User,
parts: vec![Part {
text: None,
inline_data: None,
function_call: None,
function_response: Some(FunctionResponse { name, response }),
}],
Expand All @@ -48,11 +52,32 @@ impl Content {
role: Role::Model,
parts: vec![Part {
text: None,
inline_data: None,
function_call: Some(function_call),
function_response: None,
}],
}
}

pub fn new_inline_data(mime_type: String, data: String, message: String) -> Self {
Self {
role: Role::User,
parts: vec![
Part {
text: None,
inline_data: Some(InlineData { mime_type, data }),
function_call: None,
function_response: None,
},
Part {
text: Some(message),
inline_data: None,
function_call: None,
function_response: None,
},
],
}
}
}

#[derive(Debug, Serialize)]
Expand All @@ -73,6 +98,9 @@ pub enum Role {
pub struct Part {
#[serde(skip_serializing_if = "Option::is_none")]
pub text: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub inline_data: Option<InlineData>,

#[serde(rename = "functionCall")]
#[serde(skip_serializing_if = "Option::is_none")]
pub function_call: Option<FunctionCall>,
Expand All @@ -90,6 +118,13 @@ pub struct GenerationConfig {
pub max_output_tokens: i32,
}

#[derive(Debug, Serialize, Deserialize)]
pub struct InlineData {
#[serde(rename = "mimeType")]
pub mime_type: String,
pub data: String,
}

#[derive(Debug, Deserialize)]
pub struct GenerateContentResponse {
pub candidates: Vec<Candidate>,
Expand Down
84 changes: 49 additions & 35 deletions src/gcloud/vertex.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
use base64::prelude::BASE64_STANDARD;
use base64::Engine;
use futures::StreamExt;
use reqwest::Response;
use tokio::sync::mpsc::channel;
use tokio::sync::mpsc::Sender;

use std::env;
use std::fs;
use std::path::Path;
use std::rc::Rc;

use crate::bot::ChatEvent;
Expand All @@ -23,23 +27,27 @@ use super::api::StreamGenerateContent;
use super::api::Tool;

pub struct Vertex {
endpoint: String,
project: String,
location: String,
model: String,
url: String,
messages: Rc<Vec<Content>>,
system_message: Rc<Option<Content>>,
tools: Rc<Vec<Tool>>,
function_store: FunctionStore,
}

impl Vertex {
pub fn new(endpoint: String, project: String, location: String, model: String, function_store: FunctionStore) -> Self {
pub fn new(
endpoint: String,
project: String,
location: String,
model: String,
system_message: Option<String>,
function_store: FunctionStore,
) -> Self {
let url = format!("{endpoint}/v1/projects/{project}/locations/{location}/publishers/google/models/{model}:streamGenerateContent");
Vertex {
endpoint,
project,
location,
model,
url,
messages: Rc::new(vec![]),
system_message: Rc::new(system_message.map(|message| Content::new_text(Role::Model, message))),
tools: Rc::new(
function_store
.declarations
Expand Down Expand Up @@ -67,6 +75,23 @@ impl Vertex {
Ok(())
}

pub async fn upload(&mut self, path: &Path, message: String, handler: &dyn ChatHandler) -> Result<(), Exception> {
let extension = path
.extension()
.ok_or_else(|| Exception::new(&format!("file must have extension, path={}", path.to_string_lossy())))?
.to_str()
.unwrap();
let content = fs::read(path)?;
let mime_type = match extension {
"jpg" => Ok("image/jpeg".to_string()),
"pdf" => Ok("application/pdf".to_string()),
_ => Err(Exception::new("not supported extension")),
}?;
self.process(Content::new_inline_data(mime_type, BASE64_STANDARD.encode(content), message), handler)
.await?;
Ok(())
}

async fn process(&mut self, content: Content, handler: &dyn ChatHandler) -> Result<Option<FunctionCall>, Exception> {
self.add_message(content);

Expand All @@ -80,21 +105,16 @@ impl Vertex {

let mut model_message = String::new();
while let Some(response) = rx.recv().await {
match response {
Ok(response) => {
let part = response.candidates.into_iter().next().unwrap().content.parts.into_iter().next().unwrap();

if let Some(function) = part.function_call {
self.add_message(Content::new_function_call(function.clone()));
return Ok(Some(function));
} else if let Some(text) = part.text {
handler.on_event(ChatEvent::Delta(text.clone()));
model_message.push_str(&text);
}
}
Err(err) => {
return Err(err);
}
let response = response?;

let part = response.candidates.into_iter().next().unwrap().content.parts.into_iter().next().unwrap();

if let Some(function) = part.function_call {
self.add_message(Content::new_function_call(function.clone()));
return Ok(Some(function));
} else if let Some(text) = part.text {
handler.on_event(ChatEvent::Delta(text.clone()));
model_message.push_str(&text);
}
}
if !model_message.is_empty() {
Expand All @@ -111,29 +131,24 @@ impl Vertex {
async fn call_api(&self) -> Result<Response, Exception> {
let has_function = !self.tools.is_empty();

let endpoint = &self.endpoint;
let project = &self.project;
let location = &self.location;
let model = &self.model;
let url = format!("{endpoint}/v1/projects/{project}/locations/{location}/publishers/google/models/{model}:streamGenerateContent");

let request = StreamGenerateContent {
contents: Rc::clone(&self.messages),
system_instruction: Rc::clone(&self.system_message),
generation_config: GenerationConfig {
temperature: 1.0,
top_p: 0.95,
max_output_tokens: 2048,
},
tools: has_function.then(|| Rc::clone(&self.tools)),
};
let response = self.post(&url, &request).await?;
let response = self.post(request).await?;
Ok(response)
}

async fn post(&self, url: &str, request: &StreamGenerateContent) -> Result<Response, Exception> {
let body = json::to_json(request)?;
async fn post(&self, request: StreamGenerateContent) -> Result<Response, Exception> {
let body = json::to_json(&request)?;
let response = http_client::http_client()
.post(url)
.post(&self.url)
.bearer_auth(token())
.header("Content-Type", "application/json")
.header("Accept", "application/json")
Expand Down Expand Up @@ -172,7 +187,6 @@ async fn process_response_stream(response: Response, tx: Sender<Result<GenerateC
buffer.clear();
}
Err(err) => {
// tx.send(InternalEvent::Event(ChatEvent::Error(err.to_string()))).await.unwrap();
tx.send(Err(Exception::new(&err.to_string()))).await.unwrap();
break;
}
Expand Down
12 changes: 6 additions & 6 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@ mod util;
#[command(about = "Puppet AI")]
pub struct Cli {
#[command(subcommand)]
command: Option<Commands>,
command: Option<Command>,
}

#[derive(Subcommand)]
#[command(arg_required_else_help(true))]
pub enum Commands {
pub enum Command {
Chat(Chat),
Server(Server),
GenerateZshCompletion(GenerateZshCompletion),
Expand All @@ -32,10 +32,10 @@ pub enum Commands {
async fn main() -> Result<(), Exception> {
tracing_subscriber::fmt::init();
let cli = Cli::parse();
match &cli.command {
Some(Commands::GenerateZshCompletion(command)) => command.execute(),
Some(Commands::Chat(command)) => command.execute().await,
Some(Commands::Server(command)) => command.execute().await,
match cli.command {
Some(Command::GenerateZshCompletion(command)) => command.execute(),
Some(Command::Chat(command)) => command.execute().await,
Some(Command::Server(command)) => command.execute().await,
None => panic!("not implemented"),
}
}
Loading

0 comments on commit cf98860

Please sign in to comment.