From c89bab4dab2bb7554f45789b53d0d20de6139664 Mon Sep 17 00:00:00 2001 From: Ahmad ElRouby Date: Fri, 6 Sep 2024 15:17:20 +0200 Subject: [PATCH 1/5] split auth into decode and authorize + logging jwt sub --- cmds/core-service/main.go | 7 ++-- pkg/auth/auth.go | 84 +++++++++++++++++++++++++++++++-------- pkg/logging/http.go | 6 +++ 3 files changed, 76 insertions(+), 21 deletions(-) diff --git a/cmds/core-service/main.go b/cmds/core-service/main.go index bdd21e2eb..8dfc784c7 100644 --- a/cmds/core-service/main.go +++ b/cmds/core-service/main.go @@ -281,10 +281,9 @@ func RunHTTPServer(ctx context.Context, ctxCanceler func(), address, locality st multiRouter.Routers = append(multiRouter.Routers, &scdV1Router) } - handler := logging.HTTPMiddleware(logger, *dumpRequests, - healthyEndpointMiddleware(logger, - &multiRouter, - )) + handler := logging.HTTPMiddleware(logger, *dumpRequests, &multiRouter) + + handler = auth.DecoderMiddleware(authorizer, healthyEndpointMiddleware(logger, handler)) httpServer := &http.Server{ Addr: address, diff --git a/pkg/auth/auth.go b/pkg/auth/auth.go index f49b468e3..fa96f74fb 100644 --- a/pkg/auth/auth.go +++ b/pkg/auth/auth.go @@ -185,29 +185,26 @@ func (a *Authorizer) setKeys(keys []interface{}) { // Authorize extracts and verifies bearer tokens from a http.Request. func (a *Authorizer) Authorize(_ http.ResponseWriter, r *http.Request, authOptions []api.AuthorizationOption) api.AuthorizationResult { - tknStr, ok := getToken(r) - if !ok { + missing, ok := r.Context().Value("authTokenMissing").(bool) + if !ok || missing { return api.AuthorizationResult{Error: stacktrace.NewErrorWithCode(dsserr.Unauthenticated, "Missing access token")} } - a.keyGuard.RLock() - keys := a.keys - a.keyGuard.RUnlock() - validated := false - var err error - var keyClaims claims + validated, ok := r.Context().Value("authValidated").(bool) + if !ok { + validated = false + } - for _, key := range keys { + err, ok := r.Context().Value("authError").(error) + if !ok { + err = nil + } + + keyClaims, ok := r.Context().Value("authClaims").(claims) + if !ok { keyClaims = claims{} - key := key - _, err = jwt.ParseWithClaims(tknStr, &keyClaims, func(token *jwt.Token) (interface{}, error) { - return key, nil - }) - if err == nil { - validated = true - break - } } + if !validated { return api.AuthorizationResult{Error: stacktrace.PropagateWithCode(err, dsserr.Unauthenticated, "Access token validation failed")} } @@ -300,3 +297,56 @@ func getToken(r *http.Request) (string, bool) { } return authHeader[7:], true } + +// Extract valid claims from a JWT token. +func (a *Authorizer) extractClaims(r *http.Request) (missing bool, validated bool, keyClaims claims, err error) { + tknStr, ok := getToken(r) + + if !ok { + return true, false, claims{}, nil + } + + a.keyGuard.RLock() + keys := a.keys + a.keyGuard.RUnlock() + validated = false + missing = false + + for _, key := range keys { + keyClaims = claims{} + key := key + _, err = jwt.ParseWithClaims(tknStr, &keyClaims, func(token *jwt.Token) (interface{}, error) { + return key, nil + }) + if err == nil { + validated = true + break + } + } + + return missing, validated, keyClaims, err + +} + +// Decoder Middleware +func DecoderMiddleware(a *Authorizer, handler http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + + missing, validated, keyClaims, err := a.extractClaims(r) + + ctx := context.WithValue(r.Context(), "authValidated", validated) + ctx = context.WithValue(ctx, "authTokenMissing", missing) + ctx = context.WithValue(ctx, "authClaims", keyClaims) + ctx = context.WithValue(ctx, "authError", err) + + if validated && err == nil { + // If the token is valid, we can extract the subject from the token and add them to the context. + ctx = context.WithValue(ctx, "authSubject", keyClaims.Subject) + } + + // Create a new request with the updated context + r = r.WithContext(ctx) + + handler.ServeHTTP(w, r) + }) +} diff --git a/pkg/logging/http.go b/pkg/logging/http.go index ce9a5473c..2a933fcc7 100644 --- a/pkg/logging/http.go +++ b/pkg/logging/http.go @@ -67,6 +67,12 @@ func HTTPMiddleware(logger *zap.Logger, dump bool, handler http.Handler) http.Ha // replace req.Body with a copy r.Body = io.NopCloser(bytes.NewReader(reqData)) } + + subject, ok := r.Context().Value("authSubject").(string) + if !ok { + subject = "" + } + logger = logger.With(zap.String("req_sub", subject)) } handler.ServeHTTP(trw, r) From 9e059e464f13d91df706f9dd6ac5efb45bf2d8c2 Mon Sep 17 00:00:00 2001 From: Ahmad ElRouby Date: Fri, 6 Sep 2024 15:29:53 +0200 Subject: [PATCH 2/5] run decode if middleware not triggered --- pkg/auth/auth.go | 47 ++++++++++++++++++++++++++++++++--------------- 1 file changed, 32 insertions(+), 15 deletions(-) diff --git a/pkg/auth/auth.go b/pkg/auth/auth.go index fa96f74fb..be41ddf1c 100644 --- a/pkg/auth/auth.go +++ b/pkg/auth/auth.go @@ -185,24 +185,38 @@ func (a *Authorizer) setKeys(keys []interface{}) { // Authorize extracts and verifies bearer tokens from a http.Request. func (a *Authorizer) Authorize(_ http.ResponseWriter, r *http.Request, authOptions []api.AuthorizationOption) api.AuthorizationResult { - missing, ok := r.Context().Value("authTokenMissing").(bool) - if !ok || missing { - return api.AuthorizationResult{Error: stacktrace.NewErrorWithCode(dsserr.Unauthenticated, "Missing access token")} - } + missing := false + validated := false + keyClaims := claims{} + err := error(nil) + ok := false + + // if the decoding middleware wasn't triggered, attempt decoding + if _, ok = r.Context().Value("authDecoded").(bool); !ok { + a.logger.Info("Decoding middleware not triggered") + missing, validated, keyClaims, err = a.extractClaims(r) + } else { + a.logger.Info("Decoding middleware already triggered") + // Use previously decoded values + missing, ok = r.Context().Value("authTokenMissing").(bool) + if !ok || missing { + return api.AuthorizationResult{Error: stacktrace.NewErrorWithCode(dsserr.Unauthenticated, "Missing access token")} + } - validated, ok := r.Context().Value("authValidated").(bool) - if !ok { - validated = false - } + validated, ok = r.Context().Value("authValidated").(bool) + if !ok { + validated = false + } - err, ok := r.Context().Value("authError").(error) - if !ok { - err = nil - } + err, ok = r.Context().Value("authError").(error) + if !ok { + err = nil + } - keyClaims, ok := r.Context().Value("authClaims").(claims) - if !ok { - keyClaims = claims{} + keyClaims, ok = r.Context().Value("authClaims").(claims) + if !ok { + keyClaims = claims{} + } } if !validated { @@ -344,6 +358,9 @@ func DecoderMiddleware(a *Authorizer, handler http.Handler) http.Handler { ctx = context.WithValue(ctx, "authSubject", keyClaims.Subject) } + // Add a flag to the context to indicate that the token was decoded + ctx = context.WithValue(ctx, "authDecoded", true) + // Create a new request with the updated context r = r.WithContext(ctx) From 805ea3e53de289ce76bf40f1a83d3e84f3e6ce20 Mon Sep 17 00:00:00 2001 From: Ahmad ElRouby Date: Fri, 6 Sep 2024 15:32:57 +0200 Subject: [PATCH 3/5] delete extra logs --- pkg/auth/auth.go | 2 -- 1 file changed, 2 deletions(-) diff --git a/pkg/auth/auth.go b/pkg/auth/auth.go index be41ddf1c..d8634e02a 100644 --- a/pkg/auth/auth.go +++ b/pkg/auth/auth.go @@ -193,10 +193,8 @@ func (a *Authorizer) Authorize(_ http.ResponseWriter, r *http.Request, authOptio // if the decoding middleware wasn't triggered, attempt decoding if _, ok = r.Context().Value("authDecoded").(bool); !ok { - a.logger.Info("Decoding middleware not triggered") missing, validated, keyClaims, err = a.extractClaims(r) } else { - a.logger.Info("Decoding middleware already triggered") // Use previously decoded values missing, ok = r.Context().Value("authTokenMissing").(bool) if !ok || missing { From 3862a074b7b27d910d69f177de0629ee70ef4949 Mon Sep 17 00:00:00 2001 From: Ahmad ElRouby Date: Fri, 6 Sep 2024 15:50:12 +0200 Subject: [PATCH 4/5] return missing error --- pkg/auth/auth.go | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/pkg/auth/auth.go b/pkg/auth/auth.go index d8634e02a..ea6dca949 100644 --- a/pkg/auth/auth.go +++ b/pkg/auth/auth.go @@ -197,8 +197,8 @@ func (a *Authorizer) Authorize(_ http.ResponseWriter, r *http.Request, authOptio } else { // Use previously decoded values missing, ok = r.Context().Value("authTokenMissing").(bool) - if !ok || missing { - return api.AuthorizationResult{Error: stacktrace.NewErrorWithCode(dsserr.Unauthenticated, "Missing access token")} + if !ok { + missing = true } validated, ok = r.Context().Value("authValidated").(bool) @@ -217,6 +217,10 @@ func (a *Authorizer) Authorize(_ http.ResponseWriter, r *http.Request, authOptio } } + if missing { + return api.AuthorizationResult{Error: stacktrace.NewErrorWithCode(dsserr.Unauthenticated, "Missing access token")} + } + if !validated { return api.AuthorizationResult{Error: stacktrace.PropagateWithCode(err, dsserr.Unauthenticated, "Access token validation failed")} } From 9efe2d4705ca2fbd4f5263ef9c0a64d77c707d2b Mon Sep 17 00:00:00 2001 From: Ahmad ElRouby Date: Fri, 6 Sep 2024 16:02:44 +0200 Subject: [PATCH 5/5] flip the order of middlewares --- cmds/core-service/main.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cmds/core-service/main.go b/cmds/core-service/main.go index 8dfc784c7..1a05ff148 100644 --- a/cmds/core-service/main.go +++ b/cmds/core-service/main.go @@ -281,9 +281,9 @@ func RunHTTPServer(ctx context.Context, ctxCanceler func(), address, locality st multiRouter.Routers = append(multiRouter.Routers, &scdV1Router) } - handler := logging.HTTPMiddleware(logger, *dumpRequests, &multiRouter) + handler := logging.HTTPMiddleware(logger, *dumpRequests, healthyEndpointMiddleware(logger, &multiRouter)) - handler = auth.DecoderMiddleware(authorizer, healthyEndpointMiddleware(logger, handler)) + handler = auth.DecoderMiddleware(authorizer, handler) httpServer := &http.Server{ Addr: address,