diff --git a/.vscode/launch.json b/.vscode/launch.json deleted file mode 100644 index 959b4ce..0000000 --- a/.vscode/launch.json +++ /dev/null @@ -1,49 +0,0 @@ -{ - // Use IntelliSense to learn about possible attributes. - // Hover to view descriptions of existing attributes. - // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 - "version": "0.2.0", - "configurations": [ - { - "type": "lldb", - "request": "launch", - "name": "Debug executable 'puppet'", - "cargo": { - "args": [ - "build", - "--bin=puppet", - "--package=puppet" - ], - "filter": { - "name": "puppet", - "kind": "bin" - } - }, - "args": [ - "chat", - "--conf", - "./env/conf.json" - ], - "cwd": "${workspaceFolder}" - }, - { - "type": "lldb", - "request": "launch", - "name": "Debug unit tests in executable 'puppet'", - "cargo": { - "args": [ - "test", - "--no-run", - "--bin=puppet", - "--package=puppet" - ], - "filter": { - "name": "puppet", - "kind": "bin" - } - }, - "args": [], - "cwd": "${workspaceFolder}" - } - ] -} \ No newline at end of file diff --git a/.vscode/settings.json b/.vscode/settings.json deleted file mode 100644 index ad7705f..0000000 --- a/.vscode/settings.json +++ /dev/null @@ -1,27 +0,0 @@ -{ - "workbench.colorCustomizations": { - "activityBar.activeBackground": "#c8adc3", - "activityBar.background": "#c8adc3", - "activityBar.foreground": "#15202b", - "activityBar.inactiveForeground": "#15202b99", - "activityBarBadge.background": "#6b724d", - "activityBarBadge.foreground": "#e7e7e7", - "commandCenter.border": "#15202b99", - "sash.hoverBorder": "#c8adc3", - "statusBar.background": "#b48ead", - "statusBar.foreground": "#15202b", - "statusBarItem.hoverBackground": "#a06f97", - "statusBarItem.remoteBackground": "#b48ead", - "statusBarItem.remoteForeground": "#15202b", - "titleBar.activeBackground": "#b48ead", - "titleBar.activeForeground": "#15202b", - "titleBar.inactiveBackground": "#b48ead99", - "titleBar.inactiveForeground": "#15202b99" - }, - "peacock.color": "#B48EAD", - "cSpell.words": [ - "gcloud", - "openai" - ], - "rust-analyzer.showUnlinkedFileNotification": false -} \ No newline at end of file diff --git a/src/azure/chatgpt.rs b/src/azure/chatgpt.rs index d414cd9..41779b3 100644 --- a/src/azure/chatgpt.rs +++ b/src/azure/chatgpt.rs @@ -29,21 +29,24 @@ use crate::util::exception::Exception; use crate::util::http_client; use crate::util::json; -pub struct ChatGPT { +pub struct ChatGPT +where + L: ChatListener, +{ url: String, api_key: String, messages: Rc>, tools: Option>, function_store: FunctionStore, + listener: Option, pub option: Option, - pub listener: Option>, last_assistant_message: String, } type FunctionCall = HashMap; -impl ChatGPT { - pub fn new(endpoint: String, model: String, api_key: String, function_store: FunctionStore) -> Self { +impl ChatGPT { + pub fn new(endpoint: String, model: String, api_key: String, function_store: FunctionStore, listener: Option) -> Self { let url = format!("{endpoint}/openai/deployments/{model}/chat/completions?api-version=2024-06-01"); let tools: Option> = function_store.declarations.is_empty().not().then_some( function_store @@ -61,8 +64,8 @@ impl ChatGPT { messages: Rc::new(vec![]), tools, function_store, + listener, last_assistant_message: String::new(), - listener: None, option: None, } } @@ -141,7 +144,7 @@ impl ChatGPT { assistant_message.push_str(&value); if let Some(listener) = self.listener.as_ref() { - listener.on_event(ChatEvent::Delta(value)); + listener.on_event(ChatEvent::Delta(value)).await?; } } } @@ -163,7 +166,7 @@ impl ChatGPT { Ok(Some(function_calls)) } else { if let Some(listener) = self.listener.as_ref() { - listener.on_event(ChatEvent::End(usage)); + listener.on_event(ChatEvent::End(usage)).await?; } Ok(None) } diff --git a/src/command/chat.rs b/src/command/chat.rs index 81f7508..f988fef 100644 --- a/src/command/chat.rs +++ b/src/command/chat.rs @@ -1,4 +1,3 @@ -use std::io::Write; use std::mem; use std::path::PathBuf; @@ -6,11 +5,10 @@ use clap::Args; use tokio::io::stdin; use tokio::io::AsyncBufReadExt; use tokio::io::BufReader; -use tracing::info; use crate::llm; -use crate::llm::ChatEvent; -use crate::llm::ChatListener; +use crate::llm::ConsolePrinter; +use crate::util::console; use crate::util::exception::Exception; #[derive(Args)] @@ -22,38 +20,19 @@ pub struct Chat { name: String, } -struct ConsoleHandler; - -impl ChatListener for ConsoleHandler { - fn on_event(&self, event: ChatEvent) { - match event { - ChatEvent::Delta(data) => { - print_flush(&data).unwrap(); - } - ChatEvent::End(usage) => { - println!(); - info!( - "usage, request_tokens={}, response_tokens={}", - usage.request_tokens, usage.response_tokens - ); - } - } - } -} - impl Chat { pub async fn execute(&self) -> Result<(), Exception> { let config = llm::load(&self.conf).await?; - let mut model = config.create(&self.name)?; - model.listener(Box::new(ConsoleHandler)); + let mut model = config.create(&self.name, Some(ConsolePrinter))?; let reader = BufReader::new(stdin()); let mut lines = reader.lines(); - let mut files: Vec = vec![]; loop { - print_flush("> ")?; - let Some(line) = lines.next_line().await? else { break }; + console::print("> ").await?; + let Some(line) = lines.next_line().await? else { + break; + }; if line.starts_with("/quit") { break; } @@ -71,9 +50,3 @@ impl Chat { Ok(()) } } - -fn print_flush(message: &str) -> Result<(), Exception> { - print!("{message}"); - std::io::stdout().flush()?; - Ok(()) -} diff --git a/src/command/complete.rs b/src/command/complete.rs index 668d015..bea5f56 100644 --- a/src/command/complete.rs +++ b/src/command/complete.rs @@ -1,4 +1,3 @@ -use std::io::Write; use std::mem; use std::path::PathBuf; use std::str::FromStr; @@ -12,9 +11,9 @@ use tokio::io::BufReader; use tracing::info; use crate::llm; -use crate::llm::ChatEvent; use crate::llm::ChatListener; use crate::llm::ChatOption; +use crate::llm::ConsolePrinter; use crate::util::exception::Exception; #[derive(Args)] @@ -29,26 +28,6 @@ pub struct Complete { name: String, } -struct Listener; - -impl ChatListener for Listener { - fn on_event(&self, event: ChatEvent) { - match event { - ChatEvent::Delta(data) => { - print!("{data}"); - let _ = std::io::stdout().flush(); - } - ChatEvent::End(usage) => { - println!(); - info!( - "usage, request_tokens={}, response_tokens={}", - usage.request_tokens, usage.response_tokens - ); - } - } - } -} - enum ParserState { System, User, @@ -58,8 +37,7 @@ enum ParserState { impl Complete { pub async fn execute(&self) -> Result<(), Exception> { let config = llm::load(&self.conf).await?; - let mut model = config.create(&self.name)?; - model.listener(Box::new(Listener)); + let mut model = config.create(&self.name, Some(ConsolePrinter))?; let prompt = fs::OpenOptions::new().read(true).open(&self.prompt).await?; let reader = BufReader::new(prompt); @@ -68,9 +46,7 @@ impl Complete { let mut files: Vec = vec![]; let mut message = String::new(); let mut state = ParserState::User; - loop { - let Some(line) = lines.next_line().await? else { break }; - + while let Some(line) = lines.next_line().await? { if line.is_empty() { continue; } @@ -122,7 +98,10 @@ impl Complete { } } -async fn add_message(model: &mut llm::Model, state: &ParserState, message: String, files: Vec) -> Result<(), Exception> { +async fn add_message(model: &mut llm::Model, state: &ParserState, message: String, files: Vec) -> Result<(), Exception> +where + L: ChatListener, +{ match state { ParserState::System => { info!("system message: {}", message); diff --git a/src/gcloud/gemini.rs b/src/gcloud/gemini.rs index 4656343..a00b973 100644 --- a/src/gcloud/gemini.rs +++ b/src/gcloud/gemini.rs @@ -32,20 +32,23 @@ use crate::util::exception::Exception; use crate::util::http_client; use crate::util::json; -pub struct Gemini { +pub struct Gemini +where + L: ChatListener, +{ url: String, messages: Rc>, system_instruction: Option>, tools: Option>, function_store: FunctionStore, + listener: Option, pub option: Option, - pub listener: Option>, last_model_message: String, usage: Usage, } -impl Gemini { - pub fn new(endpoint: String, project: String, location: String, model: String, function_store: FunctionStore) -> Self { +impl Gemini { + pub fn new(endpoint: String, project: String, location: String, model: String, function_store: FunctionStore, listener: Option) -> Self { let url = format!("{endpoint}/v1/projects/{project}/locations/{location}/publishers/google/models/{model}:streamGenerateContent"); Gemini { url, @@ -55,8 +58,8 @@ impl Gemini { function_declarations: function_store.declarations.to_vec(), }])), function_store, + listener, option: None, - listener: None, last_model_message: String::with_capacity(1024), usage: Usage::default(), } @@ -129,7 +132,7 @@ impl Gemini { } else if let Some(text) = part.text { model_message.push_str(&text); if let Some(listener) = self.listener.as_ref() { - listener.on_event(ChatEvent::Delta(text)); + listener.on_event(ChatEvent::Delta(text)).await?; } } } @@ -147,7 +150,7 @@ impl Gemini { let usage = mem::take(&mut self.usage); if let Some(listener) = self.listener.as_ref() { - listener.on_event(ChatEvent::End(usage)); + listener.on_event(ChatEvent::End(usage)).await?; } Ok(None) diff --git a/src/llm.rs b/src/llm.rs index 84a6040..d45f892 100644 --- a/src/llm.rs +++ b/src/llm.rs @@ -7,6 +7,7 @@ use tracing::info; use crate::azure::chatgpt::ChatGPT; use crate::gcloud::gemini::Gemini; use crate::llm::config::Config; +use crate::util::console; use crate::util::exception::Exception; use crate::util::json; @@ -14,7 +15,7 @@ pub mod config; pub mod function; pub trait ChatListener { - fn on_event(&self, event: ChatEvent); + async fn on_event(&self, event: ChatEvent) -> Result<(), Exception>; } pub enum ChatEvent { @@ -33,12 +34,15 @@ pub struct Usage { pub response_tokens: i32, } -pub enum Model { - ChatGPT(ChatGPT), - Gemini(Gemini), +pub enum Model +where + L: ChatListener, +{ + ChatGPT(ChatGPT), + Gemini(Gemini), } -impl Model { +impl Model { pub async fn chat(&mut self) -> Result { match self { Model::ChatGPT(model) => model.chat().await, @@ -46,13 +50,6 @@ impl Model { } } - pub fn listener(&mut self, listener: Box) { - match self { - Model::ChatGPT(model) => model.listener = Some(listener), - Model::Gemini(model) => model.listener = Some(listener), - } - } - pub fn system_message(&mut self, message: String) { match self { Model::ChatGPT(model) => model.system_message(message), @@ -88,3 +85,24 @@ pub async fn load(path: &Path) -> Result { let config: Config = json::from_json(&content)?; Ok(config) } + +pub struct ConsolePrinter; + +impl ChatListener for ConsolePrinter { + async fn on_event(&self, event: ChatEvent) -> Result<(), Exception> { + match event { + ChatEvent::Delta(data) => { + console::print(&data).await?; + Ok(()) + } + ChatEvent::End(usage) => { + console::print("\n").await?; + info!( + "usage, request_tokens={}, response_tokens={}", + usage.request_tokens, usage.response_tokens + ); + Ok(()) + } + } + } +} diff --git a/src/llm/config.rs b/src/llm/config.rs index f91e2d7..b72611d 100644 --- a/src/llm/config.rs +++ b/src/llm/config.rs @@ -4,6 +4,7 @@ use rand::Rng; use serde::Deserialize; use serde_json::json; +use super::ChatListener; use crate::azure::chatgpt::ChatGPT; use crate::gcloud::gemini::Gemini; use crate::llm::function::Function; @@ -27,7 +28,10 @@ pub struct ModelConfig { } impl Config { - pub fn create(&self, name: &str) -> Result { + pub fn create(&self, name: &str, listener: Option) -> Result, Exception> + where + L: ChatListener, + { let config = self .models .get(name) @@ -41,6 +45,7 @@ impl Config { config.params.get("model").unwrap().to_string(), config.params.get("api_key").unwrap().to_string(), function_store, + listener, )), Provider::GCloud => Model::Gemini(Gemini::new( config.endpoint.to_string(), @@ -48,6 +53,7 @@ impl Config { config.params.get("location").unwrap().to_string(), config.params.get("model").unwrap().to_string(), function_store, + listener, )), }; diff --git a/src/util.rs b/src/util.rs index 3b0844d..f99495f 100644 --- a/src/util.rs +++ b/src/util.rs @@ -1,3 +1,4 @@ +pub mod console; pub mod exception; pub mod http_client; pub mod json; diff --git a/src/util/console.rs b/src/util/console.rs new file mode 100644 index 0000000..4efb461 --- /dev/null +++ b/src/util/console.rs @@ -0,0 +1,11 @@ +use tokio::io::stdout; +use tokio::io::AsyncWriteExt; + +use super::exception::Exception; + +pub async fn print(text: &str) -> Result<(), Exception> { + let out = &mut stdout(); + out.write_all(text.as_bytes()).await?; + out.flush().await?; + Ok(()) +}