Skip to content

Commit

Permalink
Adds new login method handling to the model proxy.
Browse files Browse the repository at this point in the history
  • Loading branch information
alesstimec committed Mar 18, 2024
1 parent 12c20db commit 4627413
Show file tree
Hide file tree
Showing 12 changed files with 768 additions and 214 deletions.
19 changes: 7 additions & 12 deletions internal/jimm/access.go
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,6 @@ type JWTService interface {

// JWTGenerator provides the necessary state and methods to authorize a user and generate JWT tokens.
type JWTGenerator struct {
authenticator Authenticator
database JWTGeneratorDatabase
accessChecker JWTGeneratorAccessChecker
jwtService JWTService
Expand All @@ -190,9 +189,8 @@ type JWTGenerator struct {
}

// NewJWTGenerator returns a new JwtAuthorizer struct
func NewJWTGenerator(authenticator Authenticator, database JWTGeneratorDatabase, accessChecker JWTGeneratorAccessChecker, jwtService JWTService) JWTGenerator {
func NewJWTGenerator(database JWTGeneratorDatabase, accessChecker JWTGeneratorAccessChecker, jwtService JWTService) JWTGenerator {
return JWTGenerator{
authenticator: authenticator,
database: database,
accessChecker: accessChecker,
jwtService: jwtService,
Expand All @@ -216,24 +214,21 @@ func (auth *JWTGenerator) GetUser() names.UserTag {
// MakeLoginToken authorizes the user based on the provided login requests and returns
// a JWT containing claims about user's access to the controller, model (if applicable)
// and all clouds that the controller knows about.
func (auth *JWTGenerator) MakeLoginToken(ctx context.Context, req *jujuparams.LoginRequest) ([]byte, error) {
func (auth *JWTGenerator) MakeLoginToken(ctx context.Context, user *openfga.User) ([]byte, error) {
const op = errors.Op("jimm.MakeLoginToken")

auth.mu.Lock()
defer auth.mu.Unlock()

if req == nil {
return nil, errors.E(op, "missing login request.")
if user == nil {
return nil, errors.E(op, "user not specified")
}
auth.user = user

// Recreate the accessMapCache to prevent leaking permissions across multiple login requests.
auth.accessMapCache = make(map[string]string)
var authErr error
// TODO(CSS-7331) Refactor model proxy for new login methods
auth.user, authErr = auth.authenticator.Authenticate(ctx, req)
if authErr != nil {
zapctx.Error(ctx, "authentication failed", zap.Error(authErr))
return nil, authErr
}

var modelAccess string
if auth.mt.Id() == "" {
return nil, errors.E(op, "model not set")
Expand Down
62 changes: 24 additions & 38 deletions internal/jimm/access_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -195,17 +195,15 @@ func TestJWTGeneratorMakeLoginToken(t *testing.T) {

tests := []struct {
about string
authenticator *testAuthenticator
username string
database *testDatabase
accessChecker *testAccessChecker
jwtService *testJWTService
expectedError string
expectedJWTParams jimmjwx.JWTParams
}{{
about: "initial login, all is well",
authenticator: &testAuthenticator{
username: "[email protected]",
},
about: "initial login, all is well",
username: "[email protected]",
database: &testDatabase{
ctl: dbmodel.Controller{
CloudRegions: []dbmodel.CloudRegionControllerPriority{{
Expand Down Expand Up @@ -239,27 +237,16 @@ func TestJWTGeneratorMakeLoginToken(t *testing.T) {
},
},
}, {
about: "authorization fails",
authenticator: &testAuthenticator{
username: "[email protected]",
err: errors.E("a test error"),
},
expectedError: "a test error",
}, {
about: "model access check fails",
authenticator: &testAuthenticator{
username: "[email protected]",
},
about: "model access check fails",
username: "[email protected]",
accessChecker: &testAccessChecker{
modelAccessCheckErr: errors.E("a test error"),
},
jwtService: &testJWTService{},
expectedError: "a test error",
}, {
about: "controller access check fails",
authenticator: &testAuthenticator{
username: "[email protected]",
},
about: "controller access check fails",
username: "[email protected]",
accessChecker: &testAccessChecker{
modelAccess: map[string]string{
mt.String(): "admin",
Expand All @@ -268,10 +255,8 @@ func TestJWTGeneratorMakeLoginToken(t *testing.T) {
},
expectedError: "a test error",
}, {
about: "get controller from db fails",
authenticator: &testAuthenticator{
username: "[email protected]",
},
about: "get controller from db fails",
username: "[email protected]",
database: &testDatabase{
err: errors.E("a test error"),
},
Expand All @@ -285,10 +270,8 @@ func TestJWTGeneratorMakeLoginToken(t *testing.T) {
},
expectedError: "failed to fetch controller",
}, {
about: "cloud access check fails",
authenticator: &testAuthenticator{
username: "[email protected]",
},
about: "cloud access check fails",
username: "[email protected]",
database: &testDatabase{
ctl: dbmodel.Controller{
CloudRegions: []dbmodel.CloudRegionControllerPriority{{
Expand All @@ -311,10 +294,8 @@ func TestJWTGeneratorMakeLoginToken(t *testing.T) {
},
expectedError: "failed to check user's cloud access",
}, {
about: "jwt service errors out",
authenticator: &testAuthenticator{
username: "[email protected]",
},
about: "jwt service errors out",
username: "[email protected]",
database: &testDatabase{
ctl: dbmodel.Controller{
CloudRegions: []dbmodel.CloudRegionControllerPriority{{
Expand Down Expand Up @@ -344,10 +325,14 @@ func TestJWTGeneratorMakeLoginToken(t *testing.T) {
}}

for _, test := range tests {
generator := jimm.NewJWTGenerator(test.authenticator, test.database, test.accessChecker, test.jwtService)
generator := jimm.NewJWTGenerator(test.database, test.accessChecker, test.jwtService)
generator.SetTags(mt, ct)

_, err := generator.MakeLoginToken(context.Background(), &jujuparams.LoginRequest{})
_, err := generator.MakeLoginToken(context.Background(), &openfga.User{
Identity: &dbmodel.Identity{
Name: test.username,
},
})
if test.expectedError != "" {
c.Assert(err, qt.ErrorMatches, test.expectedError)
} else {
Expand Down Expand Up @@ -414,9 +399,6 @@ func TestJWTGeneratorMakeToken(t *testing.T) {

for _, test := range tests {
generator := jimm.NewJWTGenerator(
&testAuthenticator{
username: "[email protected]",
},
&testDatabase{
ctl: dbmodel.Controller{
CloudRegions: []dbmodel.CloudRegionControllerPriority{{
Expand Down Expand Up @@ -445,7 +427,11 @@ func TestJWTGeneratorMakeToken(t *testing.T) {
)
generator.SetTags(mt, ct)

_, err := generator.MakeLoginToken(context.Background(), &jujuparams.LoginRequest{})
_, err := generator.MakeLoginToken(context.Background(), &openfga.User{
Identity: &dbmodel.Identity{
Name: "[email protected]",
},
})
c.Assert(err, qt.IsNil)

_, err = generator.MakeToken(context.Background(), test.permissions)
Expand Down
55 changes: 55 additions & 0 deletions internal/jimm/admin.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
// Copyright 2024 Canonical Ltd.

package jimm

import (
"context"

"golang.org/x/oauth2"

"github.com/canonical/jimm/internal/errors"
)

// LoginDevice starts the device login flow.
func LoginDevice(ctx context.Context, authenticator OAuthAuthenticator) (*oauth2.DeviceAuthResponse, error) {
const op = errors.Op("jujuapi.LoginDevice")

deviceResponse, err := authenticator.Device(ctx)
if err != nil {
return nil, errors.E(op, err)
}

return deviceResponse, nil
}

func GetDeviceSessionToken(ctx context.Context, authenticator OAuthAuthenticator, deviceOAuthResponse *oauth2.DeviceAuthResponse) (string, error) {
const op = errors.Op("jujuapi.GetDeviceSessionToken")

token, err := authenticator.DeviceAccessToken(ctx, deviceOAuthResponse)
if err != nil {
return "", errors.E(op, err)
}

idToken, err := authenticator.ExtractAndVerifyIDToken(ctx, token)
if err != nil {
return "", errors.E(op, err)
}

email, err := authenticator.Email(idToken)
if err != nil {
return "", errors.E(op, err)
}

if err := authenticator.UpdateIdentity(ctx, email, token); err != nil {
return "", errors.E(op, err)
}

// TODO(ale8k): Add vault logic to get secret key and generate one
// on start up.
encToken, err := authenticator.MintSessionToken(email, "test-secret")
if err != nil {
return "", errors.E(op, err)
}

return string(encToken), nil
}
7 changes: 0 additions & 7 deletions internal/jimm/jimm.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,13 +115,6 @@ func (j *JIMM) AuthorizationClient() *openfga.OFGAClient {
return j.OpenFGAClient
}

// An Authenticator authenticates login requests.
type Authenticator interface {
// Authenticate processes the given LoginRequest and returns the user
// that has authenticated.
Authenticate(ctx context.Context, req *jujuparams.LoginRequest) (*openfga.User, error)
}

// OAuthAuthenticator is responsible for handling authentication
// via OAuth2.0 AND JWT access tokens to JIMM.
type OAuthAuthenticator interface {
Expand Down
3 changes: 2 additions & 1 deletion internal/jimmtest/gorm.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ func PostgresDB(t Tester, nowFunc func() time.Time) *gorm.DB {
}

suggestedName := "jimm_test_" + t.Name()
t.Logf("suggested db name: %s", suggestedName)
_, dsn, err := createDatabaseFromTemplate(suggestedName, templateDatabaseName)
if err != nil {
t.Fatalf("error creating database (%s): %s", suggestedName, err)
Expand Down Expand Up @@ -197,7 +198,7 @@ func createDatabaseFromTemplate(suggestedName string, templateName string) (stri

dropDatabaseCommand := fmt.Sprintf(`DROP DATABASE IF EXISTS "%s"`, databaseName)
if err := gdb.Exec(dropDatabaseCommand).Error; err != nil {
return "", "", errors.E(err, fmt.Sprintf("error dropping existing database (maybe there's an active connection like psql client): %s", databaseName))
return "", "", errors.E(err, fmt.Sprintf("error dropping existing database (maybe there's an active connection like psql client): %s [%v]", databaseName, err))
}

createDatabaseCommand := fmt.Sprintf(`CREATE DATABASE "%s" TEMPLATE "%s"`, databaseName, templateName)
Expand Down
39 changes: 5 additions & 34 deletions internal/jujuapi/admin.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"github.com/canonical/jimm/api/params"
"github.com/canonical/jimm/internal/auth"
"github.com/canonical/jimm/internal/errors"
"github.com/canonical/jimm/internal/jimm"
"github.com/canonical/jimm/internal/openfga"
)

Expand All @@ -28,12 +29,6 @@ func unsupportedLogin() error {

var facadeInit = make(map[string]func(r *controllerRoot) []int)

// Login implements the Login method on the Admin facade.
func (r *controllerRoot) Login(ctx context.Context, req jujuparams.LoginRequest) (jujuparams.LoginResult, error) {
const op = errors.Op("jujuapi.Login")
return jujuparams.LoginResult{}, errors.E(op, "Invalid login, ensure you are using Juju 3.5+")
}

// LoginDevice starts a device login flow (typically a CLI). It will return a verification URI
// and user code that the user is expected to enter into the verification URI link.
//
Expand All @@ -42,9 +37,8 @@ func (r *controllerRoot) Login(ctx context.Context, req jujuparams.LoginRequest)
func (r *controllerRoot) LoginDevice(ctx context.Context) (params.LoginDeviceResponse, error) {
const op = errors.Op("jujuapi.LoginDevice")
response := params.LoginDeviceResponse{}
authSvc := r.jimm.OAuthAuthenticationService()

deviceResponse, err := authSvc.Device(ctx)
deviceResponse, err := jimm.LoginDevice(ctx, r.jimm.OAuthAuthenticationService())
if err != nil {
return response, errors.E(op, err)
}
Expand All @@ -53,8 +47,8 @@ func (r *controllerRoot) LoginDevice(ctx context.Context) (params.LoginDeviceRes
// happens on the SAME websocket.
r.deviceOAuthResponse = deviceResponse

response.VerificationURI = deviceResponse.VerificationURI
response.UserCode = deviceResponse.UserCode
response.VerificationURI = deviceResponse.VerificationURI

return response, nil
}
Expand All @@ -67,36 +61,13 @@ func (r *controllerRoot) LoginDevice(ctx context.Context) (params.LoginDeviceRes
func (r *controllerRoot) GetDeviceSessionToken(ctx context.Context) (params.GetDeviceSessionTokenResponse, error) {
const op = errors.Op("jujuapi.GetDeviceSessionToken")
response := params.GetDeviceSessionTokenResponse{}
authSvc := r.jimm.OAuthAuthenticationService()

token, err := authSvc.DeviceAccessToken(ctx, r.deviceOAuthResponse)
if err != nil {
return response, errors.E(op, err)
}

idToken, err := authSvc.ExtractAndVerifyIDToken(ctx, token)
if err != nil {
return response, errors.E(op, err)
}

email, err := authSvc.Email(idToken)
if err != nil {
return response, errors.E(op, err)
}

if err := authSvc.UpdateIdentity(ctx, email, token); err != nil {
return response, errors.E(op, err)
}

// TODO(ale8k): Add vault logic to get secret key and generate one
// on start up.
encToken, err := authSvc.MintSessionToken(email, "test-secret")
token, err := jimm.GetDeviceSessionToken(ctx, r.jimm.OAuthAuthenticationService(), r.deviceOAuthResponse)
if err != nil {
return response, errors.E(op, err)
}

response.SessionToken = string(encToken)

response.SessionToken = token
return response, nil
}

Expand Down
2 changes: 1 addition & 1 deletion internal/jujuapi/admin_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ func (s *adminSuite) TestLoginToController(c *gc.C) {
}, "test")
defer conn.Close()
err := conn.Login(nil, "", "", nil)
c.Assert(err, gc.ErrorMatches, "Invalid login, ensure you are using Juju 3\\.5\\+")
c.Assert(err, gc.ErrorMatches, `JIMM does not support login from old clients \(not supported\)`)
var resp jujuparams.RedirectInfoResult
err = conn.APICall("Admin", 3, "", "RedirectInfo", nil, &resp)
c.Assert(jujuparams.ErrCode(err), gc.Equals, jujuparams.CodeNotImplemented)
Expand Down
4 changes: 2 additions & 2 deletions internal/jujuapi/controllerroot.go
Original file line number Diff line number Diff line change
Expand Up @@ -145,8 +145,8 @@ func newControllerRoot(j JIMM, p Params) *controllerRoot {

r.AddMethod("Admin", 1, "Login", rpc.Method(unsupportedLogin))
r.AddMethod("Admin", 2, "Login", rpc.Method(unsupportedLogin))
r.AddMethod("Admin", 3, "Login", rpc.Method(r.Login))
r.AddMethod("Admin", 4, "Login", rpc.Method(r.Login))
r.AddMethod("Admin", 3, "Login", rpc.Method(unsupportedLogin))
r.AddMethod("Admin", 4, "Login", rpc.Method(unsupportedLogin))
r.AddMethod("Admin", 4, "LoginDevice", rpc.Method(r.LoginDevice))
r.AddMethod("Admin", 4, "GetDeviceSessionToken", rpc.Method(r.GetDeviceSessionToken))
r.AddMethod("Admin", 4, "LoginWithSessionToken", rpc.Method(r.LoginWithSessionToken))
Expand Down
Loading

0 comments on commit 4627413

Please sign in to comment.