Skip to content

Commit

Permalink
refactor syntax
Browse files Browse the repository at this point in the history
  • Loading branch information
neowu committed Apr 18, 2024
1 parent 7f07333 commit 1eccfea
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 26 deletions.
31 changes: 14 additions & 17 deletions src/bot/function.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
use std::collections::HashMap;
use std::sync::Arc;

use futures::future::join_all;
use futures::future::try_join_all;
use serde::Serialize;
use tokio::task::JoinHandle;
use tracing::info;

use crate::util::exception::Exception;
Expand Down Expand Up @@ -46,25 +45,23 @@ impl FunctionStore {
}

pub async fn call_functions(&self, functions: Vec<(String, String, serde_json::Value)>) -> Result<Vec<(String, serde_json::Value)>, Exception> {
let handles: Result<Vec<JoinHandle<_>>, _> = functions
.into_iter()
.map(|(id, name, args)| {
let mut handles = vec![];
for (id, name, args) in functions {
let function = self.get(&name)?;
handles.push(tokio::spawn(async move {
info!("call function, id={id}, name={name}, args={args}");
let function = self.get(&name)?;
Ok::<JoinHandle<_>, Exception>(tokio::spawn(async move { (id, function(args)) }))
})
.collect();

let results = join_all(handles?).await.into_iter().collect::<Result<Vec<_>, _>>()?;
(id, function(args))
}));
}
let results = try_join_all(handles).await?;
Ok(results)
}

fn get(&self, name: &str) -> Result<Arc<Box<FunctionImplementation>>, Exception> {
let function = Arc::clone(
self.implementations
.get(name)
.ok_or_else(|| Exception::new(format!("function not found, name={name}")))?,
);
Ok(function)
let function = self
.implementations
.get(name)
.ok_or_else(|| Exception::new(format!("function not found, name={name}")))?;
Ok(Arc::clone(function))
}
}
14 changes: 6 additions & 8 deletions src/gcloud/vertex.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,9 +104,7 @@ impl Vertex {

let (tx, mut rx) = channel(64);

tokio::spawn(async move {
process_response_stream(response, tx).await;
});
tokio::spawn(process_response_stream(response, tx));

let mut model_message = String::new();
while let Some(response) = rx.recv().await {
Expand All @@ -119,12 +117,12 @@ impl Vertex {

let part = response.candidates.into_iter().next().unwrap().content.parts.into_iter().next().unwrap();

if let Some(function) = part.function_call {
self.add_message(Content::new_function_call(function.clone()));
return Ok(Some(function));
if let Some(function_call) = part.function_call {
self.add_message(Content::new_function_call(function_call.clone()));
return Ok(Some(function_call));
} else if let Some(text) = part.text {
handler.on_event(ChatEvent::Delta(text.clone()));
model_message.push_str(&text);
handler.on_event(ChatEvent::Delta(text));
}
}
if !model_message.is_empty() {
Expand Down Expand Up @@ -158,7 +156,7 @@ impl Vertex {

async fn post(&self, request: StreamGenerateContent) -> Result<Response, Exception> {
let body = json::to_json(&request)?;
info!("body={body}");
// info!("body={body}");
let response = http_client::http_client()
.post(&self.url)
.bearer_auth(token())
Expand Down
2 changes: 1 addition & 1 deletion src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ pub enum Command {

#[tokio::main]
async fn main() -> Result<(), Exception> {
tracing_subscriber::fmt::init();
tracing_subscriber::fmt().with_thread_ids(true).init();
let cli = Cli::parse();
match cli.command {
Some(Command::GenerateZshCompletion(command)) => command.execute(),
Expand Down

0 comments on commit 1eccfea

Please sign in to comment.