From 3a4bb962aab62009ab977a001e12fa8ebc396984 Mon Sep 17 00:00:00 2001 From: Ronen Hilewicz Date: Fri, 6 Dec 2024 15:18:57 -0500 Subject: [PATCH] Use sync.Pool for Relation slices. (#60) --- cache/cache.go | 12 ++++++++---- cache/check.go | 12 +++++++++--- graph/check.go | 37 +++++++++++++++++++++++++---------- graph/check_test.go | 6 +++++- graph/objects.go | 20 +++++++++++-------- graph/objects_test.go | 5 ++++- graph/search.go | 30 +++++++++++++++++----------- graph/subjects.go | 39 +++++++++++++++++++++++++------------ graph/subjects_test.go | 5 ++++- graph/utils_test.go | 11 +++++++---- internal/mempool/mempool.go | 34 ++++++++++++++++++++++++++++++++ 11 files changed, 156 insertions(+), 55 deletions(-) create mode 100644 internal/mempool/mempool.go diff --git a/cache/cache.go b/cache/cache.go index 596751b..0bc8cd3 100644 --- a/cache/cache.go +++ b/cache/cache.go @@ -3,6 +3,8 @@ package cache import ( "sync" + "github.com/aserto-dev/azm/graph" + "github.com/aserto-dev/azm/internal/mempool" "github.com/aserto-dev/azm/model" "github.com/aserto-dev/azm/model/diff" stts "github.com/aserto-dev/azm/stats" @@ -17,15 +19,17 @@ type ( ) type Cache struct { - model *model.Model - mtx sync.RWMutex + model *model.Model + mtx sync.RWMutex + relsPool *graph.RelationsPool } // New, create new model cache instance. func New(m *model.Model) *Cache { return &Cache{ - model: m, - mtx: sync.RWMutex{}, + model: m, + mtx: sync.RWMutex{}, + relsPool: mempool.NewSlicePool[*dsc.Relation](), } } diff --git a/cache/check.go b/cache/check.go index cb800fe..2743fd6 100644 --- a/cache/check.go +++ b/cache/check.go @@ -9,7 +9,10 @@ import ( ) func (c *Cache) Check(req *dsr.CheckRequest, relReader graph.RelationReader) (*dsr.CheckResponse, error) { - checker := graph.NewCheck(c.model, req, relReader) + c.mtx.RLock() + defer c.mtx.RUnlock() + + checker := graph.NewCheck(c.model, req, relReader, c.relsPool) ctx := pb.NewStruct() @@ -26,15 +29,18 @@ type graphSearch interface { } func (c *Cache) GetGraph(req *dsr.GetGraphRequest, relReader graph.RelationReader) (*dsr.GetGraphResponse, error) { + c.mtx.RLock() + defer c.mtx.RUnlock() + var ( search graphSearch err error ) if req.ObjectId == "" { - search, err = graph.NewObjectSearch(c.model, req, relReader) + search, err = graph.NewObjectSearch(c.model, req, relReader, c.relsPool) } else { - search, err = graph.NewSubjectSearch(c.model, req, relReader) + search, err = graph.NewSubjectSearch(c.model, req, relReader, c.relsPool) } if err != nil { diff --git a/graph/check.go b/graph/check.go index 7ca3f85..ee3e0cf 100644 --- a/graph/check.go +++ b/graph/check.go @@ -15,9 +15,10 @@ type Checker struct { getRels RelationReader memo *checkMemo + pool *RelationsPool } -func NewCheck(m *model.Model, req *dsr.CheckRequest, reader RelationReader) *Checker { +func NewCheck(m *model.Model, req *dsr.CheckRequest, reader RelationReader, pool *RelationsPool) *Checker { return &Checker{ m: m, params: &relation{ @@ -29,6 +30,7 @@ func NewCheck(m *model.Model, req *dsr.CheckRequest, reader RelationReader) *Che }, getRels: reader, memo: newCheckMemo(req.Trace), + pool: pool, } } @@ -88,7 +90,16 @@ func (c *Checker) checkRelation(params *relation) (checkStatus, error) { r := c.m.Objects[params.ot].Relations[params.rel] steps := c.m.StepRelation(r, params.st) + // Reuse the same slice in all steps. + relsPtr := c.pool.Get() + defer func() { + *relsPtr = (*relsPtr)[:0] + c.pool.Put(relsPtr) + }() + for _, step := range steps { + *relsPtr = (*relsPtr)[:0] + req := &dsc.Relation{ ObjectType: params.ot.String(), ObjectId: params.oid.String(), @@ -103,27 +114,26 @@ func (c *Checker) checkRelation(params *relation) (checkStatus, error) { req.SubjectRelation = step.Relation.String() } - rels, err := c.getRels(req) - if err != nil { + if err := c.getRels(req, relsPtr); err != nil { return checkStatusFalse, err } switch { case step.IsDirect(): - for _, rel := range rels { + for _, rel := range *relsPtr { if rel.SubjectId == params.sid.String() { return checkStatusTrue, nil } } case step.IsWildcard(): - if len(rels) > 0 { + if len(*relsPtr) > 0 { // We have a wildcard match. return checkStatusTrue, nil } case step.IsSubject(): - for _, rel := range rels { + for _, rel := range *relsPtr { if status, err := c.check(&relation{ ot: step.Object, oid: ObjectID(rel.SubjectId), @@ -190,17 +200,21 @@ func (c *Checker) checkPermission(params *relation) (checkStatus, error) { func (c *Checker) expandTerm(pt *model.PermissionTerm, params *relation) (relations, error) { if pt.IsArrow() { - // Resolve the base of the arrow. - rels, err := c.getRels(&dsc.Relation{ + query := &dsc.Relation{ ObjectType: params.ot.String(), ObjectId: params.oid.String(), Relation: pt.Base.String(), - }) + } + + relsPtr := c.pool.Get() + + // Resolve the base of the arrow. + err := c.getRels(query, relsPtr) if err != nil { return relations{}, err } - expanded := lo.Map(rels, func(rel *dsc.Relation, _ int) *relation { + expanded := lo.Map(*relsPtr, func(rel *dsc.Relation, _ int) *relation { return &relation{ ot: model.ObjectName(rel.SubjectType), oid: ObjectID(rel.SubjectId), @@ -210,6 +224,9 @@ func (c *Checker) expandTerm(pt *model.PermissionTerm, params *relation) (relati } }) + *relsPtr = (*relsPtr)[:0] + c.pool.Put(relsPtr) + return expanded, nil } diff --git a/graph/check_test.go b/graph/check_test.go index abe0f4d..6283033 100644 --- a/graph/check_test.go +++ b/graph/check_test.go @@ -5,7 +5,9 @@ import ( "testing" azmgraph "github.com/aserto-dev/azm/graph" + "github.com/aserto-dev/azm/internal/mempool" v3 "github.com/aserto-dev/azm/v3" + dsc "github.com/aserto-dev/go-directory/aserto/directory/common/v3" "github.com/stretchr/testify/assert" ) @@ -67,11 +69,13 @@ func TestCheck(t *testing.T) { assert.NoError(t, err) assert.NotNil(t, m) + pool := mempool.NewSlicePool[*dsc.Relation]() + for _, test := range tests { t.Run(test.check, func(tt *testing.T) { assert := assert.New(tt) - checker := azmgraph.NewCheck(m, checkReq(test.check), rels.GetRelations) + checker := azmgraph.NewCheck(m, checkReq(test.check), rels.GetRelations, pool) res, err := checker.Check() assert.NoError(err) diff --git a/graph/objects.go b/graph/objects.go index 8da33e1..52ba0ab 100644 --- a/graph/objects.go +++ b/graph/objects.go @@ -16,7 +16,7 @@ type ObjectSearch struct { wildcardSearch *SubjectSearch } -func NewObjectSearch(m *model.Model, req *dsr.GetGraphRequest, reader RelationReader) (*ObjectSearch, error) { +func NewObjectSearch(m *model.Model, req *dsr.GetGraphRequest, reader RelationReader, pool *RelationsPool) (*ObjectSearch, error) { params := searchParams(req) if err := validate(m, params); err != nil { return nil, err @@ -40,6 +40,7 @@ func NewObjectSearch(m *model.Model, req *dsr.GetGraphRequest, reader RelationRe getRels: invertedRelationReader(im, reader), memo: newSearchMemo(req.Trace), explain: req.Explain, + pool: pool, }}, wildcardSearch: &SubjectSearch{graphSearch{ m: im, @@ -47,6 +48,7 @@ func NewObjectSearch(m *model.Model, req *dsr.GetGraphRequest, reader RelationRe getRels: invertedRelationReader(im, reader), memo: newSearchMemo(req.Trace), explain: req.Explain, + pool: pool, }}, }, nil } @@ -125,22 +127,24 @@ func wildcardParams(params *relation) *relation { } func invertedRelationReader(m *model.Model, reader RelationReader) RelationReader { - return func(r *dsc.Relation) ([]*dsc.Relation, error) { + return func(r *dsc.Relation, out *Relations) error { ir := uninvertRelation(m, relationFromProto(r)) - res, err := reader(ir.asProto()) - if err != nil { - return nil, err + if err := reader(ir.asProto(), out); err != nil { + return err } - return lo.Map(res, func(r *dsc.Relation, _ int) *dsc.Relation { - return &dsc.Relation{ + res := *out + for i, r := range res { + res[i] = &dsc.Relation{ ObjectType: r.SubjectType, ObjectId: r.SubjectId, Relation: r.Relation, SubjectType: r.ObjectType, SubjectId: r.ObjectId, } - }), nil + } + + return nil } } diff --git a/graph/objects_test.go b/graph/objects_test.go index 3d9bf0e..ddf0d30 100644 --- a/graph/objects_test.go +++ b/graph/objects_test.go @@ -5,6 +5,7 @@ import ( "testing" "github.com/aserto-dev/azm/graph" + "github.com/aserto-dev/azm/internal/mempool" "github.com/aserto-dev/azm/model" v3 "github.com/aserto-dev/azm/v3" dsc "github.com/aserto-dev/go-directory/aserto/directory/common/v3" @@ -35,11 +36,13 @@ func TestSearchObjects(t *testing.T) { im.Validate(model.SkipNameValidation, model.AllowPermissionInArrowBase), ) + pool := mempool.NewSlicePool[*dsc.Relation]() + for _, test := range searchObjectsTests { t.Run(test.search, func(tt *testing.T) { assert := assert.New(tt) - objSearch, err := graph.NewObjectSearch(m, graphReq(test.search), rels.GetRelations) + objSearch, err := graph.NewObjectSearch(m, graphReq(test.search), rels.GetRelations, pool) assert.NoError(err) res, err := objSearch.Search() diff --git a/graph/search.go b/graph/search.go index 49cb3a8..5d39637 100644 --- a/graph/search.go +++ b/graph/search.go @@ -4,6 +4,7 @@ import ( "fmt" "strings" + "github.com/aserto-dev/azm/internal/mempool" "github.com/aserto-dev/azm/model" dsc "github.com/aserto-dev/go-directory/aserto/directory/common/v3" dsr "github.com/aserto-dev/go-directory/aserto/directory/reader/v3" @@ -12,21 +13,27 @@ import ( "google.golang.org/protobuf/types/known/structpb" ) -type ObjectID = model.ObjectID +type ( + ObjectID = model.ObjectID -// RelationReader retrieves relations that match the given filter. -type RelationReader func(*dsc.Relation) ([]*dsc.Relation, error) + Relations = []*dsc.Relation -type searchPath relations + // RelationReader retrieves relations that match the given filter. + RelationReader func(*dsc.Relation, *Relations) error -type object struct { - Type model.ObjectName - ID ObjectID -} + RelationsPool = mempool.Pool[*Relations] + + searchPath relations -// The results of a search is a map where the key is a matching relations -// and the value is a list of paths that connect the search object and subject. -type searchResults map[object][]searchPath + object struct { + Type model.ObjectName + ID ObjectID + } + + // The results of a search is a map where the key is a matching relations + // and the value is a list of paths that connect the search object and subject. + searchResults map[object][]searchPath +) // Objects returns the objects from the search results. func (r searchResults) Objects() []*dsc.ObjectIdentifier { @@ -92,6 +99,7 @@ type graphSearch struct { memo *searchMemo explain bool + pool *RelationsPool } func validate(m *model.Model, params *relation) error { diff --git a/graph/subjects.go b/graph/subjects.go index 71d02cc..22403be 100644 --- a/graph/subjects.go +++ b/graph/subjects.go @@ -12,7 +12,7 @@ type SubjectSearch struct { graphSearch } -func NewSubjectSearch(m *model.Model, req *dsr.GetGraphRequest, reader RelationReader) (*SubjectSearch, error) { +func NewSubjectSearch(m *model.Model, req *dsr.GetGraphRequest, reader RelationReader, pool *RelationsPool) (*SubjectSearch, error) { params := searchParams(req) if err := validate(m, params); err != nil { return nil, err @@ -24,6 +24,7 @@ func NewSubjectSearch(m *model.Model, req *dsr.GetGraphRequest, reader RelationR getRels: reader, memo: newSearchMemo(req.Trace), explain: req.Explain, + pool: pool, }}, nil } @@ -118,12 +119,12 @@ func (s *SubjectSearch) findNeighbor(step *model.RelationRef, params *relation) results := searchResults{} - rels, err := s.getRels(req) - if err != nil { + relsPtr := s.pool.Get() + if err := s.getRels(req, relsPtr); err != nil { return results, err } - for _, rel := range rels { + for _, rel := range *relsPtr { if rel.SubjectId != "*" && params.oid != ObjectID(rel.ObjectId) { continue } @@ -146,6 +147,9 @@ func (s *SubjectSearch) findNeighbor(step *model.RelationRef, params *relation) results[*subj] = path } + *relsPtr = (*relsPtr)[:0] + s.pool.Put(relsPtr) + return results, nil } @@ -159,12 +163,17 @@ func (s *SubjectSearch) searchSubjectRelation(step *model.RelationRef, params *r SubjectType: step.Object.String(), SubjectRelation: step.Relation.String(), } - rels, err := s.getRels(req) - if err != nil { + + relsPtr := s.pool.Get() + if err := s.getRels(req, relsPtr); err != nil { return results, err } + defer func() { + *relsPtr = (*relsPtr)[:0] + s.pool.Put(relsPtr) + }() - for _, rel := range rels { + for _, rel := range *relsPtr { current := relationFromProto(rel) if params.srel == model.RelationName(rel.SubjectRelation) && params.st == model.ObjectName(rel.SubjectType) { @@ -281,17 +290,20 @@ func (s *SubjectSearch) expandTerm(o *model.Object, pt *model.PermissionTerm, pa } func (s *SubjectSearch) expandRelationArrow(pt *model.PermissionTerm, params *relation) ([]*relation, error) { - // Resolve the base of the arrow. - rels, err := s.getRels(&dsc.Relation{ + relsPtr := s.pool.Get() + + req := &dsc.Relation{ ObjectType: params.ot.String(), ObjectId: params.oid.String(), Relation: pt.Base.String(), - }) - if err != nil { + } + + // Resolve the base of the arrow. + if err := s.getRels(req, relsPtr); err != nil { return []*relation{}, err } - expanded := lo.Map(rels, func(rel *dsc.Relation, _ int) *relation { + expanded := lo.Map(*relsPtr, func(rel *dsc.Relation, _ int) *relation { return &relation{ ot: model.ObjectName(rel.SubjectType), oid: ObjectID(rel.SubjectId), @@ -302,6 +314,9 @@ func (s *SubjectSearch) expandRelationArrow(pt *model.PermissionTerm, params *re } }) + *relsPtr = (*relsPtr)[:0] + s.pool.Put(relsPtr) + return expanded, nil } diff --git a/graph/subjects_test.go b/graph/subjects_test.go index 8e38164..873bfff 100644 --- a/graph/subjects_test.go +++ b/graph/subjects_test.go @@ -4,6 +4,7 @@ import ( "testing" azmgraph "github.com/aserto-dev/azm/graph" + "github.com/aserto-dev/azm/internal/mempool" "github.com/aserto-dev/azm/model" v3 "github.com/aserto-dev/azm/v3" dsc "github.com/aserto-dev/go-directory/aserto/directory/common/v3" @@ -18,11 +19,13 @@ func TestSearchSubjects(t *testing.T) { assert.NoError(t, err) assert.NotNil(t, m) + pool := mempool.NewSlicePool[*dsc.Relation]() + for _, test := range searchSubjectsTests { t.Run(test.search, func(tt *testing.T) { assert := assert.New(tt) - subjSearch, err := azmgraph.NewSubjectSearch(m, graphReq(test.search), rels.GetRelations) + subjSearch, err := azmgraph.NewSubjectSearch(m, graphReq(test.search), rels.GetRelations, pool) assert.NoError(err) res, err := subjSearch.Search() diff --git a/graph/utils_test.go b/graph/utils_test.go index 0b68bd7..419e16f 100644 --- a/graph/utils_test.go +++ b/graph/utils_test.go @@ -4,6 +4,7 @@ import ( "regexp" "testing" + "github.com/aserto-dev/azm/graph" "github.com/aserto-dev/azm/model" dsc "github.com/aserto-dev/go-directory/aserto/directory/common/v3" dsr "github.com/aserto-dev/go-directory/aserto/directory/reader/v3" @@ -96,7 +97,7 @@ func NewRelationsReader(rels ...string) RelationsReader { }) } -func (r RelationsReader) GetRelations(req *dsc.Relation) ([]*dsc.Relation, error) { +func (r RelationsReader) GetRelations(req *dsc.Relation, out *graph.Relations) error { ot := model.ObjectName(req.ObjectType) oid := model.ObjectID(req.ObjectId) rn := model.RelationName(req.Relation) @@ -113,9 +114,11 @@ func (r RelationsReader) GetRelations(req *dsc.Relation) ([]*dsc.Relation, error (sr == "" || rel.SubjectRelation == sr) }) - return lo.Map(matches, func(r *relation, _ int) *dsc.Relation { - return r.proto() - }), nil + for _, rel := range matches { + *out = append(*out, rel.proto()) + } + + return nil } type parseTest struct { diff --git a/internal/mempool/mempool.go b/internal/mempool/mempool.go new file mode 100644 index 0000000..afeebb7 --- /dev/null +++ b/internal/mempool/mempool.go @@ -0,0 +1,34 @@ +package mempool + +import "sync" + +const defaultSliceCapacity = 128 + +type Pool[T any] struct { + sync.Pool +} + +func (p *Pool[T]) Get() T { + return p.Pool.Get().(T) +} + +func (p *Pool[T]) Put(x T) { + p.Pool.Put(x) +} + +func NewPool[T any](newF func() T) *Pool[T] { + return &Pool[T]{ + Pool: sync.Pool{ + New: func() interface{} { + return newF() + }, + }, + } +} + +func NewSlicePool[T any]() *Pool[*[]T] { + return NewPool(func() *[]T { + s := make([]T, 0, defaultSliceCapacity) + return &s + }) +}