Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(flow): fix call df func bug&sqlness test #4165

Merged
merged 16 commits into from
Jun 24, 2024
Merged
30 changes: 28 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions src/flow/src/adapter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,7 @@ impl FlownodeManager {
schema
.get_name(*i)
.clone()
.unwrap_or_else(|| format!("Col_{i}"))
.unwrap_or_else(|| format!("col_{i}"))
})
.collect_vec()
})
Expand All @@ -344,7 +344,7 @@ impl FlownodeManager {
.get(idx)
.cloned()
.flatten()
.unwrap_or(format!("Col_{}", idx));
.unwrap_or(format!("col_{}", idx));
let ret = ColumnSchema::new(name, typ.scalar_type, typ.nullable);
if schema.typ().time_index == Some(idx) {
ret.with_time_index(true)
Expand Down
2 changes: 2 additions & 0 deletions src/flow/src/adapter/table_source.rs
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,8 @@ impl TableSource {
column_types,
keys,
time_index,
// by default table schema's column are all non-auto
auto_columns: vec![],
},
names: col_names,
},
Expand Down
26 changes: 22 additions & 4 deletions src/flow/src/adapter/worker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -206,10 +206,15 @@ impl WorkerHandle {

impl Drop for WorkerHandle {
fn drop(&mut self) {
if let Err(err) = self.shutdown_blocking() {
common_telemetry::error!("Fail to shutdown worker: {:?}", err)
let ret = futures::executor::block_on(async { self.shutdown().await });
if let Err(ret) = ret {
common_telemetry::error!(
ret;
"While dropping Worker Handle, failed to shutdown worker, worker might be in inconsistent state."
);
} else {
info!("Flow Worker shutdown due to Worker Handle dropped.")
}
info!("Flow Worker shutdown due to Worker Handle dropped.")
}
}

Expand Down Expand Up @@ -496,6 +501,19 @@ mod test {
use crate::plan::Plan;
use crate::repr::{RelationType, Row};

#[test]
fn drop_handle() {
let (tx, rx) = oneshot::channel();
let worker_thread_handle = std::thread::spawn(move || {
let (handle, mut worker) = create_worker();
tx.send(handle).unwrap();
worker.run();
});
let handle = rx.blocking_recv().unwrap();
drop(handle);
worker_thread_handle.join().unwrap();
}

#[tokio::test]
pub async fn test_simple_get_with_worker_and_handle() {
let (tx, rx) = oneshot::channel();
Expand Down Expand Up @@ -532,7 +550,7 @@ mod test {
tx.send((Row::empty(), 0, 0)).unwrap();
handle.run_available(0).await.unwrap();
assert_eq!(sink_rx.recv().await.unwrap().0, Row::empty());
handle.shutdown().await.unwrap();
drop(handle);
worker_thread_handle.join().unwrap();
}
}
76 changes: 61 additions & 15 deletions src/flow/src/expr/scalar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@ use bytes::BytesMut;
use common_error::ext::BoxedError;
use common_recordbatch::DfRecordBatch;
use datafusion_physical_expr::PhysicalExpr;
use datatypes::arrow_array;
use datatypes::data_type::DataType;
use datatypes::prelude::ConcreteDataType;
use datatypes::value::Value;
use datatypes::{arrow_array, value};
use prost::Message;
use serde::{Deserialize, Serialize};
use snafu::{ensure, ResultExt};
Expand Down Expand Up @@ -155,8 +155,10 @@ pub enum ScalarExpr {
exprs: Vec<ScalarExpr>,
},
CallDf {
// TODO(discord9): support shuffle
/// invariant: the input args set inside this [`DfScalarFunction`] is
/// always col(0) to col(n-1) where n is the length of `expr`
df_scalar_fn: DfScalarFunction,
exprs: Vec<ScalarExpr>,
},
/// Conditionally evaluated expressions.
///
Expand Down Expand Up @@ -189,8 +191,27 @@ impl DfScalarFunction {
})
}

pub fn try_from_raw_fn(raw_fn: RawDfScalarFn) -> Result<Self, Error> {
Ok(Self {
fn_impl: raw_fn.get_fn_impl()?,
df_schema: Arc::new(raw_fn.input_schema.to_df_schema()?),
raw_fn,
})
}

/// eval a list of expressions using input values
fn eval_args(values: &[Value], exprs: &[ScalarExpr]) -> Result<Vec<Value>, EvalError> {
exprs
.iter()
.map(|expr| expr.eval(values))
.collect::<Result<_, _>>()
}

// TODO(discord9): add RecordBatch support
pub fn eval(&self, values: &[Value]) -> Result<Value, EvalError> {
pub fn eval(&self, values: &[Value], exprs: &[ScalarExpr]) -> Result<Value, EvalError> {
// first eval exprs to construct values to feed to datafusion
let values: Vec<_> = Self::eval_args(values, exprs)?;

if values.is_empty() {
return InvalidArgumentSnafu {
reason: "values is empty".to_string(),
Expand Down Expand Up @@ -259,16 +280,18 @@ impl<'de> serde::de::Deserialize<'de> for DfScalarFunction {
D: serde::de::Deserializer<'de>,
{
let raw_fn = RawDfScalarFn::deserialize(deserializer)?;
let fn_impl = raw_fn.get_fn_impl().map_err(serde::de::Error::custom)?;
DfScalarFunction::new(raw_fn, fn_impl).map_err(serde::de::Error::custom)
DfScalarFunction::try_from_raw_fn(raw_fn).map_err(serde::de::Error::custom)
}
}

#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct RawDfScalarFn {
f: bytes::BytesMut,
input_schema: RelationDesc,
extensions: FunctionExtensions,
/// The raw bytes encoded datafusion scalar function
pub(crate) f: bytes::BytesMut,
/// The input schema of the function
pub(crate) input_schema: RelationDesc,
/// Extension contains mapping from function reference to function name
pub(crate) extensions: FunctionExtensions,
}

impl RawDfScalarFn {
Expand Down Expand Up @@ -354,7 +377,7 @@ impl ScalarExpr {
Ok(ColumnType::new_nullable(func.signature().output))
}
ScalarExpr::If { then, .. } => then.typ(context),
ScalarExpr::CallDf { df_scalar_fn } => {
ScalarExpr::CallDf { df_scalar_fn, .. } => {
let arrow_typ = df_scalar_fn
.fn_impl
// TODO(discord9): get scheme from args instead?
Expand Down Expand Up @@ -445,7 +468,10 @@ impl ScalarExpr {
}
.fail(),
},
ScalarExpr::CallDf { df_scalar_fn } => df_scalar_fn.eval(values),
ScalarExpr::CallDf {
df_scalar_fn,
exprs,
} => df_scalar_fn.eval(values, exprs),
}
}

Expand Down Expand Up @@ -614,7 +640,15 @@ impl ScalarExpr {
f(then)?;
f(els)
}
_ => Ok(()),
ScalarExpr::CallDf {
df_scalar_fn: _,
exprs,
} => {
for expr in exprs {
f(expr)?;
}
Ok(())
}
}
}

Expand Down Expand Up @@ -650,7 +684,15 @@ impl ScalarExpr {
f(then)?;
f(els)
}
_ => Ok(()),
ScalarExpr::CallDf {
df_scalar_fn: _,
exprs,
} => {
for expr in exprs {
f(expr)?;
}
Ok(())
}
}
}
}
Expand Down Expand Up @@ -852,11 +894,15 @@ mod test {
.unwrap();
let extensions = FunctionExtensions::from_iter(vec![(0, "abs")]);
let raw_fn = RawDfScalarFn::from_proto(&raw_scalar_func, input_schema, extensions).unwrap();
let fn_impl = raw_fn.get_fn_impl().unwrap();
let df_func = DfScalarFunction::new(raw_fn, fn_impl).unwrap();
let df_func = DfScalarFunction::try_from_raw_fn(raw_fn).unwrap();
let as_str = serde_json::to_string(&df_func).unwrap();
let from_str: DfScalarFunction = serde_json::from_str(&as_str).unwrap();
assert_eq!(df_func, from_str);
assert_eq!(df_func.eval(&[Value::Null]).unwrap(), Value::Int64(1));
assert_eq!(
df_func
.eval(&[Value::Null], &[ScalarExpr::Column(0)])
.unwrap(),
Value::Int64(1)
);
}
}
28 changes: 26 additions & 2 deletions src/flow/src/repr/relation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@ use itertools::Itertools;
use serde::{Deserialize, Serialize};
use snafu::{ensure, OptionExt, ResultExt};

use crate::adapter::error::{DatafusionSnafu, InvalidQuerySnafu, Result, UnexpectedSnafu};
use crate::adapter::error::{
DatafusionSnafu, InternalSnafu, InvalidQuerySnafu, Result, UnexpectedSnafu,
};
use crate::expr::{MapFilterProject, SafeMfpPlan, ScalarExpr};

/// a set of column indices that are "keys" for the collection.
Expand Down Expand Up @@ -93,13 +95,19 @@ pub struct RelationType {
///
/// A collection can contain multiple sets of keys, although it is common to
/// have either zero or one sets of key indices.
#[serde(default)]
pub keys: Vec<Key>,
/// optionally indicate the column that is TIME INDEX
pub time_index: Option<usize>,
/// mark all the columns that are added automatically by flow, but are not present in original sql
pub auto_columns: Vec<usize>,
}

impl RelationType {
pub fn with_autos(mut self, auto_cols: &[usize]) -> Self {
self.auto_columns = auto_cols.to_vec();
self
}

/// Trying to apply a mpf on current types, will return a new RelationType
/// with the new types, will also try to preserve keys&time index information
/// if the old key&time index columns are preserve in given mfp
Expand Down Expand Up @@ -155,10 +163,16 @@ impl RelationType {
let time_index = self
.time_index
.and_then(|old| old_to_new_col.get(&old).cloned());
let auto_columns = self
.auto_columns
.iter()
.filter_map(|old| old_to_new_col.get(old).cloned())
.collect_vec();
Ok(Self {
column_types: mfp_out_types,
keys,
time_index,
auto_columns,
})
}
/// Constructs a `RelationType` representing the relation with no columns and
Expand All @@ -175,6 +189,7 @@ impl RelationType {
column_types,
keys: Vec::new(),
time_index: None,
auto_columns: vec![],
}
}

Expand Down Expand Up @@ -340,6 +355,15 @@ pub struct RelationDesc {
}

impl RelationDesc {
pub fn len(&self) -> Result<usize> {
ensure!(
self.typ.column_types.len() == self.names.len(),
InternalSnafu {
reason: "Expect typ and names field to be of same length"
}
);
Ok(self.names.len())
}
pub fn to_df_schema(&self) -> Result<DFSchema> {
discord9 marked this conversation as resolved.
Show resolved Hide resolved
let fields: Vec<_> = self
.iter()
Expand Down
Loading
Loading