Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: redirect /accounts/me to the correct /accounts/{id} #43

Merged
merged 5 commits into from
Jul 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 28 additions & 5 deletions internal/api/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,9 @@ func NewGoCertRouter(env *Environment) http.Handler {
apiV1Router.HandleFunc("POST /certificate_requests/{id}/certificate/reject", RejectCertificate(env))
apiV1Router.HandleFunc("DELETE /certificate_requests/{id}/certificate", DeleteCertificate(env))

apiV1Router.HandleFunc("GET /accounts/{id}", GetUserAccount(env))
apiV1Router.HandleFunc("GET /accounts", GetUserAccounts(env))
apiV1Router.HandleFunc("POST /accounts", PostUserAccount(env))
apiV1Router.HandleFunc("GET /accounts/{id}", GetUserAccount(env))
apiV1Router.HandleFunc("DELETE /accounts/{id}", DeleteUserAccount(env))
apiV1Router.HandleFunc("POST /accounts/{id}/change_password", ChangeUserAccountPassword(env))

Expand Down Expand Up @@ -300,7 +300,17 @@ func GetUserAccounts(env *Environment) http.HandlerFunc {
func GetUserAccount(env *Environment) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
id := r.PathValue("id")
userAccount, err := env.DB.RetrieveUser(id)
var userAccount certdb.User
var err error
if id == "me" {
claims, headerErr := getClaimsFromAuthorizationHeader(r.Header.Get("Authorization"), env.JWTSecret)
if headerErr != nil {
logErrorAndWriteResponse(headerErr.Error(), http.StatusUnauthorized, w)
}
userAccount, err = env.DB.RetrieveUserByUsername(claims.Username)
} else {
userAccount, err = env.DB.RetrieveUser(id)
}
if err != nil {
if errors.Is(err, certdb.ErrIdNotFound) {
logErrorAndWriteResponse(err.Error(), http.StatusNotFound, w)
Expand Down Expand Up @@ -409,6 +419,17 @@ func DeleteUserAccount(env *Environment) http.HandlerFunc {
func ChangeUserAccountPassword(env *Environment) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
id := r.PathValue("id")
if id == "me" {
claims, err := getClaimsFromAuthorizationHeader(r.Header.Get("Authorization"), env.JWTSecret)
if err != nil {
logErrorAndWriteResponse(err.Error(), http.StatusUnauthorized, w)
}
userAccount, err := env.DB.RetrieveUserByUsername(claims.Username)
if err != nil {
logErrorAndWriteResponse(err.Error(), http.StatusUnauthorized, w)
}
id = strconv.Itoa(userAccount.ID)
}
var user certdb.User
if err := json.NewDecoder(r.Body).Decode(&user); err != nil {
logErrorAndWriteResponse("Invalid JSON format", http.StatusBadRequest, w)
Expand Down Expand Up @@ -471,7 +492,7 @@ func Login(env *Environment) http.HandlerFunc {
logErrorAndWriteResponse("The username or password is incorrect. Try again.", http.StatusUnauthorized, w)
return
}
jwt, err := generateJWT(userRequest.Username, env.JWTSecret, userAccount.Permissions)
jwt, err := generateJWT(userAccount.ID, userAccount.Username, env.JWTSecret, userAccount.Permissions)
if err != nil {
logErrorAndWriteResponse(err.Error(), http.StatusInternalServerError, w)
return
Expand All @@ -489,7 +510,7 @@ func logErrorAndWriteResponse(msg string, status int, w http.ResponseWriter) {
log.Println(errMsg)
w.WriteHeader(status)
if _, err := w.Write([]byte(errMsg)); err != nil {
logErrorAndWriteResponse(err.Error(), http.StatusInternalServerError, w)
log.Printf("error writing response: %s", err.Error())
}
}

Expand Down Expand Up @@ -554,8 +575,9 @@ func validatePassword(password string) bool {
}

// Helper function to generate a JWT
func generateJWT(username string, jwtSecret []byte, permissions int) (string, error) {
func generateJWT(id int, username string, jwtSecret []byte, permissions int) (string, error) {
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwtGocertClaims{
ID: id,
Username: username,
Permissions: permissions,
StandardClaims: jwt.StandardClaims{
Expand All @@ -571,6 +593,7 @@ func generateJWT(username string, jwtSecret []byte, permissions int) (string, er
}

type jwtGocertClaims struct {
ID int `json:"id"`
Username string `json:"username"`
Permissions int `json:"permissions"`
jwt.StandardClaims
Expand Down
18 changes: 18 additions & 0 deletions internal/api/handlers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -737,6 +737,24 @@ func TestAuthorization(t *testing.T) {
response: "",
status: http.StatusForbidden,
},
{
desc: "user can change self password with /me",
method: "POST",
path: "/api/v1/accounts/me/change_password",
data: `{"password":"BetterPW1!"}`,
auth: nonAdminToken,
response: "",
status: http.StatusOK,
},
{
desc: "user can login with new password",
method: "POST",
path: "/login",
data: `{"username":"testuser","password":"BetterPW1!"}`,
auth: nonAdminToken,
response: "",
status: http.StatusOK,
},
{
desc: "admin can't delete itself",
method: "DELETE",
Expand Down
116 changes: 84 additions & 32 deletions internal/api/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,11 @@ import (
"github.com/prometheus/client_golang/prometheus/promhttp"
)

const (
USER_ACCOUNT = 0
ADMIN_ACCOUNT = 1
)

type middleware func(http.Handler) http.Handler

// The middlewareContext type helps middleware receive and pass along information through the middleware chain.
Expand Down Expand Up @@ -94,14 +99,6 @@ func loggingMiddleware(ctx *middlewareContext) middleware {
// authMiddleware intercepts requests that need authorization to check if the user's token exists and is
// permitted to use the endpoint
func authMiddleware(ctx *middlewareContext) middleware {
AdminOnlyPaths := []struct{ method, path string }{
{"POST", `accounts`},
{"GET", `accounts`},
{"GET", `accounts\/\d+$`},
{"DELETE", `accounts\/\d+$`},
{"POST", `accounts\/\d+\/change_password$`},
}

return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if !strings.HasPrefix(r.URL.Path, "/api/v1/") {
Expand All @@ -115,35 +112,23 @@ func authMiddleware(ctx *middlewareContext) middleware {
}
return
}
authHeader := r.Header.Get("Authorization")
if authHeader == "" {
logErrorAndWriteResponse("authorization header not found", http.StatusUnauthorized, w)
return
}
bearerToken := strings.Split(authHeader, " ")
if len(bearerToken) != 2 || bearerToken[0] != "Bearer" {
logErrorAndWriteResponse("authorization header couldn't be processed. The expected format is 'Bearer <token>'", http.StatusUnauthorized, w)
return
}
claims, err := getClaimsFromJWT(bearerToken[1], ctx.jwtSecret)
claims, err := getClaimsFromAuthorizationHeader(r.Header.Get("Authorization"), ctx.jwtSecret)
if err != nil {
logErrorAndWriteResponse(fmt.Sprintf("token is not valid: %s", err.Error()), http.StatusUnauthorized, w)
logErrorAndWriteResponse(fmt.Sprintf("auth failed: %s", err.Error()), http.StatusUnauthorized, w)
return
}
if claims.Permissions == 0 {
for _, v := range AdminOnlyPaths {
matched, err := regexp.Match(v.path, []byte(r.URL.Path))
if err != nil {
logErrorAndWriteResponse(fmt.Sprintf("ran into issue parsing path: %s", err.Error()), http.StatusInternalServerError, w)
return
}
if r.Method == v.method && matched {
logErrorAndWriteResponse("forbidden", http.StatusForbidden, w)
return
}
if claims.Permissions == USER_ACCOUNT {
requestAllowed, err := AllowRequest(claims, r.Method, r.URL.Path)
if err != nil {
logErrorAndWriteResponse(fmt.Sprintf("error processing path: %s", err.Error()), http.StatusInternalServerError, w)
return
}
if !requestAllowed {
logErrorAndWriteResponse("forbidden", http.StatusForbidden, w)
return
}
}
if claims.Permissions == 1 && r.Method == "DELETE" && strings.HasSuffix(r.URL.Path, "accounts/1") {
if r.Method == "DELETE" && strings.HasSuffix(r.URL.Path, "accounts/1") {
gruyaume marked this conversation as resolved.
Show resolved Hide resolved
logErrorAndWriteResponse("can't delete admin account", http.StatusConflict, w)
return
}
Expand All @@ -152,6 +137,73 @@ func authMiddleware(ctx *middlewareContext) middleware {
}
}

func getClaimsFromAuthorizationHeader(header string, jwtSecret []byte) (*jwtGocertClaims, error) {
if header == "" {
return nil, fmt.Errorf("authorization header not found")
}
bearerToken := strings.Split(header, " ")
if len(bearerToken) != 2 || bearerToken[0] != "Bearer" {
return nil, fmt.Errorf("authorization header couldn't be processed. The expected format is 'Bearer <token>'")
}
claims, err := getClaimsFromJWT(bearerToken[1], jwtSecret)
if err != nil {
return nil, fmt.Errorf("token is not valid: %s", err)
}
return claims, nil
}

// AllowRequest looks at the user data to determine the following things:
// The first question is "Is this user trying to access a path that's restricted?"
//
// There are two types of restricted paths: admin only paths that only admins can access, and self authorized paths,
// which users are allowed to use only if they are taking an action on their own user ID. The second question is
// "If the path requires an ID, is the user attempting to access their own ID?"
//
// For all endpoints and permission permutations, there are only 2 cases when users are allowed to use endpoints:
// If the URL path is not restricted to admins
// If the URL path is restricted to self authorized endpoints, and the user is taking action with their own ID
// This function validates that the user the with the given claims is allowed to use the endpoints by passing the above checks.
func AllowRequest(claims *jwtGocertClaims, method, path string) (bool, error) {
restrictedPaths := []struct {
method, pathRegex string
SelfAuthorizedAllowed bool
}{
{"POST", `accounts$`, false},
{"GET", `accounts$`, false},
{"DELETE", `accounts\/(\d+)$`, false},
{"GET", `accounts\/(\d+)$`, true},
{"POST", `accounts\/(\d+)\/change_password$`, true},
}
for _, pr := range restrictedPaths {
regexChallenge, err := regexp.Compile(pr.pathRegex)
if err != nil {
return false, fmt.Errorf("regex couldn't compile: %s", err)
}
matches := regexChallenge.FindStringSubmatch(path)
restrictedPathMatchedToRequestedPath := len(matches) > 0 && method == pr.method
if !restrictedPathMatchedToRequestedPath {
continue
}
if !pr.SelfAuthorizedAllowed {
return false, nil
}
matchedID, err := strconv.Atoi(matches[1])
if err != nil {
return true, fmt.Errorf("error converting url id to string: %s", err)
}
var requestedIDMatchesTheClaimant bool
if matchedID == claims.ID {
requestedIDMatchesTheClaimant = true
}
IDRequiredForPath := len(matches) > 1
if IDRequiredForPath && !requestedIDMatchesTheClaimant {
return false, nil
}
return true, nil
}
return true, nil
}

func getClaimsFromJWT(bearerToken string, jwtSecret []byte) (*jwtGocertClaims, error) {
claims := jwtGocertClaims{}
token, err := jwt.ParseWithClaims(bearerToken, &claims, func(token *jwt.Token) (interface{}, error) {
Expand Down
Loading