Skip to content

Commit

Permalink
refactor(flow): make from_substrait_* async& worker handle refactor (
Browse files Browse the repository at this point in the history
…#4210)

* refactor: use oneshot to receive result

* refactor: make from_substrait_* async

* refacrot: remove serde for plan&expr
  • Loading branch information
discord9 authored Jun 27, 2024
1 parent 10b7a3d commit b6585e3
Show file tree
Hide file tree
Showing 16 changed files with 223 additions and 241 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions src/flow/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ workspace = true
[dependencies]
api.workspace = true
arrow-schema.workspace = true
async-recursion = "1.0"
async-trait.workspace = true
bytes.workspace = true
catalog.workspace = true
Expand Down
152 changes: 52 additions & 100 deletions src/flow/src/adapter/worker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,14 @@
//! For single-thread flow worker

use std::collections::{BTreeMap, VecDeque};
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;

use common_telemetry::info;
use enum_as_inner::EnumAsInner;
use hydroflow::scheduled::graph::Hydroflow;
use snafu::{ensure, OptionExt};
use tokio::sync::{broadcast, mpsc, Mutex};
use snafu::ensure;
use tokio::sync::{broadcast, mpsc, oneshot, Mutex};

use crate::adapter::error::{Error, FlowAlreadyExistSnafu, InternalSnafu, UnexpectedSnafu};
use crate::adapter::FlowId;
Expand All @@ -39,7 +39,7 @@ type ReqId = usize;
pub fn create_worker<'a>() -> (WorkerHandle, Worker<'a>) {
let (itc_client, itc_server) = create_inter_thread_call();
let worker_handle = WorkerHandle {
itc_client: Mutex::new(itc_client),
itc_client,
shutdown: AtomicBool::new(false),
};
let worker = Worker {
Expand Down Expand Up @@ -106,7 +106,7 @@ impl<'subgraph> ActiveDataflowState<'subgraph> {

#[derive(Debug)]
pub struct WorkerHandle {
itc_client: Mutex<InterThreadCallClient>,
itc_client: InterThreadCallClient,
shutdown: AtomicBool,
}

Expand All @@ -122,12 +122,7 @@ impl WorkerHandle {
}
);

let ret = self
.itc_client
.lock()
.await
.call_with_resp(create_reqs)
.await?;
let ret = self.itc_client.call_with_resp(create_reqs).await?;
ret.into_create().map_err(|ret| {
InternalSnafu {
reason: format!(
Expand All @@ -141,7 +136,8 @@ impl WorkerHandle {
/// remove task, return task id
pub async fn remove_flow(&self, flow_id: FlowId) -> Result<bool, Error> {
let req = Request::Remove { flow_id };
let ret = self.itc_client.lock().await.call_with_resp(req).await?;

let ret = self.itc_client.call_with_resp(req).await?;

ret.into_remove().map_err(|ret| {
InternalSnafu {
Expand All @@ -157,15 +153,12 @@ impl WorkerHandle {
///
/// the returned error is unrecoverable, and the worker should be shutdown/rebooted
pub async fn run_available(&self, now: repr::Timestamp) -> Result<(), Error> {
self.itc_client
.lock()
.await
.call_no_resp(Request::RunAvail { now })
self.itc_client.call_no_resp(Request::RunAvail { now })
}

pub async fn contains_flow(&self, flow_id: FlowId) -> Result<bool, Error> {
let req = Request::ContainTask { flow_id };
let ret = self.itc_client.lock().await.call_with_resp(req).await?;
let ret = self.itc_client.call_with_resp(req).await?;

ret.into_contain_task().map_err(|ret| {
InternalSnafu {
Expand All @@ -178,23 +171,9 @@ impl WorkerHandle {
}

/// shutdown the worker
pub async fn shutdown(&self) -> Result<(), Error> {
pub fn shutdown(&self) -> Result<(), Error> {
if !self.shutdown.fetch_or(true, Ordering::SeqCst) {
self.itc_client.lock().await.call_no_resp(Request::Shutdown)
} else {
UnexpectedSnafu {
reason: "Worker already shutdown",
}
.fail()
}
}

/// shutdown the worker
pub fn shutdown_blocking(&self) -> Result<(), Error> {
if !self.shutdown.fetch_or(true, Ordering::SeqCst) {
self.itc_client
.blocking_lock()
.call_no_resp(Request::Shutdown)
self.itc_client.call_no_resp(Request::Shutdown)
} else {
UnexpectedSnafu {
reason: "Worker already shutdown",
Expand All @@ -206,8 +185,7 @@ impl WorkerHandle {

impl Drop for WorkerHandle {
fn drop(&mut self) {
let ret = futures::executor::block_on(async { self.shutdown().await });
if let Err(ret) = ret {
if let Err(ret) = self.shutdown() {
common_telemetry::error!(
ret;
"While dropping Worker Handle, failed to shutdown worker, worker might be in inconsistent state."
Expand Down Expand Up @@ -276,7 +254,7 @@ impl<'s> Worker<'s> {
/// Run the worker, blocking, until shutdown signal is received
pub fn run(&mut self) {
loop {
let (req_id, req) = if let Some(ret) = self.itc_server.blocking_lock().blocking_recv() {
let (req, ret_tx) = if let Some(ret) = self.itc_server.blocking_lock().blocking_recv() {
ret
} else {
common_telemetry::error!(
Expand All @@ -285,19 +263,26 @@ impl<'s> Worker<'s> {
break;
};

let ret = self.handle_req(req_id, req);
match ret {
Ok(Some((id, resp))) => {
if let Err(err) = self.itc_server.blocking_lock().resp(id, resp) {
let ret = self.handle_req(req);
match (ret, ret_tx) {
(Ok(Some(resp)), Some(ret_tx)) => {
if let Err(err) = ret_tx.send(resp) {
common_telemetry::error!(
err;
"Worker's itc server has been closed unexpectedly, shutting down worker"
"Result receiver is dropped, can't send result"
);
break;
};
}
Ok(None) => continue,
Err(()) => {
(Ok(None), None) => continue,
(Ok(Some(resp)), None) => {
common_telemetry::error!(
"Expect no result for current request, but found {resp:?}"
)
}
(Ok(None), Some(_)) => {
common_telemetry::error!("Expect result for current request, but found nothing")
}
(Err(()), _) => {
break;
}
}
Expand All @@ -315,7 +300,7 @@ impl<'s> Worker<'s> {
/// handle request, return response if any, Err if receive shutdown signal
///
/// return `Err(())` if receive shutdown request
fn handle_req(&mut self, req_id: ReqId, req: Request) -> Result<Option<(ReqId, Response)>, ()> {
fn handle_req(&mut self, req: Request) -> Result<Option<Response>, ()> {
let ret = match req {
Request::Create {
flow_id,
Expand All @@ -339,24 +324,21 @@ impl<'s> Worker<'s> {
create_if_not_exists,
err_collector,
);
Some((
req_id,
Response::Create {
result: task_create_result,
},
))
Some(Response::Create {
result: task_create_result,
})
}
Request::Remove { flow_id } => {
let ret = self.remove_flow(flow_id);
Some((req_id, Response::Remove { result: ret }))
Some(Response::Remove { result: ret })
}
Request::RunAvail { now } => {
self.run_tick(now);
None
}
Request::ContainTask { flow_id } => {
let ret = self.task_states.contains_key(&flow_id);
Some((req_id, Response::ContainTask { result: ret }))
Some(Response::ContainTask { result: ret })
}
Request::Shutdown => return Err(()),
};
Expand Down Expand Up @@ -406,83 +388,50 @@ enum Response {

fn create_inter_thread_call() -> (InterThreadCallClient, InterThreadCallServer) {
let (arg_send, arg_recv) = mpsc::unbounded_channel();
let (ret_send, ret_recv) = mpsc::unbounded_channel();
let client = InterThreadCallClient {
call_id: AtomicUsize::new(0),
arg_sender: arg_send,
ret_recv,
};
let server = InterThreadCallServer {
arg_recv,
ret_sender: ret_send,
};
let server = InterThreadCallServer { arg_recv };
(client, server)
}

#[derive(Debug)]
struct InterThreadCallClient {
call_id: AtomicUsize,
arg_sender: mpsc::UnboundedSender<(ReqId, Request)>,
ret_recv: mpsc::UnboundedReceiver<(ReqId, Response)>,
arg_sender: mpsc::UnboundedSender<(Request, Option<oneshot::Sender<Response>>)>,
}

impl InterThreadCallClient {
/// call without expecting responses or blocking
fn call_no_resp(&self, req: Request) -> Result<(), Error> {
// TODO(discord9): relax memory order later
let call_id = self.call_id.fetch_add(1, Ordering::SeqCst);
self.arg_sender
.send((call_id, req))
.map_err(from_send_error)
self.arg_sender.send((req, None)).map_err(from_send_error)
}

/// call blocking, and return the result
async fn call_with_resp(&mut self, req: Request) -> Result<Response, Error> {
// TODO(discord9): relax memory order later
let call_id = self.call_id.fetch_add(1, Ordering::SeqCst);
async fn call_with_resp(&self, req: Request) -> Result<Response, Error> {
let (tx, rx) = oneshot::channel();
self.arg_sender
.send((call_id, req))
.send((req, Some(tx)))
.map_err(from_send_error)?;

// TODO(discord9): better inter thread call impl, i.e. support multiple client(also consider if it's necessary)
// since one node manger might manage multiple worker, but one worker should only belong to one node manager
let (ret_call_id, ret) = self
.ret_recv
.recv()
.await
.context(InternalSnafu { reason: "InterThreadCallClient call_blocking failed, ret_recv has been closed and there are no remaining messages in the channel's buffer" })?;

ensure!(
ret_call_id == call_id,
rx.await.map_err(|_| {
InternalSnafu {
reason: "call id mismatch, worker/worker handler should be in sync",
reason: "Sender is dropped",
}
);
Ok(ret)
.build()
})
}
}

#[derive(Debug)]
struct InterThreadCallServer {
pub arg_recv: mpsc::UnboundedReceiver<(ReqId, Request)>,
pub ret_sender: mpsc::UnboundedSender<(ReqId, Response)>,
pub arg_recv: mpsc::UnboundedReceiver<(Request, Option<oneshot::Sender<Response>>)>,
}

impl InterThreadCallServer {
pub async fn recv(&mut self) -> Option<(usize, Request)> {
pub async fn recv(&mut self) -> Option<(Request, Option<oneshot::Sender<Response>>)> {
self.arg_recv.recv().await
}

pub fn blocking_recv(&mut self) -> Option<(usize, Request)> {
pub fn blocking_recv(&mut self) -> Option<(Request, Option<oneshot::Sender<Response>>)> {
self.arg_recv.blocking_recv()
}

/// Send response back to the client
pub fn resp(&self, call_id: ReqId, resp: Response) -> Result<(), Error> {
self.ret_sender
.send((call_id, resp))
.map_err(from_send_error)
}
}

fn from_send_error<T>(err: mpsc::error::SendError<T>) -> Error {
Expand Down Expand Up @@ -546,7 +495,10 @@ mod test {
create_if_not_exists: true,
err_collector: ErrCollector::default(),
};
handle.create_flow(create_reqs).await.unwrap();
assert_eq!(
handle.create_flow(create_reqs).await.unwrap(),
Some(flow_id)
);
tx.send((Row::empty(), 0, 0)).unwrap();
handle.run_available(0).await.unwrap();
assert_eq!(sink_rx.recv().await.unwrap().0, Row::empty());
Expand Down
2 changes: 1 addition & 1 deletion src/flow/src/expr/func.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ use crate::repr::{self, value_to_internal_ts, Row};

/// UnmaterializableFunc is a function that can't be eval independently,
/// and require special handling
#[derive(Ord, PartialOrd, Clone, Debug, Eq, PartialEq, Serialize, Deserialize, Hash)]
#[derive(Ord, PartialOrd, Clone, Debug, Eq, PartialEq, Hash)]
pub enum UnmaterializableFunc {
Now,
CurrentSchema,
Expand Down
4 changes: 2 additions & 2 deletions src/flow/src/expr/linear.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ use crate::repr::{self, value_to_internal_ts, Diff, Row};
/// expressions in `self.expressions`, even though this is not something
/// we can directly evaluate. The plan creation methods will defensively
/// ensure that the right thing happens.
#[derive(Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize)]
#[derive(Clone, Debug, Eq, PartialEq, Ord, PartialOrd)]
pub struct MapFilterProject {
/// A sequence of expressions that should be appended to the row.
///
Expand Down Expand Up @@ -462,7 +462,7 @@ impl MapFilterProject {
}

/// A wrapper type which indicates it is safe to simply evaluate all expressions.
#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq, Ord, PartialOrd)]
#[derive(Clone, Debug, Eq, PartialEq, Ord, PartialOrd)]
pub struct SafeMfpPlan {
/// the inner `MapFilterProject` that is safe to evaluate.
pub(crate) mfp: MapFilterProject,
Expand Down
3 changes: 1 addition & 2 deletions src/flow/src/expr/relation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ mod accum;
mod func;

/// Describes an aggregation expression.
#[derive(Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize)]
#[derive(Clone, Debug, Eq, PartialEq, Ord, PartialOrd)]
pub struct AggregateExpr {
/// Names the aggregation function.
pub func: AggregateFunc,
Expand All @@ -32,6 +32,5 @@ pub struct AggregateExpr {
/// so it only used in generate KeyValPlan from AggregateExpr
pub expr: ScalarExpr,
/// Should the aggregation be applied only to distinct results in each group.
#[serde(default)]
pub distinct: bool,
}
Loading

0 comments on commit b6585e3

Please sign in to comment.