diff --git a/cmd/jaas/cmd/updatecredentials_test.go b/cmd/jaas/cmd/updatecredentials_test.go index f7899e3fa..f41119c2b 100644 --- a/cmd/jaas/cmd/updatecredentials_test.go +++ b/cmd/jaas/cmd/updatecredentials_test.go @@ -76,7 +76,7 @@ func (s *updateCredentialsSuite) TestUpdateCredentialsWithLocalCredentials(c *gc models: [] `) - ofgaUser := openfga.NewUser(sa, s.JIMM.AuthorizationClient()) + ofgaUser := openfga.NewUser(sa, s.JIMM.OpenFGAClient) cloudCredentialTag := names.NewCloudCredentialTag("test-cloud/" + clientIDWithDomain + "/test-credentials") cloudCredential2, err := s.JIMM.GetCloudCredential(ctx, ofgaUser, cloudCredentialTag) c.Assert(err, gc.IsNil) diff --git a/internal/jimm/access.go b/internal/jimm/access.go index 98f809786..de85a10a8 100644 --- a/internal/jimm/access.go +++ b/internal/jimm/access.go @@ -8,7 +8,6 @@ import ( "fmt" "regexp" "strings" - "sync" "github.com/canonical/ofga" "github.com/google/uuid" @@ -20,7 +19,6 @@ import ( "github.com/canonical/jimm/v3/internal/db" "github.com/canonical/jimm/v3/internal/dbmodel" "github.com/canonical/jimm/v3/internal/errors" - "github.com/canonical/jimm/v3/internal/jimmjwx" "github.com/canonical/jimm/v3/internal/openfga" ofganames "github.com/canonical/jimm/v3/internal/openfga/names" "github.com/canonical/jimm/v3/internal/servermon" @@ -147,163 +145,6 @@ func ToOfferRelation(accessLevel string) (openfga.Relation, error) { } } -// JWTGeneratorDatabase specifies the database interface used by the -// JWT generator. -type JWTGeneratorDatabase interface { - GetController(ctx context.Context, controller *dbmodel.Controller) error -} - -// JWTGeneratorAccessChecker specifies the access checker used by the JWT -// generator to obtain user's access rights to various entities. -type JWTGeneratorAccessChecker interface { - GetUserModelAccess(context.Context, *openfga.User, names.ModelTag) (string, error) - GetUserControllerAccess(context.Context, *openfga.User, names.ControllerTag) (string, error) - GetUserCloudAccess(context.Context, *openfga.User, names.CloudTag) (string, error) - CheckPermission(context.Context, *openfga.User, map[string]string, map[string]interface{}) (map[string]string, error) -} - -// JWTService specifies the service JWT generator uses to generate JWTs. -type JWTService interface { - NewJWT(context.Context, jimmjwx.JWTParams) ([]byte, error) -} - -// JWTGenerator provides the necessary state and methods to authorize a user and generate JWT tokens. -type JWTGenerator struct { - database JWTGeneratorDatabase - accessChecker JWTGeneratorAccessChecker - jwtService JWTService - - mu sync.Mutex - accessMapCache map[string]string - mt names.ModelTag - ct names.ControllerTag - user *openfga.User - callCount int -} - -// NewJWTGenerator returns a new JwtAuthorizer struct -func NewJWTGenerator(database JWTGeneratorDatabase, accessChecker JWTGeneratorAccessChecker, jwtService JWTService) JWTGenerator { - return JWTGenerator{ - database: database, - accessChecker: accessChecker, - jwtService: jwtService, - } -} - -// SetTags implements TokenGenerator -func (auth *JWTGenerator) SetTags(mt names.ModelTag, ct names.ControllerTag) { - auth.mt = mt - auth.ct = ct -} - -// SetTags implements TokenGenerator -func (auth *JWTGenerator) GetUser() names.UserTag { - if auth.user != nil { - return auth.user.ResourceTag() - } - return names.UserTag{} -} - -// MakeLoginToken authorizes the user based on the provided login requests and returns -// a JWT containing claims about user's access to the controller, model (if applicable) -// and all clouds that the controller knows about. -func (auth *JWTGenerator) MakeLoginToken(ctx context.Context, user *openfga.User) ([]byte, error) { - const op = errors.Op("jimm.MakeLoginToken") - - auth.mu.Lock() - defer auth.mu.Unlock() - - if user == nil { - return nil, errors.E(op, "user not specified") - } - auth.user = user - - // Recreate the accessMapCache to prevent leaking permissions across multiple login requests. - auth.accessMapCache = make(map[string]string) - var authErr error - - var modelAccess string - if auth.mt.Id() == "" { - return nil, errors.E(op, "model not set") - } - modelAccess, authErr = auth.accessChecker.GetUserModelAccess(ctx, auth.user, auth.mt) - if authErr != nil { - zapctx.Error(ctx, "model access check failed", zap.Error(authErr)) - return nil, authErr - } - auth.accessMapCache[auth.mt.String()] = modelAccess - - if auth.ct.Id() == "" { - return nil, errors.E(op, "controller not set") - } - var controllerAccess string - controllerAccess, authErr = auth.accessChecker.GetUserControllerAccess(ctx, auth.user, auth.ct) - if authErr != nil { - return nil, authErr - } - auth.accessMapCache[auth.ct.String()] = controllerAccess - - var ctl dbmodel.Controller - ctl.SetTag(auth.ct) - err := auth.database.GetController(ctx, &ctl) - if err != nil { - zapctx.Error(ctx, "failed to fetch controller", zap.Error(err)) - return nil, errors.E(op, "failed to fetch controller", err) - } - clouds := make(map[names.CloudTag]bool) - for _, cloudRegion := range ctl.CloudRegions { - clouds[cloudRegion.CloudRegion.Cloud.ResourceTag()] = true - } - for cloudTag := range clouds { - accessLevel, err := auth.accessChecker.GetUserCloudAccess(ctx, auth.user, cloudTag) - if err != nil { - zapctx.Error(ctx, "cloud access check failed", zap.Error(err)) - return nil, errors.E(op, "failed to check user's cloud access", err) - } - auth.accessMapCache[cloudTag.String()] = accessLevel - } - - return auth.jwtService.NewJWT(ctx, jimmjwx.JWTParams{ - Controller: auth.ct.Id(), - User: auth.user.Tag().String(), - Access: auth.accessMapCache, - }) -} - -// MakeToken assumes MakeLoginToken has already been called and checks the permissions -// specified in the permissionMap. If the logged in user has all those permissions -// a JWT will be returned with assertions confirming all those permissions. -func (auth *JWTGenerator) MakeToken(ctx context.Context, permissionMap map[string]interface{}) ([]byte, error) { - const op = errors.Op("jimm.MakeToken") - - auth.mu.Lock() - defer auth.mu.Unlock() - - if auth.callCount >= 10 { - return nil, errors.E(op, "Permission check limit exceeded") - } - auth.callCount++ - if auth.user == nil { - return nil, errors.E(op, "User authorization missing.") - } - if permissionMap != nil { - var err error - auth.accessMapCache, err = auth.accessChecker.CheckPermission(ctx, auth.user, auth.accessMapCache, permissionMap) - if err != nil { - return nil, err - } - } - jwt, err := auth.jwtService.NewJWT(ctx, jimmjwx.JWTParams{ - Controller: auth.ct.Id(), - User: auth.user.Tag().String(), - Access: auth.accessMapCache, - }) - if err != nil { - return nil, err - } - return jwt, nil -} - // CheckPermission loops over the desired permissions in desiredPerms and adds these permissions // to cachedPerms if they exist. If the user does not have any of the desired permissions then an // error is returned. diff --git a/internal/jimm/access_test.go b/internal/jimm/access_test.go index 5efbfba74..4d9d83107 100644 --- a/internal/jimm/access_test.go +++ b/internal/jimm/access_test.go @@ -14,98 +14,13 @@ import ( "github.com/juju/names/v5" "github.com/canonical/jimm/v3/internal/dbmodel" - "github.com/canonical/jimm/v3/internal/errors" "github.com/canonical/jimm/v3/internal/jimm" - "github.com/canonical/jimm/v3/internal/jimmjwx" "github.com/canonical/jimm/v3/internal/openfga" ofganames "github.com/canonical/jimm/v3/internal/openfga/names" "github.com/canonical/jimm/v3/internal/testutils/jimmtest" jimmnames "github.com/canonical/jimm/v3/pkg/names" ) -// testDatabase is a database implementation intended for testing the token generator. -type testDatabase struct { - ctl dbmodel.Controller - err error -} - -// GetController implements the GetController method of the JWTGeneratorDatabase interface. -func (tdb *testDatabase) GetController(ctx context.Context, controller *dbmodel.Controller) error { - if tdb.err != nil { - return tdb.err - } - *controller = tdb.ctl - return nil -} - -// testAccessChecker is an access checker implementation intended for testing the -// token generator. -type testAccessChecker struct { - controllerAccess map[string]string - controllerAccessCheckErr error - modelAccess map[string]string - modelAccessCheckErr error - cloudAccess map[string]string - cloudAccessCheckErr error - permissions map[string]string - permissionCheckErr error -} - -// GetUserModelAccess implements the GetUserModelAccess method of the JWTGeneratorAccessChecker interface. -func (tac *testAccessChecker) GetUserModelAccess(ctx context.Context, user *openfga.User, mt names.ModelTag) (string, error) { - if tac.modelAccessCheckErr != nil { - return "", tac.modelAccessCheckErr - } - return tac.modelAccess[mt.String()], nil -} - -// GetUserControllerAccess implements the GetUserControllerAccess method of the JWTGeneratorAccessChecker interface. -func (tac *testAccessChecker) GetUserControllerAccess(ctx context.Context, user *openfga.User, ct names.ControllerTag) (string, error) { - if tac.controllerAccessCheckErr != nil { - return "", tac.controllerAccessCheckErr - } - return tac.controllerAccess[ct.String()], nil -} - -// GetUserCloudAccess implements the GetUserCloudAccess method of the JWTGeneratorAccessChecker interface. -func (tac *testAccessChecker) GetUserCloudAccess(ctx context.Context, user *openfga.User, ct names.CloudTag) (string, error) { - if tac.cloudAccessCheckErr != nil { - return "", tac.cloudAccessCheckErr - } - return tac.cloudAccess[ct.String()], nil -} - -// CheckPermission implements the CheckPermission methods of the JWTGeneratorAccessChecker interface. -func (tac *testAccessChecker) CheckPermission(ctx context.Context, user *openfga.User, accessMap map[string]string, permissions map[string]interface{}) (map[string]string, error) { - if tac.permissionCheckErr != nil { - return nil, tac.permissionCheckErr - } - access := make(map[string]string) - for k, v := range accessMap { - access[k] = v - } - for k, v := range tac.permissions { - access[k] = v - } - return access, nil -} - -// testJWTService is a jwt service implementation intended for testing the token generator. -type testJWTService struct { - newJWTError error - - params jimmjwx.JWTParams -} - -// NewJWT implements the NewJWT methods of the JWTService interface. -func (t *testJWTService) NewJWT(ctx context.Context, params jimmjwx.JWTParams) ([]byte, error) { - if t.newJWTError != nil { - return nil, t.newJWTError - } - t.params = params - return []byte("test jwt"), nil -} - func TestAuditLogAccess(t *testing.T) { c := qt.New(t) @@ -154,263 +69,6 @@ func TestAuditLogAccess(t *testing.T) { c.Assert(err, qt.ErrorMatches, "unauthorized") } -func TestJWTGeneratorMakeLoginToken(t *testing.T) { - c := qt.New(t) - - ct := names.NewControllerTag(uuid.New().String()) - mt := names.NewModelTag(uuid.New().String()) - - tests := []struct { - about string - username string - database *testDatabase - accessChecker *testAccessChecker - jwtService *testJWTService - expectedError string - expectedJWTParams jimmjwx.JWTParams - }{{ - about: "initial login, all is well", - username: "eve@canonical.com", - database: &testDatabase{ - ctl: dbmodel.Controller{ - CloudRegions: []dbmodel.CloudRegionControllerPriority{{ - CloudRegion: dbmodel.CloudRegion{ - Cloud: dbmodel.Cloud{ - Name: "test-cloud", - }, - }, - }}, - }, - }, - accessChecker: &testAccessChecker{ - modelAccess: map[string]string{ - mt.String(): "admin", - }, - controllerAccess: map[string]string{ - ct.String(): "superuser", - }, - cloudAccess: map[string]string{ - names.NewCloudTag("test-cloud").String(): "add-model", - }, - }, - jwtService: &testJWTService{}, - expectedJWTParams: jimmjwx.JWTParams{ - Controller: ct.Id(), - User: names.NewUserTag("eve@canonical.com").String(), - Access: map[string]string{ - ct.String(): "superuser", - mt.String(): "admin", - names.NewCloudTag("test-cloud").String(): "add-model", - }, - }, - }, { - about: "model access check fails", - username: "eve@canonical.com", - accessChecker: &testAccessChecker{ - modelAccessCheckErr: errors.E("a test error"), - }, - jwtService: &testJWTService{}, - expectedError: "a test error", - }, { - about: "controller access check fails", - username: "eve@canonical.com", - accessChecker: &testAccessChecker{ - modelAccess: map[string]string{ - mt.String(): "admin", - }, - controllerAccessCheckErr: errors.E("a test error"), - }, - expectedError: "a test error", - }, { - about: "get controller from db fails", - username: "eve@canonical.com", - database: &testDatabase{ - err: errors.E("a test error"), - }, - accessChecker: &testAccessChecker{ - modelAccess: map[string]string{ - mt.String(): "admin", - }, - controllerAccess: map[string]string{ - ct.String(): "superuser", - }, - }, - expectedError: "failed to fetch controller", - }, { - about: "cloud access check fails", - username: "eve@canonical.com", - database: &testDatabase{ - ctl: dbmodel.Controller{ - CloudRegions: []dbmodel.CloudRegionControllerPriority{{ - CloudRegion: dbmodel.CloudRegion{ - Cloud: dbmodel.Cloud{ - Name: "test-cloud", - }, - }, - }}, - }, - }, - accessChecker: &testAccessChecker{ - modelAccess: map[string]string{ - mt.String(): "admin", - }, - controllerAccess: map[string]string{ - ct.String(): "superuser", - }, - cloudAccessCheckErr: errors.E("a test error"), - }, - expectedError: "failed to check user's cloud access", - }, { - about: "jwt service errors out", - username: "eve@canonical.com", - database: &testDatabase{ - ctl: dbmodel.Controller{ - CloudRegions: []dbmodel.CloudRegionControllerPriority{{ - CloudRegion: dbmodel.CloudRegion{ - Cloud: dbmodel.Cloud{ - Name: "test-cloud", - }, - }, - }}, - }, - }, - accessChecker: &testAccessChecker{ - modelAccess: map[string]string{ - mt.String(): "admin", - }, - controllerAccess: map[string]string{ - ct.String(): "superuser", - }, - cloudAccess: map[string]string{ - names.NewCloudTag("test-cloud").String(): "add-model", - }, - }, - jwtService: &testJWTService{ - newJWTError: errors.E("a test error"), - }, - expectedError: "a test error", - }} - - for _, test := range tests { - generator := jimm.NewJWTGenerator(test.database, test.accessChecker, test.jwtService) - generator.SetTags(mt, ct) - - i, err := dbmodel.NewIdentity(test.username) - c.Assert(err, qt.IsNil) - _, err = generator.MakeLoginToken(context.Background(), &openfga.User{ - Identity: i, - }) - if test.expectedError != "" { - c.Assert(err, qt.ErrorMatches, test.expectedError) - } else { - c.Assert(err, qt.IsNil) - c.Assert(test.jwtService.params, qt.DeepEquals, test.expectedJWTParams) - } - } -} - -func TestJWTGeneratorMakeToken(t *testing.T) { - c := qt.New(t) - - ct := names.NewControllerTag(uuid.New().String()) - mt := names.NewModelTag(uuid.New().String()) - - tests := []struct { - about string - checkPermissions map[string]string - checkPermissionsError error - jwtService *testJWTService - expectedError string - permissions map[string]interface{} - expectedJWTParams jimmjwx.JWTParams - }{{ - about: "all is well", - jwtService: &testJWTService{}, - expectedJWTParams: jimmjwx.JWTParams{ - Controller: ct.Id(), - User: names.NewUserTag("eve@canonical.com").String(), - Access: map[string]string{ - ct.String(): "superuser", - mt.String(): "admin", - names.NewCloudTag("test-cloud").String(): "add-model", - }, - }, - }, { - about: "check permission fails", - jwtService: &testJWTService{}, - permissions: map[string]interface{}{ - "entity1": "access_level1", - }, - checkPermissionsError: errors.E("a test error"), - expectedError: "a test error", - }, { - about: "additional permissions need checking", - jwtService: &testJWTService{}, - permissions: map[string]interface{}{ - "entity1": "access_level1", - }, - checkPermissions: map[string]string{ - "entity1": "access_level1", - }, - expectedJWTParams: jimmjwx.JWTParams{ - Controller: ct.Id(), - User: names.NewUserTag("eve@canonical.com").String(), - Access: map[string]string{ - ct.String(): "superuser", - mt.String(): "admin", - names.NewCloudTag("test-cloud").String(): "add-model", - "entity1": "access_level1", - }, - }, - }} - - for _, test := range tests { - generator := jimm.NewJWTGenerator( - &testDatabase{ - ctl: dbmodel.Controller{ - CloudRegions: []dbmodel.CloudRegionControllerPriority{{ - CloudRegion: dbmodel.CloudRegion{ - Cloud: dbmodel.Cloud{ - Name: "test-cloud", - }, - }, - }}, - }, - }, - &testAccessChecker{ - modelAccess: map[string]string{ - mt.String(): "admin", - }, - controllerAccess: map[string]string{ - ct.String(): "superuser", - }, - cloudAccess: map[string]string{ - names.NewCloudTag("test-cloud").String(): "add-model", - }, - permissions: test.checkPermissions, - permissionCheckErr: test.checkPermissionsError, - }, - test.jwtService, - ) - generator.SetTags(mt, ct) - - i, err := dbmodel.NewIdentity("eve@canonical.com") - c.Assert(err, qt.IsNil) - _, err = generator.MakeLoginToken(context.Background(), &openfga.User{ - Identity: i, - }) - c.Assert(err, qt.IsNil) - - _, err = generator.MakeToken(context.Background(), test.permissions) - if test.expectedError != "" { - c.Assert(err, qt.ErrorMatches, test.expectedError) - } else { - c.Assert(err, qt.IsNil) - c.Assert(test.jwtService.params, qt.DeepEquals, test.expectedJWTParams) - } - } -} - func TestParseAndValidateTag(t *testing.T) { c := qt.New(t) ctx := context.Background() diff --git a/internal/jimm/jimm.go b/internal/jimm/jimm.go index a749fee6f..cca0121d0 100644 --- a/internal/jimm/jimm.go +++ b/internal/jimm/jimm.go @@ -268,21 +268,11 @@ func (j *JIMM) ResourceTag() names.ControllerTag { return names.NewControllerTag(j.UUID) } -// DB returns the database used by JIMM. -func (j *JIMM) DB() *db.Database { - return j.Database -} - // PubsubHub returns the pub-sub hub used for buffering model summaries. func (j *JIMM) PubSubHub() *pubsub.Hub { return j.Pubsub } -// AuthorizationClient return the OpenFGA client used by JIMM. -func (j *JIMM) AuthorizationClient() *openfga.OFGAClient { - return j.OpenFGAClient -} - // RoleManager returns a manager that enables role management. func (j *JIMM) RoleManager() RoleManager { return j.roleManager @@ -293,11 +283,6 @@ func (j *JIMM) GroupManager() GroupManager { return j.groupManager } -// GetCredentialStore returns the credential store used by JIMM. -func (j *JIMM) GetCredentialStore() credentials.CredentialStore { - return j.CredentialStore -} - type permission struct { resource string relation string diff --git a/internal/jimm/jujuauth/jwtgenerator.go b/internal/jimm/jujuauth/jwtgenerator.go new file mode 100644 index 000000000..7924d2ce8 --- /dev/null +++ b/internal/jimm/jujuauth/jwtgenerator.go @@ -0,0 +1,180 @@ +// Copyright 2024 Canonical. + +// Package jujuauth generates JWT tokens to +// authenticate and authorize messages to Juju controllers. +// This package is more specialised than a generic +// JWT token generator as it crafts Juju specific +// permissions that are added as claims to the JWT +// and therefore exists in JIMM's business logic layer. +package jujuauth + +import ( + "context" + "sync" + + "github.com/juju/names/v5" + "github.com/juju/zaputil/zapctx" + "go.uber.org/zap" + + "github.com/canonical/jimm/v3/internal/dbmodel" + "github.com/canonical/jimm/v3/internal/errors" + "github.com/canonical/jimm/v3/internal/jimmjwx" + "github.com/canonical/jimm/v3/internal/openfga" +) + +// GeneratorDatabase specifies the database interface used by the +// JWT generator. +type GeneratorDatabase interface { + GetController(ctx context.Context, controller *dbmodel.Controller) error +} + +// GeneratorAccessChecker specifies the access checker used by the JWT +// generator to obtain user's access rights to various entities. +type GeneratorAccessChecker interface { + GetUserModelAccess(context.Context, *openfga.User, names.ModelTag) (string, error) + GetUserControllerAccess(context.Context, *openfga.User, names.ControllerTag) (string, error) + GetUserCloudAccess(context.Context, *openfga.User, names.CloudTag) (string, error) + CheckPermission(context.Context, *openfga.User, map[string]string, map[string]interface{}) (map[string]string, error) +} + +// JWTService specifies the service JWT generator uses to generate JWTs. +type JWTService interface { + NewJWT(context.Context, jimmjwx.JWTParams) ([]byte, error) +} + +// TokenGenerator provides the necessary state and methods to authorize a user and generate JWT tokens. +type TokenGenerator struct { + database GeneratorDatabase + accessChecker GeneratorAccessChecker + jwtService JWTService + + mu sync.Mutex + accessMapCache map[string]string + mt names.ModelTag + ct names.ControllerTag + user *openfga.User + callCount int +} + +// New returns a new JWTGenerator. +func New(database GeneratorDatabase, accessChecker GeneratorAccessChecker, jwtService JWTService) TokenGenerator { + return TokenGenerator{ + database: database, + accessChecker: accessChecker, + jwtService: jwtService, + } +} + +// SetTags implements TokenGenerator. +func (auth *TokenGenerator) SetTags(mt names.ModelTag, ct names.ControllerTag) { + auth.mt = mt + auth.ct = ct +} + +// SetTags implements TokenGenerator. +func (auth *TokenGenerator) GetUser() names.UserTag { + if auth.user != nil { + return auth.user.ResourceTag() + } + return names.UserTag{} +} + +// MakeLoginToken authorizes the user based on the provided login requests and returns +// a JWT containing claims about user's access to the controller, model (if applicable) +// and all clouds that the controller knows about. +func (auth *TokenGenerator) MakeLoginToken(ctx context.Context, user *openfga.User) ([]byte, error) { + const op = errors.Op("jimm.MakeLoginToken") + + auth.mu.Lock() + defer auth.mu.Unlock() + + if user == nil { + return nil, errors.E(op, "user not specified") + } + auth.user = user + + // Recreate the accessMapCache to prevent leaking permissions across multiple login requests. + auth.accessMapCache = make(map[string]string) + var authErr error + + var modelAccess string + if auth.mt.Id() == "" { + return nil, errors.E(op, "model not set") + } + modelAccess, authErr = auth.accessChecker.GetUserModelAccess(ctx, auth.user, auth.mt) + if authErr != nil { + zapctx.Error(ctx, "model access check failed", zap.Error(authErr)) + return nil, authErr + } + auth.accessMapCache[auth.mt.String()] = modelAccess + + if auth.ct.Id() == "" { + return nil, errors.E(op, "controller not set") + } + var controllerAccess string + controllerAccess, authErr = auth.accessChecker.GetUserControllerAccess(ctx, auth.user, auth.ct) + if authErr != nil { + return nil, authErr + } + auth.accessMapCache[auth.ct.String()] = controllerAccess + + var ctl dbmodel.Controller + ctl.SetTag(auth.ct) + err := auth.database.GetController(ctx, &ctl) + if err != nil { + zapctx.Error(ctx, "failed to fetch controller", zap.Error(err)) + return nil, errors.E(op, "failed to fetch controller", err) + } + clouds := make(map[names.CloudTag]bool) + for _, cloudRegion := range ctl.CloudRegions { + clouds[cloudRegion.CloudRegion.Cloud.ResourceTag()] = true + } + for cloudTag := range clouds { + accessLevel, err := auth.accessChecker.GetUserCloudAccess(ctx, auth.user, cloudTag) + if err != nil { + zapctx.Error(ctx, "cloud access check failed", zap.Error(err)) + return nil, errors.E(op, "failed to check user's cloud access", err) + } + auth.accessMapCache[cloudTag.String()] = accessLevel + } + + return auth.jwtService.NewJWT(ctx, jimmjwx.JWTParams{ + Controller: auth.ct.Id(), + User: auth.user.Tag().String(), + Access: auth.accessMapCache, + }) +} + +// MakeToken assumes MakeLoginToken has already been called and checks the permissions +// specified in the permissionMap. If the logged in user has all those permissions +// a JWT will be returned with assertions confirming all those permissions. +func (auth *TokenGenerator) MakeToken(ctx context.Context, permissionMap map[string]interface{}) ([]byte, error) { + const op = errors.Op("jimm.MakeToken") + + auth.mu.Lock() + defer auth.mu.Unlock() + + if auth.callCount >= 10 { + return nil, errors.E(op, "Permission check limit exceeded") + } + auth.callCount++ + if auth.user == nil { + return nil, errors.E(op, "User authorization missing.") + } + if permissionMap != nil { + var err error + auth.accessMapCache, err = auth.accessChecker.CheckPermission(ctx, auth.user, auth.accessMapCache, permissionMap) + if err != nil { + return nil, err + } + } + jwt, err := auth.jwtService.NewJWT(ctx, jimmjwx.JWTParams{ + Controller: auth.ct.Id(), + User: auth.user.Tag().String(), + Access: auth.accessMapCache, + }) + if err != nil { + return nil, err + } + return jwt, nil +} diff --git a/internal/jimm/jujuauth/jwtgenerator_test.go b/internal/jimm/jujuauth/jwtgenerator_test.go new file mode 100644 index 000000000..f04519486 --- /dev/null +++ b/internal/jimm/jujuauth/jwtgenerator_test.go @@ -0,0 +1,358 @@ +// Copyright 2024 Canonical. + +package jujuauth_test + +import ( + "context" + "testing" + + qt "github.com/frankban/quicktest" + "github.com/google/uuid" + "github.com/juju/names/v5" + + "github.com/canonical/jimm/v3/internal/dbmodel" + "github.com/canonical/jimm/v3/internal/errors" + "github.com/canonical/jimm/v3/internal/jimm/jujuauth" + "github.com/canonical/jimm/v3/internal/jimmjwx" + "github.com/canonical/jimm/v3/internal/openfga" +) + +// testDatabase is a database implementation intended for testing the token generator. +type testDatabase struct { + ctl dbmodel.Controller + err error +} + +// GetController implements the GetController method of the JWTGeneratorDatabase interface. +func (tdb *testDatabase) GetController(ctx context.Context, controller *dbmodel.Controller) error { + if tdb.err != nil { + return tdb.err + } + *controller = tdb.ctl + return nil +} + +// testAccessChecker is an access checker implementation intended for testing the +// token generator. +type testAccessChecker struct { + controllerAccess map[string]string + controllerAccessCheckErr error + modelAccess map[string]string + modelAccessCheckErr error + cloudAccess map[string]string + cloudAccessCheckErr error + permissions map[string]string + permissionCheckErr error +} + +// GetUserModelAccess implements the GetUserModelAccess method of the JWTGeneratorAccessChecker interface. +func (tac *testAccessChecker) GetUserModelAccess(ctx context.Context, user *openfga.User, mt names.ModelTag) (string, error) { + if tac.modelAccessCheckErr != nil { + return "", tac.modelAccessCheckErr + } + return tac.modelAccess[mt.String()], nil +} + +// GetUserControllerAccess implements the GetUserControllerAccess method of the JWTGeneratorAccessChecker interface. +func (tac *testAccessChecker) GetUserControllerAccess(ctx context.Context, user *openfga.User, ct names.ControllerTag) (string, error) { + if tac.controllerAccessCheckErr != nil { + return "", tac.controllerAccessCheckErr + } + return tac.controllerAccess[ct.String()], nil +} + +// GetUserCloudAccess implements the GetUserCloudAccess method of the JWTGeneratorAccessChecker interface. +func (tac *testAccessChecker) GetUserCloudAccess(ctx context.Context, user *openfga.User, ct names.CloudTag) (string, error) { + if tac.cloudAccessCheckErr != nil { + return "", tac.cloudAccessCheckErr + } + return tac.cloudAccess[ct.String()], nil +} + +// CheckPermission implements the CheckPermission methods of the JWTGeneratorAccessChecker interface. +func (tac *testAccessChecker) CheckPermission(ctx context.Context, user *openfga.User, accessMap map[string]string, permissions map[string]interface{}) (map[string]string, error) { + if tac.permissionCheckErr != nil { + return nil, tac.permissionCheckErr + } + access := make(map[string]string) + for k, v := range accessMap { + access[k] = v + } + for k, v := range tac.permissions { + access[k] = v + } + return access, nil +} + +// testJWTService is a jwt service implementation intended for testing the token generator. +type testJWTService struct { + newJWTError error + + params jimmjwx.JWTParams +} + +// NewJWT implements the NewJWT methods of the JWTService interface. +func (t *testJWTService) NewJWT(ctx context.Context, params jimmjwx.JWTParams) ([]byte, error) { + if t.newJWTError != nil { + return nil, t.newJWTError + } + t.params = params + return []byte("test jwt"), nil +} + +func TestJWTGeneratorMakeLoginToken(t *testing.T) { + c := qt.New(t) + + ct := names.NewControllerTag(uuid.New().String()) + mt := names.NewModelTag(uuid.New().String()) + + tests := []struct { + about string + username string + database *testDatabase + accessChecker *testAccessChecker + jwtService *testJWTService + expectedError string + expectedJWTParams jimmjwx.JWTParams + }{{ + about: "initial login, all is well", + username: "eve@canonical.com", + database: &testDatabase{ + ctl: dbmodel.Controller{ + CloudRegions: []dbmodel.CloudRegionControllerPriority{{ + CloudRegion: dbmodel.CloudRegion{ + Cloud: dbmodel.Cloud{ + Name: "test-cloud", + }, + }, + }}, + }, + }, + accessChecker: &testAccessChecker{ + modelAccess: map[string]string{ + mt.String(): "admin", + }, + controllerAccess: map[string]string{ + ct.String(): "superuser", + }, + cloudAccess: map[string]string{ + names.NewCloudTag("test-cloud").String(): "add-model", + }, + }, + jwtService: &testJWTService{}, + expectedJWTParams: jimmjwx.JWTParams{ + Controller: ct.Id(), + User: names.NewUserTag("eve@canonical.com").String(), + Access: map[string]string{ + ct.String(): "superuser", + mt.String(): "admin", + names.NewCloudTag("test-cloud").String(): "add-model", + }, + }, + }, { + about: "model access check fails", + username: "eve@canonical.com", + accessChecker: &testAccessChecker{ + modelAccessCheckErr: errors.E("a test error"), + }, + jwtService: &testJWTService{}, + expectedError: "a test error", + }, { + about: "controller access check fails", + username: "eve@canonical.com", + accessChecker: &testAccessChecker{ + modelAccess: map[string]string{ + mt.String(): "admin", + }, + controllerAccessCheckErr: errors.E("a test error"), + }, + expectedError: "a test error", + }, { + about: "get controller from db fails", + username: "eve@canonical.com", + database: &testDatabase{ + err: errors.E("a test error"), + }, + accessChecker: &testAccessChecker{ + modelAccess: map[string]string{ + mt.String(): "admin", + }, + controllerAccess: map[string]string{ + ct.String(): "superuser", + }, + }, + expectedError: "failed to fetch controller", + }, { + about: "cloud access check fails", + username: "eve@canonical.com", + database: &testDatabase{ + ctl: dbmodel.Controller{ + CloudRegions: []dbmodel.CloudRegionControllerPriority{{ + CloudRegion: dbmodel.CloudRegion{ + Cloud: dbmodel.Cloud{ + Name: "test-cloud", + }, + }, + }}, + }, + }, + accessChecker: &testAccessChecker{ + modelAccess: map[string]string{ + mt.String(): "admin", + }, + controllerAccess: map[string]string{ + ct.String(): "superuser", + }, + cloudAccessCheckErr: errors.E("a test error"), + }, + expectedError: "failed to check user's cloud access", + }, { + about: "jwt service errors out", + username: "eve@canonical.com", + database: &testDatabase{ + ctl: dbmodel.Controller{ + CloudRegions: []dbmodel.CloudRegionControllerPriority{{ + CloudRegion: dbmodel.CloudRegion{ + Cloud: dbmodel.Cloud{ + Name: "test-cloud", + }, + }, + }}, + }, + }, + accessChecker: &testAccessChecker{ + modelAccess: map[string]string{ + mt.String(): "admin", + }, + controllerAccess: map[string]string{ + ct.String(): "superuser", + }, + cloudAccess: map[string]string{ + names.NewCloudTag("test-cloud").String(): "add-model", + }, + }, + jwtService: &testJWTService{ + newJWTError: errors.E("a test error"), + }, + expectedError: "a test error", + }} + + for _, test := range tests { + generator := jujuauth.New(test.database, test.accessChecker, test.jwtService) + generator.SetTags(mt, ct) + + i, err := dbmodel.NewIdentity(test.username) + c.Assert(err, qt.IsNil) + _, err = generator.MakeLoginToken(context.Background(), &openfga.User{ + Identity: i, + }) + if test.expectedError != "" { + c.Assert(err, qt.ErrorMatches, test.expectedError) + } else { + c.Assert(err, qt.IsNil) + c.Assert(test.jwtService.params, qt.DeepEquals, test.expectedJWTParams) + } + } +} + +func TestJWTGeneratorMakeToken(t *testing.T) { + c := qt.New(t) + + ct := names.NewControllerTag(uuid.New().String()) + mt := names.NewModelTag(uuid.New().String()) + + tests := []struct { + about string + checkPermissions map[string]string + checkPermissionsError error + jwtService *testJWTService + expectedError string + permissions map[string]interface{} + expectedJWTParams jimmjwx.JWTParams + }{{ + about: "all is well", + jwtService: &testJWTService{}, + expectedJWTParams: jimmjwx.JWTParams{ + Controller: ct.Id(), + User: names.NewUserTag("eve@canonical.com").String(), + Access: map[string]string{ + ct.String(): "superuser", + mt.String(): "admin", + names.NewCloudTag("test-cloud").String(): "add-model", + }, + }, + }, { + about: "check permission fails", + jwtService: &testJWTService{}, + permissions: map[string]interface{}{ + "entity1": "access_level1", + }, + checkPermissionsError: errors.E("a test error"), + expectedError: "a test error", + }, { + about: "additional permissions need checking", + jwtService: &testJWTService{}, + permissions: map[string]interface{}{ + "entity1": "access_level1", + }, + checkPermissions: map[string]string{ + "entity1": "access_level1", + }, + expectedJWTParams: jimmjwx.JWTParams{ + Controller: ct.Id(), + User: names.NewUserTag("eve@canonical.com").String(), + Access: map[string]string{ + ct.String(): "superuser", + mt.String(): "admin", + names.NewCloudTag("test-cloud").String(): "add-model", + "entity1": "access_level1", + }, + }, + }} + + for _, test := range tests { + generator := jujuauth.New( + &testDatabase{ + ctl: dbmodel.Controller{ + CloudRegions: []dbmodel.CloudRegionControllerPriority{{ + CloudRegion: dbmodel.CloudRegion{ + Cloud: dbmodel.Cloud{ + Name: "test-cloud", + }, + }, + }}, + }, + }, + &testAccessChecker{ + modelAccess: map[string]string{ + mt.String(): "admin", + }, + controllerAccess: map[string]string{ + ct.String(): "superuser", + }, + cloudAccess: map[string]string{ + names.NewCloudTag("test-cloud").String(): "add-model", + }, + permissions: test.checkPermissions, + permissionCheckErr: test.checkPermissionsError, + }, + test.jwtService, + ) + generator.SetTags(mt, ct) + + i, err := dbmodel.NewIdentity("eve@canonical.com") + c.Assert(err, qt.IsNil) + _, err = generator.MakeLoginToken(context.Background(), &openfga.User{ + Identity: i, + }) + c.Assert(err, qt.IsNil) + + _, err = generator.MakeToken(context.Background(), test.permissions) + if test.expectedError != "" { + c.Assert(err, qt.ErrorMatches, test.expectedError) + } else { + c.Assert(err, qt.IsNil) + c.Assert(test.jwtService.params, qt.DeepEquals, test.expectedJWTParams) + } + } +} diff --git a/internal/jimm/model.go b/internal/jimm/model.go index ea6c16d7d..482086160 100644 --- a/internal/jimm/model.go +++ b/internal/jimm/model.go @@ -1355,7 +1355,7 @@ func (j *JIMM) ListModels(ctx context.Context, user *openfga.User) ([]base.UserM } // Get the models from the database - models, err := j.DB().GetModelsByUUID(ctx, uuids) + models, err := j.Database.GetModelsByUUID(ctx, uuids) if err != nil { return nil, errors.E(op, err, "failed to get models by uuid") } diff --git a/internal/jimm/model_cleanup.go b/internal/jimm/model_cleanup.go index 8b5aafa87..2f3f04310 100644 --- a/internal/jimm/model_cleanup.go +++ b/internal/jimm/model_cleanup.go @@ -23,7 +23,7 @@ func (j *JIMM) CleanupDyingModels(ctx context.Context) (err error) { durationObserver := servermon.DurationObserver(servermon.JimmMethodsDurationHistogram, string(op)) defer durationObserver() - err = j.DB().ForEachModel(ctx, func(m *dbmodel.Model) error { + err = j.Database.ForEachModel(ctx, func(m *dbmodel.Model) error { if m.Life != state.Dying.String() { return nil } @@ -37,7 +37,7 @@ func (j *JIMM) CleanupDyingModels(ctx context.Context) (err error) { if err := api.ModelInfo(ctx, &jujuparams.ModelInfo{UUID: m.UUID.String}); err != nil { // Some versions of juju return unauthorized for models that cannot be found. if errors.ErrorCode(err) == errors.CodeNotFound || errors.ErrorCode(err) == errors.CodeUnauthorized { - if err := j.DB().DeleteModel(ctx, m); err != nil { + if err := j.Database.DeleteModel(ctx, m); err != nil { zapctx.Error(ctx, fmt.Sprintf("cannot delete model %s: %s\n", m.UUID.String, err)) } else { return nil diff --git a/internal/jimm/model_cleanup_test.go b/internal/jimm/model_cleanup_test.go index 3241a9dbd..9e8354ffd 100644 --- a/internal/jimm/model_cleanup_test.go +++ b/internal/jimm/model_cleanup_test.go @@ -135,7 +135,7 @@ func (s *modelCleanupSuite) TestPollModelsDying(c *qt.C) { Valid: true, }, } - err = s.jimm.DB().GetModel(ctx, &model) + err = s.jimm.Database.GetModel(ctx, &model) c.Assert(err, qt.ErrorMatches, "model not found") model = dbmodel.Model{ @@ -144,7 +144,7 @@ func (s *modelCleanupSuite) TestPollModelsDying(c *qt.C) { Valid: true, }, } - err = s.jimm.DB().GetModel(ctx, &model) + err = s.jimm.Database.GetModel(ctx, &model) c.Assert(err, qt.IsNil) } @@ -174,7 +174,7 @@ func (s *modelCleanupSuite) TestPollModelsDyingControllerErrors(c *qt.C) { Valid: true, }, } - err = s.jimm.DB().GetModel(ctx, &model) + err = s.jimm.Database.GetModel(ctx, &model) c.Assert(err, qt.IsNil) c.Assert(model.Life, qt.Equals, state.Dying.String()) } diff --git a/internal/jimmhttp/httpproxy_handler.go b/internal/jimmhttp/httpproxy_handler.go index 5e8070bf5..9067aa0ec 100644 --- a/internal/jimmhttp/httpproxy_handler.go +++ b/internal/jimmhttp/httpproxy_handler.go @@ -63,7 +63,7 @@ func (hph *HTTPProxyHandler) ProxyHTTP(w http.ResponseWriter, req *http.Request) writeError(ctx, w, http.StatusNotFound, err, "cannot get model") return } - u, p, err := hph.jimm.GetCredentialStore().GetControllerCredentials(ctx, model.Controller.Name) + u, p, err := hph.jimm.CredentialStore.GetControllerCredentials(ctx, model.Controller.Name) if err != nil { writeError(ctx, w, http.StatusNotFound, err, "cannot retrieve credentials") return diff --git a/internal/jimmhttp/httpproxy_handler_test.go b/internal/jimmhttp/httpproxy_handler_test.go index b84c94358..a9d07e669 100644 --- a/internal/jimmhttp/httpproxy_handler_test.go +++ b/internal/jimmhttp/httpproxy_handler_test.go @@ -66,14 +66,14 @@ func (s *httpProxySuite) SetUpTest(c *gc.C) { err := s.JIMM.Database.GetModel(ctx, model) c.Assert(err, gc.IsNil) s.model = model - err = s.JIMM.GetCredentialStore().PutControllerCredentials(ctx, model.Controller.Name, "user", "psw") + err = s.JIMM.CredentialStore.PutControllerCredentials(ctx, model.Controller.Name, "user", "psw") c.Assert(err, gc.IsNil) } func (s *httpProxySuite) TestHTTPProxyHandler(c *gc.C) { ctx := context.Background() httpProxier := jimmhttp.NewHTTPProxyHandler(s.JIMM) - expectU, expectP, err := s.JIMM.GetCredentialStore().GetControllerCredentials(ctx, s.model.Controller.Name) + expectU, expectP, err := s.JIMM.CredentialStore.GetControllerCredentials(ctx, s.model.Controller.Name) c.Assert(err, gc.IsNil) // we expect the controller to respond with TLS fakeController := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { diff --git a/internal/jujuapi/access_control_test.go b/internal/jujuapi/access_control_test.go index d8e5b8491..a26006411 100644 --- a/internal/jujuapi/access_control_test.go +++ b/internal/jujuapi/access_control_test.go @@ -891,7 +891,7 @@ func (s *accessControlSuite) TestListRelationshipTuplesNoUUIDResolution(c *gc.C) c.Assert(err, jc.ErrorIsNil) groupOrange := dbmodel.GroupEntry{Name: "orange"} - err = s.JIMM.DB().GetGroup(ctx, &groupOrange) + err = s.JIMM.Database.GetGroup(ctx, &groupOrange) c.Assert(err, jc.ErrorIsNil) expected := []apiparams.RelationshipTuple{{ Object: "group-" + groupOrange.UUID + "#member", diff --git a/internal/jujuapi/interface.go b/internal/jujuapi/interface.go index 1565abf9d..e57b1341c 100644 --- a/internal/jujuapi/interface.go +++ b/internal/jujuapi/interface.go @@ -15,7 +15,6 @@ import ( "github.com/canonical/jimm/v3/internal/db" "github.com/canonical/jimm/v3/internal/dbmodel" "github.com/canonical/jimm/v3/internal/jimm" - "github.com/canonical/jimm/v3/internal/jimm/credentials" "github.com/canonical/jimm/v3/internal/openfga" ofganames "github.com/canonical/jimm/v3/internal/openfga/names" "github.com/canonical/jimm/v3/internal/pubsub" @@ -45,7 +44,6 @@ type JIMM interface { GetCloud(ctx context.Context, u *openfga.User, tag names.CloudTag) (dbmodel.Cloud, error) GetCloudCredential(ctx context.Context, user *openfga.User, tag names.CloudCredentialTag) (*dbmodel.CloudCredential, error) GetCloudCredentialAttributes(ctx context.Context, u *openfga.User, cred *dbmodel.CloudCredential, hidden bool) (attrs map[string]string, redacted []string, err error) - GetCredentialStore() credentials.CredentialStore RoleManager() jimm.RoleManager GroupManager() jimm.GroupManager GetJimmControllerAccess(ctx context.Context, user *openfga.User, tag names.UserTag) (string, error) diff --git a/internal/jujuapi/websocket.go b/internal/jujuapi/websocket.go index cf8f278d8..3f8d70efa 100644 --- a/internal/jujuapi/websocket.go +++ b/internal/jujuapi/websocket.go @@ -21,6 +21,7 @@ import ( "github.com/canonical/jimm/v3/internal/dbmodel" "github.com/canonical/jimm/v3/internal/errors" "github.com/canonical/jimm/v3/internal/jimm" + "github.com/canonical/jimm/v3/internal/jimm/jujuauth" "github.com/canonical/jimm/v3/internal/jimmhttp" jimmRPC "github.com/canonical/jimm/v3/internal/rpc" ) @@ -172,7 +173,7 @@ func modelInfoFromPath(path string) (uuid string, finalPath string, err error) { // We act as a proxier, handling auth on requests before forwarding the // requests to the appropriate Juju controller. func (s apiProxier) ServeWS(ctx context.Context, clientConn *websocket.Conn) { - jwtGenerator := jimm.NewJWTGenerator(s.jimm.Database, s.jimm, s.jimm.JWTService) + jwtGenerator := jujuauth.New(s.jimm.Database, s.jimm, s.jimm.JWTService) connectionFunc := controllerConnectionFunc(s, &jwtGenerator) zapctx.Debug(ctx, "Starting proxier") auditLogger := s.jimm.AddAuditLogEntry @@ -191,7 +192,7 @@ func (s apiProxier) ServeWS(ctx context.Context, clientConn *websocket.Conn) { // controllerConnectionFunc returns a function that will be used to // connect to a controller when a client makes a request. -func controllerConnectionFunc(s apiProxier, jwtGenerator *jimm.JWTGenerator) func(context.Context) (jimmRPC.WebsocketConnectionWithMetadata, error) { +func controllerConnectionFunc(s apiProxier, jwtGenerator *jujuauth.TokenGenerator) func(context.Context) (jimmRPC.WebsocketConnectionWithMetadata, error) { return func(ctx context.Context) (jimmRPC.WebsocketConnectionWithMetadata, error) { const op = errors.Op("proxy.controllerConnectionFunc") path := jimmhttp.PathElementFromContext(ctx, "path")