diff --git a/cmds/core-service/main.go b/cmds/core-service/main.go index bdd21e2eb..1a05ff148 100644 --- a/cmds/core-service/main.go +++ b/cmds/core-service/main.go @@ -281,10 +281,9 @@ func RunHTTPServer(ctx context.Context, ctxCanceler func(), address, locality st multiRouter.Routers = append(multiRouter.Routers, &scdV1Router) } - handler := logging.HTTPMiddleware(logger, *dumpRequests, - healthyEndpointMiddleware(logger, - &multiRouter, - )) + handler := logging.HTTPMiddleware(logger, *dumpRequests, healthyEndpointMiddleware(logger, &multiRouter)) + + handler = auth.DecoderMiddleware(authorizer, handler) httpServer := &http.Server{ Addr: address, diff --git a/pkg/auth/auth.go b/pkg/auth/auth.go index f49b468e3..ea6dca949 100644 --- a/pkg/auth/auth.go +++ b/pkg/auth/auth.go @@ -185,29 +185,42 @@ func (a *Authorizer) setKeys(keys []interface{}) { // Authorize extracts and verifies bearer tokens from a http.Request. func (a *Authorizer) Authorize(_ http.ResponseWriter, r *http.Request, authOptions []api.AuthorizationOption) api.AuthorizationResult { - tknStr, ok := getToken(r) - if !ok { - return api.AuthorizationResult{Error: stacktrace.NewErrorWithCode(dsserr.Unauthenticated, "Missing access token")} - } - - a.keyGuard.RLock() - keys := a.keys - a.keyGuard.RUnlock() + missing := false validated := false - var err error - var keyClaims claims + keyClaims := claims{} + err := error(nil) + ok := false + + // if the decoding middleware wasn't triggered, attempt decoding + if _, ok = r.Context().Value("authDecoded").(bool); !ok { + missing, validated, keyClaims, err = a.extractClaims(r) + } else { + // Use previously decoded values + missing, ok = r.Context().Value("authTokenMissing").(bool) + if !ok { + missing = true + } - for _, key := range keys { - keyClaims = claims{} - key := key - _, err = jwt.ParseWithClaims(tknStr, &keyClaims, func(token *jwt.Token) (interface{}, error) { - return key, nil - }) - if err == nil { - validated = true - break + validated, ok = r.Context().Value("authValidated").(bool) + if !ok { + validated = false + } + + err, ok = r.Context().Value("authError").(error) + if !ok { + err = nil + } + + keyClaims, ok = r.Context().Value("authClaims").(claims) + if !ok { + keyClaims = claims{} } } + + if missing { + return api.AuthorizationResult{Error: stacktrace.NewErrorWithCode(dsserr.Unauthenticated, "Missing access token")} + } + if !validated { return api.AuthorizationResult{Error: stacktrace.PropagateWithCode(err, dsserr.Unauthenticated, "Access token validation failed")} } @@ -300,3 +313,59 @@ func getToken(r *http.Request) (string, bool) { } return authHeader[7:], true } + +// Extract valid claims from a JWT token. +func (a *Authorizer) extractClaims(r *http.Request) (missing bool, validated bool, keyClaims claims, err error) { + tknStr, ok := getToken(r) + + if !ok { + return true, false, claims{}, nil + } + + a.keyGuard.RLock() + keys := a.keys + a.keyGuard.RUnlock() + validated = false + missing = false + + for _, key := range keys { + keyClaims = claims{} + key := key + _, err = jwt.ParseWithClaims(tknStr, &keyClaims, func(token *jwt.Token) (interface{}, error) { + return key, nil + }) + if err == nil { + validated = true + break + } + } + + return missing, validated, keyClaims, err + +} + +// Decoder Middleware +func DecoderMiddleware(a *Authorizer, handler http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + + missing, validated, keyClaims, err := a.extractClaims(r) + + ctx := context.WithValue(r.Context(), "authValidated", validated) + ctx = context.WithValue(ctx, "authTokenMissing", missing) + ctx = context.WithValue(ctx, "authClaims", keyClaims) + ctx = context.WithValue(ctx, "authError", err) + + if validated && err == nil { + // If the token is valid, we can extract the subject from the token and add them to the context. + ctx = context.WithValue(ctx, "authSubject", keyClaims.Subject) + } + + // Add a flag to the context to indicate that the token was decoded + ctx = context.WithValue(ctx, "authDecoded", true) + + // Create a new request with the updated context + r = r.WithContext(ctx) + + handler.ServeHTTP(w, r) + }) +} diff --git a/pkg/logging/http.go b/pkg/logging/http.go index ce9a5473c..2a933fcc7 100644 --- a/pkg/logging/http.go +++ b/pkg/logging/http.go @@ -67,6 +67,12 @@ func HTTPMiddleware(logger *zap.Logger, dump bool, handler http.Handler) http.Ha // replace req.Body with a copy r.Body = io.NopCloser(bytes.NewReader(reqData)) } + + subject, ok := r.Context().Value("authSubject").(string) + if !ok { + subject = "" + } + logger = logger.With(zap.String("req_sub", subject)) } handler.ServeHTTP(trw, r)