diff --git a/pkg/enforcement/bulk_check.go b/pkg/enforcement/bulk_check.go index a5174e2..e8e5545 100644 --- a/pkg/enforcement/bulk_check.go +++ b/pkg/enforcement/bulk_check.go @@ -131,3 +131,32 @@ func (e *PermitEnforcer) BulkCheck(requests ...CheckRequest) ([]bool, error) { } return allowResults, nil } + +func (e *PermitEnforcer) FilterObjects(user User, action Action, context map[string]string, resources ...ResourceI) ([]ResourceI, error) { + requests := make([]CheckRequest, len(resources)) + for i, resource := range resources { + permitResource := ResourceBuilder(resource.GetType()). + WithID(resource.GetID()). + WithContext(resource.GetContext()). + WithAttributes(resource.GetAttributes()). + WithTenant(resource.GetTenant()). + Build() + requests[i] = *NewCheckRequest(user, + action, + permitResource, + context, + ) + } + results, err := e.BulkCheck(requests...) + if err != nil { + return nil, err + } + filteredResources := make([]ResourceI, 0) + for i, result := range results { + if result { + filteredResources = append(filteredResources, resources[i]) + } + } + + return filteredResources, nil +} diff --git a/pkg/enforcement/resource.go b/pkg/enforcement/resource.go index 1435fb1..bd69dd5 100644 --- a/pkg/enforcement/resource.go +++ b/pkg/enforcement/resource.go @@ -1,5 +1,13 @@ package enforcement +type ResourceI interface { + GetID() string + GetType() string + GetTenant() string + GetAttributes() map[string]string + GetContext() map[string]string +} + type Resource struct { Type string `json:"type,omitempty"` ID string `json:"id,omitempty"` @@ -8,6 +16,26 @@ type Resource struct { Context map[string]string `json:"context,omitempty"` } +func (r *Resource) GetID() string { + return r.ID +} + +func (r *Resource) GetType() string { + return r.Type +} + +func (r *Resource) GetTenant() string { + return r.Tenant +} + +func (r *Resource) GetAttributes() map[string]string { + return r.Attributes +} + +func (r *Resource) GetContext() map[string]string { + return r.Context +} + func ResourceBuilder(resourceType string) *Resource { return &Resource{ Type: resourceType, diff --git a/pkg/permit/permit.go b/pkg/permit/permit.go index 4a39e23..487638a 100644 --- a/pkg/permit/permit.go +++ b/pkg/permit/permit.go @@ -40,6 +40,10 @@ func (c *Client) BulkCheck(requests ...enforcement.CheckRequest) ([]bool, error) return c.enforcement.BulkCheck(requests...) } +func (c *Client) FilterObjects(user enforcement.User, action enforcement.Action, context map[string]string, resources ...enforcement.ResourceI) ([]enforcement.ResourceI, error) { + return c.enforcement.FilterObjects(user, action, context, resources...) +} + func (c *Client) AllTenantsCheck(user enforcement.User, action enforcement.Action, resource enforcement.Resource) ([]enforcement.TenantDetails, error) { return c.enforcement.AllTenantsCheck(user, action, resource) } @@ -47,6 +51,7 @@ func (c *Client) AllTenantsCheck(user enforcement.User, action enforcement.Actio type PermitInterface interface { Check(user enforcement.User, action enforcement.Action, resource enforcement.Resource) (bool, error) BulkCheck(requests ...enforcement.CheckRequest) ([]bool, error) + FilterObjects(user enforcement.User, action enforcement.Action, context map[string]string, resources ...enforcement.ResourceI) ([]enforcement.ResourceI, error) AllTenantsCheck(request enforcement.CheckRequest) ([]enforcement.TenantDetails, error) SyncUser(ctx context.Context, user models.UserCreate) (*models.UserRead, error) } diff --git a/pkg/tests/integration_test.go b/pkg/tests/integration_test.go index b6c9888..7de69d9 100644 --- a/pkg/tests/integration_test.go +++ b/pkg/tests/integration_test.go @@ -12,6 +12,7 @@ import ( "go.uber.org/zap" "math/rand" "os" + "reflect" "testing" "time" ) @@ -20,6 +21,39 @@ func init() { rand.Seed(time.Now().UnixNano()) } +type MyResource struct { + UniqueID string + Type string + Organization string +} + +func (m MyResource) GetID() string { + return m.UniqueID +} + +func (m MyResource) GetType() string { + if m.Type != "" { + return m.Type + } + if t := reflect.TypeOf(m); t.Kind() == reflect.Ptr { + return t.Elem().Name() + } else { + return t.Name() + } +} + +func (m MyResource) GetTenant() string { + return m.Organization +} + +func (m MyResource) GetAttributes() map[string]string { + return make(map[string]string) +} + +func (m MyResource) GetContext() map[string]string { + return make(map[string]string) +} + var letterRunes = []rune("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ") func randKey(prefix string) string { @@ -76,7 +110,8 @@ func checkBulk(ctx context.Context, t *testing.T, permitClient *permit.Client, r Resource: enforcement.ResourceBuilder(resourceKey).WithTenant(tenantKey).Build(), Context: nil, } - results, _ := permitClient.BulkCheck(requests...) + results, err := permitClient.BulkCheck(requests...) + assert.NoError(t, err) assert.Len(t, results, len(bulkAssignments)+1) for i := 0; i <= len(bulkAssignments); i++ { if i == len(bulkAssignments) { @@ -247,6 +282,22 @@ func TestIntegration(t *testing.T) { assert.NoError(t, err) assert.True(t, allowed) + myResources := []enforcement.ResourceI{ + MyResource{ + UniqueID: "my-random-id", + Organization: tenantKey, + Type: resourceKey, + }, + MyResource{ + UniqueID: "my-random-id-2", + Organization: tenantKey, + }, + } + filteredResources, err := permitClient.FilterObjects(userCheck, "read", nil, myResources...) + assert.NoError(t, err) + assert.Len(t, filteredResources, 1) + assert.True(t, assert.ObjectsAreEqual(&filteredResources[0], &myResources[0])) + allowedTenants, err := permitClient.AllTenantsCheck( userCheck, "read",