Skip to content

Commit 23ad5d1

Browse files
committed
feat: migrate CreatePolicyForProject API to Connect RPC
1 parent 62ae3f2 commit 23ad5d1

File tree

3 files changed

+300
-0
lines changed

3 files changed

+300
-0
lines changed

internal/api/v1beta1connect/errors.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,4 +49,5 @@ var (
4949
ErrInvalidRoleID = errors.New("role id is invalid")
5050
ErrNamespaceSplitNotation = errors.New("subject/object should be provided as 'namespace:uuid'")
5151
ErrPolicyNotFound = errors.New("policy doesn't exist")
52+
ErrProjectNotFound = errors.New("project doesn't exist")
5253
)

internal/api/v1beta1connect/policy.go

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,45 @@ func (h *ConnectHandler) DeletePolicy(ctx context.Context, request *connect.Requ
117117
return connect.NewResponse(&frontierv1beta1.DeletePolicyResponse{}), nil
118118
}
119119

120+
func (h *ConnectHandler) CreatePolicyForProject(ctx context.Context, request *connect.Request[frontierv1beta1.CreatePolicyForProjectRequest]) (*connect.Response[frontierv1beta1.CreatePolicyForProjectResponse], error) {
121+
if request.Msg.GetBody() == nil || request.Msg.GetBody().GetRoleId() == "" || request.Msg.GetBody().GetPrincipal() == "" {
122+
return nil, connect.NewError(connect.CodeInvalidArgument, ErrBadRequest)
123+
}
124+
125+
principalType, principalID, err := schema.SplitNamespaceAndResourceID(request.Msg.GetBody().GetPrincipal())
126+
if err != nil {
127+
return nil, connect.NewError(connect.CodeInvalidArgument, ErrNamespaceSplitNotation)
128+
}
129+
130+
project, err := h.projectService.Get(ctx, request.Msg.GetProjectId())
131+
if err != nil {
132+
return nil, connect.NewError(connect.CodeNotFound, ErrProjectNotFound)
133+
}
134+
135+
p := policy.Policy{
136+
RoleID: request.Msg.GetBody().GetRoleId(),
137+
PrincipalType: principalType,
138+
PrincipalID: principalID,
139+
ResourceID: project.ID,
140+
ResourceType: schema.ProjectNamespace,
141+
}
142+
143+
newPolicy, err := h.policyService.Create(ctx, p)
144+
if err != nil {
145+
switch {
146+
case errors.Is(err, role.ErrInvalidID):
147+
return nil, connect.NewError(connect.CodeInvalidArgument, ErrInvalidRoleID)
148+
case errors.Is(err, policy.ErrInvalidDetail):
149+
return nil, connect.NewError(connect.CodeInvalidArgument, ErrBadRequest)
150+
default:
151+
return nil, connect.NewError(connect.CodeInternal, ErrInternalServerError)
152+
}
153+
}
154+
155+
auditPolicyCreationEvent(ctx, newPolicy)
156+
return connect.NewResponse(&frontierv1beta1.CreatePolicyForProjectResponse{}), nil
157+
}
158+
120159
func (h *ConnectHandler) ListPolicies(ctx context.Context, request *connect.Request[frontierv1beta1.ListPoliciesRequest]) (*connect.Response[frontierv1beta1.ListPoliciesResponse], error) {
121160
var policies []*frontierv1beta1.Policy
122161

internal/api/v1beta1connect/policy_test.go

Lines changed: 260 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,9 @@ import (
99
"connectrpc.com/connect"
1010
"github.com/raystack/frontier/core/namespace"
1111
"github.com/raystack/frontier/core/policy"
12+
"github.com/raystack/frontier/core/project"
1213
"github.com/raystack/frontier/core/role"
14+
projectMocks "github.com/raystack/frontier/internal/api/v1beta1/mocks"
1315
"github.com/raystack/frontier/internal/api/v1beta1connect/mocks"
1416
"github.com/raystack/frontier/pkg/metadata"
1517
"github.com/raystack/frontier/pkg/utils"
@@ -760,3 +762,261 @@ func TestConnectHandler_DeletePolicy(t *testing.T) {
760762
})
761763
}
762764
}
765+
766+
func TestConnectHandler_CreatePolicyForProject(t *testing.T) {
767+
fixedTime := time.Date(2023, 1, 1, 0, 0, 0, 0, time.UTC)
768+
testPolicyID := utils.NewString()
769+
testUserID := utils.NewString()
770+
testProjectID := utils.NewString()
771+
testRoleID := "admin"
772+
773+
testProject := project.Project{
774+
ID: testProjectID,
775+
Name: "test-project",
776+
CreatedAt: fixedTime,
777+
UpdatedAt: fixedTime,
778+
}
779+
780+
tests := []struct {
781+
name string
782+
setupPolicy func(ps *mocks.PolicyService)
783+
setupProject func(ps *projectMocks.ProjectService)
784+
request *connect.Request[frontierv1beta1.CreatePolicyForProjectRequest]
785+
want *connect.Response[frontierv1beta1.CreatePolicyForProjectResponse]
786+
wantErr error
787+
errCode connect.Code
788+
}{
789+
{
790+
name: "should return invalid argument error when body is nil",
791+
setupPolicy: func(ps *mocks.PolicyService) {
792+
// No expectations as we return early
793+
},
794+
setupProject: func(ps *projectMocks.ProjectService) {
795+
// No expectations as we return early
796+
},
797+
request: connect.NewRequest(&frontierv1beta1.CreatePolicyForProjectRequest{
798+
ProjectId: testProjectID,
799+
Body: nil,
800+
}),
801+
want: nil,
802+
wantErr: ErrBadRequest,
803+
errCode: connect.CodeInvalidArgument,
804+
},
805+
{
806+
name: "should return invalid argument error when role ID is empty",
807+
setupPolicy: func(ps *mocks.PolicyService) {
808+
// No expectations as we return early
809+
},
810+
setupProject: func(ps *projectMocks.ProjectService) {
811+
// No expectations as we return early
812+
},
813+
request: connect.NewRequest(&frontierv1beta1.CreatePolicyForProjectRequest{
814+
ProjectId: testProjectID,
815+
Body: &frontierv1beta1.CreatePolicyForProjectBody{
816+
RoleId: "",
817+
Principal: "user:" + testUserID,
818+
},
819+
}),
820+
want: nil,
821+
wantErr: ErrBadRequest,
822+
errCode: connect.CodeInvalidArgument,
823+
},
824+
{
825+
name: "should return invalid argument error when principal is empty",
826+
setupPolicy: func(ps *mocks.PolicyService) {
827+
// No expectations as we return early
828+
},
829+
setupProject: func(ps *projectMocks.ProjectService) {
830+
// No expectations as we return early
831+
},
832+
request: connect.NewRequest(&frontierv1beta1.CreatePolicyForProjectRequest{
833+
ProjectId: testProjectID,
834+
Body: &frontierv1beta1.CreatePolicyForProjectBody{
835+
RoleId: testRoleID,
836+
Principal: "",
837+
},
838+
}),
839+
want: nil,
840+
wantErr: ErrBadRequest,
841+
errCode: connect.CodeInvalidArgument,
842+
},
843+
{
844+
name: "should return invalid argument error when principal namespace splitting fails",
845+
setupPolicy: func(ps *mocks.PolicyService) {
846+
// No expectations as we return early
847+
},
848+
setupProject: func(ps *projectMocks.ProjectService) {
849+
// No expectations as we return early
850+
},
851+
request: connect.NewRequest(&frontierv1beta1.CreatePolicyForProjectRequest{
852+
ProjectId: testProjectID,
853+
Body: &frontierv1beta1.CreatePolicyForProjectBody{
854+
RoleId: testRoleID,
855+
Principal: "invalid-principal-format",
856+
},
857+
}),
858+
want: nil,
859+
wantErr: ErrNamespaceSplitNotation,
860+
errCode: connect.CodeInvalidArgument,
861+
},
862+
{
863+
name: "should return not found error when project doesn't exist",
864+
setupPolicy: func(ps *mocks.PolicyService) {
865+
// No expectations as we return early
866+
},
867+
setupProject: func(ps *projectMocks.ProjectService) {
868+
ps.On("Get", mock.Anything, testProjectID).Return(project.Project{}, errors.New("project not found"))
869+
},
870+
request: connect.NewRequest(&frontierv1beta1.CreatePolicyForProjectRequest{
871+
ProjectId: testProjectID,
872+
Body: &frontierv1beta1.CreatePolicyForProjectBody{
873+
RoleId: testRoleID,
874+
Principal: "user:" + testUserID,
875+
},
876+
}),
877+
want: nil,
878+
wantErr: ErrProjectNotFound,
879+
errCode: connect.CodeNotFound,
880+
},
881+
{
882+
name: "should return invalid argument error when role ID is invalid",
883+
setupPolicy: func(ps *mocks.PolicyService) {
884+
ps.On("Create", mock.Anything, policy.Policy{
885+
RoleID: testRoleID,
886+
PrincipalType: "app/user",
887+
PrincipalID: testUserID,
888+
ResourceID: testProjectID,
889+
ResourceType: "app/project",
890+
}).Return(policy.Policy{}, role.ErrInvalidID)
891+
},
892+
setupProject: func(ps *projectMocks.ProjectService) {
893+
ps.On("Get", mock.Anything, testProjectID).Return(testProject, nil)
894+
},
895+
request: connect.NewRequest(&frontierv1beta1.CreatePolicyForProjectRequest{
896+
ProjectId: testProjectID,
897+
Body: &frontierv1beta1.CreatePolicyForProjectBody{
898+
RoleId: testRoleID,
899+
Principal: "user:" + testUserID,
900+
},
901+
}),
902+
want: nil,
903+
wantErr: ErrInvalidRoleID,
904+
errCode: connect.CodeInvalidArgument,
905+
},
906+
{
907+
name: "should return invalid argument error when policy details are invalid",
908+
setupPolicy: func(ps *mocks.PolicyService) {
909+
ps.On("Create", mock.Anything, policy.Policy{
910+
RoleID: testRoleID,
911+
PrincipalType: "app/user",
912+
PrincipalID: testUserID,
913+
ResourceID: testProjectID,
914+
ResourceType: "app/project",
915+
}).Return(policy.Policy{}, policy.ErrInvalidDetail)
916+
},
917+
setupProject: func(ps *projectMocks.ProjectService) {
918+
ps.On("Get", mock.Anything, testProjectID).Return(testProject, nil)
919+
},
920+
request: connect.NewRequest(&frontierv1beta1.CreatePolicyForProjectRequest{
921+
ProjectId: testProjectID,
922+
Body: &frontierv1beta1.CreatePolicyForProjectBody{
923+
RoleId: testRoleID,
924+
Principal: "user:" + testUserID,
925+
},
926+
}),
927+
want: nil,
928+
wantErr: ErrBadRequest,
929+
errCode: connect.CodeInvalidArgument,
930+
},
931+
{
932+
name: "should return internal server error when policy service returns unknown error",
933+
setupPolicy: func(ps *mocks.PolicyService) {
934+
ps.On("Create", mock.Anything, policy.Policy{
935+
RoleID: testRoleID,
936+
PrincipalType: "app/user",
937+
PrincipalID: testUserID,
938+
ResourceID: testProjectID,
939+
ResourceType: "app/project",
940+
}).Return(policy.Policy{}, errors.New("service error"))
941+
},
942+
setupProject: func(ps *projectMocks.ProjectService) {
943+
ps.On("Get", mock.Anything, testProjectID).Return(testProject, nil)
944+
},
945+
request: connect.NewRequest(&frontierv1beta1.CreatePolicyForProjectRequest{
946+
ProjectId: testProjectID,
947+
Body: &frontierv1beta1.CreatePolicyForProjectBody{
948+
RoleId: testRoleID,
949+
Principal: "user:" + testUserID,
950+
},
951+
}),
952+
want: nil,
953+
wantErr: ErrInternalServerError,
954+
errCode: connect.CodeInternal,
955+
},
956+
{
957+
name: "should successfully create policy for project",
958+
setupPolicy: func(ps *mocks.PolicyService) {
959+
ps.On("Create", mock.Anything, policy.Policy{
960+
RoleID: testRoleID,
961+
PrincipalType: "app/user",
962+
PrincipalID: testUserID,
963+
ResourceID: testProjectID,
964+
ResourceType: "app/project",
965+
}).Return(policy.Policy{
966+
ID: testPolicyID,
967+
RoleID: testRoleID,
968+
PrincipalType: "app/user",
969+
PrincipalID: testUserID,
970+
ResourceID: testProjectID,
971+
ResourceType: "app/project",
972+
CreatedAt: fixedTime,
973+
UpdatedAt: fixedTime,
974+
}, nil)
975+
},
976+
setupProject: func(ps *projectMocks.ProjectService) {
977+
ps.On("Get", mock.Anything, testProjectID).Return(testProject, nil)
978+
},
979+
request: connect.NewRequest(&frontierv1beta1.CreatePolicyForProjectRequest{
980+
ProjectId: testProjectID,
981+
Body: &frontierv1beta1.CreatePolicyForProjectBody{
982+
RoleId: testRoleID,
983+
Principal: "user:" + testUserID,
984+
},
985+
}),
986+
want: connect.NewResponse(&frontierv1beta1.CreatePolicyForProjectResponse{}),
987+
wantErr: nil,
988+
},
989+
}
990+
991+
for _, tt := range tests {
992+
t.Run(tt.name, func(t *testing.T) {
993+
mockPolicyService := &mocks.PolicyService{}
994+
mockProjectService := &projectMocks.ProjectService{}
995+
996+
if tt.setupPolicy != nil {
997+
tt.setupPolicy(mockPolicyService)
998+
}
999+
if tt.setupProject != nil {
1000+
tt.setupProject(mockProjectService)
1001+
}
1002+
1003+
handler := &ConnectHandler{
1004+
policyService: mockPolicyService,
1005+
projectService: mockProjectService,
1006+
}
1007+
1008+
got, err := handler.CreatePolicyForProject(context.Background(), tt.request)
1009+
if tt.wantErr != nil {
1010+
assert.Error(t, err)
1011+
assert.Equal(t, tt.errCode, connect.CodeOf(err))
1012+
assert.Nil(t, got)
1013+
} else {
1014+
assert.NoError(t, err)
1015+
assert.Equal(t, tt.want, got)
1016+
}
1017+
1018+
mockPolicyService.AssertExpectations(t)
1019+
mockProjectService.AssertExpectations(t)
1020+
})
1021+
}
1022+
}

0 commit comments

Comments
 (0)