Skip to content

Commit

Permalink
feat(flow): impl ScalarExpr&Scalar Function (#3283)
Browse files Browse the repository at this point in the history
* feat: impl for ScalarExpr

* feat: plain functions

* refactor: simpler trait bound&tests

* chore: remove unused imports

* chore: fmt

* refactor: early ret on first error

* refactor: remove abunant match arm

* chore: per review

* doc: `support` fn

* chore: per review more

* chore: more per review

* fix: extract_bound

* chore: per review

* refactor: reduce nest
  • Loading branch information
discord9 authored Feb 21, 2024
1 parent 7c88d72 commit 860b1e9
Show file tree
Hide file tree
Showing 5 changed files with 632 additions and 8 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions src/flow/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ common-time.workspace = true
datatypes.workspace = true
hydroflow = "0.5.0"
itertools.workspace = true
num-traits = "0.2"
serde.workspace = true
servers.workspace = true
session.workspace = true
Expand Down
3 changes: 3 additions & 0 deletions src/flow/src/expr/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,4 +58,7 @@ pub enum EvalError {

#[snafu(display("Optimize error: {reason}"))]
Optimize { reason: String, location: Location },

#[snafu(display("Unsupported temporal filter: {reason}"))]
UnsupportedTemporalFilter { reason: String, location: Location },
}
296 changes: 289 additions & 7 deletions src/flow/src/expr/func.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,12 @@ use hydroflow::bincode::Error;
use serde::{Deserialize, Serialize};
use snafu::ResultExt;

use super::ScalarExpr;
use crate::expr::error::CastValueSnafu;
use crate::expr::InvalidArgumentSnafu;
// TODO(discord9): more function & eval
use crate::{
expr::error::{EvalError, TryFromValueSnafu, TypeMismatchSnafu},
repr::Row,
use crate::expr::error::{
CastValueSnafu, DivisionByZeroSnafu, EvalError, InternalSnafu, TryFromValueSnafu,
TypeMismatchSnafu,
};
use crate::expr::{InvalidArgumentSnafu, ScalarExpr};
use crate::repr::Row;

/// UnmaterializableFunc is a function that can't be eval independently,
/// and require special handling
Expand All @@ -47,6 +45,66 @@ pub enum UnaryFunc {
StepTimestamp,
Cast(ConcreteDataType),
}

impl UnaryFunc {
pub fn eval(&self, values: &[Value], expr: &ScalarExpr) -> Result<Value, EvalError> {
let arg = expr.eval(values)?;
match self {
Self::Not => {
let bool = if let Value::Boolean(bool) = arg {
Ok(bool)
} else {
TypeMismatchSnafu {
expected: ConcreteDataType::boolean_datatype(),
actual: arg.data_type(),
}
.fail()?
}?;
Ok(Value::from(!bool))
}
Self::IsNull => Ok(Value::from(arg.is_null())),
Self::IsTrue | Self::IsFalse => {
let bool = if let Value::Boolean(bool) = arg {
Ok(bool)
} else {
TypeMismatchSnafu {
expected: ConcreteDataType::boolean_datatype(),
actual: arg.data_type(),
}
.fail()?
}?;
if matches!(self, Self::IsTrue) {
Ok(Value::from(bool))
} else {
Ok(Value::from(!bool))
}
}
Self::StepTimestamp => {
if let Value::DateTime(datetime) = arg {
let datetime = DateTime::from(datetime.val() + 1);
Ok(Value::from(datetime))
} else {
TypeMismatchSnafu {
expected: ConcreteDataType::datetime_datatype(),
actual: arg.data_type(),
}
.fail()?
}
}
Self::Cast(to) => {
let arg_ty = arg.data_type();
let res = cast(arg, to).context({
CastValueSnafu {
from: arg_ty,
to: to.clone(),
}
})?;
Ok(res)
}
}
}
}

/// TODO(discord9): support more binary functions for more types
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Deserialize, Serialize, Hash)]
pub enum BinaryFunc {
Expand Down Expand Up @@ -96,8 +154,232 @@ pub enum BinaryFunc {
ModUInt64,
}

impl BinaryFunc {
pub fn eval(
&self,
values: &[Value],
expr1: &ScalarExpr,
expr2: &ScalarExpr,
) -> Result<Value, EvalError> {
let left = expr1.eval(values)?;
let right = expr2.eval(values)?;
match self {
Self::Eq => Ok(Value::from(left == right)),
Self::NotEq => Ok(Value::from(left != right)),
Self::Lt => Ok(Value::from(left < right)),
Self::Lte => Ok(Value::from(left <= right)),
Self::Gt => Ok(Value::from(left > right)),
Self::Gte => Ok(Value::from(left >= right)),

Self::AddInt16 => Ok(add::<i16>(left, right)?),
Self::AddInt32 => Ok(add::<i32>(left, right)?),
Self::AddInt64 => Ok(add::<i64>(left, right)?),
Self::AddUInt16 => Ok(add::<u16>(left, right)?),
Self::AddUInt32 => Ok(add::<u32>(left, right)?),
Self::AddUInt64 => Ok(add::<u64>(left, right)?),
Self::AddFloat32 => Ok(add::<f32>(left, right)?),
Self::AddFloat64 => Ok(add::<f64>(left, right)?),

Self::SubInt16 => Ok(sub::<i16>(left, right)?),
Self::SubInt32 => Ok(sub::<i32>(left, right)?),
Self::SubInt64 => Ok(sub::<i64>(left, right)?),
Self::SubUInt16 => Ok(sub::<u16>(left, right)?),
Self::SubUInt32 => Ok(sub::<u32>(left, right)?),
Self::SubUInt64 => Ok(sub::<u64>(left, right)?),
Self::SubFloat32 => Ok(sub::<f32>(left, right)?),
Self::SubFloat64 => Ok(sub::<f64>(left, right)?),

Self::MulInt16 => Ok(mul::<i16>(left, right)?),
Self::MulInt32 => Ok(mul::<i32>(left, right)?),
Self::MulInt64 => Ok(mul::<i64>(left, right)?),
Self::MulUInt16 => Ok(mul::<u16>(left, right)?),
Self::MulUInt32 => Ok(mul::<u32>(left, right)?),
Self::MulUInt64 => Ok(mul::<u64>(left, right)?),
Self::MulFloat32 => Ok(mul::<f32>(left, right)?),
Self::MulFloat64 => Ok(mul::<f64>(left, right)?),

Self::DivInt16 => Ok(div::<i16>(left, right)?),
Self::DivInt32 => Ok(div::<i32>(left, right)?),
Self::DivInt64 => Ok(div::<i64>(left, right)?),
Self::DivUInt16 => Ok(div::<u16>(left, right)?),
Self::DivUInt32 => Ok(div::<u32>(left, right)?),
Self::DivUInt64 => Ok(div::<u64>(left, right)?),
Self::DivFloat32 => Ok(div::<f32>(left, right)?),
Self::DivFloat64 => Ok(div::<f64>(left, right)?),

Self::ModInt16 => Ok(rem::<i16>(left, right)?),
Self::ModInt32 => Ok(rem::<i32>(left, right)?),
Self::ModInt64 => Ok(rem::<i64>(left, right)?),
Self::ModUInt16 => Ok(rem::<u16>(left, right)?),
Self::ModUInt32 => Ok(rem::<u32>(left, right)?),
Self::ModUInt64 => Ok(rem::<u64>(left, right)?),
}
}

/// Reverse the comparison operator, i.e. `a < b` becomes `b > a`,
/// equal and not equal are unchanged.
pub fn reverse_compare(&self) -> Result<Self, EvalError> {
let ret = match &self {
BinaryFunc::Eq => BinaryFunc::Eq,
BinaryFunc::NotEq => BinaryFunc::NotEq,
BinaryFunc::Lt => BinaryFunc::Gt,
BinaryFunc::Lte => BinaryFunc::Gte,
BinaryFunc::Gt => BinaryFunc::Lt,
BinaryFunc::Gte => BinaryFunc::Lte,
_ => {
return InternalSnafu {
reason: format!("Expect a comparison operator, found {:?}", self),
}
.fail();
}
};
Ok(ret)
}
}

#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Deserialize, Serialize, Hash)]
pub enum VariadicFunc {
And,
Or,
}

impl VariadicFunc {
pub fn eval(&self, values: &[Value], exprs: &[ScalarExpr]) -> Result<Value, EvalError> {
match self {
VariadicFunc::And => and(values, exprs),
VariadicFunc::Or => or(values, exprs),
}
}
}

fn and(values: &[Value], exprs: &[ScalarExpr]) -> Result<Value, EvalError> {
// If any is false, then return false. Else, if any is null, then return null. Else, return true.
let mut null = false;
for expr in exprs {
match expr.eval(values) {
Ok(Value::Boolean(true)) => {}
Ok(Value::Boolean(false)) => return Ok(Value::Boolean(false)), // short-circuit
Ok(Value::Null) => null = true,
Err(this_err) => {
return Err(this_err);
} // retain first error encountered
Ok(x) => InvalidArgumentSnafu {
reason: format!(
"`and()` only support boolean type, found value {:?} of type {:?}",
x,
x.data_type()
),
}
.fail()?,
}
}
match null {
true => Ok(Value::Null),
false => Ok(Value::Boolean(true)),
}
}

fn or(values: &[Value], exprs: &[ScalarExpr]) -> Result<Value, EvalError> {
// If any is false, then return false. Else, if any is null, then return null. Else, return true.
let mut null = false;
for expr in exprs {
match expr.eval(values) {
Ok(Value::Boolean(true)) => return Ok(Value::Boolean(true)), // short-circuit
Ok(Value::Boolean(false)) => {}
Ok(Value::Null) => null = true,
Err(this_err) => {
return Err(this_err);
} // retain first error encountered
Ok(x) => InvalidArgumentSnafu {
reason: format!(
"`or()` only support boolean type, found value {:?} of type {:?}",
x,
x.data_type()
),
}
.fail()?,
}
}
match null {
true => Ok(Value::Null),
false => Ok(Value::Boolean(false)),
}
}

fn add<T>(left: Value, right: Value) -> Result<Value, EvalError>
where
T: TryFrom<Value, Error = datatypes::Error> + num_traits::Num,
Value: From<T>,
{
let left = T::try_from(left).map_err(|e| TryFromValueSnafu { msg: e.to_string() }.build())?;
let right = T::try_from(right).map_err(|e| TryFromValueSnafu { msg: e.to_string() }.build())?;
Ok(Value::from(left + right))
}

fn sub<T>(left: Value, right: Value) -> Result<Value, EvalError>
where
T: TryFrom<Value, Error = datatypes::Error> + num_traits::Num,
Value: From<T>,
{
let left = T::try_from(left).map_err(|e| TryFromValueSnafu { msg: e.to_string() }.build())?;
let right = T::try_from(right).map_err(|e| TryFromValueSnafu { msg: e.to_string() }.build())?;
Ok(Value::from(left - right))
}

fn mul<T>(left: Value, right: Value) -> Result<Value, EvalError>
where
T: TryFrom<Value, Error = datatypes::Error> + num_traits::Num,
Value: From<T>,
{
let left = T::try_from(left).map_err(|e| TryFromValueSnafu { msg: e.to_string() }.build())?;
let right = T::try_from(right).map_err(|e| TryFromValueSnafu { msg: e.to_string() }.build())?;
Ok(Value::from(left * right))
}

fn div<T>(left: Value, right: Value) -> Result<Value, EvalError>
where
T: TryFrom<Value, Error = datatypes::Error> + num_traits::Num,
<T as TryFrom<Value>>::Error: std::fmt::Debug,
Value: From<T>,
{
let left = T::try_from(left).map_err(|e| TryFromValueSnafu { msg: e.to_string() }.build())?;
let right = T::try_from(right).map_err(|e| TryFromValueSnafu { msg: e.to_string() }.build())?;
if right.is_zero() {
return Err(DivisionByZeroSnafu {}.build());
}
Ok(Value::from(left / right))
}

fn rem<T>(left: Value, right: Value) -> Result<Value, EvalError>
where
T: TryFrom<Value, Error = datatypes::Error> + num_traits::Num,
<T as TryFrom<Value>>::Error: std::fmt::Debug,
Value: From<T>,
{
let left = T::try_from(left).map_err(|e| TryFromValueSnafu { msg: e.to_string() }.build())?;
let right = T::try_from(right).map_err(|e| TryFromValueSnafu { msg: e.to_string() }.build())?;
Ok(Value::from(left % right))
}

#[test]
fn test_num_ops() {
let left = Value::from(10);
let right = Value::from(3);
let res = add::<i32>(left.clone(), right.clone()).unwrap();
assert_eq!(res, Value::from(13));
let res = sub::<i32>(left.clone(), right.clone()).unwrap();
assert_eq!(res, Value::from(7));
let res = mul::<i32>(left.clone(), right.clone()).unwrap();
assert_eq!(res, Value::from(30));
let res = div::<i32>(left.clone(), right.clone()).unwrap();
assert_eq!(res, Value::from(3));
let res = rem::<i32>(left.clone(), right.clone()).unwrap();
assert_eq!(res, Value::from(1));

let values = vec![Value::from(true), Value::from(false)];
let exprs = vec![ScalarExpr::Column(0), ScalarExpr::Column(1)];
let res = and(&values, &exprs).unwrap();
assert_eq!(res, Value::from(false));
let res = or(&values, &exprs).unwrap();
assert_eq!(res, Value::from(true));
}
Loading

0 comments on commit 860b1e9

Please sign in to comment.