Skip to content

Commit

Permalink
refactor error handling
Browse files Browse the repository at this point in the history
  • Loading branch information
neowu committed Apr 22, 2024
1 parent f29726c commit 9afcf22
Show file tree
Hide file tree
Showing 6 changed files with 91 additions and 55 deletions.
13 changes: 8 additions & 5 deletions src/bot/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,15 +37,18 @@ impl FunctionStore {
self.implementations.insert(name, Arc::new(implementation));
}

pub async fn call_function(&self, name: &str, args: serde_json::Value) -> Result<serde_json::Value, Exception> {
info!("call function, name={name}, args={args}");
let function = self.get(name)?;
let response = tokio::spawn(async move { function(args) }).await?;
pub async fn call_function(&self, name: String, args: serde_json::Value) -> Result<serde_json::Value, Exception> {
let function = self.get(&name)?;
let response = tokio::spawn(async move {
info!("call function, name={name}, args={args}");
function(args)
})
.await?;
Ok(response)
}

pub async fn call_functions(&self, functions: Vec<(String, String, serde_json::Value)>) -> Result<Vec<(String, serde_json::Value)>, Exception> {
let mut handles = vec![];
let mut handles = Vec::with_capacity(functions.len());
for (id, name, args) in functions {
let function = self.get(&name)?;
handles.push(tokio::spawn(async move {
Expand Down
2 changes: 1 addition & 1 deletion src/command/chat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ impl ChatHandler for ConsoleHandler {
print_flush(&data).unwrap();
}
ChatEvent::Error(error) => {
println!("Error: {}", error);
println!("\nError: {}", error);
}
ChatEvent::End(usage) => {
println!();
Expand Down
4 changes: 3 additions & 1 deletion src/gcloud/api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,9 @@ pub struct GenerateContentResponse {

#[derive(Debug, Deserialize)]
pub struct Candidate {
pub content: Content,
pub content: Option<Content>,
#[serde(rename = "finishReason")]
pub finish_reason: Option<String>,
}

#[derive(Debug, Serialize, Deserialize, Clone)]
Expand Down
56 changes: 37 additions & 19 deletions src/gcloud/vertex.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ use base64::Engine;
use futures::StreamExt;
use reqwest::Response;
use tokio::sync::mpsc::channel;
use tokio::sync::mpsc::Receiver;
use tokio::sync::mpsc::Sender;
use tracing::info;

Expand Down Expand Up @@ -67,7 +68,7 @@ impl Vertex {
let mut result = self.process(Content::new_text_with_inline_data(message, data), handler).await?;

while let Some(function_call) = result {
let function_response = self.function_store.call_function(&function_call.name, function_call.args).await?;
let function_response = self.function_store.call_function(function_call.name.clone(), function_call.args).await?;
let content = Content::new_function_response(function_call.name, function_response);
result = self.process(content, handler).await?;
}
Expand Down Expand Up @@ -103,35 +104,52 @@ impl Vertex {

let response = self.call_api().await?;

let (tx, mut rx) = channel(64);
let (tx, rx) = channel(64);
let handle = tokio::spawn(read_response_stream(response, tx));
let function_call = self.process_response(rx, handler).await;
let _ = tokio::try_join!(handle)?;

tokio::spawn(process_response_stream(response, tx));
Ok(function_call)
}

async fn process_response(&mut self, mut rx: Receiver<GenerateContentResponse>, handler: &impl ChatHandler) -> Option<FunctionCall> {
let mut model_message = String::new();
while let Some(response) = rx.recv().await {
let response = response?;

if let Some(usage) = response.usage_metadata {
self.usage.request_tokens += usage.prompt_token_count;
self.usage.response_tokens += usage.candidates_token_count;
}

let part = response.candidates.into_iter().next().unwrap().content.parts.into_iter().next().unwrap();

if let Some(function_call) = part.function_call {
self.add_message(Content::new_function_call(function_call.clone()));
return Ok(Some(function_call));
} else if let Some(text) = part.text {
model_message.push_str(&text);
handler.on_event(ChatEvent::Delta(text));
let candidate = response.candidates.into_iter().next().unwrap();
match candidate.content {
Some(content) => {
let part = content.parts.into_iter().next().unwrap();

if let Some(function_call) = part.function_call {
self.add_message(Content::new_function_call(function_call.clone()));
return Some(function_call);
} else if let Some(text) = part.text {
model_message.push_str(&text);
handler.on_event(ChatEvent::Delta(text));
}
}
None => {
handler.on_event(ChatEvent::Error(format!(
"response ended, finish_reason={}",
candidate.finish_reason.unwrap_or("".to_string())
)));
}
}
}

if !model_message.is_empty() {
self.add_message(Content::new_text(Role::Model, model_message));
}

let usage = mem::take(&mut self.usage);
handler.on_event(ChatEvent::End(usage));
Ok(None)

None
}

fn add_message(&mut self, content: Content) {
Expand Down Expand Up @@ -174,7 +192,7 @@ impl Vertex {
}
}

async fn process_response_stream(response: Response, tx: Sender<Result<GenerateContentResponse, Exception>>) {
async fn read_response_stream(response: Response, tx: Sender<GenerateContentResponse>) -> Result<(), Exception> {
let stream = &mut response.bytes_stream();

let mut buffer = String::new();
Expand All @@ -188,16 +206,16 @@ async fn process_response_stream(response: Response, tx: Sender<Result<GenerateC
continue;
}

let content: GenerateContentResponse = json::from_json(&buffer[1..]).unwrap();
tx.send(Ok(content)).await.unwrap();
let content: GenerateContentResponse = json::from_json(&buffer[1..])?;
tx.send(content).await?;
buffer.clear();
}
Err(err) => {
tx.send(Err(Exception::new(err.to_string()))).await.unwrap();
break;
return Err(Exception::new(err.to_string()));
}
}
}
Ok(())
}

fn is_valid_json(content: &str) -> bool {
Expand Down
64 changes: 35 additions & 29 deletions src/openai/chatgpt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use reqwest_eventsource::CannotCloneRequestError;
use reqwest_eventsource::Event;
use reqwest_eventsource::EventSource;
use tokio::sync::mpsc::channel;
use tokio::sync::mpsc::Receiver;
use tokio::sync::mpsc::Sender;

use crate::bot::function::FunctionStore;
Expand Down Expand Up @@ -65,12 +66,15 @@ impl ChatGPT {

pub async fn chat(&mut self, message: String, handler: &impl ChatHandler) -> Result<(), Exception> {
self.add_message(ChatRequestMessage::new_message(Role::User, message));
let result = self.process(handler).await;
if let Ok(Some(InternalEvent::FunctionCall(calls))) = result {
let result = self.process(handler).await?;
if let Some(calls) = result {
self.add_message(ChatRequestMessage::new_function_call(&calls));

let mut functions = Vec::with_capacity(calls.len());
for (_, (id, name, args)) in calls {
functions.push((id, name, json::from_json::<serde_json::Value>(&args)?))
}

let results = self.function_store.call_functions(functions).await?;
for result in results {
let function_response = ChatRequestMessage::new_function_response(result.0, json::to_json(&result.1)?);
Expand All @@ -85,14 +89,19 @@ impl ChatGPT {
Rc::get_mut(&mut self.messages).unwrap().push(message);
}

async fn process(&mut self, handler: &impl ChatHandler) -> Result<Option<InternalEvent>, Exception> {
async fn process(&mut self, handler: &impl ChatHandler) -> Result<Option<FunctionCall>, Exception> {
let source = self.call_api().await?;

let (tx, mut rx) = channel(64);
tokio::spawn(async move {
process_event_source(source, tx).await;
});
let (tx, rx) = channel(64);
let handle = tokio::spawn(read_event_source(source, tx));

let function_call = self.process_event(rx, handler).await;
let _ = tokio::try_join!(handle)?;

Ok(function_call)
}

async fn process_event(&mut self, mut rx: Receiver<InternalEvent>, handler: &impl ChatHandler) -> Option<FunctionCall> {
let mut assistant_message = String::new();
while let Some(event) = rx.recv().await {
match event {
Expand All @@ -103,17 +112,14 @@ impl ChatGPT {
handler.on_event(event);
}
InternalEvent::FunctionCall(calls) => {
self.add_message(ChatRequestMessage::new_function_call(&calls));
return Ok(Some(InternalEvent::FunctionCall(calls)));
return Some(calls);
}
}
}

if !assistant_message.is_empty() {
self.add_message(ChatRequestMessage::new_message(Role::Assistant, assistant_message));
}

Ok(None)
None
}

async fn call_api(&mut self) -> Result<EventSource, Exception> {
Expand Down Expand Up @@ -147,7 +153,7 @@ impl From<CannotCloneRequestError> for Exception {
}
}

async fn process_event_source(mut source: EventSource, tx: Sender<InternalEvent>) {
async fn read_event_source(mut source: EventSource, tx: Sender<InternalEvent>) -> Result<(), Exception> {
let mut function_calls: FunctionCall = HashMap::new();
while let Some(event) = source.next().await {
match event {
Expand All @@ -160,34 +166,34 @@ async fn process_event_source(mut source: EventSource, tx: Sender<InternalEvent>
break;
}

let response: ChatResponse = json::from_json(&data).unwrap();
if response.choices.is_empty() {
continue;
}
let response: ChatResponse = json::from_json(&data)?;

let choice = response.choices.first().unwrap();
let delta = choice.delta.as_ref().unwrap();
if let Some(choice) = response.choices.into_iter().next() {
let delta = choice.delta.unwrap();

if let Some(tool_calls) = delta.tool_calls.as_ref() {
let call = tool_calls.first().unwrap();
if let Some(name) = &call.function.name {
function_calls.insert(call.index, (call.id.as_ref().unwrap().to_string(), name.to_string(), String::new()));
if let Some(tool_calls) = delta.tool_calls {
let call = tool_calls.into_iter().next().unwrap();
if let Some(name) = call.function.name {
function_calls.insert(call.index, (call.id.unwrap(), name, String::new()));
}
function_calls.get_mut(&call.index).unwrap().2.push_str(&call.function.arguments)
} else if let Some(value) = delta.content {
tx.send(InternalEvent::Event(ChatEvent::Delta(value))).await?;
}
function_calls.get_mut(&call.index).unwrap().2.push_str(&call.function.arguments)
} else if let Some(value) = delta.content.as_ref() {
tx.send(InternalEvent::Event(ChatEvent::Delta(value.to_string()))).await.unwrap();
}
}
Err(err) => {
tx.send(InternalEvent::Event(ChatEvent::Error(err.to_string()))).await.unwrap();
source.close();
return Err(Exception::new(err.to_string()));
}
}
}
if !function_calls.is_empty() {
tx.send(InternalEvent::FunctionCall(function_calls)).await.unwrap();
tx.send(InternalEvent::FunctionCall(function_calls)).await?;
} else {
// chatgpt doesn't support token usage with stream mode
tx.send(InternalEvent::Event(ChatEvent::End(Usage::default()))).await.unwrap();
tx.send(InternalEvent::Event(ChatEvent::End(Usage::default()))).await?;
}

Ok(())
}
7 changes: 7 additions & 0 deletions src/util/exception.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use std::error::Error;
use std::fmt;
use std::io;

use tokio::sync::mpsc::error::SendError;
use tokio::task::JoinError;

pub struct Exception {
Expand Down Expand Up @@ -70,3 +71,9 @@ impl From<JoinError> for Exception {
Exception::new(err.to_string())
}
}

impl<T> From<SendError<T>> for Exception {
fn from(err: SendError<T>) -> Self {
Exception::new(err.to_string())
}
}

0 comments on commit 9afcf22

Please sign in to comment.