From eb0cc8da036625f5a9a6d7ccbb4deef3b0c7b1dd Mon Sep 17 00:00:00 2001 From: neo <1100909+neowu@users.noreply.github.com> Date: Tue, 16 Jul 2024 10:56:44 +0800 Subject: [PATCH] support glob in complete file including --- Cargo.lock | 7 ++ Cargo.toml | 1 + src/command/complete.rs | 167 +++++++++++++++++++++++++++------------- 3 files changed, 120 insertions(+), 55 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index ecfe5cc..29bf81d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -332,6 +332,12 @@ version = "0.29.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "40ecd4077b5ae9fd2e9e169b102c6c330d0605168eb0e8bf79952b256dbefffd" +[[package]] +name = "glob" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d2fabcfbdc87f4758337ca535fb41a6d701b65693ce38287d856d1674551ec9b" + [[package]] name = "h2" version = "0.4.5" @@ -787,6 +793,7 @@ dependencies = [ "bytes", "clap", "clap_complete", + "glob", "rand", "regex", "reqwest", diff --git a/Cargo.toml b/Cargo.toml index 1cda909..55e0989 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,3 +19,4 @@ base64 = "0" uuid = { version = "1", features = ["v4"] } bytes = "1" regex = "1" +glob = "0" diff --git a/src/command/complete.rs b/src/command/complete.rs index bea5f56..25ce42f 100644 --- a/src/command/complete.rs +++ b/src/command/complete.rs @@ -1,8 +1,12 @@ use std::mem; +use std::path::Path; use std::path::PathBuf; use std::str::FromStr; use clap::Args; +use glob::glob; +use glob::GlobError; +use glob::PatternError; use regex::Regex; use tokio::fs; use tokio::io::AsyncBufReadExt; @@ -11,7 +15,6 @@ use tokio::io::BufReader; use tracing::info; use crate::llm; -use crate::llm::ChatListener; use crate::llm::ChatOption; use crate::llm::ConsolePrinter; use crate::util::exception::Exception; @@ -50,75 +53,109 @@ impl Complete { if line.is_empty() { continue; } - - if line.starts_with("# system") { - if !message.is_empty() { - return Err(Exception::ValidationError("system message must be at first".to_string())); - } - state = ParserState::System; - if let Some(option) = parse_option(&line) { - info!("option: {:?}", option); - model.option(option); - } - } else if line.starts_with("# user") { - add_message(&mut model, &state, mem::take(&mut message), mem::take(&mut files)).await?; - state = ParserState::User; - } else if line.starts_with("# assistant") { - add_message(&mut model, &state, mem::take(&mut message), mem::take(&mut files)).await?; - state = ParserState::Assistant; - } else if line.starts_with("> file: ") { - let file = self.prompt.with_file_name(line.strip_prefix("> file: ").unwrap()); - let extension = file - .extension() - .ok_or_else(|| Exception::ValidationError(format!("file must have extension, path={}", file.to_string_lossy())))? - .to_str() - .unwrap(); - if extension == "txt" { - message.push_str(&format!("> start of file: {}\n", &file.to_string_lossy())); - message.push_str(&fs::read_to_string(&file).await?); - message.push_str(&format!("> end of file: {}\n", &file.to_string_lossy())); - } else { - info!("file: {}", file.to_string_lossy()); - files.push(file); - } - } else { - message.push_str(&line); - message.push('\n'); - } + state = self.process_line(state, line, &mut model, &mut message, &mut files).await?; } + add_message(&mut model, state, message, files.into_iter().map(Some).collect()).await?; - add_message(&mut model, &state, message, files).await?; - let message = model.chat().await?; - + let assistant_message = model.chat().await?; let mut prompt = fs::OpenOptions::new().append(true).open(&self.prompt).await?; prompt.write_all(format!("\n# assistant ({})\n\n", self.name).as_bytes()).await?; - prompt.write_all(message.as_bytes()).await?; - + prompt.write_all(assistant_message.as_bytes()).await?; + prompt.write_all(b"\n").await?; Ok(()) } + + async fn process_line( + &self, + state: ParserState, + line: String, + model: &mut llm::Model, + message: &mut String, + files: &mut Vec, + ) -> Result { + if line.starts_with("# system") { + if !message.is_empty() { + return Err(Exception::ValidationError("system message must be at first".to_string())); + } + if let Some(option) = parse_option(&line) { + info!("option: {:?}", option); + model.option(option); + } + return Ok(ParserState::System); + } else if line.starts_with("# user") { + let files = mem::take(files).into_iter().map(Some).collect(); + add_message(model, state, mem::take(message), files).await?; + return Ok(ParserState::User); + } else if line.starts_with("# assistant") { + add_message(model, state, mem::take(message), None).await?; + return Ok(ParserState::Assistant); + } else if line.starts_with("> file: ") { + if !matches!(state, ParserState::User) { + return Err(Exception::ValidationError(format!( + "file can only be included in user message, line={line}" + ))); + } + + let pattern = self.pattern(line.strip_prefix("> file: ").unwrap()); + for entry in glob(&pattern)? { + let entry = entry?; + let extension = extension(&entry)?; + match extension { + "txt" | "md" => { + message.push_str(&fs::read_to_string(entry).await?); + } + "java" | "rs" => { + message.push_str(&format!("```{} (path: {})\n", language(extension)?, entry.to_string_lossy())); + message.push_str(&fs::read_to_string(entry).await?); + message.push_str("```\n"); + } + _ => { + info!("file: {}", entry.to_string_lossy()); + files.push(entry); + } + } + } + } else { + message.push_str(&line); + message.push('\n'); + } + Ok(state) + } + + fn pattern(&self, pattern: &str) -> String { + if !pattern.starts_with('/') { + return format!("{}/{}", self.prompt.parent().unwrap().to_string_lossy(), pattern); + } + pattern.to_string() + } +} + +fn extension(file: &Path) -> Result<&str, Exception> { + let extension = file + .extension() + .ok_or_else(|| Exception::ValidationError(format!("file must have a valid extension, path={}", file.to_string_lossy())))? + .to_str() + .unwrap(); + Ok(extension) } -async fn add_message(model: &mut llm::Model, state: &ParserState, message: String, files: Vec) -> Result<(), Exception> -where - L: ChatListener, -{ +async fn add_message( + model: &mut llm::Model, + state: ParserState, + message: String, + files: Option>, +) -> Result<(), Exception> { match state { ParserState::System => { - info!("system message: {}", message); + info!("set system message: {}", message); model.system_message(message); } ParserState::User => { - info!("user message: {}", message); - model.add_user_message(message, files.into_iter().map(Some).collect()).await?; + info!("add user message: {}", message); + model.add_user_message(message, files).await?; } ParserState::Assistant => { - if !files.is_empty() { - return Err(Exception::ValidationError(format!( - "cannot include file in assistant message, files={:?}", - files - ))); - } - info!("assistent message: {}", message); + info!("add assistent message: {}", message); model.add_assistant_message(message); } } @@ -135,6 +172,26 @@ fn parse_option(line: &str) -> Option { } } +fn language(extenstion: &str) -> Result<&'static str, Exception> { + match extenstion { + "java" => Ok("java"), + "rs" => Ok("rust"), + _ => Err(Exception::ValidationError(format!("unsupported extension, ext={}", extenstion))), + } +} + +impl From for Exception { + fn from(err: PatternError) -> Self { + Exception::unexpected(err) + } +} + +impl From for Exception { + fn from(err: GlobError) -> Self { + Exception::unexpected(err) + } +} + #[cfg(test)] mod tests { #[test]