diff --git a/internal/jimm/access.go b/internal/jimm/access.go index da937474e..eb7268f38 100644 --- a/internal/jimm/access.go +++ b/internal/jimm/access.go @@ -176,7 +176,6 @@ type JWTService interface { // JWTGenerator provides the necessary state and methods to authorize a user and generate JWT tokens. type JWTGenerator struct { - authenticator Authenticator database JWTGeneratorDatabase accessChecker JWTGeneratorAccessChecker jwtService JWTService @@ -190,9 +189,8 @@ type JWTGenerator struct { } // NewJWTGenerator returns a new JwtAuthorizer struct -func NewJWTGenerator(authenticator Authenticator, database JWTGeneratorDatabase, accessChecker JWTGeneratorAccessChecker, jwtService JWTService) JWTGenerator { +func NewJWTGenerator(database JWTGeneratorDatabase, accessChecker JWTGeneratorAccessChecker, jwtService JWTService) JWTGenerator { return JWTGenerator{ - authenticator: authenticator, database: database, accessChecker: accessChecker, jwtService: jwtService, @@ -216,24 +214,21 @@ func (auth *JWTGenerator) GetUser() names.UserTag { // MakeLoginToken authorizes the user based on the provided login requests and returns // a JWT containing claims about user's access to the controller, model (if applicable) // and all clouds that the controller knows about. -func (auth *JWTGenerator) MakeLoginToken(ctx context.Context, req *jujuparams.LoginRequest) ([]byte, error) { +func (auth *JWTGenerator) MakeLoginToken(ctx context.Context, user *openfga.User) ([]byte, error) { const op = errors.Op("jimm.MakeLoginToken") auth.mu.Lock() defer auth.mu.Unlock() - if req == nil { - return nil, errors.E(op, "missing login request.") + if user == nil { + return nil, errors.E(op, "user not specified") } + auth.user = user + // Recreate the accessMapCache to prevent leaking permissions across multiple login requests. auth.accessMapCache = make(map[string]string) var authErr error - // TODO(CSS-7331) Refactor model proxy for new login methods - auth.user, authErr = auth.authenticator.Authenticate(ctx, req) - if authErr != nil { - zapctx.Error(ctx, "authentication failed", zap.Error(authErr)) - return nil, authErr - } + var modelAccess string if auth.mt.Id() == "" { return nil, errors.E(op, "model not set") diff --git a/internal/jimm/access_test.go b/internal/jimm/access_test.go index 42ebee6a5..7c61174e0 100644 --- a/internal/jimm/access_test.go +++ b/internal/jimm/access_test.go @@ -195,17 +195,15 @@ func TestJWTGeneratorMakeLoginToken(t *testing.T) { tests := []struct { about string - authenticator *testAuthenticator + username string database *testDatabase accessChecker *testAccessChecker jwtService *testJWTService expectedError string expectedJWTParams jimmjwx.JWTParams }{{ - about: "initial login, all is well", - authenticator: &testAuthenticator{ - username: "eve@canonical.com", - }, + about: "initial login, all is well", + username: "eve@canonical.com", database: &testDatabase{ ctl: dbmodel.Controller{ CloudRegions: []dbmodel.CloudRegionControllerPriority{{ @@ -239,27 +237,16 @@ func TestJWTGeneratorMakeLoginToken(t *testing.T) { }, }, }, { - about: "authorization fails", - authenticator: &testAuthenticator{ - username: "eve@canonical.com", - err: errors.E("a test error"), - }, - expectedError: "a test error", - }, { - about: "model access check fails", - authenticator: &testAuthenticator{ - username: "eve@canonical.com", - }, + about: "model access check fails", + username: "eve@canonical.com", accessChecker: &testAccessChecker{ modelAccessCheckErr: errors.E("a test error"), }, jwtService: &testJWTService{}, expectedError: "a test error", }, { - about: "controller access check fails", - authenticator: &testAuthenticator{ - username: "eve@canonical.com", - }, + about: "controller access check fails", + username: "eve@canonical.com", accessChecker: &testAccessChecker{ modelAccess: map[string]string{ mt.String(): "admin", @@ -268,10 +255,8 @@ func TestJWTGeneratorMakeLoginToken(t *testing.T) { }, expectedError: "a test error", }, { - about: "get controller from db fails", - authenticator: &testAuthenticator{ - username: "eve@canonical.com", - }, + about: "get controller from db fails", + username: "eve@canonical.com", database: &testDatabase{ err: errors.E("a test error"), }, @@ -285,10 +270,8 @@ func TestJWTGeneratorMakeLoginToken(t *testing.T) { }, expectedError: "failed to fetch controller", }, { - about: "cloud access check fails", - authenticator: &testAuthenticator{ - username: "eve@canonical.com", - }, + about: "cloud access check fails", + username: "eve@canonical.com", database: &testDatabase{ ctl: dbmodel.Controller{ CloudRegions: []dbmodel.CloudRegionControllerPriority{{ @@ -311,10 +294,8 @@ func TestJWTGeneratorMakeLoginToken(t *testing.T) { }, expectedError: "failed to check user's cloud access", }, { - about: "jwt service errors out", - authenticator: &testAuthenticator{ - username: "eve@canonical.com", - }, + about: "jwt service errors out", + username: "eve@canonical.com", database: &testDatabase{ ctl: dbmodel.Controller{ CloudRegions: []dbmodel.CloudRegionControllerPriority{{ @@ -344,10 +325,14 @@ func TestJWTGeneratorMakeLoginToken(t *testing.T) { }} for _, test := range tests { - generator := jimm.NewJWTGenerator(test.authenticator, test.database, test.accessChecker, test.jwtService) + generator := jimm.NewJWTGenerator(test.database, test.accessChecker, test.jwtService) generator.SetTags(mt, ct) - _, err := generator.MakeLoginToken(context.Background(), &jujuparams.LoginRequest{}) + _, err := generator.MakeLoginToken(context.Background(), &openfga.User{ + Identity: &dbmodel.Identity{ + Name: test.username, + }, + }) if test.expectedError != "" { c.Assert(err, qt.ErrorMatches, test.expectedError) } else { @@ -414,9 +399,6 @@ func TestJWTGeneratorMakeToken(t *testing.T) { for _, test := range tests { generator := jimm.NewJWTGenerator( - &testAuthenticator{ - username: "eve@canonical.com", - }, &testDatabase{ ctl: dbmodel.Controller{ CloudRegions: []dbmodel.CloudRegionControllerPriority{{ @@ -445,7 +427,11 @@ func TestJWTGeneratorMakeToken(t *testing.T) { ) generator.SetTags(mt, ct) - _, err := generator.MakeLoginToken(context.Background(), &jujuparams.LoginRequest{}) + _, err := generator.MakeLoginToken(context.Background(), &openfga.User{ + Identity: &dbmodel.Identity{ + Name: "eve@canonical.com", + }, + }) c.Assert(err, qt.IsNil) _, err = generator.MakeToken(context.Background(), test.permissions) diff --git a/internal/jimm/jimm.go b/internal/jimm/jimm.go index af6ce42f3..ba74814ce 100644 --- a/internal/jimm/jimm.go +++ b/internal/jimm/jimm.go @@ -115,13 +115,6 @@ func (j *JIMM) AuthorizationClient() *openfga.OFGAClient { return j.OpenFGAClient } -// An Authenticator authenticates login requests. -type Authenticator interface { - // Authenticate processes the given LoginRequest and returns the user - // that has authenticated. - Authenticate(ctx context.Context, req *jujuparams.LoginRequest) (*openfga.User, error) -} - // OAuthAuthenticator is responsible for handling authentication // via OAuth2.0 AND JWT access tokens to JIMM. type OAuthAuthenticator interface { diff --git a/internal/jujuapi/admin.go b/internal/jujuapi/admin.go index 80868248b..e473310a9 100644 --- a/internal/jujuapi/admin.go +++ b/internal/jujuapi/admin.go @@ -14,6 +14,7 @@ import ( "github.com/canonical/jimm/api/params" "github.com/canonical/jimm/internal/auth" "github.com/canonical/jimm/internal/errors" + "github.com/canonical/jimm/internal/jimm" "github.com/canonical/jimm/internal/openfga" ) @@ -28,12 +29,6 @@ func unsupportedLogin() error { var facadeInit = make(map[string]func(r *controllerRoot) []int) -// Login implements the Login method on the Admin facade. -func (r *controllerRoot) Login(ctx context.Context, req jujuparams.LoginRequest) (jujuparams.LoginResult, error) { - const op = errors.Op("jujuapi.Login") - return jujuparams.LoginResult{}, errors.E(op, "Invalid login, ensure you are using Juju 3.5+") -} - // LoginDevice starts a device login flow (typically a CLI). It will return a verification URI // and user code that the user is expected to enter into the verification URI link. // @@ -42,9 +37,8 @@ func (r *controllerRoot) Login(ctx context.Context, req jujuparams.LoginRequest) func (r *controllerRoot) LoginDevice(ctx context.Context) (params.LoginDeviceResponse, error) { const op = errors.Op("jujuapi.LoginDevice") response := params.LoginDeviceResponse{} - authSvc := r.jimm.OAuthAuthenticationService() - deviceResponse, err := authSvc.Device(ctx) + deviceResponse, err := jimm.LoginDevice(ctx, r.jimm.OAuthAuthenticationService()) if err != nil { return response, errors.E(op, err) } @@ -53,8 +47,8 @@ func (r *controllerRoot) LoginDevice(ctx context.Context) (params.LoginDeviceRes // happens on the SAME websocket. r.deviceOAuthResponse = deviceResponse - response.VerificationURI = deviceResponse.VerificationURI response.UserCode = deviceResponse.UserCode + response.VerificationURI = deviceResponse.VerificationURI return response, nil } @@ -67,36 +61,13 @@ func (r *controllerRoot) LoginDevice(ctx context.Context) (params.LoginDeviceRes func (r *controllerRoot) GetDeviceSessionToken(ctx context.Context) (params.GetDeviceSessionTokenResponse, error) { const op = errors.Op("jujuapi.GetDeviceSessionToken") response := params.GetDeviceSessionTokenResponse{} - authSvc := r.jimm.OAuthAuthenticationService() - - token, err := authSvc.DeviceAccessToken(ctx, r.deviceOAuthResponse) - if err != nil { - return response, errors.E(op, err) - } - - idToken, err := authSvc.ExtractAndVerifyIDToken(ctx, token) - if err != nil { - return response, errors.E(op, err) - } - - email, err := authSvc.Email(idToken) - if err != nil { - return response, errors.E(op, err) - } - if err := authSvc.UpdateIdentity(ctx, email, token); err != nil { - return response, errors.E(op, err) - } - - // TODO(ale8k): Add vault logic to get secret key and generate one - // on start up. - encToken, err := authSvc.MintSessionToken(email, "test-secret") + token, err := jimm.GetDeviceSessionToken(ctx, r.jimm.OAuthAuthenticationService(), r.deviceOAuthResponse) if err != nil { return response, errors.E(op, err) } - response.SessionToken = string(encToken) - + response.SessionToken = token return response, nil } diff --git a/internal/jujuapi/admin_test.go b/internal/jujuapi/admin_test.go index 4bce1d761..6a704f4ce 100644 --- a/internal/jujuapi/admin_test.go +++ b/internal/jujuapi/admin_test.go @@ -56,7 +56,7 @@ func (s *adminSuite) TestLoginToController(c *gc.C) { }, "test") defer conn.Close() err := conn.Login(nil, "", "", nil) - c.Assert(err, gc.ErrorMatches, "Invalid login, ensure you are using Juju 3\\.5\\+") + c.Assert(err, gc.ErrorMatches, `JIMM does not support login from old clients \(not supported\)`) var resp jujuparams.RedirectInfoResult err = conn.APICall("Admin", 3, "", "RedirectInfo", nil, &resp) c.Assert(jujuparams.ErrCode(err), gc.Equals, jujuparams.CodeNotImplemented) diff --git a/internal/jujuapi/controllerroot.go b/internal/jujuapi/controllerroot.go index 883cd9881..638e68f3c 100644 --- a/internal/jujuapi/controllerroot.go +++ b/internal/jujuapi/controllerroot.go @@ -145,8 +145,8 @@ func newControllerRoot(j JIMM, p Params) *controllerRoot { r.AddMethod("Admin", 1, "Login", rpc.Method(unsupportedLogin)) r.AddMethod("Admin", 2, "Login", rpc.Method(unsupportedLogin)) - r.AddMethod("Admin", 3, "Login", rpc.Method(r.Login)) - r.AddMethod("Admin", 4, "Login", rpc.Method(r.Login)) + r.AddMethod("Admin", 3, "Login", rpc.Method(unsupportedLogin)) + r.AddMethod("Admin", 4, "Login", rpc.Method(unsupportedLogin)) r.AddMethod("Admin", 4, "LoginDevice", rpc.Method(r.LoginDevice)) r.AddMethod("Admin", 4, "GetDeviceSessionToken", rpc.Method(r.GetDeviceSessionToken)) r.AddMethod("Admin", 4, "LoginWithSessionToken", rpc.Method(r.LoginWithSessionToken)) diff --git a/internal/jujuapi/websocket.go b/internal/jujuapi/websocket.go index 0d92260a5..fd2fc9a84 100644 --- a/internal/jujuapi/websocket.go +++ b/internal/jujuapi/websocket.go @@ -130,8 +130,7 @@ func modelInfoFromPath(path string) (uuid string, finalPath string, err error) { // ServeWS implements jimmhttp.WSServer. func (s modelProxyServer) ServeWS(ctx context.Context, clientConn *websocket.Conn) { - // TODO(CSS-7331) Refactor model proxy for new login methods - jwtGenerator := jimm.NewJWTGenerator(nil, &s.jimm.Database, s.jimm, s.jimm.JWTService) + jwtGenerator := jimm.NewJWTGenerator(&s.jimm.Database, s.jimm, s.jimm.JWTService) connectionFunc := controllerConnectionFunc(s, &jwtGenerator) zapctx.Debug(ctx, "Starting proxier") auditLogger := s.jimm.AddAuditLogEntry @@ -140,14 +139,15 @@ func (s modelProxyServer) ServeWS(ctx context.Context, clientConn *websocket.Con TokenGen: &jwtGenerator, ConnectController: connectionFunc, AuditLog: auditLogger, + JIMM: s.jimm, } jimmRPC.ProxySockets(ctx, proxyHelpers) } // controllerConnectionFunc returns a function that will be used to // connect to a controller when a client makes a request. -func controllerConnectionFunc(s modelProxyServer, jwtGenerator *jimm.JWTGenerator) func(context.Context) (*websocket.Conn, string, error) { - connectToControllerFunc := func(ctx context.Context) (*websocket.Conn, string, error) { +func controllerConnectionFunc(s modelProxyServer, jwtGenerator *jimm.JWTGenerator) func(context.Context) (jimmRPC.WebsocketConnection, string, error) { + return func(ctx context.Context) (jimmRPC.WebsocketConnection, string, error) { const op = errors.Op("proxy.controllerConnectionFunc") path := jimmhttp.PathElementFromContext(ctx, "path") zapctx.Debug(ctx, "grabbing model info from path", zap.String("path", path)) @@ -177,7 +177,6 @@ func controllerConnectionFunc(s modelProxyServer, jwtGenerator *jimm.JWTGenerato fullModelName := m.Controller.Name + "/" + m.Name return controllerConn, fullModelName, nil } - return connectToControllerFunc } // Use a 64k frame size for the websockets while we need to deal diff --git a/internal/rpc/client_test.go b/internal/rpc/client_test.go index 063c9b8df..e284853c0 100644 --- a/internal/rpc/client_test.go +++ b/internal/rpc/client_test.go @@ -15,11 +15,11 @@ import ( qt "github.com/frankban/quicktest" "github.com/gorilla/websocket" - "github.com/juju/juju/rpc/params" "github.com/juju/names/v5" "github.com/canonical/jimm/internal/dbmodel" "github.com/canonical/jimm/internal/errors" + "github.com/canonical/jimm/internal/openfga" "github.com/canonical/jimm/internal/rpc" ) @@ -225,7 +225,7 @@ func TestClientReceiveInvalidMessage(t *testing.T) { type testTokenGenerator struct{} -func (p *testTokenGenerator) MakeLoginToken(ctx context.Context, req *params.LoginRequest) ([]byte, error) { +func (p *testTokenGenerator) MakeLoginToken(ctx context.Context, user *openfga.User) ([]byte, error) { return nil, nil } @@ -250,7 +250,7 @@ func TestProxySockets(t *testing.T) { errChan := make(chan error) srvJIMM := newServer(func(connClient *websocket.Conn) error { testTokenGen := testTokenGenerator{} - f := func(context.Context) (*websocket.Conn, string, error) { + f := func(context.Context) (rpc.WebsocketConnection, string, error) { connController, err := srvController.dialer.DialWebsocket(ctx, srvController.URL) c.Assert(err, qt.IsNil) return connController, "TestName", nil @@ -297,7 +297,7 @@ func TestCancelProxySockets(t *testing.T) { errChan := make(chan error) srvJIMM := newServer(func(connClient *websocket.Conn) error { testTokenGen := testTokenGenerator{} - f := func(context.Context) (*websocket.Conn, string, error) { + f := func(context.Context) (rpc.WebsocketConnection, string, error) { connController, err := srvController.dialer.DialWebsocket(ctx, srvController.URL) c.Assert(err, qt.IsNil) return connController, "TestName", nil @@ -337,7 +337,7 @@ func TestProxySocketsAuditLogs(t *testing.T) { errChan := make(chan error) srvJIMM := newServer(func(connClient *websocket.Conn) error { testTokenGen := testTokenGenerator{} - f := func(context.Context) (*websocket.Conn, string, error) { + f := func(context.Context) (rpc.WebsocketConnection, string, error) { connController, err := srvController.dialer.DialWebsocket(ctx, srvController.URL) c.Assert(err, qt.IsNil) return connController, "TestModelName", nil diff --git a/internal/rpc/proxy.go b/internal/rpc/proxy.go index 9b208e7bc..08b562b65 100644 --- a/internal/rpc/proxy.go +++ b/internal/rpc/proxy.go @@ -8,15 +8,18 @@ import ( "sync" "time" - "github.com/gorilla/websocket" "github.com/juju/juju/rpc/params" "github.com/juju/names/v5" "github.com/juju/zaputil/zapctx" "go.uber.org/zap" + "golang.org/x/oauth2" + apiparams "github.com/canonical/jimm/api/params" "github.com/canonical/jimm/internal/auth" "github.com/canonical/jimm/internal/dbmodel" "github.com/canonical/jimm/internal/errors" + "github.com/canonical/jimm/internal/jimm" + "github.com/canonical/jimm/internal/openfga" "github.com/canonical/jimm/internal/utils" ) @@ -26,10 +29,10 @@ const ( // TokenGenerator authenticates a user and generates a JWT token. type TokenGenerator interface { - // MakeLoginToken authorizes the user based on the provided login requests and returns - // a JWT containing claims about user's access to the controller, model (if applicable) - // and all clouds that the controller knows about. - MakeLoginToken(ctx context.Context, req *params.LoginRequest) ([]byte, error) + // MakeLoginToken returns a JWT containing claims about user's access + // to the controller, model (if applicable) and all clouds that the + // controller knows about. + MakeLoginToken(ctx context.Context, user *openfga.User) ([]byte, error) // MakeToken assumes MakeLoginToken has already been called and checks the permissions // specified in the permissionMap. If the logged in user has all those permissions // a JWT will be returned with assertions confirming all those permissions. @@ -40,10 +43,90 @@ type TokenGenerator interface { GetUser() names.UserTag } +// WebsocketConnection represents the websocket connection interface used by the proxy. +type WebsocketConnection interface { + ReadJSON(v interface{}) error + WriteJSON(v interface{}) error + Close() error +} + +// JIMM represents the JIMM interface used by the proxy. +type JIMM interface { + GetOpenFGAUserAndAuthorise(ctx context.Context, email string) (*openfga.User, error) + OAuthAuthenticationService() jimm.OAuthAuthenticator +} + +// ProxyHelpers contains all the necessary helpers for proxying a Juju client +// connection to a model. +type ProxyHelpers struct { + ConnClient WebsocketConnection + TokenGen TokenGenerator + ConnectController func(context.Context) (WebsocketConnection, string, error) + AuditLog func(*dbmodel.AuditLogEntry) + JIMM JIMM +} + +// ProxySockets will proxy requests from a client connection through to a controller +// tokenGen is used to authenticate the user and generate JWT token. +// connectController provides the function to return a connection to the desired controller endpoint. +func ProxySockets(ctx context.Context, helpers ProxyHelpers) error { + const op = errors.Op("rpc.ProxySockets") + if helpers.ConnectController == nil { + zapctx.Error(ctx, "Missing controller connect function") + return errors.E(op, "Missing controller connect function") + } + if helpers.AuditLog == nil { + zapctx.Error(ctx, "Missing audit log function") + return errors.E(op, "Missing audit log function") + } + errChan := make(chan error, 2) + msgInFlight := inflightMsgs{messages: make(map[uint64]*message)} + client := writeLockConn{conn: helpers.ConnClient} + // Note that the clProxy start method will create the connection to the desired controller only + // after the first message has been received so that any errors can be properly sent back to the client. + clProxy := clientProxy{ + modelProxy: modelProxy{ + src: &client, + msgs: &msgInFlight, + tokenGen: helpers.TokenGen, + auditLog: helpers.AuditLog, + conversationId: utils.NewConversationID(), + jimm: helpers.JIMM, + }, + errChan: errChan, + createControllerConn: helpers.ConnectController, + } + clProxy.wg.Add(1) + go func() { + defer clProxy.wg.Done() + errChan <- clProxy.start(ctx) + }() + var err error + select { + // No cleanup is needed on error, when the client closes the connection + // all go routines will proceed to error and exit. + case err = <-errChan: + zapctx.Debug(ctx, "Proxy error", zap.Error(err)) + case <-ctx.Done(): + err = errors.E(op, "Context cancelled") + zapctx.Debug(ctx, "Context cancelled") + helpers.ConnClient.Close() + clProxy.mu.Lock() + clProxy.closed = true + // TODO(Kian): Test removing close on dst below. The client connection should do it. + if clProxy.dst != nil { + clProxy.dst.conn.Close() + } + clProxy.mu.Unlock() + } + clProxy.wg.Wait() + return err +} + // writeLockConn provides a websocket connection that is safe for concurrent writes. type writeLockConn struct { mu sync.Mutex - conn *websocket.Conn + conn WebsocketConnection } // readJson allows for non-concurrent reads on the websocket. @@ -58,10 +141,18 @@ func (c *writeLockConn) writeJson(v interface{}) error { return c.conn.WriteJSON(v) } -func (c *writeLockConn) sendMessage(responseData json.RawMessage, request *message) { +func (c *writeLockConn) sendMessage(responseObject any, request *message) { msg := new(message) msg.RequestID = request.RequestID - msg.Response = responseData + msg.Response = request.Response + if responseObject != nil { + responseData, err := json.Marshal(responseObject) + if err != nil { + errorMsg := createErrResponse(err, request) + c.writeJson(errorMsg) + } + msg.Response = responseData + } c.writeJson(msg) } @@ -73,7 +164,10 @@ type inflightMsgs struct { func (msgs *inflightMsgs) addMessage(msg *message) { msgs.mu.Lock() defer msgs.mu.Unlock() + // Putting the login request on ID 0 to persist it. + // Note (alesstimec) It's a bit confusing that we automagically add "login" message + // as the first message. We should revisit this. if msg.Type == "Admin" && msg.Request == "Login" { msgs.messages[0] = msg } else { @@ -103,8 +197,11 @@ type modelProxy struct { msgs *inflightMsgs auditLog func(*dbmodel.AuditLogEntry) tokenGen TokenGenerator + jimm JIMM modelName string conversationId string + + deviceOAuthResponse *oauth2.DeviceAuthResponse } func (p *modelProxy) sendError(socket *writeLockConn, req *message, err error) { @@ -169,7 +266,7 @@ type clientProxy struct { modelProxy wg sync.WaitGroup errChan chan error - createControllerConn func(context.Context) (*websocket.Conn, string, error) + createControllerConn func(context.Context) (WebsocketConnection, string, error) // mu synchronises changes to closed and modelproxy.dst, dst is is only created // at some unspecified point in the future after a client request. mu sync.Mutex @@ -179,7 +276,6 @@ type clientProxy struct { // start begins the client->controller proxier. func (p *clientProxy) start(ctx context.Context) error { const op = errors.Op("rpc.clientProxy.start") - const initialLogin = true defer func() { if p.dst != nil { p.dst.conn.Close() @@ -202,10 +298,10 @@ func (p *clientProxy) start(ctx context.Context) error { p.auditLogMessage(msg, false) // All requests should be proxied as transparently as possible through to the controller // except for auth related requests like Login because JIMM is auth gateway. - if msg.Type == "Admin" && msg.Request == "Login" { - zapctx.Debug(ctx, "Login request found, adding JWT") - if err := addJWT(ctx, initialLogin, msg, nil, p.tokenGen); err != nil { - zapctx.Error(ctx, "Failed to add JWT", zap.Error(err)) + if msg.Type == "Admin" { + zapctx.Debug(ctx, "Found an Admin facade call") + toClient, toController, err := p.handleAdminFacade(ctx, msg) + if err != nil { var aerr *auth.AuthenticationError if stderrors.As(err, &aerr) { res, err := json.Marshal(aerr.LoginResult) @@ -219,6 +315,11 @@ func (p *clientProxy) start(ctx context.Context) error { p.sendError(p.src, msg, err) continue } + if toClient != nil { + p.src.sendMessage(nil, toClient) + } else if toController != nil { + msg = toController + } } if msg.RequestID == 0 { zapctx.Error(ctx, "Invalid request ID 0") @@ -381,7 +482,6 @@ func checkPermissionsRequired(ctx context.Context, msg *message) (map[string]any func (p *controllerProxy) redoLogin(ctx context.Context, permissions map[string]any) error { const op = errors.Op("rpc.redoLogin") - const initialLogin = false var loginMsg *message if msg, ok := p.msgs.messages[0]; ok { loginMsg = msg @@ -389,7 +489,7 @@ func (p *controllerProxy) redoLogin(ctx context.Context, permissions map[string] if loginMsg == nil { return errors.E(op, errors.CodeUnauthorized, "Haven't received login yet") } - err := addJWT(ctx, initialLogin, loginMsg, permissions, p.tokenGen) + err := addJWT(ctx, loginMsg, permissions, p.tokenGen) if err != nil { return err } @@ -401,8 +501,7 @@ func (p *controllerProxy) redoLogin(ctx context.Context, permissions map[string] } // addJWT adds a JWT token to the the provided message. -// If initialLogin is set the user will be authenticated. -func addJWT(ctx context.Context, initialLogin bool, msg *message, permissions map[string]interface{}, tokenGen TokenGenerator) error { +func addJWT(ctx context.Context, msg *message, permissions map[string]interface{}, tokenGen TokenGenerator) error { const op = errors.Op("rpc.addJWT") // First we unmarshal the existing LoginRequest. if msg == nil { @@ -412,21 +511,13 @@ func addJWT(ctx context.Context, initialLogin bool, msg *message, permissions ma if err := json.Unmarshal(msg.Params, &lr); err != nil { return errors.E(op, err) } - var jwt []byte - var err error - if initialLogin { - jwt, err = tokenGen.MakeLoginToken(ctx, &lr) - if err != nil { - zapctx.Error(ctx, "failed to make token", zap.Error(err)) - return errors.E(op, err) - } - } else { - jwt, err = tokenGen.MakeToken(ctx, permissions) - if err != nil { - zapctx.Error(ctx, "failed to make token", zap.Error(err)) - return errors.E(op, err) - } + + jwt, err := tokenGen.MakeToken(ctx, permissions) + if err != nil { + zapctx.Error(ctx, "failed to make token", zap.Error(err)) + return errors.E(op, err) } + jwtString := base64.StdEncoding.EncodeToString(jwt) // Add the JWT as base64 encoded string. lr.Token = jwtString @@ -448,71 +539,6 @@ func createErrResponse(err error, req *message) *message { return errMsg } -// ProxyHelpers contains all the necessary helpers for proxying a Juju client -// connection to a model. -type ProxyHelpers struct { - ConnClient *websocket.Conn - TokenGen TokenGenerator - ConnectController func(context.Context) (*websocket.Conn, string, error) - AuditLog func(*dbmodel.AuditLogEntry) -} - -// ProxySockets will proxy requests from a client connection through to a controller -// tokenGen is used to authenticate the user and generate JWT token. -// connectController provides the function to return a connection to the desired controller endpoint. -func ProxySockets(ctx context.Context, helpers ProxyHelpers) error { - const op = errors.Op("rpc.ProxySockets") - if helpers.ConnectController == nil { - zapctx.Error(ctx, "Missing controller connect function") - return errors.E(op, "Missing controller connect function") - } - if helpers.AuditLog == nil { - zapctx.Error(ctx, "Missing audit log function") - return errors.E(op, "Missing audit log function") - } - errChan := make(chan error, 2) - msgInFlight := inflightMsgs{messages: make(map[uint64]*message)} - client := writeLockConn{conn: helpers.ConnClient} - // Note that the clProxy start method will create the connection to the desired controller only - // after the first message has been received so that any errors can be properly sent back to the client. - clProxy := clientProxy{ - modelProxy: modelProxy{ - src: &client, - msgs: &msgInFlight, - tokenGen: helpers.TokenGen, - auditLog: helpers.AuditLog, - conversationId: utils.NewConversationID(), - }, - errChan: errChan, - createControllerConn: helpers.ConnectController, - } - clProxy.wg.Add(1) - go func() { - defer clProxy.wg.Done() - errChan <- clProxy.start(ctx) - }() - var err error - select { - // No cleanup is needed on error, when the client closes the connection - // all go routines will proceed to error and exit. - case err = <-errChan: - zapctx.Debug(ctx, "Proxy error", zap.Error(err)) - case <-ctx.Done(): - err = errors.E(op, "Context cancelled") - zapctx.Debug(ctx, "Context cancelled") - helpers.ConnClient.Close() - clProxy.mu.Lock() - clProxy.closed = true - // TODO(Kian): Test removing close on dst below. The client connection should do it. - if clProxy.dst != nil { - clProxy.dst.conn.Close() - } - clProxy.mu.Unlock() - } - clProxy.wg.Wait() - return err -} - func modifyControllerResponse(msg *message) error { var response map[string]interface{} err := json.Unmarshal(msg.Response, &response) @@ -528,3 +554,122 @@ func modifyControllerResponse(msg *message) error { msg.Response = newResp return nil } + +// handleAdminFacade processes the admin facade call and returns: +// a message to be returned to the source +// a message to be sent to the destination +// an error +func (p *clientProxy) handleAdminFacade(ctx context.Context, msg *message) (*message, *message, error) { + errorFnc := func(err error) (*message, *message, error) { + return nil, nil, err + } + switch msg.Request { + case "LoginDevice": + deviceResponse, err := jimm.LoginDevice(ctx, p.jimm.OAuthAuthenticationService()) + if err != nil { + return nil, nil, errors.E(err) + } + p.deviceOAuthResponse = deviceResponse + + data, err := json.Marshal(apiparams.LoginDeviceResponse{ + VerificationURI: deviceResponse.VerificationURI, + UserCode: deviceResponse.UserCode, + }) + if err != nil { + return errorFnc(err) + } + msg.Response = data + return msg, nil, nil + case "GetDeviceSessionToken": + sessionToken, err := jimm.GetDeviceSessionToken(ctx, p.jimm.OAuthAuthenticationService(), p.deviceOAuthResponse) + if err != nil { + return errorFnc(err) + } + data, err := json.Marshal(apiparams.GetDeviceSessionTokenResponse{ + SessionToken: sessionToken, + }) + if err != nil { + return errorFnc(err) + } + msg.Response = data + return msg, nil, nil + case "LoginWithSessionToken": + var request apiparams.LoginWithSessionTokenRequest + err := json.Unmarshal(msg.Params, &request) + if err != nil { + return errorFnc(err) + } + + // Verify the session token + // TODO(CSS-7081): Ensure for tests that the secret key can be configured. + // Or configure cmd tests to use the configured secret. + token, err := p.jimm.OAuthAuthenticationService().VerifySessionToken(request.SessionToken, "test-secret") + if err != nil { + return errorFnc(err) + } + email := token.Subject() + + user, err := p.jimm.GetOpenFGAUserAndAuthorise(ctx, email) + if err != nil { + return errorFnc(err) + } + + jwt, err := p.tokenGen.MakeLoginToken(ctx, user) + if err != nil { + return errorFnc(err) + } + data, err := json.Marshal(params.LoginRequest{ + AuthTag: names.NewUserTag(email).String(), + Token: base64.StdEncoding.EncodeToString(jwt), + }) + if err != nil { + return errorFnc(err) + } + m := *msg + m.Type = "Admin" + m.Request = "Login" + m.Version = 3 + m.Params = data + return nil, &m, nil + + case "LoginWithClientCredentials": + var request apiparams.LoginWithClientCredentialsRequest + err := json.Unmarshal(msg.Params, &request) + if err != nil { + return errorFnc(err) + } + err = p.jimm.OAuthAuthenticationService().VerifyClientCredentials(ctx, request.ClientID, request.ClientSecret) + if err != nil { + return errorFnc(err) + } + + user, err := p.jimm.GetOpenFGAUserAndAuthorise(ctx, request.ClientID) + if err != nil { + return errorFnc(err) + } + + jwt, err := p.tokenGen.MakeLoginToken(ctx, user) + if err != nil { + return errorFnc(err) + } + data, err := json.Marshal(params.LoginRequest{ + AuthTag: names.NewUserTag(request.ClientID).String(), + Token: base64.StdEncoding.EncodeToString(jwt), + }) + if err != nil { + return errorFnc(err) + } + m := *msg + m.Type = "Admin" + m.Request = "Login" + m.Version = 3 + m.Params = data + return nil, &m, nil + case "LoginWithCookie": + return errorFnc(errors.E(errors.CodeNotImplemented)) + case "Login": + return errorFnc(errors.E("Invalid login. Ensure you are using Juju 3.5+", errors.CodeNotSupported)) + default: + return nil, nil, nil + } +}