Skip to content

Commit

Permalink
tweak complete
Browse files Browse the repository at this point in the history
  • Loading branch information
neowu committed Jul 12, 2024
1 parent 93c8793 commit 31b6b54
Show file tree
Hide file tree
Showing 7 changed files with 93 additions and 9 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
/target
/env
tests
/prompts
39 changes: 39 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,4 @@ rand = "0"
base64 = "0"
uuid = { version = "1", features = ["v4"] }
bytes = "1"
regex = "1"
7 changes: 5 additions & 2 deletions src/azure/chatgpt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ use crate::azure::chatgpt_api::Tool;
use crate::llm::function::FunctionStore;
use crate::llm::ChatEvent;
use crate::llm::ChatListener;
use crate::llm::ChatOption;
use crate::llm::Usage;
use crate::util::exception::Exception;
use crate::util::http_client;
Expand All @@ -34,8 +35,9 @@ pub struct ChatGPT {
messages: Rc<Vec<ChatRequestMessage>>,
tools: Option<Rc<[Tool]>>,
function_store: FunctionStore,
last_assistant_message: String,
pub option: Option<ChatOption>,
pub listener: Option<Box<dyn ChatListener>>,
last_assistant_message: String,
}

type FunctionCall = HashMap<i64, (String, String, String)>;
Expand All @@ -61,6 +63,7 @@ impl ChatGPT {
function_store,
last_assistant_message: String::new(),
listener: None,
option: None,
}
}

Expand Down Expand Up @@ -162,7 +165,7 @@ impl ChatGPT {
async fn call_api(&mut self) -> Result<Response, Exception> {
let request = ChatRequest {
messages: Rc::clone(&self.messages),
temperature: 0.7,
temperature: self.option.as_ref().map_or(0.7, |option| option.temperature),
top_p: 0.95,
stream: true,
// stream_options: Some(StreamOptions { include_usage: true }),
Expand Down
28 changes: 27 additions & 1 deletion src/command/complete.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
use std::io::Write;
use std::path::PathBuf;
use std::str::FromStr;

use clap::Args;
use regex::Regex;
use tokio::fs;
use tokio::io::AsyncBufReadExt;
use tokio::io::AsyncWriteExt;
Expand All @@ -11,11 +13,12 @@ use tracing::info;
use crate::llm;
use crate::llm::ChatEvent;
use crate::llm::ChatListener;
use crate::llm::ChatOption;
use crate::util::exception::Exception;

#[derive(Args)]
pub struct Complete {
#[arg(help = "prompt file")]
#[arg(help = "prompt file path")]
prompt: PathBuf,

#[arg(long, help = "conf path")]
Expand Down Expand Up @@ -70,6 +73,10 @@ impl Complete {
return Err(Exception::ValidationError("system message must be at first".to_string()));
}
on_system_message = true;
if let Some(option) = parse_option(&line) {
info!("option: {:?}", option);
model.option(option);
}
} else if line.starts_with("# prompt") {
if on_system_message {
info!("system message: {}", message);
Expand Down Expand Up @@ -109,3 +116,22 @@ impl Complete {
Ok(())
}
}

fn parse_option(line: &str) -> Option<ChatOption> {
let regex = Regex::new(r".*temperature=(\d+\.\d+).*").unwrap();
if let Some(capture) = regex.captures(line) {
let temperature = f32::from_str(&capture[1]).unwrap();
Some(ChatOption { temperature })
} else {
None
}
}

#[cfg(test)]
mod tests {
#[test]
fn parse_option() {
let option = super::parse_option("# system, temperature=2.0, top_p=0.95");
assert_eq!(option.unwrap().temperature, 2.0);
}
}
13 changes: 8 additions & 5 deletions src/gcloud/gemini.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ use crate::gcloud::gemini_api::GenerateContentResponse;
use crate::llm::function::FunctionStore;
use crate::llm::ChatEvent;
use crate::llm::ChatListener;
use crate::llm::ChatOption;
use crate::llm::Usage;
use crate::util::exception::Exception;
use crate::util::http_client;
Expand All @@ -37,9 +38,10 @@ pub struct Gemini {
system_instruction: Option<Rc<Content>>,
tools: Option<Rc<[Tool]>>,
function_store: FunctionStore,
usage: Usage,
last_model_message: String,
pub option: Option<ChatOption>,
pub listener: Option<Box<dyn ChatListener>>,
last_model_message: String,
usage: Usage,
}

impl Gemini {
Expand All @@ -53,9 +55,10 @@ impl Gemini {
function_declarations: function_store.declarations.to_vec(),
}])),
function_store,
usage: Usage::default(),
last_model_message: String::with_capacity(1024),
option: None,
listener: None,
last_model_message: String::with_capacity(1024),
usage: Usage::default(),
}
}

Expand Down Expand Up @@ -152,7 +155,7 @@ impl Gemini {
contents: Rc::clone(&self.messages),
system_instruction: self.system_instruction.clone(),
generation_config: GenerationConfig {
temperature: 1.0,
temperature: self.option.as_ref().map_or(1.0, |option| option.temperature),
top_p: 0.95,
max_output_tokens: 4096,
},
Expand Down
12 changes: 12 additions & 0 deletions src/llm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@ pub enum ChatEvent {
End(Usage),
}

#[derive(Debug)]
pub struct ChatOption {
pub temperature: f32,
}

#[derive(Default)]
pub struct Usage {
pub request_tokens: i32,
Expand Down Expand Up @@ -54,6 +59,13 @@ impl Model {
Model::Gemini(model) => model.system_instruction(message),
}
}

pub fn option(&mut self, option: ChatOption) {
match self {
Model::ChatGPT(model) => model.option = Some(option),
Model::Gemini(model) => model.option = Some(option),
}
}
}

pub async fn load(path: &Path) -> Result<Config, Exception> {
Expand Down

0 comments on commit 31b6b54

Please sign in to comment.