diff --git a/src/bot/function.rs b/src/bot/function.rs index 6a3fbec..c6f4acb 100644 --- a/src/bot/function.rs +++ b/src/bot/function.rs @@ -37,15 +37,18 @@ impl FunctionStore { self.implementations.insert(name, Arc::new(implementation)); } - pub async fn call_function(&self, name: &str, args: serde_json::Value) -> Result { - info!("call function, name={name}, args={args}"); - let function = self.get(name)?; - let response = tokio::spawn(async move { function(args) }).await?; + pub async fn call_function(&self, name: String, args: serde_json::Value) -> Result { + let function = self.get(&name)?; + let response = tokio::spawn(async move { + info!("call function, name={name}, args={args}"); + function(args) + }) + .await?; Ok(response) } pub async fn call_functions(&self, functions: Vec<(String, String, serde_json::Value)>) -> Result, Exception> { - let mut handles = vec![]; + let mut handles = Vec::with_capacity(functions.len()); for (id, name, args) in functions { let function = self.get(&name)?; handles.push(tokio::spawn(async move { diff --git a/src/command/chat.rs b/src/command/chat.rs index 9d760f6..4c655fe 100644 --- a/src/command/chat.rs +++ b/src/command/chat.rs @@ -30,7 +30,7 @@ impl ChatHandler for ConsoleHandler { print_flush(&data).unwrap(); } ChatEvent::Error(error) => { - println!("Error: {}", error); + println!("\nError: {}", error); } ChatEvent::End(usage) => { println!(); diff --git a/src/gcloud/api.rs b/src/gcloud/api.rs index 4ec23a6..7809866 100644 --- a/src/gcloud/api.rs +++ b/src/gcloud/api.rs @@ -136,7 +136,9 @@ pub struct GenerateContentResponse { #[derive(Debug, Deserialize)] pub struct Candidate { - pub content: Content, + pub content: Option, + #[serde(rename = "finishReason")] + pub finish_reason: Option, } #[derive(Debug, Serialize, Deserialize, Clone)] diff --git a/src/gcloud/vertex.rs b/src/gcloud/vertex.rs index fdb1102..57b1dab 100644 --- a/src/gcloud/vertex.rs +++ b/src/gcloud/vertex.rs @@ -10,6 +10,7 @@ use base64::Engine; use futures::StreamExt; use reqwest::Response; use tokio::sync::mpsc::channel; +use tokio::sync::mpsc::Receiver; use tokio::sync::mpsc::Sender; use tracing::info; @@ -67,7 +68,7 @@ impl Vertex { let mut result = self.process(Content::new_text_with_inline_data(message, data), handler).await?; while let Some(function_call) = result { - let function_response = self.function_store.call_function(&function_call.name, function_call.args).await?; + let function_response = self.function_store.call_function(function_call.name.clone(), function_call.args).await?; let content = Content::new_function_response(function_call.name, function_response); result = self.process(content, handler).await?; } @@ -103,35 +104,52 @@ impl Vertex { let response = self.call_api().await?; - let (tx, mut rx) = channel(64); + let (tx, rx) = channel(64); + let handle = tokio::spawn(read_response_stream(response, tx)); + let function_call = self.process_response(rx, handler).await; + let _ = tokio::try_join!(handle)?; - tokio::spawn(process_response_stream(response, tx)); + Ok(function_call) + } + async fn process_response(&mut self, mut rx: Receiver, handler: &impl ChatHandler) -> Option { let mut model_message = String::new(); while let Some(response) = rx.recv().await { - let response = response?; - if let Some(usage) = response.usage_metadata { self.usage.request_tokens += usage.prompt_token_count; self.usage.response_tokens += usage.candidates_token_count; } - let part = response.candidates.into_iter().next().unwrap().content.parts.into_iter().next().unwrap(); - - if let Some(function_call) = part.function_call { - self.add_message(Content::new_function_call(function_call.clone())); - return Ok(Some(function_call)); - } else if let Some(text) = part.text { - model_message.push_str(&text); - handler.on_event(ChatEvent::Delta(text)); + let candidate = response.candidates.into_iter().next().unwrap(); + match candidate.content { + Some(content) => { + let part = content.parts.into_iter().next().unwrap(); + + if let Some(function_call) = part.function_call { + self.add_message(Content::new_function_call(function_call.clone())); + return Some(function_call); + } else if let Some(text) = part.text { + model_message.push_str(&text); + handler.on_event(ChatEvent::Delta(text)); + } + } + None => { + handler.on_event(ChatEvent::Error(format!( + "response ended, finish_reason={}", + candidate.finish_reason.unwrap_or("".to_string()) + ))); + } } } + if !model_message.is_empty() { self.add_message(Content::new_text(Role::Model, model_message)); } + let usage = mem::take(&mut self.usage); handler.on_event(ChatEvent::End(usage)); - Ok(None) + + None } fn add_message(&mut self, content: Content) { @@ -174,7 +192,7 @@ impl Vertex { } } -async fn process_response_stream(response: Response, tx: Sender>) { +async fn read_response_stream(response: Response, tx: Sender) -> Result<(), Exception> { let stream = &mut response.bytes_stream(); let mut buffer = String::new(); @@ -188,16 +206,16 @@ async fn process_response_stream(response: Response, tx: Sender { - tx.send(Err(Exception::new(err.to_string()))).await.unwrap(); - break; + return Err(Exception::new(err.to_string())); } } } + Ok(()) } fn is_valid_json(content: &str) -> bool { diff --git a/src/openai/chatgpt.rs b/src/openai/chatgpt.rs index f2a4d0b..5fb39ad 100644 --- a/src/openai/chatgpt.rs +++ b/src/openai/chatgpt.rs @@ -7,6 +7,7 @@ use reqwest_eventsource::CannotCloneRequestError; use reqwest_eventsource::Event; use reqwest_eventsource::EventSource; use tokio::sync::mpsc::channel; +use tokio::sync::mpsc::Receiver; use tokio::sync::mpsc::Sender; use crate::bot::function::FunctionStore; @@ -65,12 +66,15 @@ impl ChatGPT { pub async fn chat(&mut self, message: String, handler: &impl ChatHandler) -> Result<(), Exception> { self.add_message(ChatRequestMessage::new_message(Role::User, message)); - let result = self.process(handler).await; - if let Ok(Some(InternalEvent::FunctionCall(calls))) = result { + let result = self.process(handler).await?; + if let Some(calls) = result { + self.add_message(ChatRequestMessage::new_function_call(&calls)); + let mut functions = Vec::with_capacity(calls.len()); for (_, (id, name, args)) in calls { functions.push((id, name, json::from_json::(&args)?)) } + let results = self.function_store.call_functions(functions).await?; for result in results { let function_response = ChatRequestMessage::new_function_response(result.0, json::to_json(&result.1)?); @@ -85,14 +89,19 @@ impl ChatGPT { Rc::get_mut(&mut self.messages).unwrap().push(message); } - async fn process(&mut self, handler: &impl ChatHandler) -> Result, Exception> { + async fn process(&mut self, handler: &impl ChatHandler) -> Result, Exception> { let source = self.call_api().await?; - let (tx, mut rx) = channel(64); - tokio::spawn(async move { - process_event_source(source, tx).await; - }); + let (tx, rx) = channel(64); + let handle = tokio::spawn(read_event_source(source, tx)); + + let function_call = self.process_event(rx, handler).await; + let _ = tokio::try_join!(handle)?; + + Ok(function_call) + } + async fn process_event(&mut self, mut rx: Receiver, handler: &impl ChatHandler) -> Option { let mut assistant_message = String::new(); while let Some(event) = rx.recv().await { match event { @@ -103,17 +112,14 @@ impl ChatGPT { handler.on_event(event); } InternalEvent::FunctionCall(calls) => { - self.add_message(ChatRequestMessage::new_function_call(&calls)); - return Ok(Some(InternalEvent::FunctionCall(calls))); + return Some(calls); } } } - if !assistant_message.is_empty() { self.add_message(ChatRequestMessage::new_message(Role::Assistant, assistant_message)); } - - Ok(None) + None } async fn call_api(&mut self) -> Result { @@ -147,7 +153,7 @@ impl From for Exception { } } -async fn process_event_source(mut source: EventSource, tx: Sender) { +async fn read_event_source(mut source: EventSource, tx: Sender) -> Result<(), Exception> { let mut function_calls: FunctionCall = HashMap::new(); while let Some(event) = source.next().await { match event { @@ -160,34 +166,34 @@ async fn process_event_source(mut source: EventSource, tx: Sender break; } - let response: ChatResponse = json::from_json(&data).unwrap(); - if response.choices.is_empty() { - continue; - } + let response: ChatResponse = json::from_json(&data)?; - let choice = response.choices.first().unwrap(); - let delta = choice.delta.as_ref().unwrap(); + if let Some(choice) = response.choices.into_iter().next() { + let delta = choice.delta.unwrap(); - if let Some(tool_calls) = delta.tool_calls.as_ref() { - let call = tool_calls.first().unwrap(); - if let Some(name) = &call.function.name { - function_calls.insert(call.index, (call.id.as_ref().unwrap().to_string(), name.to_string(), String::new())); + if let Some(tool_calls) = delta.tool_calls { + let call = tool_calls.into_iter().next().unwrap(); + if let Some(name) = call.function.name { + function_calls.insert(call.index, (call.id.unwrap(), name, String::new())); + } + function_calls.get_mut(&call.index).unwrap().2.push_str(&call.function.arguments) + } else if let Some(value) = delta.content { + tx.send(InternalEvent::Event(ChatEvent::Delta(value))).await?; } - function_calls.get_mut(&call.index).unwrap().2.push_str(&call.function.arguments) - } else if let Some(value) = delta.content.as_ref() { - tx.send(InternalEvent::Event(ChatEvent::Delta(value.to_string()))).await.unwrap(); } } Err(err) => { - tx.send(InternalEvent::Event(ChatEvent::Error(err.to_string()))).await.unwrap(); source.close(); + return Err(Exception::new(err.to_string())); } } } if !function_calls.is_empty() { - tx.send(InternalEvent::FunctionCall(function_calls)).await.unwrap(); + tx.send(InternalEvent::FunctionCall(function_calls)).await?; } else { // chatgpt doesn't support token usage with stream mode - tx.send(InternalEvent::Event(ChatEvent::End(Usage::default()))).await.unwrap(); + tx.send(InternalEvent::Event(ChatEvent::End(Usage::default()))).await?; } + + Ok(()) } diff --git a/src/util/exception.rs b/src/util/exception.rs index f6a465e..882c8ce 100644 --- a/src/util/exception.rs +++ b/src/util/exception.rs @@ -3,6 +3,7 @@ use std::error::Error; use std::fmt; use std::io; +use tokio::sync::mpsc::error::SendError; use tokio::task::JoinError; pub struct Exception { @@ -70,3 +71,9 @@ impl From for Exception { Exception::new(err.to_string()) } } + +impl From> for Exception { + fn from(err: SendError) -> Self { + Exception::new(err.to_string()) + } +}