From d53de579b3162a00d9693eb20aa508283beb0032 Mon Sep 17 00:00:00 2001 From: Kian Parvin Date: Tue, 17 Dec 2024 15:38:40 +0200 Subject: [PATCH] chore: move jwt generator Move the jwtGenerator into a separate package. --- internal/jimm/access.go | 159 -------- internal/jimm/access_test.go | 342 ----------------- internal/jimm/jwtgenerator/jwtgenerator.go | 180 +++++++++ .../jimm/jwtgenerator/jwtgenerator_test.go | 357 ++++++++++++++++++ internal/jujuapi/websocket.go | 5 +- 5 files changed, 540 insertions(+), 503 deletions(-) create mode 100644 internal/jimm/jwtgenerator/jwtgenerator.go create mode 100644 internal/jimm/jwtgenerator/jwtgenerator_test.go diff --git a/internal/jimm/access.go b/internal/jimm/access.go index 98f809786..de85a10a8 100644 --- a/internal/jimm/access.go +++ b/internal/jimm/access.go @@ -8,7 +8,6 @@ import ( "fmt" "regexp" "strings" - "sync" "github.com/canonical/ofga" "github.com/google/uuid" @@ -20,7 +19,6 @@ import ( "github.com/canonical/jimm/v3/internal/db" "github.com/canonical/jimm/v3/internal/dbmodel" "github.com/canonical/jimm/v3/internal/errors" - "github.com/canonical/jimm/v3/internal/jimmjwx" "github.com/canonical/jimm/v3/internal/openfga" ofganames "github.com/canonical/jimm/v3/internal/openfga/names" "github.com/canonical/jimm/v3/internal/servermon" @@ -147,163 +145,6 @@ func ToOfferRelation(accessLevel string) (openfga.Relation, error) { } } -// JWTGeneratorDatabase specifies the database interface used by the -// JWT generator. -type JWTGeneratorDatabase interface { - GetController(ctx context.Context, controller *dbmodel.Controller) error -} - -// JWTGeneratorAccessChecker specifies the access checker used by the JWT -// generator to obtain user's access rights to various entities. -type JWTGeneratorAccessChecker interface { - GetUserModelAccess(context.Context, *openfga.User, names.ModelTag) (string, error) - GetUserControllerAccess(context.Context, *openfga.User, names.ControllerTag) (string, error) - GetUserCloudAccess(context.Context, *openfga.User, names.CloudTag) (string, error) - CheckPermission(context.Context, *openfga.User, map[string]string, map[string]interface{}) (map[string]string, error) -} - -// JWTService specifies the service JWT generator uses to generate JWTs. -type JWTService interface { - NewJWT(context.Context, jimmjwx.JWTParams) ([]byte, error) -} - -// JWTGenerator provides the necessary state and methods to authorize a user and generate JWT tokens. -type JWTGenerator struct { - database JWTGeneratorDatabase - accessChecker JWTGeneratorAccessChecker - jwtService JWTService - - mu sync.Mutex - accessMapCache map[string]string - mt names.ModelTag - ct names.ControllerTag - user *openfga.User - callCount int -} - -// NewJWTGenerator returns a new JwtAuthorizer struct -func NewJWTGenerator(database JWTGeneratorDatabase, accessChecker JWTGeneratorAccessChecker, jwtService JWTService) JWTGenerator { - return JWTGenerator{ - database: database, - accessChecker: accessChecker, - jwtService: jwtService, - } -} - -// SetTags implements TokenGenerator -func (auth *JWTGenerator) SetTags(mt names.ModelTag, ct names.ControllerTag) { - auth.mt = mt - auth.ct = ct -} - -// SetTags implements TokenGenerator -func (auth *JWTGenerator) GetUser() names.UserTag { - if auth.user != nil { - return auth.user.ResourceTag() - } - return names.UserTag{} -} - -// MakeLoginToken authorizes the user based on the provided login requests and returns -// a JWT containing claims about user's access to the controller, model (if applicable) -// and all clouds that the controller knows about. -func (auth *JWTGenerator) MakeLoginToken(ctx context.Context, user *openfga.User) ([]byte, error) { - const op = errors.Op("jimm.MakeLoginToken") - - auth.mu.Lock() - defer auth.mu.Unlock() - - if user == nil { - return nil, errors.E(op, "user not specified") - } - auth.user = user - - // Recreate the accessMapCache to prevent leaking permissions across multiple login requests. - auth.accessMapCache = make(map[string]string) - var authErr error - - var modelAccess string - if auth.mt.Id() == "" { - return nil, errors.E(op, "model not set") - } - modelAccess, authErr = auth.accessChecker.GetUserModelAccess(ctx, auth.user, auth.mt) - if authErr != nil { - zapctx.Error(ctx, "model access check failed", zap.Error(authErr)) - return nil, authErr - } - auth.accessMapCache[auth.mt.String()] = modelAccess - - if auth.ct.Id() == "" { - return nil, errors.E(op, "controller not set") - } - var controllerAccess string - controllerAccess, authErr = auth.accessChecker.GetUserControllerAccess(ctx, auth.user, auth.ct) - if authErr != nil { - return nil, authErr - } - auth.accessMapCache[auth.ct.String()] = controllerAccess - - var ctl dbmodel.Controller - ctl.SetTag(auth.ct) - err := auth.database.GetController(ctx, &ctl) - if err != nil { - zapctx.Error(ctx, "failed to fetch controller", zap.Error(err)) - return nil, errors.E(op, "failed to fetch controller", err) - } - clouds := make(map[names.CloudTag]bool) - for _, cloudRegion := range ctl.CloudRegions { - clouds[cloudRegion.CloudRegion.Cloud.ResourceTag()] = true - } - for cloudTag := range clouds { - accessLevel, err := auth.accessChecker.GetUserCloudAccess(ctx, auth.user, cloudTag) - if err != nil { - zapctx.Error(ctx, "cloud access check failed", zap.Error(err)) - return nil, errors.E(op, "failed to check user's cloud access", err) - } - auth.accessMapCache[cloudTag.String()] = accessLevel - } - - return auth.jwtService.NewJWT(ctx, jimmjwx.JWTParams{ - Controller: auth.ct.Id(), - User: auth.user.Tag().String(), - Access: auth.accessMapCache, - }) -} - -// MakeToken assumes MakeLoginToken has already been called and checks the permissions -// specified in the permissionMap. If the logged in user has all those permissions -// a JWT will be returned with assertions confirming all those permissions. -func (auth *JWTGenerator) MakeToken(ctx context.Context, permissionMap map[string]interface{}) ([]byte, error) { - const op = errors.Op("jimm.MakeToken") - - auth.mu.Lock() - defer auth.mu.Unlock() - - if auth.callCount >= 10 { - return nil, errors.E(op, "Permission check limit exceeded") - } - auth.callCount++ - if auth.user == nil { - return nil, errors.E(op, "User authorization missing.") - } - if permissionMap != nil { - var err error - auth.accessMapCache, err = auth.accessChecker.CheckPermission(ctx, auth.user, auth.accessMapCache, permissionMap) - if err != nil { - return nil, err - } - } - jwt, err := auth.jwtService.NewJWT(ctx, jimmjwx.JWTParams{ - Controller: auth.ct.Id(), - User: auth.user.Tag().String(), - Access: auth.accessMapCache, - }) - if err != nil { - return nil, err - } - return jwt, nil -} - // CheckPermission loops over the desired permissions in desiredPerms and adds these permissions // to cachedPerms if they exist. If the user does not have any of the desired permissions then an // error is returned. diff --git a/internal/jimm/access_test.go b/internal/jimm/access_test.go index 5efbfba74..4d9d83107 100644 --- a/internal/jimm/access_test.go +++ b/internal/jimm/access_test.go @@ -14,98 +14,13 @@ import ( "github.com/juju/names/v5" "github.com/canonical/jimm/v3/internal/dbmodel" - "github.com/canonical/jimm/v3/internal/errors" "github.com/canonical/jimm/v3/internal/jimm" - "github.com/canonical/jimm/v3/internal/jimmjwx" "github.com/canonical/jimm/v3/internal/openfga" ofganames "github.com/canonical/jimm/v3/internal/openfga/names" "github.com/canonical/jimm/v3/internal/testutils/jimmtest" jimmnames "github.com/canonical/jimm/v3/pkg/names" ) -// testDatabase is a database implementation intended for testing the token generator. -type testDatabase struct { - ctl dbmodel.Controller - err error -} - -// GetController implements the GetController method of the JWTGeneratorDatabase interface. -func (tdb *testDatabase) GetController(ctx context.Context, controller *dbmodel.Controller) error { - if tdb.err != nil { - return tdb.err - } - *controller = tdb.ctl - return nil -} - -// testAccessChecker is an access checker implementation intended for testing the -// token generator. -type testAccessChecker struct { - controllerAccess map[string]string - controllerAccessCheckErr error - modelAccess map[string]string - modelAccessCheckErr error - cloudAccess map[string]string - cloudAccessCheckErr error - permissions map[string]string - permissionCheckErr error -} - -// GetUserModelAccess implements the GetUserModelAccess method of the JWTGeneratorAccessChecker interface. -func (tac *testAccessChecker) GetUserModelAccess(ctx context.Context, user *openfga.User, mt names.ModelTag) (string, error) { - if tac.modelAccessCheckErr != nil { - return "", tac.modelAccessCheckErr - } - return tac.modelAccess[mt.String()], nil -} - -// GetUserControllerAccess implements the GetUserControllerAccess method of the JWTGeneratorAccessChecker interface. -func (tac *testAccessChecker) GetUserControllerAccess(ctx context.Context, user *openfga.User, ct names.ControllerTag) (string, error) { - if tac.controllerAccessCheckErr != nil { - return "", tac.controllerAccessCheckErr - } - return tac.controllerAccess[ct.String()], nil -} - -// GetUserCloudAccess implements the GetUserCloudAccess method of the JWTGeneratorAccessChecker interface. -func (tac *testAccessChecker) GetUserCloudAccess(ctx context.Context, user *openfga.User, ct names.CloudTag) (string, error) { - if tac.cloudAccessCheckErr != nil { - return "", tac.cloudAccessCheckErr - } - return tac.cloudAccess[ct.String()], nil -} - -// CheckPermission implements the CheckPermission methods of the JWTGeneratorAccessChecker interface. -func (tac *testAccessChecker) CheckPermission(ctx context.Context, user *openfga.User, accessMap map[string]string, permissions map[string]interface{}) (map[string]string, error) { - if tac.permissionCheckErr != nil { - return nil, tac.permissionCheckErr - } - access := make(map[string]string) - for k, v := range accessMap { - access[k] = v - } - for k, v := range tac.permissions { - access[k] = v - } - return access, nil -} - -// testJWTService is a jwt service implementation intended for testing the token generator. -type testJWTService struct { - newJWTError error - - params jimmjwx.JWTParams -} - -// NewJWT implements the NewJWT methods of the JWTService interface. -func (t *testJWTService) NewJWT(ctx context.Context, params jimmjwx.JWTParams) ([]byte, error) { - if t.newJWTError != nil { - return nil, t.newJWTError - } - t.params = params - return []byte("test jwt"), nil -} - func TestAuditLogAccess(t *testing.T) { c := qt.New(t) @@ -154,263 +69,6 @@ func TestAuditLogAccess(t *testing.T) { c.Assert(err, qt.ErrorMatches, "unauthorized") } -func TestJWTGeneratorMakeLoginToken(t *testing.T) { - c := qt.New(t) - - ct := names.NewControllerTag(uuid.New().String()) - mt := names.NewModelTag(uuid.New().String()) - - tests := []struct { - about string - username string - database *testDatabase - accessChecker *testAccessChecker - jwtService *testJWTService - expectedError string - expectedJWTParams jimmjwx.JWTParams - }{{ - about: "initial login, all is well", - username: "eve@canonical.com", - database: &testDatabase{ - ctl: dbmodel.Controller{ - CloudRegions: []dbmodel.CloudRegionControllerPriority{{ - CloudRegion: dbmodel.CloudRegion{ - Cloud: dbmodel.Cloud{ - Name: "test-cloud", - }, - }, - }}, - }, - }, - accessChecker: &testAccessChecker{ - modelAccess: map[string]string{ - mt.String(): "admin", - }, - controllerAccess: map[string]string{ - ct.String(): "superuser", - }, - cloudAccess: map[string]string{ - names.NewCloudTag("test-cloud").String(): "add-model", - }, - }, - jwtService: &testJWTService{}, - expectedJWTParams: jimmjwx.JWTParams{ - Controller: ct.Id(), - User: names.NewUserTag("eve@canonical.com").String(), - Access: map[string]string{ - ct.String(): "superuser", - mt.String(): "admin", - names.NewCloudTag("test-cloud").String(): "add-model", - }, - }, - }, { - about: "model access check fails", - username: "eve@canonical.com", - accessChecker: &testAccessChecker{ - modelAccessCheckErr: errors.E("a test error"), - }, - jwtService: &testJWTService{}, - expectedError: "a test error", - }, { - about: "controller access check fails", - username: "eve@canonical.com", - accessChecker: &testAccessChecker{ - modelAccess: map[string]string{ - mt.String(): "admin", - }, - controllerAccessCheckErr: errors.E("a test error"), - }, - expectedError: "a test error", - }, { - about: "get controller from db fails", - username: "eve@canonical.com", - database: &testDatabase{ - err: errors.E("a test error"), - }, - accessChecker: &testAccessChecker{ - modelAccess: map[string]string{ - mt.String(): "admin", - }, - controllerAccess: map[string]string{ - ct.String(): "superuser", - }, - }, - expectedError: "failed to fetch controller", - }, { - about: "cloud access check fails", - username: "eve@canonical.com", - database: &testDatabase{ - ctl: dbmodel.Controller{ - CloudRegions: []dbmodel.CloudRegionControllerPriority{{ - CloudRegion: dbmodel.CloudRegion{ - Cloud: dbmodel.Cloud{ - Name: "test-cloud", - }, - }, - }}, - }, - }, - accessChecker: &testAccessChecker{ - modelAccess: map[string]string{ - mt.String(): "admin", - }, - controllerAccess: map[string]string{ - ct.String(): "superuser", - }, - cloudAccessCheckErr: errors.E("a test error"), - }, - expectedError: "failed to check user's cloud access", - }, { - about: "jwt service errors out", - username: "eve@canonical.com", - database: &testDatabase{ - ctl: dbmodel.Controller{ - CloudRegions: []dbmodel.CloudRegionControllerPriority{{ - CloudRegion: dbmodel.CloudRegion{ - Cloud: dbmodel.Cloud{ - Name: "test-cloud", - }, - }, - }}, - }, - }, - accessChecker: &testAccessChecker{ - modelAccess: map[string]string{ - mt.String(): "admin", - }, - controllerAccess: map[string]string{ - ct.String(): "superuser", - }, - cloudAccess: map[string]string{ - names.NewCloudTag("test-cloud").String(): "add-model", - }, - }, - jwtService: &testJWTService{ - newJWTError: errors.E("a test error"), - }, - expectedError: "a test error", - }} - - for _, test := range tests { - generator := jimm.NewJWTGenerator(test.database, test.accessChecker, test.jwtService) - generator.SetTags(mt, ct) - - i, err := dbmodel.NewIdentity(test.username) - c.Assert(err, qt.IsNil) - _, err = generator.MakeLoginToken(context.Background(), &openfga.User{ - Identity: i, - }) - if test.expectedError != "" { - c.Assert(err, qt.ErrorMatches, test.expectedError) - } else { - c.Assert(err, qt.IsNil) - c.Assert(test.jwtService.params, qt.DeepEquals, test.expectedJWTParams) - } - } -} - -func TestJWTGeneratorMakeToken(t *testing.T) { - c := qt.New(t) - - ct := names.NewControllerTag(uuid.New().String()) - mt := names.NewModelTag(uuid.New().String()) - - tests := []struct { - about string - checkPermissions map[string]string - checkPermissionsError error - jwtService *testJWTService - expectedError string - permissions map[string]interface{} - expectedJWTParams jimmjwx.JWTParams - }{{ - about: "all is well", - jwtService: &testJWTService{}, - expectedJWTParams: jimmjwx.JWTParams{ - Controller: ct.Id(), - User: names.NewUserTag("eve@canonical.com").String(), - Access: map[string]string{ - ct.String(): "superuser", - mt.String(): "admin", - names.NewCloudTag("test-cloud").String(): "add-model", - }, - }, - }, { - about: "check permission fails", - jwtService: &testJWTService{}, - permissions: map[string]interface{}{ - "entity1": "access_level1", - }, - checkPermissionsError: errors.E("a test error"), - expectedError: "a test error", - }, { - about: "additional permissions need checking", - jwtService: &testJWTService{}, - permissions: map[string]interface{}{ - "entity1": "access_level1", - }, - checkPermissions: map[string]string{ - "entity1": "access_level1", - }, - expectedJWTParams: jimmjwx.JWTParams{ - Controller: ct.Id(), - User: names.NewUserTag("eve@canonical.com").String(), - Access: map[string]string{ - ct.String(): "superuser", - mt.String(): "admin", - names.NewCloudTag("test-cloud").String(): "add-model", - "entity1": "access_level1", - }, - }, - }} - - for _, test := range tests { - generator := jimm.NewJWTGenerator( - &testDatabase{ - ctl: dbmodel.Controller{ - CloudRegions: []dbmodel.CloudRegionControllerPriority{{ - CloudRegion: dbmodel.CloudRegion{ - Cloud: dbmodel.Cloud{ - Name: "test-cloud", - }, - }, - }}, - }, - }, - &testAccessChecker{ - modelAccess: map[string]string{ - mt.String(): "admin", - }, - controllerAccess: map[string]string{ - ct.String(): "superuser", - }, - cloudAccess: map[string]string{ - names.NewCloudTag("test-cloud").String(): "add-model", - }, - permissions: test.checkPermissions, - permissionCheckErr: test.checkPermissionsError, - }, - test.jwtService, - ) - generator.SetTags(mt, ct) - - i, err := dbmodel.NewIdentity("eve@canonical.com") - c.Assert(err, qt.IsNil) - _, err = generator.MakeLoginToken(context.Background(), &openfga.User{ - Identity: i, - }) - c.Assert(err, qt.IsNil) - - _, err = generator.MakeToken(context.Background(), test.permissions) - if test.expectedError != "" { - c.Assert(err, qt.ErrorMatches, test.expectedError) - } else { - c.Assert(err, qt.IsNil) - c.Assert(test.jwtService.params, qt.DeepEquals, test.expectedJWTParams) - } - } -} - func TestParseAndValidateTag(t *testing.T) { c := qt.New(t) ctx := context.Background() diff --git a/internal/jimm/jwtgenerator/jwtgenerator.go b/internal/jimm/jwtgenerator/jwtgenerator.go new file mode 100644 index 000000000..33ed2a3d2 --- /dev/null +++ b/internal/jimm/jwtgenerator/jwtgenerator.go @@ -0,0 +1,180 @@ +// Copyright 2024 Canonical. + +// jwtgenerator generates JWT tokens to authenticate +// and authorize messages to Juju controllers. +// This package is more specialised than a generic +// JWT token generator as it crafts Juju specific +// permissions that are added as claims to the JWT +// and therefore exists in JIMM's business logic layer. +package jwtgenerator + +import ( + "context" + "sync" + + "github.com/juju/names/v5" + "github.com/juju/zaputil/zapctx" + "go.uber.org/zap" + + "github.com/canonical/jimm/v3/internal/dbmodel" + "github.com/canonical/jimm/v3/internal/errors" + "github.com/canonical/jimm/v3/internal/jimmjwx" + "github.com/canonical/jimm/v3/internal/openfga" +) + +// jwtGeneratorDatabase specifies the database interface used by the +// JWT generator. +type jwtGeneratorDatabase interface { + GetController(ctx context.Context, controller *dbmodel.Controller) error +} + +// jwtGeneratorAccessChecker specifies the access checker used by the JWT +// generator to obtain user's access rights to various entities. +type jwtGeneratorAccessChecker interface { + GetUserModelAccess(context.Context, *openfga.User, names.ModelTag) (string, error) + GetUserControllerAccess(context.Context, *openfga.User, names.ControllerTag) (string, error) + GetUserCloudAccess(context.Context, *openfga.User, names.CloudTag) (string, error) + CheckPermission(context.Context, *openfga.User, map[string]string, map[string]interface{}) (map[string]string, error) +} + +// jwtService specifies the service JWT generator uses to generate JWTs. +type jwtService interface { + NewJWT(context.Context, jimmjwx.JWTParams) ([]byte, error) +} + +// JWTGenerator provides the necessary state and methods to authorize a user and generate JWT tokens. +type JWTGenerator struct { + database jwtGeneratorDatabase + accessChecker jwtGeneratorAccessChecker + jwtService jwtService + + mu sync.Mutex + accessMapCache map[string]string + mt names.ModelTag + ct names.ControllerTag + user *openfga.User + callCount int +} + +// New returns a new JWTGenerator. +func New(database jwtGeneratorDatabase, accessChecker jwtGeneratorAccessChecker, jwtService jwtService) JWTGenerator { + return JWTGenerator{ + database: database, + accessChecker: accessChecker, + jwtService: jwtService, + } +} + +// SetTags implements TokenGenerator. +func (auth *JWTGenerator) SetTags(mt names.ModelTag, ct names.ControllerTag) { + auth.mt = mt + auth.ct = ct +} + +// SetTags implements TokenGenerator. +func (auth *JWTGenerator) GetUser() names.UserTag { + if auth.user != nil { + return auth.user.ResourceTag() + } + return names.UserTag{} +} + +// MakeLoginToken authorizes the user based on the provided login requests and returns +// a JWT containing claims about user's access to the controller, model (if applicable) +// and all clouds that the controller knows about. +func (auth *JWTGenerator) MakeLoginToken(ctx context.Context, user *openfga.User) ([]byte, error) { + const op = errors.Op("jimm.MakeLoginToken") + + auth.mu.Lock() + defer auth.mu.Unlock() + + if user == nil { + return nil, errors.E(op, "user not specified") + } + auth.user = user + + // Recreate the accessMapCache to prevent leaking permissions across multiple login requests. + auth.accessMapCache = make(map[string]string) + var authErr error + + var modelAccess string + if auth.mt.Id() == "" { + return nil, errors.E(op, "model not set") + } + modelAccess, authErr = auth.accessChecker.GetUserModelAccess(ctx, auth.user, auth.mt) + if authErr != nil { + zapctx.Error(ctx, "model access check failed", zap.Error(authErr)) + return nil, authErr + } + auth.accessMapCache[auth.mt.String()] = modelAccess + + if auth.ct.Id() == "" { + return nil, errors.E(op, "controller not set") + } + var controllerAccess string + controllerAccess, authErr = auth.accessChecker.GetUserControllerAccess(ctx, auth.user, auth.ct) + if authErr != nil { + return nil, authErr + } + auth.accessMapCache[auth.ct.String()] = controllerAccess + + var ctl dbmodel.Controller + ctl.SetTag(auth.ct) + err := auth.database.GetController(ctx, &ctl) + if err != nil { + zapctx.Error(ctx, "failed to fetch controller", zap.Error(err)) + return nil, errors.E(op, "failed to fetch controller", err) + } + clouds := make(map[names.CloudTag]bool) + for _, cloudRegion := range ctl.CloudRegions { + clouds[cloudRegion.CloudRegion.Cloud.ResourceTag()] = true + } + for cloudTag := range clouds { + accessLevel, err := auth.accessChecker.GetUserCloudAccess(ctx, auth.user, cloudTag) + if err != nil { + zapctx.Error(ctx, "cloud access check failed", zap.Error(err)) + return nil, errors.E(op, "failed to check user's cloud access", err) + } + auth.accessMapCache[cloudTag.String()] = accessLevel + } + + return auth.jwtService.NewJWT(ctx, jimmjwx.JWTParams{ + Controller: auth.ct.Id(), + User: auth.user.Tag().String(), + Access: auth.accessMapCache, + }) +} + +// MakeToken assumes MakeLoginToken has already been called and checks the permissions +// specified in the permissionMap. If the logged in user has all those permissions +// a JWT will be returned with assertions confirming all those permissions. +func (auth *JWTGenerator) MakeToken(ctx context.Context, permissionMap map[string]interface{}) ([]byte, error) { + const op = errors.Op("jimm.MakeToken") + + auth.mu.Lock() + defer auth.mu.Unlock() + + if auth.callCount >= 10 { + return nil, errors.E(op, "Permission check limit exceeded") + } + auth.callCount++ + if auth.user == nil { + return nil, errors.E(op, "User authorization missing.") + } + if permissionMap != nil { + var err error + auth.accessMapCache, err = auth.accessChecker.CheckPermission(ctx, auth.user, auth.accessMapCache, permissionMap) + if err != nil { + return nil, err + } + } + jwt, err := auth.jwtService.NewJWT(ctx, jimmjwx.JWTParams{ + Controller: auth.ct.Id(), + User: auth.user.Tag().String(), + Access: auth.accessMapCache, + }) + if err != nil { + return nil, err + } + return jwt, nil +} diff --git a/internal/jimm/jwtgenerator/jwtgenerator_test.go b/internal/jimm/jwtgenerator/jwtgenerator_test.go new file mode 100644 index 000000000..048feff5b --- /dev/null +++ b/internal/jimm/jwtgenerator/jwtgenerator_test.go @@ -0,0 +1,357 @@ +// Copyright 2024 Canonical. +package jwtgenerator_test + +import ( + "context" + "testing" + + qt "github.com/frankban/quicktest" + "github.com/google/uuid" + "github.com/juju/names/v5" + + "github.com/canonical/jimm/v3/internal/dbmodel" + "github.com/canonical/jimm/v3/internal/errors" + "github.com/canonical/jimm/v3/internal/jimm/jwtgenerator" + "github.com/canonical/jimm/v3/internal/jimmjwx" + "github.com/canonical/jimm/v3/internal/openfga" +) + +// testDatabase is a database implementation intended for testing the token generator. +type testDatabase struct { + ctl dbmodel.Controller + err error +} + +// GetController implements the GetController method of the JWTGeneratorDatabase interface. +func (tdb *testDatabase) GetController(ctx context.Context, controller *dbmodel.Controller) error { + if tdb.err != nil { + return tdb.err + } + *controller = tdb.ctl + return nil +} + +// testAccessChecker is an access checker implementation intended for testing the +// token generator. +type testAccessChecker struct { + controllerAccess map[string]string + controllerAccessCheckErr error + modelAccess map[string]string + modelAccessCheckErr error + cloudAccess map[string]string + cloudAccessCheckErr error + permissions map[string]string + permissionCheckErr error +} + +// GetUserModelAccess implements the GetUserModelAccess method of the JWTGeneratorAccessChecker interface. +func (tac *testAccessChecker) GetUserModelAccess(ctx context.Context, user *openfga.User, mt names.ModelTag) (string, error) { + if tac.modelAccessCheckErr != nil { + return "", tac.modelAccessCheckErr + } + return tac.modelAccess[mt.String()], nil +} + +// GetUserControllerAccess implements the GetUserControllerAccess method of the JWTGeneratorAccessChecker interface. +func (tac *testAccessChecker) GetUserControllerAccess(ctx context.Context, user *openfga.User, ct names.ControllerTag) (string, error) { + if tac.controllerAccessCheckErr != nil { + return "", tac.controllerAccessCheckErr + } + return tac.controllerAccess[ct.String()], nil +} + +// GetUserCloudAccess implements the GetUserCloudAccess method of the JWTGeneratorAccessChecker interface. +func (tac *testAccessChecker) GetUserCloudAccess(ctx context.Context, user *openfga.User, ct names.CloudTag) (string, error) { + if tac.cloudAccessCheckErr != nil { + return "", tac.cloudAccessCheckErr + } + return tac.cloudAccess[ct.String()], nil +} + +// CheckPermission implements the CheckPermission methods of the JWTGeneratorAccessChecker interface. +func (tac *testAccessChecker) CheckPermission(ctx context.Context, user *openfga.User, accessMap map[string]string, permissions map[string]interface{}) (map[string]string, error) { + if tac.permissionCheckErr != nil { + return nil, tac.permissionCheckErr + } + access := make(map[string]string) + for k, v := range accessMap { + access[k] = v + } + for k, v := range tac.permissions { + access[k] = v + } + return access, nil +} + +// testJWTService is a jwt service implementation intended for testing the token generator. +type testJWTService struct { + newJWTError error + + params jimmjwx.JWTParams +} + +// NewJWT implements the NewJWT methods of the JWTService interface. +func (t *testJWTService) NewJWT(ctx context.Context, params jimmjwx.JWTParams) ([]byte, error) { + if t.newJWTError != nil { + return nil, t.newJWTError + } + t.params = params + return []byte("test jwt"), nil +} + +func TestJWTGeneratorMakeLoginToken(t *testing.T) { + c := qt.New(t) + + ct := names.NewControllerTag(uuid.New().String()) + mt := names.NewModelTag(uuid.New().String()) + + tests := []struct { + about string + username string + database *testDatabase + accessChecker *testAccessChecker + jwtService *testJWTService + expectedError string + expectedJWTParams jimmjwx.JWTParams + }{{ + about: "initial login, all is well", + username: "eve@canonical.com", + database: &testDatabase{ + ctl: dbmodel.Controller{ + CloudRegions: []dbmodel.CloudRegionControllerPriority{{ + CloudRegion: dbmodel.CloudRegion{ + Cloud: dbmodel.Cloud{ + Name: "test-cloud", + }, + }, + }}, + }, + }, + accessChecker: &testAccessChecker{ + modelAccess: map[string]string{ + mt.String(): "admin", + }, + controllerAccess: map[string]string{ + ct.String(): "superuser", + }, + cloudAccess: map[string]string{ + names.NewCloudTag("test-cloud").String(): "add-model", + }, + }, + jwtService: &testJWTService{}, + expectedJWTParams: jimmjwx.JWTParams{ + Controller: ct.Id(), + User: names.NewUserTag("eve@canonical.com").String(), + Access: map[string]string{ + ct.String(): "superuser", + mt.String(): "admin", + names.NewCloudTag("test-cloud").String(): "add-model", + }, + }, + }, { + about: "model access check fails", + username: "eve@canonical.com", + accessChecker: &testAccessChecker{ + modelAccessCheckErr: errors.E("a test error"), + }, + jwtService: &testJWTService{}, + expectedError: "a test error", + }, { + about: "controller access check fails", + username: "eve@canonical.com", + accessChecker: &testAccessChecker{ + modelAccess: map[string]string{ + mt.String(): "admin", + }, + controllerAccessCheckErr: errors.E("a test error"), + }, + expectedError: "a test error", + }, { + about: "get controller from db fails", + username: "eve@canonical.com", + database: &testDatabase{ + err: errors.E("a test error"), + }, + accessChecker: &testAccessChecker{ + modelAccess: map[string]string{ + mt.String(): "admin", + }, + controllerAccess: map[string]string{ + ct.String(): "superuser", + }, + }, + expectedError: "failed to fetch controller", + }, { + about: "cloud access check fails", + username: "eve@canonical.com", + database: &testDatabase{ + ctl: dbmodel.Controller{ + CloudRegions: []dbmodel.CloudRegionControllerPriority{{ + CloudRegion: dbmodel.CloudRegion{ + Cloud: dbmodel.Cloud{ + Name: "test-cloud", + }, + }, + }}, + }, + }, + accessChecker: &testAccessChecker{ + modelAccess: map[string]string{ + mt.String(): "admin", + }, + controllerAccess: map[string]string{ + ct.String(): "superuser", + }, + cloudAccessCheckErr: errors.E("a test error"), + }, + expectedError: "failed to check user's cloud access", + }, { + about: "jwt service errors out", + username: "eve@canonical.com", + database: &testDatabase{ + ctl: dbmodel.Controller{ + CloudRegions: []dbmodel.CloudRegionControllerPriority{{ + CloudRegion: dbmodel.CloudRegion{ + Cloud: dbmodel.Cloud{ + Name: "test-cloud", + }, + }, + }}, + }, + }, + accessChecker: &testAccessChecker{ + modelAccess: map[string]string{ + mt.String(): "admin", + }, + controllerAccess: map[string]string{ + ct.String(): "superuser", + }, + cloudAccess: map[string]string{ + names.NewCloudTag("test-cloud").String(): "add-model", + }, + }, + jwtService: &testJWTService{ + newJWTError: errors.E("a test error"), + }, + expectedError: "a test error", + }} + + for _, test := range tests { + generator := jwtgenerator.New(test.database, test.accessChecker, test.jwtService) + generator.SetTags(mt, ct) + + i, err := dbmodel.NewIdentity(test.username) + c.Assert(err, qt.IsNil) + _, err = generator.MakeLoginToken(context.Background(), &openfga.User{ + Identity: i, + }) + if test.expectedError != "" { + c.Assert(err, qt.ErrorMatches, test.expectedError) + } else { + c.Assert(err, qt.IsNil) + c.Assert(test.jwtService.params, qt.DeepEquals, test.expectedJWTParams) + } + } +} + +func TestJWTGeneratorMakeToken(t *testing.T) { + c := qt.New(t) + + ct := names.NewControllerTag(uuid.New().String()) + mt := names.NewModelTag(uuid.New().String()) + + tests := []struct { + about string + checkPermissions map[string]string + checkPermissionsError error + jwtService *testJWTService + expectedError string + permissions map[string]interface{} + expectedJWTParams jimmjwx.JWTParams + }{{ + about: "all is well", + jwtService: &testJWTService{}, + expectedJWTParams: jimmjwx.JWTParams{ + Controller: ct.Id(), + User: names.NewUserTag("eve@canonical.com").String(), + Access: map[string]string{ + ct.String(): "superuser", + mt.String(): "admin", + names.NewCloudTag("test-cloud").String(): "add-model", + }, + }, + }, { + about: "check permission fails", + jwtService: &testJWTService{}, + permissions: map[string]interface{}{ + "entity1": "access_level1", + }, + checkPermissionsError: errors.E("a test error"), + expectedError: "a test error", + }, { + about: "additional permissions need checking", + jwtService: &testJWTService{}, + permissions: map[string]interface{}{ + "entity1": "access_level1", + }, + checkPermissions: map[string]string{ + "entity1": "access_level1", + }, + expectedJWTParams: jimmjwx.JWTParams{ + Controller: ct.Id(), + User: names.NewUserTag("eve@canonical.com").String(), + Access: map[string]string{ + ct.String(): "superuser", + mt.String(): "admin", + names.NewCloudTag("test-cloud").String(): "add-model", + "entity1": "access_level1", + }, + }, + }} + + for _, test := range tests { + generator := jwtgenerator.New( + &testDatabase{ + ctl: dbmodel.Controller{ + CloudRegions: []dbmodel.CloudRegionControllerPriority{{ + CloudRegion: dbmodel.CloudRegion{ + Cloud: dbmodel.Cloud{ + Name: "test-cloud", + }, + }, + }}, + }, + }, + &testAccessChecker{ + modelAccess: map[string]string{ + mt.String(): "admin", + }, + controllerAccess: map[string]string{ + ct.String(): "superuser", + }, + cloudAccess: map[string]string{ + names.NewCloudTag("test-cloud").String(): "add-model", + }, + permissions: test.checkPermissions, + permissionCheckErr: test.checkPermissionsError, + }, + test.jwtService, + ) + generator.SetTags(mt, ct) + + i, err := dbmodel.NewIdentity("eve@canonical.com") + c.Assert(err, qt.IsNil) + _, err = generator.MakeLoginToken(context.Background(), &openfga.User{ + Identity: i, + }) + c.Assert(err, qt.IsNil) + + _, err = generator.MakeToken(context.Background(), test.permissions) + if test.expectedError != "" { + c.Assert(err, qt.ErrorMatches, test.expectedError) + } else { + c.Assert(err, qt.IsNil) + c.Assert(test.jwtService.params, qt.DeepEquals, test.expectedJWTParams) + } + } +} diff --git a/internal/jujuapi/websocket.go b/internal/jujuapi/websocket.go index cf8f278d8..95461c8cd 100644 --- a/internal/jujuapi/websocket.go +++ b/internal/jujuapi/websocket.go @@ -21,6 +21,7 @@ import ( "github.com/canonical/jimm/v3/internal/dbmodel" "github.com/canonical/jimm/v3/internal/errors" "github.com/canonical/jimm/v3/internal/jimm" + "github.com/canonical/jimm/v3/internal/jimm/jwtgenerator" "github.com/canonical/jimm/v3/internal/jimmhttp" jimmRPC "github.com/canonical/jimm/v3/internal/rpc" ) @@ -172,7 +173,7 @@ func modelInfoFromPath(path string) (uuid string, finalPath string, err error) { // We act as a proxier, handling auth on requests before forwarding the // requests to the appropriate Juju controller. func (s apiProxier) ServeWS(ctx context.Context, clientConn *websocket.Conn) { - jwtGenerator := jimm.NewJWTGenerator(s.jimm.Database, s.jimm, s.jimm.JWTService) + jwtGenerator := jwtgenerator.New(s.jimm.Database, s.jimm, s.jimm.JWTService) connectionFunc := controllerConnectionFunc(s, &jwtGenerator) zapctx.Debug(ctx, "Starting proxier") auditLogger := s.jimm.AddAuditLogEntry @@ -191,7 +192,7 @@ func (s apiProxier) ServeWS(ctx context.Context, clientConn *websocket.Conn) { // controllerConnectionFunc returns a function that will be used to // connect to a controller when a client makes a request. -func controllerConnectionFunc(s apiProxier, jwtGenerator *jimm.JWTGenerator) func(context.Context) (jimmRPC.WebsocketConnectionWithMetadata, error) { +func controllerConnectionFunc(s apiProxier, jwtGenerator *jwtgenerator.JWTGenerator) func(context.Context) (jimmRPC.WebsocketConnectionWithMetadata, error) { return func(ctx context.Context) (jimmRPC.WebsocketConnectionWithMetadata, error) { const op = errors.Op("proxy.controllerConnectionFunc") path := jimmhttp.PathElementFromContext(ctx, "path")