diff --git a/internal/api/handlers.go b/internal/api/handlers.go index 6a0d180..6737b30 100644 --- a/internal/api/handlers.go +++ b/internal/api/handlers.go @@ -36,9 +36,9 @@ func NewGoCertRouter(env *Environment) http.Handler { apiV1Router.HandleFunc("POST /certificate_requests/{id}/certificate/reject", RejectCertificate(env)) apiV1Router.HandleFunc("DELETE /certificate_requests/{id}/certificate", DeleteCertificate(env)) - apiV1Router.HandleFunc("GET /accounts/{id}", GetUserAccount(env)) apiV1Router.HandleFunc("GET /accounts", GetUserAccounts(env)) apiV1Router.HandleFunc("POST /accounts", PostUserAccount(env)) + apiV1Router.HandleFunc("GET /accounts/{id}", GetUserAccount(env)) apiV1Router.HandleFunc("DELETE /accounts/{id}", DeleteUserAccount(env)) apiV1Router.HandleFunc("POST /accounts/{id}/change_password", ChangeUserAccountPassword(env)) @@ -300,7 +300,17 @@ func GetUserAccounts(env *Environment) http.HandlerFunc { func GetUserAccount(env *Environment) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { id := r.PathValue("id") - userAccount, err := env.DB.RetrieveUser(id) + var userAccount certdb.User + var err error + if id == "me" { + claims, headerErr := getClaimsFromAuthorizationHeader(r.Header.Get("Authorization"), env.JWTSecret) + if headerErr != nil { + logErrorAndWriteResponse(headerErr.Error(), http.StatusUnauthorized, w) + } + userAccount, err = env.DB.RetrieveUserByUsername(claims.Username) + } else { + userAccount, err = env.DB.RetrieveUser(id) + } if err != nil { if errors.Is(err, certdb.ErrIdNotFound) { logErrorAndWriteResponse(err.Error(), http.StatusNotFound, w) @@ -409,6 +419,17 @@ func DeleteUserAccount(env *Environment) http.HandlerFunc { func ChangeUserAccountPassword(env *Environment) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { id := r.PathValue("id") + if id == "me" { + claims, err := getClaimsFromAuthorizationHeader(r.Header.Get("Authorization"), env.JWTSecret) + if err != nil { + logErrorAndWriteResponse(err.Error(), http.StatusUnauthorized, w) + } + userAccount, err := env.DB.RetrieveUserByUsername(claims.Username) + if err != nil { + logErrorAndWriteResponse(err.Error(), http.StatusUnauthorized, w) + } + id = strconv.Itoa(userAccount.ID) + } var user certdb.User if err := json.NewDecoder(r.Body).Decode(&user); err != nil { logErrorAndWriteResponse("Invalid JSON format", http.StatusBadRequest, w) @@ -471,7 +492,7 @@ func Login(env *Environment) http.HandlerFunc { logErrorAndWriteResponse("The username or password is incorrect. Try again.", http.StatusUnauthorized, w) return } - jwt, err := generateJWT(userRequest.Username, env.JWTSecret, userAccount.Permissions) + jwt, err := generateJWT(userAccount.ID, userAccount.Username, env.JWTSecret, userAccount.Permissions) if err != nil { logErrorAndWriteResponse(err.Error(), http.StatusInternalServerError, w) return @@ -489,7 +510,7 @@ func logErrorAndWriteResponse(msg string, status int, w http.ResponseWriter) { log.Println(errMsg) w.WriteHeader(status) if _, err := w.Write([]byte(errMsg)); err != nil { - logErrorAndWriteResponse(err.Error(), http.StatusInternalServerError, w) + log.Printf("error writing response: %s", err.Error()) } } @@ -554,8 +575,9 @@ func validatePassword(password string) bool { } // Helper function to generate a JWT -func generateJWT(username string, jwtSecret []byte, permissions int) (string, error) { +func generateJWT(id int, username string, jwtSecret []byte, permissions int) (string, error) { token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwtGocertClaims{ + ID: id, Username: username, Permissions: permissions, StandardClaims: jwt.StandardClaims{ @@ -571,6 +593,7 @@ func generateJWT(username string, jwtSecret []byte, permissions int) (string, er } type jwtGocertClaims struct { + ID int `json:"id"` Username string `json:"username"` Permissions int `json:"permissions"` jwt.StandardClaims diff --git a/internal/api/handlers_test.go b/internal/api/handlers_test.go index 46b18e5..8e0eb33 100644 --- a/internal/api/handlers_test.go +++ b/internal/api/handlers_test.go @@ -737,6 +737,24 @@ func TestAuthorization(t *testing.T) { response: "", status: http.StatusForbidden, }, + { + desc: "user can change self password with /me", + method: "POST", + path: "/api/v1/accounts/me/change_password", + data: `{"password":"BetterPW1!"}`, + auth: nonAdminToken, + response: "", + status: http.StatusOK, + }, + { + desc: "user can login with new password", + method: "POST", + path: "/login", + data: `{"username":"testuser","password":"BetterPW1!"}`, + auth: nonAdminToken, + response: "", + status: http.StatusOK, + }, { desc: "admin can't delete itself", method: "DELETE", diff --git a/internal/api/middleware.go b/internal/api/middleware.go index 75b952b..f74657a 100644 --- a/internal/api/middleware.go +++ b/internal/api/middleware.go @@ -13,6 +13,11 @@ import ( "github.com/prometheus/client_golang/prometheus/promhttp" ) +const ( + USER_ACCOUNT = 0 + ADMIN_ACCOUNT = 1 +) + type middleware func(http.Handler) http.Handler // The middlewareContext type helps middleware receive and pass along information through the middleware chain. @@ -94,14 +99,6 @@ func loggingMiddleware(ctx *middlewareContext) middleware { // authMiddleware intercepts requests that need authorization to check if the user's token exists and is // permitted to use the endpoint func authMiddleware(ctx *middlewareContext) middleware { - AdminOnlyPaths := []struct{ method, path string }{ - {"POST", `accounts`}, - {"GET", `accounts`}, - {"GET", `accounts\/\d+$`}, - {"DELETE", `accounts\/\d+$`}, - {"POST", `accounts\/\d+\/change_password$`}, - } - return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if !strings.HasPrefix(r.URL.Path, "/api/v1/") { @@ -115,35 +112,23 @@ func authMiddleware(ctx *middlewareContext) middleware { } return } - authHeader := r.Header.Get("Authorization") - if authHeader == "" { - logErrorAndWriteResponse("authorization header not found", http.StatusUnauthorized, w) - return - } - bearerToken := strings.Split(authHeader, " ") - if len(bearerToken) != 2 || bearerToken[0] != "Bearer" { - logErrorAndWriteResponse("authorization header couldn't be processed. The expected format is 'Bearer '", http.StatusUnauthorized, w) - return - } - claims, err := getClaimsFromJWT(bearerToken[1], ctx.jwtSecret) + claims, err := getClaimsFromAuthorizationHeader(r.Header.Get("Authorization"), ctx.jwtSecret) if err != nil { - logErrorAndWriteResponse(fmt.Sprintf("token is not valid: %s", err.Error()), http.StatusUnauthorized, w) + logErrorAndWriteResponse(fmt.Sprintf("auth failed: %s", err.Error()), http.StatusUnauthorized, w) return } - if claims.Permissions == 0 { - for _, v := range AdminOnlyPaths { - matched, err := regexp.Match(v.path, []byte(r.URL.Path)) - if err != nil { - logErrorAndWriteResponse(fmt.Sprintf("ran into issue parsing path: %s", err.Error()), http.StatusInternalServerError, w) - return - } - if r.Method == v.method && matched { - logErrorAndWriteResponse("forbidden", http.StatusForbidden, w) - return - } + if claims.Permissions == USER_ACCOUNT { + requestAllowed, err := AllowRequest(claims, r.Method, r.URL.Path) + if err != nil { + logErrorAndWriteResponse(fmt.Sprintf("error processing path: %s", err.Error()), http.StatusInternalServerError, w) + return + } + if !requestAllowed { + logErrorAndWriteResponse("forbidden", http.StatusForbidden, w) + return } } - if claims.Permissions == 1 && r.Method == "DELETE" && strings.HasSuffix(r.URL.Path, "accounts/1") { + if r.Method == "DELETE" && strings.HasSuffix(r.URL.Path, "accounts/1") { logErrorAndWriteResponse("can't delete admin account", http.StatusConflict, w) return } @@ -152,6 +137,73 @@ func authMiddleware(ctx *middlewareContext) middleware { } } +func getClaimsFromAuthorizationHeader(header string, jwtSecret []byte) (*jwtGocertClaims, error) { + if header == "" { + return nil, fmt.Errorf("authorization header not found") + } + bearerToken := strings.Split(header, " ") + if len(bearerToken) != 2 || bearerToken[0] != "Bearer" { + return nil, fmt.Errorf("authorization header couldn't be processed. The expected format is 'Bearer '") + } + claims, err := getClaimsFromJWT(bearerToken[1], jwtSecret) + if err != nil { + return nil, fmt.Errorf("token is not valid: %s", err) + } + return claims, nil +} + +// AllowRequest looks at the user data to determine the following things: +// The first question is "Is this user trying to access a path that's restricted?" +// +// There are two types of restricted paths: admin only paths that only admins can access, and self authorized paths, +// which users are allowed to use only if they are taking an action on their own user ID. The second question is +// "If the path requires an ID, is the user attempting to access their own ID?" +// +// For all endpoints and permission permutations, there are only 2 cases when users are allowed to use endpoints: +// If the URL path is not restricted to admins +// If the URL path is restricted to self authorized endpoints, and the user is taking action with their own ID +// This function validates that the user the with the given claims is allowed to use the endpoints by passing the above checks. +func AllowRequest(claims *jwtGocertClaims, method, path string) (bool, error) { + restrictedPaths := []struct { + method, pathRegex string + SelfAuthorizedAllowed bool + }{ + {"POST", `accounts$`, false}, + {"GET", `accounts$`, false}, + {"DELETE", `accounts\/(\d+)$`, false}, + {"GET", `accounts\/(\d+)$`, true}, + {"POST", `accounts\/(\d+)\/change_password$`, true}, + } + for _, pr := range restrictedPaths { + regexChallenge, err := regexp.Compile(pr.pathRegex) + if err != nil { + return false, fmt.Errorf("regex couldn't compile: %s", err) + } + matches := regexChallenge.FindStringSubmatch(path) + restrictedPathMatchedToRequestedPath := len(matches) > 0 && method == pr.method + if !restrictedPathMatchedToRequestedPath { + continue + } + if !pr.SelfAuthorizedAllowed { + return false, nil + } + matchedID, err := strconv.Atoi(matches[1]) + if err != nil { + return true, fmt.Errorf("error converting url id to string: %s", err) + } + var requestedIDMatchesTheClaimant bool + if matchedID == claims.ID { + requestedIDMatchesTheClaimant = true + } + IDRequiredForPath := len(matches) > 1 + if IDRequiredForPath && !requestedIDMatchesTheClaimant { + return false, nil + } + return true, nil + } + return true, nil +} + func getClaimsFromJWT(bearerToken string, jwtSecret []byte) (*jwtGocertClaims, error) { claims := jwtGocertClaims{} token, err := jwt.ParseWithClaims(bearerToken, &claims, func(token *jwt.Token) (interface{}, error) {