Skip to content

Commit

Permalink
refactor llm
Browse files Browse the repository at this point in the history
  • Loading branch information
neowu committed Jul 12, 2024
1 parent bc65bb9 commit 8374340
Show file tree
Hide file tree
Showing 9 changed files with 409 additions and 452 deletions.
521 changes: 194 additions & 327 deletions Cargo.lock

Large diffs are not rendered by default.

89 changes: 66 additions & 23 deletions src/azure/chatgpt.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
use std::collections::HashMap;
use std::ops::Not;
use std::path::Path;
use std::path::PathBuf;
use std::rc::Rc;
use std::str;
use std::str::Utf8Error;

use base64::prelude::BASE64_STANDARD;
use base64::Engine;
use bytes::Bytes;
use reqwest::Response;
use tokio::fs;
use tokio::sync::mpsc::channel;
use tokio::sync::mpsc::Receiver;
use tokio::sync::mpsc::Sender;
Expand All @@ -18,7 +22,7 @@ use crate::azure::chatgpt_api::Role;
use crate::azure::chatgpt_api::Tool;
use crate::llm::function::FunctionStore;
use crate::llm::ChatEvent;
use crate::llm::ChatHandler;
use crate::llm::ChatListener;
use crate::llm::Usage;
use crate::util::exception::Exception;
use crate::util::http_client;
Expand All @@ -30,12 +34,13 @@ pub struct ChatGPT {
messages: Rc<Vec<ChatRequestMessage>>,
tools: Option<Rc<[Tool]>>,
function_store: FunctionStore,
pub listener: Option<Box<dyn ChatListener>>,
}

type FunctionCall = HashMap<i64, (String, String, String)>;

impl ChatGPT {
pub fn new(endpoint: String, model: String, api_key: String, system_message: Option<String>, function_store: FunctionStore) -> Self {
pub fn new(endpoint: String, model: String, api_key: String, function_store: FunctionStore) -> Self {
let url = format!("{endpoint}/openai/deployments/{model}/chat/completions?api-version=2024-06-01");
let tools: Option<Rc<[Tool]>> = function_store.declarations.is_empty().not().then_some(
function_store
Expand All @@ -47,23 +52,21 @@ impl ChatGPT {
})
.collect(),
);
let mut chatgpt = ChatGPT {
ChatGPT {
url,
api_key,
messages: Rc::new(vec![]),
tools,
function_store,
};
if let Some(message) = system_message {
chatgpt.add_message(ChatRequestMessage::new_message(Role::System, message));
listener: None,
}
chatgpt
}

pub async fn chat(&mut self, message: String, handler: &impl ChatHandler) -> Result<(), Exception> {
self.add_message(ChatRequestMessage::new_message(Role::User, message));
pub async fn chat(&mut self, message: String, files: Option<Vec<PathBuf>>) -> Result<(), Exception> {
let image_urls = image_urls(files).await?;
self.add_message(ChatRequestMessage::new_user_message(message, image_urls));

let result = self.process(handler).await?;
let result = self.process().await?;
if let Some(calls) = result {
self.add_message(ChatRequestMessage::new_function_call(&calls));

Expand All @@ -77,30 +80,40 @@ impl ChatGPT {
let function_response = ChatRequestMessage::new_function_response(result.0, json::to_json(&result.1)?);
self.add_message(function_response);
}
self.process(handler).await?;
self.process().await?;
}
Ok(())
}

pub fn system_message(&mut self, message: String) {
let messages = Rc::get_mut(&mut self.messages).unwrap();
if let Some(message) = messages.first() {
if let Role::System = message.role {
messages.remove(0);
}
}
messages.insert(0, ChatRequestMessage::new_message(Role::System, message))
}

fn add_message(&mut self, message: ChatRequestMessage) {
Rc::get_mut(&mut self.messages).unwrap().push(message);
}

async fn process(&mut self, handler: &impl ChatHandler) -> Result<Option<FunctionCall>, Exception> {
async fn process(&mut self) -> Result<Option<FunctionCall>, Exception> {
let (tx, rx) = channel(64);

let response = self.call_api().await?;
let handle = tokio::spawn(read_sse(response, tx));
let function_call = self.process_response(rx, handler).await?;
let function_call = self.process_response(rx).await?;
handle.await??;

Ok(function_call)
}

async fn process_response(&mut self, mut rx: Receiver<ChatResponse>, handler: &impl ChatHandler) -> Result<Option<FunctionCall>, Exception> {
async fn process_response(&mut self, mut rx: Receiver<ChatResponse>) -> Result<Option<FunctionCall>, Exception> {
let mut function_calls: FunctionCall = HashMap::new();
let mut usage = Usage::default();
let mut assistant_message = String::new();
let mut usage = Usage::default();

while let Some(response) = rx.recv().await {
if let Some(choice) = response.choices.into_iter().next() {
Expand All @@ -114,7 +127,10 @@ impl ChatGPT {
function_calls.get_mut(&call.index).unwrap().2.push_str(&call.function.arguments)
} else if let Some(value) = delta.content {
assistant_message.push_str(&value);
handler.on_event(ChatEvent::Delta(value));

if let Some(listener) = self.listener.as_ref() {
listener.on_event(ChatEvent::Delta(value));
}
}
}

Expand All @@ -133,7 +149,9 @@ impl ChatGPT {
if !function_calls.is_empty() {
Ok(Some(function_calls))
} else {
handler.on_event(ChatEvent::End(usage));
if let Some(listener) = self.listener.as_ref() {
listener.on_event(ChatEvent::End(usage));
}
Ok(None)
}
}
Expand Down Expand Up @@ -165,7 +183,7 @@ impl ChatGPT {
let response = request.send().await?;
let status = response.status();
if status != 200 {
let body = str::from_utf8(&body).unwrap();
let body = str::from_utf8(&body)?;
info!("body={}", body);
let response_text = response.text().await?;
return Err(Exception::ExternalError(format!(
Expand Down Expand Up @@ -203,8 +221,33 @@ async fn read_sse(response: Response, tx: Sender<ChatResponse>) -> Result<(), Ex
Ok(())
}

impl From<Utf8Error> for Exception {
fn from(err: Utf8Error) -> Self {
Exception::unexpected(err)
}
async fn image_urls(files: Option<Vec<PathBuf>>) -> Result<Option<Vec<String>>, Exception> {
let image_urls = if let Some(paths) = files {
let mut image_urls = Vec::with_capacity(paths.len());
for path in paths {
image_urls.push(base64_image_url(&path).await?)
}
Some(image_urls)
} else {
None
};
Ok(image_urls)
}

async fn base64_image_url(path: &Path) -> Result<String, Exception> {
let extension = path
.extension()
.ok_or_else(|| Exception::ValidationError(format!("file must have extension, path={}", path.to_string_lossy())))?
.to_str()
.unwrap();
let content = fs::read(path).await?;
let mime_type = match extension {
"jpg" => Ok("image/jpeg".to_string()),
"png" => Ok("image/png".to_string()),
_ => Err(Exception::ValidationError(format!(
"not supported extension, path={}",
path.to_string_lossy()
))),
}?;
Ok(format!("data:{mime_type};base64,{}", BASE64_STANDARD.encode(content)))
}
24 changes: 24 additions & 0 deletions src/azure/chatgpt_api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,30 @@ impl ChatRequestMessage {
}
}

pub fn new_user_message(message: String, image_urls: Option<Vec<String>>) -> Self {
let mut content = vec![];
content.push(Content {
r#type: "text".to_string(),
text: Some(message),
image_url: None,
});
if let Some(image_urls) = image_urls {
for url in image_urls {
content.push(Content {
r#type: "image_url".to_string(),
text: None,
image_url: Some(ImageUrl { url }),
});
}
}
ChatRequestMessage {
role: Role::User,
content: Some(content),
tool_call_id: None,
tool_calls: None,
}
}

pub fn new_function_response(id: String, result: String) -> Self {
ChatRequestMessage {
role: Role::Tool,
Expand Down
16 changes: 10 additions & 6 deletions src/command/chat.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::io;
use std::io::Write;
use std::path::Path;
use std::mem;
use std::path::PathBuf;

use clap::Args;
Expand All @@ -11,7 +11,7 @@ use tracing::info;

use crate::llm;
use crate::llm::ChatEvent;
use crate::llm::ChatHandler;
use crate::llm::ChatListener;
use crate::util::exception::Exception;

#[derive(Args)]
Expand All @@ -25,7 +25,7 @@ pub struct Chat {

struct ConsoleHandler;

impl ChatHandler for ConsoleHandler {
impl ChatListener for ConsoleHandler {
fn on_event(&self, event: ChatEvent) {
match event {
ChatEvent::Delta(data) => {
Expand All @@ -46,21 +46,25 @@ impl Chat {
pub async fn execute(&self) -> Result<(), Exception> {
let config = llm::load(&self.conf).await?;
let mut model = config.create(&self.name)?;
let handler = ConsoleHandler {};
model.listener(Box::new(ConsoleHandler));

let reader = BufReader::new(stdin());
let mut lines = reader.lines();

let mut files: Vec<PathBuf> = vec![];
loop {
print_flush("> ")?;
let Some(line) = lines.next_line().await? else { break };
if line.starts_with("/quit") {
break;
}
if line.starts_with("/file ") {
model.file(Path::new(line.strip_prefix("/file ").unwrap()))?;
let file = PathBuf::from(line.strip_prefix("/file ").unwrap().to_string());
println!("added file, path={}", file.to_string_lossy());
files.push(file);
} else {
model.chat(line, &handler).await?;
let files = mem::take(&mut files).into_iter().map(Some).collect();
model.chat(line, files).await?;
}
}

Expand Down
Loading

0 comments on commit 8374340

Please sign in to comment.