From 31b6b540f37f08b8f18c3c5bae067b20b1472196 Mon Sep 17 00:00:00 2001 From: neo <1100909+neowu@users.noreply.github.com> Date: Fri, 12 Jul 2024 16:21:15 +0800 Subject: [PATCH] tweak complete --- .gitignore | 2 +- Cargo.lock | 39 +++++++++++++++++++++++++++++++++++++++ Cargo.toml | 1 + src/azure/chatgpt.rs | 7 +++++-- src/command/complete.rs | 28 +++++++++++++++++++++++++++- src/gcloud/gemini.rs | 13 ++++++++----- src/llm.rs | 12 ++++++++++++ 7 files changed, 93 insertions(+), 9 deletions(-) diff --git a/.gitignore b/.gitignore index 968c470..15e9261 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,3 @@ /target /env -tests +/prompts diff --git a/Cargo.lock b/Cargo.lock index 613ab1e..ecfe5cc 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -17,6 +17,15 @@ version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe" +[[package]] +name = "aho-corasick" +version = "1.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e60d3430d3a69478ad0993f19238d2df97c507009a52b3c10addcd7f6bcb916" +dependencies = [ + "memchr", +] + [[package]] name = "anstream" version = "0.6.14" @@ -779,6 +788,7 @@ dependencies = [ "clap", "clap_complete", "rand", + "regex", "reqwest", "serde", "serde_json", @@ -836,6 +846,35 @@ dependencies = [ "bitflags 2.6.0", ] +[[package]] +name = "regex" +version = "1.10.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b91213439dad192326a0d7c6ee3955910425f441d7038e0d6933b0aec5c4517f" +dependencies = [ + "aho-corasick", + "memchr", + "regex-automata", + "regex-syntax", +] + +[[package]] +name = "regex-automata" +version = "0.4.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "38caf58cc5ef2fed281f89292ef23f6365465ed9a41b7a7754eb4e26496c92df" +dependencies = [ + "aho-corasick", + "memchr", + "regex-syntax", +] + +[[package]] +name = "regex-syntax" +version = "0.8.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7a66a03ae7c801facd77a29370b4faec201768915ac14a721ba36f20bc9c209b" + [[package]] name = "reqwest" version = "0.12.5" diff --git a/Cargo.toml b/Cargo.toml index 8a4c9a7..1cda909 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -18,3 +18,4 @@ rand = "0" base64 = "0" uuid = { version = "1", features = ["v4"] } bytes = "1" +regex = "1" diff --git a/src/azure/chatgpt.rs b/src/azure/chatgpt.rs index 936383a..ed8ce12 100644 --- a/src/azure/chatgpt.rs +++ b/src/azure/chatgpt.rs @@ -23,6 +23,7 @@ use crate::azure::chatgpt_api::Tool; use crate::llm::function::FunctionStore; use crate::llm::ChatEvent; use crate::llm::ChatListener; +use crate::llm::ChatOption; use crate::llm::Usage; use crate::util::exception::Exception; use crate::util::http_client; @@ -34,8 +35,9 @@ pub struct ChatGPT { messages: Rc>, tools: Option>, function_store: FunctionStore, - last_assistant_message: String, + pub option: Option, pub listener: Option>, + last_assistant_message: String, } type FunctionCall = HashMap; @@ -61,6 +63,7 @@ impl ChatGPT { function_store, last_assistant_message: String::new(), listener: None, + option: None, } } @@ -162,7 +165,7 @@ impl ChatGPT { async fn call_api(&mut self) -> Result { let request = ChatRequest { messages: Rc::clone(&self.messages), - temperature: 0.7, + temperature: self.option.as_ref().map_or(0.7, |option| option.temperature), top_p: 0.95, stream: true, // stream_options: Some(StreamOptions { include_usage: true }), diff --git a/src/command/complete.rs b/src/command/complete.rs index 52eba4a..9a47d72 100644 --- a/src/command/complete.rs +++ b/src/command/complete.rs @@ -1,7 +1,9 @@ use std::io::Write; use std::path::PathBuf; +use std::str::FromStr; use clap::Args; +use regex::Regex; use tokio::fs; use tokio::io::AsyncBufReadExt; use tokio::io::AsyncWriteExt; @@ -11,11 +13,12 @@ use tracing::info; use crate::llm; use crate::llm::ChatEvent; use crate::llm::ChatListener; +use crate::llm::ChatOption; use crate::util::exception::Exception; #[derive(Args)] pub struct Complete { - #[arg(help = "prompt file")] + #[arg(help = "prompt file path")] prompt: PathBuf, #[arg(long, help = "conf path")] @@ -70,6 +73,10 @@ impl Complete { return Err(Exception::ValidationError("system message must be at first".to_string())); } on_system_message = true; + if let Some(option) = parse_option(&line) { + info!("option: {:?}", option); + model.option(option); + } } else if line.starts_with("# prompt") { if on_system_message { info!("system message: {}", message); @@ -109,3 +116,22 @@ impl Complete { Ok(()) } } + +fn parse_option(line: &str) -> Option { + let regex = Regex::new(r".*temperature=(\d+\.\d+).*").unwrap(); + if let Some(capture) = regex.captures(line) { + let temperature = f32::from_str(&capture[1]).unwrap(); + Some(ChatOption { temperature }) + } else { + None + } +} + +#[cfg(test)] +mod tests { + #[test] + fn parse_option() { + let option = super::parse_option("# system, temperature=2.0, top_p=0.95"); + assert_eq!(option.unwrap().temperature, 2.0); + } +} diff --git a/src/gcloud/gemini.rs b/src/gcloud/gemini.rs index 5389844..3da6eed 100644 --- a/src/gcloud/gemini.rs +++ b/src/gcloud/gemini.rs @@ -26,6 +26,7 @@ use crate::gcloud::gemini_api::GenerateContentResponse; use crate::llm::function::FunctionStore; use crate::llm::ChatEvent; use crate::llm::ChatListener; +use crate::llm::ChatOption; use crate::llm::Usage; use crate::util::exception::Exception; use crate::util::http_client; @@ -37,9 +38,10 @@ pub struct Gemini { system_instruction: Option>, tools: Option>, function_store: FunctionStore, - usage: Usage, - last_model_message: String, + pub option: Option, pub listener: Option>, + last_model_message: String, + usage: Usage, } impl Gemini { @@ -53,9 +55,10 @@ impl Gemini { function_declarations: function_store.declarations.to_vec(), }])), function_store, - usage: Usage::default(), - last_model_message: String::with_capacity(1024), + option: None, listener: None, + last_model_message: String::with_capacity(1024), + usage: Usage::default(), } } @@ -152,7 +155,7 @@ impl Gemini { contents: Rc::clone(&self.messages), system_instruction: self.system_instruction.clone(), generation_config: GenerationConfig { - temperature: 1.0, + temperature: self.option.as_ref().map_or(1.0, |option| option.temperature), top_p: 0.95, max_output_tokens: 4096, }, diff --git a/src/llm.rs b/src/llm.rs index e01804f..0963d3c 100644 --- a/src/llm.rs +++ b/src/llm.rs @@ -22,6 +22,11 @@ pub enum ChatEvent { End(Usage), } +#[derive(Debug)] +pub struct ChatOption { + pub temperature: f32, +} + #[derive(Default)] pub struct Usage { pub request_tokens: i32, @@ -54,6 +59,13 @@ impl Model { Model::Gemini(model) => model.system_instruction(message), } } + + pub fn option(&mut self, option: ChatOption) { + match self { + Model::ChatGPT(model) => model.option = Some(option), + Model::Gemini(model) => model.option = Some(option), + } + } } pub async fn load(path: &Path) -> Result {