From 2f6149c47de974384a6da7d016a4c194ad3bb096 Mon Sep 17 00:00:00 2001 From: Kian Parvin <46668016+kian99@users.noreply.github.com> Date: Fri, 29 Nov 2024 15:08:23 +0200 Subject: [PATCH 1/3] feat: add group to role rebac admin handlers (#1472) --- internal/jimmhttp/rebac_admin/groups.go | 111 +++++++++++++++- .../rebac_admin/groups_integration_test.go | 61 +++++++++ internal/jimmhttp/rebac_admin/groups_test.go | 119 ++++++++++++++++++ 3 files changed, 288 insertions(+), 3 deletions(-) diff --git a/internal/jimmhttp/rebac_admin/groups.go b/internal/jimmhttp/rebac_admin/groups.go index d29db1c29..fbd26b77a 100644 --- a/internal/jimmhttp/rebac_admin/groups.go +++ b/internal/jimmhttp/rebac_admin/groups.go @@ -14,6 +14,7 @@ import ( "github.com/canonical/jimm/v3/internal/errors" "github.com/canonical/jimm/v3/internal/jimmhttp/rebac_admin/utils" "github.com/canonical/jimm/v3/internal/jujuapi" + "github.com/canonical/jimm/v3/internal/openfga" ofganames "github.com/canonical/jimm/v3/internal/openfga/names" apiparams "github.com/canonical/jimm/v3/pkg/api/params" jimmnames "github.com/canonical/jimm/v3/pkg/names" @@ -236,12 +237,116 @@ func (s *groupsService) PatchGroupIdentities(ctx context.Context, groupId string // GetGroupRoles returns a page of Roles for Group `groupId`. func (s *groupsService) GetGroupRoles(ctx context.Context, groupId string, params *resources.GetGroupsItemRolesParams) (*resources.PaginatedResponse[resources.Role], error) { - return nil, v1.NewNotImplementedError("get group roles not implemented") + user, err := utils.GetUserFromContext(ctx) + if err != nil { + return nil, err + } + + if !jimmnames.IsValidGroupId(groupId) { + return nil, v1.NewValidationError("invalid group ID") + } + + filter := utils.CreateTokenPaginationFilter(params.Size, params.NextToken, params.NextPageToken) + groupTag := jimmnames.NewGroupTag(groupId) + _, err = s.jimm.GetGroupByUUID(ctx, user, groupId) + if err != nil { + if errors.ErrorCode(err) == errors.CodeNotFound { + return nil, v1.NewNotFoundError("group not found") + } + return nil, err + } + + tuple := apiparams.RelationshipTuple{ + Object: ofganames.WithMemberRelation(groupTag), + Relation: ofganames.AssigneeRelation.String(), + TargetObject: openfga.RoleType.String(), + } + roles, nextToken, err := s.jimm.ListRelationshipTuples(ctx, user, tuple, int32(filter.Limit()), filter.Token()) // #nosec G115 accept integer conversion + if err != nil { + return nil, err + } + + data := make([]resources.Role, 0, len(roles)) + for _, role := range roles { + roleUUID := role.Target.ID + roleEntry, err := s.jimm.GetRoleManager().GetRoleByUUID(ctx, user, roleUUID) + if err != nil { + // If a role does not exist in the database but a linger tuple exists, drop the role from the results. + if errors.ErrorCode(err) == errors.CodeNotFound { + continue + } + return nil, err + } + data = append(data, resources.Role{ + Id: &roleUUID, + Name: roleEntry.Name, + }, + ) + } + + originalToken := filter.Token() + resp := resources.PaginatedResponse[resources.Role]{ + Meta: resources.ResponseMeta{ + Size: len(data), + PageToken: &originalToken, + }, + Data: data, + } + if nextToken != "" { + resp.Next = resources.Next{ + PageToken: &nextToken, + } + } + return &resp, nil } -// PatchGroupRoles performs addition or removal of a Role to/from a Group identified by `groupId`. +// PatchGroupRoles performs addition or removal of a group to/from a role identified by `groupId`. func (s *groupsService) PatchGroupRoles(ctx context.Context, groupId string, rolePatches []resources.GroupRolesPatchItem) (bool, error) { - return false, v1.NewNotImplementedError("patch group roles not implemented") + user, err := utils.GetUserFromContext(ctx) + if err != nil { + return false, err + } + if !jimmnames.IsValidGroupId(groupId) { + return false, v1.NewValidationError("invalid group ID") + } + + groupTag := jimmnames.NewGroupTag(groupId) + tuple := apiparams.RelationshipTuple{ + Object: ofganames.WithMemberRelation(groupTag), + Relation: ofganames.AssigneeRelation.String(), + } + + var toRemove []apiparams.RelationshipTuple + var toAdd []apiparams.RelationshipTuple + for _, rolePatch := range rolePatches { + if !jimmnames.IsValidRoleId(rolePatch.Role) { + return false, v1.NewValidationError(fmt.Sprintf("invalid role ID: %s", rolePatch.Role)) + } + role := jimmnames.NewRoleTag(rolePatch.Role) + if rolePatch.Op == resources.GroupRolesPatchItemOpAdd { + t := tuple + t.TargetObject = role.String() + toAdd = append(toAdd, t) + } else { + t := tuple + t.TargetObject = role.String() + toRemove = append(toRemove, t) + } + } + + if toAdd != nil { + err := s.jimm.AddRelation(ctx, user, toAdd) + if err != nil { + return false, err + } + } + if toRemove != nil { + err := s.jimm.RemoveRelation(ctx, user, toRemove) + if err != nil { + return false, err + } + } + return true, nil } // GetGroupEntitlements returns a page of Entitlements for Group `groupId`. diff --git a/internal/jimmhttp/rebac_admin/groups_integration_test.go b/internal/jimmhttp/rebac_admin/groups_integration_test.go index 9a28ebb9e..de1e89628 100644 --- a/internal/jimmhttp/rebac_admin/groups_integration_test.go +++ b/internal/jimmhttp/rebac_admin/groups_integration_test.go @@ -151,6 +151,67 @@ func (s rebacAdminSuite) TestPatchGroupIdentitiesIntegration(c *gc.C) { c.Assert(allowed, gc.Equals, true) } +func (s rebacAdminSuite) TestGetGroupRolesIntegration(c *gc.C) { + ctx := context.Background() + group := s.AddGroup(c, "test-group") + role := s.AddRole(c, "test-role") + tuple := openfga.Tuple{ + Object: ofganames.ConvertTagWithRelation(jimmnames.NewGroupTag(group.UUID), ofganames.MemberRelation), + Relation: ofganames.AssigneeRelation, + Target: ofganames.ConvertTag(jimmnames.NewRoleTag(role.UUID)), + } + err := s.JIMM.OpenFGAClient.AddRelation(ctx, tuple) + c.Assert(err, gc.IsNil) + + params := &resources.GetGroupsItemRolesParams{} + ctx = rebac_handlers.ContextWithIdentity(ctx, s.AdminUser) + res, err := s.groupSvc.GetGroupRoles(ctx, group.UUID, params) + c.Assert(err, gc.IsNil) + c.Assert(res, gc.Not(gc.IsNil)) + c.Assert(res.Meta.Size, gc.Equals, 1) + c.Assert(*res.Meta.PageToken, gc.Equals, "") + c.Assert(res.Next.PageToken, gc.IsNil) + c.Assert(res.Data, gc.HasLen, 1) + c.Assert(res.Data[0].Id, gc.Not(gc.IsNil)) + c.Assert(*res.Data[0].Id, gc.Equals, role.UUID) + c.Assert(res.Data[0].Name, gc.Equals, role.Name) +} + +func (s rebacAdminSuite) TestPatchGroupRolesIntegration(c *gc.C) { + ctx := context.Background() + group := s.AddGroup(c, "test-group") + role := s.AddRole(c, "test-role") + + // Assign the role to the group. + rolePatches := []resources.GroupRolesPatchItem{ + {Role: role.UUID, Op: resources.GroupRolesPatchItemOpAdd}, + } + ctx = rebac_handlers.ContextWithIdentity(ctx, s.AdminUser) + res, err := s.groupSvc.PatchGroupRoles(ctx, group.UUID, rolePatches) + c.Assert(err, gc.IsNil) + c.Assert(res, gc.Equals, true) + + checkTuple := openfga.Tuple{ + Object: ofganames.ConvertTagWithRelation(group.ResourceTag(), ofganames.MemberRelation), + Relation: ofganames.AssigneeRelation, + Target: ofganames.ConvertTag(role.ResourceTag()), + } + allowed, err := s.JIMM.OpenFGAClient.CheckRelation(ctx, checkTuple, false) + c.Assert(err, gc.IsNil) + c.Assert(allowed, gc.Equals, true) + + // Remove the role from the group. + rolePatches[0].Op = resources.GroupRolesPatchItemOpRemove + ctx = rebac_handlers.ContextWithIdentity(ctx, s.AdminUser) + res, err = s.groupSvc.PatchGroupRoles(ctx, group.UUID, rolePatches) + c.Assert(err, gc.IsNil) + c.Assert(res, gc.Equals, true) + + allowed, err = s.JIMM.OpenFGAClient.CheckRelation(ctx, checkTuple, false) + c.Assert(err, gc.IsNil) + c.Assert(allowed, gc.Equals, false) +} + func (s rebacAdminSuite) TestGetGroupEntitlementsIntegration(c *gc.C) { ctx := context.Background() group, err := s.JIMM.AddGroup(ctx, s.AdminUser, "test-group") diff --git a/internal/jimmhttp/rebac_admin/groups_test.go b/internal/jimmhttp/rebac_admin/groups_test.go index 02a515e1c..3c91270b8 100644 --- a/internal/jimmhttp/rebac_admin/groups_test.go +++ b/internal/jimmhttp/rebac_admin/groups_test.go @@ -15,8 +15,11 @@ import ( "github.com/canonical/jimm/v3/internal/common/pagination" "github.com/canonical/jimm/v3/internal/dbmodel" + jimmerr "github.com/canonical/jimm/v3/internal/errors" + "github.com/canonical/jimm/v3/internal/jimm" "github.com/canonical/jimm/v3/internal/jimmhttp/rebac_admin" "github.com/canonical/jimm/v3/internal/openfga" + ofganames "github.com/canonical/jimm/v3/internal/openfga/names" "github.com/canonical/jimm/v3/internal/testutils/jimmtest" "github.com/canonical/jimm/v3/internal/testutils/jimmtest/mocks" "github.com/canonical/jimm/v3/pkg/api/params" @@ -238,6 +241,122 @@ func TestPatchGroupIdentities(t *testing.T) { c.Assert(err, qt.ErrorMatches, "foo") } +func TestGetGroupRoles(t *testing.T) { + c := qt.New(t) + var listTuplesErr error + var getGroupErr error + var getRoleErr error + var continuationToken string + + testTuple := openfga.Tuple{ + Object: &ofga.Entity{Kind: "group", ID: "foo", Relation: ofganames.MemberRelation}, + Relation: ofganames.AssigneeRelation, + Target: &ofga.Entity{Kind: "role", ID: "my-role"}, + } + roleManager := mocks.RoleManager{ + GetRoleByUUID_: func(ctx context.Context, user *openfga.User, uuid string) (*dbmodel.RoleEntry, error) { + return &dbmodel.RoleEntry{}, getRoleErr + }, + } + jimm := jimmtest.JIMM{ + GetRoleManager_: func() jimm.RoleManager { + return roleManager + }, + GroupService: mocks.GroupService{ + GetGroupByUUID_: func(ctx context.Context, user *openfga.User, uuid string) (*dbmodel.GroupEntry, error) { + return nil, getGroupErr + }, + }, + RelationService: mocks.RelationService{ + ListRelationshipTuples_: func(ctx context.Context, user *openfga.User, tuple params.RelationshipTuple, pageSize int32, ct string) ([]openfga.Tuple, string, error) { + return []openfga.Tuple{testTuple}, continuationToken, listTuplesErr + }, + }, + } + + user := openfga.User{} + ctx := context.Background() + ctx = rebac_handlers.ContextWithIdentity(ctx, &user) + groupSvc := rebac_admin.NewGroupService(&jimm) + + _, err := groupSvc.GetGroupRoles(ctx, "invalid-group-id", &resources.GetGroupsItemRolesParams{}) + c.Assert(err, qt.ErrorMatches, ".*invalid group ID") + + newUUID := uuid.New() + getGroupErr = errors.New("group doesn't exist") + _, err = groupSvc.GetGroupRoles(ctx, newUUID.String(), &resources.GetGroupsItemRolesParams{}) + c.Assert(err, qt.ErrorMatches, ".*group doesn't exist") + getGroupErr = nil + + getRoleErr = jimmerr.E("role doesn't exist", jimmerr.CodeNotFound) + _, err = groupSvc.GetGroupRoles(ctx, newUUID.String(), &resources.GetGroupsItemRolesParams{}) + c.Assert(err, qt.IsNil) + getRoleErr = nil + + getRoleErr = errors.New("could not connect to DB") + _, err = groupSvc.GetGroupRoles(ctx, newUUID.String(), &resources.GetGroupsItemRolesParams{}) + c.Assert(err, qt.ErrorMatches, ".*could not connect to DB") + getRoleErr = nil + + continuationToken = "continuation-token" + res, err := groupSvc.GetGroupRoles(ctx, newUUID.String(), &resources.GetGroupsItemRolesParams{}) + c.Assert(err, qt.IsNil) + c.Assert(res, qt.IsNotNil) + c.Assert(res.Data, qt.HasLen, 1) + c.Assert(*res.Next.PageToken, qt.Equals, "continuation-token") + + continuationToken = "" + res, err = groupSvc.GetGroupRoles(ctx, newUUID.String(), &resources.GetGroupsItemRolesParams{}) + c.Assert(err, qt.IsNil) + c.Assert(res, qt.IsNotNil) + c.Assert(res.Next.PageToken, qt.IsNil) + + listTuplesErr = errors.New("foo") + _, err = groupSvc.GetGroupRoles(ctx, newUUID.String(), &resources.GetGroupsItemRolesParams{}) + c.Assert(err, qt.ErrorMatches, "foo") +} + +func TestPatchGroupRoles(t *testing.T) { + c := qt.New(t) + var patchTuplesErr error + jimm := jimmtest.JIMM{ + RelationService: mocks.RelationService{ + AddRelation_: func(ctx context.Context, user *openfga.User, tuples []params.RelationshipTuple) error { + return patchTuplesErr + }, + RemoveRelation_: func(ctx context.Context, user *openfga.User, tuples []params.RelationshipTuple) error { + return patchTuplesErr + }, + }, + } + user := openfga.User{} + ctx := context.Background() + ctx = rebac_handlers.ContextWithIdentity(ctx, &user) + groupSvc := rebac_admin.NewGroupService(&jimm) + + _, err := groupSvc.PatchGroupRoles(ctx, "invalid-group-id", nil) + c.Assert(err, qt.ErrorMatches, ".* invalid group ID") + + newUUID := uuid.New() + operations := []resources.GroupRolesPatchItem{ + {Role: uuid.NewString(), Op: resources.GroupRolesPatchItemOpAdd}, + {Role: uuid.NewString(), Op: resources.GroupRolesPatchItemOpRemove}, + } + res, err := groupSvc.PatchGroupRoles(ctx, newUUID.String(), operations) + c.Assert(err, qt.IsNil) + c.Assert(res, qt.IsTrue) + + operationsWithInvalidIdentity := []resources.GroupRolesPatchItem{ + {Role: "foo_", Op: resources.GroupRolesPatchItemOpAdd}, + } + _, err = groupSvc.PatchGroupRoles(ctx, newUUID.String(), operationsWithInvalidIdentity) + c.Assert(err, qt.ErrorMatches, ".*invalid role ID.*") + + patchTuplesErr = errors.New("foo") + _, err = groupSvc.PatchGroupRoles(ctx, newUUID.String(), operations) + c.Assert(err, qt.ErrorMatches, "foo") +} + func TestGetGroupEntitlements(t *testing.T) { c := qt.New(t) var listRelationsErr error From 37afcf56d9489be06797f2ddfee45c3f4bf01254 Mon Sep 17 00:00:00 2001 From: Kian Parvin <46668016+kian99@users.noreply.github.com> Date: Fri, 29 Nov 2024 16:11:09 +0200 Subject: [PATCH 2/3] chore: add role entitlements and capabilities (#1473) Referred to the full list of capabilities from https://github.com/canonical/rebac-admin-ui-handlers/blob/main/v1/capabilities.go Looked at the OpenFGA auth model to determine the entitlement list --- internal/common/pagination/entitlement.go | 1 + internal/jimmhttp/rebac_admin/capabilities.go | 22 +++++++++++++++++++ .../jimmhttp/rebac_admin/capabilities_test.go | 2 +- internal/jimmhttp/rebac_admin/entitlements.go | 11 ++++++++++ .../jimmhttp/rebac_admin/entitlements_test.go | 4 ++-- 5 files changed, 37 insertions(+), 3 deletions(-) diff --git a/internal/common/pagination/entitlement.go b/internal/common/pagination/entitlement.go index 276a07e38..fd26e528e 100644 --- a/internal/common/pagination/entitlement.go +++ b/internal/common/pagination/entitlement.go @@ -23,6 +23,7 @@ var entitlementResources = []openfga.Kind{ openfga.ModelType, openfga.ApplicationOfferType, openfga.GroupType, + openfga.RoleType, openfga.ServiceAccountType, } diff --git a/internal/jimmhttp/rebac_admin/capabilities.go b/internal/jimmhttp/rebac_admin/capabilities.go index 6b8303436..2e89453bb 100644 --- a/internal/jimmhttp/rebac_admin/capabilities.go +++ b/internal/jimmhttp/rebac_admin/capabilities.go @@ -89,6 +89,28 @@ var capabilities = []resources.Capability{ "PATCH", }, }, + { + Endpoint: "/groups/{id}/roles", + Methods: []resources.CapabilityMethods{ + "GET", + "PATCH", + }, + }, + { + Endpoint: "/roles", + Methods: []resources.CapabilityMethods{ + "GET", + "POST", + }, + }, + { + Endpoint: "/roles/{id}", + Methods: []resources.CapabilityMethods{ + "GET", + "PUT", + "DELETE", + }, + }, { Endpoint: "/entitlements", Methods: []resources.CapabilityMethods{ diff --git a/internal/jimmhttp/rebac_admin/capabilities_test.go b/internal/jimmhttp/rebac_admin/capabilities_test.go index e304ea1ed..dc0aac495 100644 --- a/internal/jimmhttp/rebac_admin/capabilities_test.go +++ b/internal/jimmhttp/rebac_admin/capabilities_test.go @@ -46,7 +46,7 @@ func TestCapabilities(t *testing.T) { defer resp.Body.Close() // 404 is for not found endpoints and 501 is for "not implemented" endpoints in the rebac-admin-ui-handlers library isNotFound := resp.StatusCode == 404 || resp.StatusCode == 501 - c.Assert(isNotFound, qt.IsFalse) + c.Assert(isNotFound, qt.IsFalse, qt.Commentf("failed for url %s, method %s", url, m)) }) } diff --git a/internal/jimmhttp/rebac_admin/entitlements.go b/internal/jimmhttp/rebac_admin/entitlements.go index b7c9a1459..773c7854f 100644 --- a/internal/jimmhttp/rebac_admin/entitlements.go +++ b/internal/jimmhttp/rebac_admin/entitlements.go @@ -25,28 +25,35 @@ var entitlementsList = []resources.EntitlementSchema{ {Entitlement: "administrator", ReceiverType: "user", EntityType: ApplicationOffer}, {Entitlement: "administrator", ReceiverType: "user:*", EntityType: ApplicationOffer}, {Entitlement: "administrator", ReceiverType: "group#member", EntityType: ApplicationOffer}, + {Entitlement: "administrator", ReceiverType: "role#assignee", EntityType: ApplicationOffer}, {Entitlement: "consumer", ReceiverType: "user", EntityType: ApplicationOffer}, {Entitlement: "consumer", ReceiverType: "user:*", EntityType: ApplicationOffer}, {Entitlement: "consumer", ReceiverType: "group#member", EntityType: ApplicationOffer}, + {Entitlement: "consumer", ReceiverType: "role#assignee", EntityType: ApplicationOffer}, {Entitlement: "reader", ReceiverType: "user", EntityType: ApplicationOffer}, {Entitlement: "reader", ReceiverType: "user:*", EntityType: ApplicationOffer}, {Entitlement: "reader", ReceiverType: "group#member", EntityType: ApplicationOffer}, + {Entitlement: "reader", ReceiverType: "role#assignee", EntityType: ApplicationOffer}, // cloud {Entitlement: "administrator", ReceiverType: "user", EntityType: Cloud}, {Entitlement: "administrator", ReceiverType: "user:*", EntityType: Cloud}, {Entitlement: "administrator", ReceiverType: "group#member", EntityType: Cloud}, + {Entitlement: "administrator", ReceiverType: "role#assignee", EntityType: Cloud}, {Entitlement: "can_addmodel", ReceiverType: "user", EntityType: Cloud}, {Entitlement: "can_addmodel", ReceiverType: "user:*", EntityType: Cloud}, {Entitlement: "can_addmodel", ReceiverType: "group#member", EntityType: Cloud}, + {Entitlement: "can_addmodel", ReceiverType: "role#assignee", EntityType: Cloud}, // controller {Entitlement: "administrator", ReceiverType: "user", EntityType: Controller}, {Entitlement: "administrator", ReceiverType: "user:*", EntityType: Controller}, {Entitlement: "administrator", ReceiverType: "group#member", EntityType: Controller}, + {Entitlement: "administrator", ReceiverType: "role#assignee", EntityType: Controller}, {Entitlement: "audit_log_viewer", ReceiverType: "user", EntityType: Controller}, {Entitlement: "audit_log_viewer", ReceiverType: "user:*", EntityType: Controller}, {Entitlement: "audit_log_viewer", ReceiverType: "group#member", EntityType: Controller}, + {Entitlement: "audit_log_viewer", ReceiverType: "role#assignee", EntityType: Controller}, // group {Entitlement: "member", ReceiverType: "user", EntityType: Group}, @@ -57,17 +64,21 @@ var entitlementsList = []resources.EntitlementSchema{ {Entitlement: "administrator", ReceiverType: "user", EntityType: Model}, {Entitlement: "administrator", ReceiverType: "user:*", EntityType: Model}, {Entitlement: "administrator", ReceiverType: "group#member", EntityType: Model}, + {Entitlement: "administrator", ReceiverType: "role#assignee", EntityType: Model}, {Entitlement: "reader", ReceiverType: "user", EntityType: Model}, {Entitlement: "reader", ReceiverType: "user:*", EntityType: Model}, {Entitlement: "reader", ReceiverType: "group#member", EntityType: Model}, + {Entitlement: "reader", ReceiverType: "role#assignee", EntityType: Model}, {Entitlement: "writer", ReceiverType: "user", EntityType: Model}, {Entitlement: "writer", ReceiverType: "user:*", EntityType: Model}, {Entitlement: "writer", ReceiverType: "group#member", EntityType: Model}, + {Entitlement: "writer", ReceiverType: "role#assignee", EntityType: Model}, // serviceaccount {Entitlement: "administrator", ReceiverType: "user", EntityType: ServiceAccount}, {Entitlement: "administrator", ReceiverType: "user:*", EntityType: ServiceAccount}, {Entitlement: "administrator", ReceiverType: "group#member", EntityType: ServiceAccount}, + {Entitlement: "administrator", ReceiverType: "role#assignee", EntityType: ServiceAccount}, } // entitlementsService implements the `entitlementsService` interface from rebac-admin-ui-handlers library diff --git a/internal/jimmhttp/rebac_admin/entitlements_test.go b/internal/jimmhttp/rebac_admin/entitlements_test.go index 2273adf56..03da40eab 100644 --- a/internal/jimmhttp/rebac_admin/entitlements_test.go +++ b/internal/jimmhttp/rebac_admin/entitlements_test.go @@ -26,13 +26,13 @@ func TestEntitlements(t *testing.T) { params.Filter = &match entitlements, err = entitlementSvc.ListEntitlements(ctx, params) c.Assert(err, qt.IsNil) - c.Assert(entitlements, qt.HasLen, 15) + c.Assert(entitlements, qt.HasLen, 20) match = "cloud" params.Filter = &match entitlements, err = entitlementSvc.ListEntitlements(ctx, params) c.Assert(err, qt.IsNil) - c.Assert(entitlements, qt.HasLen, 6) + c.Assert(entitlements, qt.HasLen, 8) match = "#member" params.Filter = &match From 63be70f0b76127013fec001746e4b62bf1080592 Mon Sep 17 00:00:00 2001 From: Kian Parvin <46668016+kian99@users.noreply.github.com> Date: Mon, 2 Dec 2024 09:25:32 +0200 Subject: [PATCH 3/3] fix: flaky test in TestProxySocketsAdminFacade (#1474) * fix: flaky test in TestProxySocketsAdminFacade The test was incorrectly cancelling the context. When the proxySockets function encounters an error like a failure to connect to a Juju controller it returns with an error. Otherwise it returns a nil error when the context is cancelled. Here we were expecting the former for one of the tests cases but sometimes encountering the latter in a race condition. * chore: remove expectProxyError bool from test --- internal/rpc/apiproxy_test.go | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/internal/rpc/apiproxy_test.go b/internal/rpc/apiproxy_test.go index 3a830c241..cf4fe793f 100644 --- a/internal/rpc/apiproxy_test.go +++ b/internal/rpc/apiproxy_test.go @@ -70,7 +70,7 @@ func TestProxySocketsAdminFacade(t *testing.T) { expectedClientResponse *message expectedControllerMessage *message oauthAuthenticatorError error - expectConnectTofail bool + expectedProxyError string }{{ about: "login device call - client gets response with both user code and verification uri", messageToSend: message{ @@ -232,18 +232,21 @@ func TestProxySocketsAdminFacade(t *testing.T) { }, oauthAuthenticatorError: errors.E(errors.CodeUnauthorized), }, { - about: "connection to controller fails", - expectConnectTofail: true, + about: "connection to controller fails", expectedClientResponse: &message{ Error: "controller connection error", }, + expectedProxyError: "failed to connect to controller: controller connection error", }} for _, test := range tests { t.Run(test.about, func(t *testing.T) { + proxyError := test.expectedProxyError != "" + ctx := context.Background() ctx, cancelFunc := context.WithCancel(ctx) defer cancelFunc() + clientWebsocket := newMockWebsocketConnection(10) controllerWebsocket := newMockWebsocketConnection(10) loginSvc := &mockLoginService{ @@ -257,7 +260,7 @@ func TestProxySocketsAdminFacade(t *testing.T) { ConnClient: clientWebsocket, TokenGen: &mockTokenGenerator{}, ConnectController: func(ctx context.Context) (rpc.WebsocketConnectionWithMetadata, error) { - if test.expectConnectTofail { + if proxyError { return rpc.WebsocketConnectionWithMetadata{}, goerr.New("controller connection error") } return rpc.WebsocketConnectionWithMetadata{ @@ -275,8 +278,8 @@ func TestProxySocketsAdminFacade(t *testing.T) { go func() { defer wg.Done() err = rpc.ProxySockets(ctx, helpers) - if test.expectConnectTofail { - c.Assert(err, qt.ErrorMatches, "failed to connect to controller: controller connection error") + if proxyError { + c.Assert(err, qt.ErrorMatches, test.expectedProxyError) } else { c.Assert(err, qt.ErrorMatches, "Context cancelled") } @@ -300,7 +303,9 @@ func TestProxySocketsAdminFacade(t *testing.T) { c.Fatal("timed out waiting for response") } } - cancelFunc() + if !proxyError { + cancelFunc() + } wg.Wait() t.Logf("completed test %s", t.Name()) })