diff --git a/src/azure/chatgpt.rs b/src/azure/chatgpt.rs index 41779b3..0611c2d 100644 --- a/src/azure/chatgpt.rs +++ b/src/azure/chatgpt.rs @@ -1,7 +1,6 @@ use std::collections::HashMap; use std::ops::Not; use std::path::Path; -use std::path::PathBuf; use std::rc::Rc; use std::str; @@ -100,7 +99,7 @@ impl ChatGPT { messages.insert(0, ChatRequestMessage::new_message(Role::System, message)) } - pub async fn add_user_message(&mut self, message: String, files: Option>) -> Result<(), Exception> { + pub async fn add_user_message(&mut self, message: String, files: &[&Path]) -> Result<(), Exception> { let image_urls = image_urls(files).await?; self.add_message(ChatRequestMessage::new_user_message(message, image_urls)); Ok(()) @@ -237,16 +236,11 @@ async fn read_sse(response: Response, tx: Sender) -> Result<(), Ex Ok(()) } -async fn image_urls(files: Option>) -> Result>, Exception> { - let image_urls = if let Some(paths) = files { - let mut image_urls = Vec::with_capacity(paths.len()); - for path in paths { - image_urls.push(base64_image_url(&path).await?) - } - Some(image_urls) - } else { - None - }; +async fn image_urls(files: &[&Path]) -> Result, Exception> { + let mut image_urls = Vec::with_capacity(files.len()); + for file in files { + image_urls.push(base64_image_url(file).await?) + } Ok(image_urls) } diff --git a/src/azure/chatgpt_api.rs b/src/azure/chatgpt_api.rs index fbdcbbb..2a67494 100644 --- a/src/azure/chatgpt_api.rs +++ b/src/azure/chatgpt_api.rs @@ -69,21 +69,19 @@ impl ChatRequestMessage { } } - pub fn new_user_message(message: String, image_urls: Option>) -> Self { + pub fn new_user_message(message: String, image_urls: Vec) -> Self { let mut content = vec![]; content.push(Content { r#type: "text".to_string(), text: Some(message), image_url: None, }); - if let Some(image_urls) = image_urls { - for url in image_urls { - content.push(Content { - r#type: "image_url".to_string(), - text: None, - image_url: Some(ImageUrl { url }), - }); - } + for url in image_urls { + content.push(Content { + r#type: "image_url".to_string(), + text: None, + image_url: Some(ImageUrl { url }), + }); } ChatRequestMessage { role: Role::User, diff --git a/src/command/chat.rs b/src/command/chat.rs index 5dcd084..7583937 100644 --- a/src/command/chat.rs +++ b/src/command/chat.rs @@ -1,4 +1,4 @@ -use std::mem; +use std::path::Path; use std::path::PathBuf; use clap::Args; @@ -25,6 +25,18 @@ impl Chat { let config = llm::load(self.conf.as_deref()).await?; let mut model = config.create(&self.model, Some(ConsolePrinter))?; + let welcome_text = r#" +--- +# Welcome to Puppet Chat +--- +# Usage Instructions: + +- Type /quit to quit the application. + +- Type /file {file} to add a file. +--- +"#; + console::print(welcome_text).await?; let reader = BufReader::new(stdin()); let mut lines = reader.lines(); let mut files: Vec = vec![]; @@ -41,8 +53,10 @@ impl Chat { println!("added file, path={}", file.to_string_lossy()); files.push(file); } else { - let files = mem::take(&mut files).into_iter().map(Some).collect(); - model.add_user_message(line, files).await?; + let data: Vec<&Path> = files.iter().map(|p| p.as_path()).collect(); + model.add_user_message(line, &data).await?; + files.clear(); + model.chat().await?; } } diff --git a/src/command/complete.rs b/src/command/complete.rs index af86dbf..28c9127 100644 --- a/src/command/complete.rs +++ b/src/command/complete.rs @@ -55,7 +55,7 @@ impl Complete { } state = self.process_line(state, line, &mut model, &mut message, &mut files).await?; } - add_message(&mut model, state, message, files.into_iter().map(Some).collect()).await?; + add_message(&mut model, state, message, files).await?; let assistant_message = model.chat().await?; let mut prompt = fs::OpenOptions::new().append(true).open(&self.prompt).await?; @@ -83,11 +83,10 @@ impl Complete { } return Ok(ParserState::System); } else if line.starts_with("# user") { - let files = mem::take(files).into_iter().map(Some).collect(); - add_message(model, state, mem::take(message), files).await?; + add_message(model, state, mem::take(message), mem::take(files)).await?; return Ok(ParserState::User); } else if line.starts_with("# assistant") { - add_message(model, state, mem::take(message), None).await?; + add_message(model, state, mem::take(message), vec![]).await?; return Ok(ParserState::Assistant); } else if line.starts_with("> file: ") { if !matches!(state, ParserState::User) { @@ -139,12 +138,7 @@ fn extension(file: &Path) -> Result<&str, Exception> { Ok(extension) } -async fn add_message( - model: &mut llm::Model, - state: ParserState, - message: String, - files: Option>, -) -> Result<(), Exception> { +async fn add_message(model: &mut llm::Model, state: ParserState, message: String, files: Vec) -> Result<(), Exception> { match state { ParserState::System => { info!("set system message: {}", message); @@ -152,7 +146,8 @@ async fn add_message( } ParserState::User => { info!("add user message: {}", message); - model.add_user_message(message, files).await?; + let data: Vec<&Path> = files.iter().map(|p| p.as_path()).collect(); + model.add_user_message(message, &data).await?; } ParserState::Assistant => { info!("add assistent message: {}", message); diff --git a/src/gcloud/gemini.rs b/src/gcloud/gemini.rs index a00b973..a10c038 100644 --- a/src/gcloud/gemini.rs +++ b/src/gcloud/gemini.rs @@ -1,7 +1,6 @@ use std::mem; use std::ops::Not; use std::path::Path; -use std::path::PathBuf; use std::rc::Rc; use std::str; @@ -79,9 +78,9 @@ impl Gemini { self.system_instruction = Some(Rc::new(Content::new_model_text(text))); } - pub async fn add_user_text(&mut self, text: String, files: Option>) -> Result<(), Exception> { + pub async fn add_user_text(&mut self, text: String, files: &[&Path]) -> Result<(), Exception> { let data = inline_datas(files).await?; - if data.is_some() { + if !data.is_empty() { self.tools = None; // function call is not supported with inline data } self.add_message(Content::new_user_text(text, data)); @@ -220,16 +219,11 @@ fn is_valid_json(content: &str) -> bool { result.is_ok() } -async fn inline_datas(files: Option>) -> Result>, Exception> { - let data = if let Some(paths) = files { - let mut data = Vec::with_capacity(paths.len()); - for path in paths { - data.push(inline_data(&path).await?); - } - Some(data) - } else { - None - }; +async fn inline_datas(files: &[&Path]) -> Result, Exception> { + let mut data = Vec::with_capacity(files.len()); + for file in files { + data.push(inline_data(file).await?); + } Ok(data) } diff --git a/src/gcloud/gemini_api.rs b/src/gcloud/gemini_api.rs index a1b69dd..0fe6aee 100644 --- a/src/gcloud/gemini_api.rs +++ b/src/gcloud/gemini_api.rs @@ -23,20 +23,15 @@ pub struct Content { } impl Content { - pub fn new_user_text(text: String, data: Option>) -> Self { + pub fn new_user_text(text: String, datas: Vec) -> Self { let mut parts: Vec = vec![]; - if let Some(data) = data { - parts.append( - &mut data - .into_iter() - .map(|d| Part { - text: None, - inline_data: Some(d), - function_call: None, - function_response: None, - }) - .collect(), - ); + for data in datas { + parts.push(Part { + text: None, + inline_data: Some(data), + function_call: None, + function_response: None, + }); } parts.push(Part { text: Some(text), diff --git a/src/llm.rs b/src/llm.rs index 36fca08..dce8794 100644 --- a/src/llm.rs +++ b/src/llm.rs @@ -1,5 +1,4 @@ use std::path::Path; -use std::path::PathBuf; use tokio::fs; use tracing::info; @@ -64,7 +63,7 @@ impl Model { } } - pub async fn add_user_message(&mut self, message: String, files: Option>) -> Result<(), Exception> { + pub async fn add_user_message(&mut self, message: String, files: &[&Path]) -> Result<(), Exception> { match self { Model::ChatGPT(model) => model.add_user_message(message, files).await, Model::Gemini(model) => model.add_user_text(message, files).await,