diff --git a/src/azure/chatgpt.rs b/src/azure/chatgpt.rs index c67ac36..716638c 100644 --- a/src/azure/chatgpt.rs +++ b/src/azure/chatgpt.rs @@ -34,6 +34,7 @@ pub struct ChatGPT { messages: Rc>, tools: Option>, function_store: FunctionStore, + last_assistant_message: String, pub listener: Option>, } @@ -58,11 +59,12 @@ impl ChatGPT { messages: Rc::new(vec![]), tools, function_store, + last_assistant_message: String::new(), listener: None, } } - pub async fn chat(&mut self, message: String, files: Option>) -> Result<(), Exception> { + pub async fn chat(&mut self, message: String, files: Option>) -> Result { let image_urls = image_urls(files).await?; self.add_message(ChatRequestMessage::new_user_message(message, image_urls)); @@ -82,7 +84,7 @@ impl ChatGPT { } self.process().await?; } - Ok(()) + Ok(self.last_assistant_message.to_string()) } pub fn system_message(&mut self, message: String) { @@ -143,6 +145,7 @@ impl ChatGPT { } if !assistant_message.is_empty() { + self.last_assistant_message = assistant_message.to_string(); self.add_message(ChatRequestMessage::new_message(Role::Assistant, assistant_message)); } diff --git a/src/command.rs b/src/command.rs index 1fdbb17..7c48dfe 100644 --- a/src/command.rs +++ b/src/command.rs @@ -1,3 +1,4 @@ pub mod chat; +pub mod complete; pub mod generate_zsh_completion; pub mod speak; diff --git a/src/command/complete.rs b/src/command/complete.rs new file mode 100644 index 0000000..849a65f --- /dev/null +++ b/src/command/complete.rs @@ -0,0 +1,101 @@ +use std::io::Write; +use std::path::PathBuf; + +use clap::Args; +use tokio::fs; +use tokio::io::AsyncBufReadExt; +use tokio::io::AsyncWriteExt; +use tokio::io::BufReader; +use tracing::info; + +use crate::llm; +use crate::llm::ChatEvent; +use crate::llm::ChatListener; +use crate::util::exception::Exception; + +#[derive(Args)] +pub struct Complete { + #[arg(help = "prompt file")] + prompt: PathBuf, + + #[arg(long, help = "conf path")] + conf: PathBuf, + + #[arg(long, help = "model name")] + name: String, +} + +struct Listener; + +impl ChatListener for Listener { + fn on_event(&self, event: ChatEvent) { + match event { + ChatEvent::Delta(data) => { + print!("{data}"); + let _ = std::io::stdout().flush(); + } + ChatEvent::End(usage) => { + println!(); + info!( + "usage, request_tokens={}, response_tokens={}", + usage.request_tokens, usage.response_tokens + ); + } + } + } +} + +impl Complete { + pub async fn execute(&self) -> Result<(), Exception> { + let config = llm::load(&self.conf).await?; + let mut model = config.create(&self.name)?; + model.listener(Box::new(Listener)); + + let prompt = fs::OpenOptions::new().read(true).open(&self.prompt).await?; + let reader = BufReader::new(prompt); + let mut lines = reader.lines(); + + let mut files: Vec = vec![]; + let mut message = String::new(); + let mut on_system_message = false; + loop { + let Some(line) = lines.next_line().await? else { break }; + + if line.is_empty() { + continue; + } + + if line.starts_with("# system") { + if !message.is_empty() { + return Err(Exception::ValidationError("system message must be at first".to_string())); + } + on_system_message = true; + } else if line.starts_with("---") || line.starts_with("# file: ") { + if on_system_message { + info!("system message: {}", message); + model.system_message(message); + message = String::new(); + on_system_message = false; + } + if line.starts_with("# file: ") { + let file = PathBuf::from(line.strip_prefix("# file: ").unwrap().to_string()); + info!("file: {}", file.to_string_lossy()); + files.push(file); + } + } else { + message.push_str(&line); + message.push('\n'); + } + } + + info!("prompt: {}", message); + let files = files.into_iter().map(Some).collect(); + let message = model.chat(message, files).await?; + + let mut prompt = fs::OpenOptions::new().append(true).open(&self.prompt).await?; + prompt.write_all(b"\n---\n\n").await?; + prompt.write_all(message.as_bytes()).await?; + + Ok(()) + } +} diff --git a/src/gcloud/gemini.rs b/src/gcloud/gemini.rs index 5c7289f..547fb22 100644 --- a/src/gcloud/gemini.rs +++ b/src/gcloud/gemini.rs @@ -34,10 +34,11 @@ use crate::util::json; pub struct Gemini { url: String, messages: Rc>, - system_message: Option>, + system_instruction: Option>, tools: Option>, function_store: FunctionStore, usage: Usage, + last_model_message: String, pub listener: Option>, } @@ -47,17 +48,18 @@ impl Gemini { Gemini { url, messages: Rc::new(vec![]), - system_message: None, + system_instruction: None, tools: function_store.declarations.is_empty().not().then_some(Rc::from(vec![Tool { function_declarations: function_store.declarations.to_vec(), }])), function_store, usage: Usage::default(), + last_model_message: String::with_capacity(1024), listener: None, } } - pub async fn chat(&mut self, message: String, files: Option>) -> Result<(), Exception> { + pub async fn chat(&mut self, message: String, files: Option>) -> Result { let data = inline_datas(files).await?; self.add_message(Content::new_user_text(message, data)); @@ -67,11 +69,11 @@ impl Gemini { self.add_message(Content::new_function_response(function_call.name, function_response)); result = self.process().await?; } - Ok(()) + Ok(self.last_model_message.to_string()) } - pub fn system_message(&mut self, message: String) { - self.system_message = Some(Rc::new(Content::new_model_text(message))); + pub fn system_instruction(&mut self, message: String) { + self.system_instruction = Some(Rc::new(Content::new_model_text(message))); } async fn process(&mut self) -> Result, Exception> { @@ -126,6 +128,7 @@ impl Gemini { } if !model_message.is_empty() { + self.last_model_message = model_message.to_string(); self.add_message(Content::new_model_text(model_message)); } @@ -144,7 +147,7 @@ impl Gemini { async fn call_api(&self) -> Result { let request = StreamGenerateContent { contents: Rc::clone(&self.messages), - system_instruction: self.system_message.clone(), + system_instruction: self.system_instruction.clone(), generation_config: GenerationConfig { temperature: 1.0, top_p: 0.95, diff --git a/src/llm.rs b/src/llm.rs index 2b16c25..e01804f 100644 --- a/src/llm.rs +++ b/src/llm.rs @@ -34,7 +34,7 @@ pub enum Model { } impl Model { - pub async fn chat(&mut self, message: String, files: Option>) -> Result<(), Exception> { + pub async fn chat(&mut self, message: String, files: Option>) -> Result { match self { Model::ChatGPT(model) => model.chat(message, files).await, Model::Gemini(model) => model.chat(message, files).await, @@ -51,7 +51,7 @@ impl Model { pub fn system_message(&mut self, message: String) { match self { Model::ChatGPT(model) => model.system_message(message), - Model::Gemini(model) => model.system_message(message), + Model::Gemini(model) => model.system_instruction(message), } } } diff --git a/src/main.rs b/src/main.rs index d3e44ba..5bdf291 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,6 +1,7 @@ use clap::Parser; use clap::Subcommand; use command::chat::Chat; +use command::complete::Complete; use command::generate_zsh_completion::GenerateZshCompletion; use command::speak::Speak; use util::exception::Exception; @@ -28,6 +29,8 @@ pub enum Command { Chat(Chat), #[command(about = "speak")] Speech(Speak), + #[command(about = "complete")] + Complete(Complete), #[command(about = "generate zsh completion")] GenerateZshCompletion(GenerateZshCompletion), } @@ -39,6 +42,7 @@ async fn main() -> Result<(), Exception> { match cli.command { Some(Command::Chat(command)) => command.execute().await, Some(Command::Speech(command)) => command.execute().await, + Some(Command::Complete(command)) => command.execute().await, Some(Command::GenerateZshCompletion(command)) => command.execute(), None => panic!("not implemented"), }