From b9303fbbffd9e0583a66819e82f4e264a4e04f64 Mon Sep 17 00:00:00 2001 From: kayra1 Date: Fri, 12 Jul 2024 14:12:31 +0300 Subject: [PATCH] clean up auth middleware and redirect /me --- internal/api/handlers.go | 32 ++++++++++++-- internal/api/handlers_test.go | 18 ++++++++ internal/api/middleware.go | 83 +++++++++++++++++++++++++---------- 3 files changed, 105 insertions(+), 28 deletions(-) diff --git a/internal/api/handlers.go b/internal/api/handlers.go index 6a0d180..c2aee3c 100644 --- a/internal/api/handlers.go +++ b/internal/api/handlers.go @@ -36,9 +36,9 @@ func NewGoCertRouter(env *Environment) http.Handler { apiV1Router.HandleFunc("POST /certificate_requests/{id}/certificate/reject", RejectCertificate(env)) apiV1Router.HandleFunc("DELETE /certificate_requests/{id}/certificate", DeleteCertificate(env)) - apiV1Router.HandleFunc("GET /accounts/{id}", GetUserAccount(env)) apiV1Router.HandleFunc("GET /accounts", GetUserAccounts(env)) apiV1Router.HandleFunc("POST /accounts", PostUserAccount(env)) + apiV1Router.HandleFunc("GET /accounts/{id}", GetUserAccount(env)) apiV1Router.HandleFunc("DELETE /accounts/{id}", DeleteUserAccount(env)) apiV1Router.HandleFunc("POST /accounts/{id}/change_password", ChangeUserAccountPassword(env)) @@ -300,7 +300,17 @@ func GetUserAccounts(env *Environment) http.HandlerFunc { func GetUserAccount(env *Environment) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { id := r.PathValue("id") - userAccount, err := env.DB.RetrieveUser(id) + var userAccount certdb.User + var err error + if id == "me" { + claims, headerErr := getClaimsFromAuthorizationHeader(r.Header.Get("Authorization"), env.JWTSecret) + if headerErr != nil { + logErrorAndWriteResponse(headerErr.Error(), http.StatusUnauthorized, w) + } + userAccount, err = env.DB.RetrieveUserByUsername(claims.Username) + } else { + userAccount, err = env.DB.RetrieveUser(id) + } if err != nil { if errors.Is(err, certdb.ErrIdNotFound) { logErrorAndWriteResponse(err.Error(), http.StatusNotFound, w) @@ -409,6 +419,17 @@ func DeleteUserAccount(env *Environment) http.HandlerFunc { func ChangeUserAccountPassword(env *Environment) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { id := r.PathValue("id") + if id == "me" { + claims, err := getClaimsFromAuthorizationHeader(r.Header.Get("Authorization"), env.JWTSecret) + if err != nil { + logErrorAndWriteResponse(err.Error(), http.StatusUnauthorized, w) + } + userAccount, err := env.DB.RetrieveUserByUsername(claims.Username) + if err != nil { + logErrorAndWriteResponse(err.Error(), http.StatusUnauthorized, w) + } + id = strconv.Itoa(userAccount.ID) + } var user certdb.User if err := json.NewDecoder(r.Body).Decode(&user); err != nil { logErrorAndWriteResponse("Invalid JSON format", http.StatusBadRequest, w) @@ -471,7 +492,7 @@ func Login(env *Environment) http.HandlerFunc { logErrorAndWriteResponse("The username or password is incorrect. Try again.", http.StatusUnauthorized, w) return } - jwt, err := generateJWT(userRequest.Username, env.JWTSecret, userAccount.Permissions) + jwt, err := generateJWT(userAccount.ID, userAccount.Username, env.JWTSecret, userAccount.Permissions) if err != nil { logErrorAndWriteResponse(err.Error(), http.StatusInternalServerError, w) return @@ -489,6 +510,7 @@ func logErrorAndWriteResponse(msg string, status int, w http.ResponseWriter) { log.Println(errMsg) w.WriteHeader(status) if _, err := w.Write([]byte(errMsg)); err != nil { + // TODO: how did this get merged? logErrorAndWriteResponse(err.Error(), http.StatusInternalServerError, w) } } @@ -554,8 +576,9 @@ func validatePassword(password string) bool { } // Helper function to generate a JWT -func generateJWT(username string, jwtSecret []byte, permissions int) (string, error) { +func generateJWT(id int, username string, jwtSecret []byte, permissions int) (string, error) { token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwtGocertClaims{ + ID: id, Username: username, Permissions: permissions, StandardClaims: jwt.StandardClaims{ @@ -571,6 +594,7 @@ func generateJWT(username string, jwtSecret []byte, permissions int) (string, er } type jwtGocertClaims struct { + ID int `json:"id"` Username string `json:"username"` Permissions int `json:"permissions"` jwt.StandardClaims diff --git a/internal/api/handlers_test.go b/internal/api/handlers_test.go index 46b18e5..8e0eb33 100644 --- a/internal/api/handlers_test.go +++ b/internal/api/handlers_test.go @@ -737,6 +737,24 @@ func TestAuthorization(t *testing.T) { response: "", status: http.StatusForbidden, }, + { + desc: "user can change self password with /me", + method: "POST", + path: "/api/v1/accounts/me/change_password", + data: `{"password":"BetterPW1!"}`, + auth: nonAdminToken, + response: "", + status: http.StatusOK, + }, + { + desc: "user can login with new password", + method: "POST", + path: "/login", + data: `{"username":"testuser","password":"BetterPW1!"}`, + auth: nonAdminToken, + response: "", + status: http.StatusOK, + }, { desc: "admin can't delete itself", method: "DELETE", diff --git a/internal/api/middleware.go b/internal/api/middleware.go index 75b952b..98cc9bb 100644 --- a/internal/api/middleware.go +++ b/internal/api/middleware.go @@ -94,14 +94,16 @@ func loggingMiddleware(ctx *middlewareContext) middleware { // authMiddleware intercepts requests that need authorization to check if the user's token exists and is // permitted to use the endpoint func authMiddleware(ctx *middlewareContext) middleware { - AdminOnlyPaths := []struct{ method, path string }{ - {"POST", `accounts`}, - {"GET", `accounts`}, - {"GET", `accounts\/\d+$`}, - {"DELETE", `accounts\/\d+$`}, - {"POST", `accounts\/\d+\/change_password$`}, + RestrictedPaths := []struct { + method, pathRegex string + SelfAuthorizedAllowed bool + }{ + {"POST", `accounts$`, false}, + {"GET", `accounts$`, false}, + {"DELETE", `accounts\/(\d+)$`, false}, + {"GET", `accounts\/(\d+)$`, true}, + {"POST", `accounts\/(\d+)\/change_password$`, true}, } - return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if !strings.HasPrefix(r.URL.Path, "/api/v1/") { @@ -115,35 +117,29 @@ func authMiddleware(ctx *middlewareContext) middleware { } return } - authHeader := r.Header.Get("Authorization") - if authHeader == "" { - logErrorAndWriteResponse("authorization header not found", http.StatusUnauthorized, w) - return - } - bearerToken := strings.Split(authHeader, " ") - if len(bearerToken) != 2 || bearerToken[0] != "Bearer" { - logErrorAndWriteResponse("authorization header couldn't be processed. The expected format is 'Bearer '", http.StatusUnauthorized, w) - return - } - claims, err := getClaimsFromJWT(bearerToken[1], ctx.jwtSecret) + claims, err := getClaimsFromAuthorizationHeader(r.Header.Get("Authorization"), ctx.jwtSecret) if err != nil { - logErrorAndWriteResponse(fmt.Sprintf("token is not valid: %s", err.Error()), http.StatusUnauthorized, w) + logErrorAndWriteResponse(fmt.Sprintf("auth failed: %s", err.Error()), http.StatusUnauthorized, w) return } if claims.Permissions == 0 { - for _, v := range AdminOnlyPaths { - matched, err := regexp.Match(v.path, []byte(r.URL.Path)) + for _, v := range RestrictedPaths { + pathMatched, idMatchedIfExistsInPath, err := permitter(claims, v.pathRegex, r.URL.Path, v.SelfAuthorizedAllowed) if err != nil { - logErrorAndWriteResponse(fmt.Sprintf("ran into issue parsing path: %s", err.Error()), http.StatusInternalServerError, w) + logErrorAndWriteResponse(fmt.Sprintf("error processing path: %s", err.Error()), http.StatusInternalServerError, w) return } - if r.Method == v.method && matched { + if pathMatched && !v.SelfAuthorizedAllowed { + logErrorAndWriteResponse("forbidden", http.StatusForbidden, w) + return + } + if pathMatched && v.SelfAuthorizedAllowed && !idMatchedIfExistsInPath { logErrorAndWriteResponse("forbidden", http.StatusForbidden, w) return } } } - if claims.Permissions == 1 && r.Method == "DELETE" && strings.HasSuffix(r.URL.Path, "accounts/1") { + if r.Method == "DELETE" && strings.HasSuffix(r.URL.Path, "accounts/1") { logErrorAndWriteResponse("can't delete admin account", http.StatusConflict, w) return } @@ -152,6 +148,45 @@ func authMiddleware(ctx *middlewareContext) middleware { } } +func getClaimsFromAuthorizationHeader(header string, jwtSecret []byte) (*jwtGocertClaims, error) { + if header == "" { + return nil, fmt.Errorf("authorization header not found") + } + bearerToken := strings.Split(header, " ") + if len(bearerToken) != 2 || bearerToken[0] != "Bearer" { + return nil, fmt.Errorf("authorization header couldn't be processed. The expected format is 'Bearer '") + } + claims, err := getClaimsFromJWT(bearerToken[1], jwtSecret) + if err != nil { + return nil, fmt.Errorf("token is not valid: %s", err) + } + return claims, nil +} + +func permitter(claims *jwtGocertClaims, regex, path string, SelfAuthorizedAllowed bool) (pathMatched, idMatchedIfExistsInPath bool, err error) { + regexChallenge, err := regexp.Compile(regex) + if err != nil { + return pathMatched, idMatchedIfExistsInPath, fmt.Errorf("regex couldn't compile: %s", err) + } + matches := regexChallenge.FindStringSubmatch(path) + if len(matches) > 0 { + pathMatched = true + } + if len(matches) == 1 { + idMatchedIfExistsInPath = true + } + if len(matches) > 1 && SelfAuthorizedAllowed { + matchedID, err := strconv.Atoi(matches[1]) + if err != nil { + return pathMatched, idMatchedIfExistsInPath, fmt.Errorf("error converting url id to string: %s", err) + } + if matchedID == claims.ID { + idMatchedIfExistsInPath = true + } + } + return pathMatched, idMatchedIfExistsInPath, nil +} + func getClaimsFromJWT(bearerToken string, jwtSecret []byte) (*jwtGocertClaims, error) { claims := jwtGocertClaims{} token, err := jwt.ParseWithClaims(bearerToken, &claims, func(token *jwt.Token) (interface{}, error) {