Skip to content

Commit

Permalink
refactor chatgpt api
Browse files Browse the repository at this point in the history
  • Loading branch information
neowu committed Jul 10, 2024
1 parent 45c3f21 commit bc65bb9
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 76 deletions.
113 changes: 47 additions & 66 deletions src/azure/chatgpt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use std::collections::HashMap;
use std::ops::Not;
use std::rc::Rc;
use std::str;
use std::str::Utf8Error;

use bytes::Bytes;
use reqwest::Response;
Expand All @@ -10,8 +11,6 @@ use tokio::sync::mpsc::Receiver;
use tokio::sync::mpsc::Sender;
use tracing::info;

use super::chatgpt_api::ImageContent;
use super::chatgpt_api::ImageUrl;
use crate::azure::chatgpt_api::ChatRequest;
use crate::azure::chatgpt_api::ChatRequestMessage;
use crate::azure::chatgpt_api::ChatResponse;
Expand Down Expand Up @@ -62,30 +61,7 @@ impl ChatGPT {
}

pub async fn chat(&mut self, message: String, handler: &impl ChatHandler) -> Result<(), Exception> {
if message == "1" {
self.add_message(ChatRequestMessage {
role: Role::User,
content: None,
image_content: Some(vec![
ImageContent {
r#type: "text".to_string(),
text: Some("what is in picture".to_string()),
image_url: None,
},
ImageContent {
r#type: "image_url".to_string(),
text: None,
image_url: Some(ImageUrl {
url: "https://learn.microsoft.com/en-us/azure/ai-services/computer-vision/media/quickstarts/presentation.png".to_string(),
}),
},
]),
tool_call_id: None,
tool_calls: None,
});
} else {
self.add_message(ChatRequestMessage::new_message(Role::User, message));
}
self.add_message(ChatRequestMessage::new_message(Role::User, message));

let result = self.process(handler).await?;
if let Some(calls) = result {
Expand Down Expand Up @@ -115,23 +91,51 @@ impl ChatGPT {

let response = self.call_api().await?;
let handle = tokio::spawn(read_sse(response, tx));
self.process_event(rx, handler).await;
let function_call = handle.await??;
let function_call = self.process_response(rx, handler).await?;
handle.await??;

Ok(function_call)
}

async fn process_event(&mut self, mut rx: Receiver<ChatEvent>, handler: &impl ChatHandler) {
async fn process_response(&mut self, mut rx: Receiver<ChatResponse>, handler: &impl ChatHandler) -> Result<Option<FunctionCall>, Exception> {
let mut function_calls: FunctionCall = HashMap::new();
let mut usage = Usage::default();
let mut assistant_message = String::new();
while let Some(event) = rx.recv().await {
if let ChatEvent::Delta(ref data) = event {
assistant_message.push_str(data);

while let Some(response) = rx.recv().await {
if let Some(choice) = response.choices.into_iter().next() {
let delta = choice.delta.unwrap();

if let Some(tool_calls) = delta.tool_calls {
let call = tool_calls.into_iter().next().unwrap();
if let Some(name) = call.function.name {
function_calls.insert(call.index, (call.id.unwrap(), name, String::new()));
}
function_calls.get_mut(&call.index).unwrap().2.push_str(&call.function.arguments)
} else if let Some(value) = delta.content {
assistant_message.push_str(&value);
handler.on_event(ChatEvent::Delta(value));
}
}

if let Some(value) = response.usage {
usage = Usage {
request_tokens: value.prompt_tokens,
response_tokens: value.completion_tokens,
};
}
handler.on_event(event);
}

if !assistant_message.is_empty() {
self.add_message(ChatRequestMessage::new_message(Role::Assistant, assistant_message));
}

if !function_calls.is_empty() {
Ok(Some(function_calls))
} else {
handler.on_event(ChatEvent::End(usage));
Ok(None)
}
}

async fn call_api(&mut self) -> Result<Response, Exception> {
Expand Down Expand Up @@ -173,57 +177,34 @@ impl ChatGPT {
}
}

async fn read_sse(response: Response, tx: Sender<ChatEvent>) -> Result<Option<FunctionCall>, Exception> {
let mut function_calls: FunctionCall = HashMap::new();
let mut usage = Usage::default();

async fn read_sse(response: Response, tx: Sender<ChatResponse>) -> Result<(), Exception> {
let mut buffer = String::with_capacity(1024);
let mut response = response;
'outer: while let Some(chunk) = response.chunk().await? {
buffer.push_str(str::from_utf8(&chunk).unwrap());
while let Some(chunk) = response.chunk().await? {
buffer.push_str(str::from_utf8(&chunk)?);

while let Some(index) = buffer.find("\n\n") {
if buffer.starts_with("data:") {
let data = &buffer[6..index];

if data == "[DONE]" {
break 'outer;
return Ok(());
}

let response: ChatResponse = json::from_json(data)?;

if let Some(choice) = response.choices.into_iter().next() {
let delta = choice.delta.unwrap();

if let Some(tool_calls) = delta.tool_calls {
let call = tool_calls.into_iter().next().unwrap();
if let Some(name) = call.function.name {
function_calls.insert(call.index, (call.id.unwrap(), name, String::new()));
}
function_calls.get_mut(&call.index).unwrap().2.push_str(&call.function.arguments)
} else if let Some(value) = delta.content {
tx.send(ChatEvent::Delta(value)).await?;
}
}

if let Some(value) = response.usage {
usage = Usage {
request_tokens: value.prompt_tokens,
response_tokens: value.completion_tokens,
};
}
tx.send(response).await?;

buffer.replace_range(0..index + 2, "");
} else {
return Err(Exception::unexpected(format!("unexpected sse message, buffer={}", buffer)));
}
}
}
Ok(())
}

if !function_calls.is_empty() {
Ok(Some(function_calls))
} else {
tx.send(ChatEvent::End(usage)).await?;
Ok(None)
impl From<Utf8Error> for Exception {
fn from(err: Utf8Error) -> Self {
Exception::unexpected(err)
}
}
21 changes: 12 additions & 9 deletions src/azure/chatgpt_api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,17 +29,15 @@ pub struct ChatRequest {
pub struct ChatRequestMessage {
pub role: Role,
#[serde(skip_serializing_if = "Option::is_none")]
pub content: Option<String>,
#[serde(rename = "content", skip_serializing_if = "Option::is_none")]
pub image_content: Option<Vec<ImageContent>>,
pub content: Option<Vec<Content>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_call_id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_calls: Option<Vec<ToolCall>>,
}

#[derive(Debug, Serialize)]
pub struct ImageContent {
pub struct Content {
pub r#type: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub text: Option<String>,
Expand All @@ -61,8 +59,11 @@ impl ChatRequestMessage {
pub fn new_message(role: Role, message: String) -> Self {
ChatRequestMessage {
role,
content: Some(message),
image_content: None,
content: Some(vec![Content {
r#type: "text".to_string(),
text: Some(message),
image_url: None,
}]),
tool_call_id: None,
tool_calls: None,
}
Expand All @@ -71,8 +72,11 @@ impl ChatRequestMessage {
pub fn new_function_response(id: String, result: String) -> Self {
ChatRequestMessage {
role: Role::Tool,
content: Some(result),
image_content: None,
content: Some(vec![Content {
r#type: "text".to_string(),
text: Some(result),
image_url: None,
}]),
tool_call_id: Some(id),
tool_calls: None,
}
Expand All @@ -82,7 +86,6 @@ impl ChatRequestMessage {
ChatRequestMessage {
role: Role::Assistant,
content: None,
image_content: None,
tool_call_id: None,
tool_calls: Some(
calls
Expand Down
1 change: 0 additions & 1 deletion src/gcloud/gemini_api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,6 @@ pub struct Part {
pub text: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub inline_data: Option<InlineData>,

#[serde(rename = "functionCall")]
#[serde(skip_serializing_if = "Option::is_none")]
pub function_call: Option<FunctionCall>,
Expand Down

0 comments on commit bc65bb9

Please sign in to comment.