Skip to content

Commit

Permalink
Add some ARRAY_AGG features
Browse files Browse the repository at this point in the history
One promising approach to implementing ARRAY[SELECT ...] is use
ARRAY_AGG.
  • Loading branch information
emk committed Oct 22, 2023
1 parent bc74c91 commit d46740a
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 2 deletions.
56 changes: 56 additions & 0 deletions src/ast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -794,6 +794,7 @@ pub enum Expression {
Struct(StructExpression),
Count(CountExpression),
CurrentDate(CurrentDate),
ArrayAgg(ArrayAggExpression),
SpecialDateFunctionCall(SpecialDateFunctionCall),
FunctionCall(FunctionCall),
Index(IndexExpression),
Expand Down Expand Up @@ -1028,6 +1029,44 @@ pub enum ExpressionOrDatePart {
DatePart(DatePart),
}

/// An `ARRAY_AGG` expression.
#[derive(Clone, Debug, Drive, DriveMut, EmitDefault, ToTokens)]
pub struct ArrayAggExpression {
pub array_agg_token: CaseInsensitiveIdent,
pub paren1: Punct,
pub distinct: Option<Distinct>,
pub expression: Box<Expression>,
pub order_by: Option<OrderBy>,
pub paren2: Punct,
}

impl Emit for ArrayAggExpression {
fn emit(&self, t: Target, f: &mut TokenWriter<'_>) -> io::Result<()> {
match self {
// Snowflake formats ORDER BY as `ARRAY_AGG(expression) WITHIN GROUP
// (ORDER BY ...)`.
ArrayAggExpression {
array_agg_token,
paren1,
distinct,
expression,
order_by: Some(order_by),
paren2,
} if t == Target::Snowflake => {
array_agg_token.emit(t, f)?;
paren1.emit(t, f)?;
distinct.emit(t, f)?;
expression.emit(t, f)?;
paren2.emit(t, f)?;
f.write_token_start("WITHIN GROUP(")?;
order_by.emit(t, f)?;
f.write_token_start(")")
}
_ => self.emit_default(t, f),
}
}
}

/// A function call.
#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault, ToTokens)]
pub struct FunctionCall {
Expand Down Expand Up @@ -1954,6 +1993,7 @@ peg::parser! {
null_token:k("NULL") { Expression::Null { null_token } }
interval_expression:interval_expression() { Expression::Interval(interval_expression) }
cast:cast() { Expression::Cast(cast) }
array_agg:array_agg() { Expression::ArrayAgg(array_agg) }
special_date_function_call:special_date_function_call() { Expression::SpecialDateFunctionCall(special_date_function_call) }
// Things from here down might start with arbitrary identifiers, so
// we need to be careful about the order.
Expand Down Expand Up @@ -2121,6 +2161,22 @@ peg::parser! {
EmptyParens { paren1, paren2 }
}

rule array_agg() -> ArrayAggExpression
= array_agg_token:ti("ARRAY_AGG") paren1:p("(") distinct:distinct()?
expression:expression()
order_by:order_by()?
paren2:p(")")
{
ArrayAggExpression {
array_agg_token,
paren1,
distinct,
expression: Box::new(expression),
order_by,
paren2,
}
}

rule special_date_function_call() -> SpecialDateFunctionCall
= function_name:special_date_function_name() paren1:p("(")
args:sep(<expression_or_date_part()>, ",") paren2:p(")") {
Expand Down
14 changes: 14 additions & 0 deletions tests/sql/functions/aggregate/array_agg.sql
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,17 @@ CREATE OR REPLACE TABLE __expected1 (
-- Snowflake does not allow array constants in VALUES.
INSERT INTO __expected1
SELECT 'a', [1, 1, 2, 0];

-- Now test DISTINCT.
CREATE OR REPLACE TABLE __result2 AS
SELECT grp, ARRAY_AGG(DISTINCT x ORDER BY x) AS arr
FROM array_agg_data
GROUP BY grp;

CREATE OR REPLACE TABLE __expected2 (
grp STRING,
arr ARRAY<INT64>,
);
-- Snowflake does not allow array constants in VALUES.
INSERT INTO __expected2
SELECT 'a', [0, 1, 2];
4 changes: 2 additions & 2 deletions tests/sql/operators/array_select.sql
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
-- ARRAY(SELECT ...)

CREATE TEMP TABLE array_items (idx INT64, val INT64);
INSERT INTO array_items VALUES (1, 1), (2, 3);
INSERT INTO array_items VALUES (1, 1), (2, 3), (3, 3);

CREATE OR REPLACE TABLE __result1 AS
SELECT ARRAY(SELECT val FROM array_items ORDER BY idx) AS arr;
SELECT ARRAY(SELECT DISTINCT val FROM array_items ORDER BY idx) AS arr;

CREATE OR REPLACE TABLE __expected1 (
arr ARRAY<INT64>
Expand Down

0 comments on commit d46740a

Please sign in to comment.