From bc65bb957076c793ed1fad9cd101174718dcd9e2 Mon Sep 17 00:00:00 2001 From: neo <1100909+neowu@users.noreply.github.com> Date: Wed, 10 Jul 2024 11:52:33 +0800 Subject: [PATCH] refactor chatgpt api --- src/azure/chatgpt.rs | 113 ++++++++++++++++----------------------- src/azure/chatgpt_api.rs | 21 ++++---- src/gcloud/gemini_api.rs | 1 - 3 files changed, 59 insertions(+), 76 deletions(-) diff --git a/src/azure/chatgpt.rs b/src/azure/chatgpt.rs index c896694..6872211 100644 --- a/src/azure/chatgpt.rs +++ b/src/azure/chatgpt.rs @@ -2,6 +2,7 @@ use std::collections::HashMap; use std::ops::Not; use std::rc::Rc; use std::str; +use std::str::Utf8Error; use bytes::Bytes; use reqwest::Response; @@ -10,8 +11,6 @@ use tokio::sync::mpsc::Receiver; use tokio::sync::mpsc::Sender; use tracing::info; -use super::chatgpt_api::ImageContent; -use super::chatgpt_api::ImageUrl; use crate::azure::chatgpt_api::ChatRequest; use crate::azure::chatgpt_api::ChatRequestMessage; use crate::azure::chatgpt_api::ChatResponse; @@ -62,30 +61,7 @@ impl ChatGPT { } pub async fn chat(&mut self, message: String, handler: &impl ChatHandler) -> Result<(), Exception> { - if message == "1" { - self.add_message(ChatRequestMessage { - role: Role::User, - content: None, - image_content: Some(vec![ - ImageContent { - r#type: "text".to_string(), - text: Some("what is in picture".to_string()), - image_url: None, - }, - ImageContent { - r#type: "image_url".to_string(), - text: None, - image_url: Some(ImageUrl { - url: "https://learn.microsoft.com/en-us/azure/ai-services/computer-vision/media/quickstarts/presentation.png".to_string(), - }), - }, - ]), - tool_call_id: None, - tool_calls: None, - }); - } else { - self.add_message(ChatRequestMessage::new_message(Role::User, message)); - } + self.add_message(ChatRequestMessage::new_message(Role::User, message)); let result = self.process(handler).await?; if let Some(calls) = result { @@ -115,23 +91,51 @@ impl ChatGPT { let response = self.call_api().await?; let handle = tokio::spawn(read_sse(response, tx)); - self.process_event(rx, handler).await; - let function_call = handle.await??; + let function_call = self.process_response(rx, handler).await?; + handle.await??; Ok(function_call) } - async fn process_event(&mut self, mut rx: Receiver, handler: &impl ChatHandler) { + async fn process_response(&mut self, mut rx: Receiver, handler: &impl ChatHandler) -> Result, Exception> { + let mut function_calls: FunctionCall = HashMap::new(); + let mut usage = Usage::default(); let mut assistant_message = String::new(); - while let Some(event) = rx.recv().await { - if let ChatEvent::Delta(ref data) = event { - assistant_message.push_str(data); + + while let Some(response) = rx.recv().await { + if let Some(choice) = response.choices.into_iter().next() { + let delta = choice.delta.unwrap(); + + if let Some(tool_calls) = delta.tool_calls { + let call = tool_calls.into_iter().next().unwrap(); + if let Some(name) = call.function.name { + function_calls.insert(call.index, (call.id.unwrap(), name, String::new())); + } + function_calls.get_mut(&call.index).unwrap().2.push_str(&call.function.arguments) + } else if let Some(value) = delta.content { + assistant_message.push_str(&value); + handler.on_event(ChatEvent::Delta(value)); + } + } + + if let Some(value) = response.usage { + usage = Usage { + request_tokens: value.prompt_tokens, + response_tokens: value.completion_tokens, + }; } - handler.on_event(event); } + if !assistant_message.is_empty() { self.add_message(ChatRequestMessage::new_message(Role::Assistant, assistant_message)); } + + if !function_calls.is_empty() { + Ok(Some(function_calls)) + } else { + handler.on_event(ChatEvent::End(usage)); + Ok(None) + } } async fn call_api(&mut self) -> Result { @@ -173,45 +177,22 @@ impl ChatGPT { } } -async fn read_sse(response: Response, tx: Sender) -> Result, Exception> { - let mut function_calls: FunctionCall = HashMap::new(); - let mut usage = Usage::default(); - +async fn read_sse(response: Response, tx: Sender) -> Result<(), Exception> { let mut buffer = String::with_capacity(1024); let mut response = response; - 'outer: while let Some(chunk) = response.chunk().await? { - buffer.push_str(str::from_utf8(&chunk).unwrap()); + while let Some(chunk) = response.chunk().await? { + buffer.push_str(str::from_utf8(&chunk)?); while let Some(index) = buffer.find("\n\n") { if buffer.starts_with("data:") { let data = &buffer[6..index]; if data == "[DONE]" { - break 'outer; + return Ok(()); } let response: ChatResponse = json::from_json(data)?; - - if let Some(choice) = response.choices.into_iter().next() { - let delta = choice.delta.unwrap(); - - if let Some(tool_calls) = delta.tool_calls { - let call = tool_calls.into_iter().next().unwrap(); - if let Some(name) = call.function.name { - function_calls.insert(call.index, (call.id.unwrap(), name, String::new())); - } - function_calls.get_mut(&call.index).unwrap().2.push_str(&call.function.arguments) - } else if let Some(value) = delta.content { - tx.send(ChatEvent::Delta(value)).await?; - } - } - - if let Some(value) = response.usage { - usage = Usage { - request_tokens: value.prompt_tokens, - response_tokens: value.completion_tokens, - }; - } + tx.send(response).await?; buffer.replace_range(0..index + 2, ""); } else { @@ -219,11 +200,11 @@ async fn read_sse(response: Response, tx: Sender) -> Result for Exception { + fn from(err: Utf8Error) -> Self { + Exception::unexpected(err) } } diff --git a/src/azure/chatgpt_api.rs b/src/azure/chatgpt_api.rs index 9fb73b6..5faf2f3 100644 --- a/src/azure/chatgpt_api.rs +++ b/src/azure/chatgpt_api.rs @@ -29,9 +29,7 @@ pub struct ChatRequest { pub struct ChatRequestMessage { pub role: Role, #[serde(skip_serializing_if = "Option::is_none")] - pub content: Option, - #[serde(rename = "content", skip_serializing_if = "Option::is_none")] - pub image_content: Option>, + pub content: Option>, #[serde(skip_serializing_if = "Option::is_none")] pub tool_call_id: Option, #[serde(skip_serializing_if = "Option::is_none")] @@ -39,7 +37,7 @@ pub struct ChatRequestMessage { } #[derive(Debug, Serialize)] -pub struct ImageContent { +pub struct Content { pub r#type: String, #[serde(skip_serializing_if = "Option::is_none")] pub text: Option, @@ -61,8 +59,11 @@ impl ChatRequestMessage { pub fn new_message(role: Role, message: String) -> Self { ChatRequestMessage { role, - content: Some(message), - image_content: None, + content: Some(vec![Content { + r#type: "text".to_string(), + text: Some(message), + image_url: None, + }]), tool_call_id: None, tool_calls: None, } @@ -71,8 +72,11 @@ impl ChatRequestMessage { pub fn new_function_response(id: String, result: String) -> Self { ChatRequestMessage { role: Role::Tool, - content: Some(result), - image_content: None, + content: Some(vec![Content { + r#type: "text".to_string(), + text: Some(result), + image_url: None, + }]), tool_call_id: Some(id), tool_calls: None, } @@ -82,7 +86,6 @@ impl ChatRequestMessage { ChatRequestMessage { role: Role::Assistant, content: None, - image_content: None, tool_call_id: None, tool_calls: Some( calls diff --git a/src/gcloud/gemini_api.rs b/src/gcloud/gemini_api.rs index c233325..2521199 100644 --- a/src/gcloud/gemini_api.rs +++ b/src/gcloud/gemini_api.rs @@ -102,7 +102,6 @@ pub struct Part { pub text: Option, #[serde(skip_serializing_if = "Option::is_none")] pub inline_data: Option, - #[serde(rename = "functionCall")] #[serde(skip_serializing_if = "Option::is_none")] pub function_call: Option,