diff --git a/internal/middleware/auth.go b/internal/middleware/auth.go index 44120ffbf..d9505092a 100644 --- a/internal/middleware/auth.go +++ b/internal/middleware/auth.go @@ -3,6 +3,7 @@ package middleware import ( + "context" "net/http" rebac_handlers "github.com/canonical/rebac-admin-ui-handlers/v1" @@ -10,9 +11,14 @@ import ( "go.uber.org/zap" "github.com/canonical/jimm/v3/internal/auth" + "github.com/canonical/jimm/v3/internal/errors" "github.com/canonical/jimm/v3/internal/jujuapi" + "github.com/canonical/jimm/v3/internal/openfga" ) +// identityContextKey is the unique key to extract user from context for basic-auth authentication +type identityContextKey struct{} + // AuthenticateViaCookie performs browser session authentication and puts an identity in the request's context func AuthenticateViaCookie(next http.Handler, jimm jujuapi.JIMM) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -57,3 +63,40 @@ func AuthenticateRebac(next http.Handler, jimm jujuapi.JIMM) http.Handler { next.ServeHTTP(w, r.WithContext(ctx)) }), jimm) } + +// AuthenticateWithSessionTokenViaBasicAuth performs basic auth authentication and puts an identity in the request's context. +// The basic-auth is composed of an empty user, and as a password a jwt token that we parse and use to authenticate the user. +func AuthenticateWithSessionTokenViaBasicAuth(next http.Handler, jimm jujuapi.JIMM) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + // extract auth token + _, password, ok := r.BasicAuth() + if !ok { + w.WriteHeader(http.StatusUnauthorized) + _, _ = w.Write([]byte("authentication missing")) + return + } + user, err := jimm.LoginWithSessionToken(ctx, password) + if err != nil { + w.WriteHeader(http.StatusUnauthorized) + _, _ = w.Write([]byte("error authenticating the user")) + return + } + next.ServeHTTP(w, r.WithContext(withIdentity(ctx, user))) + }) +} + +// IdentityFromContext extracts the user from the context. +func IdentityFromContext(ctx context.Context) (*openfga.User, error) { + identity := ctx.Value(identityContextKey{}) + user, ok := identity.(*openfga.User) + if !ok { + return nil, errors.E("cannot extract user from context") + } + return user, nil +} + +// withIdentity sets the user into the context and return the context +func withIdentity(ctx context.Context, user *openfga.User) context.Context { + return context.WithValue(ctx, identityContextKey{}, user) +} diff --git a/internal/middleware/auth_test.go b/internal/middleware/auth_test.go index a622f542b..a16bc6ab0 100644 --- a/internal/middleware/auth_test.go +++ b/internal/middleware/auth_test.go @@ -5,6 +5,7 @@ package middleware_test import ( "context" "errors" + "io" "net/http" "net/http/httptest" "testing" @@ -14,6 +15,7 @@ import ( "github.com/canonical/jimm/v3/internal/auth" "github.com/canonical/jimm/v3/internal/dbmodel" + jimm_errors "github.com/canonical/jimm/v3/internal/errors" "github.com/canonical/jimm/v3/internal/jimmtest" "github.com/canonical/jimm/v3/internal/jimmtest/mocks" "github.com/canonical/jimm/v3/internal/middleware" @@ -96,3 +98,71 @@ func TestAuthenticateRebac(t *testing.T) { }) } } + +func TestAuthenticateViaBasicAuth(t *testing.T) { + testUser := "test-user@canonical.com" + jt := jimmtest.JIMM{ + LoginService: mocks.LoginService{ + LoginWithSessionToken_: func(ctx context.Context, sessionToken string) (*openfga.User, error) { + if sessionToken != "good" { + return nil, jimm_errors.E(jimm_errors.CodeSessionTokenInvalid) + } + user := dbmodel.Identity{Name: testUser} + return &openfga.User{Identity: &user, JimmAdmin: true}, nil + }, + }, + } + tests := []struct { + name string + jimmAdmin bool + expectedStatus int + basicAuthPassword string + errorExpected string + }{ + { + name: "success", + jimmAdmin: true, + expectedStatus: http.StatusOK, + basicAuthPassword: "good", + }, + { + name: "failure", + expectedStatus: http.StatusUnauthorized, + basicAuthPassword: "bad", + errorExpected: "error authenticating the user", + }, + { + name: "no basic auth", + expectedStatus: http.StatusUnauthorized, + errorExpected: "authentication missing", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := qt.New(t) + req := httptest.NewRequest(http.MethodGet, "/", nil) + w := httptest.NewRecorder() + if tt.basicAuthPassword != "" { + req.SetBasicAuth("", tt.basicAuthPassword) + } + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + user, err := middleware.IdentityFromContext(r.Context()) + c.Assert(err, qt.IsNil) + c.Assert(user.Name, qt.Equals, testUser) + w.WriteHeader(http.StatusOK) + }) + middleware := middleware.AuthenticateWithSessionTokenViaBasicAuth(handler, &jt) + middleware.ServeHTTP(w, req) + c.Assert(w.Code, qt.Equals, tt.expectedStatus) + b := w.Result().Body + defer b.Close() + body, err := io.ReadAll(b) + c.Assert(err, qt.IsNil) + if tt.errorExpected != "" { + c.Assert(string(body), qt.Matches, tt.errorExpected) + } + + }) + } +}