diff --git a/Cargo.lock b/Cargo.lock index 450afc7dc02e..d3f3e2a38513 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3406,6 +3406,7 @@ dependencies = [ "datatypes", "hydroflow", "itertools 0.10.5", + "num-traits", "serde", "servers", "session", diff --git a/src/flow/Cargo.toml b/src/flow/Cargo.toml index f20aa5d07e4c..2b244314cb02 100644 --- a/src/flow/Cargo.toml +++ b/src/flow/Cargo.toml @@ -17,6 +17,7 @@ common-time.workspace = true datatypes.workspace = true hydroflow = "0.5.0" itertools.workspace = true +num-traits = "0.2" serde.workspace = true servers.workspace = true session.workspace = true diff --git a/src/flow/src/expr/error.rs b/src/flow/src/expr/error.rs index 0fd58ba1cf8f..233538fb6564 100644 --- a/src/flow/src/expr/error.rs +++ b/src/flow/src/expr/error.rs @@ -58,4 +58,7 @@ pub enum EvalError { #[snafu(display("Optimize error: {reason}"))] Optimize { reason: String, location: Location }, + + #[snafu(display("Unsupported temporal filter: {reason}"))] + UnsupportedTemporalFilter { reason: String, location: Location }, } diff --git a/src/flow/src/expr/func.rs b/src/flow/src/expr/func.rs index eed43f65a759..85a127f09a4d 100644 --- a/src/flow/src/expr/func.rs +++ b/src/flow/src/expr/func.rs @@ -21,14 +21,12 @@ use hydroflow::bincode::Error; use serde::{Deserialize, Serialize}; use snafu::ResultExt; -use super::ScalarExpr; -use crate::expr::error::CastValueSnafu; -use crate::expr::InvalidArgumentSnafu; -// TODO(discord9): more function & eval -use crate::{ - expr::error::{EvalError, TryFromValueSnafu, TypeMismatchSnafu}, - repr::Row, +use crate::expr::error::{ + CastValueSnafu, DivisionByZeroSnafu, EvalError, InternalSnafu, TryFromValueSnafu, + TypeMismatchSnafu, }; +use crate::expr::{InvalidArgumentSnafu, ScalarExpr}; +use crate::repr::Row; /// UnmaterializableFunc is a function that can't be eval independently, /// and require special handling @@ -47,6 +45,66 @@ pub enum UnaryFunc { StepTimestamp, Cast(ConcreteDataType), } + +impl UnaryFunc { + pub fn eval(&self, values: &[Value], expr: &ScalarExpr) -> Result { + let arg = expr.eval(values)?; + match self { + Self::Not => { + let bool = if let Value::Boolean(bool) = arg { + Ok(bool) + } else { + TypeMismatchSnafu { + expected: ConcreteDataType::boolean_datatype(), + actual: arg.data_type(), + } + .fail()? + }?; + Ok(Value::from(!bool)) + } + Self::IsNull => Ok(Value::from(arg.is_null())), + Self::IsTrue | Self::IsFalse => { + let bool = if let Value::Boolean(bool) = arg { + Ok(bool) + } else { + TypeMismatchSnafu { + expected: ConcreteDataType::boolean_datatype(), + actual: arg.data_type(), + } + .fail()? + }?; + if matches!(self, Self::IsTrue) { + Ok(Value::from(bool)) + } else { + Ok(Value::from(!bool)) + } + } + Self::StepTimestamp => { + if let Value::DateTime(datetime) = arg { + let datetime = DateTime::from(datetime.val() + 1); + Ok(Value::from(datetime)) + } else { + TypeMismatchSnafu { + expected: ConcreteDataType::datetime_datatype(), + actual: arg.data_type(), + } + .fail()? + } + } + Self::Cast(to) => { + let arg_ty = arg.data_type(); + let res = cast(arg, to).context({ + CastValueSnafu { + from: arg_ty, + to: to.clone(), + } + })?; + Ok(res) + } + } + } +} + /// TODO(discord9): support more binary functions for more types #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Deserialize, Serialize, Hash)] pub enum BinaryFunc { @@ -96,8 +154,232 @@ pub enum BinaryFunc { ModUInt64, } +impl BinaryFunc { + pub fn eval( + &self, + values: &[Value], + expr1: &ScalarExpr, + expr2: &ScalarExpr, + ) -> Result { + let left = expr1.eval(values)?; + let right = expr2.eval(values)?; + match self { + Self::Eq => Ok(Value::from(left == right)), + Self::NotEq => Ok(Value::from(left != right)), + Self::Lt => Ok(Value::from(left < right)), + Self::Lte => Ok(Value::from(left <= right)), + Self::Gt => Ok(Value::from(left > right)), + Self::Gte => Ok(Value::from(left >= right)), + + Self::AddInt16 => Ok(add::(left, right)?), + Self::AddInt32 => Ok(add::(left, right)?), + Self::AddInt64 => Ok(add::(left, right)?), + Self::AddUInt16 => Ok(add::(left, right)?), + Self::AddUInt32 => Ok(add::(left, right)?), + Self::AddUInt64 => Ok(add::(left, right)?), + Self::AddFloat32 => Ok(add::(left, right)?), + Self::AddFloat64 => Ok(add::(left, right)?), + + Self::SubInt16 => Ok(sub::(left, right)?), + Self::SubInt32 => Ok(sub::(left, right)?), + Self::SubInt64 => Ok(sub::(left, right)?), + Self::SubUInt16 => Ok(sub::(left, right)?), + Self::SubUInt32 => Ok(sub::(left, right)?), + Self::SubUInt64 => Ok(sub::(left, right)?), + Self::SubFloat32 => Ok(sub::(left, right)?), + Self::SubFloat64 => Ok(sub::(left, right)?), + + Self::MulInt16 => Ok(mul::(left, right)?), + Self::MulInt32 => Ok(mul::(left, right)?), + Self::MulInt64 => Ok(mul::(left, right)?), + Self::MulUInt16 => Ok(mul::(left, right)?), + Self::MulUInt32 => Ok(mul::(left, right)?), + Self::MulUInt64 => Ok(mul::(left, right)?), + Self::MulFloat32 => Ok(mul::(left, right)?), + Self::MulFloat64 => Ok(mul::(left, right)?), + + Self::DivInt16 => Ok(div::(left, right)?), + Self::DivInt32 => Ok(div::(left, right)?), + Self::DivInt64 => Ok(div::(left, right)?), + Self::DivUInt16 => Ok(div::(left, right)?), + Self::DivUInt32 => Ok(div::(left, right)?), + Self::DivUInt64 => Ok(div::(left, right)?), + Self::DivFloat32 => Ok(div::(left, right)?), + Self::DivFloat64 => Ok(div::(left, right)?), + + Self::ModInt16 => Ok(rem::(left, right)?), + Self::ModInt32 => Ok(rem::(left, right)?), + Self::ModInt64 => Ok(rem::(left, right)?), + Self::ModUInt16 => Ok(rem::(left, right)?), + Self::ModUInt32 => Ok(rem::(left, right)?), + Self::ModUInt64 => Ok(rem::(left, right)?), + } + } + + /// Reverse the comparison operator, i.e. `a < b` becomes `b > a`, + /// equal and not equal are unchanged. + pub fn reverse_compare(&self) -> Result { + let ret = match &self { + BinaryFunc::Eq => BinaryFunc::Eq, + BinaryFunc::NotEq => BinaryFunc::NotEq, + BinaryFunc::Lt => BinaryFunc::Gt, + BinaryFunc::Lte => BinaryFunc::Gte, + BinaryFunc::Gt => BinaryFunc::Lt, + BinaryFunc::Gte => BinaryFunc::Lte, + _ => { + return InternalSnafu { + reason: format!("Expect a comparison operator, found {:?}", self), + } + .fail(); + } + }; + Ok(ret) + } +} + #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Deserialize, Serialize, Hash)] pub enum VariadicFunc { And, Or, } + +impl VariadicFunc { + pub fn eval(&self, values: &[Value], exprs: &[ScalarExpr]) -> Result { + match self { + VariadicFunc::And => and(values, exprs), + VariadicFunc::Or => or(values, exprs), + } + } +} + +fn and(values: &[Value], exprs: &[ScalarExpr]) -> Result { + // If any is false, then return false. Else, if any is null, then return null. Else, return true. + let mut null = false; + for expr in exprs { + match expr.eval(values) { + Ok(Value::Boolean(true)) => {} + Ok(Value::Boolean(false)) => return Ok(Value::Boolean(false)), // short-circuit + Ok(Value::Null) => null = true, + Err(this_err) => { + return Err(this_err); + } // retain first error encountered + Ok(x) => InvalidArgumentSnafu { + reason: format!( + "`and()` only support boolean type, found value {:?} of type {:?}", + x, + x.data_type() + ), + } + .fail()?, + } + } + match null { + true => Ok(Value::Null), + false => Ok(Value::Boolean(true)), + } +} + +fn or(values: &[Value], exprs: &[ScalarExpr]) -> Result { + // If any is false, then return false. Else, if any is null, then return null. Else, return true. + let mut null = false; + for expr in exprs { + match expr.eval(values) { + Ok(Value::Boolean(true)) => return Ok(Value::Boolean(true)), // short-circuit + Ok(Value::Boolean(false)) => {} + Ok(Value::Null) => null = true, + Err(this_err) => { + return Err(this_err); + } // retain first error encountered + Ok(x) => InvalidArgumentSnafu { + reason: format!( + "`or()` only support boolean type, found value {:?} of type {:?}", + x, + x.data_type() + ), + } + .fail()?, + } + } + match null { + true => Ok(Value::Null), + false => Ok(Value::Boolean(false)), + } +} + +fn add(left: Value, right: Value) -> Result +where + T: TryFrom + num_traits::Num, + Value: From, +{ + let left = T::try_from(left).map_err(|e| TryFromValueSnafu { msg: e.to_string() }.build())?; + let right = T::try_from(right).map_err(|e| TryFromValueSnafu { msg: e.to_string() }.build())?; + Ok(Value::from(left + right)) +} + +fn sub(left: Value, right: Value) -> Result +where + T: TryFrom + num_traits::Num, + Value: From, +{ + let left = T::try_from(left).map_err(|e| TryFromValueSnafu { msg: e.to_string() }.build())?; + let right = T::try_from(right).map_err(|e| TryFromValueSnafu { msg: e.to_string() }.build())?; + Ok(Value::from(left - right)) +} + +fn mul(left: Value, right: Value) -> Result +where + T: TryFrom + num_traits::Num, + Value: From, +{ + let left = T::try_from(left).map_err(|e| TryFromValueSnafu { msg: e.to_string() }.build())?; + let right = T::try_from(right).map_err(|e| TryFromValueSnafu { msg: e.to_string() }.build())?; + Ok(Value::from(left * right)) +} + +fn div(left: Value, right: Value) -> Result +where + T: TryFrom + num_traits::Num, + >::Error: std::fmt::Debug, + Value: From, +{ + let left = T::try_from(left).map_err(|e| TryFromValueSnafu { msg: e.to_string() }.build())?; + let right = T::try_from(right).map_err(|e| TryFromValueSnafu { msg: e.to_string() }.build())?; + if right.is_zero() { + return Err(DivisionByZeroSnafu {}.build()); + } + Ok(Value::from(left / right)) +} + +fn rem(left: Value, right: Value) -> Result +where + T: TryFrom + num_traits::Num, + >::Error: std::fmt::Debug, + Value: From, +{ + let left = T::try_from(left).map_err(|e| TryFromValueSnafu { msg: e.to_string() }.build())?; + let right = T::try_from(right).map_err(|e| TryFromValueSnafu { msg: e.to_string() }.build())?; + Ok(Value::from(left % right)) +} + +#[test] +fn test_num_ops() { + let left = Value::from(10); + let right = Value::from(3); + let res = add::(left.clone(), right.clone()).unwrap(); + assert_eq!(res, Value::from(13)); + let res = sub::(left.clone(), right.clone()).unwrap(); + assert_eq!(res, Value::from(7)); + let res = mul::(left.clone(), right.clone()).unwrap(); + assert_eq!(res, Value::from(30)); + let res = div::(left.clone(), right.clone()).unwrap(); + assert_eq!(res, Value::from(3)); + let res = rem::(left.clone(), right.clone()).unwrap(); + assert_eq!(res, Value::from(1)); + + let values = vec![Value::from(true), Value::from(false)]; + let exprs = vec![ScalarExpr::Column(0), ScalarExpr::Column(1)]; + let res = and(&values, &exprs).unwrap(); + assert_eq!(res, Value::from(false)); + let res = or(&values, &exprs).unwrap(); + assert_eq!(res, Value::from(true)); +} diff --git a/src/flow/src/expr/scalar.rs b/src/flow/src/expr/scalar.rs index 3c1d745a8616..fa03bb9f1912 100644 --- a/src/flow/src/expr/scalar.rs +++ b/src/flow/src/expr/scalar.rs @@ -18,7 +18,9 @@ use datatypes::prelude::ConcreteDataType; use datatypes::value::Value; use serde::{Deserialize, Serialize}; -use crate::expr::error::{EvalError, InvalidArgumentSnafu, OptimizeSnafu}; +use crate::expr::error::{ + EvalError, InvalidArgumentSnafu, OptimizeSnafu, UnsupportedTemporalFilterSnafu, +}; use crate::expr::func::{BinaryFunc, UnaryFunc, UnmaterializableFunc, VariadicFunc}; /// A scalar expression, which can be evaluated to a value. @@ -59,3 +61,338 @@ pub enum ScalarExpr { els: Box, }, } + +impl ScalarExpr { + pub fn call_unary(self, func: UnaryFunc) -> Self { + ScalarExpr::CallUnary { + func, + expr: Box::new(self), + } + } + + pub fn call_binary(self, other: Self, func: BinaryFunc) -> Self { + ScalarExpr::CallBinary { + func, + expr1: Box::new(self), + expr2: Box::new(other), + } + } + + pub fn eval(&self, values: &[Value]) -> Result { + match self { + ScalarExpr::Column(index) => Ok(values[*index].clone()), + ScalarExpr::Literal(row_res, _ty) => Ok(row_res.clone()), + ScalarExpr::CallUnmaterializable(f) => OptimizeSnafu { + reason: "Can't eval unmaterializable function".to_string(), + } + .fail(), + ScalarExpr::CallUnary { func, expr } => func.eval(values, expr), + ScalarExpr::CallBinary { func, expr1, expr2 } => func.eval(values, expr1, expr2), + ScalarExpr::CallVariadic { func, exprs } => func.eval(values, exprs), + ScalarExpr::If { cond, then, els } => match cond.eval(values) { + Ok(Value::Boolean(true)) => then.eval(values), + Ok(Value::Boolean(false)) => els.eval(values), + _ => InvalidArgumentSnafu { + reason: "if condition must be boolean".to_string(), + } + .fail(), + }, + } + } + + /// Rewrites column indices with their value in `permutation`. + /// + /// This method is applicable even when `permutation` is not a + /// strict permutation, and it only needs to have entries for + /// each column referenced in `self`. + pub fn permute(&mut self, permutation: &[usize]) { + self.visit_mut_post_nolimit(&mut |e| { + if let ScalarExpr::Column(old_i) = e { + *old_i = permutation[*old_i]; + } + }); + } + + /// Rewrites column indices with their value in `permutation`. + /// + /// This method is applicable even when `permutation` is not a + /// strict permutation, and it only needs to have entries for + /// each column referenced in `self`. + pub fn permute_map(&mut self, permutation: &BTreeMap) { + self.visit_mut_post_nolimit(&mut |e| { + if let ScalarExpr::Column(old_i) = e { + *old_i = permutation[old_i]; + } + }); + } + + /// Returns the set of columns that are referenced by `self`. + pub fn get_all_ref_columns(&self) -> BTreeSet { + let mut support = BTreeSet::new(); + self.visit_post_nolimit(&mut |e| { + if let ScalarExpr::Column(i) = e { + support.insert(*i); + } + }); + support + } + + pub fn as_literal(&self) -> Option { + if let ScalarExpr::Literal(lit, _column_type) = self { + Some(lit.clone()) + } else { + None + } + } + + pub fn is_literal(&self) -> bool { + matches!(self, ScalarExpr::Literal(..)) + } + + pub fn is_literal_true(&self) -> bool { + Some(Value::Boolean(true)) == self.as_literal() + } + + pub fn is_literal_false(&self) -> bool { + Some(Value::Boolean(false)) == self.as_literal() + } + + pub fn is_literal_null(&self) -> bool { + Some(Value::Null) == self.as_literal() + } + + pub fn literal_null() -> Self { + ScalarExpr::Literal(Value::Null, ConcreteDataType::null_datatype()) + } + + pub fn literal(res: Value, typ: ConcreteDataType) -> Self { + ScalarExpr::Literal(res, typ) + } + + pub fn literal_false() -> Self { + ScalarExpr::Literal(Value::Boolean(false), ConcreteDataType::boolean_datatype()) + } + + pub fn literal_true() -> Self { + ScalarExpr::Literal(Value::Boolean(true), ConcreteDataType::boolean_datatype()) + } +} + +impl ScalarExpr { + /// visit post-order without stack call limit, but may cause stack overflow + fn visit_post_nolimit(&self, f: &mut F) + where + F: FnMut(&Self), + { + self.visit_children(|e| e.visit_post_nolimit(f)); + f(self); + } + + fn visit_children(&self, mut f: F) + where + F: FnMut(&Self), + { + match self { + ScalarExpr::Column(_) + | ScalarExpr::Literal(_, _) + | ScalarExpr::CallUnmaterializable(_) => (), + ScalarExpr::CallUnary { expr, .. } => f(expr), + ScalarExpr::CallBinary { expr1, expr2, .. } => { + f(expr1); + f(expr2); + } + ScalarExpr::CallVariadic { exprs, .. } => { + for expr in exprs { + f(expr); + } + } + ScalarExpr::If { cond, then, els } => { + f(cond); + f(then); + f(els); + } + } + } + + fn visit_mut_post_nolimit(&mut self, f: &mut F) + where + F: FnMut(&mut Self), + { + self.visit_mut_children(|e: &mut Self| e.visit_mut_post_nolimit(f)); + f(self); + } + + fn visit_mut_children(&mut self, mut f: F) + where + F: FnMut(&mut Self), + { + match self { + ScalarExpr::Column(_) + | ScalarExpr::Literal(_, _) + | ScalarExpr::CallUnmaterializable(_) => (), + ScalarExpr::CallUnary { expr, .. } => f(expr), + ScalarExpr::CallBinary { expr1, expr2, .. } => { + f(expr1); + f(expr2); + } + ScalarExpr::CallVariadic { exprs, .. } => { + for expr in exprs { + f(expr); + } + } + ScalarExpr::If { cond, then, els } => { + f(cond); + f(then); + f(els); + } + } + } +} + +impl ScalarExpr { + /// if expr contains function `Now` + pub fn contains_temporal(&self) -> bool { + let mut contains = false; + self.visit_post_nolimit(&mut |e| { + if let ScalarExpr::CallUnmaterializable(UnmaterializableFunc::Now) = e { + contains = true; + } + }); + contains + } + + /// extract lower or upper bound of `Now` for expr, where `lower bound <= expr < upper bound` + /// + /// returned bool indicates whether the bound is upper bound: + /// + /// false for lower bound, true for upper bound + /// TODO(discord9): allow simple transform like `now() + a < b` to `now() < b - a` + pub fn extract_bound(&self) -> Result<(Option, Option), EvalError> { + let unsupported_err = |msg: &str| { + UnsupportedTemporalFilterSnafu { + reason: msg.to_string(), + } + .fail() + }; + + let Self::CallBinary { + mut func, + mut expr1, + mut expr2, + } = self.clone() + else { + return unsupported_err("Not a binary expression"); + }; + + // TODO: support simple transform like `now() + a < b` to `now() < b - a` + + let expr1_is_now = *expr1 == ScalarExpr::CallUnmaterializable(UnmaterializableFunc::Now); + let expr2_is_now = *expr2 == ScalarExpr::CallUnmaterializable(UnmaterializableFunc::Now); + + if !(expr1_is_now ^ expr2_is_now) { + return unsupported_err("None of the sides of the comparison is `now()`"); + } + + if expr2_is_now { + std::mem::swap(&mut expr1, &mut expr2); + func = BinaryFunc::reverse_compare(&func)?; + } + + let step = |expr: ScalarExpr| expr.call_unary(UnaryFunc::StepTimestamp); + match func { + // now == expr2 -> now <= expr2 && now < expr2 + 1 + BinaryFunc::Eq => Ok((Some(*expr2.clone()), Some(step(*expr2)))), + // now < expr2 -> now < expr2 + BinaryFunc::Lt => Ok((None, Some(*expr2))), + // now <= expr2 -> now < expr2 + 1 + BinaryFunc::Lte => Ok((None, Some(step(*expr2)))), + // now > expr2 -> now >= expr2 + 1 + BinaryFunc::Gt => Ok((Some(step(*expr2)), None)), + // now >= expr2 -> now >= expr2 + BinaryFunc::Gte => Ok((Some(*expr2), None)), + _ => unreachable!("Already checked"), + } + } +} + +#[cfg(test)] +mod test { + use super::*; + #[test] + fn test_extract_bound() { + let test_list: [(ScalarExpr, Result<_, EvalError>); 5] = [ + // col(0) == now + ( + ScalarExpr::CallBinary { + func: BinaryFunc::Eq, + expr1: Box::new(ScalarExpr::CallUnmaterializable(UnmaterializableFunc::Now)), + expr2: Box::new(ScalarExpr::Column(0)), + }, + Ok(( + Some(ScalarExpr::Column(0)), + Some(ScalarExpr::CallUnary { + func: UnaryFunc::StepTimestamp, + expr: Box::new(ScalarExpr::Column(0)), + }), + )), + ), + // now < col(0) + ( + ScalarExpr::CallBinary { + func: BinaryFunc::Lt, + expr1: Box::new(ScalarExpr::CallUnmaterializable(UnmaterializableFunc::Now)), + expr2: Box::new(ScalarExpr::Column(0)), + }, + Ok((None, Some(ScalarExpr::Column(0)))), + ), + // now <= col(0) + ( + ScalarExpr::CallBinary { + func: BinaryFunc::Lte, + expr1: Box::new(ScalarExpr::CallUnmaterializable(UnmaterializableFunc::Now)), + expr2: Box::new(ScalarExpr::Column(0)), + }, + Ok(( + None, + Some(ScalarExpr::CallUnary { + func: UnaryFunc::StepTimestamp, + expr: Box::new(ScalarExpr::Column(0)), + }), + )), + ), + // now > col(0) -> now >= col(0) + 1 + ( + ScalarExpr::CallBinary { + func: BinaryFunc::Gt, + expr1: Box::new(ScalarExpr::CallUnmaterializable(UnmaterializableFunc::Now)), + expr2: Box::new(ScalarExpr::Column(0)), + }, + Ok(( + Some(ScalarExpr::CallUnary { + func: UnaryFunc::StepTimestamp, + expr: Box::new(ScalarExpr::Column(0)), + }), + None, + )), + ), + // now >= col(0) + ( + ScalarExpr::CallBinary { + func: BinaryFunc::Gte, + expr1: Box::new(ScalarExpr::CallUnmaterializable(UnmaterializableFunc::Now)), + expr2: Box::new(ScalarExpr::Column(0)), + }, + Ok((Some(ScalarExpr::Column(0)), None)), + ), + ]; + for (expr, expected) in test_list.into_iter() { + let actual = expr.extract_bound(); + // EvalError is not Eq, so we need to compare the error message + match (actual, expected) { + (Ok(l), Ok(r)) => assert_eq!(l, r), + (Err(l), Err(r)) => assert!(matches!(l, r)), + (l, r) => panic!("expected: {:?}, actual: {:?}", r, l), + } + } + } +}