From d46740ac078ba74292ea8d13e7c65b192672fb84 Mon Sep 17 00:00:00 2001 From: Eric Kidd Date: Sun, 22 Oct 2023 13:01:11 -0400 Subject: [PATCH] Add some ARRAY_AGG features One promising approach to implementing ARRAY[SELECT ...] is use ARRAY_AGG. --- src/ast.rs | 56 +++++++++++++++++++++ tests/sql/functions/aggregate/array_agg.sql | 14 ++++++ tests/sql/operators/array_select.sql | 4 +- 3 files changed, 72 insertions(+), 2 deletions(-) diff --git a/src/ast.rs b/src/ast.rs index 4af7c54..c77cb07 100644 --- a/src/ast.rs +++ b/src/ast.rs @@ -794,6 +794,7 @@ pub enum Expression { Struct(StructExpression), Count(CountExpression), CurrentDate(CurrentDate), + ArrayAgg(ArrayAggExpression), SpecialDateFunctionCall(SpecialDateFunctionCall), FunctionCall(FunctionCall), Index(IndexExpression), @@ -1028,6 +1029,44 @@ pub enum ExpressionOrDatePart { DatePart(DatePart), } +/// An `ARRAY_AGG` expression. +#[derive(Clone, Debug, Drive, DriveMut, EmitDefault, ToTokens)] +pub struct ArrayAggExpression { + pub array_agg_token: CaseInsensitiveIdent, + pub paren1: Punct, + pub distinct: Option, + pub expression: Box, + pub order_by: Option, + pub paren2: Punct, +} + +impl Emit for ArrayAggExpression { + fn emit(&self, t: Target, f: &mut TokenWriter<'_>) -> io::Result<()> { + match self { + // Snowflake formats ORDER BY as `ARRAY_AGG(expression) WITHIN GROUP + // (ORDER BY ...)`. + ArrayAggExpression { + array_agg_token, + paren1, + distinct, + expression, + order_by: Some(order_by), + paren2, + } if t == Target::Snowflake => { + array_agg_token.emit(t, f)?; + paren1.emit(t, f)?; + distinct.emit(t, f)?; + expression.emit(t, f)?; + paren2.emit(t, f)?; + f.write_token_start("WITHIN GROUP(")?; + order_by.emit(t, f)?; + f.write_token_start(")") + } + _ => self.emit_default(t, f), + } + } +} + /// A function call. #[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault, ToTokens)] pub struct FunctionCall { @@ -1954,6 +1993,7 @@ peg::parser! { null_token:k("NULL") { Expression::Null { null_token } } interval_expression:interval_expression() { Expression::Interval(interval_expression) } cast:cast() { Expression::Cast(cast) } + array_agg:array_agg() { Expression::ArrayAgg(array_agg) } special_date_function_call:special_date_function_call() { Expression::SpecialDateFunctionCall(special_date_function_call) } // Things from here down might start with arbitrary identifiers, so // we need to be careful about the order. @@ -2121,6 +2161,22 @@ peg::parser! { EmptyParens { paren1, paren2 } } + rule array_agg() -> ArrayAggExpression + = array_agg_token:ti("ARRAY_AGG") paren1:p("(") distinct:distinct()? + expression:expression() + order_by:order_by()? + paren2:p(")") + { + ArrayAggExpression { + array_agg_token, + paren1, + distinct, + expression: Box::new(expression), + order_by, + paren2, + } + } + rule special_date_function_call() -> SpecialDateFunctionCall = function_name:special_date_function_name() paren1:p("(") args:sep(, ",") paren2:p(")") { diff --git a/tests/sql/functions/aggregate/array_agg.sql b/tests/sql/functions/aggregate/array_agg.sql index 0cac432..521b516 100644 --- a/tests/sql/functions/aggregate/array_agg.sql +++ b/tests/sql/functions/aggregate/array_agg.sql @@ -20,3 +20,17 @@ CREATE OR REPLACE TABLE __expected1 ( -- Snowflake does not allow array constants in VALUES. INSERT INTO __expected1 SELECT 'a', [1, 1, 2, 0]; + +-- Now test DISTINCT. +CREATE OR REPLACE TABLE __result2 AS +SELECT grp, ARRAY_AGG(DISTINCT x ORDER BY x) AS arr +FROM array_agg_data +GROUP BY grp; + +CREATE OR REPLACE TABLE __expected2 ( + grp STRING, + arr ARRAY, +); +-- Snowflake does not allow array constants in VALUES. +INSERT INTO __expected2 +SELECT 'a', [0, 1, 2]; \ No newline at end of file diff --git a/tests/sql/operators/array_select.sql b/tests/sql/operators/array_select.sql index 4fd6889..09183b1 100644 --- a/tests/sql/operators/array_select.sql +++ b/tests/sql/operators/array_select.sql @@ -5,10 +5,10 @@ -- ARRAY(SELECT ...) CREATE TEMP TABLE array_items (idx INT64, val INT64); -INSERT INTO array_items VALUES (1, 1), (2, 3); +INSERT INTO array_items VALUES (1, 1), (2, 3), (3, 3); CREATE OR REPLACE TABLE __result1 AS -SELECT ARRAY(SELECT val FROM array_items ORDER BY idx) AS arr; +SELECT ARRAY(SELECT DISTINCT val FROM array_items ORDER BY idx) AS arr; CREATE OR REPLACE TABLE __expected1 ( arr ARRAY