Skip to content

Commit

Permalink
refactor code
Browse files Browse the repository at this point in the history
  • Loading branch information
neowu committed Jul 16, 2024
1 parent 55c1d69 commit a809007
Show file tree
Hide file tree
Showing 7 changed files with 52 additions and 63 deletions.
18 changes: 6 additions & 12 deletions src/azure/chatgpt.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
use std::collections::HashMap;
use std::ops::Not;
use std::path::Path;
use std::path::PathBuf;
use std::rc::Rc;
use std::str;

Expand Down Expand Up @@ -100,7 +99,7 @@ impl<L: ChatListener> ChatGPT<L> {
messages.insert(0, ChatRequestMessage::new_message(Role::System, message))
}

pub async fn add_user_message(&mut self, message: String, files: Option<Vec<PathBuf>>) -> Result<(), Exception> {
pub async fn add_user_message(&mut self, message: String, files: &[&Path]) -> Result<(), Exception> {
let image_urls = image_urls(files).await?;
self.add_message(ChatRequestMessage::new_user_message(message, image_urls));
Ok(())
Expand Down Expand Up @@ -237,16 +236,11 @@ async fn read_sse(response: Response, tx: Sender<ChatResponse>) -> Result<(), Ex
Ok(())
}

async fn image_urls(files: Option<Vec<PathBuf>>) -> Result<Option<Vec<String>>, Exception> {
let image_urls = if let Some(paths) = files {
let mut image_urls = Vec::with_capacity(paths.len());
for path in paths {
image_urls.push(base64_image_url(&path).await?)
}
Some(image_urls)
} else {
None
};
async fn image_urls(files: &[&Path]) -> Result<Vec<String>, Exception> {
let mut image_urls = Vec::with_capacity(files.len());
for file in files {
image_urls.push(base64_image_url(file).await?)
}
Ok(image_urls)
}

Expand Down
16 changes: 7 additions & 9 deletions src/azure/chatgpt_api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,21 +69,19 @@ impl ChatRequestMessage {
}
}

pub fn new_user_message(message: String, image_urls: Option<Vec<String>>) -> Self {
pub fn new_user_message(message: String, image_urls: Vec<String>) -> Self {
let mut content = vec![];
content.push(Content {
r#type: "text".to_string(),
text: Some(message),
image_url: None,
});
if let Some(image_urls) = image_urls {
for url in image_urls {
content.push(Content {
r#type: "image_url".to_string(),
text: None,
image_url: Some(ImageUrl { url }),
});
}
for url in image_urls {
content.push(Content {
r#type: "image_url".to_string(),
text: None,
image_url: Some(ImageUrl { url }),
});
}
ChatRequestMessage {
role: Role::User,
Expand Down
20 changes: 17 additions & 3 deletions src/command/chat.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::mem;
use std::path::Path;
use std::path::PathBuf;

use clap::Args;
Expand All @@ -25,6 +25,18 @@ impl Chat {
let config = llm::load(self.conf.as_deref()).await?;
let mut model = config.create(&self.model, Some(ConsolePrinter))?;

let welcome_text = r#"
---
# Welcome to Puppet Chat
---
# Usage Instructions:
- Type /quit to quit the application.
- Type /file {file} to add a file.
---
"#;
console::print(welcome_text).await?;
let reader = BufReader::new(stdin());
let mut lines = reader.lines();
let mut files: Vec<PathBuf> = vec![];
Expand All @@ -41,8 +53,10 @@ impl Chat {
println!("added file, path={}", file.to_string_lossy());
files.push(file);
} else {
let files = mem::take(&mut files).into_iter().map(Some).collect();
model.add_user_message(line, files).await?;
let data: Vec<&Path> = files.iter().map(|p| p.as_path()).collect();
model.add_user_message(line, &data).await?;
files.clear();

model.chat().await?;
}
}
Expand Down
17 changes: 6 additions & 11 deletions src/command/complete.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ impl Complete {
}
state = self.process_line(state, line, &mut model, &mut message, &mut files).await?;
}
add_message(&mut model, state, message, files.into_iter().map(Some).collect()).await?;
add_message(&mut model, state, message, files).await?;

let assistant_message = model.chat().await?;
let mut prompt = fs::OpenOptions::new().append(true).open(&self.prompt).await?;
Expand Down Expand Up @@ -83,11 +83,10 @@ impl Complete {
}
return Ok(ParserState::System);
} else if line.starts_with("# user") {
let files = mem::take(files).into_iter().map(Some).collect();
add_message(model, state, mem::take(message), files).await?;
add_message(model, state, mem::take(message), mem::take(files)).await?;
return Ok(ParserState::User);
} else if line.starts_with("# assistant") {
add_message(model, state, mem::take(message), None).await?;
add_message(model, state, mem::take(message), vec![]).await?;
return Ok(ParserState::Assistant);
} else if line.starts_with("> file: ") {
if !matches!(state, ParserState::User) {
Expand Down Expand Up @@ -139,20 +138,16 @@ fn extension(file: &Path) -> Result<&str, Exception> {
Ok(extension)
}

async fn add_message(
model: &mut llm::Model<ConsolePrinter>,
state: ParserState,
message: String,
files: Option<Vec<PathBuf>>,
) -> Result<(), Exception> {
async fn add_message(model: &mut llm::Model<ConsolePrinter>, state: ParserState, message: String, files: Vec<PathBuf>) -> Result<(), Exception> {
match state {
ParserState::System => {
info!("set system message: {}", message);
model.system_message(message);
}
ParserState::User => {
info!("add user message: {}", message);
model.add_user_message(message, files).await?;
let data: Vec<&Path> = files.iter().map(|p| p.as_path()).collect();
model.add_user_message(message, &data).await?;
}
ParserState::Assistant => {
info!("add assistent message: {}", message);
Expand Down
20 changes: 7 additions & 13 deletions src/gcloud/gemini.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
use std::mem;
use std::ops::Not;
use std::path::Path;
use std::path::PathBuf;
use std::rc::Rc;
use std::str;

Expand Down Expand Up @@ -79,9 +78,9 @@ impl<L: ChatListener> Gemini<L> {
self.system_instruction = Some(Rc::new(Content::new_model_text(text)));
}

pub async fn add_user_text(&mut self, text: String, files: Option<Vec<PathBuf>>) -> Result<(), Exception> {
pub async fn add_user_text(&mut self, text: String, files: &[&Path]) -> Result<(), Exception> {
let data = inline_datas(files).await?;
if data.is_some() {
if !data.is_empty() {
self.tools = None; // function call is not supported with inline data
}
self.add_message(Content::new_user_text(text, data));
Expand Down Expand Up @@ -220,16 +219,11 @@ fn is_valid_json(content: &str) -> bool {
result.is_ok()
}

async fn inline_datas(files: Option<Vec<PathBuf>>) -> Result<Option<Vec<InlineData>>, Exception> {
let data = if let Some(paths) = files {
let mut data = Vec::with_capacity(paths.len());
for path in paths {
data.push(inline_data(&path).await?);
}
Some(data)
} else {
None
};
async fn inline_datas(files: &[&Path]) -> Result<Vec<InlineData>, Exception> {
let mut data = Vec::with_capacity(files.len());
for file in files {
data.push(inline_data(file).await?);
}
Ok(data)
}

Expand Down
21 changes: 8 additions & 13 deletions src/gcloud/gemini_api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,20 +23,15 @@ pub struct Content {
}

impl Content {
pub fn new_user_text(text: String, data: Option<Vec<InlineData>>) -> Self {
pub fn new_user_text(text: String, datas: Vec<InlineData>) -> Self {
let mut parts: Vec<Part> = vec![];
if let Some(data) = data {
parts.append(
&mut data
.into_iter()
.map(|d| Part {
text: None,
inline_data: Some(d),
function_call: None,
function_response: None,
})
.collect(),
);
for data in datas {
parts.push(Part {
text: None,
inline_data: Some(data),
function_call: None,
function_response: None,
});
}
parts.push(Part {
text: Some(text),
Expand Down
3 changes: 1 addition & 2 deletions src/llm.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
use std::path::Path;
use std::path::PathBuf;

use tokio::fs;
use tracing::info;
Expand Down Expand Up @@ -64,7 +63,7 @@ impl<L: ChatListener> Model<L> {
}
}

pub async fn add_user_message(&mut self, message: String, files: Option<Vec<PathBuf>>) -> Result<(), Exception> {
pub async fn add_user_message(&mut self, message: String, files: &[&Path]) -> Result<(), Exception> {
match self {
Model::ChatGPT(model) => model.add_user_message(message, files).await,
Model::Gemini(model) => model.add_user_text(message, files).await,
Expand Down

0 comments on commit a809007

Please sign in to comment.