diff --git a/CHANGELOG.md b/CHANGELOG.md index fa6bcab0..8dd822b8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] ### Changed ### Added +- Implements the aggregation functions `ANY`, `SOME`, `EVERY` and their `COLL_` versions ### Fixes ## [0.4.1] - 2023-05-25 diff --git a/partiql-eval/src/eval/evaluable.rs b/partiql-eval/src/eval/evaluable.rs index f2f95aae..393fbfcf 100644 --- a/partiql-eval/src/eval/evaluable.rs +++ b/partiql-eval/src/eval/evaluable.rs @@ -317,6 +317,8 @@ pub(crate) enum AggFunc { Max(Max), Min(Min), Sum(Sum), + Any(Any), + Every(Every), } impl AggregateFunction for AggFunc { @@ -327,6 +329,8 @@ impl AggregateFunction for AggFunc { AggFunc::Max(v) => v.next_value(input_value, group), AggFunc::Min(v) => v.next_value(input_value, group), AggFunc::Sum(v) => v.next_value(input_value, group), + AggFunc::Any(v) => v.next_value(input_value, group), + AggFunc::Every(v) => v.next_value(input_value, group), } } @@ -337,6 +341,8 @@ impl AggregateFunction for AggFunc { AggFunc::Max(v) => v.compute(group), AggFunc::Min(v) => v.compute(group), AggFunc::Sum(v) => v.compute(group), + AggFunc::Any(v) => v.compute(group), + AggFunc::Every(v) => v.compute(group), } } } @@ -436,7 +442,7 @@ impl AggregateFunction for Avg { } fn compute(&self, group: &Tuple) -> Value { - match self.avgs.get(group).expect("Expect group to exist in avgs") { + match self.avgs.get(group).unwrap_or(&(0, Null)) { (0, _) => Null, (c, s) => s / &Value::Decimal(rust_decimal::Decimal::from(*c)), } @@ -483,11 +489,7 @@ impl AggregateFunction for Count { } fn compute(&self, group: &Tuple) -> Value { - Value::from( - self.counts - .get(group) - .expect("Expect group to exist in counts"), - ) + Value::from(self.counts.get(group).unwrap_or(&0)) } } @@ -531,10 +533,7 @@ impl AggregateFunction for Max { } fn compute(&self, group: &Tuple) -> Value { - self.maxes - .get(group) - .expect("Expect group to exist in sums") - .clone() + self.maxes.get(group).unwrap_or(&Null).clone() } } @@ -578,10 +577,7 @@ impl AggregateFunction for Min { } fn compute(&self, group: &Tuple) -> Value { - self.mins - .get(group) - .expect("Expect group to exist in mins") - .clone() + self.mins.get(group).unwrap_or(&Null).clone() } } @@ -625,10 +621,107 @@ impl AggregateFunction for Sum { } fn compute(&self, group: &Tuple) -> Value { - self.sums - .get(group) - .expect("Expect group to exist in sums") - .clone() + self.sums.get(group).unwrap_or(&Null).clone() + } +} + +/// Represents SQL's `ANY`/`SOME` aggregation function +#[derive(Debug)] +pub(crate) struct Any { + anys: HashMap, + aggregator: AggFilterFn, +} + +impl Any { + pub(crate) fn new_distinct() -> Self { + Any { + anys: HashMap::new(), + aggregator: AggFilterFn::Distinct(AggFilterDistinct::new()), + } + } + + pub(crate) fn new_all() -> Self { + Any { + anys: HashMap::new(), + aggregator: AggFilterFn::default(), + } + } +} + +impl AggregateFunction for Any { + fn next_value(&mut self, input_value: &Value, group: &Tuple) { + if !input_value.is_null_or_missing() + && self.aggregator.filter_value(input_value.clone(), group) + { + match self.anys.get_mut(group) { + None => { + match input_value { + Boolean(_) => self.anys.insert(group.clone(), input_value.clone()), + _ => self.anys.insert(group.clone(), Missing), + }; + } + Some(acc) => { + *acc = match (acc.clone(), input_value) { + (Boolean(l), Value::Boolean(r)) => Value::Boolean(l || *r), + (_, _) => Missing, + }; + } + } + } + } + + fn compute(&self, group: &Tuple) -> Value { + self.anys.get(group).unwrap_or(&Null).clone() + } +} + +/// Represents SQL's `EVERY` aggregation function +#[derive(Debug)] +pub(crate) struct Every { + everys: HashMap, + aggregator: AggFilterFn, +} + +impl Every { + pub(crate) fn new_distinct() -> Self { + Every { + everys: HashMap::new(), + aggregator: AggFilterFn::Distinct(AggFilterDistinct::new()), + } + } + + pub(crate) fn new_all() -> Self { + Every { + everys: HashMap::new(), + aggregator: AggFilterFn::default(), + } + } +} + +impl AggregateFunction for Every { + fn next_value(&mut self, input_value: &Value, group: &Tuple) { + if !input_value.is_null_or_missing() + && self.aggregator.filter_value(input_value.clone(), group) + { + match self.everys.get_mut(group) { + None => { + match input_value { + Boolean(_) => self.everys.insert(group.clone(), input_value.clone()), + _ => self.everys.insert(group.clone(), Missing), + }; + } + Some(acc) => { + *acc = match (acc.clone(), input_value) { + (Boolean(l), Value::Boolean(r)) => Value::Boolean(l && *r), + (_, _) => Missing, + }; + } + } + } + } + + fn compute(&self, group: &Tuple) -> Value { + self.everys.get(group).unwrap_or(&Null).clone() } } diff --git a/partiql-eval/src/eval/expr/mod.rs b/partiql-eval/src/eval/expr/mod.rs index 4210618c..8c4190b1 100644 --- a/partiql-eval/src/eval/expr/mod.rs +++ b/partiql-eval/src/eval/expr/mod.rs @@ -1434,3 +1434,117 @@ impl EvalExpr for EvalFnBaseTableExpr { Cow::Owned(result) } } + +/// Represents the `COLL_ANY`/`COLL_SOME` function, e.g. `COLL_ANY(DISTINCT [true, true, false])`. +#[derive(Debug)] +pub(crate) struct EvalFnCollAny { + pub(crate) setq: SetQuantifier, + pub(crate) elems: Box, +} + +#[inline] +#[track_caller] +fn coll_any(elems: Vec<&Value>) -> Value { + if elems.is_empty() { + Null + } else { + let mut any = false; + for e in elems { + match e { + Value::Boolean(b) => any = any || *b, + _ => return Missing, + } + } + Value::from(any) + } +} + +impl EvalExpr for EvalFnCollAny { + fn evaluate<'a>(&'a self, bindings: &'a Tuple, ctx: &'a dyn EvalContext) -> Cow<'a, Value> { + let elems = self.elems.evaluate(bindings, ctx); + let result = match elems.borrow() { + Null => Null, + Value::List(l) => { + let l_nums: Vec<&Value> = match self.setq { + SetQuantifier::All => l.iter().filter(|&e| !e.is_null_or_missing()).collect(), + SetQuantifier::Distinct => l + .iter() + .filter(|&e| !e.is_null_or_missing()) + .unique() + .collect(), + }; + coll_any(l_nums) + } + Value::Bag(b) => { + let b_nums: Vec<&Value> = match self.setq { + SetQuantifier::All => b.iter().filter(|&e| !e.is_null_or_missing()).collect(), + SetQuantifier::Distinct => b + .iter() + .filter(|&e| !e.is_null_or_missing()) + .unique() + .collect(), + }; + coll_any(b_nums) + } + _ => Missing, + }; + Cow::Owned(result) + } +} + +/// Represents the `COLL_EVERY` function, e.g. `COLL_EVERY(DISTINCT [true, true, false])`. +#[derive(Debug)] +pub(crate) struct EvalFnCollEvery { + pub(crate) setq: SetQuantifier, + pub(crate) elems: Box, +} + +#[inline] +#[track_caller] +fn coll_every(elems: Vec<&Value>) -> Value { + if elems.is_empty() { + Null + } else { + let mut every = true; + for e in elems { + match e { + Value::Boolean(b) => every = every && *b, + _ => return Missing, + } + } + Value::from(every) + } +} + +impl EvalExpr for EvalFnCollEvery { + fn evaluate<'a>(&'a self, bindings: &'a Tuple, ctx: &'a dyn EvalContext) -> Cow<'a, Value> { + let elems = self.elems.evaluate(bindings, ctx); + let result = match elems.borrow() { + Null => Null, + Value::List(l) => { + let l_nums: Vec<&Value> = match self.setq { + SetQuantifier::All => l.iter().filter(|&e| !e.is_null_or_missing()).collect(), + SetQuantifier::Distinct => l + .iter() + .filter(|&e| !e.is_null_or_missing()) + .unique() + .collect(), + }; + coll_every(l_nums) + } + Value::Bag(b) => { + let b_nums: Vec<&Value> = match self.setq { + SetQuantifier::All => b.iter().filter(|&e| !e.is_null_or_missing()).collect(), + SetQuantifier::Distinct => b + .iter() + .filter(|&e| !e.is_null_or_missing()) + .unique() + .collect(), + }; + coll_every(b_nums) + } + _ => Missing, + }; + Cow::Owned(result) + } +} diff --git a/partiql-eval/src/plan.rs b/partiql-eval/src/plan.rs index e1a8ea2f..8c479794 100644 --- a/partiql-eval/src/plan.rs +++ b/partiql-eval/src/plan.rs @@ -13,20 +13,20 @@ use partiql_logical::{ use crate::error::{ErrorNode, PlanErr, PlanningError}; use crate::eval; use crate::eval::evaluable::{ - Avg, Count, EvalGroupingStrategy, EvalJoinKind, EvalOrderBy, EvalOrderBySortCondition, - EvalOrderBySortSpec, EvalSubQueryExpr, Evaluable, Max, Min, Sum, + Any, Avg, Count, EvalGroupingStrategy, EvalJoinKind, EvalOrderBy, EvalOrderBySortCondition, + EvalOrderBySortSpec, EvalSubQueryExpr, Evaluable, Every, Max, Min, Sum, }; use crate::eval::expr::pattern_match::like_to_re_pattern; use crate::eval::expr::{ EvalBagExpr, EvalBetweenExpr, EvalBinOp, EvalBinOpExpr, EvalDynamicLookup, EvalExpr, EvalFnAbs, EvalFnBaseTableExpr, EvalFnBitLength, EvalFnBtrim, EvalFnCardinality, EvalFnCharLength, - EvalFnCollAvg, EvalFnCollCount, EvalFnCollMax, EvalFnCollMin, EvalFnCollSum, EvalFnExists, - EvalFnExtractDay, EvalFnExtractHour, EvalFnExtractMinute, EvalFnExtractMonth, - EvalFnExtractSecond, EvalFnExtractTimezoneHour, EvalFnExtractTimezoneMinute, EvalFnExtractYear, - EvalFnLower, EvalFnLtrim, EvalFnModulus, EvalFnOctetLength, EvalFnOverlay, EvalFnPosition, - EvalFnRtrim, EvalFnSubstring, EvalFnUpper, EvalIsTypeExpr, EvalLikeMatch, - EvalLikeNonStringNonLiteralMatch, EvalListExpr, EvalLitExpr, EvalPath, EvalSearchedCaseExpr, - EvalTupleExpr, EvalUnaryOp, EvalUnaryOpExpr, EvalVarRef, + EvalFnCollAny, EvalFnCollAvg, EvalFnCollCount, EvalFnCollEvery, EvalFnCollMax, EvalFnCollMin, + EvalFnCollSum, EvalFnExists, EvalFnExtractDay, EvalFnExtractHour, EvalFnExtractMinute, + EvalFnExtractMonth, EvalFnExtractSecond, EvalFnExtractTimezoneHour, + EvalFnExtractTimezoneMinute, EvalFnExtractYear, EvalFnLower, EvalFnLtrim, EvalFnModulus, + EvalFnOctetLength, EvalFnOverlay, EvalFnPosition, EvalFnRtrim, EvalFnSubstring, EvalFnUpper, + EvalIsTypeExpr, EvalLikeMatch, EvalLikeNonStringNonLiteralMatch, EvalListExpr, EvalLitExpr, + EvalPath, EvalSearchedCaseExpr, EvalTupleExpr, EvalUnaryOp, EvalUnaryOpExpr, EvalVarRef, }; use crate::eval::EvalPlan; use partiql_catalog::Catalog; @@ -216,6 +216,12 @@ impl<'c> EvaluatorPlanner<'c> { (AggFunc::AggSum, logical::SetQuantifier::All) => { eval::evaluable::AggFunc::Sum(Sum::new_all()) } + (AggFunc::AggAny, logical::SetQuantifier::All) => { + eval::evaluable::AggFunc::Any(Any::new_all()) + } + (AggFunc::AggEvery, logical::SetQuantifier::All) => { + eval::evaluable::AggFunc::Every(Every::new_all()) + } (AggFunc::AggAvg, logical::SetQuantifier::Distinct) => { eval::evaluable::AggFunc::Avg(Avg::new_distinct()) } @@ -231,6 +237,12 @@ impl<'c> EvaluatorPlanner<'c> { (AggFunc::AggSum, logical::SetQuantifier::Distinct) => { eval::evaluable::AggFunc::Sum(Sum::new_distinct()) } + (AggFunc::AggAny, logical::SetQuantifier::Distinct) => { + eval::evaluable::AggFunc::Any(Any::new_distinct()) + } + (AggFunc::AggEvery, logical::SetQuantifier::Distinct) => { + eval::evaluable::AggFunc::Every(Every::new_distinct()) + } }; eval::evaluable::AggregateExpression { name: a_e.name.to_string(), @@ -725,6 +737,20 @@ impl<'c> EvaluatorPlanner<'c> { elems: args.pop().unwrap(), }) } + CallName::CollAny(setq) => { + correct_num_args_or_err!(self, args, 1, "coll_any/coll_some"); + Box::new(EvalFnCollAny { + setq: plan_set_quantifier(setq), + elems: args.pop().unwrap(), + }) + } + CallName::CollEvery(setq) => { + correct_num_args_or_err!(self, args, 1, "coll_every"); + Box::new(EvalFnCollEvery { + setq: plan_set_quantifier(setq), + elems: args.pop().unwrap(), + }) + } CallName::ByName(name) => match self.catalog.get_function(name) { None => { self.errors.push(PlanningError::IllegalState(format!( diff --git a/partiql-logical-planner/src/builtins.rs b/partiql-logical-planner/src/builtins.rs index e9fe7289..86e83bcc 100644 --- a/partiql-logical-planner/src/builtins.rs +++ b/partiql-logical-planner/src/builtins.rs @@ -655,6 +655,76 @@ fn function_call_def_coll_sum() -> CallDef { } } +fn function_call_def_coll_any() -> CallDef { + CallDef { + names: vec!["coll_any", "coll_some"], + overloads: vec![ + CallSpec { + input: vec![CallSpecArg::Positional], + output: Box::new(|args| { + logical::ValueExpr::Call(logical::CallExpr { + name: logical::CallName::CollAny(SetQuantifier::All), + arguments: args, + }) + }), + }, + CallSpec { + input: vec![CallSpecArg::Named("all".into())], + output: Box::new(|args| { + logical::ValueExpr::Call(logical::CallExpr { + name: logical::CallName::CollAny(SetQuantifier::All), + arguments: args, + }) + }), + }, + CallSpec { + input: vec![CallSpecArg::Named("distinct".into())], + output: Box::new(|args| { + logical::ValueExpr::Call(logical::CallExpr { + name: logical::CallName::CollAny(SetQuantifier::Distinct), + arguments: args, + }) + }), + }, + ], + } +} + +fn function_call_def_coll_every() -> CallDef { + CallDef { + names: vec!["coll_every"], + overloads: vec![ + CallSpec { + input: vec![CallSpecArg::Positional], + output: Box::new(|args| { + logical::ValueExpr::Call(logical::CallExpr { + name: logical::CallName::CollEvery(SetQuantifier::All), + arguments: args, + }) + }), + }, + CallSpec { + input: vec![CallSpecArg::Named("all".into())], + output: Box::new(|args| { + logical::ValueExpr::Call(logical::CallExpr { + name: logical::CallName::CollEvery(SetQuantifier::All), + arguments: args, + }) + }), + }, + CallSpec { + input: vec![CallSpecArg::Named("distinct".into())], + output: Box::new(|args| { + logical::ValueExpr::Call(logical::CallExpr { + name: logical::CallName::CollEvery(SetQuantifier::Distinct), + arguments: args, + }) + }), + }, + ], + } +} + pub(crate) static FN_SYM_TAB: Lazy = Lazy::new(function_call_def); /// Function symbol table @@ -698,6 +768,8 @@ pub fn function_call_def() -> FnSymTab { function_call_def_coll_max(), function_call_def_coll_min(), function_call_def_coll_sum(), + function_call_def_coll_any(), + function_call_def_coll_every(), ] { assert!(!def.names.is_empty()); let primary = def.names[0]; diff --git a/partiql-logical-planner/src/lower.rs b/partiql-logical-planner/src/lower.rs index 2d9fda37..9b96fcc5 100644 --- a/partiql-logical-planner/src/lower.rs +++ b/partiql-logical-planner/src/lower.rs @@ -34,7 +34,7 @@ use crate::error::{LowerError, LoweringError}; use partiql_catalog::Catalog; use partiql_extension_ion::decode::{IonDecoderBuilder, IonDecoderConfig}; use partiql_extension_ion::Encoding; -use partiql_logical::AggFunc::{AggAvg, AggCount, AggMax, AggMin, AggSum}; +use partiql_logical::AggFunc::{AggAny, AggAvg, AggCount, AggEvery, AggMax, AggMin, AggSum}; use std::sync::atomic::{AtomicU32, Ordering}; type FnvIndexMap = IndexMap; @@ -1150,6 +1150,18 @@ impl<'a, 'ast> Visitor<'ast> for AstToLogical<'a> { func: AggSum, setq, }, + "any" | "some" => AggregateExpression { + name: new_name, + expr: arg, + func: AggAny, + setq, + }, + "every" => AggregateExpression { + name: new_name, + expr: arg, + func: AggEvery, + setq, + }, _ => { // Include as an error but allow lowering to proceed for multiple error reporting self.errors.push(LowerError::UnsupportedFunction(name)); diff --git a/partiql-logical/src/lib.rs b/partiql-logical/src/lib.rs index 7b149a46..9d6ce1ae 100644 --- a/partiql-logical/src/lib.rs +++ b/partiql-logical/src/lib.rs @@ -339,6 +339,10 @@ pub enum AggFunc { AggMin, /// Represents SQL's `SUM` aggregation function AggSum, + /// Represents SQL's `ANY`/`SOME` aggregation function + AggAny, + /// Represents SQL's `EVERY` aggregation function + AggEvery, } /// Represents `GROUP BY` [, ] ... \[AS \] @@ -664,6 +668,8 @@ pub enum CallName { CollMax(SetQuantifier), CollMin(SetQuantifier), CollSum(SetQuantifier), + CollAny(SetQuantifier), + CollEvery(SetQuantifier), ByName(String), } diff --git a/partiql-parser/src/parse/parser_state.rs b/partiql-parser/src/parse/parser_state.rs index b47f3da8..31a778d5 100644 --- a/partiql-parser/src/parse/parser_state.rs +++ b/partiql-parser/src/parse/parser_state.rs @@ -69,7 +69,8 @@ impl<'input> Default for ParserState<'input, NodeIdGenerator> { // TODO: currently needs to be manually kept in-sync with preprocessor's `built_in_aggs` // TODO: make extensible -const KNOWN_AGGREGATES: &str = "(?i:^count$)|(?i:^avg$)|(?i:^min$)|(?i:^max$)|(?i:^sum$)"; +const KNOWN_AGGREGATES: &str = + "(?i:^count$)|(?i:^avg$)|(?i:^min$)|(?i:^max$)|(?i:^sum$)|(?i:^any$)|(?i:^some$)|(?i:^every$)"; static KNOWN_AGGREGATE_PATTERN: Lazy = Lazy::new(|| Regex::new(KNOWN_AGGREGATES).unwrap()); impl<'input, I> ParserState<'input, I> diff --git a/partiql-parser/src/preprocessor.rs b/partiql-parser/src/preprocessor.rs index a76257bb..e23784f0 100644 --- a/partiql-parser/src/preprocessor.rs +++ b/partiql-parser/src/preprocessor.rs @@ -132,7 +132,7 @@ mod built_ins { pub(crate) fn built_in_aggs() -> FnExpr<'static> { FnExpr { // TODO: currently needs to be manually kept in-sync with parsers's `KNOWN_AGGREGATES` - fn_names: vec!["count", "avg", "min", "max", "sum"], + fn_names: vec!["count", "avg", "min", "max", "sum", "any", "some", "every"], #[rustfmt::skip] patterns: vec![ // e.g., count(all x) => count("all": x)