diff --git a/internal/api/v1beta1connect/errors.go b/internal/api/v1beta1connect/errors.go index 03fe21581..74d7799fd 100644 --- a/internal/api/v1beta1connect/errors.go +++ b/internal/api/v1beta1connect/errors.go @@ -46,4 +46,8 @@ var ( ErrBillingProviderNotSupported = errors.New("provider not supported") ErrInsufficientCredits = errors.New("insufficient credits") ErrAlreadyApplied = errors.New("credits already applied") + ErrInvalidRoleID = errors.New("role id is invalid") + ErrNamespaceSplitNotation = errors.New("subject/object should be provided as 'namespace:uuid'") + ErrPolicyNotFound = errors.New("policy doesn't exist") + ErrProjectNotFound = errors.New("project doesn't exist") ) diff --git a/internal/api/v1beta1connect/mocks/policy_service.go b/internal/api/v1beta1connect/mocks/policy_service.go new file mode 100644 index 000000000..81352b1ce --- /dev/null +++ b/internal/api/v1beta1connect/mocks/policy_service.go @@ -0,0 +1,321 @@ +// Code generated by mockery v2.45.0. DO NOT EDIT. + +package mocks + +import ( + context "context" + + policy "github.com/raystack/frontier/core/policy" + mock "github.com/stretchr/testify/mock" + + role "github.com/raystack/frontier/core/role" +) + +// PolicyService is an autogenerated mock type for the PolicyService type +type PolicyService struct { + mock.Mock +} + +type PolicyService_Expecter struct { + mock *mock.Mock +} + +func (_m *PolicyService) EXPECT() *PolicyService_Expecter { + return &PolicyService_Expecter{mock: &_m.Mock} +} + +// Create provides a mock function with given fields: ctx, pol +func (_m *PolicyService) Create(ctx context.Context, pol policy.Policy) (policy.Policy, error) { + ret := _m.Called(ctx, pol) + + if len(ret) == 0 { + panic("no return value specified for Create") + } + + var r0 policy.Policy + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, policy.Policy) (policy.Policy, error)); ok { + return rf(ctx, pol) + } + if rf, ok := ret.Get(0).(func(context.Context, policy.Policy) policy.Policy); ok { + r0 = rf(ctx, pol) + } else { + r0 = ret.Get(0).(policy.Policy) + } + + if rf, ok := ret.Get(1).(func(context.Context, policy.Policy) error); ok { + r1 = rf(ctx, pol) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// PolicyService_Create_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Create' +type PolicyService_Create_Call struct { + *mock.Call +} + +// Create is a helper method to define mock.On call +// - ctx context.Context +// - pol policy.Policy +func (_e *PolicyService_Expecter) Create(ctx interface{}, pol interface{}) *PolicyService_Create_Call { + return &PolicyService_Create_Call{Call: _e.mock.On("Create", ctx, pol)} +} + +func (_c *PolicyService_Create_Call) Run(run func(ctx context.Context, pol policy.Policy)) *PolicyService_Create_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(policy.Policy)) + }) + return _c +} + +func (_c *PolicyService_Create_Call) Return(_a0 policy.Policy, _a1 error) *PolicyService_Create_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *PolicyService_Create_Call) RunAndReturn(run func(context.Context, policy.Policy) (policy.Policy, error)) *PolicyService_Create_Call { + _c.Call.Return(run) + return _c +} + +// Delete provides a mock function with given fields: ctx, id +func (_m *PolicyService) Delete(ctx context.Context, id string) error { + ret := _m.Called(ctx, id) + + if len(ret) == 0 { + panic("no return value specified for Delete") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string) error); ok { + r0 = rf(ctx, id) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// PolicyService_Delete_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Delete' +type PolicyService_Delete_Call struct { + *mock.Call +} + +// Delete is a helper method to define mock.On call +// - ctx context.Context +// - id string +func (_e *PolicyService_Expecter) Delete(ctx interface{}, id interface{}) *PolicyService_Delete_Call { + return &PolicyService_Delete_Call{Call: _e.mock.On("Delete", ctx, id)} +} + +func (_c *PolicyService_Delete_Call) Run(run func(ctx context.Context, id string)) *PolicyService_Delete_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string)) + }) + return _c +} + +func (_c *PolicyService_Delete_Call) Return(_a0 error) *PolicyService_Delete_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *PolicyService_Delete_Call) RunAndReturn(run func(context.Context, string) error) *PolicyService_Delete_Call { + _c.Call.Return(run) + return _c +} + +// Get provides a mock function with given fields: ctx, id +func (_m *PolicyService) Get(ctx context.Context, id string) (policy.Policy, error) { + ret := _m.Called(ctx, id) + + if len(ret) == 0 { + panic("no return value specified for Get") + } + + var r0 policy.Policy + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string) (policy.Policy, error)); ok { + return rf(ctx, id) + } + if rf, ok := ret.Get(0).(func(context.Context, string) policy.Policy); ok { + r0 = rf(ctx, id) + } else { + r0 = ret.Get(0).(policy.Policy) + } + + if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { + r1 = rf(ctx, id) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// PolicyService_Get_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Get' +type PolicyService_Get_Call struct { + *mock.Call +} + +// Get is a helper method to define mock.On call +// - ctx context.Context +// - id string +func (_e *PolicyService_Expecter) Get(ctx interface{}, id interface{}) *PolicyService_Get_Call { + return &PolicyService_Get_Call{Call: _e.mock.On("Get", ctx, id)} +} + +func (_c *PolicyService_Get_Call) Run(run func(ctx context.Context, id string)) *PolicyService_Get_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string)) + }) + return _c +} + +func (_c *PolicyService_Get_Call) Return(_a0 policy.Policy, _a1 error) *PolicyService_Get_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *PolicyService_Get_Call) RunAndReturn(run func(context.Context, string) (policy.Policy, error)) *PolicyService_Get_Call { + _c.Call.Return(run) + return _c +} + +// List provides a mock function with given fields: ctx, f +func (_m *PolicyService) List(ctx context.Context, f policy.Filter) ([]policy.Policy, error) { + ret := _m.Called(ctx, f) + + if len(ret) == 0 { + panic("no return value specified for List") + } + + var r0 []policy.Policy + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, policy.Filter) ([]policy.Policy, error)); ok { + return rf(ctx, f) + } + if rf, ok := ret.Get(0).(func(context.Context, policy.Filter) []policy.Policy); ok { + r0 = rf(ctx, f) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]policy.Policy) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, policy.Filter) error); ok { + r1 = rf(ctx, f) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// PolicyService_List_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'List' +type PolicyService_List_Call struct { + *mock.Call +} + +// List is a helper method to define mock.On call +// - ctx context.Context +// - f policy.Filter +func (_e *PolicyService_Expecter) List(ctx interface{}, f interface{}) *PolicyService_List_Call { + return &PolicyService_List_Call{Call: _e.mock.On("List", ctx, f)} +} + +func (_c *PolicyService_List_Call) Run(run func(ctx context.Context, f policy.Filter)) *PolicyService_List_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(policy.Filter)) + }) + return _c +} + +func (_c *PolicyService_List_Call) Return(_a0 []policy.Policy, _a1 error) *PolicyService_List_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *PolicyService_List_Call) RunAndReturn(run func(context.Context, policy.Filter) ([]policy.Policy, error)) *PolicyService_List_Call { + _c.Call.Return(run) + return _c +} + +// ListRoles provides a mock function with given fields: ctx, principalType, principalID, objectNamespace, objectID +func (_m *PolicyService) ListRoles(ctx context.Context, principalType string, principalID string, objectNamespace string, objectID string) ([]role.Role, error) { + ret := _m.Called(ctx, principalType, principalID, objectNamespace, objectID) + + if len(ret) == 0 { + panic("no return value specified for ListRoles") + } + + var r0 []role.Role + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string, string, string, string) ([]role.Role, error)); ok { + return rf(ctx, principalType, principalID, objectNamespace, objectID) + } + if rf, ok := ret.Get(0).(func(context.Context, string, string, string, string) []role.Role); ok { + r0 = rf(ctx, principalType, principalID, objectNamespace, objectID) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]role.Role) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, string, string, string, string) error); ok { + r1 = rf(ctx, principalType, principalID, objectNamespace, objectID) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// PolicyService_ListRoles_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ListRoles' +type PolicyService_ListRoles_Call struct { + *mock.Call +} + +// ListRoles is a helper method to define mock.On call +// - ctx context.Context +// - principalType string +// - principalID string +// - objectNamespace string +// - objectID string +func (_e *PolicyService_Expecter) ListRoles(ctx interface{}, principalType interface{}, principalID interface{}, objectNamespace interface{}, objectID interface{}) *PolicyService_ListRoles_Call { + return &PolicyService_ListRoles_Call{Call: _e.mock.On("ListRoles", ctx, principalType, principalID, objectNamespace, objectID)} +} + +func (_c *PolicyService_ListRoles_Call) Run(run func(ctx context.Context, principalType string, principalID string, objectNamespace string, objectID string)) *PolicyService_ListRoles_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string), args[2].(string), args[3].(string), args[4].(string)) + }) + return _c +} + +func (_c *PolicyService_ListRoles_Call) Return(_a0 []role.Role, _a1 error) *PolicyService_ListRoles_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *PolicyService_ListRoles_Call) RunAndReturn(run func(context.Context, string, string, string, string) ([]role.Role, error)) *PolicyService_ListRoles_Call { + _c.Call.Return(run) + return _c +} + +// NewPolicyService creates a new instance of PolicyService. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewPolicyService(t interface { + mock.TestingT + Cleanup(func()) +}) *PolicyService { + mock := &PolicyService{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/internal/api/v1beta1connect/policy.go b/internal/api/v1beta1connect/policy.go new file mode 100644 index 000000000..f7d2e31b1 --- /dev/null +++ b/internal/api/v1beta1connect/policy.go @@ -0,0 +1,260 @@ +package v1beta1connect + +import ( + "context" + "errors" + + "connectrpc.com/connect" + "github.com/raystack/frontier/core/audit" + "github.com/raystack/frontier/core/namespace" + "github.com/raystack/frontier/core/policy" + "github.com/raystack/frontier/core/role" + "github.com/raystack/frontier/internal/bootstrap/schema" + "github.com/raystack/frontier/pkg/metadata" + "github.com/raystack/frontier/pkg/utils" + frontierv1beta1 "github.com/raystack/frontier/proto/v1beta1" + "google.golang.org/protobuf/types/known/structpb" + "google.golang.org/protobuf/types/known/timestamppb" +) + +type PolicyService interface { + Get(ctx context.Context, id string) (policy.Policy, error) + List(ctx context.Context, f policy.Filter) ([]policy.Policy, error) + Create(ctx context.Context, pol policy.Policy) (policy.Policy, error) + Delete(ctx context.Context, id string) error +} + +func (h *ConnectHandler) CreatePolicy(ctx context.Context, request *connect.Request[frontierv1beta1.CreatePolicyRequest]) (*connect.Response[frontierv1beta1.CreatePolicyResponse], error) { + var metaDataMap metadata.Metadata + if request.Msg.GetBody().GetMetadata() != nil { + metaDataMap = metadata.Build(request.Msg.GetBody().GetMetadata().AsMap()) + } + + resourceType, resourceID, err := schema.SplitNamespaceAndResourceID(request.Msg.GetBody().GetResource()) + if err != nil { + return nil, connect.NewError(connect.CodeInvalidArgument, ErrNamespaceSplitNotation) + } + principalType, principalID, err := schema.SplitNamespaceAndResourceID(request.Msg.GetBody().GetPrincipal()) + if err != nil { + return nil, connect.NewError(connect.CodeInvalidArgument, ErrNamespaceSplitNotation) + } + + newPolicy, err := h.policyService.Create(ctx, policy.Policy{ + RoleID: request.Msg.GetBody().GetRoleId(), + ResourceID: resourceID, + ResourceType: resourceType, + PrincipalID: principalID, + PrincipalType: principalType, + Metadata: metaDataMap, + }) + if err != nil { + switch { + case errors.Is(err, role.ErrInvalidID): + return nil, connect.NewError(connect.CodeInvalidArgument, ErrInvalidRoleID) + case errors.Is(err, policy.ErrInvalidDetail): + return nil, connect.NewError(connect.CodeInvalidArgument, ErrBadRequest) + default: + return nil, connect.NewError(connect.CodeInternal, ErrInternalServerError) + } + } + + policyPB, err := transformPolicyToPB(newPolicy) + if err != nil { + return nil, connect.NewError(connect.CodeInternal, ErrInternalServerError) + } + + auditPolicyCreationEvent(ctx, newPolicy) + return connect.NewResponse(&frontierv1beta1.CreatePolicyResponse{Policy: policyPB}), nil +} + +func (h *ConnectHandler) GetPolicy(ctx context.Context, request *connect.Request[frontierv1beta1.GetPolicyRequest]) (*connect.Response[frontierv1beta1.GetPolicyResponse], error) { + fetchedPolicy, err := h.policyService.Get(ctx, request.Msg.GetId()) + if err != nil { + switch { + case errors.Is(err, policy.ErrNotExist), + errors.Is(err, policy.ErrInvalidUUID), + errors.Is(err, policy.ErrInvalidID): + return nil, connect.NewError(connect.CodeNotFound, ErrPolicyNotFound) + default: + return nil, connect.NewError(connect.CodeInternal, ErrInternalServerError) + } + } + + policyPB, err := transformPolicyToPB(fetchedPolicy) + if err != nil { + return nil, connect.NewError(connect.CodeInternal, ErrInternalServerError) + } + + return connect.NewResponse(&frontierv1beta1.GetPolicyResponse{Policy: policyPB}), nil +} + +func (h *ConnectHandler) DeletePolicy(ctx context.Context, request *connect.Request[frontierv1beta1.DeletePolicyRequest]) (*connect.Response[frontierv1beta1.DeletePolicyResponse], error) { + err := h.policyService.Delete(ctx, request.Msg.GetId()) + if err != nil { + switch { + case errors.Is(err, policy.ErrNotExist), + errors.Is(err, policy.ErrInvalidID), + errors.Is(err, policy.ErrInvalidUUID): + return nil, connect.NewError(connect.CodeNotFound, ErrPolicyNotFound) + case errors.Is(err, policy.ErrInvalidDetail), + errors.Is(err, namespace.ErrNotExist): + return nil, connect.NewError(connect.CodeInvalidArgument, ErrBadRequest) + case errors.Is(err, policy.ErrConflict): + return nil, connect.NewError(connect.CodeAlreadyExists, ErrConflictRequest) + default: + return nil, connect.NewError(connect.CodeInternal, ErrInternalServerError) + } + } + + audit.GetAuditor(ctx, schema.PlatformOrgID.String()).Log(audit.PolicyDeletedEvent, audit.Target{ + ID: request.Msg.GetId(), + Type: "app/policy", + }) + return connect.NewResponse(&frontierv1beta1.DeletePolicyResponse{}), nil +} + +func (h *ConnectHandler) CreatePolicyForProject(ctx context.Context, request *connect.Request[frontierv1beta1.CreatePolicyForProjectRequest]) (*connect.Response[frontierv1beta1.CreatePolicyForProjectResponse], error) { + if request.Msg.GetBody() == nil || request.Msg.GetBody().GetRoleId() == "" || request.Msg.GetBody().GetPrincipal() == "" { + return nil, connect.NewError(connect.CodeInvalidArgument, ErrBadRequest) + } + + principalType, principalID, err := schema.SplitNamespaceAndResourceID(request.Msg.GetBody().GetPrincipal()) + if err != nil { + return nil, connect.NewError(connect.CodeInvalidArgument, ErrNamespaceSplitNotation) + } + + project, err := h.projectService.Get(ctx, request.Msg.GetProjectId()) + if err != nil { + return nil, connect.NewError(connect.CodeNotFound, ErrProjectNotFound) + } + + p := policy.Policy{ + RoleID: request.Msg.GetBody().GetRoleId(), + PrincipalType: principalType, + PrincipalID: principalID, + ResourceID: project.ID, + ResourceType: schema.ProjectNamespace, + } + + newPolicy, err := h.policyService.Create(ctx, p) + if err != nil { + switch { + case errors.Is(err, role.ErrInvalidID): + return nil, connect.NewError(connect.CodeInvalidArgument, ErrInvalidRoleID) + case errors.Is(err, policy.ErrInvalidDetail): + return nil, connect.NewError(connect.CodeInvalidArgument, ErrBadRequest) + default: + return nil, connect.NewError(connect.CodeInternal, ErrInternalServerError) + } + } + + auditPolicyCreationEvent(ctx, newPolicy) + return connect.NewResponse(&frontierv1beta1.CreatePolicyForProjectResponse{}), nil +} + +func (h *ConnectHandler) ListPolicies(ctx context.Context, request *connect.Request[frontierv1beta1.ListPoliciesRequest]) (*connect.Response[frontierv1beta1.ListPoliciesResponse], error) { + var policies []*frontierv1beta1.Policy + + filter, err := h.resolveFilter(ctx, request.Msg) + if err != nil { + return nil, connect.NewError(connect.CodeInvalidArgument, ErrBadRequest) + } + + policyList, err := h.policyService.List(ctx, filter) + if err != nil { + return nil, connect.NewError(connect.CodeInternal, ErrInternalServerError) + } + + for _, p := range policyList { + policyPB, err := transformPolicyToPB(p) + if err != nil { + return nil, connect.NewError(connect.CodeInternal, ErrInternalServerError) + } + policies = append(policies, policyPB) + } + + return connect.NewResponse(&frontierv1beta1.ListPoliciesResponse{Policies: policies}), nil +} + +func transformPolicyToPB(policy policy.Policy) (*frontierv1beta1.Policy, error) { + var metadata *structpb.Struct + var err error + if len(policy.Metadata) > 0 { + metadata, err = structpb.NewStruct(policy.Metadata) + if err != nil { + return nil, err + } + } + + pbPol := &frontierv1beta1.Policy{ + Id: policy.ID, + RoleId: policy.RoleID, + Resource: schema.JoinNamespaceAndResourceID(policy.ResourceType, policy.ResourceID), + Principal: schema.JoinNamespaceAndResourceID(policy.PrincipalType, policy.PrincipalID), + Metadata: metadata, + } + if !policy.CreatedAt.IsZero() { + pbPol.CreatedAt = timestamppb.New(policy.CreatedAt) + } + if !policy.UpdatedAt.IsZero() { + pbPol.UpdatedAt = timestamppb.New(policy.UpdatedAt) + } + return pbPol, nil +} + +func auditPolicyCreationEvent(ctx context.Context, policyCreated policy.Policy) { + audit.GetAuditor(ctx, schema.PlatformOrgID.String()). + LogWithAttrs(audit.PolicyCreatedEvent, audit.Target{ + ID: policyCreated.ResourceID, + Type: policyCreated.ResourceType, + }, map[string]string{ + "role_id": policyCreated.RoleID, + "principal_id": policyCreated.PrincipalID, + "principal_type": policyCreated.PrincipalType, + }) +} + +// resolveFilter resolves the filter from the request and returns a policy filter +// if the filter fields are not valid UUIDs, it will try to resolve them to their names and then return the filter. Note the group id is not resolved to name as it is not unique +func (h *ConnectHandler) resolveFilter(ctx context.Context, request *frontierv1beta1.ListPoliciesRequest) (policy.Filter, error) { + var filter policy.Filter + orgID := request.GetOrgId() + if orgID != "" && !utils.IsValidUUID(orgID) { + org, err := h.orgService.Get(ctx, orgID) + if err != nil { + return filter, err + } + orgID = org.ID + } + roleId := request.GetRoleId() + if roleId != "" && !utils.IsValidUUID(roleId) { + role, err := h.roleService.Get(ctx, roleId) + if err != nil { + return filter, err + } + roleId = role.ID + } + projectId := request.GetProjectId() + if projectId != "" && !utils.IsValidUUID(projectId) { + project, err := h.projectService.Get(ctx, projectId) + if err != nil { + return filter, err + } + projectId = project.ID + } + userId := request.GetUserId() + if userId != "" && !utils.IsValidUUID(userId) { + user, err := h.userService.GetByID(ctx, userId) + if err != nil { + return filter, err + } + userId = user.ID + } + return policy.Filter{ + PrincipalID: userId, + OrgID: orgID, + ProjectID: projectId, + GroupID: request.GetGroupId(), + RoleID: roleId, + }, nil +} diff --git a/internal/api/v1beta1connect/policy_test.go b/internal/api/v1beta1connect/policy_test.go new file mode 100644 index 000000000..f2bf57bac --- /dev/null +++ b/internal/api/v1beta1connect/policy_test.go @@ -0,0 +1,994 @@ +package v1beta1connect + +import ( + "context" + "errors" + "testing" + "time" + + "connectrpc.com/connect" + "github.com/raystack/frontier/core/namespace" + "github.com/raystack/frontier/core/policy" + "github.com/raystack/frontier/core/project" + "github.com/raystack/frontier/core/role" + projectMocks "github.com/raystack/frontier/internal/api/v1beta1/mocks" + "github.com/raystack/frontier/internal/api/v1beta1connect/mocks" + "github.com/raystack/frontier/pkg/metadata" + "github.com/raystack/frontier/pkg/utils" + frontierv1beta1 "github.com/raystack/frontier/proto/v1beta1" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "google.golang.org/protobuf/types/known/structpb" + "google.golang.org/protobuf/types/known/timestamppb" +) + +func TestConnectHandler_CreatePolicy(t *testing.T) { + fixedTime := time.Date(2023, 1, 1, 0, 0, 0, 0, time.UTC) + testPolicyID := utils.NewString() + testUserID := utils.NewString() + testResourceID := utils.NewString() + testGroupID := utils.NewString() + + tests := []struct { + name string + setup func(ps *mocks.PolicyService) + request *connect.Request[frontierv1beta1.CreatePolicyRequest] + want *connect.Response[frontierv1beta1.CreatePolicyResponse] + wantErr error + errCode connect.Code + }{ + { + name: "should return invalid argument error when resource namespace splitting fails", + setup: func(ps *mocks.PolicyService) { + // No expectations as we return early on resource splitting error + }, + request: connect.NewRequest(&frontierv1beta1.CreatePolicyRequest{ + Body: &frontierv1beta1.PolicyRequestBody{ + RoleId: "admin", + Resource: "invalid-resource-format", + Principal: "user:" + testUserID, + }, + }), + want: nil, + wantErr: ErrBadRequest, + errCode: connect.CodeInvalidArgument, + }, + { + name: "should return invalid argument error when principal namespace splitting fails", + setup: func(ps *mocks.PolicyService) { + // No expectations as we return early on principal splitting error + }, + request: connect.NewRequest(&frontierv1beta1.CreatePolicyRequest{ + Body: &frontierv1beta1.PolicyRequestBody{ + RoleId: "admin", + Resource: "project:" + testResourceID, + Principal: "invalid-principal-format", + }, + }), + want: nil, + wantErr: ErrBadRequest, + errCode: connect.CodeInvalidArgument, + }, + { + name: "should return invalid argument error when role ID is invalid", + setup: func(ps *mocks.PolicyService) { + ps.On("Create", mock.Anything, policy.Policy{ + RoleID: "invalid-role", + ResourceID: testResourceID, + ResourceType: "app/project", + PrincipalID: testUserID, + PrincipalType: "app/user", + Metadata: metadata.Metadata(nil), + }).Return(policy.Policy{}, role.ErrInvalidID) + }, + request: connect.NewRequest(&frontierv1beta1.CreatePolicyRequest{ + Body: &frontierv1beta1.PolicyRequestBody{ + RoleId: "invalid-role", + Resource: "project:" + testResourceID, + Principal: "user:" + testUserID, + }, + }), + want: nil, + wantErr: role.ErrInvalidID, + errCode: connect.CodeInvalidArgument, + }, + { + name: "should return invalid argument error when policy details are invalid", + setup: func(ps *mocks.PolicyService) { + ps.On("Create", mock.Anything, policy.Policy{ + RoleID: "admin", + ResourceID: testResourceID, + ResourceType: "app/project", + PrincipalID: testUserID, + PrincipalType: "app/user", + Metadata: metadata.Metadata(nil), + }).Return(policy.Policy{}, policy.ErrInvalidDetail) + }, + request: connect.NewRequest(&frontierv1beta1.CreatePolicyRequest{ + Body: &frontierv1beta1.PolicyRequestBody{ + RoleId: "admin", + Resource: "project:" + testResourceID, + Principal: "user:" + testUserID, + }, + }), + want: nil, + wantErr: ErrBadRequest, + errCode: connect.CodeInvalidArgument, + }, + { + name: "should return internal server error when policy service returns unknown error", + setup: func(ps *mocks.PolicyService) { + ps.On("Create", mock.Anything, policy.Policy{ + RoleID: "admin", + ResourceID: testResourceID, + ResourceType: "app/project", + PrincipalID: testUserID, + PrincipalType: "app/user", + Metadata: metadata.Metadata(nil), + }).Return(policy.Policy{}, errors.New("service error")) + }, + request: connect.NewRequest(&frontierv1beta1.CreatePolicyRequest{ + Body: &frontierv1beta1.PolicyRequestBody{ + RoleId: "admin", + Resource: "project:" + testResourceID, + Principal: "user:" + testUserID, + }, + }), + want: nil, + wantErr: ErrInternalServerError, + errCode: connect.CodeInternal, + }, + { + name: "should successfully create policy with basic data", + setup: func(ps *mocks.PolicyService) { + ps.On("Create", mock.Anything, policy.Policy{ + RoleID: "admin", + ResourceID: testResourceID, + ResourceType: "app/project", + PrincipalID: testUserID, + PrincipalType: "app/user", + Metadata: metadata.Metadata(nil), + }).Return(policy.Policy{ + ID: testPolicyID, + RoleID: "admin", + ResourceID: testResourceID, + ResourceType: "app/project", + PrincipalID: testUserID, + PrincipalType: "app/user", + Metadata: metadata.Metadata{}, + CreatedAt: fixedTime, + UpdatedAt: fixedTime, + }, nil) + }, + request: connect.NewRequest(&frontierv1beta1.CreatePolicyRequest{ + Body: &frontierv1beta1.PolicyRequestBody{ + RoleId: "admin", + Resource: "project:" + testResourceID, + Principal: "user:" + testUserID, + }, + }), + want: connect.NewResponse(&frontierv1beta1.CreatePolicyResponse{ + Policy: &frontierv1beta1.Policy{ + Id: testPolicyID, + RoleId: "admin", + Resource: "app/project:" + testResourceID, + Principal: "app/user:" + testUserID, + Metadata: nil, + CreatedAt: timestamppb.New(fixedTime), + UpdatedAt: timestamppb.New(fixedTime), + }, + }), + }, + { + name: "should successfully create policy with metadata", + setup: func(ps *mocks.PolicyService) { + metadataMap := map[string]interface{}{ + "description": "Test policy", + "priority": "high", + } + ps.On("Create", mock.Anything, policy.Policy{ + RoleID: "viewer", + ResourceID: testResourceID, + ResourceType: "app/organization", + PrincipalID: testUserID, + PrincipalType: "app/user", + Metadata: metadata.Build(metadataMap), + }).Return(policy.Policy{ + ID: testPolicyID, + RoleID: "viewer", + ResourceID: testResourceID, + ResourceType: "app/organization", + PrincipalID: testUserID, + PrincipalType: "app/user", + Metadata: metadata.Build(metadataMap), + CreatedAt: fixedTime, + }, nil) + }, + request: connect.NewRequest(&frontierv1beta1.CreatePolicyRequest{ + Body: &frontierv1beta1.PolicyRequestBody{ + RoleId: "viewer", + Resource: "organization:" + testResourceID, + Principal: "user:" + testUserID, + Metadata: func() *structpb.Struct { + s, _ := structpb.NewStruct(map[string]interface{}{ + "description": "Test policy", + "priority": "high", + }) + return s + }(), + }, + }), + want: connect.NewResponse(&frontierv1beta1.CreatePolicyResponse{ + Policy: &frontierv1beta1.Policy{ + Id: testPolicyID, + RoleId: "viewer", + Resource: "app/organization:" + testResourceID, + Principal: "app/user:" + testUserID, + Metadata: func() *structpb.Struct { + s, _ := structpb.NewStruct(map[string]interface{}{ + "description": "Test policy", + "priority": "high", + }) + return s + }(), + CreatedAt: timestamppb.New(fixedTime), + }, + }), + }, + { + name: "should successfully create policy for group principal", + setup: func(ps *mocks.PolicyService) { + ps.On("Create", mock.Anything, policy.Policy{ + RoleID: "editor", + ResourceID: testResourceID, + ResourceType: "app/project", + PrincipalID: testGroupID, + PrincipalType: "app/group", + Metadata: metadata.Metadata(nil), + }).Return(policy.Policy{ + ID: testPolicyID, + RoleID: "editor", + ResourceID: testResourceID, + ResourceType: "app/project", + PrincipalID: testGroupID, + PrincipalType: "app/group", + Metadata: metadata.Metadata{}, + CreatedAt: fixedTime, + }, nil) + }, + request: connect.NewRequest(&frontierv1beta1.CreatePolicyRequest{ + Body: &frontierv1beta1.PolicyRequestBody{ + RoleId: "editor", + Resource: "project:" + testResourceID, + Principal: "group:" + testGroupID, + }, + }), + want: connect.NewResponse(&frontierv1beta1.CreatePolicyResponse{ + Policy: &frontierv1beta1.Policy{ + Id: testPolicyID, + RoleId: "editor", + Resource: "app/project:" + testResourceID, + Principal: "app/group:" + testGroupID, + Metadata: nil, + CreatedAt: timestamppb.New(fixedTime), + }, + }), + }, + { + name: "should return internal error when transformPolicyToPB fails due to metadata error", + setup: func(ps *mocks.PolicyService) { + // Create policy with metadata that will fail structpb conversion + invalidMetadata := metadata.Metadata{"invalid": make(chan int)} // channels can't be converted to protobuf + ps.On("Create", mock.Anything, policy.Policy{ + RoleID: "admin", + ResourceID: testResourceID, + ResourceType: "app/project", + PrincipalID: testUserID, + PrincipalType: "app/user", + Metadata: metadata.Metadata(nil), + }).Return(policy.Policy{ + ID: testPolicyID, + RoleID: "admin", + ResourceID: testResourceID, + ResourceType: "app/project", + PrincipalID: testUserID, + PrincipalType: "app/user", + Metadata: invalidMetadata, + CreatedAt: fixedTime, + }, nil) + }, + request: connect.NewRequest(&frontierv1beta1.CreatePolicyRequest{ + Body: &frontierv1beta1.PolicyRequestBody{ + RoleId: "admin", + Resource: "project:" + testResourceID, + Principal: "user:" + testUserID, + }, + }), + want: nil, + wantErr: ErrInternalServerError, + errCode: connect.CodeInternal, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mockPolicyService := &mocks.PolicyService{} + if tt.setup != nil { + tt.setup(mockPolicyService) + } + + handler := &ConnectHandler{ + policyService: mockPolicyService, + } + + got, err := handler.CreatePolicy(context.Background(), tt.request) + if tt.wantErr != nil { + assert.Error(t, err) + assert.Equal(t, tt.errCode, connect.CodeOf(err)) + assert.Nil(t, got) + } else { + assert.NoError(t, err) + assert.Equal(t, tt.want, got) + } + + mockPolicyService.AssertExpectations(t) + }) + } +} + +func TestConnectHandler_GetPolicy(t *testing.T) { + fixedTime := time.Date(2023, 1, 1, 0, 0, 0, 0, time.UTC) + testPolicyID := utils.NewString() + testUserID := utils.NewString() + testResourceID := utils.NewString() + testPolicyResourceType := "app/compute" + + testPolicy := policy.Policy{ + ID: testPolicyID, + PrincipalType: "app/user", + PrincipalID: testUserID, + ResourceID: testResourceID, + ResourceType: testPolicyResourceType, + RoleID: "reader", + CreatedAt: fixedTime, + UpdatedAt: fixedTime, + } + + tests := []struct { + name string + setup func(ps *mocks.PolicyService) + request *connect.Request[frontierv1beta1.GetPolicyRequest] + want *connect.Response[frontierv1beta1.GetPolicyResponse] + wantErr error + errCode connect.Code + }{ + { + name: "should return internal server error when policy service returns generic error", + setup: func(ps *mocks.PolicyService) { + ps.On("Get", mock.Anything, testPolicyID).Return(policy.Policy{}, errors.New("service error")) + }, + request: connect.NewRequest(&frontierv1beta1.GetPolicyRequest{ + Id: testPolicyID, + }), + want: nil, + wantErr: ErrInternalServerError, + errCode: connect.CodeInternal, + }, + { + name: "should return not found error when ID is empty", + setup: func(ps *mocks.PolicyService) { + ps.On("Get", mock.Anything, "").Return(policy.Policy{}, policy.ErrInvalidID) + }, + request: connect.NewRequest(&frontierv1beta1.GetPolicyRequest{}), + want: nil, + wantErr: ErrPolicyNotFound, + errCode: connect.CodeNotFound, + }, + { + name: "should return not found error when ID is not UUID", + setup: func(ps *mocks.PolicyService) { + ps.On("Get", mock.Anything, "some-id").Return(policy.Policy{}, policy.ErrInvalidUUID) + }, + request: connect.NewRequest(&frontierv1beta1.GetPolicyRequest{ + Id: "some-id", + }), + want: nil, + wantErr: ErrPolicyNotFound, + errCode: connect.CodeNotFound, + }, + { + name: "should return not found error when policy doesn't exist", + setup: func(ps *mocks.PolicyService) { + ps.On("Get", mock.Anything, testPolicyID).Return(policy.Policy{}, policy.ErrNotExist) + }, + request: connect.NewRequest(&frontierv1beta1.GetPolicyRequest{ + Id: testPolicyID, + }), + want: nil, + wantErr: ErrPolicyNotFound, + errCode: connect.CodeNotFound, + }, + { + name: "should successfully get policy with basic data", + setup: func(ps *mocks.PolicyService) { + ps.On("Get", mock.Anything, testPolicyID).Return(testPolicy, nil) + }, + request: connect.NewRequest(&frontierv1beta1.GetPolicyRequest{ + Id: testPolicyID, + }), + want: connect.NewResponse(&frontierv1beta1.GetPolicyResponse{ + Policy: &frontierv1beta1.Policy{ + Id: testPolicyID, + RoleId: "reader", + Resource: testPolicyResourceType + ":" + testResourceID, + Principal: "app/user:" + testUserID, + Metadata: nil, + CreatedAt: timestamppb.New(fixedTime), + UpdatedAt: timestamppb.New(fixedTime), + }, + }), + wantErr: nil, + }, + { + name: "should return internal error when transformPolicyToPB fails due to metadata error", + setup: func(ps *mocks.PolicyService) { + invalidPolicy := testPolicy + invalidPolicy.Metadata = metadata.Metadata{"invalid": make(chan int)} // channels can't be converted to protobuf + ps.On("Get", mock.Anything, testPolicyID).Return(invalidPolicy, nil) + }, + request: connect.NewRequest(&frontierv1beta1.GetPolicyRequest{ + Id: testPolicyID, + }), + want: nil, + wantErr: ErrInternalServerError, + errCode: connect.CodeInternal, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mockPolicyService := &mocks.PolicyService{} + if tt.setup != nil { + tt.setup(mockPolicyService) + } + + handler := &ConnectHandler{ + policyService: mockPolicyService, + } + + got, err := handler.GetPolicy(context.Background(), tt.request) + if tt.wantErr != nil { + assert.Error(t, err) + assert.Equal(t, tt.errCode, connect.CodeOf(err)) + assert.Nil(t, got) + } else { + assert.NoError(t, err) + assert.Equal(t, tt.want, got) + } + + mockPolicyService.AssertExpectations(t) + }) + } +} + +func TestConnectHandler_ListPolicies(t *testing.T) { + fixedTime := time.Date(2023, 1, 1, 0, 0, 0, 0, time.UTC) + testPolicyID1 := utils.NewString() + testPolicyID2 := utils.NewString() + testUserID := utils.NewString() + testResourceID1 := utils.NewString() + testResourceID2 := utils.NewString() + + testPolicies := []policy.Policy{ + { + ID: testPolicyID1, + PrincipalType: "app/user", + PrincipalID: testUserID, + ResourceID: testResourceID1, + ResourceType: "app/project", + RoleID: "admin", + CreatedAt: fixedTime, + UpdatedAt: fixedTime, + }, + { + ID: testPolicyID2, + PrincipalType: "app/user", + PrincipalID: testUserID, + ResourceID: testResourceID2, + ResourceType: "app/organization", + RoleID: "viewer", + CreatedAt: fixedTime, + UpdatedAt: fixedTime, + }, + } + + tests := []struct { + name string + setup func(ps *mocks.PolicyService) + request *connect.Request[frontierv1beta1.ListPoliciesRequest] + want *connect.Response[frontierv1beta1.ListPoliciesResponse] + wantErr error + errCode connect.Code + }{ + { + name: "should return internal server error when policy service returns error", + setup: func(ps *mocks.PolicyService) { + ps.On("List", mock.Anything, policy.Filter{}).Return([]policy.Policy{}, errors.New("service error")) + }, + request: connect.NewRequest(&frontierv1beta1.ListPoliciesRequest{}), + want: nil, + wantErr: ErrInternalServerError, + errCode: connect.CodeInternal, + }, + { + name: "should successfully list policies with empty filter", + setup: func(ps *mocks.PolicyService) { + ps.On("List", mock.Anything, policy.Filter{}).Return(testPolicies, nil) + }, + request: connect.NewRequest(&frontierv1beta1.ListPoliciesRequest{}), + want: connect.NewResponse(&frontierv1beta1.ListPoliciesResponse{ + Policies: []*frontierv1beta1.Policy{ + { + Id: testPolicyID1, + RoleId: "admin", + Resource: "app/project:" + testResourceID1, + Principal: "app/user:" + testUserID, + Metadata: nil, + CreatedAt: timestamppb.New(fixedTime), + UpdatedAt: timestamppb.New(fixedTime), + }, + { + Id: testPolicyID2, + RoleId: "viewer", + Resource: "app/organization:" + testResourceID2, + Principal: "app/user:" + testUserID, + Metadata: nil, + CreatedAt: timestamppb.New(fixedTime), + UpdatedAt: timestamppb.New(fixedTime), + }, + }, + }), + wantErr: nil, + }, + { + name: "should successfully list empty policies", + setup: func(ps *mocks.PolicyService) { + ps.On("List", mock.Anything, policy.Filter{}).Return([]policy.Policy{}, nil) + }, + request: connect.NewRequest(&frontierv1beta1.ListPoliciesRequest{}), + want: connect.NewResponse(&frontierv1beta1.ListPoliciesResponse{ + Policies: nil, + }), + wantErr: nil, + }, + { + name: "should return internal error when transformPolicyToPB fails due to metadata error", + setup: func(ps *mocks.PolicyService) { + invalidPolicy := testPolicies[0] + invalidPolicy.Metadata = metadata.Metadata{"invalid": make(chan int)} // channels can't be converted to protobuf + ps.On("List", mock.Anything, policy.Filter{}).Return([]policy.Policy{invalidPolicy}, nil) + }, + request: connect.NewRequest(&frontierv1beta1.ListPoliciesRequest{}), + want: nil, + wantErr: ErrInternalServerError, + errCode: connect.CodeInternal, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mockPolicyService := &mocks.PolicyService{} + if tt.setup != nil { + tt.setup(mockPolicyService) + } + + handler := &ConnectHandler{ + policyService: mockPolicyService, + } + + got, err := handler.ListPolicies(context.Background(), tt.request) + if tt.wantErr != nil { + assert.Error(t, err) + assert.Equal(t, tt.errCode, connect.CodeOf(err)) + assert.Nil(t, got) + } else { + assert.NoError(t, err) + assert.Equal(t, tt.want, got) + } + + mockPolicyService.AssertExpectations(t) + }) + } +} + +func TestConnectHandler_DeletePolicy(t *testing.T) { + testPolicyID := utils.NewString() + + tests := []struct { + name string + setup func(ps *mocks.PolicyService) + request *connect.Request[frontierv1beta1.DeletePolicyRequest] + want *connect.Response[frontierv1beta1.DeletePolicyResponse] + wantErr error + errCode connect.Code + }{ + { + name: "should return not found error when policy doesn't exist", + setup: func(ps *mocks.PolicyService) { + ps.On("Delete", mock.Anything, testPolicyID).Return(policy.ErrNotExist) + }, + request: connect.NewRequest(&frontierv1beta1.DeletePolicyRequest{ + Id: testPolicyID, + }), + want: nil, + wantErr: ErrPolicyNotFound, + errCode: connect.CodeNotFound, + }, + { + name: "should return not found error when policy ID is invalid", + setup: func(ps *mocks.PolicyService) { + ps.On("Delete", mock.Anything, "invalid-id").Return(policy.ErrInvalidID) + }, + request: connect.NewRequest(&frontierv1beta1.DeletePolicyRequest{ + Id: "invalid-id", + }), + want: nil, + wantErr: ErrPolicyNotFound, + errCode: connect.CodeNotFound, + }, + { + name: "should return not found error when policy UUID is invalid", + setup: func(ps *mocks.PolicyService) { + ps.On("Delete", mock.Anything, "not-a-uuid").Return(policy.ErrInvalidUUID) + }, + request: connect.NewRequest(&frontierv1beta1.DeletePolicyRequest{ + Id: "not-a-uuid", + }), + want: nil, + wantErr: ErrPolicyNotFound, + errCode: connect.CodeNotFound, + }, + { + name: "should return invalid argument error when policy details are invalid", + setup: func(ps *mocks.PolicyService) { + ps.On("Delete", mock.Anything, testPolicyID).Return(policy.ErrInvalidDetail) + }, + request: connect.NewRequest(&frontierv1beta1.DeletePolicyRequest{ + Id: testPolicyID, + }), + want: nil, + wantErr: ErrBadRequest, + errCode: connect.CodeInvalidArgument, + }, + { + name: "should return invalid argument error when namespace doesn't exist", + setup: func(ps *mocks.PolicyService) { + ps.On("Delete", mock.Anything, testPolicyID).Return(namespace.ErrNotExist) + }, + request: connect.NewRequest(&frontierv1beta1.DeletePolicyRequest{ + Id: testPolicyID, + }), + want: nil, + wantErr: ErrBadRequest, + errCode: connect.CodeInvalidArgument, + }, + { + name: "should return already exists error when policy has conflicts", + setup: func(ps *mocks.PolicyService) { + ps.On("Delete", mock.Anything, testPolicyID).Return(policy.ErrConflict) + }, + request: connect.NewRequest(&frontierv1beta1.DeletePolicyRequest{ + Id: testPolicyID, + }), + want: nil, + wantErr: ErrConflictRequest, + errCode: connect.CodeAlreadyExists, + }, + { + name: "should return internal server error when policy service returns unknown error", + setup: func(ps *mocks.PolicyService) { + ps.On("Delete", mock.Anything, testPolicyID).Return(errors.New("service error")) + }, + request: connect.NewRequest(&frontierv1beta1.DeletePolicyRequest{ + Id: testPolicyID, + }), + want: nil, + wantErr: ErrInternalServerError, + errCode: connect.CodeInternal, + }, + { + name: "should successfully delete policy", + setup: func(ps *mocks.PolicyService) { + ps.On("Delete", mock.Anything, testPolicyID).Return(nil) + }, + request: connect.NewRequest(&frontierv1beta1.DeletePolicyRequest{ + Id: testPolicyID, + }), + want: connect.NewResponse(&frontierv1beta1.DeletePolicyResponse{}), + wantErr: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mockPolicyService := &mocks.PolicyService{} + if tt.setup != nil { + tt.setup(mockPolicyService) + } + + handler := &ConnectHandler{ + policyService: mockPolicyService, + } + + got, err := handler.DeletePolicy(context.Background(), tt.request) + if tt.wantErr != nil { + assert.Error(t, err) + assert.Equal(t, tt.errCode, connect.CodeOf(err)) + assert.Nil(t, got) + } else { + assert.NoError(t, err) + assert.Equal(t, tt.want, got) + } + + mockPolicyService.AssertExpectations(t) + }) + } +} + +func TestConnectHandler_CreatePolicyForProject(t *testing.T) { + fixedTime := time.Date(2023, 1, 1, 0, 0, 0, 0, time.UTC) + testPolicyID := utils.NewString() + testUserID := utils.NewString() + testProjectID := utils.NewString() + testRoleID := "admin" + + testProject := project.Project{ + ID: testProjectID, + Name: "test-project", + CreatedAt: fixedTime, + UpdatedAt: fixedTime, + } + + tests := []struct { + name string + setupPolicy func(ps *mocks.PolicyService) + setupProject func(ps *projectMocks.ProjectService) + request *connect.Request[frontierv1beta1.CreatePolicyForProjectRequest] + want *connect.Response[frontierv1beta1.CreatePolicyForProjectResponse] + wantErr error + errCode connect.Code + }{ + { + name: "should return invalid argument error when body is nil", + setupPolicy: func(ps *mocks.PolicyService) { + // No expectations as we return early + }, + setupProject: func(ps *projectMocks.ProjectService) { + // No expectations as we return early + }, + request: connect.NewRequest(&frontierv1beta1.CreatePolicyForProjectRequest{ + ProjectId: testProjectID, + Body: nil, + }), + want: nil, + wantErr: ErrBadRequest, + errCode: connect.CodeInvalidArgument, + }, + { + name: "should return invalid argument error when role ID is empty", + setupPolicy: func(ps *mocks.PolicyService) { + // No expectations as we return early + }, + setupProject: func(ps *projectMocks.ProjectService) { + // No expectations as we return early + }, + request: connect.NewRequest(&frontierv1beta1.CreatePolicyForProjectRequest{ + ProjectId: testProjectID, + Body: &frontierv1beta1.CreatePolicyForProjectBody{ + RoleId: "", + Principal: "user:" + testUserID, + }, + }), + want: nil, + wantErr: ErrBadRequest, + errCode: connect.CodeInvalidArgument, + }, + { + name: "should return invalid argument error when principal is empty", + setupPolicy: func(ps *mocks.PolicyService) { + // No expectations as we return early + }, + setupProject: func(ps *projectMocks.ProjectService) { + // No expectations as we return early + }, + request: connect.NewRequest(&frontierv1beta1.CreatePolicyForProjectRequest{ + ProjectId: testProjectID, + Body: &frontierv1beta1.CreatePolicyForProjectBody{ + RoleId: testRoleID, + Principal: "", + }, + }), + want: nil, + wantErr: ErrBadRequest, + errCode: connect.CodeInvalidArgument, + }, + { + name: "should return invalid argument error when principal namespace splitting fails", + setupPolicy: func(ps *mocks.PolicyService) { + // No expectations as we return early + }, + setupProject: func(ps *projectMocks.ProjectService) { + // No expectations as we return early + }, + request: connect.NewRequest(&frontierv1beta1.CreatePolicyForProjectRequest{ + ProjectId: testProjectID, + Body: &frontierv1beta1.CreatePolicyForProjectBody{ + RoleId: testRoleID, + Principal: "invalid-principal-format", + }, + }), + want: nil, + wantErr: ErrNamespaceSplitNotation, + errCode: connect.CodeInvalidArgument, + }, + { + name: "should return not found error when project doesn't exist", + setupPolicy: func(ps *mocks.PolicyService) { + // No expectations as we return early + }, + setupProject: func(ps *projectMocks.ProjectService) { + ps.On("Get", mock.Anything, testProjectID).Return(project.Project{}, errors.New("project not found")) + }, + request: connect.NewRequest(&frontierv1beta1.CreatePolicyForProjectRequest{ + ProjectId: testProjectID, + Body: &frontierv1beta1.CreatePolicyForProjectBody{ + RoleId: testRoleID, + Principal: "user:" + testUserID, + }, + }), + want: nil, + wantErr: ErrProjectNotFound, + errCode: connect.CodeNotFound, + }, + { + name: "should return invalid argument error when role ID is invalid", + setupPolicy: func(ps *mocks.PolicyService) { + ps.On("Create", mock.Anything, policy.Policy{ + RoleID: testRoleID, + PrincipalType: "app/user", + PrincipalID: testUserID, + ResourceID: testProjectID, + ResourceType: "app/project", + }).Return(policy.Policy{}, role.ErrInvalidID) + }, + setupProject: func(ps *projectMocks.ProjectService) { + ps.On("Get", mock.Anything, testProjectID).Return(testProject, nil) + }, + request: connect.NewRequest(&frontierv1beta1.CreatePolicyForProjectRequest{ + ProjectId: testProjectID, + Body: &frontierv1beta1.CreatePolicyForProjectBody{ + RoleId: testRoleID, + Principal: "user:" + testUserID, + }, + }), + want: nil, + wantErr: ErrInvalidRoleID, + errCode: connect.CodeInvalidArgument, + }, + { + name: "should return invalid argument error when policy details are invalid", + setupPolicy: func(ps *mocks.PolicyService) { + ps.On("Create", mock.Anything, policy.Policy{ + RoleID: testRoleID, + PrincipalType: "app/user", + PrincipalID: testUserID, + ResourceID: testProjectID, + ResourceType: "app/project", + }).Return(policy.Policy{}, policy.ErrInvalidDetail) + }, + setupProject: func(ps *projectMocks.ProjectService) { + ps.On("Get", mock.Anything, testProjectID).Return(testProject, nil) + }, + request: connect.NewRequest(&frontierv1beta1.CreatePolicyForProjectRequest{ + ProjectId: testProjectID, + Body: &frontierv1beta1.CreatePolicyForProjectBody{ + RoleId: testRoleID, + Principal: "user:" + testUserID, + }, + }), + want: nil, + wantErr: ErrBadRequest, + errCode: connect.CodeInvalidArgument, + }, + { + name: "should return internal server error when policy service returns unknown error", + setupPolicy: func(ps *mocks.PolicyService) { + ps.On("Create", mock.Anything, policy.Policy{ + RoleID: testRoleID, + PrincipalType: "app/user", + PrincipalID: testUserID, + ResourceID: testProjectID, + ResourceType: "app/project", + }).Return(policy.Policy{}, errors.New("service error")) + }, + setupProject: func(ps *projectMocks.ProjectService) { + ps.On("Get", mock.Anything, testProjectID).Return(testProject, nil) + }, + request: connect.NewRequest(&frontierv1beta1.CreatePolicyForProjectRequest{ + ProjectId: testProjectID, + Body: &frontierv1beta1.CreatePolicyForProjectBody{ + RoleId: testRoleID, + Principal: "user:" + testUserID, + }, + }), + want: nil, + wantErr: ErrInternalServerError, + errCode: connect.CodeInternal, + }, + { + name: "should successfully create policy for project", + setupPolicy: func(ps *mocks.PolicyService) { + ps.On("Create", mock.Anything, policy.Policy{ + RoleID: testRoleID, + PrincipalType: "app/user", + PrincipalID: testUserID, + ResourceID: testProjectID, + ResourceType: "app/project", + }).Return(policy.Policy{ + ID: testPolicyID, + RoleID: testRoleID, + PrincipalType: "app/user", + PrincipalID: testUserID, + ResourceID: testProjectID, + ResourceType: "app/project", + CreatedAt: fixedTime, + UpdatedAt: fixedTime, + }, nil) + }, + setupProject: func(ps *projectMocks.ProjectService) { + ps.On("Get", mock.Anything, testProjectID).Return(testProject, nil) + }, + request: connect.NewRequest(&frontierv1beta1.CreatePolicyForProjectRequest{ + ProjectId: testProjectID, + Body: &frontierv1beta1.CreatePolicyForProjectBody{ + RoleId: testRoleID, + Principal: "user:" + testUserID, + }, + }), + want: connect.NewResponse(&frontierv1beta1.CreatePolicyForProjectResponse{}), + wantErr: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mockPolicyService := &mocks.PolicyService{} + mockProjectService := &projectMocks.ProjectService{} + + if tt.setupPolicy != nil { + tt.setupPolicy(mockPolicyService) + } + if tt.setupProject != nil { + tt.setupProject(mockProjectService) + } + + handler := &ConnectHandler{ + policyService: mockPolicyService, + projectService: mockProjectService, + } + + got, err := handler.CreatePolicyForProject(context.Background(), tt.request) + if tt.wantErr != nil { + assert.Error(t, err) + assert.Equal(t, tt.errCode, connect.CodeOf(err)) + assert.Nil(t, got) + } else { + assert.NoError(t, err) + assert.Equal(t, tt.want, got) + } + + mockPolicyService.AssertExpectations(t) + mockProjectService.AssertExpectations(t) + }) + } +} diff --git a/internal/api/v1beta1connect/relation.go b/internal/api/v1beta1connect/relation.go index 86935a1c2..0f0b915f9 100644 --- a/internal/api/v1beta1connect/relation.go +++ b/internal/api/v1beta1connect/relation.go @@ -2,7 +2,6 @@ package v1beta1connect import ( "context" - "errors" "connectrpc.com/connect" "github.com/raystack/frontier/core/relation" @@ -18,10 +17,6 @@ type RelationService interface { Delete(ctx context.Context, rel relation.Relation) error } -var ( - ErrNamespaceSplitNotation = errors.New("subject/object should be provided as 'namespace:uuid'") -) - func (h *ConnectHandler) ListRelations(ctx context.Context, request *connect.Request[frontierv1beta1.ListRelationsRequest]) (*connect.Response[frontierv1beta1.ListRelationsResponse], error) { var err error var subject relation.Subject