Skip to content

Commit

Permalink
make chat listener async
Browse files Browse the repository at this point in the history
  • Loading branch information
neowu committed Jul 15, 2024
1 parent a64c354 commit 989845e
Show file tree
Hide file tree
Showing 10 changed files with 83 additions and 165 deletions.
49 changes: 0 additions & 49 deletions .vscode/launch.json

This file was deleted.

27 changes: 0 additions & 27 deletions .vscode/settings.json

This file was deleted.

17 changes: 10 additions & 7 deletions src/azure/chatgpt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,21 +29,24 @@ use crate::util::exception::Exception;
use crate::util::http_client;
use crate::util::json;

pub struct ChatGPT {
pub struct ChatGPT<L>
where
L: ChatListener,
{
url: String,
api_key: String,
messages: Rc<Vec<ChatRequestMessage>>,
tools: Option<Rc<[Tool]>>,
function_store: FunctionStore,
listener: Option<L>,
pub option: Option<ChatOption>,
pub listener: Option<Box<dyn ChatListener>>,
last_assistant_message: String,
}

type FunctionCall = HashMap<i64, (String, String, String)>;

impl ChatGPT {
pub fn new(endpoint: String, model: String, api_key: String, function_store: FunctionStore) -> Self {
impl<L: ChatListener> ChatGPT<L> {
pub fn new(endpoint: String, model: String, api_key: String, function_store: FunctionStore, listener: Option<L>) -> Self {
let url = format!("{endpoint}/openai/deployments/{model}/chat/completions?api-version=2024-06-01");
let tools: Option<Rc<[Tool]>> = function_store.declarations.is_empty().not().then_some(
function_store
Expand All @@ -61,8 +64,8 @@ impl ChatGPT {
messages: Rc::new(vec![]),
tools,
function_store,
listener,
last_assistant_message: String::new(),
listener: None,
option: None,
}
}
Expand Down Expand Up @@ -141,7 +144,7 @@ impl ChatGPT {
assistant_message.push_str(&value);

if let Some(listener) = self.listener.as_ref() {
listener.on_event(ChatEvent::Delta(value));
listener.on_event(ChatEvent::Delta(value)).await?;
}
}
}
Expand All @@ -163,7 +166,7 @@ impl ChatGPT {
Ok(Some(function_calls))
} else {
if let Some(listener) = self.listener.as_ref() {
listener.on_event(ChatEvent::End(usage));
listener.on_event(ChatEvent::End(usage)).await?;
}
Ok(None)
}
Expand Down
41 changes: 7 additions & 34 deletions src/command/chat.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
use std::io::Write;
use std::mem;
use std::path::PathBuf;

use clap::Args;
use tokio::io::stdin;
use tokio::io::AsyncBufReadExt;
use tokio::io::BufReader;
use tracing::info;

use crate::llm;
use crate::llm::ChatEvent;
use crate::llm::ChatListener;
use crate::llm::ConsolePrinter;
use crate::util::console;
use crate::util::exception::Exception;

#[derive(Args)]
Expand All @@ -22,38 +20,19 @@ pub struct Chat {
name: String,
}

struct ConsoleHandler;

impl ChatListener for ConsoleHandler {
fn on_event(&self, event: ChatEvent) {
match event {
ChatEvent::Delta(data) => {
print_flush(&data).unwrap();
}
ChatEvent::End(usage) => {
println!();
info!(
"usage, request_tokens={}, response_tokens={}",
usage.request_tokens, usage.response_tokens
);
}
}
}
}

impl Chat {
pub async fn execute(&self) -> Result<(), Exception> {
let config = llm::load(&self.conf).await?;
let mut model = config.create(&self.name)?;
model.listener(Box::new(ConsoleHandler));
let mut model = config.create(&self.name, Some(ConsolePrinter))?;

let reader = BufReader::new(stdin());
let mut lines = reader.lines();

let mut files: Vec<PathBuf> = vec![];
loop {
print_flush("> ")?;
let Some(line) = lines.next_line().await? else { break };
console::print("> ").await?;
let Some(line) = lines.next_line().await? else {
break;
};
if line.starts_with("/quit") {
break;
}
Expand All @@ -71,9 +50,3 @@ impl Chat {
Ok(())
}
}

fn print_flush(message: &str) -> Result<(), Exception> {
print!("{message}");
std::io::stdout().flush()?;
Ok(())
}
35 changes: 7 additions & 28 deletions src/command/complete.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
use std::io::Write;
use std::mem;
use std::path::PathBuf;
use std::str::FromStr;
Expand All @@ -12,9 +11,9 @@ use tokio::io::BufReader;
use tracing::info;

use crate::llm;
use crate::llm::ChatEvent;
use crate::llm::ChatListener;
use crate::llm::ChatOption;
use crate::llm::ConsolePrinter;
use crate::util::exception::Exception;

#[derive(Args)]
Expand All @@ -29,26 +28,6 @@ pub struct Complete {
name: String,
}

struct Listener;

impl ChatListener for Listener {
fn on_event(&self, event: ChatEvent) {
match event {
ChatEvent::Delta(data) => {
print!("{data}");
let _ = std::io::stdout().flush();
}
ChatEvent::End(usage) => {
println!();
info!(
"usage, request_tokens={}, response_tokens={}",
usage.request_tokens, usage.response_tokens
);
}
}
}
}

enum ParserState {
System,
User,
Expand All @@ -58,8 +37,7 @@ enum ParserState {
impl Complete {
pub async fn execute(&self) -> Result<(), Exception> {
let config = llm::load(&self.conf).await?;
let mut model = config.create(&self.name)?;
model.listener(Box::new(Listener));
let mut model = config.create(&self.name, Some(ConsolePrinter))?;

let prompt = fs::OpenOptions::new().read(true).open(&self.prompt).await?;
let reader = BufReader::new(prompt);
Expand All @@ -68,9 +46,7 @@ impl Complete {
let mut files: Vec<PathBuf> = vec![];
let mut message = String::new();
let mut state = ParserState::User;
loop {
let Some(line) = lines.next_line().await? else { break };

while let Some(line) = lines.next_line().await? {
if line.is_empty() {
continue;
}
Expand Down Expand Up @@ -122,7 +98,10 @@ impl Complete {
}
}

async fn add_message(model: &mut llm::Model, state: &ParserState, message: String, files: Vec<PathBuf>) -> Result<(), Exception> {
async fn add_message<L>(model: &mut llm::Model<L>, state: &ParserState, message: String, files: Vec<PathBuf>) -> Result<(), Exception>
where
L: ChatListener,
{
match state {
ParserState::System => {
info!("system message: {}", message);
Expand Down
17 changes: 10 additions & 7 deletions src/gcloud/gemini.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,20 +32,23 @@ use crate::util::exception::Exception;
use crate::util::http_client;
use crate::util::json;

pub struct Gemini {
pub struct Gemini<L>
where
L: ChatListener,
{
url: String,
messages: Rc<Vec<Content>>,
system_instruction: Option<Rc<Content>>,
tools: Option<Rc<[Tool]>>,
function_store: FunctionStore,
listener: Option<L>,
pub option: Option<ChatOption>,
pub listener: Option<Box<dyn ChatListener>>,
last_model_message: String,
usage: Usage,
}

impl Gemini {
pub fn new(endpoint: String, project: String, location: String, model: String, function_store: FunctionStore) -> Self {
impl<L: ChatListener> Gemini<L> {
pub fn new(endpoint: String, project: String, location: String, model: String, function_store: FunctionStore, listener: Option<L>) -> Self {
let url = format!("{endpoint}/v1/projects/{project}/locations/{location}/publishers/google/models/{model}:streamGenerateContent");
Gemini {
url,
Expand All @@ -55,8 +58,8 @@ impl Gemini {
function_declarations: function_store.declarations.to_vec(),
}])),
function_store,
listener,
option: None,
listener: None,
last_model_message: String::with_capacity(1024),
usage: Usage::default(),
}
Expand Down Expand Up @@ -129,7 +132,7 @@ impl Gemini {
} else if let Some(text) = part.text {
model_message.push_str(&text);
if let Some(listener) = self.listener.as_ref() {
listener.on_event(ChatEvent::Delta(text));
listener.on_event(ChatEvent::Delta(text)).await?;
}
}
}
Expand All @@ -147,7 +150,7 @@ impl Gemini {

let usage = mem::take(&mut self.usage);
if let Some(listener) = self.listener.as_ref() {
listener.on_event(ChatEvent::End(usage));
listener.on_event(ChatEvent::End(usage)).await?;
}

Ok(None)
Expand Down
Loading

0 comments on commit 989845e

Please sign in to comment.