Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
neowu committed Mar 21, 2024
1 parent 3c351b9 commit ee821d5
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 7 deletions.
11 changes: 6 additions & 5 deletions src/chatgpt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use std::sync::Arc;

use futures::stream::StreamExt;
use reqwest_eventsource::Event;
use reqwest_eventsource::EventSource;
use tokio::sync::mpsc::channel;

use crate::openai::chat_completion::ChatRequest;
Expand Down Expand Up @@ -64,19 +65,19 @@ impl ChatGPT {
}

pub async fn chat(&mut self, message: &str, handler: &dyn ChatHandler) -> Result<(), Box<dyn Error>> {
let result = self.send_request(ChatRequestMessage::new(Role::User, message), handler).await;
let result = self.process(ChatRequestMessage::new(Role::User, message), handler).await;
if let Ok(Some(InternalEvent::FunctionCall { name, arguments })) = result {
let function = Arc::clone(self.function_implementations.get(&name).unwrap());

let result = tokio::spawn(async move { function(arguments) }).await?;

self.send_request(ChatRequestMessage::new_function(name, result), handler).await?;
self.process(ChatRequestMessage::new_function(name, result), handler).await?;
}
Ok(())
}

async fn send_request(&mut self, message: ChatRequestMessage, handler: &dyn ChatHandler) -> Result<Option<InternalEvent>, Box<dyn Error>> {
let mut source = self.post_sse(message).await?;
async fn process(&mut self, message: ChatRequestMessage, handler: &dyn ChatHandler) -> Result<Option<InternalEvent>, Box<dyn Error>> {
let mut source = self.call_api(message).await?;

let (tx, mut rx) = channel(64);
tokio::spawn(async move {
Expand Down Expand Up @@ -162,7 +163,7 @@ impl ChatGPT {
Ok(None)
}

async fn post_sse(&mut self, message: ChatRequestMessage) -> Result<reqwest_eventsource::EventSource, Box<dyn Error>> {
async fn call_api(&mut self, message: ChatRequestMessage) -> Result<EventSource, Box<dyn Error>> {
let mut request = ChatRequest::new();
request.messages = mem::take(&mut self.messages);
request.messages.push(message);
Expand Down
6 changes: 4 additions & 2 deletions src/openai.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
use std::error::Error;
use std::fmt;
use std::sync::OnceLock;

use reqwest_eventsource::EventSource;
use serde::de::DeserializeOwned;
use serde::Serialize;

use crate::util::exception::Exception;
use crate::util::json;

pub mod chat_completion;

Expand Down Expand Up @@ -51,12 +53,12 @@ impl Client {

pub async fn post_sse<Request>(&self, request: &Request) -> Result<EventSource, Box<dyn Error>>
where
Request: Serialize,
Request: Serialize + fmt::Debug,
{
let endpoint = &self.endpoint;
let model = &self.model;
let url = format!("{endpoint}/openai/deployments/{model}/chat/completions?api-version=2024-02-15-preview");
let body = serde_json::to_string(request)?;
let body = json::to_json(&request)?;

let request = http_client()
.post(url)
Expand Down
11 changes: 11 additions & 0 deletions src/util/json.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
use std::fmt;

use crate::util::exception::Exception;
use serde::de;
use serde::Serialize;

pub fn from_json<'a, T>(json: &'a str) -> Result<T, Exception>
where
Expand All @@ -8,3 +11,11 @@ where
let result = serde_json::from_str(json);
result.map_err(|err| Exception::new(&format!("failed to deserialize json, error={err}, json={json}")))
}

pub fn to_json<T>(object: &T) -> Result<String, Exception>
where
T: Serialize + fmt::Debug,
{
let result = serde_json::to_string(object);
result.map_err(|err| Exception::new(&format!("failed to serialize json, error={err}, object={object:?}")))
}

0 comments on commit ee821d5

Please sign in to comment.