Skip to content

Commit

Permalink
support complete
Browse files Browse the repository at this point in the history
  • Loading branch information
neowu committed Jul 12, 2024
1 parent 8374340 commit a62bb69
Show file tree
Hide file tree
Showing 6 changed files with 123 additions and 11 deletions.
7 changes: 5 additions & 2 deletions src/azure/chatgpt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ pub struct ChatGPT {
messages: Rc<Vec<ChatRequestMessage>>,
tools: Option<Rc<[Tool]>>,
function_store: FunctionStore,
last_assistant_message: String,
pub listener: Option<Box<dyn ChatListener>>,
}

Expand All @@ -58,11 +59,12 @@ impl ChatGPT {
messages: Rc::new(vec![]),
tools,
function_store,
last_assistant_message: String::new(),
listener: None,
}
}

pub async fn chat(&mut self, message: String, files: Option<Vec<PathBuf>>) -> Result<(), Exception> {
pub async fn chat(&mut self, message: String, files: Option<Vec<PathBuf>>) -> Result<String, Exception> {
let image_urls = image_urls(files).await?;
self.add_message(ChatRequestMessage::new_user_message(message, image_urls));

Expand All @@ -82,7 +84,7 @@ impl ChatGPT {
}
self.process().await?;
}
Ok(())
Ok(self.last_assistant_message.to_string())
}

pub fn system_message(&mut self, message: String) {
Expand Down Expand Up @@ -143,6 +145,7 @@ impl ChatGPT {
}

if !assistant_message.is_empty() {
self.last_assistant_message = assistant_message.to_string();
self.add_message(ChatRequestMessage::new_message(Role::Assistant, assistant_message));
}

Expand Down
1 change: 1 addition & 0 deletions src/command.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
pub mod chat;
pub mod complete;
pub mod generate_zsh_completion;
pub mod speak;
101 changes: 101 additions & 0 deletions src/command/complete.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
use std::io::Write;
use std::path::PathBuf;

use clap::Args;
use tokio::fs;
use tokio::io::AsyncBufReadExt;
use tokio::io::AsyncWriteExt;
use tokio::io::BufReader;
use tracing::info;

use crate::llm;
use crate::llm::ChatEvent;
use crate::llm::ChatListener;
use crate::util::exception::Exception;

#[derive(Args)]
pub struct Complete {
#[arg(help = "prompt file")]
prompt: PathBuf,

#[arg(long, help = "conf path")]
conf: PathBuf,

#[arg(long, help = "model name")]
name: String,
}

struct Listener;

impl ChatListener for Listener {
fn on_event(&self, event: ChatEvent) {
match event {
ChatEvent::Delta(data) => {
print!("{data}");
let _ = std::io::stdout().flush();
}
ChatEvent::End(usage) => {
println!();
info!(
"usage, request_tokens={}, response_tokens={}",
usage.request_tokens, usage.response_tokens
);
}
}
}
}

impl Complete {
pub async fn execute(&self) -> Result<(), Exception> {
let config = llm::load(&self.conf).await?;
let mut model = config.create(&self.name)?;
model.listener(Box::new(Listener));

let prompt = fs::OpenOptions::new().read(true).open(&self.prompt).await?;
let reader = BufReader::new(prompt);
let mut lines = reader.lines();

let mut files: Vec<PathBuf> = vec![];
let mut message = String::new();
let mut on_system_message = false;
loop {
let Some(line) = lines.next_line().await? else { break };

if line.is_empty() {
continue;
}

if line.starts_with("# system") {
if !message.is_empty() {
return Err(Exception::ValidationError("system message must be at first".to_string()));
}
on_system_message = true;
} else if line.starts_with("---") || line.starts_with("# file: ") {
if on_system_message {
info!("system message: {}", message);
model.system_message(message);
message = String::new();
on_system_message = false;
}
if line.starts_with("# file: ") {
let file = PathBuf::from(line.strip_prefix("# file: ").unwrap().to_string());
info!("file: {}", file.to_string_lossy());
files.push(file);
}
} else {
message.push_str(&line);
message.push('\n');
}
}

info!("prompt: {}", message);
let files = files.into_iter().map(Some).collect();
let message = model.chat(message, files).await?;

let mut prompt = fs::OpenOptions::new().append(true).open(&self.prompt).await?;
prompt.write_all(b"\n---\n\n").await?;
prompt.write_all(message.as_bytes()).await?;

Ok(())
}
}
17 changes: 10 additions & 7 deletions src/gcloud/gemini.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,11 @@ use crate::util::json;
pub struct Gemini {
url: String,
messages: Rc<Vec<Content>>,
system_message: Option<Rc<Content>>,
system_instruction: Option<Rc<Content>>,
tools: Option<Rc<[Tool]>>,
function_store: FunctionStore,
usage: Usage,
last_model_message: String,
pub listener: Option<Box<dyn ChatListener>>,
}

Expand All @@ -47,17 +48,18 @@ impl Gemini {
Gemini {
url,
messages: Rc::new(vec![]),
system_message: None,
system_instruction: None,
tools: function_store.declarations.is_empty().not().then_some(Rc::from(vec![Tool {
function_declarations: function_store.declarations.to_vec(),
}])),
function_store,
usage: Usage::default(),
last_model_message: String::with_capacity(1024),
listener: None,
}
}

pub async fn chat(&mut self, message: String, files: Option<Vec<PathBuf>>) -> Result<(), Exception> {
pub async fn chat(&mut self, message: String, files: Option<Vec<PathBuf>>) -> Result<String, Exception> {
let data = inline_datas(files).await?;
self.add_message(Content::new_user_text(message, data));

Expand All @@ -67,11 +69,11 @@ impl Gemini {
self.add_message(Content::new_function_response(function_call.name, function_response));
result = self.process().await?;
}
Ok(())
Ok(self.last_model_message.to_string())
}

pub fn system_message(&mut self, message: String) {
self.system_message = Some(Rc::new(Content::new_model_text(message)));
pub fn system_instruction(&mut self, message: String) {
self.system_instruction = Some(Rc::new(Content::new_model_text(message)));
}

async fn process(&mut self) -> Result<Option<FunctionCall>, Exception> {
Expand Down Expand Up @@ -126,6 +128,7 @@ impl Gemini {
}

if !model_message.is_empty() {
self.last_model_message = model_message.to_string();
self.add_message(Content::new_model_text(model_message));
}

Expand All @@ -144,7 +147,7 @@ impl Gemini {
async fn call_api(&self) -> Result<Response, Exception> {
let request = StreamGenerateContent {
contents: Rc::clone(&self.messages),
system_instruction: self.system_message.clone(),
system_instruction: self.system_instruction.clone(),
generation_config: GenerationConfig {
temperature: 1.0,
top_p: 0.95,
Expand Down
4 changes: 2 additions & 2 deletions src/llm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ pub enum Model {
}

impl Model {
pub async fn chat(&mut self, message: String, files: Option<Vec<PathBuf>>) -> Result<(), Exception> {
pub async fn chat(&mut self, message: String, files: Option<Vec<PathBuf>>) -> Result<String, Exception> {
match self {
Model::ChatGPT(model) => model.chat(message, files).await,
Model::Gemini(model) => model.chat(message, files).await,
Expand All @@ -51,7 +51,7 @@ impl Model {
pub fn system_message(&mut self, message: String) {
match self {
Model::ChatGPT(model) => model.system_message(message),
Model::Gemini(model) => model.system_message(message),
Model::Gemini(model) => model.system_instruction(message),
}
}
}
Expand Down
4 changes: 4 additions & 0 deletions src/main.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use clap::Parser;
use clap::Subcommand;
use command::chat::Chat;
use command::complete::Complete;
use command::generate_zsh_completion::GenerateZshCompletion;
use command::speak::Speak;
use util::exception::Exception;
Expand Down Expand Up @@ -28,6 +29,8 @@ pub enum Command {
Chat(Chat),
#[command(about = "speak")]
Speech(Speak),
#[command(about = "complete")]
Complete(Complete),
#[command(about = "generate zsh completion")]
GenerateZshCompletion(GenerateZshCompletion),
}
Expand All @@ -39,6 +42,7 @@ async fn main() -> Result<(), Exception> {
match cli.command {
Some(Command::Chat(command)) => command.execute().await,
Some(Command::Speech(command)) => command.execute().await,
Some(Command::Complete(command)) => command.execute().await,
Some(Command::GenerateZshCompletion(command)) => command.execute(),
None => panic!("not implemented"),
}
Expand Down

0 comments on commit a62bb69

Please sign in to comment.