From d0beb1d5a2e8b8f6e5e3c79f288754bb308f074b Mon Sep 17 00:00:00 2001 From: Eric Kidd Date: Tue, 7 Nov 2023 12:51:15 -0500 Subject: [PATCH] Fix: Don't trigger implicit GROUP BY because of subqueries The old tree-walker for finding aggregate functions was wrong. We need to do this the hard way. --- src/infer/contains_aggregate.rs | 255 ++++++++++++++++++ src/{infer.rs => infer/mod.rs} | 70 +---- .../queries/group_by/implicit_aggregates.sql | 26 ++ 3 files changed, 286 insertions(+), 65 deletions(-) create mode 100644 src/infer/contains_aggregate.rs rename src/{infer.rs => infer/mod.rs} (95%) create mode 100644 tests/sql/queries/group_by/implicit_aggregates.sql diff --git a/src/infer/contains_aggregate.rs b/src/infer/contains_aggregate.rs new file mode 100644 index 0000000..3694acb --- /dev/null +++ b/src/infer/contains_aggregate.rs @@ -0,0 +1,255 @@ +use crate::{ + ast, + scope::{ColumnSetScope, ScopeGet}, +}; + +/// Interface used to search an AST node for aggregate functions. +pub trait ContainsAggregate { + /// Does this AST node contain an aggregate function? + /// + /// Does not recurse into sub-queries. + fn contains_aggregate(&self, scope: &ColumnSetScope) -> bool; +} + +impl ContainsAggregate for Box { + fn contains_aggregate(&self, scope: &ColumnSetScope) -> bool { + self.as_ref().contains_aggregate(scope) + } +} + +impl ContainsAggregate for Vec { + fn contains_aggregate(&self, scope: &ColumnSetScope) -> bool { + self.iter().any(|item| item.contains_aggregate(scope)) + } +} + +impl ContainsAggregate for Option { + fn contains_aggregate(&self, scope: &ColumnSetScope) -> bool { + self.as_ref() + .map_or(false, |item| item.contains_aggregate(scope)) + } +} + +impl ContainsAggregate for ast::NodeVec { + fn contains_aggregate(&self, scope: &ColumnSetScope) -> bool { + self.node_iter().any(|item| item.contains_aggregate(scope)) + } +} + +impl ContainsAggregate for ast::SelectList { + fn contains_aggregate(&self, scope: &ColumnSetScope) -> bool { + self.items.contains_aggregate(scope) + } +} + +impl ContainsAggregate for ast::SelectListItem { + fn contains_aggregate(&self, scope: &ColumnSetScope) -> bool { + match self { + ast::SelectListItem::Expression { + expression, + alias: _, + } => expression.contains_aggregate(scope), + ast::SelectListItem::Wildcard { .. } + | ast::SelectListItem::TableNameWildcard { .. } => false, + } + } +} + +impl ContainsAggregate for ast::Expression { + fn contains_aggregate(&self, scope: &ColumnSetScope) -> bool { + match self { + ast::Expression::Literal(_) => false, + ast::Expression::BoolValue(_) => false, + ast::Expression::Null(_) => false, + ast::Expression::Interval(interval) => interval.contains_aggregate(scope), + ast::Expression::ColumnName(_) => false, + ast::Expression::Cast(cast) => cast.contains_aggregate(scope), + ast::Expression::Is(is) => is.contains_aggregate(scope), + ast::Expression::In(in_expr) => in_expr.contains_aggregate(scope), + ast::Expression::Between(between) => between.contains_aggregate(scope), + ast::Expression::KeywordBinop(binop) => binop.contains_aggregate(scope), + ast::Expression::Not(not) => not.contains_aggregate(scope), + ast::Expression::If(if_expr) => if_expr.contains_aggregate(scope), + ast::Expression::Case(case) => case.contains_aggregate(scope), + ast::Expression::Binop(binop) => binop.contains_aggregate(scope), + // We never look into sub-queries. + ast::Expression::Query { .. } => false, + ast::Expression::Parens { expression, .. } => expression.contains_aggregate(scope), + ast::Expression::Array(arr) => arr.contains_aggregate(scope), + ast::Expression::Struct(st) => st.contains_aggregate(scope), + // TODO: false if we add `OVER`. + ast::Expression::Count(_) => true, + ast::Expression::CurrentDate(_) => false, + // TODO: false if we add `OVER`. + ast::Expression::ArrayAgg(_) => true, + ast::Expression::SpecialDateFunctionCall(fcall) => fcall.contains_aggregate(scope), + ast::Expression::FunctionCall(fcall) => fcall.contains_aggregate(scope), + ast::Expression::Index(idx) => idx.contains_aggregate(scope), + } + } +} + +impl ContainsAggregate for ast::IntervalExpression { + fn contains_aggregate(&self, _scope: &ColumnSetScope) -> bool { + // These are currently all literals, so they can't contain aggregates. + false + } +} + +impl ContainsAggregate for ast::Cast { + fn contains_aggregate(&self, scope: &ColumnSetScope) -> bool { + self.expression.contains_aggregate(scope) + } +} + +impl ContainsAggregate for ast::IsExpression { + fn contains_aggregate(&self, scope: &ColumnSetScope) -> bool { + self.left.contains_aggregate(scope) + } +} + +impl ContainsAggregate for ast::InExpression { + fn contains_aggregate(&self, scope: &ColumnSetScope) -> bool { + self.left.contains_aggregate(scope) || self.value_set.contains_aggregate(scope) + } +} + +impl ContainsAggregate for ast::InValueSet { + fn contains_aggregate(&self, scope: &ColumnSetScope) -> bool { + match self { + // We never look into sub-queries. + ast::InValueSet::QueryExpression { .. } => false, + ast::InValueSet::ExpressionList { expressions, .. } => { + expressions.contains_aggregate(scope) + } + // I am doubtful that anything good is happening if we hit this + // case. + ast::InValueSet::Unnest { expression, .. } => expression.contains_aggregate(scope), + } + } +} + +impl ContainsAggregate for ast::BetweenExpression { + fn contains_aggregate(&self, scope: &ColumnSetScope) -> bool { + self.left.contains_aggregate(scope) + || self.middle.contains_aggregate(scope) + || self.right.contains_aggregate(scope) + } +} + +impl ContainsAggregate for ast::KeywordBinopExpression { + fn contains_aggregate(&self, scope: &ColumnSetScope) -> bool { + self.left.contains_aggregate(scope) || self.right.contains_aggregate(scope) + } +} + +impl ContainsAggregate for ast::NotExpression { + fn contains_aggregate(&self, scope: &ColumnSetScope) -> bool { + self.expression.contains_aggregate(scope) + } +} + +impl ContainsAggregate for ast::IfExpression { + fn contains_aggregate(&self, scope: &ColumnSetScope) -> bool { + self.condition.contains_aggregate(scope) + || self.then_expression.contains_aggregate(scope) + || self.else_expression.contains_aggregate(scope) + } +} + +impl ContainsAggregate for ast::CaseExpression { + fn contains_aggregate(&self, scope: &ColumnSetScope) -> bool { + self.case_expr.contains_aggregate(scope) + || self.when_clauses.contains_aggregate(scope) + || self.else_clause.contains_aggregate(scope) + } +} + +impl ContainsAggregate for ast::CaseWhenClause { + fn contains_aggregate(&self, scope: &ColumnSetScope) -> bool { + self.condition.contains_aggregate(scope) || self.result.contains_aggregate(scope) + } +} + +impl ContainsAggregate for ast::CaseElseClause { + fn contains_aggregate(&self, scope: &ColumnSetScope) -> bool { + self.result.contains_aggregate(scope) + } +} + +impl ContainsAggregate for ast::BinopExpression { + fn contains_aggregate(&self, scope: &ColumnSetScope) -> bool { + self.left.contains_aggregate(scope) || self.right.contains_aggregate(scope) + } +} + +impl ContainsAggregate for ast::ArrayExpression { + fn contains_aggregate(&self, scope: &ColumnSetScope) -> bool { + self.definition.contains_aggregate(scope) + } +} + +impl ContainsAggregate for ast::ArrayDefinition { + fn contains_aggregate(&self, scope: &ColumnSetScope) -> bool { + match self { + // We never look into sub-queries. + ast::ArrayDefinition::Query { .. } => false, + ast::ArrayDefinition::Elements(expressions) => expressions.contains_aggregate(scope), + } + } +} + +impl ContainsAggregate for ast::StructExpression { + fn contains_aggregate(&self, scope: &ColumnSetScope) -> bool { + self.fields.contains_aggregate(scope) + } +} + +impl ContainsAggregate for ast::SpecialDateFunctionCall { + fn contains_aggregate(&self, scope: &ColumnSetScope) -> bool { + // We don't need to check the function name here. + self.args.contains_aggregate(scope) + } +} + +impl ContainsAggregate for ast::ExpressionOrDatePart { + fn contains_aggregate(&self, scope: &ColumnSetScope) -> bool { + match self { + ast::ExpressionOrDatePart::Expression(expr) => expr.contains_aggregate(scope), + ast::ExpressionOrDatePart::DatePart(_) => false, + } + } +} + +impl ContainsAggregate for ast::FunctionCall { + fn contains_aggregate(&self, scope: &ColumnSetScope) -> bool { + // Check the function we're calling. + if scope + .get_function_type(&self.name) + .map_or(false, |func_ty| func_ty.is_aggregate()) + // If we have an OVER clause, we're not a normal aggregate. + && self.over_clause.is_none() + { + return true; + } + + // Check our arguments. + self.args.contains_aggregate(scope) + } +} + +impl ContainsAggregate for ast::IndexExpression { + fn contains_aggregate(&self, scope: &ColumnSetScope) -> bool { + self.expression.contains_aggregate(scope) || self.index.contains_aggregate(scope) + } +} + +impl ContainsAggregate for ast::IndexOffset { + fn contains_aggregate(&self, scope: &ColumnSetScope) -> bool { + match self { + ast::IndexOffset::Simple(expression) => expression.contains_aggregate(scope), + ast::IndexOffset::Offset { expression, .. } => expression.contains_aggregate(scope), + ast::IndexOffset::Ordinal { expression, .. } => expression.contains_aggregate(scope), + } + } +} diff --git a/src/infer.rs b/src/infer/mod.rs similarity index 95% rename from src/infer.rs rename to src/infer/mod.rs index efecf12..a851d6e 100644 --- a/src/infer.rs +++ b/src/infer/mod.rs @@ -2,7 +2,6 @@ use std::collections::HashSet; -use derive_visitor::{Drive, Visitor}; use tracing::trace; use crate::{ @@ -14,6 +13,10 @@ use crate::{ unification::{UnificationTable, Unify}, }; +use self::contains_aggregate::ContainsAggregate; + +mod contains_aggregate; + // TODO: Remember this rather scary example. Verify BigQuery supports it // and that we need it. // @@ -319,7 +322,7 @@ impl InferTypes for ast::SelectExpression { column_set_scope = column_set_scope .try_transform(|column_set| column_set.group_by(&group_by_names))?; trace!(columns = %column_set_scope.column_set(), "columns after GROUP BY"); - } else if contains_aggregate(&column_set_scope, &*select_list) { + } else if select_list.contains_aggregate(&column_set_scope) { // If we have aggregates but no GROUP BY, we need to add a synthetic // GROUP BY of the empty set of columns. column_set_scope = @@ -1138,69 +1141,6 @@ fn nyi(spanned: &dyn Spanned, name: &str) -> Error { ) } -/// Check an [`ast::Expression`] for aggregate functions. -/// -/// This is used to detect implicit `GROUP BY` clauses, as in `SELECT SUM(x) -/// FROM t`. -fn contains_aggregate(scope: &ColumnSetScope, node: &ast::NodeVec) -> bool { - let mut visitor = ContainsAggregate::new(scope); - node.drive(&mut visitor); - visitor.contains_aggregate -} - -/// A struct that we use to walk an AST looking for aggregate functions. -/// -/// TODO: We probably need to be a lot more careful about sub-queries here. -/// -/// TODO: We may want to have `trait ContainsAggregate` at some point. -#[derive(Debug, Visitor)] -#[visitor( - ast::ArrayAggExpression(enter), - ast::CountExpression(enter), - ast::FunctionCall(enter) -)] -struct ContainsAggregate<'scope> { - scope: &'scope ColumnSetScope, - contains_aggregate: bool, -} - -impl<'scope> ContainsAggregate<'scope> { - fn new(scope: &'scope ColumnSetScope) -> Self { - Self { - scope, - contains_aggregate: false, - } - } - - fn enter_array_agg_expression(&mut self, _array_agg: &ast::ArrayAggExpression) { - self.contains_aggregate = true; - } - - fn enter_count_expression(&mut self, _count: &ast::CountExpression) { - self.contains_aggregate = true; - } - - fn enter_function_call(&mut self, fcall: &ast::FunctionCall) { - if self.contains_aggregate { - return; - } - if fcall.over_clause.is_some() { - // OVER clauses may contain aggregate functions, but they don't - // normally trigger an implicit GROUP BY. - return; - } - match self.scope.get_function_type(&fcall.name) { - Ok(func_ty) if func_ty.is_aggregate() => { - self.contains_aggregate = true; - } - _ => { - // If we're not sure what this function is, assume it's not an - // aggregate. We'll call `infer_types` later to check. - } - } - } -} - #[cfg(test)] mod tests { use pretty_assertions::assert_eq; diff --git a/tests/sql/queries/group_by/implicit_aggregates.sql b/tests/sql/queries/group_by/implicit_aggregates.sql new file mode 100644 index 0000000..142d5a7 --- /dev/null +++ b/tests/sql/queries/group_by/implicit_aggregates.sql @@ -0,0 +1,26 @@ +-- Implicit aggregates. + +CREATE TEMP TABLE implicit_agg (val INT64); +INSERT INTO implicit_agg VALUES (1), (2), (3); + +CREATE OR REPLACE TABLE __result1 AS +SELECT SUM(val) AS sum_val +FROM implicit_agg; + +CREATE OR REPLACE TABLE __expected1 ( + sum_val INT64, +); +INSERT INTO __expected1 VALUES + (6); + +-- Now make sure we _don't_ trigger if the aggregate appears in a subquery. +CREATE OR REPLACE TABLE __result2 AS +SELECT x, (SELECT SUM(val) FROM implicit_agg) AS sum_val +FROM (SELECT 1 AS x); + +CREATE OR REPLACE TABLE __expected2 ( + x INT64, + sum_val INT64, +); +INSERT INTO __expected2 VALUES + (1, 6);