From c355046b1c06b8602ab098b87449730106be1a85 Mon Sep 17 00:00:00 2001 From: Oskar Hahn Date: Sat, 28 Nov 2020 16:09:08 +0100 Subject: [PATCH] Plugin permissoin service (#125) --- cmd/autoupdate/main.go | 6 +- go.mod | 1 + go.sum | 2 + internal/autoupdate/autoupdate.go | 19 ++++- internal/autoupdate/autoupdate_test.go | 2 +- internal/autoupdate/connection.go | 78 +++++++++++-------- internal/autoupdate/connection_test.go | 101 +++++++++++++++++++++++-- internal/autoupdate/feature_test.go | 2 +- internal/autoupdate/interfaces.go | 7 +- internal/autoupdate/mock_test.go | 13 +++- internal/restrict/checker.go | 13 ++-- internal/restrict/checker_test.go | 11 +-- internal/restrict/interfaces.go | 12 +-- internal/restrict/restrict.go | 7 +- internal/restrict/restrict_test.go | 11 +-- internal/test/permission_mock.go | 15 ++-- internal/test/restricter_mock.go | 7 +- 17 files changed, 228 insertions(+), 79 deletions(-) diff --git a/cmd/autoupdate/main.go b/cmd/autoupdate/main.go index 3b2c0107..06c966cc 100644 --- a/cmd/autoupdate/main.go +++ b/cmd/autoupdate/main.go @@ -13,6 +13,7 @@ import ( "path" "syscall" + "github.com/OpenSlides/openslides-permission-service/pkg/permission" "github.com/openslides/openslides-autoupdate-service/internal/auth" "github.com/openslides/openslides-autoupdate-service/internal/autoupdate" "github.com/openslides/openslides-autoupdate-service/internal/datastore" @@ -93,14 +94,13 @@ func run() error { } // Perm Service. - perms := &test.MockPermission{} - perms.Default = true + perms := permission.New(datastoreService) // Restricter Service. restricter := restrict.New(perms, restrict.RelationChecker(restrict.RelationLists, perms)) // Autoupdate Service. - service := autoupdate.New(datastoreService, restricter, closed) + service := autoupdate.New(datastoreService, restricter, perms, closed) // Auth Service. authService, err := buildAuth(env, r, closed, errHandler) diff --git a/go.mod b/go.mod index c232be25..2c271255 100644 --- a/go.mod +++ b/go.mod @@ -4,6 +4,7 @@ go 1.15 require ( github.com/OpenSlides/openslides-models-to-go v0.1.1-0.20201023163752-f3a92dde2a27 + github.com/OpenSlides/openslides-permission-service v0.0.0-20201106150223-db52cf71b584 github.com/dgrijalva/jwt-go/v4 v4.0.0-preview1 github.com/gomodule/redigo v1.8.3 github.com/ostcar/topic v0.3.4-0.20200613094955-61bb28837a98 diff --git a/go.sum b/go.sum index 2b8a737a..e74a408e 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,7 @@ github.com/OpenSlides/openslides-models-to-go v0.1.1-0.20201023163752-f3a92dde2a27 h1:hKpGuicCgWOrbDHU6+1vsPJgzbzX1Y1el5XaeIUZMFY= github.com/OpenSlides/openslides-models-to-go v0.1.1-0.20201023163752-f3a92dde2a27/go.mod h1:CriCefW5smTixhFfVLiuA8NgyMX4PAU5e3YpJHnaZx8= +github.com/OpenSlides/openslides-permission-service v0.0.0-20201106150223-db52cf71b584 h1:5Fv6W0+eyuD4TkL+I4WjOjM2TSkX+m24vIQq9utPGO0= +github.com/OpenSlides/openslides-permission-service v0.0.0-20201106150223-db52cf71b584/go.mod h1:VPzKimi8Jz5Qdu5oeAWYRFLThXRP6pnueH/9K8wILDE= github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/dgrijalva/jwt-go/v4 v4.0.0-preview1 h1:CaO/zOnF8VvUfEbhRatPcwKVWamvbYd8tQGRWacE9kU= diff --git a/internal/autoupdate/autoupdate.go b/internal/autoupdate/autoupdate.go index e58b1c6e..6d9cbe2b 100644 --- a/internal/autoupdate/autoupdate.go +++ b/internal/autoupdate/autoupdate.go @@ -21,6 +21,11 @@ import ( // value means, that more memory is used. const pruneTime = 10 * time.Minute +// Format of keys in the topic that shows, that a full update is necessary. It +// is in the same namespace then model names. So make sure, there is no model +// with this name. +const fullUpdateFormat = "fullupdate/%d" + // Autoupdate holds the state of the autoupdate service. It has to be initialized // with autoupdate.New(). type Autoupdate struct { @@ -30,7 +35,7 @@ type Autoupdate struct { } // New creates a new autoupdate service. -func New(datastore Datastore, restricter Restricter, closed <-chan struct{}) *Autoupdate { +func New(datastore Datastore, restricter Restricter, userUdater UserUpdater, closed <-chan struct{}) *Autoupdate { a := &Autoupdate{ datastore: datastore, restricter: restricter, @@ -43,6 +48,16 @@ func New(datastore Datastore, restricter Restricter, closed <-chan struct{}) *Au for k := range data { keys = append(keys, k) } + + uids, err := userUdater.AdditionalUpdate(context.TODO(), data) + if err != nil { + return fmt.Errorf("getting addition user ids: %w", err) + } + + for _, uid := range uids { + keys = append(keys, fmt.Sprintf(fullUpdateFormat, uid)) + } + a.topic.Publish(keys...) return nil }) @@ -125,7 +140,7 @@ func (a *Autoupdate) RestrictedData(ctx context.Context, uid int, keys ...string data[key] = values[i] } - if err := a.restricter.Restrict(uid, data); err != nil { + if err := a.restricter.Restrict(ctx, uid, data); err != nil { return nil, fmt.Errorf("restrict data: %w", err) } return data, nil diff --git a/internal/autoupdate/autoupdate_test.go b/internal/autoupdate/autoupdate_test.go index b76a3e34..b9eb4d30 100644 --- a/internal/autoupdate/autoupdate_test.go +++ b/internal/autoupdate/autoupdate_test.go @@ -12,7 +12,7 @@ func TestLive(t *testing.T) { datastore := new(test.MockDatastore) closed := make(chan struct{}) defer close(closed) - s := autoupdate.New(datastore, new(test.MockRestricter), closed) + s := autoupdate.New(datastore, new(test.MockRestricter), mockUserUpdater{}, closed) kb := test.KeysBuilder{K: []string{"foo", "bar"}} ctx, cancel := context.WithCancel(context.Background()) diff --git a/internal/autoupdate/connection.go b/internal/autoupdate/connection.go index 5fcfaaf9..6bbe1d4d 100644 --- a/internal/autoupdate/connection.go +++ b/internal/autoupdate/connection.go @@ -22,33 +22,7 @@ type Connection struct { // this case, nil is returned. func (c *Connection) Next(ctx context.Context) (map[string]json.RawMessage, error) { if c.filter == nil { - // First time called - c.filter = new(filter) - if c.tid == 0 { - c.tid = c.autoupdate.topic.LastID() - } - - if err := c.kb.Update(ctx); err != nil { - return nil, fmt.Errorf("create keys for keysbuilder: %w", err) - } - - data, err := c.autoupdate.RestrictedData(ctx, c.uid, c.kb.Keys()...) - if err != nil { - return nil, fmt.Errorf("get first time restricted data: %w", err) - } - - // Delete empty values in first responce. - for k, v := range data { - if len(v) == 0 { - delete(data, k) - } - } - - if err := c.filter.filter(data); err != nil { - return nil, fmt.Errorf("filter data for the first time: %w", err) - } - - return data, nil + return c.allData(ctx) } var err error @@ -60,6 +34,21 @@ func (c *Connection) Next(ctx context.Context) (map[string]json.RawMessage, erro return nil, fmt.Errorf("get updated keys: %w", err) } + changedSlice := make(map[string]bool, len(changedKeys)) + for _, key := range changedKeys { + var uid int + if _, err := fmt.Sscanf(key, fullUpdateFormat, &uid); err == nil { + // The key is a fullUpdate key. Do not use it, excpect of a full + // update. + if uid == c.uid { + return c.allData(ctx) + } + continue + } + + changedSlice[key] = true + } + oldKeys := c.kb.Keys() // Update keysbuilder get new list of keys @@ -70,11 +59,6 @@ func (c *Connection) Next(ctx context.Context) (map[string]json.RawMessage, erro // Start with keys hat are new for the user keys := keysDiff(oldKeys, c.kb.Keys()) - changedSlice := make(map[string]bool, len(changedKeys)) - for _, key := range changedKeys { - changedSlice[key] = true - } - // Append keys that are old but have been changed. for _, key := range oldKeys { if !changedSlice[key] { @@ -107,6 +91,36 @@ func (c *Connection) Next(ctx context.Context) (map[string]json.RawMessage, erro return data, nil } +func (c *Connection) allData(ctx context.Context) (map[string]json.RawMessage, error) { + // First time called + c.filter = new(filter) + if c.tid == 0 { + c.tid = c.autoupdate.topic.LastID() + } + + if err := c.kb.Update(ctx); err != nil { + return nil, fmt.Errorf("create keys for keysbuilder: %w", err) + } + + data, err := c.autoupdate.RestrictedData(ctx, c.uid, c.kb.Keys()...) + if err != nil { + return nil, fmt.Errorf("get first time restricted data: %w", err) + } + + // Delete empty values in first responce. + for k, v := range data { + if len(v) == 0 { + delete(data, k) + } + } + + if err := c.filter.filter(data); err != nil { + return nil, fmt.Errorf("filter data for the first time: %w", err) + } + + return data, nil +} + func keysDiff(old []string, new []string) []string { keySet := make(map[string]bool, len(old)) for _, key := range old { diff --git a/internal/autoupdate/connection_test.go b/internal/autoupdate/connection_test.go index 3d7a7054..7d4db686 100644 --- a/internal/autoupdate/connection_test.go +++ b/internal/autoupdate/connection_test.go @@ -6,6 +6,7 @@ import ( "errors" "fmt" "testing" + "time" "github.com/openslides/openslides-autoupdate-service/internal/autoupdate" "github.com/openslides/openslides-autoupdate-service/internal/test" @@ -86,7 +87,7 @@ func TestConnectionEmptyData(t *testing.T) { closed := make(chan struct{}) defer close(closed) - s := autoupdate.New(datastore, new(test.MockRestricter), closed) + s := autoupdate.New(datastore, new(test.MockRestricter), mockUserUpdater{}, closed) kb := test.KeysBuilder{K: test.Str(doesExistKey, doesNotExistKey)} @@ -189,7 +190,7 @@ func TestConnectionFilterData(t *testing.T) { closed := make(chan struct{}) defer close(closed) - s := autoupdate.New(datastore, new(test.MockRestricter), closed) + s := autoupdate.New(datastore, new(test.MockRestricter), mockUserUpdater{}, closed) kb := test.KeysBuilder{K: test.Str("user/1/name")} c := s.Connect(1, kb) if _, err := c.Next(context.Background()); err != nil { @@ -214,7 +215,7 @@ func TestConntectionFilterOnlyOneKey(t *testing.T) { datastore := new(test.MockDatastore) closed := make(chan struct{}) close(closed) - s := autoupdate.New(datastore, new(test.MockRestricter), closed) + s := autoupdate.New(datastore, new(test.MockRestricter), mockUserUpdater{}, closed) kb := test.KeysBuilder{K: test.Str("user/1/name")} c := s.Connect(1, kb) if _, err := c.Next(context.Background()); err != nil { @@ -239,12 +240,102 @@ func TestConntectionFilterOnlyOneKey(t *testing.T) { } } +func TestFullUpdate(t *testing.T) { + datastore := new(test.MockDatastore) + closed := make(chan struct{}) + defer close(closed) + userUpdater := new(mockUserUpdater) + s := autoupdate.New(datastore, new(test.MockRestricter), userUpdater, closed) + kb := test.KeysBuilder{K: test.Str("user/1/name")} + + t.Run("other user", func(t *testing.T) { + c := s.Connect(1, kb) + if _, err := c.Next(context.Background()); err != nil { + t.Errorf("c.Next() returned an error: %v", err) + } + + // send fulldata for other user + userUpdater.userIDs = []int{2} + datastore.Send(test.Str("some/5/data")) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + var data map[string]json.RawMessage + var err error + isBlocking := blocking(func() { + data, err = c.Next(ctx) + }) + + if !isBlocking { + t.Fatalf("fulldataupdate did not block") + } + + if err != nil { + t.Errorf("Got unexpected error: %v", err) + } + + if len(data) != 0 { + t.Errorf("Got %v, expected no key update", data) + } + }) + + t.Run("same user", func(t *testing.T) { + c := s.Connect(1, kb) + if _, err := c.Next(context.Background()); err != nil { + t.Errorf("c.Next() returned an error: %v", err) + } + + // Send fulldata for same user. + userUpdater.userIDs = []int{1} + datastore.Send(test.Str("some/5/data")) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + var data map[string]json.RawMessage + var err error + isBlocking := blocking(func() { + data, err = c.Next(ctx) + }) + + if isBlocking { + t.Fatalf("fulldataupdate did block") + } + + if err != nil { + t.Errorf("Got unexpected error: %v", err) + } + + if len(data) != 1 || string(data["user/1/name"]) != `"Hello World"` { + t.Errorf("Got %v, expected [user/1/name: Hello World]", data) + } + }) +} + +func blocking(f func()) bool { + done := make(chan struct{}) + go func() { + f() + close(done) + }() + + timer := time.NewTimer(time.Millisecond) + defer timer.Stop() + select { + case <-done: + return false + case <-timer.C: + return true + } +} + func BenchmarkFilterChanging(b *testing.B) { const keyCount = 100 datastore := new(test.MockDatastore) closed := make(chan struct{}) defer close(closed) - s := autoupdate.New(datastore, new(test.MockRestricter), closed) + s := autoupdate.New(datastore, new(test.MockRestricter), mockUserUpdater{}, closed) keys := make([]string, 0, keyCount) for i := 0; i < keyCount; i++ { @@ -271,7 +362,7 @@ func BenchmarkFilterNotChanging(b *testing.B) { datastore := new(test.MockDatastore) closed := make(chan struct{}) defer close(closed) - s := autoupdate.New(datastore, new(test.MockRestricter), closed) + s := autoupdate.New(datastore, new(test.MockRestricter), mockUserUpdater{}, closed) keys := make([]string, 0, keyCount) for i := 0; i < keyCount; i++ { diff --git a/internal/autoupdate/feature_test.go b/internal/autoupdate/feature_test.go index 54952bb5..f2632094 100644 --- a/internal/autoupdate/feature_test.go +++ b/internal/autoupdate/feature_test.go @@ -69,7 +69,7 @@ func TestFeatures(t *testing.T) { datastore.OnlyData = true closed := make(chan struct{}) defer close(closed) - s := autoupdate.New(datastore, new(test.MockRestricter), closed) + s := autoupdate.New(datastore, new(test.MockRestricter), mockUserUpdater{}, closed) for _, tt := range []struct { name string diff --git a/internal/autoupdate/interfaces.go b/internal/autoupdate/interfaces.go index 00eb922e..a39a221c 100644 --- a/internal/autoupdate/interfaces.go +++ b/internal/autoupdate/interfaces.go @@ -14,7 +14,7 @@ type Datastore interface { // Restricter restricts keys. type Restricter interface { // Restrict manipulates the values for the user with the given id. - Restrict(uid int, data map[string]json.RawMessage) error + Restrict(ctx context.Context, uid int, data map[string]json.RawMessage) error } // KeysBuilder holds the keys that are requested by a user. @@ -22,3 +22,8 @@ type KeysBuilder interface { Update(ctx context.Context) error Keys() []string } + +// UserUpdater has a function to get user_ids, that should get a full update. +type UserUpdater interface { + AdditionalUpdate(ctx context.Context, updated map[string]json.RawMessage) ([]int, error) +} diff --git a/internal/autoupdate/mock_test.go b/internal/autoupdate/mock_test.go index d2e33f3b..aee3b9e0 100644 --- a/internal/autoupdate/mock_test.go +++ b/internal/autoupdate/mock_test.go @@ -1,15 +1,26 @@ package autoupdate_test import ( + "context" + "encoding/json" + "github.com/openslides/openslides-autoupdate-service/internal/autoupdate" "github.com/openslides/openslides-autoupdate-service/internal/test" ) func getConnection(closed <-chan struct{}) (*autoupdate.Connection, *test.MockDatastore) { datastore := new(test.MockDatastore) - s := autoupdate.New(datastore, new(test.MockRestricter), closed) + s := autoupdate.New(datastore, new(test.MockRestricter), mockUserUpdater{}, closed) kb := test.KeysBuilder{K: test.Str("user/1/name")} c := s.Connect(1, kb) return c, datastore } + +type mockUserUpdater struct { + userIDs []int +} + +func (u mockUserUpdater) AdditionalUpdate(ctx context.Context, updated map[string]json.RawMessage) ([]int, error) { + return u.userIDs, nil +} diff --git a/internal/restrict/checker.go b/internal/restrict/checker.go index 15945b27..764142cb 100644 --- a/internal/restrict/checker.go +++ b/internal/restrict/checker.go @@ -2,6 +2,7 @@ package restrict //go:generate sh -c "go run gendef/main.go > def.go && go fmt def.go" import ( + "context" "encoding/json" "fmt" "strings" @@ -39,7 +40,7 @@ type relationList struct { model string } -func (r *relationList) Check(uid int, key string, value json.RawMessage) (json.RawMessage, error) { +func (r *relationList) Check(ctx context.Context, uid int, key string, value json.RawMessage) (json.RawMessage, error) { var ids []int if err := json.Unmarshal(value, &ids); err != nil { return nil, fmt.Errorf("decoding %s=%s: %w", key, value, err) @@ -52,7 +53,7 @@ func (r *relationList) Check(uid int, key string, value json.RawMessage) (json.R keyToID[keys[i]] = id } - allowed, err := r.permer.CheckFQIDs(uid, keys) + allowed, err := r.permer.RestrictFQIDs(ctx, uid, keys) if err != nil { return nil, fmt.Errorf("check fqids: %w", err) } @@ -75,7 +76,7 @@ type genericRelationList struct { permer Permissioner } -func (g *genericRelationList) Check(uid int, key string, value json.RawMessage) (json.RawMessage, error) { +func (g *genericRelationList) Check(ctx context.Context, uid int, key string, value json.RawMessage) (json.RawMessage, error) { var fqids []string if err := json.Unmarshal(value, &fqids); err != nil { return nil, fmt.Errorf("decoding %s=%s: %w", key, value, err) @@ -86,7 +87,7 @@ func (g *genericRelationList) Check(uid int, key string, value json.RawMessage) keys[i] = fqid } - allowed, err := g.permer.CheckFQIDs(uid, keys) + allowed, err := g.permer.RestrictFQIDs(ctx, uid, keys) if err != nil { return nil, fmt.Errorf("check fqids: %w", err) } @@ -109,7 +110,7 @@ type templateField struct { permer Permissioner } -func (s *templateField) Check(uid int, key string, value json.RawMessage) (json.RawMessage, error) { +func (s *templateField) Check(ctx context.Context, uid int, key string, value json.RawMessage) (json.RawMessage, error) { var replacments []string if err := json.Unmarshal(value, &replacments); err != nil { return nil, fmt.Errorf("decoding key %s=%s: %w", key, value, err) @@ -122,7 +123,7 @@ func (s *templateField) Check(uid int, key string, value json.RawMessage) (json. keyToReplacement[keys[i]] = r } - allowed, err := s.permer.CheckFQFields(uid, keys) + allowed, err := s.permer.RestrictFQFields(ctx, uid, keys) if err != nil { return nil, fmt.Errorf("check generated structured fields: %w", err) } diff --git a/internal/restrict/checker_test.go b/internal/restrict/checker_test.go index 96b6800f..36b04b61 100644 --- a/internal/restrict/checker_test.go +++ b/internal/restrict/checker_test.go @@ -1,6 +1,7 @@ package restrict_test import ( + "context" "testing" "github.com/openslides/openslides-autoupdate-service/internal/restrict" @@ -45,7 +46,7 @@ func TestRelationChecker(t *testing.T) { "otherModel/2": false, } - v, err := checker["model/relation_ids"].Check(1, "model/1/relation_ids", []byte("[1,2]")) + v, err := checker["model/relation_ids"].Check(context.Background(), 1, "model/1/relation_ids", []byte("[1,2]")) if err != nil { t.Fatalf("Check returned an error: %v", err) @@ -62,7 +63,7 @@ func TestRelationChecker(t *testing.T) { "other_foo/2": false, } - v, err := checker["model/generic_relation_ids"].Check(1, "model/1/generic_relation_ids", []byte(`["foo/1","other_foo/2"]`)) + v, err := checker["model/generic_relation_ids"].Check(context.Background(), 1, "model/1/generic_relation_ids", []byte(`["foo/1","other_foo/2"]`)) if err != nil { t.Errorf("Check returned an error: %v", err) @@ -79,7 +80,7 @@ func TestRelationChecker(t *testing.T) { "model/1/template_$2_ids": false, } - v, err := checker["model/template_$_ids"].Check(1, "model/1/template_$_ids", []byte(`["1","2"]`)) + v, err := checker["model/template_$_ids"].Check(context.Background(), 1, "model/1/template_$_ids", []byte(`["1","2"]`)) if err != nil { t.Errorf("Check returned an error: %v", err) @@ -96,7 +97,7 @@ func TestRelationChecker(t *testing.T) { "otherModel/2": false, } - v, err := checker["model/template_"].Check(1, "model/1/template_$1_ids", []byte(`[1,2]`)) + v, err := checker["model/template_"].Check(context.Background(), 1, "model/1/template_$1_ids", []byte(`[1,2]`)) if err != nil { t.Errorf("Check returned an error: %v", err) @@ -113,7 +114,7 @@ func TestRelationChecker(t *testing.T) { "other_foo/2": false, } - v, err := checker["model/generic_template_"].Check(1, "model/1/generic_template_$1_ids", []byte(`["foo/1","other_foo/2"]`)) + v, err := checker["model/generic_template_"].Check(context.Background(), 1, "model/1/generic_template_$1_ids", []byte(`["foo/1","other_foo/2"]`)) if err != nil { t.Errorf("Check returned an error: %v", err) diff --git a/internal/restrict/interfaces.go b/internal/restrict/interfaces.go index 7b63b3bf..f706c3ce 100644 --- a/internal/restrict/interfaces.go +++ b/internal/restrict/interfaces.go @@ -7,8 +7,8 @@ import ( // Permissioner tells the restricter, if a user has the required permissions. type Permissioner interface { - CheckFQIDs(uid int, fqids []string) (map[string]bool, error) - CheckFQFields(uid int, fqfields []string) (map[string]bool, error) + RestrictFQIDs(ctx context.Context, uid int, fqids []string) (map[string]bool, error) + RestrictFQFields(ctx context.Context, uid int, fqfields []string) (map[string]bool, error) } // Datastore informs the restricter about changed data. @@ -21,13 +21,13 @@ type Datastore interface { // gets replaced with the returned value. Check has to return nil, if the user // is not allowed to see the key. type Checker interface { - Check(uid int, key string, value json.RawMessage) (json.RawMessage, error) + Check(ctx context.Context, uid int, key string, value json.RawMessage) (json.RawMessage, error) } // CheckerFunc is a function that implements the Checker interface. -type CheckerFunc func(uid int, key string, value json.RawMessage) (json.RawMessage, error) +type CheckerFunc func(ctx context.Context, uid int, key string, value json.RawMessage) (json.RawMessage, error) // Check calls the function. -func (f CheckerFunc) Check(uid int, key string, value json.RawMessage) (json.RawMessage, error) { - return f(uid, key, value) +func (f CheckerFunc) Check(ctx context.Context, uid int, key string, value json.RawMessage) (json.RawMessage, error) { + return f(ctx, uid, key, value) } diff --git a/internal/restrict/restrict.go b/internal/restrict/restrict.go index 5e41e21b..0084284f 100644 --- a/internal/restrict/restrict.go +++ b/internal/restrict/restrict.go @@ -3,6 +3,7 @@ package restrict import ( + "context" "encoding/json" "fmt" "strings" @@ -31,12 +32,12 @@ func New(permer Permissioner, checker map[string]Checker) *Restricter { // replaced with a new value. If the user does not have the permission to see // one key, it is not allowed to remove that key, the value has to be set to // nil. -func (r *Restricter) Restrict(uid int, data map[string]json.RawMessage) error { +func (r *Restricter) Restrict(ctx context.Context, uid int, data map[string]json.RawMessage) error { keys := make([]string, 0, len(data)) for k := range data { keys = append(keys, k) } - allowed, err := r.permer.CheckFQFields(uid, keys) + allowed, err := r.permer.RestrictFQFields(ctx, uid, keys) if err != nil { return fmt.Errorf("check permissions: %w", err) } @@ -56,7 +57,7 @@ func (r *Restricter) Restrict(uid int, data map[string]json.RawMessage) error { continue } - nv, err := checker.Check(uid, k, v) + nv, err := checker.Check(ctx, uid, k, v) if err != nil { return fmt.Errorf("checker for key %s: %w", k, err) } diff --git a/internal/restrict/restrict_test.go b/internal/restrict/restrict_test.go index 333e9819..3e9a6830 100644 --- a/internal/restrict/restrict_test.go +++ b/internal/restrict/restrict_test.go @@ -1,6 +1,7 @@ package restrict_test import ( + "context" "encoding/json" "testing" @@ -19,7 +20,7 @@ func TestRestrict(t *testing.T) { "user/1/name": []byte("uwe"), "user/1/password": []byte("easy"), } - if err := r.Restrict(1, data); err != nil { + if err := r.Restrict(context.Background(), 1, data); err != nil { t.Errorf("Restrict returned unexpected error: %v", err) } @@ -42,15 +43,15 @@ func TestChecker(t *testing.T) { called := make(map[string]bool) checker := map[string]restrict.Checker{ - "user/name": restrict.CheckerFunc(func(uid int, key string, value json.RawMessage) (json.RawMessage, error) { + "user/name": restrict.CheckerFunc(func(ctx context.Context, uid int, key string, value json.RawMessage) (json.RawMessage, error) { called[key] = true return []byte("touched"), nil }), - "user/password": restrict.CheckerFunc(func(uid int, key string, value json.RawMessage) (json.RawMessage, error) { + "user/password": restrict.CheckerFunc(func(ctx context.Context, uid int, key string, value json.RawMessage) (json.RawMessage, error) { called[key] = true return []byte("touched"), nil }), - "user/first_name": restrict.CheckerFunc(func(uid int, key string, value json.RawMessage) (json.RawMessage, error) { + "user/first_name": restrict.CheckerFunc(func(ctx context.Context, uid int, key string, value json.RawMessage) (json.RawMessage, error) { called[key] = true return []byte("touched"), nil }), @@ -62,7 +63,7 @@ func TestChecker(t *testing.T) { "user/1/password": []byte("easy"), "user/1/first_name": nil, } - if err := r.Restrict(1, data); err != nil { + if err := r.Restrict(context.Background(), 1, data); err != nil { t.Errorf("Restrict returned unexpected error: %v", err) } diff --git a/internal/test/permission_mock.go b/internal/test/permission_mock.go index bfddaa34..fe74dbb2 100644 --- a/internal/test/permission_mock.go +++ b/internal/test/permission_mock.go @@ -1,6 +1,9 @@ package test -import "sync" +import ( + "context" + "sync" +) //MockPermission mocks the permission api. type MockPermission struct { @@ -10,8 +13,8 @@ type MockPermission struct { Default bool } -// CheckFQIDs returns the fields where p.Data is true. -func (p *MockPermission) CheckFQIDs(uid int, fqids []string) (map[string]bool, error) { +// RestrictFQIDs returns the fields where p.Data is true. +func (p *MockPermission) RestrictFQIDs(ctx context.Context, uid int, fqids []string) (map[string]bool, error) { p.mu.Lock() defer p.mu.Unlock() @@ -35,7 +38,7 @@ func (p *MockPermission) CheckFQIDs(uid int, fqids []string) (map[string]bool, e return out, nil } -// CheckFQFields calls CheckFQIDs. -func (p *MockPermission) CheckFQFields(uid int, fqfields []string) (map[string]bool, error) { - return p.CheckFQIDs(uid, fqfields) +// RestrictFQFields calls RestrictFQIDs. +func (p *MockPermission) RestrictFQFields(ctx context.Context, uid int, fqfields []string) (map[string]bool, error) { + return p.RestrictFQIDs(ctx, uid, fqfields) } diff --git a/internal/test/restricter_mock.go b/internal/test/restricter_mock.go index a2a69de4..f295615b 100644 --- a/internal/test/restricter_mock.go +++ b/internal/test/restricter_mock.go @@ -1,11 +1,14 @@ package test -import "encoding/json" +import ( + "context" + "encoding/json" +) // MockRestricter implements the restricter interface. type MockRestricter struct{} // Restrict does currently nothing. -func (r *MockRestricter) Restrict(uid int, data map[string]json.RawMessage) error { +func (r *MockRestricter) Restrict(ctx context.Context, uid int, data map[string]json.RawMessage) error { return nil }