Skip to content

Commit

Permalink
wip.
Browse files Browse the repository at this point in the history
  • Loading branch information
alesstimec committed Mar 15, 2024
1 parent 0c679a6 commit c868ec5
Show file tree
Hide file tree
Showing 9 changed files with 290 additions and 201 deletions.
19 changes: 7 additions & 12 deletions internal/jimm/access.go
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,6 @@ type JWTService interface {

// JWTGenerator provides the necessary state and methods to authorize a user and generate JWT tokens.
type JWTGenerator struct {
authenticator Authenticator
database JWTGeneratorDatabase
accessChecker JWTGeneratorAccessChecker
jwtService JWTService
Expand All @@ -190,9 +189,8 @@ type JWTGenerator struct {
}

// NewJWTGenerator returns a new JwtAuthorizer struct
func NewJWTGenerator(authenticator Authenticator, database JWTGeneratorDatabase, accessChecker JWTGeneratorAccessChecker, jwtService JWTService) JWTGenerator {
func NewJWTGenerator(database JWTGeneratorDatabase, accessChecker JWTGeneratorAccessChecker, jwtService JWTService) JWTGenerator {
return JWTGenerator{
authenticator: authenticator,
database: database,
accessChecker: accessChecker,
jwtService: jwtService,
Expand All @@ -216,24 +214,21 @@ func (auth *JWTGenerator) GetUser() names.UserTag {
// MakeLoginToken authorizes the user based on the provided login requests and returns
// a JWT containing claims about user's access to the controller, model (if applicable)
// and all clouds that the controller knows about.
func (auth *JWTGenerator) MakeLoginToken(ctx context.Context, req *jujuparams.LoginRequest) ([]byte, error) {
func (auth *JWTGenerator) MakeLoginToken(ctx context.Context, user *openfga.User) ([]byte, error) {
const op = errors.Op("jimm.MakeLoginToken")

auth.mu.Lock()
defer auth.mu.Unlock()

if req == nil {
return nil, errors.E(op, "missing login request.")
if user == nil {
return nil, errors.E(op, "user not specified")
}
auth.user = user

// Recreate the accessMapCache to prevent leaking permissions across multiple login requests.
auth.accessMapCache = make(map[string]string)
var authErr error
// TODO(CSS-7331) Refactor model proxy for new login methods
auth.user, authErr = auth.authenticator.Authenticate(ctx, req)
if authErr != nil {
zapctx.Error(ctx, "authentication failed", zap.Error(authErr))
return nil, authErr
}

var modelAccess string
if auth.mt.Id() == "" {
return nil, errors.E(op, "model not set")
Expand Down
62 changes: 24 additions & 38 deletions internal/jimm/access_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -195,17 +195,15 @@ func TestJWTGeneratorMakeLoginToken(t *testing.T) {

tests := []struct {
about string
authenticator *testAuthenticator
username string
database *testDatabase
accessChecker *testAccessChecker
jwtService *testJWTService
expectedError string
expectedJWTParams jimmjwx.JWTParams
}{{
about: "initial login, all is well",
authenticator: &testAuthenticator{
username: "[email protected]",
},
about: "initial login, all is well",
username: "[email protected]",
database: &testDatabase{
ctl: dbmodel.Controller{
CloudRegions: []dbmodel.CloudRegionControllerPriority{{
Expand Down Expand Up @@ -239,27 +237,16 @@ func TestJWTGeneratorMakeLoginToken(t *testing.T) {
},
},
}, {
about: "authorization fails",
authenticator: &testAuthenticator{
username: "[email protected]",
err: errors.E("a test error"),
},
expectedError: "a test error",
}, {
about: "model access check fails",
authenticator: &testAuthenticator{
username: "[email protected]",
},
about: "model access check fails",
username: "[email protected]",
accessChecker: &testAccessChecker{
modelAccessCheckErr: errors.E("a test error"),
},
jwtService: &testJWTService{},
expectedError: "a test error",
}, {
about: "controller access check fails",
authenticator: &testAuthenticator{
username: "[email protected]",
},
about: "controller access check fails",
username: "[email protected]",
accessChecker: &testAccessChecker{
modelAccess: map[string]string{
mt.String(): "admin",
Expand All @@ -268,10 +255,8 @@ func TestJWTGeneratorMakeLoginToken(t *testing.T) {
},
expectedError: "a test error",
}, {
about: "get controller from db fails",
authenticator: &testAuthenticator{
username: "[email protected]",
},
about: "get controller from db fails",
username: "[email protected]",
database: &testDatabase{
err: errors.E("a test error"),
},
Expand All @@ -285,10 +270,8 @@ func TestJWTGeneratorMakeLoginToken(t *testing.T) {
},
expectedError: "failed to fetch controller",
}, {
about: "cloud access check fails",
authenticator: &testAuthenticator{
username: "[email protected]",
},
about: "cloud access check fails",
username: "[email protected]",
database: &testDatabase{
ctl: dbmodel.Controller{
CloudRegions: []dbmodel.CloudRegionControllerPriority{{
Expand All @@ -311,10 +294,8 @@ func TestJWTGeneratorMakeLoginToken(t *testing.T) {
},
expectedError: "failed to check user's cloud access",
}, {
about: "jwt service errors out",
authenticator: &testAuthenticator{
username: "[email protected]",
},
about: "jwt service errors out",
username: "[email protected]",
database: &testDatabase{
ctl: dbmodel.Controller{
CloudRegions: []dbmodel.CloudRegionControllerPriority{{
Expand Down Expand Up @@ -344,10 +325,14 @@ func TestJWTGeneratorMakeLoginToken(t *testing.T) {
}}

for _, test := range tests {
generator := jimm.NewJWTGenerator(test.authenticator, test.database, test.accessChecker, test.jwtService)
generator := jimm.NewJWTGenerator(test.database, test.accessChecker, test.jwtService)
generator.SetTags(mt, ct)

_, err := generator.MakeLoginToken(context.Background(), &jujuparams.LoginRequest{})
_, err := generator.MakeLoginToken(context.Background(), &openfga.User{
Identity: &dbmodel.Identity{
Name: test.username,
},
})
if test.expectedError != "" {
c.Assert(err, qt.ErrorMatches, test.expectedError)
} else {
Expand Down Expand Up @@ -414,9 +399,6 @@ func TestJWTGeneratorMakeToken(t *testing.T) {

for _, test := range tests {
generator := jimm.NewJWTGenerator(
&testAuthenticator{
username: "[email protected]",
},
&testDatabase{
ctl: dbmodel.Controller{
CloudRegions: []dbmodel.CloudRegionControllerPriority{{
Expand Down Expand Up @@ -445,7 +427,11 @@ func TestJWTGeneratorMakeToken(t *testing.T) {
)
generator.SetTags(mt, ct)

_, err := generator.MakeLoginToken(context.Background(), &jujuparams.LoginRequest{})
_, err := generator.MakeLoginToken(context.Background(), &openfga.User{
Identity: &dbmodel.Identity{
Name: "[email protected]",
},
})
c.Assert(err, qt.IsNil)

_, err = generator.MakeToken(context.Background(), test.permissions)
Expand Down
7 changes: 0 additions & 7 deletions internal/jimm/jimm.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,13 +115,6 @@ func (j *JIMM) AuthorizationClient() *openfga.OFGAClient {
return j.OpenFGAClient
}

// An Authenticator authenticates login requests.
type Authenticator interface {
// Authenticate processes the given LoginRequest and returns the user
// that has authenticated.
Authenticate(ctx context.Context, req *jujuparams.LoginRequest) (*openfga.User, error)
}

// OAuthAuthenticator is responsible for handling authentication
// via OAuth2.0 AND JWT access tokens to JIMM.
type OAuthAuthenticator interface {
Expand Down
39 changes: 5 additions & 34 deletions internal/jujuapi/admin.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"github.com/canonical/jimm/api/params"
"github.com/canonical/jimm/internal/auth"
"github.com/canonical/jimm/internal/errors"
"github.com/canonical/jimm/internal/jimm"
"github.com/canonical/jimm/internal/openfga"
)

Expand All @@ -28,12 +29,6 @@ func unsupportedLogin() error {

var facadeInit = make(map[string]func(r *controllerRoot) []int)

// Login implements the Login method on the Admin facade.
func (r *controllerRoot) Login(ctx context.Context, req jujuparams.LoginRequest) (jujuparams.LoginResult, error) {
const op = errors.Op("jujuapi.Login")
return jujuparams.LoginResult{}, errors.E(op, "Invalid login, ensure you are using Juju 3.5+")
}

// LoginDevice starts a device login flow (typically a CLI). It will return a verification URI
// and user code that the user is expected to enter into the verification URI link.
//
Expand All @@ -42,9 +37,8 @@ func (r *controllerRoot) Login(ctx context.Context, req jujuparams.LoginRequest)
func (r *controllerRoot) LoginDevice(ctx context.Context) (params.LoginDeviceResponse, error) {
const op = errors.Op("jujuapi.LoginDevice")
response := params.LoginDeviceResponse{}
authSvc := r.jimm.OAuthAuthenticationService()

deviceResponse, err := authSvc.Device(ctx)
deviceResponse, err := jimm.LoginDevice(ctx, r.jimm.OAuthAuthenticationService())
if err != nil {
return response, errors.E(op, err)
}
Expand All @@ -53,8 +47,8 @@ func (r *controllerRoot) LoginDevice(ctx context.Context) (params.LoginDeviceRes
// happens on the SAME websocket.
r.deviceOAuthResponse = deviceResponse

response.VerificationURI = deviceResponse.VerificationURI
response.UserCode = deviceResponse.UserCode
response.VerificationURI = deviceResponse.VerificationURI

return response, nil
}
Expand All @@ -67,36 +61,13 @@ func (r *controllerRoot) LoginDevice(ctx context.Context) (params.LoginDeviceRes
func (r *controllerRoot) GetDeviceSessionToken(ctx context.Context) (params.GetDeviceSessionTokenResponse, error) {
const op = errors.Op("jujuapi.GetDeviceSessionToken")
response := params.GetDeviceSessionTokenResponse{}
authSvc := r.jimm.OAuthAuthenticationService()

token, err := authSvc.DeviceAccessToken(ctx, r.deviceOAuthResponse)
if err != nil {
return response, errors.E(op, err)
}

idToken, err := authSvc.ExtractAndVerifyIDToken(ctx, token)
if err != nil {
return response, errors.E(op, err)
}

email, err := authSvc.Email(idToken)
if err != nil {
return response, errors.E(op, err)
}

if err := authSvc.UpdateIdentity(ctx, email, token); err != nil {
return response, errors.E(op, err)
}

// TODO(ale8k): Add vault logic to get secret key and generate one
// on start up.
encToken, err := authSvc.MintSessionToken(email, "test-secret")
token, err := jimm.GetDeviceSessionToken(ctx, r.jimm.OAuthAuthenticationService(), r.deviceOAuthResponse)
if err != nil {
return response, errors.E(op, err)
}

response.SessionToken = string(encToken)

response.SessionToken = token
return response, nil
}

Expand Down
2 changes: 1 addition & 1 deletion internal/jujuapi/admin_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ func (s *adminSuite) TestLoginToController(c *gc.C) {
}, "test")
defer conn.Close()
err := conn.Login(nil, "", "", nil)
c.Assert(err, gc.ErrorMatches, "Invalid login, ensure you are using Juju 3\\.5\\+")
c.Assert(err, gc.ErrorMatches, `JIMM does not support login from old clients \(not supported\)`)
var resp jujuparams.RedirectInfoResult
err = conn.APICall("Admin", 3, "", "RedirectInfo", nil, &resp)
c.Assert(jujuparams.ErrCode(err), gc.Equals, jujuparams.CodeNotImplemented)
Expand Down
4 changes: 2 additions & 2 deletions internal/jujuapi/controllerroot.go
Original file line number Diff line number Diff line change
Expand Up @@ -145,8 +145,8 @@ func newControllerRoot(j JIMM, p Params) *controllerRoot {

r.AddMethod("Admin", 1, "Login", rpc.Method(unsupportedLogin))
r.AddMethod("Admin", 2, "Login", rpc.Method(unsupportedLogin))
r.AddMethod("Admin", 3, "Login", rpc.Method(r.Login))
r.AddMethod("Admin", 4, "Login", rpc.Method(r.Login))
r.AddMethod("Admin", 3, "Login", rpc.Method(unsupportedLogin))
r.AddMethod("Admin", 4, "Login", rpc.Method(unsupportedLogin))
r.AddMethod("Admin", 4, "LoginDevice", rpc.Method(r.LoginDevice))
r.AddMethod("Admin", 4, "GetDeviceSessionToken", rpc.Method(r.GetDeviceSessionToken))
r.AddMethod("Admin", 4, "LoginWithSessionToken", rpc.Method(r.LoginWithSessionToken))
Expand Down
9 changes: 4 additions & 5 deletions internal/jujuapi/websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,7 @@ func modelInfoFromPath(path string) (uuid string, finalPath string, err error) {

// ServeWS implements jimmhttp.WSServer.
func (s modelProxyServer) ServeWS(ctx context.Context, clientConn *websocket.Conn) {
// TODO(CSS-7331) Refactor model proxy for new login methods
jwtGenerator := jimm.NewJWTGenerator(nil, &s.jimm.Database, s.jimm, s.jimm.JWTService)
jwtGenerator := jimm.NewJWTGenerator(&s.jimm.Database, s.jimm, s.jimm.JWTService)
connectionFunc := controllerConnectionFunc(s, &jwtGenerator)
zapctx.Debug(ctx, "Starting proxier")
auditLogger := s.jimm.AddAuditLogEntry
Expand All @@ -140,14 +139,15 @@ func (s modelProxyServer) ServeWS(ctx context.Context, clientConn *websocket.Con
TokenGen: &jwtGenerator,
ConnectController: connectionFunc,
AuditLog: auditLogger,
JIMM: s.jimm,
}
jimmRPC.ProxySockets(ctx, proxyHelpers)
}

// controllerConnectionFunc returns a function that will be used to
// connect to a controller when a client makes a request.
func controllerConnectionFunc(s modelProxyServer, jwtGenerator *jimm.JWTGenerator) func(context.Context) (*websocket.Conn, string, error) {
connectToControllerFunc := func(ctx context.Context) (*websocket.Conn, string, error) {
func controllerConnectionFunc(s modelProxyServer, jwtGenerator *jimm.JWTGenerator) func(context.Context) (jimmRPC.WebsocketConnection, string, error) {
return func(ctx context.Context) (jimmRPC.WebsocketConnection, string, error) {
const op = errors.Op("proxy.controllerConnectionFunc")
path := jimmhttp.PathElementFromContext(ctx, "path")
zapctx.Debug(ctx, "grabbing model info from path", zap.String("path", path))
Expand Down Expand Up @@ -177,7 +177,6 @@ func controllerConnectionFunc(s modelProxyServer, jwtGenerator *jimm.JWTGenerato
fullModelName := m.Controller.Name + "/" + m.Name
return controllerConn, fullModelName, nil
}
return connectToControllerFunc
}

// Use a 64k frame size for the websockets while we need to deal
Expand Down
10 changes: 5 additions & 5 deletions internal/rpc/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@ import (

qt "github.com/frankban/quicktest"
"github.com/gorilla/websocket"
"github.com/juju/juju/rpc/params"
"github.com/juju/names/v5"

"github.com/canonical/jimm/internal/dbmodel"
"github.com/canonical/jimm/internal/errors"
"github.com/canonical/jimm/internal/openfga"
"github.com/canonical/jimm/internal/rpc"
)

Expand Down Expand Up @@ -225,7 +225,7 @@ func TestClientReceiveInvalidMessage(t *testing.T) {

type testTokenGenerator struct{}

func (p *testTokenGenerator) MakeLoginToken(ctx context.Context, req *params.LoginRequest) ([]byte, error) {
func (p *testTokenGenerator) MakeLoginToken(ctx context.Context, user *openfga.User) ([]byte, error) {
return nil, nil
}

Expand All @@ -250,7 +250,7 @@ func TestProxySockets(t *testing.T) {
errChan := make(chan error)
srvJIMM := newServer(func(connClient *websocket.Conn) error {
testTokenGen := testTokenGenerator{}
f := func(context.Context) (*websocket.Conn, string, error) {
f := func(context.Context) (rpc.WebsocketConnection, string, error) {
connController, err := srvController.dialer.DialWebsocket(ctx, srvController.URL)
c.Assert(err, qt.IsNil)
return connController, "TestName", nil
Expand Down Expand Up @@ -297,7 +297,7 @@ func TestCancelProxySockets(t *testing.T) {
errChan := make(chan error)
srvJIMM := newServer(func(connClient *websocket.Conn) error {
testTokenGen := testTokenGenerator{}
f := func(context.Context) (*websocket.Conn, string, error) {
f := func(context.Context) (rpc.WebsocketConnection, string, error) {
connController, err := srvController.dialer.DialWebsocket(ctx, srvController.URL)
c.Assert(err, qt.IsNil)
return connController, "TestName", nil
Expand Down Expand Up @@ -337,7 +337,7 @@ func TestProxySocketsAuditLogs(t *testing.T) {
errChan := make(chan error)
srvJIMM := newServer(func(connClient *websocket.Conn) error {
testTokenGen := testTokenGenerator{}
f := func(context.Context) (*websocket.Conn, string, error) {
f := func(context.Context) (rpc.WebsocketConnection, string, error) {
connController, err := srvController.dialer.DialWebsocket(ctx, srvController.URL)
c.Assert(err, qt.IsNil)
return connController, "TestModelName", nil
Expand Down
Loading

0 comments on commit c868ec5

Please sign in to comment.