diff --git a/Cargo.lock b/Cargo.lock index 3ec7ea27d9c0..4ae09096a9ad 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3474,10 +3474,12 @@ dependencies = [ "common-telemetry", "common-time", "datatypes", + "enum_dispatch", "hydroflow", "itertools 0.10.5", "num-traits", "serde", + "serde_json", "servers", "session", "snafu", diff --git a/src/flow/Cargo.toml b/src/flow/Cargo.toml index 0dc614a6c06c..cc67850eac86 100644 --- a/src/flow/Cargo.toml +++ b/src/flow/Cargo.toml @@ -18,6 +18,7 @@ common-query.workspace = true common-telemetry.workspace = true common-time.workspace = true datatypes.workspace = true +enum_dispatch = "0.3" hydroflow = "0.5.0" itertools.workspace = true num-traits = "0.2" @@ -27,3 +28,6 @@ session.workspace = true snafu.workspace = true tokio.workspace = true tonic.workspace = true + +[dev-dependencies] +serde_json = "1.0" diff --git a/src/flow/src/expr/error.rs b/src/flow/src/expr/error.rs index 233538fb6564..9de189231670 100644 --- a/src/flow/src/expr/error.rs +++ b/src/flow/src/expr/error.rs @@ -61,4 +61,7 @@ pub enum EvalError { #[snafu(display("Unsupported temporal filter: {reason}"))] UnsupportedTemporalFilter { reason: String, location: Location }, + + #[snafu(display("Overflowed during evaluation"))] + Overflow { location: Location }, } diff --git a/src/flow/src/expr/relation/accum.rs b/src/flow/src/expr/relation/accum.rs index e2b136e849d6..06df89eb8ee9 100644 --- a/src/flow/src/expr/relation/accum.rs +++ b/src/flow/src/expr/relation/accum.rs @@ -14,7 +14,10 @@ //! Accumulators for aggregate functions that's is accumulatable. i.e. sum/count //! -//! Currently support sum, count, any, all +//! Accumulator will only be restore from row and being updated every time dataflow need process a new batch of rows. +//! So the overhead is acceptable. +//! +//! Currently support sum, count, any, all and min/max(with one caveat that min/max can't support delete with aggregate). use std::fmt::Display; @@ -22,13 +25,506 @@ use common_decimal::Decimal128; use common_time::{Date, DateTime}; use datatypes::data_type::ConcreteDataType; use datatypes::value::{OrderedF32, OrderedF64, OrderedFloat, Value}; +use enum_dispatch::enum_dispatch; use hydroflow::futures::stream::Concat; use serde::{Deserialize, Serialize}; +use snafu::ensure; -use crate::expr::error::{InternalSnafu, TryFromValueSnafu, TypeMismatchSnafu}; +use crate::expr::error::{InternalSnafu, OverflowSnafu, TryFromValueSnafu, TypeMismatchSnafu}; +use crate::expr::relation::func::GenericFn; use crate::expr::{AggregateFunc, EvalError}; use crate::repr::Diff; +/// Accumulates values for the various types of accumulable aggregations. +#[enum_dispatch] +pub trait Accumulator: Sized { + fn into_state(self) -> Vec; + fn update( + &mut self, + aggr_fn: &AggregateFunc, + value: Value, + diff: Diff, + ) -> Result<(), EvalError>; + + fn update_batch(&mut self, aggr_fn: &AggregateFunc, value_diffs: I) -> Result<(), EvalError> + where + I: IntoIterator, + { + for (v, d) in value_diffs { + self.update(aggr_fn, v, d)?; + } + Ok(()) + } + + fn eval(&self, aggr_fn: &AggregateFunc) -> Result; +} + +/// Bool accumulator, used for `Any` `All` `Max/MinBool` +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)] +pub struct Bool { + /// The number of `true` values observed. + trues: Diff, + /// The number of `false` values observed. + falses: Diff, +} + +impl TryFrom> for Bool { + type Error = EvalError; + + fn try_from(state: Vec) -> Result { + ensure!( + state.len() == 2, + InternalSnafu { + reason: "Bool Accumulator state should have 2 values", + } + ); + + let mut iter = state.into_iter(); + + Ok(Self { + trues: Diff::try_from(iter.next().unwrap()).map_err(err_try_from_val)?, + falses: Diff::try_from(iter.next().unwrap()).map_err(err_try_from_val)?, + }) + } +} + +impl Accumulator for Bool { + fn into_state(self) -> Vec { + vec![self.trues.into(), self.falses.into()] + } + + /// Null values are ignored + fn update( + &mut self, + aggr_fn: &AggregateFunc, + value: Value, + diff: Diff, + ) -> Result<(), EvalError> { + ensure!( + matches!( + aggr_fn, + AggregateFunc::Any + | AggregateFunc::All + | AggregateFunc::MaxBool + | AggregateFunc::MinBool + ), + InternalSnafu { + reason: format!( + "Bool Accumulator does not support this aggregation function: {:?}", + aggr_fn + ), + } + ); + + match value { + Value::Boolean(true) => self.trues += diff, + Value::Boolean(false) => self.falses += diff, + Value::Null => (), // ignore nulls + x => { + return Err(TypeMismatchSnafu { + expected: ConcreteDataType::boolean_datatype(), + actual: x.data_type(), + } + .build()); + } + }; + Ok(()) + } + + fn eval(&self, aggr_fn: &AggregateFunc) -> Result { + match aggr_fn { + AggregateFunc::Any => Ok(Value::from(self.trues > 0)), + AggregateFunc::All => Ok(Value::from(self.falses == 0)), + AggregateFunc::MaxBool => Ok(Value::from(self.trues > 0)), + AggregateFunc::MinBool => Ok(Value::from(self.falses == 0)), + _ => Err(InternalSnafu { + reason: format!( + "Bool Accumulator does not support this aggregation function: {:?}", + aggr_fn + ), + } + .build()), + } + } +} + +/// Accumulates simple numeric values for sum over integer. +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)] +pub struct SimpleNumber { + /// The accumulation of all non-NULL values observed. + accum: i128, + /// The number of non-NULL values observed. + non_nulls: Diff, +} + +impl TryFrom> for SimpleNumber { + type Error = EvalError; + + fn try_from(state: Vec) -> Result { + ensure!( + state.len() == 2, + InternalSnafu { + reason: "Number Accumulator state should have 2 values", + } + ); + let mut iter = state.into_iter(); + + Ok(Self { + accum: Decimal128::try_from(iter.next().unwrap()) + .map_err(err_try_from_val)? + .val(), + non_nulls: Diff::try_from(iter.next().unwrap()).map_err(err_try_from_val)?, + }) + } +} + +impl Accumulator for SimpleNumber { + fn into_state(self) -> Vec { + vec![ + Value::Decimal128(Decimal128::new(self.accum, 38, 0)), + self.non_nulls.into(), + ] + } + + fn update( + &mut self, + aggr_fn: &AggregateFunc, + value: Value, + diff: Diff, + ) -> Result<(), EvalError> { + ensure!( + matches!( + aggr_fn, + AggregateFunc::SumInt16 + | AggregateFunc::SumInt32 + | AggregateFunc::SumInt64 + | AggregateFunc::SumUInt16 + | AggregateFunc::SumUInt32 + | AggregateFunc::SumUInt64 + ), + InternalSnafu { + reason: format!( + "SimpleNumber Accumulator does not support this aggregation function: {:?}", + aggr_fn + ), + } + ); + + let v = match (aggr_fn, value) { + (AggregateFunc::SumInt16, Value::Int16(x)) => i128::from(x), + (AggregateFunc::SumInt32, Value::Int32(x)) => i128::from(x), + (AggregateFunc::SumInt64, Value::Int64(x)) => i128::from(x), + (AggregateFunc::SumUInt16, Value::UInt16(x)) => i128::from(x), + (AggregateFunc::SumUInt32, Value::UInt32(x)) => i128::from(x), + (AggregateFunc::SumUInt64, Value::UInt64(x)) => i128::from(x), + (_f, Value::Null) => return Ok(()), // ignore null + (f, v) => { + let expected_datatype = f.signature().input; + return Err(TypeMismatchSnafu { + expected: expected_datatype, + actual: v.data_type(), + } + .build())?; + } + }; + + self.accum += v * i128::from(diff); + + self.non_nulls += diff; + Ok(()) + } + + fn eval(&self, aggr_fn: &AggregateFunc) -> Result { + match aggr_fn { + AggregateFunc::SumInt16 | AggregateFunc::SumInt32 | AggregateFunc::SumInt64 => { + i64::try_from(self.accum) + .map_err(|_e| OverflowSnafu {}.build()) + .map(Value::from) + } + AggregateFunc::SumUInt16 | AggregateFunc::SumUInt32 | AggregateFunc::SumUInt64 => { + u64::try_from(self.accum) + .map_err(|_e| OverflowSnafu {}.build()) + .map(Value::from) + } + _ => Err(InternalSnafu { + reason: format!( + "SimpleNumber Accumulator does not support this aggregation function: {:?}", + aggr_fn + ), + } + .build()), + } + } +} +/// Accumulates float values for sum over floating numbers. +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)] + +pub struct Float { + /// Accumulates non-special float values, i.e. not NaN, +inf, -inf. + /// accum will be set to zero if `non_nulls` is zero. + accum: OrderedF64, + /// Counts +inf + pos_infs: Diff, + /// Counts -inf + neg_infs: Diff, + /// Counts NaNs + nans: Diff, + /// Counts non-NULL values + non_nulls: Diff, +} + +impl TryFrom> for Float { + type Error = EvalError; + + fn try_from(state: Vec) -> Result { + ensure!( + state.len() == 5, + InternalSnafu { + reason: "Float Accumulator state should have 5 values", + } + ); + + let mut iter = state.into_iter(); + + let mut ret = Self { + accum: OrderedF64::try_from(iter.next().unwrap()).map_err(err_try_from_val)?, + pos_infs: Diff::try_from(iter.next().unwrap()).map_err(err_try_from_val)?, + neg_infs: Diff::try_from(iter.next().unwrap()).map_err(err_try_from_val)?, + nans: Diff::try_from(iter.next().unwrap()).map_err(err_try_from_val)?, + non_nulls: Diff::try_from(iter.next().unwrap()).map_err(err_try_from_val)?, + }; + + // This prevent counter-intuitive behavior of summing over no values + if ret.non_nulls == 0 { + ret.accum = OrderedFloat::from(0.0); + } + + Ok(ret) + } +} + +impl Accumulator for Float { + fn into_state(self) -> Vec { + vec![ + self.accum.into(), + self.pos_infs.into(), + self.neg_infs.into(), + self.nans.into(), + self.non_nulls.into(), + ] + } + + /// sum ignore null + fn update( + &mut self, + aggr_fn: &AggregateFunc, + value: Value, + diff: Diff, + ) -> Result<(), EvalError> { + ensure!( + matches!( + aggr_fn, + AggregateFunc::SumFloat32 | AggregateFunc::SumFloat64 + ), + InternalSnafu { + reason: format!( + "Float Accumulator does not support this aggregation function: {:?}", + aggr_fn + ), + } + ); + + let x = match (aggr_fn, value) { + (AggregateFunc::SumFloat32, Value::Float32(x)) => OrderedF64::from(*x as f64), + (AggregateFunc::SumFloat64, Value::Float64(x)) => OrderedF64::from(x), + (_f, Value::Null) => return Ok(()), // ignore null + (f, v) => { + let expected_datatype = f.signature().input; + return Err(TypeMismatchSnafu { + expected: expected_datatype, + actual: v.data_type(), + } + .build())?; + } + }; + + if x.is_nan() { + self.nans += diff; + } else if x.is_infinite() { + if x.is_sign_positive() { + self.pos_infs += diff; + } else { + self.neg_infs += diff; + } + } else { + self.accum += *(x * OrderedF64::from(diff as f64)); + } + + self.non_nulls += diff; + Ok(()) + } + + fn eval(&self, aggr_fn: &AggregateFunc) -> Result { + match aggr_fn { + AggregateFunc::SumFloat32 => Ok(Value::Float32(OrderedF32::from(self.accum.0 as f32))), + AggregateFunc::SumFloat64 => Ok(Value::Float64(self.accum)), + _ => Err(InternalSnafu { + reason: format!( + "Float Accumulator does not support this aggregation function: {:?}", + aggr_fn + ), + } + .build()), + } + } +} + +/// Accumulates a single `Ord`ed `Value`, useful for min/max aggregations. +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)] +pub struct OrdValue { + val: Option, + non_nulls: Diff, +} + +impl TryFrom> for OrdValue { + type Error = EvalError; + + fn try_from(state: Vec) -> Result { + ensure!( + state.len() == 2, + InternalSnafu { + reason: "OrdValue Accumulator state should have 2 values", + } + ); + + let mut iter = state.into_iter(); + + Ok(Self { + val: { + let v = iter.next().unwrap(); + if v == Value::Null { + None + } else { + Some(v) + } + }, + non_nulls: Diff::try_from(iter.next().unwrap()).map_err(err_try_from_val)?, + }) + } +} + +impl Accumulator for OrdValue { + fn into_state(self) -> Vec { + vec![self.val.unwrap_or(Value::Null), self.non_nulls.into()] + } + + /// min/max try to find results in all non-null values, if all values are null, the result is null. + /// count(col_name) gives the number of non-null values, count(*) gives the number of rows including nulls. + /// TODO(discord9): add count(*) as a aggr function + fn update( + &mut self, + aggr_fn: &AggregateFunc, + value: Value, + diff: Diff, + ) -> Result<(), EvalError> { + ensure!( + aggr_fn.is_max() || aggr_fn.is_min() || matches!(aggr_fn, AggregateFunc::Count), + InternalSnafu { + reason: format!( + "OrdValue Accumulator does not support this aggregation function: {:?}", + aggr_fn + ), + } + ); + if diff <= 0 && (aggr_fn.is_max() || aggr_fn.is_min()) { + return Err(InternalSnafu { + reason: "OrdValue Accumulator does not support non-monotonic input for min/max aggregation".to_string(), + }.build()); + } + + // if aggr_fn is count, the incoming value type doesn't matter in type checking + // otherwise, type need to be the same or value can be null + let check_type_aggr_fn_and_arg_value = + ty_eq_without_precision(value.data_type(), aggr_fn.signature().input) + || matches!(aggr_fn, AggregateFunc::Count) + || value.is_null(); + let check_type_aggr_fn_and_self_val = self + .val + .as_ref() + .map(|zelf| ty_eq_without_precision(zelf.data_type(), aggr_fn.signature().input)) + .unwrap_or(true) + || matches!(aggr_fn, AggregateFunc::Count); + + if !check_type_aggr_fn_and_arg_value { + return Err(TypeMismatchSnafu { + expected: aggr_fn.signature().input, + actual: value.data_type(), + } + .build()); + } else if !check_type_aggr_fn_and_self_val { + return Err(TypeMismatchSnafu { + expected: aggr_fn.signature().input, + actual: self + .val + .as_ref() + .map(|v| v.data_type()) + .unwrap_or(ConcreteDataType::null_datatype()), + } + .build()); + } + + let is_null = value.is_null(); + if is_null { + return Ok(()); + } + + if !is_null { + // compile count(*) to count(true) to include null/non-nulls + // And the counts of non-null values are updated here + self.non_nulls += diff; + + match aggr_fn.signature().generic_fn { + GenericFn::Max => { + self.val = self + .val + .clone() + .map(|v| v.max(value.clone())) + .or_else(|| Some(value)) + } + GenericFn::Min => { + self.val = self + .val + .clone() + .map(|v| v.min(value.clone())) + .or_else(|| Some(value)) + } + + GenericFn::Count => (), + _ => unreachable!("already checked by ensure!"), + } + }; + // min/max ignore nulls + + Ok(()) + } + + fn eval(&self, aggr_fn: &AggregateFunc) -> Result { + if aggr_fn.is_max() || aggr_fn.is_min() { + Ok(self.val.clone().unwrap_or(Value::Null)) + } else if matches!(aggr_fn, AggregateFunc::Count) { + Ok(self.non_nulls.into()) + } else { + Err(InternalSnafu { + reason: format!( + "OrdValue Accumulator does not support this aggregation function: {:?}", + aggr_fn + ), + } + .build()) + } + } +} + /// Accumulates values for the various types of accumulable aggregations. /// /// We assume that there are not more than 2^32 elements for the aggregation. @@ -38,34 +534,407 @@ use crate::repr::Diff; /// The float accumulator performs accumulation with tolerance for floating point error. /// /// TODO(discord9): check for overflowing +#[enum_dispatch(Accumulator)] #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)] pub enum Accum { /// Accumulates boolean values. - Bool { - /// The number of `true` values observed. - trues: Diff, - /// The number of `false` values observed. - falses: Diff, - }, + Bool(Bool), /// Accumulates simple numeric values. - SimpleNumber { - /// The accumulation of all non-NULL values observed. - accum: i128, - /// The number of non-NULL values observed. - non_nulls: Diff, - }, + SimpleNumber(SimpleNumber), /// Accumulates float values. - Float { - /// Accumulates non-special float values, i.e. not NaN, +inf, -inf. - /// accum will be set to zero if `non_nulls` is zero. - accum: OrderedF64, - /// Counts +inf - pos_infs: Diff, - /// Counts -inf - neg_infs: Diff, - /// Counts NaNs - nans: Diff, - /// Counts non-NULL values - non_nulls: Diff, - }, + Float(Float), + /// Accumulate Values that impl `Ord` + OrdValue(OrdValue), +} + +impl Accum { + pub fn new_accum(aggr_fn: &AggregateFunc) -> Result { + Ok(match aggr_fn { + AggregateFunc::Any + | AggregateFunc::All + | AggregateFunc::MaxBool + | AggregateFunc::MinBool => Self::from(Bool { + trues: 0, + falses: 0, + }), + AggregateFunc::SumInt16 + | AggregateFunc::SumInt32 + | AggregateFunc::SumInt64 + | AggregateFunc::SumUInt16 + | AggregateFunc::SumUInt32 + | AggregateFunc::SumUInt64 => Self::from(SimpleNumber { + accum: 0, + non_nulls: 0, + }), + AggregateFunc::SumFloat32 | AggregateFunc::SumFloat64 => Self::from(Float { + accum: OrderedF64::from(0.0), + pos_infs: 0, + neg_infs: 0, + nans: 0, + non_nulls: 0, + }), + f if f.is_max() || f.is_min() || matches!(f, AggregateFunc::Count) => { + Self::from(OrdValue { + val: None, + non_nulls: 0, + }) + } + f => { + return Err(InternalSnafu { + reason: format!( + "Accumulator does not support this aggregation function: {:?}", + f + ), + } + .build()); + } + }) + } + pub fn try_into_accum(aggr_fn: &AggregateFunc, state: Vec) -> Result { + match aggr_fn { + AggregateFunc::Any + | AggregateFunc::All + | AggregateFunc::MaxBool + | AggregateFunc::MinBool => Ok(Self::from(Bool::try_from(state)?)), + AggregateFunc::SumInt16 + | AggregateFunc::SumInt32 + | AggregateFunc::SumInt64 + | AggregateFunc::SumUInt16 + | AggregateFunc::SumUInt32 + | AggregateFunc::SumUInt64 => Ok(Self::from(SimpleNumber::try_from(state)?)), + AggregateFunc::SumFloat32 | AggregateFunc::SumFloat64 => { + Ok(Self::from(Float::try_from(state)?)) + } + f if f.is_max() || f.is_min() || matches!(f, AggregateFunc::Count) => { + Ok(Self::from(OrdValue::try_from(state)?)) + } + f => Err(InternalSnafu { + reason: format!( + "Accumulator does not support this aggregation function: {:?}", + f + ), + } + .build()), + } + } +} + +fn err_try_from_val(reason: T) -> EvalError { + TryFromValueSnafu { + msg: reason.to_string(), + } + .build() +} + +/// compare type while ignore their precision, including `TimeStamp`, `Time`, +/// `Duration`, `Interval` +fn ty_eq_without_precision(left: ConcreteDataType, right: ConcreteDataType) -> bool { + left == right + || matches!(left, ConcreteDataType::Timestamp(..)) + && matches!(right, ConcreteDataType::Timestamp(..)) + || matches!(left, ConcreteDataType::Time(..)) && matches!(right, ConcreteDataType::Time(..)) + || matches!(left, ConcreteDataType::Duration(..)) + && matches!(right, ConcreteDataType::Duration(..)) + || matches!(left, ConcreteDataType::Interval(..)) + && matches!(right, ConcreteDataType::Interval(..)) +} + +#[cfg(test)] +mod test { + use super::*; + #[test] + fn test_accum() { + let testcases = vec![ + ( + AggregateFunc::SumInt32, + vec![(Value::Int32(1), 1), (Value::Null, 1)], + ( + Value::Int64(1), + vec![Value::Decimal128(Decimal128::new(1, 38, 0)), 1i64.into()], + ), + ), + ( + AggregateFunc::SumFloat32, + vec![(Value::Float32(OrderedF32::from(1.0)), 1), (Value::Null, 1)], + ( + Value::Float32(OrderedF32::from(1.0)), + vec![ + Value::Float64(OrderedF64::from(1.0)), + 0i64.into(), + 0i64.into(), + 0i64.into(), + 1i64.into(), + ], + ), + ), + ( + AggregateFunc::MaxInt32, + vec![(Value::Int32(1), 1), (Value::Int32(2), 1), (Value::Null, 1)], + (Value::Int32(2), vec![Value::Int32(2), 2i64.into()]), + ), + ( + AggregateFunc::MinInt32, + vec![(Value::Int32(2), 1), (Value::Int32(1), 1), (Value::Null, 1)], + (Value::Int32(1), vec![Value::Int32(1), 2i64.into()]), + ), + ( + AggregateFunc::MaxFloat32, + vec![ + (Value::Float32(OrderedF32::from(1.0)), 1), + (Value::Float32(OrderedF32::from(2.0)), 1), + (Value::Null, 1), + ], + ( + Value::Float32(OrderedF32::from(2.0)), + vec![Value::Float32(OrderedF32::from(2.0)), 2i64.into()], + ), + ), + ( + AggregateFunc::MaxDateTime, + vec![ + (Value::DateTime(DateTime::from(0)), 1), + (Value::DateTime(DateTime::from(1)), 1), + (Value::Null, 1), + ], + ( + Value::DateTime(DateTime::from(1)), + vec![Value::DateTime(DateTime::from(1)), 2i64.into()], + ), + ), + ( + AggregateFunc::Count, + vec![ + (Value::Int32(1), 1), + (Value::Int32(2), 1), + (Value::Null, 1), + (Value::Null, 1), + ], + (2i64.into(), vec![Value::Null, 2i64.into()]), + ), + ( + AggregateFunc::Any, + vec![ + (Value::Boolean(false), 1), + (Value::Boolean(false), 1), + (Value::Boolean(true), 1), + (Value::Null, 1), + ], + ( + Value::Boolean(true), + vec![Value::from(1i64), Value::from(2i64)], + ), + ), + ( + AggregateFunc::All, + vec![ + (Value::Boolean(false), 1), + (Value::Boolean(false), 1), + (Value::Boolean(true), 1), + (Value::Null, 1), + ], + ( + Value::Boolean(false), + vec![Value::from(1i64), Value::from(2i64)], + ), + ), + ( + AggregateFunc::MaxBool, + vec![ + (Value::Boolean(false), 1), + (Value::Boolean(false), 1), + (Value::Boolean(true), 1), + (Value::Null, 1), + ], + ( + Value::Boolean(true), + vec![Value::from(1i64), Value::from(2i64)], + ), + ), + ( + AggregateFunc::MinBool, + vec![ + (Value::Boolean(false), 1), + (Value::Boolean(false), 1), + (Value::Boolean(true), 1), + (Value::Null, 1), + ], + ( + Value::Boolean(false), + vec![Value::from(1i64), Value::from(2i64)], + ), + ), + ]; + + for (aggr_fn, input, (eval_res, state)) in testcases { + let create_and_insert = || -> Result { + let mut acc = Accum::new_accum(&aggr_fn)?; + acc.update_batch(&aggr_fn, input.clone())?; + let row = acc.into_state(); + let acc = Accum::try_into_accum(&aggr_fn, row)?; + Ok(acc) + }; + let acc = match create_and_insert() { + Ok(acc) => acc, + Err(err) => panic!( + "Failed to create accum for {:?} with input {:?} with error: {:?}", + aggr_fn, input, err + ), + }; + + if acc.eval(&aggr_fn).unwrap() != eval_res { + panic!( + "Failed to eval accum for {:?} with input {:?}, expect {:?}, got {:?}", + aggr_fn, + input, + eval_res, + acc.eval(&aggr_fn).unwrap() + ); + } + let actual_state = acc.into_state(); + if actual_state != state { + panic!( + "Failed to cast into state from accum for {:?} with input {:?}, expect state {:?}, got state {:?}", + aggr_fn, + input, + state, + actual_state + ); + } + } + } + #[test] + fn test_fail_path_accum() { + { + let bool_accum = Bool::try_from(vec![Value::Null]); + assert!(matches!(bool_accum, Err(EvalError::Internal { .. }))); + } + + { + let mut bool_accum = Bool::try_from(vec![1i64.into(), 1i64.into()]).unwrap(); + // serde + let bool_accum_serde = serde_json::to_string(&bool_accum).unwrap(); + let bool_accum_de = serde_json::from_str::(&bool_accum_serde).unwrap(); + assert_eq!(bool_accum, bool_accum_de); + assert!(matches!( + bool_accum.update(&AggregateFunc::MaxDate, 1.into(), 1), + Err(EvalError::Internal { .. }) + )); + assert!(matches!( + bool_accum.update(&AggregateFunc::Any, 1.into(), 1), + Err(EvalError::TypeMismatch { .. }) + )); + assert!(matches!( + bool_accum.eval(&AggregateFunc::MaxDate), + Err(EvalError::Internal { .. }) + )); + } + + { + let ret = SimpleNumber::try_from(vec![Value::Null]); + assert!(matches!(ret, Err(EvalError::Internal { .. }))); + let mut accum = + SimpleNumber::try_from(vec![Decimal128::new(0, 38, 0).into(), 0i64.into()]) + .unwrap(); + + assert!(matches!( + accum.update(&AggregateFunc::All, 0.into(), 1), + Err(EvalError::Internal { .. }) + )); + assert!(matches!( + accum.update(&AggregateFunc::SumInt64, 0i32.into(), 1), + Err(EvalError::TypeMismatch { .. }) + )); + assert!(matches!( + accum.eval(&AggregateFunc::All), + Err(EvalError::Internal { .. }) + )); + accum + .update(&AggregateFunc::SumInt64, 1i64.into(), 1) + .unwrap(); + accum + .update(&AggregateFunc::SumInt64, i64::MAX.into(), 1) + .unwrap(); + assert!(matches!( + accum.eval(&AggregateFunc::SumInt64), + Err(EvalError::Overflow { .. }) + )); + } + + { + let ret = Float::try_from(vec![2f64.into(), 0i64.into(), 0i64.into(), 0i64.into()]); + assert!(matches!(ret, Err(EvalError::Internal { .. }))); + let mut accum = Float::try_from(vec![ + 2f64.into(), + 0i64.into(), + 0i64.into(), + 0i64.into(), + 1i64.into(), + ]) + .unwrap(); + accum + .update(&AggregateFunc::SumFloat64, 2f64.into(), -1) + .unwrap(); + assert!(matches!( + accum.update(&AggregateFunc::All, 0.into(), 1), + Err(EvalError::Internal { .. }) + )); + assert!(matches!( + accum.update(&AggregateFunc::SumFloat64, 0.0f32.into(), 1), + Err(EvalError::TypeMismatch { .. }) + )); + // no record, no accum + assert_eq!( + accum.eval(&AggregateFunc::SumFloat64).unwrap(), + 0.0f64.into() + ); + + assert!(matches!( + accum.eval(&AggregateFunc::All), + Err(EvalError::Internal { .. }) + )); + + accum + .update(&AggregateFunc::SumFloat64, f64::INFINITY.into(), 1) + .unwrap(); + accum + .update(&AggregateFunc::SumFloat64, (-f64::INFINITY).into(), 1) + .unwrap(); + accum + .update(&AggregateFunc::SumFloat64, f64::NAN.into(), 1) + .unwrap(); + } + + { + let ret = OrdValue::try_from(vec![Value::Null]); + assert!(matches!(ret, Err(EvalError::Internal { .. }))); + let mut accum = OrdValue::try_from(vec![Value::Null, 0i64.into()]).unwrap(); + assert!(matches!( + accum.update(&AggregateFunc::All, 0.into(), 1), + Err(EvalError::Internal { .. }) + )); + accum + .update(&AggregateFunc::MaxInt16, 1i16.into(), 1) + .unwrap(); + assert!(matches!( + accum.update(&AggregateFunc::MaxInt16, 0i32.into(), 1), + Err(EvalError::TypeMismatch { .. }) + )); + assert!(matches!( + accum.update(&AggregateFunc::MaxInt16, 0i16.into(), -1), + Err(EvalError::Internal { .. }) + )); + accum + .update(&AggregateFunc::MaxInt16, Value::Null, 1) + .unwrap(); + } + + // insert uint64 into max_int64 should fail + { + let mut accum = OrdValue::try_from(vec![Value::Null, 0i64.into()]).unwrap(); + assert!(matches!( + accum.update(&AggregateFunc::MaxInt64, 0u64.into(), 1), + Err(EvalError::TypeMismatch { .. }) + )); + } + } } diff --git a/src/flow/src/expr/relation/func.rs b/src/flow/src/expr/relation/func.rs index 4c82281736bf..8d28533d8c96 100644 --- a/src/flow/src/expr/relation/func.rs +++ b/src/flow/src/expr/relation/func.rs @@ -12,15 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::any::type_name; - use common_time::{Date, DateTime}; use datatypes::prelude::ConcreteDataType; use datatypes::value::{OrderedF32, OrderedF64, Value}; use serde::{Deserialize, Serialize}; use crate::expr::error::{EvalError, TryFromValueSnafu, TypeMismatchSnafu}; -use crate::expr::relation::accum::Accum; +use crate::expr::relation::accum::{Accum, Accumulator}; use crate::repr::Diff; /// Aggregate functions that can be applied to a group of rows. @@ -83,3 +81,280 @@ pub enum AggregateFunc { Any, All, } + +impl AggregateFunc { + pub fn is_max(&self) -> bool { + self.signature().generic_fn == GenericFn::Max + } + + pub fn is_min(&self) -> bool { + self.signature().generic_fn == GenericFn::Min + } + + pub fn is_sum(&self) -> bool { + self.signature().generic_fn == GenericFn::Sum + } + + /// Eval value, diff with accumulator + /// + /// Expect self to be accumulable aggregate functio, i.e. sum/count + /// + /// TODO(discord9): deal with overflow&better accumulator + pub fn eval_diff_accumulable( + &self, + accum: Vec, + value_diffs: I, + ) -> Result<(Value, Vec), EvalError> + where + I: IntoIterator, + { + let mut accum = if accum.is_empty() { + Accum::new_accum(self)? + } else { + Accum::try_into_accum(self, accum)? + }; + accum.update_batch(self, value_diffs)?; + let res = accum.eval(self)?; + Ok((res, accum.into_state())) + } +} + +pub struct Signature { + pub input: ConcreteDataType, + pub output: ConcreteDataType, + pub generic_fn: GenericFn, +} + +#[derive(Debug, PartialEq, Eq)] +pub enum GenericFn { + Max, + Min, + Sum, + Count, + Any, + All, +} + +impl AggregateFunc { + /// all concrete datatypes with precision types will be returned with largest possible variant + /// as a exception, count have a signature of `null -> i64`, but it's actually `anytype -> i64` + pub fn signature(&self) -> Signature { + match self { + AggregateFunc::MaxInt16 => Signature { + input: ConcreteDataType::int16_datatype(), + output: ConcreteDataType::int16_datatype(), + generic_fn: GenericFn::Max, + }, + AggregateFunc::MaxInt32 => Signature { + input: ConcreteDataType::int32_datatype(), + output: ConcreteDataType::int32_datatype(), + generic_fn: GenericFn::Max, + }, + AggregateFunc::MaxInt64 => Signature { + input: ConcreteDataType::int64_datatype(), + output: ConcreteDataType::int64_datatype(), + generic_fn: GenericFn::Max, + }, + AggregateFunc::MaxUInt16 => Signature { + input: ConcreteDataType::uint16_datatype(), + output: ConcreteDataType::uint16_datatype(), + generic_fn: GenericFn::Max, + }, + AggregateFunc::MaxUInt32 => Signature { + input: ConcreteDataType::uint32_datatype(), + output: ConcreteDataType::uint32_datatype(), + generic_fn: GenericFn::Max, + }, + AggregateFunc::MaxUInt64 => Signature { + input: ConcreteDataType::uint64_datatype(), + output: ConcreteDataType::uint64_datatype(), + generic_fn: GenericFn::Max, + }, + AggregateFunc::MaxFloat32 => Signature { + input: ConcreteDataType::float32_datatype(), + output: ConcreteDataType::float32_datatype(), + generic_fn: GenericFn::Max, + }, + AggregateFunc::MaxFloat64 => Signature { + input: ConcreteDataType::float64_datatype(), + output: ConcreteDataType::float64_datatype(), + generic_fn: GenericFn::Max, + }, + AggregateFunc::MaxBool => Signature { + input: ConcreteDataType::boolean_datatype(), + output: ConcreteDataType::boolean_datatype(), + generic_fn: GenericFn::Max, + }, + AggregateFunc::MaxString => Signature { + input: ConcreteDataType::string_datatype(), + output: ConcreteDataType::string_datatype(), + generic_fn: GenericFn::Max, + }, + AggregateFunc::MaxDate => Signature { + input: ConcreteDataType::date_datatype(), + output: ConcreteDataType::date_datatype(), + generic_fn: GenericFn::Max, + }, + AggregateFunc::MaxDateTime => Signature { + input: ConcreteDataType::datetime_datatype(), + output: ConcreteDataType::datetime_datatype(), + generic_fn: GenericFn::Max, + }, + AggregateFunc::MaxTimestamp => Signature { + input: ConcreteDataType::timestamp_second_datatype(), + output: ConcreteDataType::timestamp_second_datatype(), + generic_fn: GenericFn::Max, + }, + AggregateFunc::MaxTime => Signature { + input: ConcreteDataType::time_second_datatype(), + output: ConcreteDataType::time_second_datatype(), + generic_fn: GenericFn::Max, + }, + AggregateFunc::MaxDuration => Signature { + input: ConcreteDataType::duration_second_datatype(), + output: ConcreteDataType::duration_second_datatype(), + generic_fn: GenericFn::Max, + }, + AggregateFunc::MaxInterval => Signature { + input: ConcreteDataType::interval_year_month_datatype(), + output: ConcreteDataType::interval_year_month_datatype(), + generic_fn: GenericFn::Max, + }, + AggregateFunc::MinInt16 => Signature { + input: ConcreteDataType::int16_datatype(), + output: ConcreteDataType::int16_datatype(), + generic_fn: GenericFn::Min, + }, + AggregateFunc::MinInt32 => Signature { + input: ConcreteDataType::int32_datatype(), + output: ConcreteDataType::int32_datatype(), + generic_fn: GenericFn::Min, + }, + AggregateFunc::MinInt64 => Signature { + input: ConcreteDataType::int64_datatype(), + output: ConcreteDataType::int64_datatype(), + generic_fn: GenericFn::Min, + }, + AggregateFunc::MinUInt16 => Signature { + input: ConcreteDataType::uint16_datatype(), + output: ConcreteDataType::uint16_datatype(), + generic_fn: GenericFn::Min, + }, + AggregateFunc::MinUInt32 => Signature { + input: ConcreteDataType::uint32_datatype(), + output: ConcreteDataType::uint32_datatype(), + generic_fn: GenericFn::Min, + }, + AggregateFunc::MinUInt64 => Signature { + input: ConcreteDataType::uint64_datatype(), + output: ConcreteDataType::uint64_datatype(), + generic_fn: GenericFn::Min, + }, + AggregateFunc::MinFloat32 => Signature { + input: ConcreteDataType::float32_datatype(), + output: ConcreteDataType::float32_datatype(), + generic_fn: GenericFn::Min, + }, + AggregateFunc::MinFloat64 => Signature { + input: ConcreteDataType::float64_datatype(), + output: ConcreteDataType::float64_datatype(), + generic_fn: GenericFn::Min, + }, + AggregateFunc::MinBool => Signature { + input: ConcreteDataType::boolean_datatype(), + output: ConcreteDataType::boolean_datatype(), + generic_fn: GenericFn::Min, + }, + AggregateFunc::MinString => Signature { + input: ConcreteDataType::string_datatype(), + output: ConcreteDataType::string_datatype(), + generic_fn: GenericFn::Min, + }, + AggregateFunc::MinDate => Signature { + input: ConcreteDataType::date_datatype(), + output: ConcreteDataType::date_datatype(), + generic_fn: GenericFn::Min, + }, + AggregateFunc::MinDateTime => Signature { + input: ConcreteDataType::datetime_datatype(), + output: ConcreteDataType::datetime_datatype(), + generic_fn: GenericFn::Min, + }, + AggregateFunc::MinTimestamp => Signature { + input: ConcreteDataType::timestamp_second_datatype(), + output: ConcreteDataType::timestamp_second_datatype(), + generic_fn: GenericFn::Min, + }, + AggregateFunc::MinTime => Signature { + input: ConcreteDataType::time_second_datatype(), + output: ConcreteDataType::time_second_datatype(), + generic_fn: GenericFn::Min, + }, + AggregateFunc::MinDuration => Signature { + input: ConcreteDataType::duration_second_datatype(), + output: ConcreteDataType::duration_second_datatype(), + generic_fn: GenericFn::Min, + }, + AggregateFunc::MinInterval => Signature { + input: ConcreteDataType::interval_year_month_datatype(), + output: ConcreteDataType::interval_year_month_datatype(), + generic_fn: GenericFn::Min, + }, + AggregateFunc::SumInt16 => Signature { + input: ConcreteDataType::int16_datatype(), + output: ConcreteDataType::int16_datatype(), + generic_fn: GenericFn::Sum, + }, + AggregateFunc::SumInt32 => Signature { + input: ConcreteDataType::int32_datatype(), + output: ConcreteDataType::int32_datatype(), + generic_fn: GenericFn::Sum, + }, + AggregateFunc::SumInt64 => Signature { + input: ConcreteDataType::int64_datatype(), + output: ConcreteDataType::int64_datatype(), + generic_fn: GenericFn::Sum, + }, + AggregateFunc::SumUInt16 => Signature { + input: ConcreteDataType::uint16_datatype(), + output: ConcreteDataType::uint16_datatype(), + generic_fn: GenericFn::Sum, + }, + AggregateFunc::SumUInt32 => Signature { + input: ConcreteDataType::uint32_datatype(), + output: ConcreteDataType::uint32_datatype(), + generic_fn: GenericFn::Sum, + }, + AggregateFunc::SumUInt64 => Signature { + input: ConcreteDataType::uint64_datatype(), + output: ConcreteDataType::uint64_datatype(), + generic_fn: GenericFn::Sum, + }, + AggregateFunc::SumFloat32 => Signature { + input: ConcreteDataType::float32_datatype(), + output: ConcreteDataType::float32_datatype(), + generic_fn: GenericFn::Sum, + }, + AggregateFunc::SumFloat64 => Signature { + input: ConcreteDataType::float64_datatype(), + output: ConcreteDataType::float64_datatype(), + generic_fn: GenericFn::Sum, + }, + AggregateFunc::Count => Signature { + input: ConcreteDataType::null_datatype(), + output: ConcreteDataType::int64_datatype(), + generic_fn: GenericFn::Count, + }, + AggregateFunc::Any => Signature { + input: ConcreteDataType::boolean_datatype(), + output: ConcreteDataType::boolean_datatype(), + generic_fn: GenericFn::Any, + }, + AggregateFunc::All => Signature { + input: ConcreteDataType::boolean_datatype(), + output: ConcreteDataType::boolean_datatype(), + generic_fn: GenericFn::All, + }, + } + } +} diff --git a/src/flow/src/repr.rs b/src/flow/src/repr.rs index 8c15483666e8..1a4547b87174 100644 --- a/src/flow/src/repr.rs +++ b/src/flow/src/repr.rs @@ -33,7 +33,10 @@ use snafu::ResultExt; use crate::expr::error::{CastValueSnafu, EvalError}; -/// System-wide Record count difference type. +/// System-wide Record count difference type. Useful for capture data change +/// +/// i.e. +1 means insert one record, -1 means remove, +/// and +/-n means insert/remove multiple duplicate records. pub type Diff = i64; /// System-wide default timestamp type