Skip to content

Commit

Permalink
[Aggregate] Correctly handle ordering multiple fields; don't crash on…
Browse files Browse the repository at this point in the history
… nil (#174)

* [Aggregate] Correctly handle ordering multiple fields; dont crash on nil

* review feedback

* lint
  • Loading branch information
ohaibbq authored Mar 10, 2024
1 parent 5837a05 commit 7fc7b7c
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 51 deletions.
100 changes: 49 additions & 51 deletions internal/function_aggregate.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,21 +80,7 @@ func (f *ARRAY_AGG) Step(v Value, opt *AggregatorOption) error {
}

func (f *ARRAY_AGG) Done() (Value, error) {
if f.opt != nil && len(f.opt.OrderBy) != 0 {
for orderBy := 0; orderBy < len(f.opt.OrderBy); orderBy++ {
if f.opt.OrderBy[orderBy].IsAsc {
sort.Slice(f.values, func(i, j int) bool {
v, _ := f.values[i].OrderBy[orderBy].Value.LT(f.values[j].OrderBy[orderBy].Value)
return v
})
} else {
sort.Slice(f.values, func(i, j int) bool {
v, _ := f.values[i].OrderBy[orderBy].Value.GT(f.values[j].OrderBy[orderBy].Value)
return v
})
}
}
}
f.values = sortAggregatedValues(f.values, f.opt)
if f.opt != nil && f.opt.Limit != nil {
minLen := int64(len(f.values))
if *f.opt.Limit < minLen {
Expand Down Expand Up @@ -122,42 +108,33 @@ func (f *ARRAY_CONCAT_AGG) Step(v *ArrayValue, opt *AggregatorOption) error {
return fmt.Errorf("ARRAY_CONCAT_AGG: NULL value unsupported")
}
f.once.Do(func() { f.opt = opt })
for _, vv := range v.values {
f.values = append(f.values, &OrderedValue{
OrderBy: opt.OrderBy,
Value: vv,
})
}
f.values = append(f.values, &OrderedValue{
OrderBy: opt.OrderBy,
Value: v,
})
return nil
}

func (f *ARRAY_CONCAT_AGG) Done() (Value, error) {
if f.opt != nil && len(f.opt.OrderBy) != 0 {
for orderBy := 0; orderBy < len(f.opt.OrderBy); orderBy++ {
if f.opt.OrderBy[orderBy].IsAsc {
sort.Slice(f.values, func(i, j int) bool {
v, _ := f.values[i].OrderBy[orderBy].Value.LT(f.values[j].OrderBy[orderBy].Value)
return v
})
} else {
sort.Slice(f.values, func(i, j int) bool {
v, _ := f.values[i].OrderBy[orderBy].Value.GT(f.values[j].OrderBy[orderBy].Value)
return v
})
}
}
}
f.values = sortAggregatedValues(f.values, f.opt)

if f.opt != nil && f.opt.Limit != nil {
minLen := int64(len(f.values))
if *f.opt.Limit < minLen {
minLen = *f.opt.Limit
}
f.values = f.values[:minLen]
}
values := make([]Value, 0, len(f.values))

var values []Value
for _, v := range f.values {
values = append(values, v.Value)
a, err := v.Value.ToArray()
if err != nil {
return nil, err
}
values = append(values, a.values...)
}

return &ArrayValue{
values: values,
}, nil
Expand Down Expand Up @@ -470,22 +447,43 @@ func (f *STRING_AGG) Step(v Value, delim string, opt *AggregatorOption) error {
return nil
}

func (f *STRING_AGG) Done() (Value, error) {
if f.opt != nil && len(f.opt.OrderBy) != 0 {
for orderBy := 0; orderBy < len(f.opt.OrderBy); orderBy++ {
if f.opt.OrderBy[orderBy].IsAsc {
sort.Slice(f.values, func(i, j int) bool {
v, _ := f.values[i].OrderBy[orderBy].Value.LT(f.values[j].OrderBy[orderBy].Value)
return v
})
func sortAggregatedValues(values []*OrderedValue, opt *AggregatorOption) []*OrderedValue {
if opt != nil && len(opt.OrderBy) == 0 {
return values
}

sort.Slice(values, func(i, j int) bool {
for orderBy := 0; orderBy < len(values[0].OrderBy); orderBy++ {
iV := values[i].OrderBy[orderBy].Value
jV := values[j].OrderBy[orderBy].Value
isAsc := values[0].OrderBy[orderBy].IsAsc
if iV == nil {
return isAsc
}
if jV == nil {
return !isAsc
}
isEqual, _ := iV.EQ(jV)
if isEqual {
// break tie with subsequent fields
continue
}
if isAsc {
cond, _ := iV.LT(jV)
return cond
} else {
sort.Slice(f.values, func(i, j int) bool {
v, _ := f.values[i].OrderBy[orderBy].Value.GT(f.values[j].OrderBy[orderBy].Value)
return v
})
cond, _ := iV.GT(jV)
return cond
}
}
}
return false
})
return values
}

func (f *STRING_AGG) Done() (Value, error) {
f.values = sortAggregatedValues(f.values, f.opt)

if f.opt != nil && f.opt.Limit != nil {
minLen := int64(len(f.values))
if *f.opt.Limit < minLen {
Expand Down
21 changes: 21 additions & 0 deletions query_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -637,6 +637,13 @@ FROM Items`,
query: `SELECT ARRAY_AGG(x) AS array_agg FROM UNNEST([NULL, 1, -2, 3, -2, 1, NULL]) AS x`,
expectedErr: "ARRAY_AGG: input value must be not null",
},
{
name: "array_agg with null in order by",
query: `WITH toks AS (SELECT '1' AS x, '1' as y UNION ALL SELECT '2', null) SELECT ARRAY_AGG(x ORDER BY y) FROM toks`,
expectedRows: [][]interface{}{{
[]interface{}{"2", "1"},
}},
},
{
name: "array_agg with struct",
query: `SELECT b, ARRAY_AGG(a) FROM UNNEST([STRUCT(1 AS a, 2 AS b), STRUCT(NULL AS a, 2 AS b)]) GROUP BY b`,
Expand Down Expand Up @@ -686,6 +693,20 @@ SELECT ARRAY_CONCAT_AGG(x) AS array_concat_agg FROM (
[]interface{}{nil, int64(1), int64(2), int64(3), int64(4), int64(5), int64(6), int64(7), int64(8), int64(9)},
}},
},
{
name: "array_concat_agg with null in order by",
query: `WITH toks AS (SELECT ['1'] AS x, '1' as y UNION ALL SELECT ['2', '3'], null) SELECT ARRAY_CONCAT_AGG(x ORDER BY y) FROM toks`,
expectedRows: [][]interface{}{{
[]interface{}{"2", "3", "1"},
}},
},
{
name: "array_concat_agg with limt",
query: `WITH toks AS (SELECT ['1'] AS x, '1' as y UNION ALL SELECT ['2', '3'], null) SELECT ARRAY_CONCAT_AGG(x ORDER BY y LIMIT 1) FROM toks`,
expectedRows: [][]interface{}{{
[]interface{}{"2", "3"},
}},
},
{
name: "array_concat_agg with format",
query: `SELECT FORMAT("%T", ARRAY_CONCAT_AGG(x)) AS array_concat_agg FROM (
Expand Down

0 comments on commit 7fc7b7c

Please sign in to comment.