From 53bef96cc39b8b61f2091e349b770d0be2216924 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andr=C3=A9=20B=C3=B6hlke?= Date: Tue, 24 Sep 2024 15:26:08 +0200 Subject: [PATCH] Load token only from header, ignore cookie --- pkg/auth/auth.go | 124 +++++++++++++---------------------------------- 1 file changed, 34 insertions(+), 90 deletions(-) diff --git a/pkg/auth/auth.go b/pkg/auth/auth.go index 1af4877f..5792e19b 100644 --- a/pkg/auth/auth.go +++ b/pkg/auth/auth.go @@ -4,7 +4,6 @@ package auth import ( "context" - "encoding/json" "errors" "fmt" "net/http" @@ -122,12 +121,14 @@ func (a *Auth) Authenticate(w http.ResponseWriter, r *http.Request) (context.Con ctx := r.Context() - p := new(payload) + p := new(OpenSlidesClaims) + // 0 means anonymous user + p.UserID = "0" if err := a.loadToken(w, r, p); err != nil { return nil, fmt.Errorf("reading token: %w", err) } - if p.UserID == 0 { + if p.UserID == "0" { return a.AuthenticatedContext(ctx, 0), nil } @@ -141,7 +142,9 @@ func (a *Auth) Authenticate(w http.ResponseWriter, r *http.Request) (context.Con } } - ctx, cancelCtx := context.WithCancel(a.AuthenticatedContext(ctx, p.UserID)) + // convert p.UserID to int + userID, err := strconv.Atoi(p.UserID) + ctx, cancelCtx := context.WithCancel(a.AuthenticatedContext(ctx, userID)) go func() { defer cancelCtx() @@ -228,67 +231,50 @@ func (a *Auth) pruneOldData(ctx context.Context) { } } -// loadToken loads and validates the token. If the token is expires, it tries -// to renews it and writes the new token to the responsewriter. -func (a *Auth) loadToken(w http.ResponseWriter, r *http.Request, payload jwt.Claims) error { - header := r.Header.Get(authHeader) - cookie, err := r.Cookie(cookieName) - if err != nil && err != http.ErrNoCookie { - return fmt.Errorf("reading cookie: %w", err) +func TrimPrefixCaseInsensitive(s, prefix string) string { + if strings.HasPrefix(strings.ToLower(s), strings.ToLower(prefix)) { + return s[len(prefix):] } + return s +} - encodedToken := strings.TrimPrefix(header, "bearer ") +// loadToken loads and validates the token. If the token is expired, it tries +// to renew it and writes the new token to the responsewriter. +func (a *Auth) loadToken(w http.ResponseWriter, r *http.Request, payload *OpenSlidesClaims) error { + header := r.Header.Get(authHeader) - if cookie == nil && header == encodedToken { - // No token and no auth cookie. Handle the request as anonymous requst. + encodedToken := TrimPrefixCaseInsensitive(header, "bearer ") + + if header == encodedToken { + println("no bearer") return nil } - if cookie == nil && header != encodedToken { - return authError{"Can not find auth cookie", nil} - } + token, err := jwt.ParseWithClaims(encodedToken, payload, func(token *jwt.Token) (interface{}, error) { + return []byte(a.tokenKey), nil + }) - if cookie != nil && header == encodedToken { - return authError{"Can not find auth token", nil} - } + claims, _ := token.Claims.(*OpenSlidesClaims) + fmt.Printf("UserID: %s\n", payload.UserID) + //fmt.Printf("Issuer: %s\n", claims.Issuer) - encodedCookie := strings.TrimPrefix(cookie.Value, "bearer%20") + payload.UserID = claims.UserID - _, err = jwt.Parse(encodedCookie, func(token *jwt.Token) (interface{}, error) { - return []byte(a.cookieKey), nil - }) if err != nil { var invalid *jwt.ValidationError if errors.As(err, &invalid) { - return authError{"Invalid auth token", err} - } - return fmt.Errorf("validating auth cookie: %w", err) - } - - _, err = jwt.ParseWithClaims(encodedToken, payload, func(token *jwt.Token) (interface{}, error) { - return []byte(a.tokenKey), nil - }) - if err != nil { - var invalid *jwt.ValidationError - if errors.As(err, &invalid) { - return a.handleInvalidToken(r.Context(), invalid, w, encodedToken, encodedCookie) + return a.handleInvalidToken(r.Context(), invalid, w, encodedToken) } } return nil } -func (a *Auth) handleInvalidToken(ctx context.Context, invalid *jwt.ValidationError, w http.ResponseWriter, encodedToken, encodedCookie string) error { - if !tokenExpired(invalid.Errors) { - return authError{"Invalid auth token", invalid} - } - - token, err := a.refreshToken(ctx, encodedToken, encodedCookie) - if err != nil { - return fmt.Errorf("refreshing token: %w", err) +func (a *Auth) handleInvalidToken(ctx context.Context, invalid *jwt.ValidationError, w http.ResponseWriter, encodedToken string) error { + if tokenExpired(invalid.Errors) { + return authError{"auth token is expired", invalid} } - w.Header().Set(authHeader, token) return nil } @@ -296,56 +282,14 @@ func tokenExpired(errNo uint32) bool { return errNo&(jwt.ValidationErrorExpired|jwt.ValidationErrorNotValidYet) != 0 } -func (a *Auth) refreshToken(ctx context.Context, token, cookie string) (string, error) { - req, err := http.NewRequestWithContext(ctx, "POST", a.authServiceURL+authPath, nil) - if err != nil { - return "", fmt.Errorf("creating auth request: %w", err) - } - - req.Header.Add(authHeader, "bearer "+token) - req.AddCookie(&http.Cookie{Name: cookieName, Value: "bearer " + cookie, HttpOnly: true, Secure: true}) - - resp, err := http.DefaultClient.Do(req) - if err != nil { - // TODO External ERROR - return "", fmt.Errorf("send request to auth service: %w", err) - } - defer resp.Body.Close() - - if resp.StatusCode != 200 { - if resp.StatusCode == 403 { - return "", authError{msg: "Invalid Session", wrapped: err} - } - // TODO LAST ERROR - return "", fmt.Errorf("auth-service returned status %s", resp.Status) - } - - newToken := resp.Header.Get(authHeader) - if newToken == "" { - var rPayload struct { - Message string `json:"message"` - } - if err := json.NewDecoder(resp.Body).Decode(&rPayload); err != nil { - return "", fmt.Errorf("decoding auth response: %w", err) - } - if rPayload.Message == "" { - rPayload.Message = "Can not refresh token" - } - return "", authError{rPayload.Message, nil} - - } - - return newToken, nil -} - type authString string const ( userIDType authString = "user_id" ) -type payload struct { +type OpenSlidesClaims struct { jwt.StandardClaims - UserID int `json:"userId"` - SessionID string `json:"sessionId"` + UserID string `json:"userId"` + SessionID string `json:"sid"` }