From 7fc7b7c7689cda1d3b4c58865deb9a93be523067 Mon Sep 17 00:00:00 2001 From: Dan Hansen Date: Sun, 10 Mar 2024 04:38:02 +0000 Subject: [PATCH] [Aggregate] Correctly handle ordering multiple fields; don't crash on nil (#174) * [Aggregate] Correctly handle ordering multiple fields; dont crash on nil * review feedback * lint --- internal/function_aggregate.go | 100 ++++++++++++++++----------------- query_test.go | 21 +++++++ 2 files changed, 70 insertions(+), 51 deletions(-) diff --git a/internal/function_aggregate.go b/internal/function_aggregate.go index 7bd8747..e1d26ad 100644 --- a/internal/function_aggregate.go +++ b/internal/function_aggregate.go @@ -80,21 +80,7 @@ func (f *ARRAY_AGG) Step(v Value, opt *AggregatorOption) error { } func (f *ARRAY_AGG) Done() (Value, error) { - if f.opt != nil && len(f.opt.OrderBy) != 0 { - for orderBy := 0; orderBy < len(f.opt.OrderBy); orderBy++ { - if f.opt.OrderBy[orderBy].IsAsc { - sort.Slice(f.values, func(i, j int) bool { - v, _ := f.values[i].OrderBy[orderBy].Value.LT(f.values[j].OrderBy[orderBy].Value) - return v - }) - } else { - sort.Slice(f.values, func(i, j int) bool { - v, _ := f.values[i].OrderBy[orderBy].Value.GT(f.values[j].OrderBy[orderBy].Value) - return v - }) - } - } - } + f.values = sortAggregatedValues(f.values, f.opt) if f.opt != nil && f.opt.Limit != nil { minLen := int64(len(f.values)) if *f.opt.Limit < minLen { @@ -122,31 +108,16 @@ func (f *ARRAY_CONCAT_AGG) Step(v *ArrayValue, opt *AggregatorOption) error { return fmt.Errorf("ARRAY_CONCAT_AGG: NULL value unsupported") } f.once.Do(func() { f.opt = opt }) - for _, vv := range v.values { - f.values = append(f.values, &OrderedValue{ - OrderBy: opt.OrderBy, - Value: vv, - }) - } + f.values = append(f.values, &OrderedValue{ + OrderBy: opt.OrderBy, + Value: v, + }) return nil } func (f *ARRAY_CONCAT_AGG) Done() (Value, error) { - if f.opt != nil && len(f.opt.OrderBy) != 0 { - for orderBy := 0; orderBy < len(f.opt.OrderBy); orderBy++ { - if f.opt.OrderBy[orderBy].IsAsc { - sort.Slice(f.values, func(i, j int) bool { - v, _ := f.values[i].OrderBy[orderBy].Value.LT(f.values[j].OrderBy[orderBy].Value) - return v - }) - } else { - sort.Slice(f.values, func(i, j int) bool { - v, _ := f.values[i].OrderBy[orderBy].Value.GT(f.values[j].OrderBy[orderBy].Value) - return v - }) - } - } - } + f.values = sortAggregatedValues(f.values, f.opt) + if f.opt != nil && f.opt.Limit != nil { minLen := int64(len(f.values)) if *f.opt.Limit < minLen { @@ -154,10 +125,16 @@ func (f *ARRAY_CONCAT_AGG) Done() (Value, error) { } f.values = f.values[:minLen] } - values := make([]Value, 0, len(f.values)) + + var values []Value for _, v := range f.values { - values = append(values, v.Value) + a, err := v.Value.ToArray() + if err != nil { + return nil, err + } + values = append(values, a.values...) } + return &ArrayValue{ values: values, }, nil @@ -470,22 +447,43 @@ func (f *STRING_AGG) Step(v Value, delim string, opt *AggregatorOption) error { return nil } -func (f *STRING_AGG) Done() (Value, error) { - if f.opt != nil && len(f.opt.OrderBy) != 0 { - for orderBy := 0; orderBy < len(f.opt.OrderBy); orderBy++ { - if f.opt.OrderBy[orderBy].IsAsc { - sort.Slice(f.values, func(i, j int) bool { - v, _ := f.values[i].OrderBy[orderBy].Value.LT(f.values[j].OrderBy[orderBy].Value) - return v - }) +func sortAggregatedValues(values []*OrderedValue, opt *AggregatorOption) []*OrderedValue { + if opt != nil && len(opt.OrderBy) == 0 { + return values + } + + sort.Slice(values, func(i, j int) bool { + for orderBy := 0; orderBy < len(values[0].OrderBy); orderBy++ { + iV := values[i].OrderBy[orderBy].Value + jV := values[j].OrderBy[orderBy].Value + isAsc := values[0].OrderBy[orderBy].IsAsc + if iV == nil { + return isAsc + } + if jV == nil { + return !isAsc + } + isEqual, _ := iV.EQ(jV) + if isEqual { + // break tie with subsequent fields + continue + } + if isAsc { + cond, _ := iV.LT(jV) + return cond } else { - sort.Slice(f.values, func(i, j int) bool { - v, _ := f.values[i].OrderBy[orderBy].Value.GT(f.values[j].OrderBy[orderBy].Value) - return v - }) + cond, _ := iV.GT(jV) + return cond } } - } + return false + }) + return values +} + +func (f *STRING_AGG) Done() (Value, error) { + f.values = sortAggregatedValues(f.values, f.opt) + if f.opt != nil && f.opt.Limit != nil { minLen := int64(len(f.values)) if *f.opt.Limit < minLen { diff --git a/query_test.go b/query_test.go index ebbc938..85046e9 100644 --- a/query_test.go +++ b/query_test.go @@ -637,6 +637,13 @@ FROM Items`, query: `SELECT ARRAY_AGG(x) AS array_agg FROM UNNEST([NULL, 1, -2, 3, -2, 1, NULL]) AS x`, expectedErr: "ARRAY_AGG: input value must be not null", }, + { + name: "array_agg with null in order by", + query: `WITH toks AS (SELECT '1' AS x, '1' as y UNION ALL SELECT '2', null) SELECT ARRAY_AGG(x ORDER BY y) FROM toks`, + expectedRows: [][]interface{}{{ + []interface{}{"2", "1"}, + }}, + }, { name: "array_agg with struct", query: `SELECT b, ARRAY_AGG(a) FROM UNNEST([STRUCT(1 AS a, 2 AS b), STRUCT(NULL AS a, 2 AS b)]) GROUP BY b`, @@ -686,6 +693,20 @@ SELECT ARRAY_CONCAT_AGG(x) AS array_concat_agg FROM ( []interface{}{nil, int64(1), int64(2), int64(3), int64(4), int64(5), int64(6), int64(7), int64(8), int64(9)}, }}, }, + { + name: "array_concat_agg with null in order by", + query: `WITH toks AS (SELECT ['1'] AS x, '1' as y UNION ALL SELECT ['2', '3'], null) SELECT ARRAY_CONCAT_AGG(x ORDER BY y) FROM toks`, + expectedRows: [][]interface{}{{ + []interface{}{"2", "3", "1"}, + }}, + }, + { + name: "array_concat_agg with limt", + query: `WITH toks AS (SELECT ['1'] AS x, '1' as y UNION ALL SELECT ['2', '3'], null) SELECT ARRAY_CONCAT_AGG(x ORDER BY y LIMIT 1) FROM toks`, + expectedRows: [][]interface{}{{ + []interface{}{"2", "3"}, + }}, + }, { name: "array_concat_agg with format", query: `SELECT FORMAT("%T", ARRAY_CONCAT_AGG(x)) AS array_concat_agg FROM (