Skip to content

Commit

Permalink
Fix: Don't trigger implicit GROUP BY because of subqueries
Browse files Browse the repository at this point in the history
The old tree-walker for finding aggregate functions was wrong. We need
to do this the hard way.
  • Loading branch information
emk committed Nov 7, 2023
1 parent 9c14b66 commit d0beb1d
Show file tree
Hide file tree
Showing 3 changed files with 286 additions and 65 deletions.
255 changes: 255 additions & 0 deletions src/infer/contains_aggregate.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,255 @@
use crate::{
ast,
scope::{ColumnSetScope, ScopeGet},
};

/// Interface used to search an AST node for aggregate functions.
pub trait ContainsAggregate {
/// Does this AST node contain an aggregate function?
///
/// Does not recurse into sub-queries.
fn contains_aggregate(&self, scope: &ColumnSetScope) -> bool;
}

impl<T: ContainsAggregate + ast::Node> ContainsAggregate for Box<T> {
fn contains_aggregate(&self, scope: &ColumnSetScope) -> bool {
self.as_ref().contains_aggregate(scope)
}
}

impl<T: ContainsAggregate + ast::Node> ContainsAggregate for Vec<T> {
fn contains_aggregate(&self, scope: &ColumnSetScope) -> bool {
self.iter().any(|item| item.contains_aggregate(scope))
}
}

impl<T: ContainsAggregate + ast::Node> ContainsAggregate for Option<T> {
fn contains_aggregate(&self, scope: &ColumnSetScope) -> bool {
self.as_ref()
.map_or(false, |item| item.contains_aggregate(scope))
}
}

impl<T: ContainsAggregate + ast::Node> ContainsAggregate for ast::NodeVec<T> {
fn contains_aggregate(&self, scope: &ColumnSetScope) -> bool {
self.node_iter().any(|item| item.contains_aggregate(scope))
}
}

impl ContainsAggregate for ast::SelectList {
fn contains_aggregate(&self, scope: &ColumnSetScope) -> bool {
self.items.contains_aggregate(scope)
}
}

impl ContainsAggregate for ast::SelectListItem {
fn contains_aggregate(&self, scope: &ColumnSetScope) -> bool {
match self {
ast::SelectListItem::Expression {
expression,
alias: _,
} => expression.contains_aggregate(scope),
ast::SelectListItem::Wildcard { .. }
| ast::SelectListItem::TableNameWildcard { .. } => false,
}
}
}

impl ContainsAggregate for ast::Expression {
fn contains_aggregate(&self, scope: &ColumnSetScope) -> bool {
match self {
ast::Expression::Literal(_) => false,
ast::Expression::BoolValue(_) => false,
ast::Expression::Null(_) => false,
ast::Expression::Interval(interval) => interval.contains_aggregate(scope),
ast::Expression::ColumnName(_) => false,
ast::Expression::Cast(cast) => cast.contains_aggregate(scope),
ast::Expression::Is(is) => is.contains_aggregate(scope),
ast::Expression::In(in_expr) => in_expr.contains_aggregate(scope),
ast::Expression::Between(between) => between.contains_aggregate(scope),
ast::Expression::KeywordBinop(binop) => binop.contains_aggregate(scope),
ast::Expression::Not(not) => not.contains_aggregate(scope),
ast::Expression::If(if_expr) => if_expr.contains_aggregate(scope),
ast::Expression::Case(case) => case.contains_aggregate(scope),
ast::Expression::Binop(binop) => binop.contains_aggregate(scope),
// We never look into sub-queries.
ast::Expression::Query { .. } => false,
ast::Expression::Parens { expression, .. } => expression.contains_aggregate(scope),
ast::Expression::Array(arr) => arr.contains_aggregate(scope),
ast::Expression::Struct(st) => st.contains_aggregate(scope),
// TODO: false if we add `OVER`.
ast::Expression::Count(_) => true,
ast::Expression::CurrentDate(_) => false,
// TODO: false if we add `OVER`.
ast::Expression::ArrayAgg(_) => true,
ast::Expression::SpecialDateFunctionCall(fcall) => fcall.contains_aggregate(scope),
ast::Expression::FunctionCall(fcall) => fcall.contains_aggregate(scope),
ast::Expression::Index(idx) => idx.contains_aggregate(scope),
}
}
}

impl ContainsAggregate for ast::IntervalExpression {
fn contains_aggregate(&self, _scope: &ColumnSetScope) -> bool {
// These are currently all literals, so they can't contain aggregates.
false
}
}

impl ContainsAggregate for ast::Cast {
fn contains_aggregate(&self, scope: &ColumnSetScope) -> bool {
self.expression.contains_aggregate(scope)
}
}

impl ContainsAggregate for ast::IsExpression {
fn contains_aggregate(&self, scope: &ColumnSetScope) -> bool {
self.left.contains_aggregate(scope)
}
}

impl ContainsAggregate for ast::InExpression {
fn contains_aggregate(&self, scope: &ColumnSetScope) -> bool {
self.left.contains_aggregate(scope) || self.value_set.contains_aggregate(scope)
}
}

impl ContainsAggregate for ast::InValueSet {
fn contains_aggregate(&self, scope: &ColumnSetScope) -> bool {
match self {
// We never look into sub-queries.
ast::InValueSet::QueryExpression { .. } => false,
ast::InValueSet::ExpressionList { expressions, .. } => {
expressions.contains_aggregate(scope)
}
// I am doubtful that anything good is happening if we hit this
// case.
ast::InValueSet::Unnest { expression, .. } => expression.contains_aggregate(scope),
}
}
}

impl ContainsAggregate for ast::BetweenExpression {
fn contains_aggregate(&self, scope: &ColumnSetScope) -> bool {
self.left.contains_aggregate(scope)
|| self.middle.contains_aggregate(scope)
|| self.right.contains_aggregate(scope)
}
}

impl ContainsAggregate for ast::KeywordBinopExpression {
fn contains_aggregate(&self, scope: &ColumnSetScope) -> bool {
self.left.contains_aggregate(scope) || self.right.contains_aggregate(scope)
}
}

impl ContainsAggregate for ast::NotExpression {
fn contains_aggregate(&self, scope: &ColumnSetScope) -> bool {
self.expression.contains_aggregate(scope)
}
}

impl ContainsAggregate for ast::IfExpression {
fn contains_aggregate(&self, scope: &ColumnSetScope) -> bool {
self.condition.contains_aggregate(scope)
|| self.then_expression.contains_aggregate(scope)
|| self.else_expression.contains_aggregate(scope)
}
}

impl ContainsAggregate for ast::CaseExpression {
fn contains_aggregate(&self, scope: &ColumnSetScope) -> bool {
self.case_expr.contains_aggregate(scope)
|| self.when_clauses.contains_aggregate(scope)
|| self.else_clause.contains_aggregate(scope)
}
}

impl ContainsAggregate for ast::CaseWhenClause {
fn contains_aggregate(&self, scope: &ColumnSetScope) -> bool {
self.condition.contains_aggregate(scope) || self.result.contains_aggregate(scope)
}
}

impl ContainsAggregate for ast::CaseElseClause {
fn contains_aggregate(&self, scope: &ColumnSetScope) -> bool {
self.result.contains_aggregate(scope)
}
}

impl ContainsAggregate for ast::BinopExpression {
fn contains_aggregate(&self, scope: &ColumnSetScope) -> bool {
self.left.contains_aggregate(scope) || self.right.contains_aggregate(scope)
}
}

impl ContainsAggregate for ast::ArrayExpression {
fn contains_aggregate(&self, scope: &ColumnSetScope) -> bool {
self.definition.contains_aggregate(scope)
}
}

impl ContainsAggregate for ast::ArrayDefinition {
fn contains_aggregate(&self, scope: &ColumnSetScope) -> bool {
match self {
// We never look into sub-queries.
ast::ArrayDefinition::Query { .. } => false,
ast::ArrayDefinition::Elements(expressions) => expressions.contains_aggregate(scope),
}
}
}

impl ContainsAggregate for ast::StructExpression {
fn contains_aggregate(&self, scope: &ColumnSetScope) -> bool {
self.fields.contains_aggregate(scope)
}
}

impl ContainsAggregate for ast::SpecialDateFunctionCall {
fn contains_aggregate(&self, scope: &ColumnSetScope) -> bool {
// We don't need to check the function name here.
self.args.contains_aggregate(scope)
}
}

impl ContainsAggregate for ast::ExpressionOrDatePart {
fn contains_aggregate(&self, scope: &ColumnSetScope) -> bool {
match self {
ast::ExpressionOrDatePart::Expression(expr) => expr.contains_aggregate(scope),
ast::ExpressionOrDatePart::DatePart(_) => false,
}
}
}

impl ContainsAggregate for ast::FunctionCall {
fn contains_aggregate(&self, scope: &ColumnSetScope) -> bool {
// Check the function we're calling.
if scope
.get_function_type(&self.name)
.map_or(false, |func_ty| func_ty.is_aggregate())
// If we have an OVER clause, we're not a normal aggregate.
&& self.over_clause.is_none()
{
return true;
}

// Check our arguments.
self.args.contains_aggregate(scope)
}
}

impl ContainsAggregate for ast::IndexExpression {
fn contains_aggregate(&self, scope: &ColumnSetScope) -> bool {
self.expression.contains_aggregate(scope) || self.index.contains_aggregate(scope)
}
}

impl ContainsAggregate for ast::IndexOffset {
fn contains_aggregate(&self, scope: &ColumnSetScope) -> bool {
match self {
ast::IndexOffset::Simple(expression) => expression.contains_aggregate(scope),
ast::IndexOffset::Offset { expression, .. } => expression.contains_aggregate(scope),
ast::IndexOffset::Ordinal { expression, .. } => expression.contains_aggregate(scope),
}
}
}
70 changes: 5 additions & 65 deletions src/infer.rs → src/infer/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

use std::collections::HashSet;

use derive_visitor::{Drive, Visitor};
use tracing::trace;

use crate::{
Expand All @@ -14,6 +13,10 @@ use crate::{
unification::{UnificationTable, Unify},
};

use self::contains_aggregate::ContainsAggregate;

mod contains_aggregate;

// TODO: Remember this rather scary example. Verify BigQuery supports it
// and that we need it.
//
Expand Down Expand Up @@ -319,7 +322,7 @@ impl InferTypes for ast::SelectExpression {
column_set_scope = column_set_scope
.try_transform(|column_set| column_set.group_by(&group_by_names))?;
trace!(columns = %column_set_scope.column_set(), "columns after GROUP BY");
} else if contains_aggregate(&column_set_scope, &*select_list) {
} else if select_list.contains_aggregate(&column_set_scope) {
// If we have aggregates but no GROUP BY, we need to add a synthetic
// GROUP BY of the empty set of columns.
column_set_scope =
Expand Down Expand Up @@ -1138,69 +1141,6 @@ fn nyi(spanned: &dyn Spanned, name: &str) -> Error {
)
}

/// Check an [`ast::Expression`] for aggregate functions.
///
/// This is used to detect implicit `GROUP BY` clauses, as in `SELECT SUM(x)
/// FROM t`.
fn contains_aggregate(scope: &ColumnSetScope, node: &ast::NodeVec<ast::SelectListItem>) -> bool {
let mut visitor = ContainsAggregate::new(scope);
node.drive(&mut visitor);
visitor.contains_aggregate
}

/// A struct that we use to walk an AST looking for aggregate functions.
///
/// TODO: We probably need to be a lot more careful about sub-queries here.
///
/// TODO: We may want to have `trait ContainsAggregate` at some point.
#[derive(Debug, Visitor)]
#[visitor(
ast::ArrayAggExpression(enter),
ast::CountExpression(enter),
ast::FunctionCall(enter)
)]
struct ContainsAggregate<'scope> {
scope: &'scope ColumnSetScope,
contains_aggregate: bool,
}

impl<'scope> ContainsAggregate<'scope> {
fn new(scope: &'scope ColumnSetScope) -> Self {
Self {
scope,
contains_aggregate: false,
}
}

fn enter_array_agg_expression(&mut self, _array_agg: &ast::ArrayAggExpression) {
self.contains_aggregate = true;
}

fn enter_count_expression(&mut self, _count: &ast::CountExpression) {
self.contains_aggregate = true;
}

fn enter_function_call(&mut self, fcall: &ast::FunctionCall) {
if self.contains_aggregate {
return;
}
if fcall.over_clause.is_some() {
// OVER clauses may contain aggregate functions, but they don't
// normally trigger an implicit GROUP BY.
return;
}
match self.scope.get_function_type(&fcall.name) {
Ok(func_ty) if func_ty.is_aggregate() => {
self.contains_aggregate = true;
}
_ => {
// If we're not sure what this function is, assume it's not an
// aggregate. We'll call `infer_types` later to check.
}
}
}
}

#[cfg(test)]
mod tests {
use pretty_assertions::assert_eq;
Expand Down
26 changes: 26 additions & 0 deletions tests/sql/queries/group_by/implicit_aggregates.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
-- Implicit aggregates.

CREATE TEMP TABLE implicit_agg (val INT64);
INSERT INTO implicit_agg VALUES (1), (2), (3);

CREATE OR REPLACE TABLE __result1 AS
SELECT SUM(val) AS sum_val
FROM implicit_agg;

CREATE OR REPLACE TABLE __expected1 (
sum_val INT64,
);
INSERT INTO __expected1 VALUES
(6);

-- Now make sure we _don't_ trigger if the aggregate appears in a subquery.
CREATE OR REPLACE TABLE __result2 AS
SELECT x, (SELECT SUM(val) FROM implicit_agg) AS sum_val
FROM (SELECT 1 AS x);

CREATE OR REPLACE TABLE __expected2 (
x INT64,
sum_val INT64,
);
INSERT INTO __expected2 VALUES
(1, 6);

0 comments on commit d0beb1d

Please sign in to comment.