Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(flow): impl ScalarExpr&Scalar Function #3283

Merged
merged 14 commits into from
Feb 21, 2024
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions src/flow/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ common-time.workspace = true
datatypes.workspace = true
hydroflow = "0.5.0"
itertools.workspace = true
num-traits = "0.2"
serde.workspace = true
servers.workspace = true
session.workspace = true
Expand Down
3 changes: 3 additions & 0 deletions src/flow/src/expr/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,4 +58,7 @@ pub enum EvalError {

#[snafu(display("Optimize error: {reason}"))]
Optimize { reason: String, location: Location },

#[snafu(display("Unsupported temporal filter: {reason}"))]
UnsupportedTemporalFilter { reason: String, location: Location },
}
296 changes: 289 additions & 7 deletions src/flow/src/expr/func.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,12 @@ use hydroflow::bincode::Error;
use serde::{Deserialize, Serialize};
use snafu::ResultExt;

use super::ScalarExpr;
use crate::expr::error::CastValueSnafu;
use crate::expr::InvalidArgumentSnafu;
// TODO(discord9): more function & eval
use crate::{
expr::error::{EvalError, TryFromValueSnafu, TypeMismatchSnafu},
repr::Row,
use crate::expr::error::{
CastValueSnafu, DivisionByZeroSnafu, EvalError, InternalSnafu, TryFromValueSnafu,
TypeMismatchSnafu,
};
use crate::expr::{InvalidArgumentSnafu, ScalarExpr};
use crate::repr::Row;

/// UnmaterializableFunc is a function that can't be eval independently,
/// and require special handling
Expand All @@ -47,6 +45,66 @@ pub enum UnaryFunc {
StepTimestamp,
Cast(ConcreteDataType),
}

impl UnaryFunc {
pub fn eval(&self, values: &[Value], expr: &ScalarExpr) -> Result<Value, EvalError> {
let arg = expr.eval(values)?;
match self {
Self::Not => {
let bool = if let Value::Boolean(bool) = arg {
Ok(bool)
} else {
TypeMismatchSnafu {
expected: ConcreteDataType::boolean_datatype(),
actual: arg.data_type(),
}
.fail()?
}?;
Ok(Value::from(!bool))
}
Self::IsNull => Ok(Value::from(arg.is_null())),
Self::IsTrue | Self::IsFalse => {
let bool = if let Value::Boolean(bool) = arg {
Ok(bool)
} else {
TypeMismatchSnafu {
expected: ConcreteDataType::boolean_datatype(),
actual: arg.data_type(),
}
.fail()?
}?;
if matches!(self, Self::IsTrue) {
Ok(Value::from(bool))
} else {
Ok(Value::from(!bool))
}
}
Self::StepTimestamp => {
if let Value::DateTime(datetime) = arg {
let datetime = DateTime::from(datetime.val() + 1);
Ok(Value::from(datetime))
} else {
TypeMismatchSnafu {
expected: ConcreteDataType::datetime_datatype(),
actual: arg.data_type(),
}
.fail()?
}
}
Self::Cast(to) => {
let arg_ty = arg.data_type();
let res = cast(arg, to).context({
CastValueSnafu {
from: arg_ty,
to: to.clone(),
}
})?;
Ok(res)
}
}
}
}

/// TODO(discord9): support more binary functions for more types
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Deserialize, Serialize, Hash)]
pub enum BinaryFunc {
Expand Down Expand Up @@ -96,8 +154,232 @@ pub enum BinaryFunc {
ModUInt64,
}

impl BinaryFunc {
pub fn eval(
&self,
values: &[Value],
expr1: &ScalarExpr,
expr2: &ScalarExpr,
) -> Result<Value, EvalError> {
let left = expr1.eval(values)?;
let right = expr2.eval(values)?;
match self {
Self::Eq => Ok(Value::from(left == right)),
Self::NotEq => Ok(Value::from(left != right)),
Self::Lt => Ok(Value::from(left < right)),
Self::Lte => Ok(Value::from(left <= right)),
Self::Gt => Ok(Value::from(left > right)),
Self::Gte => Ok(Value::from(left >= right)),

Self::AddInt16 => Ok(add::<i16>(left, right)?),
Self::AddInt32 => Ok(add::<i32>(left, right)?),
Self::AddInt64 => Ok(add::<i64>(left, right)?),
Self::AddUInt16 => Ok(add::<u16>(left, right)?),
Self::AddUInt32 => Ok(add::<u32>(left, right)?),
Self::AddUInt64 => Ok(add::<u64>(left, right)?),
Self::AddFloat32 => Ok(add::<f32>(left, right)?),
Self::AddFloat64 => Ok(add::<f64>(left, right)?),

Self::SubInt16 => Ok(sub::<i16>(left, right)?),
Self::SubInt32 => Ok(sub::<i32>(left, right)?),
Self::SubInt64 => Ok(sub::<i64>(left, right)?),
Self::SubUInt16 => Ok(sub::<u16>(left, right)?),
Self::SubUInt32 => Ok(sub::<u32>(left, right)?),
Self::SubUInt64 => Ok(sub::<u64>(left, right)?),
Self::SubFloat32 => Ok(sub::<f32>(left, right)?),
Self::SubFloat64 => Ok(sub::<f64>(left, right)?),

Self::MulInt16 => Ok(mul::<i16>(left, right)?),
Self::MulInt32 => Ok(mul::<i32>(left, right)?),
Self::MulInt64 => Ok(mul::<i64>(left, right)?),
Self::MulUInt16 => Ok(mul::<u16>(left, right)?),
Self::MulUInt32 => Ok(mul::<u32>(left, right)?),
Self::MulUInt64 => Ok(mul::<u64>(left, right)?),
Self::MulFloat32 => Ok(mul::<f32>(left, right)?),
Self::MulFloat64 => Ok(mul::<f64>(left, right)?),

Self::DivInt16 => Ok(div::<i16>(left, right)?),
Self::DivInt32 => Ok(div::<i32>(left, right)?),
Self::DivInt64 => Ok(div::<i64>(left, right)?),
Self::DivUInt16 => Ok(div::<u16>(left, right)?),
Self::DivUInt32 => Ok(div::<u32>(left, right)?),
Self::DivUInt64 => Ok(div::<u64>(left, right)?),
Self::DivFloat32 => Ok(div::<f32>(left, right)?),
Self::DivFloat64 => Ok(div::<f64>(left, right)?),

Self::ModInt16 => Ok(rem::<i16>(left, right)?),
Self::ModInt32 => Ok(rem::<i32>(left, right)?),
Self::ModInt64 => Ok(rem::<i64>(left, right)?),
Self::ModUInt16 => Ok(rem::<u16>(left, right)?),
Self::ModUInt32 => Ok(rem::<u32>(left, right)?),
Self::ModUInt64 => Ok(rem::<u64>(left, right)?),
}
}

/// Reverse the comparison operator, i.e. `a < b` becomes `b > a`,
/// equal and not equal are unchanged.
pub fn reverse_compare(&self) -> Result<Self, EvalError> {
let ret = match &self {
BinaryFunc::Eq => BinaryFunc::Eq,
BinaryFunc::NotEq => BinaryFunc::NotEq,
BinaryFunc::Lt => BinaryFunc::Gt,
BinaryFunc::Lte => BinaryFunc::Gte,
BinaryFunc::Gt => BinaryFunc::Lt,
BinaryFunc::Gte => BinaryFunc::Lte,
_ => {
return InternalSnafu {
reason: format!("Expect a comparison operator, found {:?}", self),
}
.fail();
}
};
Ok(ret)
}
}

#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Deserialize, Serialize, Hash)]
pub enum VariadicFunc {
And,
Or,
}

impl VariadicFunc {
pub fn eval(&self, values: &[Value], exprs: &[ScalarExpr]) -> Result<Value, EvalError> {
match self {
VariadicFunc::And => and(values, exprs),
VariadicFunc::Or => or(values, exprs),
}
}
}

fn and(values: &[Value], exprs: &[ScalarExpr]) -> Result<Value, EvalError> {
// If any is false, then return false. Else, if any is null, then return null. Else, return true.
let mut null = false;
for expr in exprs {
match expr.eval(values) {
Ok(Value::Boolean(true)) => {}
Ok(Value::Boolean(false)) => return Ok(Value::Boolean(false)), // short-circuit
Ok(Value::Null) => null = true,
Err(this_err) => {
return Err(this_err);
} // retain first error encountered
Ok(x) => InvalidArgumentSnafu {
reason: format!(
"`and()` only support boolean type, found value {:?} of type {:?}",
x,
x.data_type()
),
}
.fail()?,
}
}
match null {
true => Ok(Value::Null),
false => Ok(Value::Boolean(true)),
}
}

fn or(values: &[Value], exprs: &[ScalarExpr]) -> Result<Value, EvalError> {
// If any is false, then return false. Else, if any is null, then return null. Else, return true.
let mut null = false;
for expr in exprs {
match expr.eval(values) {
Ok(Value::Boolean(true)) => return Ok(Value::Boolean(true)), // short-circuit
Ok(Value::Boolean(false)) => {}
Ok(Value::Null) => null = true,
Err(this_err) => {
return Err(this_err);
} // retain first error encountered
Ok(x) => InvalidArgumentSnafu {
reason: format!(
"`or()` only support boolean type, found value {:?} of type {:?}",
x,
x.data_type()
),
}
.fail()?,
}
}
match null {
true => Ok(Value::Null),
false => Ok(Value::Boolean(false)),
}
}

fn add<T>(left: Value, right: Value) -> Result<Value, EvalError>
where
T: TryFrom<Value, Error = datatypes::Error> + num_traits::Num,
Value: From<T>,
{
let left = T::try_from(left).map_err(|e| TryFromValueSnafu { msg: e.to_string() }.build())?;
let right = T::try_from(right).map_err(|e| TryFromValueSnafu { msg: e.to_string() }.build())?;
Ok(Value::from(left + right))
}

fn sub<T>(left: Value, right: Value) -> Result<Value, EvalError>
where
T: TryFrom<Value, Error = datatypes::Error> + num_traits::Num,
Value: From<T>,
{
let left = T::try_from(left).map_err(|e| TryFromValueSnafu { msg: e.to_string() }.build())?;
let right = T::try_from(right).map_err(|e| TryFromValueSnafu { msg: e.to_string() }.build())?;
Ok(Value::from(left - right))
}

fn mul<T>(left: Value, right: Value) -> Result<Value, EvalError>
where
T: TryFrom<Value, Error = datatypes::Error> + num_traits::Num,
Value: From<T>,
{
let left = T::try_from(left).map_err(|e| TryFromValueSnafu { msg: e.to_string() }.build())?;
let right = T::try_from(right).map_err(|e| TryFromValueSnafu { msg: e.to_string() }.build())?;
Ok(Value::from(left * right))
}

fn div<T>(left: Value, right: Value) -> Result<Value, EvalError>
where
T: TryFrom<Value, Error = datatypes::Error> + num_traits::Num,
<T as TryFrom<Value>>::Error: std::fmt::Debug,
Value: From<T>,
{
let left = T::try_from(left).map_err(|e| TryFromValueSnafu { msg: e.to_string() }.build())?;
let right = T::try_from(right).map_err(|e| TryFromValueSnafu { msg: e.to_string() }.build())?;
if right.is_zero() {
return Err(DivisionByZeroSnafu {}.build());
}
Ok(Value::from(left / right))
}

fn rem<T>(left: Value, right: Value) -> Result<Value, EvalError>
where
T: TryFrom<Value, Error = datatypes::Error> + num_traits::Num,
<T as TryFrom<Value>>::Error: std::fmt::Debug,
Value: From<T>,
{
let left = T::try_from(left).map_err(|e| TryFromValueSnafu { msg: e.to_string() }.build())?;
let right = T::try_from(right).map_err(|e| TryFromValueSnafu { msg: e.to_string() }.build())?;
Ok(Value::from(left % right))
}

#[test]
fn test_num_ops() {
let left = Value::from(10);
let right = Value::from(3);
let res = add::<i32>(left.clone(), right.clone()).unwrap();
assert_eq!(res, Value::from(13));
let res = sub::<i32>(left.clone(), right.clone()).unwrap();
assert_eq!(res, Value::from(7));
let res = mul::<i32>(left.clone(), right.clone()).unwrap();
assert_eq!(res, Value::from(30));
let res = div::<i32>(left.clone(), right.clone()).unwrap();
assert_eq!(res, Value::from(3));
let res = rem::<i32>(left.clone(), right.clone()).unwrap();
assert_eq!(res, Value::from(1));

let values = vec![Value::from(true), Value::from(false)];
let exprs = vec![ScalarExpr::Column(0), ScalarExpr::Column(1)];
let res = and(&values, &exprs).unwrap();
assert_eq!(res, Value::from(false));
let res = or(&values, &exprs).unwrap();
assert_eq!(res, Value::from(true));
}
Loading
Loading