Skip to content

Commit

Permalink
support glob in complete file including
Browse files Browse the repository at this point in the history
  • Loading branch information
neowu committed Jul 16, 2024
1 parent 989845e commit eb0cc8d
Show file tree
Hide file tree
Showing 3 changed files with 120 additions and 55 deletions.
7 changes: 7 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,4 @@ base64 = "0"
uuid = { version = "1", features = ["v4"] }
bytes = "1"
regex = "1"
glob = "0"
167 changes: 112 additions & 55 deletions src/command/complete.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
use std::mem;
use std::path::Path;
use std::path::PathBuf;
use std::str::FromStr;

use clap::Args;
use glob::glob;
use glob::GlobError;
use glob::PatternError;
use regex::Regex;
use tokio::fs;
use tokio::io::AsyncBufReadExt;
Expand All @@ -11,7 +15,6 @@ use tokio::io::BufReader;
use tracing::info;

use crate::llm;
use crate::llm::ChatListener;
use crate::llm::ChatOption;
use crate::llm::ConsolePrinter;
use crate::util::exception::Exception;
Expand Down Expand Up @@ -50,75 +53,109 @@ impl Complete {
if line.is_empty() {
continue;
}

if line.starts_with("# system") {
if !message.is_empty() {
return Err(Exception::ValidationError("system message must be at first".to_string()));
}
state = ParserState::System;
if let Some(option) = parse_option(&line) {
info!("option: {:?}", option);
model.option(option);
}
} else if line.starts_with("# user") {
add_message(&mut model, &state, mem::take(&mut message), mem::take(&mut files)).await?;
state = ParserState::User;
} else if line.starts_with("# assistant") {
add_message(&mut model, &state, mem::take(&mut message), mem::take(&mut files)).await?;
state = ParserState::Assistant;
} else if line.starts_with("> file: ") {
let file = self.prompt.with_file_name(line.strip_prefix("> file: ").unwrap());
let extension = file
.extension()
.ok_or_else(|| Exception::ValidationError(format!("file must have extension, path={}", file.to_string_lossy())))?
.to_str()
.unwrap();
if extension == "txt" {
message.push_str(&format!("> start of file: {}\n", &file.to_string_lossy()));
message.push_str(&fs::read_to_string(&file).await?);
message.push_str(&format!("> end of file: {}\n", &file.to_string_lossy()));
} else {
info!("file: {}", file.to_string_lossy());
files.push(file);
}
} else {
message.push_str(&line);
message.push('\n');
}
state = self.process_line(state, line, &mut model, &mut message, &mut files).await?;
}
add_message(&mut model, state, message, files.into_iter().map(Some).collect()).await?;

add_message(&mut model, &state, message, files).await?;
let message = model.chat().await?;

let assistant_message = model.chat().await?;
let mut prompt = fs::OpenOptions::new().append(true).open(&self.prompt).await?;
prompt.write_all(format!("\n# assistant ({})\n\n", self.name).as_bytes()).await?;
prompt.write_all(message.as_bytes()).await?;

prompt.write_all(assistant_message.as_bytes()).await?;
prompt.write_all(b"\n").await?;
Ok(())
}

async fn process_line(
&self,
state: ParserState,
line: String,
model: &mut llm::Model<ConsolePrinter>,
message: &mut String,
files: &mut Vec<PathBuf>,
) -> Result<ParserState, Exception> {
if line.starts_with("# system") {
if !message.is_empty() {
return Err(Exception::ValidationError("system message must be at first".to_string()));
}
if let Some(option) = parse_option(&line) {
info!("option: {:?}", option);
model.option(option);
}
return Ok(ParserState::System);
} else if line.starts_with("# user") {
let files = mem::take(files).into_iter().map(Some).collect();
add_message(model, state, mem::take(message), files).await?;
return Ok(ParserState::User);
} else if line.starts_with("# assistant") {
add_message(model, state, mem::take(message), None).await?;
return Ok(ParserState::Assistant);
} else if line.starts_with("> file: ") {
if !matches!(state, ParserState::User) {
return Err(Exception::ValidationError(format!(
"file can only be included in user message, line={line}"
)));
}

let pattern = self.pattern(line.strip_prefix("> file: ").unwrap());
for entry in glob(&pattern)? {
let entry = entry?;
let extension = extension(&entry)?;
match extension {
"txt" | "md" => {
message.push_str(&fs::read_to_string(entry).await?);
}
"java" | "rs" => {
message.push_str(&format!("```{} (path: {})\n", language(extension)?, entry.to_string_lossy()));
message.push_str(&fs::read_to_string(entry).await?);
message.push_str("```\n");
}
_ => {
info!("file: {}", entry.to_string_lossy());
files.push(entry);
}
}
}
} else {
message.push_str(&line);
message.push('\n');
}
Ok(state)
}

fn pattern(&self, pattern: &str) -> String {
if !pattern.starts_with('/') {
return format!("{}/{}", self.prompt.parent().unwrap().to_string_lossy(), pattern);
}
pattern.to_string()
}
}

fn extension(file: &Path) -> Result<&str, Exception> {
let extension = file
.extension()
.ok_or_else(|| Exception::ValidationError(format!("file must have a valid extension, path={}", file.to_string_lossy())))?
.to_str()
.unwrap();
Ok(extension)
}

async fn add_message<L>(model: &mut llm::Model<L>, state: &ParserState, message: String, files: Vec<PathBuf>) -> Result<(), Exception>
where
L: ChatListener,
{
async fn add_message(
model: &mut llm::Model<ConsolePrinter>,
state: ParserState,
message: String,
files: Option<Vec<PathBuf>>,
) -> Result<(), Exception> {
match state {
ParserState::System => {
info!("system message: {}", message);
info!("set system message: {}", message);
model.system_message(message);
}
ParserState::User => {
info!("user message: {}", message);
model.add_user_message(message, files.into_iter().map(Some).collect()).await?;
info!("add user message: {}", message);
model.add_user_message(message, files).await?;
}
ParserState::Assistant => {
if !files.is_empty() {
return Err(Exception::ValidationError(format!(
"cannot include file in assistant message, files={:?}",
files
)));
}
info!("assistent message: {}", message);
info!("add assistent message: {}", message);
model.add_assistant_message(message);
}
}
Expand All @@ -135,6 +172,26 @@ fn parse_option(line: &str) -> Option<ChatOption> {
}
}

fn language(extenstion: &str) -> Result<&'static str, Exception> {
match extenstion {
"java" => Ok("java"),
"rs" => Ok("rust"),
_ => Err(Exception::ValidationError(format!("unsupported extension, ext={}", extenstion))),
}
}

impl From<PatternError> for Exception {
fn from(err: PatternError) -> Self {
Exception::unexpected(err)
}
}

impl From<GlobError> for Exception {
fn from(err: GlobError) -> Self {
Exception::unexpected(err)
}
}

#[cfg(test)]
mod tests {
#[test]
Expand Down

0 comments on commit eb0cc8d

Please sign in to comment.