Skip to content

Commit

Permalink
make complete support session
Browse files Browse the repository at this point in the history
  • Loading branch information
neowu committed Jul 15, 2024
1 parent 31b6b54 commit a64c354
Show file tree
Hide file tree
Showing 6 changed files with 95 additions and 37 deletions.
15 changes: 11 additions & 4 deletions src/azure/chatgpt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,7 @@ impl ChatGPT {
}
}

pub async fn chat(&mut self, message: String, files: Option<Vec<PathBuf>>) -> Result<String, Exception> {
let image_urls = image_urls(files).await?;
self.add_message(ChatRequestMessage::new_user_message(message, image_urls));

pub async fn chat(&mut self) -> Result<String, Exception> {
let result = self.process().await?;
if let Some(calls) = result {
self.add_message(ChatRequestMessage::new_function_call(&calls));
Expand Down Expand Up @@ -100,6 +97,16 @@ impl ChatGPT {
messages.insert(0, ChatRequestMessage::new_message(Role::System, message))
}

pub async fn add_user_message(&mut self, message: String, files: Option<Vec<PathBuf>>) -> Result<(), Exception> {
let image_urls = image_urls(files).await?;
self.add_message(ChatRequestMessage::new_user_message(message, image_urls));
Ok(())
}

pub fn add_assistant_message(&mut self, message: String) {
self.add_message(ChatRequestMessage::new_message(Role::Assistant, message));
}

fn add_message(&mut self, message: ChatRequestMessage) {
Rc::get_mut(&mut self.messages).unwrap().push(message);
}
Expand Down
3 changes: 2 additions & 1 deletion src/command/chat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,8 @@ impl Chat {
files.push(file);
} else {
let files = mem::take(&mut files).into_iter().map(Some).collect();
model.chat(line, files).await?;
model.add_user_message(line, files).await?;
model.chat().await?;
}
}

Expand Down
61 changes: 45 additions & 16 deletions src/command/complete.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use std::io::Write;
use std::mem;
use std::path::PathBuf;
use std::str::FromStr;

Expand Down Expand Up @@ -48,6 +49,12 @@ impl ChatListener for Listener {
}
}

enum ParserState {
System,
User,
Assistant,
}

impl Complete {
pub async fn execute(&self) -> Result<(), Exception> {
let config = llm::load(&self.conf).await?;
Expand All @@ -60,7 +67,7 @@ impl Complete {

let mut files: Vec<PathBuf> = vec![];
let mut message = String::new();
let mut on_system_message = false;
let mut state = ParserState::User;
loop {
let Some(line) = lines.next_line().await? else { break };

Expand All @@ -72,20 +79,17 @@ impl Complete {
if !message.is_empty() {
return Err(Exception::ValidationError("system message must be at first".to_string()));
}
on_system_message = true;
state = ParserState::System;
if let Some(option) = parse_option(&line) {
info!("option: {:?}", option);
model.option(option);
}
} else if line.starts_with("# prompt") {
if on_system_message {
info!("system message: {}", message);
model.system_message(message);
message = String::new();
on_system_message = false;
}
} else if line.starts_with("# anwser") {
break;
} else if line.starts_with("# user") {
add_message(&mut model, &state, mem::take(&mut message), mem::take(&mut files)).await?;
state = ParserState::User;
} else if line.starts_with("# assistant") {
add_message(&mut model, &state, mem::take(&mut message), mem::take(&mut files)).await?;
state = ParserState::Assistant;
} else if line.starts_with("> file: ") {
let file = self.prompt.with_file_name(line.strip_prefix("> file: ").unwrap());
let extension = file
Expand All @@ -94,7 +98,9 @@ impl Complete {
.to_str()
.unwrap();
if extension == "txt" {
message.push_str(&fs::read_to_string(file).await?)
message.push_str(&format!("> start of file: {}\n", &file.to_string_lossy()));
message.push_str(&fs::read_to_string(&file).await?);
message.push_str(&format!("> end of file: {}\n", &file.to_string_lossy()));
} else {
info!("file: {}", file.to_string_lossy());
files.push(file);
Expand All @@ -105,18 +111,41 @@ impl Complete {
}
}

info!("prompt: {}", message);
let files = files.into_iter().map(Some).collect();
let message = model.chat(message, files).await?;
add_message(&mut model, &state, message, files).await?;
let message = model.chat().await?;

let mut prompt = fs::OpenOptions::new().append(true).open(&self.prompt).await?;
prompt.write_all(format!("\n# anwser ({})\n\n", self.name).as_bytes()).await?;
prompt.write_all(format!("\n# assistant ({})\n\n", self.name).as_bytes()).await?;
prompt.write_all(message.as_bytes()).await?;

Ok(())
}
}

async fn add_message(model: &mut llm::Model, state: &ParserState, message: String, files: Vec<PathBuf>) -> Result<(), Exception> {
match state {
ParserState::System => {
info!("system message: {}", message);
model.system_message(message);
}
ParserState::User => {
info!("user message: {}", message);
model.add_user_message(message, files.into_iter().map(Some).collect()).await?;
}
ParserState::Assistant => {
if !files.is_empty() {
return Err(Exception::ValidationError(format!(
"cannot include file in assistant message, files={:?}",
files
)));
}
info!("assistent message: {}", message);
model.add_assistant_message(message);
}
}
Ok(())
}

fn parse_option(line: &str) -> Option<ChatOption> {
let regex = Regex::new(r".*temperature=(\d+\.\d+).*").unwrap();
if let Some(capture) = regex.captures(line) {
Expand Down
25 changes: 16 additions & 9 deletions src/gcloud/gemini.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,13 +62,7 @@ impl Gemini {
}
}

pub async fn chat(&mut self, message: String, files: Option<Vec<PathBuf>>) -> Result<String, Exception> {
let data = inline_datas(files).await?;
if data.is_some() {
self.tools = None; // function call is not supported with inline data
}
self.add_message(Content::new_user_text(message, data));

pub async fn chat(&mut self) -> Result<String, Exception> {
let mut result = self.process().await?;
while let Some(function_call) = result {
let function_response = self.function_store.call_function(function_call.name.clone(), function_call.args).await?;
Expand All @@ -78,8 +72,21 @@ impl Gemini {
Ok(self.last_model_message.to_string())
}

pub fn system_instruction(&mut self, message: String) {
self.system_instruction = Some(Rc::new(Content::new_model_text(message)));
pub fn system_instruction(&mut self, text: String) {
self.system_instruction = Some(Rc::new(Content::new_model_text(text)));
}

pub async fn add_user_text(&mut self, text: String, files: Option<Vec<PathBuf>>) -> Result<(), Exception> {
let data = inline_datas(files).await?;
if data.is_some() {
self.tools = None; // function call is not supported with inline data
}
self.add_message(Content::new_user_text(text, data));
Ok(())
}

pub fn add_model_text(&mut self, text: String) {
self.add_message(Content::new_model_text(text));
}

async fn process(&mut self) -> Result<Option<FunctionCall>, Exception> {
Expand Down
8 changes: 4 additions & 4 deletions src/gcloud/gemini_api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ pub struct Content {
}

impl Content {
pub fn new_user_text(message: String, data: Option<Vec<InlineData>>) -> Self {
pub fn new_user_text(text: String, data: Option<Vec<InlineData>>) -> Self {
let mut parts: Vec<Part> = vec![];
if let Some(data) = data {
parts.append(
Expand All @@ -39,19 +39,19 @@ impl Content {
);
}
parts.push(Part {
text: Some(message),
text: Some(text),
inline_data: None,
function_call: None,
function_response: None,
});
Self { role: Role::User, parts }
}

pub fn new_model_text(message: String) -> Self {
pub fn new_model_text(text: String) -> Self {
Self {
role: Role::Model,
parts: vec![Part {
text: Some(message),
text: Some(text),
inline_data: None,
function_call: None,
function_response: None,
Expand Down
20 changes: 17 additions & 3 deletions src/llm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,10 @@ pub enum Model {
}

impl Model {
pub async fn chat(&mut self, message: String, files: Option<Vec<PathBuf>>) -> Result<String, Exception> {
pub async fn chat(&mut self) -> Result<String, Exception> {
match self {
Model::ChatGPT(model) => model.chat(message, files).await,
Model::Gemini(model) => model.chat(message, files).await,
Model::ChatGPT(model) => model.chat().await,
Model::Gemini(model) => model.chat().await,
}
}

Expand All @@ -66,6 +66,20 @@ impl Model {
Model::Gemini(model) => model.option = Some(option),
}
}

pub async fn add_user_message(&mut self, message: String, files: Option<Vec<PathBuf>>) -> Result<(), Exception> {
match self {
Model::ChatGPT(model) => model.add_user_message(message, files).await,
Model::Gemini(model) => model.add_user_text(message, files).await,
}
}

pub fn add_assistant_message(&mut self, message: String) {
match self {
Model::ChatGPT(model) => model.add_assistant_message(message),
Model::Gemini(model) => model.add_model_text(message),
}
}
}

pub async fn load(path: &Path) -> Result<Config, Exception> {
Expand Down

0 comments on commit a64c354

Please sign in to comment.