Skip to content

Commit

Permalink
feat: Aggregate filter alias targeting (#3252)
Browse files Browse the repository at this point in the history
## Relevant issue(s)

Resolves #3195

## Description

This PR enables aggregate alias targeting in filters.

Blocked by #3253

## Tasks

- [x] I made sure the code is well commented, particularly
hard-to-understand areas.
- [x] I made sure the repository-held documentation is changed
accordingly.
- [x] I made sure the pull request title adheres to the conventional
commit style (the subset used in the project can be found in
[tools/configs/chglog/config.yml](tools/configs/chglog/config.yml)).
- [x] I made sure to discuss its limitations such as threats to
validity, vulnerability to mistake and misuse, robustness to
invalidation of assumptions, resource requirements, ...

## How has this been tested?

Added and updated integration tests.

Specify the platform(s) on which this was tested:
- MacOS
  • Loading branch information
nasdf authored Dec 5, 2024
1 parent 2776f1e commit e4599e8
Show file tree
Hide file tree
Showing 9 changed files with 350 additions and 22 deletions.
6 changes: 5 additions & 1 deletion internal/planner/average.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ type averageNode struct {
virtualFieldIndex int

execInfo averageExecInfo

aggregateFilter *mapper.Filter
}

type averageExecInfo struct {
Expand All @@ -37,6 +39,7 @@ type averageExecInfo struct {

func (p *Planner) Average(
field *mapper.Aggregate,
filter *mapper.Filter,
) (*averageNode, error) {
var sumField *mapper.Aggregate
var countField *mapper.Aggregate
Expand All @@ -57,6 +60,7 @@ func (p *Planner) Average(
countFieldIndex: countField.Index,
virtualFieldIndex: field.Index,
docMapper: docMapper{field.DocumentMapping},
aggregateFilter: filter,
}, nil
}

Expand Down Expand Up @@ -102,7 +106,7 @@ func (n *averageNode) Next() (bool, error) {
return false, client.NewErrUnhandledType("sum", sumProp)
}

return true, nil
return mapper.RunFilter(n.currentValue, n.aggregateFilter)
}

func (n *averageNode) SetPlan(p planNode) { n.plan = p }
Expand Down
6 changes: 4 additions & 2 deletions internal/planner/count.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ type countNode struct {

virtualFieldIndex int
aggregateMapping []mapper.AggregateTarget
aggregateFilter *mapper.Filter

execInfo countExecInfo
}
Expand All @@ -44,11 +45,12 @@ type countExecInfo struct {
iterations uint64
}

func (p *Planner) Count(field *mapper.Aggregate, host *mapper.Select) (*countNode, error) {
func (p *Planner) Count(field *mapper.Aggregate, host *mapper.Select, filter *mapper.Filter) (*countNode, error) {
return &countNode{
p: p,
virtualFieldIndex: field.Index,
aggregateMapping: field.AggregateTargets,
aggregateFilter: filter,
docMapper: docMapper{field.DocumentMapping},
}, nil
}
Expand Down Expand Up @@ -181,7 +183,7 @@ func (n *countNode) Next() (bool, error) {
}

n.currentValue.Fields[n.virtualFieldIndex] = count
return true, nil
return mapper.RunFilter(n.currentValue, n.aggregateFilter)
}

// countDocs counts the number of documents in a slice, skipping over hidden items
Expand Down
5 changes: 4 additions & 1 deletion internal/planner/max.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ type maxNode struct {
// that contains the result of the aggregate.
virtualFieldIndex int
aggregateMapping []mapper.AggregateTarget
aggregateFilter *mapper.Filter

execInfo maxExecInfo
}
Expand All @@ -45,11 +46,13 @@ type maxExecInfo struct {
func (p *Planner) Max(
field *mapper.Aggregate,
parent *mapper.Select,
filter *mapper.Filter,
) (*maxNode, error) {
return &maxNode{
p: p,
parent: parent,
aggregateMapping: field.AggregateTargets,
aggregateFilter: filter,
virtualFieldIndex: field.Index,
docMapper: docMapper{field.DocumentMapping},
}, nil
Expand Down Expand Up @@ -252,5 +255,5 @@ func (n *maxNode) Next() (bool, error) {
res, _ := max.Int64()
n.currentValue.Fields[n.virtualFieldIndex] = res
}
return true, nil
return mapper.RunFilter(n.currentValue, n.aggregateFilter)
}
5 changes: 4 additions & 1 deletion internal/planner/min.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ type minNode struct {
// that contains the result of the aggregate.
virtualFieldIndex int
aggregateMapping []mapper.AggregateTarget
aggregateFilter *mapper.Filter

execInfo minExecInfo
}
Expand All @@ -45,11 +46,13 @@ type minExecInfo struct {
func (p *Planner) Min(
field *mapper.Aggregate,
parent *mapper.Select,
filter *mapper.Filter,
) (*minNode, error) {
return &minNode{
p: p,
parent: parent,
aggregateMapping: field.AggregateTargets,
aggregateFilter: filter,
virtualFieldIndex: field.Index,
docMapper: docMapper{field.DocumentMapping},
}, nil
Expand Down Expand Up @@ -252,5 +255,5 @@ func (n *minNode) Next() (bool, error) {
res, _ := min.Int64()
n.currentValue.Fields[n.virtualFieldIndex] = res
}
return true, nil
return mapper.RunFilter(n.currentValue, n.aggregateFilter)
}
14 changes: 9 additions & 5 deletions internal/planner/select.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"github.com/sourcenetwork/defradb/internal/core"
"github.com/sourcenetwork/defradb/internal/db/base"
"github.com/sourcenetwork/defradb/internal/keys"
"github.com/sourcenetwork/defradb/internal/planner/filter"
"github.com/sourcenetwork/defradb/internal/planner/mapper"
)

Expand Down Expand Up @@ -344,18 +345,21 @@ func (n *selectNode) initFields(selectReq *mapper.Select) ([]aggregateNode, erro
case *mapper.Aggregate:
var plan aggregateNode
var aggregateError error
var aggregateFilter *mapper.Filter

// extract aggregate filters from the select
selectReq.Filter, aggregateFilter = filter.SplitByFields(selectReq.Filter, f.Field)
switch f.Name {
case request.CountFieldName:
plan, aggregateError = n.planner.Count(f, selectReq)
plan, aggregateError = n.planner.Count(f, selectReq, aggregateFilter)
case request.SumFieldName:
plan, aggregateError = n.planner.Sum(f, selectReq)
plan, aggregateError = n.planner.Sum(f, selectReq, aggregateFilter)
case request.AverageFieldName:
plan, aggregateError = n.planner.Average(f)
plan, aggregateError = n.planner.Average(f, aggregateFilter)
case request.MaxFieldName:
plan, aggregateError = n.planner.Max(f, selectReq)
plan, aggregateError = n.planner.Max(f, selectReq, aggregateFilter)
case request.MinFieldName:
plan, aggregateError = n.planner.Min(f, selectReq)
plan, aggregateError = n.planner.Min(f, selectReq, aggregateFilter)
}

if aggregateError != nil {
Expand Down
6 changes: 4 additions & 2 deletions internal/planner/sum.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ type sumNode struct {
isFloat bool
virtualFieldIndex int
aggregateMapping []mapper.AggregateTarget
aggregateFilter *mapper.Filter

execInfo sumExecInfo
}
Expand All @@ -42,6 +43,7 @@ type sumExecInfo struct {
func (p *Planner) Sum(
field *mapper.Aggregate,
parent *mapper.Select,
filter *mapper.Filter,
) (*sumNode, error) {
isFloat := false
for _, target := range field.AggregateTargets {
Expand All @@ -60,6 +62,7 @@ func (p *Planner) Sum(
p: p,
isFloat: isFloat,
aggregateMapping: field.AggregateTargets,
aggregateFilter: filter,
virtualFieldIndex: field.Index,
docMapper: docMapper{field.DocumentMapping},
}, nil
Expand Down Expand Up @@ -310,8 +313,7 @@ func (n *sumNode) Next() (bool, error) {
typedSum = int64(sum)
}
n.currentValue.Fields[n.virtualFieldIndex] = typedSum

return true, nil
return mapper.RunFilter(n.currentValue, n.aggregateFilter)
}

func (n *sumNode) SetPlan(p planNode) { n.plan = p }
10 changes: 5 additions & 5 deletions internal/planner/top.go
Original file line number Diff line number Diff line change
Expand Up @@ -199,15 +199,15 @@ func (p *Planner) Top(m *mapper.Select) (*topLevelNode, error) {
var err error
switch field.GetName() {
case request.CountFieldName:
child, err = p.Count(f, m)
child, err = p.Count(f, m, nil)
case request.SumFieldName:
child, err = p.Sum(f, m)
child, err = p.Sum(f, m, nil)
case request.AverageFieldName:
child, err = p.Average(f)
child, err = p.Average(f, nil)
case request.MaxFieldName:
child, err = p.Max(f, m)
child, err = p.Max(f, m, nil)
case request.MinFieldName:
child, err = p.Min(f, m)
child, err = p.Min(f, m, nil)
}
if err != nil {
return nil, err
Expand Down
17 changes: 12 additions & 5 deletions tests/integration/query/one_to_many/with_count_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,11 +119,9 @@ func TestQueryOneToManyWithCount(t *testing.T) {
}
}

// This test documents the behavior of aggregate alias targeting which is not yet implemented.
// https://github.com/sourcenetwork/defradb/issues/3195
func TestQueryOneToMany_WithCountAliasFilter_ShouldFilterAll(t *testing.T) {
func TestQueryOneToMany_WithCountAliasFilter_ShouldMatchAll(t *testing.T) {
test := testUtils.TestCase{
Description: "One-to-many relation query from many side with count",
Description: "One-to-many relation query from many side with count alias",
Actions: []any{
testUtils.CreateDoc{
CollectionID: 1,
Expand Down Expand Up @@ -173,7 +171,16 @@ func TestQueryOneToMany_WithCountAliasFilter_ShouldFilterAll(t *testing.T) {
}
}`,
Results: map[string]any{
"Author": []map[string]any{},
"Author": []map[string]any{
{
"name": "Cornelia Funke",
"publishedCount": 1,
},
{
"name": "John Grisham",
"publishedCount": 2,
},
},
},
},
},
Expand Down
Loading

0 comments on commit e4599e8

Please sign in to comment.