From efbf21c1fbe1d149b52510fd51ad847e143c010f Mon Sep 17 00:00:00 2001 From: Alexander <42068202+ale8k@users.noreply.github.com> Date: Fri, 22 Mar 2024 08:41:25 +0000 Subject: [PATCH] Browser cookie sessions (#1178) Implements browser sessions via cookies and persistent session storage. --- cmd/jimmsrv/main.go | 22 +- docker-compose.yaml | 2 +- go.mod | 3 +- go.sum | 4 + internal/auth/oauth2.go | 199 +++++++++++++++++- internal/auth/oauth2_test.go | 252 +++++++++++++++++++++-- internal/cmdtest/jimmsuite.go | 9 +- internal/dbmodel/identity.go | 7 + internal/dbmodel/sql/postgres/1_6.sql | 2 + internal/jimm/jimm.go | 10 +- internal/jimm/user_test.go | 19 +- internal/jimmhttp/auth_handler.go | 43 ++-- internal/jimmhttp/auth_handler_test.go | 121 +---------- internal/jimmhttp/websocket.go | 89 +++++++- internal/jimmhttp/websocket_test.go | 43 ++++ internal/jimmjwx/utils_test.go | 9 +- internal/jimmtest/auth.go | 154 ++++++++++++++ internal/jimmtest/suite.go | 4 +- internal/jujuapi/admin.go | 39 ++++ internal/jujuapi/admin_test.go | 149 +++++++++++++- internal/jujuapi/controllerroot.go | 7 +- internal/jujuapi/export_test.go | 2 +- internal/jujuapi/pinger_internal_test.go | 2 +- internal/jujuapi/websocket.go | 16 +- internal/jujuapi/websocket_test.go | 32 ++- service.go | 22 +- service_test.go | 72 ++++--- 27 files changed, 1075 insertions(+), 258 deletions(-) diff --git a/cmd/jimmsrv/main.go b/cmd/jimmsrv/main.go index e9918c682..33d2acd3f 100644 --- a/cmd/jimmsrv/main.go +++ b/cmd/jimmsrv/main.go @@ -125,13 +125,13 @@ func start(ctx context.Context, s *service.Service) error { secureSessionCookies = true } - sessionCookieExpiry := os.Getenv("JIMM_SESSION_COOKIE_EXPIRY") - sessionCookieExpiryInt, err := strconv.Atoi(sessionCookieExpiry) + sessionCookieMaxAge := os.Getenv("JIMM_SESSION_COOKIE_MAX_AGE") + sessionCookieMaxAgeInt, err := strconv.Atoi(sessionCookieMaxAge) if err != nil { - return errors.E("unable to parse jimm session cookie expiry") + return errors.E("unable to parse jimm session cookie max age") } - if sessionCookieExpiryInt < 0 { - return errors.E("jimm session cookie expiry cannot be less than 0") + if sessionCookieMaxAgeInt < 0 { + return errors.E("jimm session cookie max age cannot be less than 0") } jimmsvc, err := jimm.NewService(ctx, jimm.Params{ @@ -159,15 +159,15 @@ func start(ctx context.Context, s *service.Service) error { JWTExpiryDuration: jwtExpiryDuration, InsecureSecretStorage: insecureSecretStorage, OAuthAuthenticatorParams: jimm.OAuthAuthenticatorParams{ - IssuerURL: issuerURL, - ClientID: clientID, - ClientSecret: clientSecret, - Scopes: scopesParsed, - SessionTokenExpiry: sessionTokenExpiryDuration, + IssuerURL: issuerURL, + ClientID: clientID, + ClientSecret: clientSecret, + Scopes: scopesParsed, + SessionTokenExpiry: sessionTokenExpiryDuration, + SessionCookieMaxAge: sessionCookieMaxAgeInt, }, DashboardFinalRedirectURL: os.Getenv("JIMM_DASHBOARD_FINAL_REDIRECT_URL"), SecureSessionCookies: secureSessionCookies, - SessionCookieExpiry: sessionCookieExpiryInt, }) if err != nil { return err diff --git a/docker-compose.yaml b/docker-compose.yaml index 9a52e2648..c7cf27977 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -78,7 +78,7 @@ services: JIMM_DASHBOARD_FINAL_REDIRECT_URL: "https://my-dashboard.com/final-callback" # Example URL JIMM_ACCESS_TOKEN_EXPIRY_DURATION: 1h JIMM_SECURE_SESSION_COOKIES: false - JIMM_SESSION_COOKIE_EXPIRY: 86400 + JIMM_SESSION_COOKIE_MAX_AGE: 86400 volumes: - ./:/jimm/ - ./local/vault/approle.json:/vault/approle.json:rw diff --git a/go.mod b/go.mod index a27924975..d61ff3809 100644 --- a/go.mod +++ b/go.mod @@ -132,7 +132,7 @@ require ( github.com/godbus/dbus/v5 v5.0.4 // indirect github.com/gofrs/flock v0.8.1 // indirect github.com/gogo/protobuf v1.3.2 // indirect - github.com/golang-jwt/jwt/v5 v5.0.0 // indirect + github.com/golang-jwt/jwt/v5 v5.2.0 // indirect github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect github.com/golang/protobuf v1.5.3 // indirect github.com/google/gnostic-models v0.6.8 // indirect @@ -251,6 +251,7 @@ require ( github.com/muhlemmer/gu v0.3.1 // indirect github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect github.com/mxk/go-flowrate v0.0.0-20140419014527-cca7078d478f // indirect + github.com/oauth2-proxy/mockoidc v0.0.0-20240214162133-caebfff84d25 // indirect github.com/opencontainers/go-digest v1.0.0 // indirect github.com/oracle/oci-go-sdk/v65 v65.55.0 // indirect github.com/packethost/packngo v0.28.1 // indirect diff --git a/go.sum b/go.sum index d82c9d791..d00d9fbe1 100644 --- a/go.sum +++ b/go.sum @@ -296,6 +296,8 @@ github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= github.com/golang-jwt/jwt/v5 v5.0.0 h1:1n1XNM9hk7O9mnQoNBGolZvzebBQ7p93ULHRc28XJUE= github.com/golang-jwt/jwt/v5 v5.0.0/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= +github.com/golang-jwt/jwt/v5 v5.2.0 h1:d/ix8ftRUorsN+5eMIlF4T6J8CAt9rch3My2winC1Jw= +github.com/golang-jwt/jwt/v5 v5.2.0/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= github.com/golang/groupcache v0.0.0-20190702054246-869f871628b6/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= github.com/golang/groupcache v0.0.0-20191227052852-215e87163ea7/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= @@ -881,6 +883,8 @@ github.com/mxk/go-flowrate v0.0.0-20140419014527-cca7078d478f h1:y5//uYreIhSUg3J github.com/mxk/go-flowrate v0.0.0-20140419014527-cca7078d478f/go.mod h1:ZdcZmHo+o7JKHSa8/e818NopupXU1YMK5fe1lsApnBw= github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno= github.com/nu7hatch/gouuid v0.0.0-20131221200532-179d4d0c4d8d/go.mod h1:YUTz3bUH2ZwIWBy3CJBeOBEugqcmXREj14T+iG/4k4U= +github.com/oauth2-proxy/mockoidc v0.0.0-20240214162133-caebfff84d25 h1:9bCMuD3TcnjeqjPT2gSlha4asp8NvgcFRYExCaikCxk= +github.com/oauth2-proxy/mockoidc v0.0.0-20240214162133-caebfff84d25/go.mod h1:eDjgYHYDJbPLBLsyZ6qRaugP0mX8vePOhZ5id1fdzJw= github.com/oklog/ulid/v2 v2.1.0 h1:+9lhoxAP56we25tyYETBBY1YLA2SaoLvUFgrP2miPJU= github.com/oklog/ulid/v2 v2.1.0/go.mod h1:rcEKHmBBKfef9DhnvX7y1HZBYxjXb0cP5ExxNsTT1QQ= github.com/onsi/ginkgo/v2 v2.13.0 h1:0jY9lJquiL8fcf3M4LAXN5aMlS/b2BV86HFFPCPMgE4= diff --git a/internal/auth/oauth2.go b/internal/auth/oauth2.go index c98b7e802..3dc47a86a 100644 --- a/internal/auth/oauth2.go +++ b/internal/auth/oauth2.go @@ -12,11 +12,14 @@ import ( "context" "encoding/base64" stderrors "errors" + "fmt" + "net/http" "net/mail" "strings" "time" "github.com/coreos/go-oidc/v3/oidc" + "github.com/gorilla/sessions" "github.com/juju/zaputil/zapctx" "github.com/lestrrat-go/jwx/v2/jwa" "github.com/lestrrat-go/jwx/v2/jwt" @@ -28,6 +31,31 @@ import ( "github.com/canonical/jimm/internal/errors" ) +const ( + // SessionName is the name of the gorilla session and is used to retrieve + // the session object from the database. + SessionName = "jimm-browser-session" + + // SessionIdentityKey is the key for the identity value stored within the + // session. + SessionIdentityKey = "identity-id" +) + +type sessionIdentityContextKey struct{} + +func contextWithSessionIdentity(ctx context.Context, sessionIdentityId any) context.Context { + return context.WithValue(ctx, sessionIdentityContextKey{}, sessionIdentityId) +} + +// SessionIdentityFromContext returns the session identity key from the context. +func SessionIdentityFromContext(ctx context.Context) string { + v := ctx.Value(sessionIdentityContextKey{}) + if v == nil { + return "" + } + return v.(string) +} + // AuthenticationService handles authentication within JIMM. type AuthenticationService struct { oauthConfig oauth2.Config @@ -37,7 +65,12 @@ type AuthenticationService struct { // sessionTokenExpiry holds the expiry time for JIMM minted session tokens (JWTs). sessionTokenExpiry time.Duration + // sessionCookieMaxAge holds the max age for session cookies. + sessionCookieMaxAge int + db IdentityStore + + sessionStore sessions.Store } // Identity store holds the necessary methods to get and update an identity @@ -62,6 +95,8 @@ type AuthenticationServiceParams struct { Scopes []string // SessionTokenExpiry holds the expiry time of minted JIMM session tokens (JWTs). SessionTokenExpiry time.Duration + // SessionCookieMaxAge holds the max age for session cookies. + SessionCookieMaxAge int // RedirectURL is the URL for handling the exchange of authorisation // codes into access tokens (and id tokens), for JIMM, this is expected // to be the servers own callback endpoint registered under /auth/callback. @@ -71,6 +106,9 @@ type AuthenticationServiceParams struct { // to fetch and update identities. I.e., their access tokens, refresh tokens, // display name, etc. Store IdentityStore + + // SessionStore holds the store for creating, getting and saving gorrila sessions. + SessionStore sessions.Store } // NewAuthenticationService returns a new authentication service for handling @@ -93,8 +131,10 @@ func NewAuthenticationService(ctx context.Context, params AuthenticationServiceP Scopes: params.Scopes, RedirectURL: params.RedirectURL, }, - sessionTokenExpiry: params.SessionTokenExpiry, - db: params.Store, + sessionTokenExpiry: params.SessionTokenExpiry, + db: params.Store, + sessionStore: params.SessionStore, + sessionCookieMaxAge: params.SessionCookieMaxAge, }, nil } @@ -277,6 +317,8 @@ func (as *AuthenticationService) UpdateIdentity(ctx context.Context, email strin u.AccessToken = token.AccessToken u.RefreshToken = token.RefreshToken + u.AccessTokenExpiry = token.Expiry + u.AccessTokenType = token.TokenType if err := db.UpdateIdentity(ctx, u); err != nil { return errors.E(op, err) } @@ -335,3 +377,156 @@ func (as *AuthenticationService) VerifyClientCredentials(ctx context.Context, cl } return nil } + +// CreateBrowserSession creates a session and updates the cookie for a browser +// login callback. +func (as *AuthenticationService) CreateBrowserSession( + ctx context.Context, + w http.ResponseWriter, + r *http.Request, + secureCookies bool, + email string, +) error { + const op = errors.Op("auth.AuthenticationService.CreateBrowserSession") + + session, err := as.sessionStore.Get(r, SessionName) + if err != nil { + return errors.E(op, err) + } + + session.IsNew = true // Sets cookie to a fresh new cookie + session.Options.MaxAge = as.sessionCookieMaxAge // Expiry in seconds + session.Options.Secure = secureCookies // Ensures only sent with HTTPS + session.Options.HttpOnly = false // Allow Javascript to read it + + session.Values[SessionIdentityKey] = email + if err = session.Save(r, w); err != nil { + return errors.E(op, err) + } + return nil +} + +// AuthenticateBrowserSession updates the session for a browser, additionally +// retrieving new access tokens upon expiry. If this cannot be done, the cookie +// is deleted and an error is returned. +func (as *AuthenticationService) AuthenticateBrowserSession(ctx context.Context, w http.ResponseWriter, req *http.Request) (context.Context, error) { + const op = errors.Op("auth.AuthenticationService.AuthenticateBrowserSession") + + session, err := as.sessionStore.Get(req, SessionName) + if err != nil { + return ctx, errors.E(op, err, "failed to retrieve session") + } + + identityId, ok := session.Values[SessionIdentityKey] + if !ok { + return ctx, errors.E(op, "session is missing identity key") + } + + err = as.validateAndUpdateAccessToken(ctx, identityId) + if err != nil { + if err := as.deleteSession(session, w, req); err != nil { + return ctx, errors.E(op, err) + } + return ctx, errors.E(op, err) + } + + ctx = contextWithSessionIdentity(ctx, identityId) + + if err := as.extendSession(session, w, req); err != nil { + return ctx, errors.E(op, err) + } + + return ctx, nil +} + +// validateAndUpdateAccessToken validates the access tokens expiry, and if it cannot, then +// it attempts to refresh the access token. +func (as *AuthenticationService) validateAndUpdateAccessToken(ctx context.Context, email any) error { + const op = errors.Op("auth.AuthenticationService.validateAndUpdateAccessToken") + + emailStr, ok := email.(string) + if !ok { + return errors.E(op, fmt.Sprintf("failed to cast email: got %T, expected %T", email, emailStr)) + } + + db := as.db + u := &dbmodel.Identity{ + Name: emailStr, + } + if err := db.GetIdentity(ctx, u); err != nil { + return errors.E(op, err) + } + + t := &oauth2.Token{ + AccessToken: u.AccessToken, + RefreshToken: u.RefreshToken, + Expiry: u.AccessTokenExpiry, + TokenType: u.AccessTokenType, + } + + // Valid simply checks the expiry, if the token isn't valid, + // we attempt to refresh the identities tokens and update them. + if t.Valid() { + return nil + } + + if err := as.refreshIdentitiesToken(ctx, emailStr, t); err != nil { + return errors.E(op, err) + } + + return nil +} + +// refreshIdentitiesToken creates a token source based on the expired token and performs +// a manual token refresh, updating the identity afterwards. +// +// This is to be called only when a token is expired. +func (as *AuthenticationService) refreshIdentitiesToken(ctx context.Context, email string, t *oauth2.Token) error { + const op = errors.Op("auth.AuthenticationService.refreshIdentitiesToken") + + tSrc := as.oauthConfig.TokenSource(ctx, t) + + // Get a new access and refresh token (token source only has Token()) + newToken, err := tSrc.Token() + if err != nil { + return errors.E(op, err, "failed to refresh token") + } + + if err := as.UpdateIdentity(ctx, email, newToken); err != nil { + return errors.E(op, err, "failed to update identity") + } + + return nil +} + +func (as *AuthenticationService) deleteSession(session *sessions.Session, w http.ResponseWriter, req *http.Request) error { + const op = errors.Op("auth.AuthenticationService.deleteSession") + + if err := as.modifySession(session, w, req, -1); err != nil { + return errors.E(op, err) + } + + return nil +} + +func (as *AuthenticationService) extendSession(session *sessions.Session, w http.ResponseWriter, req *http.Request) error { + const op = errors.Op("auth.AuthenticationService.extendSession") + + if err := as.modifySession(session, w, req, as.sessionCookieMaxAge); err != nil { + return errors.E(op, err) + } + + return nil +} + +func (as *AuthenticationService) modifySession(session *sessions.Session, w http.ResponseWriter, req *http.Request, maxAge int) error { + const op = errors.Op("auth.AuthenticationService.modifySession") + + session.Options.MaxAge = maxAge + + if err := session.Save(req, w); err != nil { + return errors.E(op, err) + } + + return nil +} diff --git a/internal/auth/oauth2_test.go b/internal/auth/oauth2_test.go index d07e9e6e6..0107dce17 100644 --- a/internal/auth/oauth2_test.go +++ b/internal/auth/oauth2_test.go @@ -4,41 +4,53 @@ package auth_test import ( "context" + "encoding/base64" "fmt" "io" "net/http" "net/http/cookiejar" + "net/http/httptest" "net/url" "regexp" "testing" "time" + "github.com/antonlindstrom/pgstore" "github.com/canonical/jimm/internal/auth" "github.com/canonical/jimm/internal/db" "github.com/canonical/jimm/internal/dbmodel" "github.com/canonical/jimm/internal/jimmtest" "github.com/coreos/go-oidc/v3/oidc" qt "github.com/frankban/quicktest" + "github.com/gorilla/sessions" ) -func setupTestAuthSvc(ctx context.Context, c *qt.C, expiry time.Duration) (*auth.AuthenticationService, *db.Database) { +func setupTestAuthSvc(ctx context.Context, c *qt.C, expiry time.Duration) (*auth.AuthenticationService, *db.Database, sessions.Store) { db := &db.Database{ DB: jimmtest.PostgresDB(c, func() time.Time { return time.Now() }), } c.Assert(db.Migrate(ctx, false), qt.IsNil) + sqldb, err := db.DB.DB() + c.Assert(err, qt.IsNil) + + sessionStore, err := pgstore.NewPGStoreFromPool(sqldb, []byte("secretsecretdigletts")) + c.Assert(err, qt.IsNil) + authSvc, err := auth.NewAuthenticationService(ctx, auth.AuthenticationServiceParams{ - IssuerURL: "http://localhost:8082/realms/jimm", - ClientID: "jimm-device", - ClientSecret: "SwjDofnbDzJDm9iyfUhEp67FfUFMY8L4", - Scopes: []string{oidc.ScopeOpenID, "profile", "email"}, - SessionTokenExpiry: expiry, - RedirectURL: "http://localhost:8080/auth/callback", - Store: db, + IssuerURL: "http://localhost:8082/realms/jimm", + ClientID: "jimm-device", + ClientSecret: "SwjDofnbDzJDm9iyfUhEp67FfUFMY8L4", + Scopes: []string{oidc.ScopeOpenID, "profile", "email"}, + SessionTokenExpiry: expiry, + RedirectURL: "http://localhost:8080/auth/callback", + Store: db, + SessionStore: sessionStore, + SessionCookieMaxAge: 60, }) c.Assert(err, qt.IsNil) - return authSvc, db + return authSvc, db, sessionStore } // This test requires the local docker compose to be running and keycloak @@ -49,7 +61,7 @@ func TestAuthCodeURL(t *testing.T) { c := qt.New(t) ctx := context.Background() - authSvc, _ := setupTestAuthSvc(ctx, c, time.Hour) + authSvc, _, _ := setupTestAuthSvc(ctx, c, time.Hour) url := authSvc.AuthCodeURL() c.Assert( @@ -75,7 +87,7 @@ func TestDevice(t *testing.T) { ctx := context.Background() - authSvc, db := setupTestAuthSvc(ctx, c, time.Hour) + authSvc, db, _ := setupTestAuthSvc(ctx, c, time.Hour) res, err := authSvc.Device(ctx) c.Assert(err, qt.IsNil) @@ -166,7 +178,7 @@ func TestSessionTokens(t *testing.T) { ctx := context.Background() - authSvc, _ := setupTestAuthSvc(ctx, c, time.Hour) + authSvc, _, _ := setupTestAuthSvc(ctx, c, time.Hour) secretKey := "secret-key" token, err := authSvc.MintSessionToken("jimm-test@canonical.com", secretKey) @@ -183,7 +195,7 @@ func TestSessionTokenRejectsWrongSecretKey(t *testing.T) { ctx := context.Background() - authSvc, _ := setupTestAuthSvc(ctx, c, time.Hour) + authSvc, _, _ := setupTestAuthSvc(ctx, c, time.Hour) secretKey := "secret-key" token, err := authSvc.MintSessionToken("jimm-test@canonical.com", secretKey) @@ -200,7 +212,7 @@ func TestSessionTokenRejectsExpiredToken(t *testing.T) { ctx := context.Background() noDuration := time.Duration(0) - authSvc, _ := setupTestAuthSvc(ctx, c, noDuration) + authSvc, _, _ := setupTestAuthSvc(ctx, c, noDuration) secretKey := "secret-key" token, err := authSvc.MintSessionToken("jimm-test@canonical.com", secretKey) @@ -216,7 +228,7 @@ func TestSessionTokenValidatesEmail(t *testing.T) { ctx := context.Background() - authSvc, _ := setupTestAuthSvc(ctx, c, time.Hour) + authSvc, _, _ := setupTestAuthSvc(ctx, c, time.Hour) secretKey := "secret-key" token, err := authSvc.MintSessionToken("", secretKey) @@ -237,7 +249,7 @@ func TestVerifyClientCredentials(t *testing.T) { validClientSecret = "2M2blFbO4GX4zfggQpivQSxwWX1XGgNf" ) - authSvc, _ := setupTestAuthSvc(ctx, c, time.Hour) + authSvc, _, _ := setupTestAuthSvc(ctx, c, time.Hour) err := authSvc.VerifyClientCredentials(ctx, validClientID, validClientSecret) c.Assert(err, qt.IsNil) @@ -245,3 +257,211 @@ func TestVerifyClientCredentials(t *testing.T) { err = authSvc.VerifyClientCredentials(ctx, "invalid-client-id", validClientSecret) c.Assert(err, qt.ErrorMatches, "invalid client credentials") } + +func assertSetCookiesIsCorrect(c *qt.C, rec *httptest.ResponseRecorder, parsedCookies []*http.Cookie) { + assertHasCookie := func(name string, cookies []*http.Cookie) { + found := false + for _, v := range cookies { + if v.Name == name { + found = true + } + } + c.Assert(found, qt.IsTrue) + } + assertHasCookie(auth.SessionName, parsedCookies) + assertHasCookie("Path", parsedCookies) + assertHasCookie("Expires", parsedCookies) + assertHasCookie("Max-Age", parsedCookies) +} + +func TestCreateBrowserSession(t *testing.T) { + c := qt.New(t) + ctx := context.Background() + + authSvc, _, sessionStore := setupTestAuthSvc(ctx, c, time.Hour) + + rec := httptest.NewRecorder() + req, err := http.NewRequest("GET", "", nil) + c.Assert(err, qt.IsNil) + + err = authSvc.CreateBrowserSession(ctx, rec, req, false, "jimm-test@canonical.com") + c.Assert(err, qt.IsNil) + + cookies := rec.Header().Get("Set-Cookie") + parsedCookies := jimmtest.ParseCookies(cookies) + assertSetCookiesIsCorrect(c, rec, parsedCookies) + + req.AddCookie(&http.Cookie{ + Name: auth.SessionName, + Value: parsedCookies[0].Value, + }) + + session, err := sessionStore.Get(req, auth.SessionName) + c.Assert(err, qt.IsNil) + c.Assert(session.Values[auth.SessionIdentityKey], qt.Equals, "jimm-test@canonical.com") +} + +func TestAuthenticateBrowserSession(t *testing.T) { + c := qt.New(t) + ctx := context.Background() + + authSvc, db, sessionStore := setupTestAuthSvc(ctx, c, time.Hour) + + cookie, err := jimmtest.RunBrowserLogin(db, sessionStore) + c.Assert(err, qt.IsNil) + + rec := httptest.NewRecorder() + req, err := http.NewRequest("GET", "", nil) + c.Assert(err, qt.IsNil) + + cookies := jimmtest.ParseCookies(cookie) + + req.AddCookie(cookies[0]) + + ctx, err = authSvc.AuthenticateBrowserSession(ctx, rec, req) + c.Assert(err, qt.IsNil) + + // Check identity added + identityId := auth.SessionIdentityFromContext(ctx) + c.Assert(identityId, qt.Equals, "jimm-test@canonical.com") + + // Assert Set-Cookie present + setCookieCookies := rec.Header().Get("Set-Cookie") + parsedCookies := jimmtest.ParseCookies(setCookieCookies) + assertSetCookiesIsCorrect(c, rec, parsedCookies) +} + +func TestAuthenticateBrowserSessionRejectsNoneDecryptableOrDecodableCookies(t *testing.T) { + c := qt.New(t) + ctx := context.Background() + + authSvc, db, sessionStore := setupTestAuthSvc(ctx, c, time.Hour) + + _, err := jimmtest.RunBrowserLogin(db, sessionStore) + c.Assert(err, qt.IsNil) + + // Failure case 1: Bad base64 decoding + req, err := http.NewRequest("GET", "", nil) + c.Assert(err, qt.IsNil) + req.AddCookie(&http.Cookie{ + Name: auth.SessionName, + Value: "bad cookie, very naughty, bad bad cookie", + }) + + rec := httptest.NewRecorder() + + // The underlying error is a failed base64 decode + _, err = authSvc.AuthenticateBrowserSession(ctx, rec, req) + c.Assert(err, qt.ErrorMatches, "failed to retrieve session") + + // Failure case 2: Value isn't valid but is base64 decoded + req, err = http.NewRequest("GET", "", nil) + c.Assert(err, qt.IsNil) + req.AddCookie(&http.Cookie{ + Name: auth.SessionName, + Value: base64.StdEncoding.EncodeToString([]byte("bad cookie, very naughty, bad bad cookie")), + }) + + rec = httptest.NewRecorder() + + _, err = authSvc.AuthenticateBrowserSession(ctx, rec, req) + c.Assert(err, qt.ErrorMatches, "failed to retrieve session") +} + +func TestAuthenticateBrowserSessionHandlesExpiredAccessTokens(t *testing.T) { + c := qt.New(t) + ctx := context.Background() + + authSvc, db, sessionStore := setupTestAuthSvc(ctx, c, time.Hour) + + cookie, err := jimmtest.RunBrowserLogin(db, sessionStore) + c.Assert(err, qt.IsNil) + + rec := httptest.NewRecorder() + req, err := http.NewRequest("GET", "", nil) + c.Assert(err, qt.IsNil) + + cookies := jimmtest.ParseCookies(cookie) + + req.AddCookie(cookies[0]) + + // User exists from run browser login, but we're gonna + // artificially expire their access token + u := dbmodel.Identity{ + Name: "jimm-test@canonical.com", + } + err = db.GetIdentity(ctx, &u) + c.Assert(err, qt.IsNil) + + previousToken := u.AccessToken + + u.AccessTokenExpiry = time.Now() + db.UpdateIdentity(ctx, &u) + + ctx, err = authSvc.AuthenticateBrowserSession(ctx, rec, req) + c.Assert(err, qt.IsNil) + + // Check identity added + identityId := auth.SessionIdentityFromContext(ctx) + c.Assert(identityId, qt.Equals, "jimm-test@canonical.com") + + // Get identity again with new access token expiry and access token + err = db.GetIdentity(ctx, &u) + c.Assert(err, qt.IsNil) + + // Assert new access token is valid for at least 4 minutes(our setup is 5 minutes) + c.Assert(u.AccessTokenExpiry.After(time.Now().Add(time.Minute*4)), qt.IsTrue) + // Assert its not the same token as previous token + c.Assert(u.AccessToken, qt.Not(qt.Equals), previousToken) + // Assert Set-Cookie present + setCookieCookies := rec.Header().Get("Set-Cookie") + parsedCookies := jimmtest.ParseCookies(setCookieCookies) + assertSetCookiesIsCorrect(c, rec, parsedCookies) +} + +func TestAuthenticateBrowserSessionHandlesMissingOrExpiredRefreshTokens(t *testing.T) { + c := qt.New(t) + ctx := context.Background() + + authSvc, db, sessionStore := setupTestAuthSvc(ctx, c, time.Hour) + + cookie, err := jimmtest.RunBrowserLogin(db, sessionStore) + c.Assert(err, qt.IsNil) + + rec := httptest.NewRecorder() + req, err := http.NewRequest("GET", "", nil) + c.Assert(err, qt.IsNil) + + cookies := jimmtest.ParseCookies(cookie) + + req.AddCookie(cookies[0]) + + // User exists from run browser login, but we're gonna + // artificially expire their access token + u := dbmodel.Identity{ + Name: "jimm-test@canonical.com", + } + err = db.GetIdentity(ctx, &u) + c.Assert(err, qt.IsNil) + + // As our access token has "expired" + u.AccessTokenExpiry = time.Now() + // And we're missing a refresh token (the same case would apply for an expired refresh token + // or any scenario where the token source cannot refresh the access token) + u.RefreshToken = "" + db.UpdateIdentity(ctx, &u) + + // AuthenticateBrowserSession should fail to refresh the users session and delete + // the current session, giving us the same cookie back with a max-age of -1. + _, err = authSvc.AuthenticateBrowserSession(ctx, rec, req) + c.Assert(err, qt.ErrorMatches, ".*failed to refresh token.*") + + // Assert that the header to delete the session is set correctly based + // on a failed access token refresh due to refresh token issues. + setCookieCookies := rec.Header().Get("Set-Cookie") + c.Assert( + setCookieCookies, + qt.Equals, + "jimm-browser-session=; Path=/; Expires=Thu, 01 Jan 1970 00:00:01 GMT; Max-Age=0", + ) +} diff --git a/internal/cmdtest/jimmsuite.go b/internal/cmdtest/jimmsuite.go index 50de3118f..c35b1eb5b 100644 --- a/internal/cmdtest/jimmsuite.go +++ b/internal/cmdtest/jimmsuite.go @@ -83,10 +83,11 @@ func (s *JimmCmdSuite) SetUpTest(c *gc.C) { JWTExpiryDuration: time.Minute, InsecureSecretStorage: true, OAuthAuthenticatorParams: service.OAuthAuthenticatorParams{ - IssuerURL: "http://localhost:8082/realms/jimm", - ClientID: "jimm-device", - Scopes: []string{oidc.ScopeOpenID, "profile", "email"}, - SessionTokenExpiry: time.Duration(time.Hour), + IssuerURL: "http://localhost:8082/realms/jimm", + ClientID: "jimm-device", + Scopes: []string{oidc.ScopeOpenID, "profile", "email"}, + SessionTokenExpiry: time.Duration(time.Hour), + SessionCookieMaxAge: 60, }, DashboardFinalRedirectURL: "", } diff --git a/internal/dbmodel/identity.go b/internal/dbmodel/identity.go index 75607fc65..1888af930 100644 --- a/internal/dbmodel/identity.go +++ b/internal/dbmodel/identity.go @@ -4,6 +4,7 @@ package dbmodel import ( "database/sql" + "time" jujuparams "github.com/juju/juju/rpc/params" "github.com/juju/names/v5" @@ -44,6 +45,12 @@ type Identity struct { // from the browser or device flow, and as such is updated on every successful // login. RefreshToken string + + // AccessTokenExpiry is the expiration date for this access token. + AccessTokenExpiry time.Time + + // AccessTokenType is the type for the token, typically bearer. + AccessTokenType string } // Tag returns a names.Tag for the identity. diff --git a/internal/dbmodel/sql/postgres/1_6.sql b/internal/dbmodel/sql/postgres/1_6.sql index d5ba10f6d..5f380c483 100644 --- a/internal/dbmodel/sql/postgres/1_6.sql +++ b/internal/dbmodel/sql/postgres/1_6.sql @@ -2,6 +2,8 @@ -- and is a migration that renames `user` to `identity`. ALTER TABLE users ADD COLUMN access_token TEXT; ALTER TABLE users ADD COLUMN refresh_token TEXT; +ALTER TABLE users ADD COLUMN access_token_expiry TIMESTAMP; +ALTER TABLE users ADD COLUMN access_token_type TEXT; -- Note that we don't need to rename underlying indexes/constraints. As Postgres -- docs states: diff --git a/internal/jimm/jimm.go b/internal/jimm/jimm.go index 388690fd1..fbf22e097 100644 --- a/internal/jimm/jimm.go +++ b/internal/jimm/jimm.go @@ -7,10 +7,10 @@ package jimm import ( "context" "database/sql" + "net/http" "strings" "time" - "github.com/antonlindstrom/pgstore" "github.com/coreos/go-oidc/v3/oidc" "github.com/go-macaroon-bakery/macaroon-bakery/v3/bakery" "github.com/juju/juju/api/base" @@ -85,9 +85,6 @@ type JIMM struct { // OAuthAuthenticator is responsible for handling authentication // via OAuth2.0 AND JWT access tokens to JIMM. OAuthAuthenticator OAuthAuthenticator - - // CookieSessionStore is respnsible for handling cookie based sessions. - CookieSessionStore *pgstore.PGStore } // OAuthAuthenticationService returns the JIMM's authentication service. @@ -164,6 +161,11 @@ type OAuthAuthenticator interface { // VerifyClientCredentials verifies the provided client ID and client secret. VerifyClientCredentials(ctx context.Context, clientID string, clientSecret string) error + + // AuthenticateBrowserSession updates the session for a browser, additionally + // retrieving new access tokens upon expiry. If this cannot be done, the cookie + // is deleted and an error is returned. + AuthenticateBrowserSession(ctx context.Context, w http.ResponseWriter, req *http.Request) (context.Context, error) } // GetCredentialStore returns the credential store used by JIMM. diff --git a/internal/jimm/user_test.go b/internal/jimm/user_test.go index 1e55c516d..bae5f7b52 100644 --- a/internal/jimm/user_test.go +++ b/internal/jimm/user_test.go @@ -7,6 +7,7 @@ import ( "testing" "time" + "github.com/antonlindstrom/pgstore" qt "github.com/frankban/quicktest" "github.com/juju/names/v5" @@ -29,13 +30,19 @@ func TestGetOpenFGAUser(t *testing.T) { db := &db.Database{ DB: jimmtest.PostgresDB(c, func() time.Time { return time.Now() }), } - // TODO(ale8k): Mock this + sqldb, err := db.DB.DB() + c.Assert(err, qt.IsNil) + + sessionStore, err := pgstore.NewPGStoreFromPool(sqldb, []byte("secretsecretdigletts")) + c.Assert(err, qt.IsNil) authSvc, err := auth.NewAuthenticationService(ctx, auth.AuthenticationServiceParams{ - IssuerURL: "http://localhost:8082/realms/jimm", - ClientID: "jimm-device", - Scopes: []string{"openid", "profile", "email"}, - SessionTokenExpiry: time.Hour, - Store: db, + IssuerURL: "http://localhost:8082/realms/jimm", + ClientID: "jimm-device", + Scopes: []string{"openid", "profile", "email"}, + SessionTokenExpiry: time.Hour, + Store: db, + SessionStore: sessionStore, + SessionCookieMaxAge: 60, }) c.Assert(err, qt.IsNil) diff --git a/internal/jimmhttp/auth_handler.go b/internal/jimmhttp/auth_handler.go index ac33cc379..43eb15156 100644 --- a/internal/jimmhttp/auth_handler.go +++ b/internal/jimmhttp/auth_handler.go @@ -4,7 +4,6 @@ import ( "context" "net/http" - "github.com/antonlindstrom/pgstore" "github.com/coreos/go-oidc/v3/oidc" "github.com/go-chi/chi/v5" "github.com/juju/zaputil/zapctx" @@ -20,9 +19,7 @@ type OAuthHandler struct { Router *chi.Mux authenticator BrowserOAuthAuthenticator dashboardFinalRedirectURL string - sessionStore *pgstore.PGStore secureCookies bool - cookieExpiry int } // OAuthHandlerParams holds the parameters to configure the OAuthHandler. @@ -34,15 +31,9 @@ type OAuthHandlerParams struct { // upon completing the authorisation code flow. DashboardFinalRedirectURL string - // SessionStore is the cookie session store. - SessionStore *pgstore.PGStore - // SessionCookies determines if HTTPS must be enabled in order for JIMM // to set cookies when creating browser based sessions. SecureCookies bool - - // CookieExpiry is how long the cookie will be valid before expiring in seconds. - CookieExpiry int } // BrowserOAuthAuthenticator handles authorisation code authentication within JIMM @@ -53,6 +44,13 @@ type BrowserOAuthAuthenticator interface { ExtractAndVerifyIDToken(ctx context.Context, oauth2Token *oauth2.Token) (*oidc.IDToken, error) Email(idToken *oidc.IDToken) (string, error) UpdateIdentity(ctx context.Context, email string, token *oauth2.Token) error + CreateBrowserSession( + ctx context.Context, + w http.ResponseWriter, + r *http.Request, + secureCookies bool, + email string, + ) error } // NewOAuthHandler returns a new OAuth handler. @@ -63,16 +61,11 @@ func NewOAuthHandler(p OAuthHandlerParams) (*OAuthHandler, error) { if p.DashboardFinalRedirectURL == "" { return nil, errors.E("final redirect url not specified") } - if p.SessionStore == nil { - return nil, errors.E("nil session store") - } return &OAuthHandler{ Router: chi.NewRouter(), authenticator: p.Authenticator, dashboardFinalRedirectURL: p.DashboardFinalRedirectURL, - sessionStore: p.SessionStore, secureCookies: p.SecureCookies, - cookieExpiry: p.CookieExpiry, }, nil } @@ -129,22 +122,16 @@ func (oah *OAuthHandler) Callback(w http.ResponseWriter, r *http.Request) { return } - // If the session is empty, it'll just be an empty session, we only check - // errors for bad decoding etc. - session, err := oah.sessionStore.Get(r, "jimm-browser-session") - if err != nil { - writeError(ctx, w, http.StatusBadRequest, err, "failed to get session") + if err := oah.authenticator.CreateBrowserSession( + ctx, + w, + r, + oah.secureCookies, + email, + ); err != nil { + writeError(ctx, w, http.StatusBadRequest, err, "failed to setup session") } - session.IsNew = true // Sets cookie to a fresh new cookie - session.Options.MaxAge = oah.cookieExpiry // Expiry in seconds - session.Options.Secure = oah.secureCookies // Ensures only sent with HTTPS - session.Options.HttpOnly = false // Allow Javascript to read it - - session.Values["jimm-session"] = email - if err = session.Save(r, w); err != nil { - writeError(ctx, w, http.StatusBadRequest, err, "failed to save session") - } http.Redirect(w, r, oah.dashboardFinalRedirectURL, http.StatusPermanentRedirect) } diff --git a/internal/jimmhttp/auth_handler_test.go b/internal/jimmhttp/auth_handler_test.go index 81cdc95ef..190da8701 100644 --- a/internal/jimmhttp/auth_handler_test.go +++ b/internal/jimmhttp/auth_handler_test.go @@ -2,28 +2,20 @@ package jimmhttp_test import ( "context" - "fmt" "io" - "net" "net/http" - "net/http/cookiejar" - "net/http/httptest" - "net/url" - "regexp" "testing" "time" "github.com/antonlindstrom/pgstore" - "github.com/coreos/go-oidc/v3/oidc" qt "github.com/frankban/quicktest" + "github.com/gorilla/sessions" - "github.com/canonical/jimm/internal/auth" "github.com/canonical/jimm/internal/db" - "github.com/canonical/jimm/internal/jimmhttp" "github.com/canonical/jimm/internal/jimmtest" ) -func setupDbAndSessionStore(c *qt.C) (*db.Database, *pgstore.PGStore) { +func setupDbAndSessionStore(c *qt.C) (*db.Database, sessions.Store) { // Setup db ahead of time so we have access to session store db := &db.Database{ DB: jimmtest.PostgresDB(c, func() time.Time { return time.Now() }), @@ -39,49 +31,6 @@ func setupDbAndSessionStore(c *qt.C) (*db.Database, *pgstore.PGStore) { return db, store } -func setupTestServer(c *qt.C, dashboardURL string, db *db.Database, sessionStore *pgstore.PGStore) *httptest.Server { - // Find a random free TCP port. - listener, err := net.Listen("tcp", "127.0.0.1:0") - c.Assert(err, qt.IsNil) - port := fmt.Sprintf("%d", listener.Addr().(*net.TCPAddr).Port) - - // Create unstarted server to enable auth service - s := httptest.NewUnstartedServer(nil) - s.Listener = listener - - // Remember redirect url to check it matches after test server starts - redirectURL := "http://127.0.0.1:" + port + "/callback" - authSvc, err := auth.NewAuthenticationService(context.Background(), auth.AuthenticationServiceParams{ - IssuerURL: "http://localhost:8082/realms/jimm", - ClientID: "jimm-device", - ClientSecret: "SwjDofnbDzJDm9iyfUhEp67FfUFMY8L4", - Scopes: []string{oidc.ScopeOpenID, "profile", "email"}, - SessionTokenExpiry: time.Hour, - // Now we know the port the test server is running on - RedirectURL: redirectURL, - Store: db, - }) - c.Assert(err, qt.IsNil) - - h, err := jimmhttp.NewOAuthHandler(jimmhttp.OAuthHandlerParams{ - Authenticator: authSvc, - DashboardFinalRedirectURL: dashboardURL, - SessionStore: sessionStore, - SecureCookies: false, - CookieExpiry: 86400, - }) - c.Assert(err, qt.IsNil) - - s.Config.Handler = h.Routes() - - s.Start() - - // Ensure redirectURL is matching port on listener - c.Assert(s.URL+"/callback", qt.Equals, redirectURL) - - return s -} - // TestBrowserAuth goes through the flow of a browser logging in, simulating // the cookie state and handling the callbacks are as expected. Additionally handling // the final callback to the dashboard emulating an endpoint. See setupTestServer @@ -91,72 +40,17 @@ func TestBrowserAuth(t *testing.T) { c := qt.New(t) db, sessionStore := setupDbAndSessionStore(c) - - // Setup final test redirect url server, to emulate - // the dashboard receiving the final piece of the flow - dashboardResponse := "dashboard received final callback" - dashboard := httptest.NewServer( - http.HandlerFunc( - func(w http.ResponseWriter, r *http.Request) { - fmt.Fprint(w, dashboardResponse) - sessionCookie, _ := r.Cookie("jimm-browser-session") - c.Assert(sessionCookie.Name, qt.Equals, "jimm-browser-session") - c.Assert(sessionCookie.Value, qt.Not(qt.Equals), "") - // Check the session exist in db - session, err := sessionStore.Get(r, "jimm-browser-session") - c.Assert(err, qt.IsNil) - c.Assert(session.Values["jimm-session"], qt.Equals, "jimm-test@canonical.com") - }, - ), - ) - defer dashboard.Close() - - s := setupTestServer(c, dashboard.URL, db, sessionStore) - defer s.Close() - - jar, err := cookiejar.New(nil) - c.Assert(err, qt.IsNil) - - client := &http.Client{ - Jar: jar, - CheckRedirect: func(req *http.Request, via []*http.Request) error { - fmt.Println("redirected to", req.URL) - return nil - }, - } - - res, err := client.Get(s.URL + "/login") - c.Assert(err, qt.IsNil) - c.Assert(res.StatusCode, qt.Equals, http.StatusOK) - - defer res.Body.Close() - b, err := io.ReadAll(res.Body) + cookie, err := jimmtest.RunBrowserLogin(db, sessionStore) c.Assert(err, qt.IsNil) - - re := regexp.MustCompile(`action="(.*?)" method=`) - match := re.FindStringSubmatch(string(b)) - loginFormUrl := match[1] - - v := url.Values{} - v.Add("username", "jimm-test") - v.Add("password", "password") - loginResp, err := client.PostForm(loginFormUrl, v) - c.Assert(err, qt.IsNil) - - b, err = io.ReadAll(loginResp.Body) - c.Assert(err, qt.IsNil) - - c.Assert(string(b), qt.Equals, dashboardResponse) - c.Assert(loginResp.StatusCode, qt.Equals, 200) - - defer loginResp.Body.Close() + c.Assert(cookie, qt.Not(qt.Equals), "") } func TestCallbackFailsNoCodePresent(t *testing.T) { c := qt.New(t) db, sessionStore := setupDbAndSessionStore(c) - s := setupTestServer(c, "", db, sessionStore) + s, err := jimmtest.SetupTestDashboardCallbackHandler("", db, sessionStore) + c.Assert(err, qt.IsNil) defer s.Close() // Test with no code present at all @@ -174,7 +68,8 @@ func TestCallbackFailsExchange(t *testing.T) { c := qt.New(t) db, sessionStore := setupDbAndSessionStore(c) - s := setupTestServer(c, "", db, sessionStore) + s, err := jimmtest.SetupTestDashboardCallbackHandler("", db, sessionStore) + c.Assert(err, qt.IsNil) defer s.Close() // Test with no code present at all diff --git a/internal/jimmhttp/websocket.go b/internal/jimmhttp/websocket.go index 178e22494..821056130 100644 --- a/internal/jimmhttp/websocket.go +++ b/internal/jimmhttp/websocket.go @@ -12,6 +12,8 @@ import ( "github.com/juju/zaputil/zapctx" "go.uber.org/zap" + "github.com/canonical/jimm/internal/auth" + "github.com/canonical/jimm/internal/jimm" "github.com/canonical/jimm/internal/servermon" ) @@ -33,6 +35,17 @@ type WSHandler struct { // been started. func (h *WSHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { ctx := req.Context() + var authErr error + + if h.Server != nil && h.Server.GetAuthenticationService() != nil { + ctx, authErr = handleBrowserAuthentication( + ctx, + h.Server.GetAuthenticationService(), + w, + req, + ) + } + ctx = context.WithValue(ctx, contextPathKey("path"), req.URL.EscapedPath()) conn, err := h.Upgrader.Upgrade(w, req, nil) if err != nil { @@ -41,28 +54,85 @@ func (h *WSHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { zapctx.Error(ctx, "cannot upgrade websocket", zap.Error(err)) return } + servermon.ConcurrentWebsocketConnections.Inc() defer conn.Close() defer servermon.ConcurrentWebsocketConnections.Dec() defer func() { if err := recover(); err != nil { zapctx.Error(ctx, "websocket panic", zap.Any("err", err), zap.Stack("stack")) - data := websocket.FormatCloseMessage(websocket.CloseInternalServerErr, fmt.Sprintf("%v", err)) - if err := conn.WriteControl(websocket.CloseMessage, data, time.Time{}); err != nil { - zapctx.Error(ctx, "cannot write close message", zap.Error(err)) - } + writeInternalServerErrorClosure(ctx, conn, err) } }() + + if authErr != nil { + zapctx.Error(ctx, "browser authentication error", zap.Any("err", authErr), zap.Stack("stack")) + writeInternalServerErrorClosure(ctx, conn, authErr) + return + } + if h.Server == nil { - data := websocket.FormatCloseMessage(websocket.CloseNormalClosure, "") - if err := conn.WriteControl(websocket.CloseMessage, data, time.Time{}); err != nil { - zapctx.Error(ctx, "cannot write close message", zap.Error(err)) - } + writeNormalClosure(ctx, conn) return } + h.Server.ServeWS(ctx, conn) } +func writeNormalClosure(ctx context.Context, conn *websocket.Conn) { + data := websocket.FormatCloseMessage(websocket.CloseNormalClosure, "") + if err := conn.WriteControl(websocket.CloseMessage, data, time.Time{}); err != nil { + zapctx.Error(ctx, "cannot write close message", zap.Error(err)) + } +} + +func writeInternalServerErrorClosure(ctx context.Context, conn *websocket.Conn, err any) { + data := websocket.FormatCloseMessage(websocket.CloseInternalServerErr, fmt.Sprintf("%v", err)) + if err := conn.WriteControl(websocket.CloseMessage, data, time.Time{}); err != nil { + zapctx.Error(ctx, "cannot write close message", zap.Error(err)) + } +} + +// handleBrowserAuthentication handles browser authentication when a session cookie +// is present, ultimately placing the identity resolved from the cookie within the +// passed context. +// +// It updates the response header on authentication errors with a InternalServerError, +// and as such is safe to return from your handler upon error without updating +// the response statuses. +func handleBrowserAuthentication(ctx context.Context, authSvc jimm.OAuthAuthenticator, w http.ResponseWriter, req *http.Request) (context.Context, error) { + // We perform cookie authentication at the HTTP layer instead of WS + // due to limitations of setting and retrieving cookies in the WS layer. + // + // If no cookie is present, we expect 1 of 3 scenarios: + // 1. It's a device session token login. + // 2. It's a client credential login. + // 3. It's an "expired" cookie login, and as such no cookie + // has been sent with the request. The handling of this is within + // LoginWithSessionCookie, in which, due to no identityId being present + // we know the cookie expired or a request with no cookie was made. + _, err := req.Cookie(auth.SessionName) + + // Now we know a cookie is present, so let's try perform a cookie login / logic + // as presumably a cookie of this name should only ever be present in the case + // the browser performs a connection. + if err == nil { + ctx, err = authSvc.AuthenticateBrowserSession( + ctx, w, req, + ) + if err != nil { + zapctx.Error(ctx, "authenticate browser session failed", zap.Error(err)) + // Something went wrong when trying to perform the authentication + // of the cookie. + return ctx, err + } + } + + // If there's an error due to failure to find the cookie, just return the context + // and move on presuming it's a device or client credentials login. + return ctx, nil +} + // A WSServer is a websocket server. // // ServeWS should handle all messaging on the websocket connection and @@ -70,4 +140,7 @@ func (h *WSHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { // the websocket connection, but not send any control messages. type WSServer interface { ServeWS(context.Context, *websocket.Conn) + + // GetAuthenticationService returns JIMM's authentication services. + GetAuthenticationService() jimm.OAuthAuthenticator } diff --git a/internal/jimmhttp/websocket_test.go b/internal/jimmhttp/websocket_test.go index 8542d2483..8079a0b5b 100644 --- a/internal/jimmhttp/websocket_test.go +++ b/internal/jimmhttp/websocket_test.go @@ -4,6 +4,7 @@ package jimmhttp_test import ( "context" + "net/http" "net/http/httptest" "strings" "testing" @@ -12,7 +13,10 @@ import ( qt "github.com/frankban/quicktest" "github.com/gorilla/websocket" + "github.com/canonical/jimm/internal/auth" + "github.com/canonical/jimm/internal/jimm" "github.com/canonical/jimm/internal/jimmhttp" + "github.com/canonical/jimm/internal/jimmtest" ) func TestWSHandler(t *testing.T) { @@ -57,6 +61,11 @@ func (s echoServer) ServeWS(ctx context.Context, conn *websocket.Conn) { } } +// GetAuthenticationService returns JIMM's oauth authentication service. +func (s echoServer) GetAuthenticationService() jimm.OAuthAuthenticator { + return nil +} + func TestWSHandlerPanic(t *testing.T) { c := qt.New(t) @@ -77,6 +86,11 @@ func TestWSHandlerPanic(t *testing.T) { type panicServer struct{} +// GetAuthenticationService returns JIMM's oauth authentication service. +func (s panicServer) GetAuthenticationService() jimm.OAuthAuthenticator { + return nil +} + func (s panicServer) ServeWS(ctx context.Context, conn *websocket.Conn) { panic("test") } @@ -96,3 +110,32 @@ func TestWSHandlerNilServer(t *testing.T) { _, _, err = conn.ReadMessage() c.Assert(err, qt.ErrorMatches, `websocket: close 1000 \(normal\)`) } + +type authFailServer struct{} + +// GetAuthenticationService returns JIMM's oauth authentication service. +func (s authFailServer) GetAuthenticationService() jimm.OAuthAuthenticator { + return jimmtest.NewMockOAuthAuthenticator("") +} + +func (s authFailServer) ServeWS(ctx context.Context, conn *websocket.Conn) {} + +func TestWSHandlerAuthFailsServer(t *testing.T) { + c := qt.New(t) + + hnd := &jimmhttp.WSHandler{ + Server: authFailServer{}, + } + + srv := httptest.NewServer(hnd) + c.Cleanup(srv.Close) + + var d websocket.Dialer + conn, _, err := d.Dial("ws"+strings.TrimPrefix(srv.URL, "http"), http.Header{ + "Cookie": []string{auth.SessionName + "=naughty_cookie"}, + }) + c.Assert(err, qt.IsNil) + + _, _, err = conn.ReadMessage() + c.Assert(err, qt.ErrorMatches, `websocket: close 1011 \(internal server error\): authentication failed`) +} diff --git a/internal/jimmjwx/utils_test.go b/internal/jimmjwx/utils_test.go index 642a776b7..3e31b7537 100644 --- a/internal/jimmjwx/utils_test.go +++ b/internal/jimmjwx/utils_test.go @@ -109,10 +109,11 @@ func setupService(ctx context.Context, c *qt.C) (*jimm.Service, *httptest.Server AuthModel: cofgaParams.AuthModelID, }, OAuthAuthenticatorParams: jimm.OAuthAuthenticatorParams{ - IssuerURL: "http://localhost:8082/realms/jimm", - ClientID: "jimm-device", - Scopes: []string{oidc.ScopeOpenID, "profile", "email"}, - SessionTokenExpiry: time.Duration(time.Hour), + IssuerURL: "http://localhost:8082/realms/jimm", + ClientID: "jimm-device", + Scopes: []string{oidc.ScopeOpenID, "profile", "email"}, + SessionTokenExpiry: time.Duration(time.Hour), + SessionCookieMaxAge: 60, }, DashboardFinalRedirectURL: "", }) diff --git a/internal/jimmtest/auth.go b/internal/jimmtest/auth.go index 19212b0f1..b46e1ee08 100644 --- a/internal/jimmtest/auth.go +++ b/internal/jimmtest/auth.go @@ -5,16 +5,30 @@ package jimmtest import ( "context" "encoding/base64" + "errors" + "fmt" + "io" + "net" + "net/http" + "net/http/cookiejar" + "net/http/httptest" + "net/url" + "regexp" + "strconv" "strings" "time" + "github.com/coreos/go-oidc/v3/oidc" + "github.com/gorilla/sessions" "github.com/juju/juju/api" jujuparams "github.com/juju/juju/rpc/params" "github.com/lestrrat-go/jwx/v2/jwa" "github.com/lestrrat-go/jwx/v2/jwt" "github.com/canonical/jimm/internal/auth" + "github.com/canonical/jimm/internal/db" "github.com/canonical/jimm/internal/jimm" + "github.com/canonical/jimm/internal/jimmhttp" "github.com/canonical/jimm/internal/openfga" ) @@ -57,6 +71,10 @@ func (m MockOAuthAuthenticator) VerifySessionToken(token string, secretKey strin return auth.VerifySessionToken(token, m.secretKey) } +func (m MockOAuthAuthenticator) AuthenticateBrowserSession(ctx context.Context, w http.ResponseWriter, req *http.Request) (context.Context, error) { + return ctx, errors.New("authentication failed") +} + // NewUserSessionLogin returns a login provider than be used with Juju Dial Opts // to define how login will take place. In this case we login using a session token // that the JIMM server should verify with the same test secret. @@ -86,3 +104,139 @@ func convertUsernameToEmail(username string) string { } return username } + +func SetupTestDashboardCallbackHandler(browserURL string, db *db.Database, sessionStore sessions.Store) (*httptest.Server, error) { + // Find a random free TCP port. + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + return nil, err + } + port := strconv.Itoa(listener.Addr().(*net.TCPAddr).Port) + + // Create unstarted server to enable auth service + s := httptest.NewUnstartedServer(nil) + s.Listener = listener + + // Remember redirect url to check it matches after test server starts + redirectURL := "http://127.0.0.1:" + port + "/callback" + authSvc, err := auth.NewAuthenticationService(context.Background(), auth.AuthenticationServiceParams{ + IssuerURL: "http://localhost:8082/realms/jimm", + ClientID: "jimm-device", + ClientSecret: "SwjDofnbDzJDm9iyfUhEp67FfUFMY8L4", + Scopes: []string{oidc.ScopeOpenID, "profile", "email"}, + SessionTokenExpiry: time.Hour, + // Now we know the port the test server is running on + RedirectURL: redirectURL, + Store: db, + SessionStore: sessionStore, + SessionCookieMaxAge: 60, + }) + if err != nil { + return nil, err + } + + h, err := jimmhttp.NewOAuthHandler(jimmhttp.OAuthHandlerParams{ + Authenticator: authSvc, + DashboardFinalRedirectURL: browserURL, + SecureCookies: false, + }) + if err != nil { + return nil, err + } + + s.Config.Handler = h.Routes() + + s.Start() + + // Ensure redirectURL is matching port on listener + if s.URL+"/callback" != redirectURL { + return s, errors.New("server callback does not match redirectURL") + } + + return s, nil +} + +func RunBrowserLogin(db *db.Database, sessionStore sessions.Store) (string, error) { + var cookieString string + + // Setup final test redirect url server, to emulate + // the dashboard receiving the final piece of the flow + dashboardResponse := "dashboard received final callback" + browser := httptest.NewServer( + http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + cookieString = r.Header.Get("Cookie") + w.Write([]byte(dashboardResponse)) + }, + ), + ) + defer browser.Close() + + s, err := SetupTestDashboardCallbackHandler(browser.URL, db, sessionStore) + if err != nil { + return cookieString, err + } + defer s.Close() + + jar, err := cookiejar.New(nil) + if err != nil { + return cookieString, err + } + + client := &http.Client{ + Jar: jar, + CheckRedirect: func(req *http.Request, via []*http.Request) error { + fmt.Println("redirected to", req.URL) + return nil + }, + } + + res, err := client.Get(s.URL + "/login") + if err != nil { + return cookieString, err + } + + if res.StatusCode != http.StatusOK { + return cookieString, errors.New("status code not ok") + } + + defer res.Body.Close() + b, err := io.ReadAll(res.Body) + if err != nil { + return cookieString, err + } + + re := regexp.MustCompile(`action="(.*?)" method=`) + match := re.FindStringSubmatch(string(b)) + loginFormUrl := match[1] + + v := url.Values{} + v.Add("username", "jimm-test") + v.Add("password", "password") + loginResp, err := client.PostForm(loginFormUrl, v) + if err != nil { + return cookieString, err + } + + b, err = io.ReadAll(loginResp.Body) + if err != nil { + return cookieString, err + } + + if string(b) != dashboardResponse { + return cookieString, errors.New("dashboard response not equal") + } + if loginResp.StatusCode != http.StatusOK { + return cookieString, errors.New("status code not ok") + } + + loginResp.Body.Close() + return cookieString, nil +} + +func ParseCookies(cookies string) []*http.Cookie { + header := http.Header{} + header.Add("Cookie", cookies) + request := http.Request{Header: header} + return request.Cookies() +} diff --git a/internal/jimmtest/suite.go b/internal/jimmtest/suite.go index ff8a465e1..49d872095 100644 --- a/internal/jimmtest/suite.go +++ b/internal/jimmtest/suite.go @@ -69,10 +69,12 @@ func (s *JIMMSuite) SetUpTest(c *gc.C) { s.OFGAClient, s.COFGAClient, s.COFGAParams, err = SetupTestOFGAClient(c.TestName()) c.Assert(err, gc.IsNil) + pgdb := PostgresDB(GocheckTester{c}, nil) + // Setup OpenFGA. s.JIMM = &jimm.JIMM{ Database: db.Database{ - DB: PostgresDB(GocheckTester{c}, nil), + DB: pgdb, }, CredentialStore: NewInMemoryCredentialStore(), Pubsub: &pubsub.Hub{MaxConcurrency: 10}, diff --git a/internal/jujuapi/admin.go b/internal/jujuapi/admin.go index e72965137..81a143f4e 100644 --- a/internal/jujuapi/admin.go +++ b/internal/jujuapi/admin.go @@ -71,6 +71,45 @@ func (r *controllerRoot) GetDeviceSessionToken(ctx context.Context) (params.GetD return response, nil } +// LoginWithSessionCookie is a facade call which has the cookie intercepted at the http layer, +// in which it is then placed on the controller root under "identityId", this identityId is used +// to perform a user lookup and authorise the login call. +// +// It may be misleading in that it does not interact with cookies at all, but this will only ever +// be successful upon the http layer login being successful. +func (r *controllerRoot) LoginWithSessionCookie(ctx context.Context) (jujuparams.LoginResult, error) { + const op = errors.Op("jujuapi.LoginWithSessionCookie") + + // If no identity ID has come through, then no cookie was present + // and as such authentication has failed. + if r.identityId == "" { + return jujuparams.LoginResult{}, errors.E(op, &auth.AuthenticationError{}) + } + + user, err := r.jimm.GetOpenFGAUserAndAuthorise(ctx, r.identityId) + if err != nil { + return jujuparams.LoginResult{}, errors.E(op, err) + } + + r.mu.Lock() + r.user = user + r.mu.Unlock() + + // Get server version for LoginResult + srvVersion, err := r.jimm.EarliestControllerVersion(ctx) + if err != nil { + return jujuparams.LoginResult{}, errors.E(op, err) + } + + return jujuparams.LoginResult{ + PublicDNSName: r.params.PublicDNSName, + UserInfo: setupAuthUserInfo(ctx, r, user), + ControllerTag: setupControllerTag(r), + Facades: setupFacades(r), + ServerVersion: srvVersion.String(), + }, nil +} + // LoginWithSessionToken handles logging into the JIMM via a session token that JIMM has // minted itself, this session token is simply a JWT containing the users email // at which point the email is used to perform a lookup for the user, authorise diff --git a/internal/jujuapi/admin_test.go b/internal/jujuapi/admin_test.go index 6a704f4ce..239c32f75 100644 --- a/internal/jujuapi/admin_test.go +++ b/internal/jujuapi/admin_test.go @@ -4,9 +4,11 @@ package jujuapi_test import ( "context" + "crypto/tls" "encoding/base64" "fmt" "io" + "net" "net/http" "net/http/cookiejar" "net/url" @@ -14,14 +16,19 @@ import ( "strings" "time" + "github.com/antonlindstrom/pgstore" "github.com/canonical/jimm/api/params" "github.com/canonical/jimm/internal/auth" "github.com/canonical/jimm/internal/dbmodel" "github.com/canonical/jimm/internal/jimmtest" + "github.com/gorilla/websocket" "github.com/coreos/go-oidc/v3/oidc" + "github.com/juju/errors" "github.com/juju/juju/api" + "github.com/juju/juju/rpc/jsoncodec" jujuparams "github.com/juju/juju/rpc/params" + "github.com/juju/juju/utils/proxy" "github.com/juju/names/v4" gc "gopkg.in/check.v1" ) @@ -34,15 +41,23 @@ func (s *adminSuite) SetUpTest(c *gc.C) { s.websocketSuite.SetUpTest(c) ctx := context.Background() + sqldb, err := s.JIMM.Database.DB.DB() + c.Assert(err, gc.IsNil) + + sessionStore, err := pgstore.NewPGStoreFromPool(sqldb, []byte("secretsecretdigletts")) + c.Assert(err, gc.IsNil) + // Replace JIMM's mock authenticator with a real one here // for testing the login flows. authSvc, err := auth.NewAuthenticationService(ctx, auth.AuthenticationServiceParams{ - IssuerURL: "http://localhost:8082/realms/jimm", - ClientID: "jimm-device", - ClientSecret: "SwjDofnbDzJDm9iyfUhEp67FfUFMY8L4", - Scopes: []string{oidc.ScopeOpenID, "profile", "email"}, - SessionTokenExpiry: time.Hour, - Store: &s.JIMM.Database, + IssuerURL: "http://localhost:8082/realms/jimm", + ClientID: "jimm-device", + ClientSecret: "SwjDofnbDzJDm9iyfUhEp67FfUFMY8L4", + Scopes: []string{oidc.ScopeOpenID, "profile", "email"}, + SessionTokenExpiry: time.Hour, + Store: &s.JIMM.Database, + SessionStore: sessionStore, + SessionCookieMaxAge: 60, }) c.Assert(err, gc.Equals, nil) s.JIMM.OAuthAuthenticator = authSvc @@ -62,6 +77,80 @@ func (s *adminSuite) TestLoginToController(c *gc.C) { c.Assert(jujuparams.ErrCode(err), gc.Equals, jujuparams.CodeNotImplemented) } +// TestBrowserLogin takes a test user through the flow of logging into jimm +// via the correct facades. All are done in a single test to see the flow end-2-end. +// +// Within the test are clear comments explaining what is happening when and why. +// Please refer to these comments for further details. +// +// We only test happy path here due to having tested edge cases and failure cases +// within the auth service itself such as invalid cookies, expired access tokens and +// missing/expired/revoked refresh tokens. + +func (s *adminSuite) TestBrowserLogin(c *gc.C) { + // The setup runs a browser login with callback, ultimately retrieving + // a logged in user by cookie. + sqldb, err := s.JIMM.DB().DB.DB() + c.Assert(err, gc.IsNil) + + sessionStore, err := pgstore.NewPGStoreFromPool(sqldb, []byte("secretsecretdigletts")) + c.Assert(err, gc.IsNil) + + cookie, err := jimmtest.RunBrowserLogin(s.JIMM.DB(), sessionStore) + c.Assert(err, gc.IsNil) + c.Assert(cookie, gc.Not(gc.Equals), "") + + cookies := jimmtest.ParseCookies(cookie) + c.Assert(cookies, gc.HasLen, 1) + + jar, err := cookiejar.New(nil) + c.Assert(err, gc.IsNil) + + // Now we move this cookie to the JIMM server on the admin suite and + // set the cookie on the jimm test server url so that the cookie can be + // sent on WS calls. + jimmURL, err := url.Parse(s.Server.URL) + c.Assert(err, gc.IsNil) + jar.SetCookies(jimmURL, cookies) + + conn := s.openWithDialWebsocket( + c, + &api.Info{ + SkipLogin: true, + }, + "test", + getDialWebsocketWithCustomCookieJar(jar), + ) + defer conn.Close() + + lr := &jujuparams.LoginResult{} + err = conn.APICall("Admin", 4, "", "LoginWithSessionCookie", nil, lr) + c.Assert(err, gc.IsNil) + + c.Assert(lr.UserInfo.Identity, gc.Equals, "user-jimm-test@canonical.com") + c.Assert(lr.UserInfo.DisplayName, gc.Equals, "jimm-test") +} + +// TestBrowserLoginNoCookie attempts to login without a cookie. +func (s *adminSuite) TestBrowserLoginNoCookie(c *gc.C) { + conn := s.open( + c, + &api.Info{ + SkipLogin: true, + }, + "test", + ) + defer conn.Close() + + lr := &jujuparams.LoginResult{} + err := conn.APICall("Admin", 4, "", "LoginWithSessionCookie", nil, lr) + c.Assert( + err, + gc.ErrorMatches, + "authentication failed", + ) +} + // TestDeviceLogin takes a test user through the flow of logging into jimm // via the correct facades. All are done in a single test to see the flow end-2-end. // @@ -235,3 +324,51 @@ func (s *adminSuite) TestLoginWithClientCredentials(c *gc.C) { }, &loginResult) c.Assert(err, gc.ErrorMatches, `invalid client credentials \(unauthorized access\)`) } + +// getDialWebsocketWithCustomCookieJar is mostly the default dialer configuration exception +// we need a dial websocket for juju containing a custom cookie jar to send cookies to +// a new server url when testing LoginWithSessionCookie. As such this closure simply +// passes the jar through. +func getDialWebsocketWithCustomCookieJar(jar *cookiejar.Jar) func(ctx context.Context, urlStr string, tlsConfig *tls.Config, ipAddr string) (jsoncodec.JSONConn, error) { + // Copied from github.com/juju/juju@v0.0.0-20240304110523-55fb5d03683b/api/apiclient.go + dialWebsocket := func(ctx context.Context, urlStr string, tlsConfig *tls.Config, ipAddr string) (jsoncodec.JSONConn, error) { + url, err := url.Parse(urlStr) + if err != nil { + return nil, errors.Trace(err) + } + + netDialer := net.Dialer{} + dialer := &websocket.Dialer{ + NetDial: func(netw, addr string) (net.Conn, error) { + if addr == url.Host { + addr = ipAddr + } + return netDialer.DialContext(ctx, netw, addr) + }, + Proxy: proxy.DefaultConfig.GetProxy, + HandshakeTimeout: 45 * time.Second, + TLSClientConfig: tlsConfig, + // We update the jar so that the cookies retrieved from RunBrowserLogin + // can be sent in the LoginWithSessionCookie call. + Jar: jar, + } + + c, resp, err := dialer.Dial(urlStr, nil) + if err != nil { + if err == websocket.ErrBadHandshake { + defer resp.Body.Close() + body, readErr := io.ReadAll(resp.Body) + if readErr == nil { + err = errors.Errorf( + "%s (%s)", + strings.TrimSpace(string(body)), + http.StatusText(resp.StatusCode), + ) + } + } + return nil, errors.Trace(err) + } + return jsoncodec.NewWebsocketConn(c), nil + } + return dialWebsocket +} diff --git a/internal/jujuapi/controllerroot.go b/internal/jujuapi/controllerroot.go index 9bcd708b9..6ba3a55be 100644 --- a/internal/jujuapi/controllerroot.go +++ b/internal/jujuapi/controllerroot.go @@ -131,9 +131,12 @@ type controllerRoot struct { // is created per WS, it is EXPECTED that the subsequent call to GetDeviceSessionToken // happens on the SAME websocket. deviceOAuthResponse *oauth2.DeviceAuthResponse + + // identityId is the id of the identity attempting to login via a session cookie. + identityId string } -func newControllerRoot(j JIMM, p Params) *controllerRoot { +func newControllerRoot(j JIMM, p Params, identityId string) *controllerRoot { watcherRegistry := &watcherRegistry{ watchers: make(map[string]*modelSummaryWatcher), } @@ -143,6 +146,7 @@ func newControllerRoot(j JIMM, p Params) *controllerRoot { watchers: watcherRegistry, pingF: func() {}, controllerUUIDMasking: true, + identityId: identityId, } r.AddMethod("Admin", 1, "Login", rpc.Method(unsupportedLogin)) @@ -152,6 +156,7 @@ func newControllerRoot(j JIMM, p Params) *controllerRoot { r.AddMethod("Admin", 4, "LoginDevice", rpc.Method(r.LoginDevice)) r.AddMethod("Admin", 4, "GetDeviceSessionToken", rpc.Method(r.GetDeviceSessionToken)) r.AddMethod("Admin", 4, "LoginWithSessionToken", rpc.Method(r.LoginWithSessionToken)) + r.AddMethod("Admin", 4, "LoginWithSessionCookie", rpc.Method(r.LoginWithSessionCookie)) r.AddMethod("Admin", 4, "LoginWithClientCredentials", rpc.Method(r.LoginWithClientCredentials)) r.AddMethod("Pinger", 1, "Ping", rpc.Method(r.Ping)) return r diff --git a/internal/jujuapi/export_test.go b/internal/jujuapi/export_test.go index e5ac3a80c..754435102 100644 --- a/internal/jujuapi/export_test.go +++ b/internal/jujuapi/export_test.go @@ -46,7 +46,7 @@ func ToJAASTag(db db.Database, tag *ofganames.Tag) (string, error) { } func NewControllerRoot(j JIMM, p Params) *controllerRoot { - return newControllerRoot(j, p) + return newControllerRoot(j, p, "") } func (r *controllerRoot) GetServiceAccount(ctx context.Context, clientID string) (*openfga.User, error) { diff --git a/internal/jujuapi/pinger_internal_test.go b/internal/jujuapi/pinger_internal_test.go index 5495156dd..8fbe84ece 100644 --- a/internal/jujuapi/pinger_internal_test.go +++ b/internal/jujuapi/pinger_internal_test.go @@ -14,7 +14,7 @@ import ( func TestControllerPing(t *testing.T) { c := qt.New(t) - r := newControllerRoot(nil, Params{}) + r := newControllerRoot(nil, Params{}, "") defer r.cleanup() var calls uint32 r.setPingF(func() { atomic.AddUint32(&calls, 1) }) diff --git a/internal/jujuapi/websocket.go b/internal/jujuapi/websocket.go index fd2fc9a84..823d7163f 100644 --- a/internal/jujuapi/websocket.go +++ b/internal/jujuapi/websocket.go @@ -17,6 +17,7 @@ import ( "github.com/juju/zaputil/zapctx" "go.uber.org/zap" + "github.com/canonical/jimm/internal/auth" "github.com/canonical/jimm/internal/dbmodel" "github.com/canonical/jimm/internal/errors" "github.com/canonical/jimm/internal/jimm" @@ -43,9 +44,15 @@ type apiServer struct { params Params } +// GetAuthenticationService returns JIMM's oauth authentication service. +func (s *apiServer) GetAuthenticationService() jimm.OAuthAuthenticator { + return s.jimm.OAuthAuthenticator +} + // ServeWS implements jimmhttp.WSServer. -func (s *apiServer) ServeWS(_ context.Context, conn *websocket.Conn) { - controllerRoot := newControllerRoot(s.jimm, s.params) +func (s *apiServer) ServeWS(ctx context.Context, conn *websocket.Conn) { + identityId := auth.SessionIdentityFromContext(ctx) + controllerRoot := newControllerRoot(s.jimm, s.params, identityId) s.cleanup = controllerRoot.cleanup Dblogger := controllerRoot.newAuditLogger() serveRoot(context.Background(), controllerRoot, Dblogger, conn) @@ -128,6 +135,11 @@ func modelInfoFromPath(path string) (uuid string, finalPath string, err error) { return matches[modelIndex], matches[finalPathIndex], nil } +// GetAuthenticationService returns JIMM's oauth authentication service. +func (s modelProxyServer) GetAuthenticationService() jimm.OAuthAuthenticator { + return s.jimm.OAuthAuthenticator +} + // ServeWS implements jimmhttp.WSServer. func (s modelProxyServer) ServeWS(ctx context.Context, clientConn *websocket.Conn) { jwtGenerator := jimm.NewJWTGenerator(&s.jimm.Database, s.jimm, s.jimm.JWTService) diff --git a/internal/jujuapi/websocket_test.go b/internal/jujuapi/websocket_test.go index e0fb1bb24..833d8fd39 100644 --- a/internal/jujuapi/websocket_test.go +++ b/internal/jujuapi/websocket_test.go @@ -5,6 +5,7 @@ package jujuapi_test import ( "bytes" "context" + "crypto/tls" "encoding/pem" "fmt" "net/http" @@ -12,6 +13,7 @@ import ( "net/url" "github.com/juju/juju/api" + "github.com/juju/juju/rpc/jsoncodec" jujuparams "github.com/juju/juju/rpc/params" "github.com/juju/names/v5" gc "gopkg.in/check.v1" @@ -99,7 +101,12 @@ func (s *websocketSuite) TearDownTest(c *gc.C) { // openNoAssert creates a new websocket connection to the test server, using the // connection info specified in info, authenticating as the given user. // If info is nil then default values will be used. -func (s *websocketSuite) openNoAssert(c *gc.C, info *api.Info, username string) (api.Connection, error) { +func (s *websocketSuite) openNoAssert( + c *gc.C, + info *api.Info, + username string, + dialWebsocket func(ctx context.Context, urlStr string, tlsConfig *tls.Config, ipAddr string) (jsoncodec.JSONConn, error), +) (api.Connection, error) { var inf api.Info if info != nil { inf = *info @@ -119,14 +126,31 @@ func (s *websocketSuite) openNoAssert(c *gc.C, info *api.Info, username string) lp := jimmtest.NewUserSessionLogin(c, username) - return api.Open(&inf, api.DialOpts{ + dialOpts := api.DialOpts{ InsecureSkipVerify: true, LoginProvider: lp, - }) + } + + if dialWebsocket != nil { + dialOpts.DialWebsocket = dialWebsocket + } + + return api.Open(&inf, dialOpts) } func (s *websocketSuite) open(c *gc.C, info *api.Info, username string) api.Connection { - conn, err := s.openNoAssert(c, info, username) + conn, err := s.openNoAssert(c, info, username, nil) + c.Assert(err, gc.Equals, nil) + return conn +} + +func (s *websocketSuite) openWithDialWebsocket( + c *gc.C, + info *api.Info, + username string, + dialWebsocket func(ctx context.Context, urlStr string, tlsConfig *tls.Config, ipAddr string) (jsoncodec.JSONConn, error), +) api.Connection { + conn, err := s.openNoAssert(c, info, username, dialWebsocket) c.Assert(err, gc.Equals, nil) return conn } diff --git a/service.go b/service.go index 72936970c..29d88115f 100644 --- a/service.go +++ b/service.go @@ -75,6 +75,8 @@ type OAuthAuthenticatorParams struct { // SessionTokenExpiry holds the expiry duration for issued JWTs // for user (CLI) to JIMM authentication. SessionTokenExpiry time.Duration + // SessionCookieMaxAge holds the max age for session cookies. + SessionCookieMaxAge int } // A Params structure contains the parameters required to initialise a new @@ -168,9 +170,6 @@ type Params struct { // SecureSessionCookies determines if HTTPS must be enabled in order for JIMM // to set cookies when creating browser based sessions. SecureSessionCookies bool - - // SessionCookieExpiry is how long the cookie will be valid before expiring in seconds. - SessionCookieExpiry int } // A Service is the implementation of a JIMM server. @@ -266,7 +265,6 @@ func NewService(ctx context.Context, p Params) (*Service, error) { // Cleanup expired session every 30 minutes defer sessionStore.StopCleanup(sessionStore.Cleanup(time.Minute * 30)) - s.jimm.CookieSessionStore = sessionStore if p.AuditLogRetentionPeriodInDays != "" { period, err := strconv.Atoi(p.AuditLogRetentionPeriodInDays) @@ -293,12 +291,14 @@ func NewService(ctx context.Context, p Params) (*Service, error) { authSvc, err := auth.NewAuthenticationService( ctx, auth.AuthenticationServiceParams{ - IssuerURL: p.OAuthAuthenticatorParams.IssuerURL, - ClientID: p.OAuthAuthenticatorParams.ClientID, - ClientSecret: p.OAuthAuthenticatorParams.ClientSecret, - Scopes: p.OAuthAuthenticatorParams.Scopes, - SessionTokenExpiry: p.OAuthAuthenticatorParams.SessionTokenExpiry, - Store: &s.jimm.Database, + IssuerURL: p.OAuthAuthenticatorParams.IssuerURL, + ClientID: p.OAuthAuthenticatorParams.ClientID, + ClientSecret: p.OAuthAuthenticatorParams.ClientSecret, + Scopes: p.OAuthAuthenticatorParams.Scopes, + SessionTokenExpiry: p.OAuthAuthenticatorParams.SessionTokenExpiry, + SessionCookieMaxAge: p.OAuthAuthenticatorParams.SessionCookieMaxAge, + Store: &s.jimm.Database, + SessionStore: sessionStore, }, ) s.jimm.OAuthAuthenticator = authSvc @@ -353,9 +353,7 @@ func NewService(ctx context.Context, p Params) (*Service, error) { oauthHandler, err := jimmhttp.NewOAuthHandler(jimmhttp.OAuthHandlerParams{ Authenticator: authSvc, DashboardFinalRedirectURL: p.DashboardFinalRedirectURL, - SessionStore: sessionStore, SecureCookies: p.SecureSessionCookies, - CookieExpiry: p.SessionCookieExpiry, }) if err != nil { return nil, errors.E(op, err, "failed to setup authentication handler") diff --git a/service_test.go b/service_test.go index 17d61bef2..b243fa0c3 100644 --- a/service_test.go +++ b/service_test.go @@ -47,10 +47,11 @@ func TestDefaultService(t *testing.T) { OpenFGAParams: cofgaParamsToJIMMOpenFGAParams(*cofgaParams), InsecureSecretStorage: true, OAuthAuthenticatorParams: jimm.OAuthAuthenticatorParams{ - IssuerURL: "http://localhost:8082/realms/jimm", - ClientID: "jimm-device", - Scopes: []string{oidc.ScopeOpenID, "profile", "email"}, - SessionTokenExpiry: time.Duration(time.Hour), + IssuerURL: "http://localhost:8082/realms/jimm", + ClientID: "jimm-device", + Scopes: []string{oidc.ScopeOpenID, "profile", "email"}, + SessionTokenExpiry: time.Duration(time.Hour), + SessionCookieMaxAge: 60, }, DashboardFinalRedirectURL: "", }) @@ -72,10 +73,11 @@ func TestServiceStartsWithoutSecretStore(t *testing.T) { DSN: jimmtest.CreateEmptyDatabase(c), OpenFGAParams: cofgaParamsToJIMMOpenFGAParams(*cofgaParams), OAuthAuthenticatorParams: jimm.OAuthAuthenticatorParams{ - IssuerURL: "http://localhost:8082/realms/jimm", - ClientID: "jimm-device", - Scopes: []string{oidc.ScopeOpenID, "profile", "email"}, - SessionTokenExpiry: time.Duration(time.Hour), + IssuerURL: "http://localhost:8082/realms/jimm", + ClientID: "jimm-device", + Scopes: []string{oidc.ScopeOpenID, "profile", "email"}, + SessionTokenExpiry: time.Duration(time.Hour), + SessionCookieMaxAge: 60, }, DashboardFinalRedirectURL: "", }) @@ -96,10 +98,11 @@ func TestAuthenticator(t *testing.T) { OpenFGAParams: cofgaParamsToJIMMOpenFGAParams(*cofgaParams), InsecureSecretStorage: true, OAuthAuthenticatorParams: jimm.OAuthAuthenticatorParams{ - IssuerURL: "http://localhost:8082/realms/jimm", - ClientID: "jimm-device", - Scopes: []string{oidc.ScopeOpenID, "profile", "email"}, - SessionTokenExpiry: time.Duration(time.Hour), + IssuerURL: "http://localhost:8082/realms/jimm", + ClientID: "jimm-device", + Scopes: []string{oidc.ScopeOpenID, "profile", "email"}, + SessionTokenExpiry: time.Duration(time.Hour), + SessionCookieMaxAge: 60, }, DashboardFinalRedirectURL: "", } @@ -169,10 +172,11 @@ func TestVault(t *testing.T) { VaultSecretFile: "./local/vault/approle.json", OpenFGAParams: cofgaParamsToJIMMOpenFGAParams(*cofgaParams), OAuthAuthenticatorParams: jimm.OAuthAuthenticatorParams{ - IssuerURL: "http://localhost:8082/realms/jimm", - ClientID: "jimm-device", - Scopes: []string{oidc.ScopeOpenID, "profile", "email"}, - SessionTokenExpiry: time.Duration(time.Hour), + IssuerURL: "http://localhost:8082/realms/jimm", + ClientID: "jimm-device", + Scopes: []string{oidc.ScopeOpenID, "profile", "email"}, + SessionTokenExpiry: time.Duration(time.Hour), + SessionCookieMaxAge: 60, }, DashboardFinalRedirectURL: "", } @@ -241,10 +245,11 @@ func TestPostgresSecretStore(t *testing.T) { OpenFGAParams: cofgaParamsToJIMMOpenFGAParams(*cofgaParams), InsecureSecretStorage: true, OAuthAuthenticatorParams: jimm.OAuthAuthenticatorParams{ - IssuerURL: "http://localhost:8082/realms/jimm", - ClientID: "jimm-device", - Scopes: []string{oidc.ScopeOpenID, "profile", "email"}, - SessionTokenExpiry: time.Duration(time.Hour), + IssuerURL: "http://localhost:8082/realms/jimm", + ClientID: "jimm-device", + Scopes: []string{oidc.ScopeOpenID, "profile", "email"}, + SessionTokenExpiry: time.Duration(time.Hour), + SessionCookieMaxAge: 60, }, DashboardFinalRedirectURL: "", } @@ -266,10 +271,11 @@ func TestOpenFGA(t *testing.T) { ControllerAdmins: []string{"alice", "eve"}, InsecureSecretStorage: true, OAuthAuthenticatorParams: jimm.OAuthAuthenticatorParams{ - IssuerURL: "http://localhost:8082/realms/jimm", - ClientID: "jimm-device", - Scopes: []string{oidc.ScopeOpenID, "profile", "email"}, - SessionTokenExpiry: time.Duration(time.Hour), + IssuerURL: "http://localhost:8082/realms/jimm", + ClientID: "jimm-device", + Scopes: []string{oidc.ScopeOpenID, "profile", "email"}, + SessionTokenExpiry: time.Duration(time.Hour), + SessionCookieMaxAge: 60, }, DashboardFinalRedirectURL: "", } @@ -326,10 +332,11 @@ func TestPublicKey(t *testing.T) { PrivateKey: "c1VkV05+iWzCxMwMVcWbr0YJWQSEO62v+z3EQ2BhFMw=", PublicKey: "pC8MEk9MS9S8fhyRnOJ4qARTcTAwoM9L1nH/Yq0MwWU=", OAuthAuthenticatorParams: jimm.OAuthAuthenticatorParams{ - IssuerURL: "http://localhost:8082/realms/jimm", - ClientID: "jimm-device", - Scopes: []string{oidc.ScopeOpenID, "profile", "email"}, - SessionTokenExpiry: time.Duration(time.Hour), + IssuerURL: "http://localhost:8082/realms/jimm", + ClientID: "jimm-device", + Scopes: []string{oidc.ScopeOpenID, "profile", "email"}, + SessionTokenExpiry: time.Duration(time.Hour), + SessionCookieMaxAge: 60, }, DashboardFinalRedirectURL: "", } @@ -414,10 +421,11 @@ func TestThirdPartyCaveatDischarge(t *testing.T) { PrivateKey: "c1VkV05+iWzCxMwMVcWbr0YJWQSEO62v+z3EQ2BhFMw=", PublicKey: "pC8MEk9MS9S8fhyRnOJ4qARTcTAwoM9L1nH/Yq0MwWU=", OAuthAuthenticatorParams: jimm.OAuthAuthenticatorParams{ - IssuerURL: "http://localhost:8082/realms/jimm", - ClientID: "jimm-device", - Scopes: []string{oidc.ScopeOpenID, "profile", "email"}, - SessionTokenExpiry: time.Duration(time.Hour), + IssuerURL: "http://localhost:8082/realms/jimm", + ClientID: "jimm-device", + Scopes: []string{oidc.ScopeOpenID, "profile", "email"}, + SessionTokenExpiry: time.Duration(time.Hour), + SessionCookieMaxAge: 60, }, DashboardFinalRedirectURL: "", }