Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: redirect /accounts/me to the correct /accounts/{id} #43

Merged
merged 5 commits into from
Jul 18, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 28 additions & 5 deletions internal/api/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,9 @@ func NewGoCertRouter(env *Environment) http.Handler {
apiV1Router.HandleFunc("POST /certificate_requests/{id}/certificate/reject", RejectCertificate(env))
apiV1Router.HandleFunc("DELETE /certificate_requests/{id}/certificate", DeleteCertificate(env))

apiV1Router.HandleFunc("GET /accounts/{id}", GetUserAccount(env))
apiV1Router.HandleFunc("GET /accounts", GetUserAccounts(env))
apiV1Router.HandleFunc("POST /accounts", PostUserAccount(env))
apiV1Router.HandleFunc("GET /accounts/{id}", GetUserAccount(env))
apiV1Router.HandleFunc("DELETE /accounts/{id}", DeleteUserAccount(env))
apiV1Router.HandleFunc("POST /accounts/{id}/change_password", ChangeUserAccountPassword(env))

Expand Down Expand Up @@ -300,7 +300,17 @@ func GetUserAccounts(env *Environment) http.HandlerFunc {
func GetUserAccount(env *Environment) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
id := r.PathValue("id")
userAccount, err := env.DB.RetrieveUser(id)
var userAccount certdb.User
var err error
if id == "me" {
claims, headerErr := getClaimsFromAuthorizationHeader(r.Header.Get("Authorization"), env.JWTSecret)
if headerErr != nil {
logErrorAndWriteResponse(headerErr.Error(), http.StatusUnauthorized, w)
}
userAccount, err = env.DB.RetrieveUserByUsername(claims.Username)
} else {
userAccount, err = env.DB.RetrieveUser(id)
}
if err != nil {
if errors.Is(err, certdb.ErrIdNotFound) {
logErrorAndWriteResponse(err.Error(), http.StatusNotFound, w)
Expand Down Expand Up @@ -409,6 +419,17 @@ func DeleteUserAccount(env *Environment) http.HandlerFunc {
func ChangeUserAccountPassword(env *Environment) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
id := r.PathValue("id")
if id == "me" {
claims, err := getClaimsFromAuthorizationHeader(r.Header.Get("Authorization"), env.JWTSecret)
if err != nil {
logErrorAndWriteResponse(err.Error(), http.StatusUnauthorized, w)
}
userAccount, err := env.DB.RetrieveUserByUsername(claims.Username)
if err != nil {
logErrorAndWriteResponse(err.Error(), http.StatusUnauthorized, w)
}
id = strconv.Itoa(userAccount.ID)
}
var user certdb.User
if err := json.NewDecoder(r.Body).Decode(&user); err != nil {
logErrorAndWriteResponse("Invalid JSON format", http.StatusBadRequest, w)
Expand Down Expand Up @@ -471,7 +492,7 @@ func Login(env *Environment) http.HandlerFunc {
logErrorAndWriteResponse("The username or password is incorrect. Try again.", http.StatusUnauthorized, w)
return
}
jwt, err := generateJWT(userRequest.Username, env.JWTSecret, userAccount.Permissions)
jwt, err := generateJWT(userAccount.ID, userAccount.Username, env.JWTSecret, userAccount.Permissions)
if err != nil {
logErrorAndWriteResponse(err.Error(), http.StatusInternalServerError, w)
return
Expand All @@ -489,7 +510,7 @@ func logErrorAndWriteResponse(msg string, status int, w http.ResponseWriter) {
log.Println(errMsg)
w.WriteHeader(status)
if _, err := w.Write([]byte(errMsg)); err != nil {
logErrorAndWriteResponse(err.Error(), http.StatusInternalServerError, w)
log.Printf("error writing response: %s", err.Error())
}
}

Expand Down Expand Up @@ -554,8 +575,9 @@ func validatePassword(password string) bool {
}

// Helper function to generate a JWT
func generateJWT(username string, jwtSecret []byte, permissions int) (string, error) {
func generateJWT(id int, username string, jwtSecret []byte, permissions int) (string, error) {
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwtGocertClaims{
ID: id,
Username: username,
Permissions: permissions,
StandardClaims: jwt.StandardClaims{
Expand All @@ -571,6 +593,7 @@ func generateJWT(username string, jwtSecret []byte, permissions int) (string, er
}

type jwtGocertClaims struct {
ID int `json:"id"`
Username string `json:"username"`
Permissions int `json:"permissions"`
jwt.StandardClaims
Expand Down
18 changes: 18 additions & 0 deletions internal/api/handlers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -737,6 +737,24 @@ func TestAuthorization(t *testing.T) {
response: "",
status: http.StatusForbidden,
},
{
desc: "user can change self password with /me",
method: "POST",
path: "/api/v1/accounts/me/change_password",
data: `{"password":"BetterPW1!"}`,
auth: nonAdminToken,
response: "",
status: http.StatusOK,
},
{
desc: "user can login with new password",
method: "POST",
path: "/login",
data: `{"username":"testuser","password":"BetterPW1!"}`,
auth: nonAdminToken,
response: "",
status: http.StatusOK,
},
{
desc: "admin can't delete itself",
method: "DELETE",
Expand Down
97 changes: 73 additions & 24 deletions internal/api/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,14 +94,16 @@ func loggingMiddleware(ctx *middlewareContext) middleware {
// authMiddleware intercepts requests that need authorization to check if the user's token exists and is
// permitted to use the endpoint
func authMiddleware(ctx *middlewareContext) middleware {
AdminOnlyPaths := []struct{ method, path string }{
{"POST", `accounts`},
{"GET", `accounts`},
{"GET", `accounts\/\d+$`},
{"DELETE", `accounts\/\d+$`},
{"POST", `accounts\/\d+\/change_password$`},
RestrictedPaths := []struct {
method, pathRegex string
SelfAuthorizedAllowed bool
}{
{"POST", `accounts$`, false},
{"GET", `accounts$`, false},
{"DELETE", `accounts\/(\d+)$`, false},
{"GET", `accounts\/(\d+)$`, true},
{"POST", `accounts\/(\d+)\/change_password$`, true},
}

return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if !strings.HasPrefix(r.URL.Path, "/api/v1/") {
Expand All @@ -115,35 +117,25 @@ func authMiddleware(ctx *middlewareContext) middleware {
}
return
}
authHeader := r.Header.Get("Authorization")
if authHeader == "" {
logErrorAndWriteResponse("authorization header not found", http.StatusUnauthorized, w)
return
}
bearerToken := strings.Split(authHeader, " ")
if len(bearerToken) != 2 || bearerToken[0] != "Bearer" {
logErrorAndWriteResponse("authorization header couldn't be processed. The expected format is 'Bearer <token>'", http.StatusUnauthorized, w)
return
}
claims, err := getClaimsFromJWT(bearerToken[1], ctx.jwtSecret)
claims, err := getClaimsFromAuthorizationHeader(r.Header.Get("Authorization"), ctx.jwtSecret)
if err != nil {
logErrorAndWriteResponse(fmt.Sprintf("token is not valid: %s", err.Error()), http.StatusUnauthorized, w)
logErrorAndWriteResponse(fmt.Sprintf("auth failed: %s", err.Error()), http.StatusUnauthorized, w)
return
}
if claims.Permissions == 0 {
for _, v := range AdminOnlyPaths {
matched, err := regexp.Match(v.path, []byte(r.URL.Path))
for _, v := range RestrictedPaths {
pathMatched, requestAllowed, err := AllowRequest(claims, v.pathRegex, r.URL.Path, v.SelfAuthorizedAllowed)
if err != nil {
logErrorAndWriteResponse(fmt.Sprintf("ran into issue parsing path: %s", err.Error()), http.StatusInternalServerError, w)
logErrorAndWriteResponse(fmt.Sprintf("error processing path: %s", err.Error()), http.StatusInternalServerError, w)
return
}
if r.Method == v.method && matched {
if pathMatched && !requestAllowed {
logErrorAndWriteResponse("forbidden", http.StatusForbidden, w)
return
}
}
}
if claims.Permissions == 1 && r.Method == "DELETE" && strings.HasSuffix(r.URL.Path, "accounts/1") {
if r.Method == "DELETE" && strings.HasSuffix(r.URL.Path, "accounts/1") {
gruyaume marked this conversation as resolved.
Show resolved Hide resolved
logErrorAndWriteResponse("can't delete admin account", http.StatusConflict, w)
return
}
Expand All @@ -152,6 +144,63 @@ func authMiddleware(ctx *middlewareContext) middleware {
}
}

func getClaimsFromAuthorizationHeader(header string, jwtSecret []byte) (*jwtGocertClaims, error) {
if header == "" {
return nil, fmt.Errorf("authorization header not found")
}
bearerToken := strings.Split(header, " ")
if len(bearerToken) != 2 || bearerToken[0] != "Bearer" {
return nil, fmt.Errorf("authorization header couldn't be processed. The expected format is 'Bearer <token>'")
}
claims, err := getClaimsFromJWT(bearerToken[1], jwtSecret)
if err != nil {
return nil, fmt.Errorf("token is not valid: %s", err)
}
return claims, nil
}

// AllowRequest looks at the user data to determine the following things:
// Is this user trying to access a path that's restricted? The answer to this question is the pathMatched variable
//
// There are two types of restricted paths: admin only paths that only admins can access, and self authorized paths,
// which users are allowed to use only if they are taking an action on their own user ID. The second question answers the
// question "If the path requires an ID, is the user attempting to access their own ID?"
//
// For all endpoints and permission permutations, there are only 2 cases when users are allowed to use endpoints:
// If the URL path is not restricted to admins
// If the URL path is restricted to self authorized endpoints, and the user is taking action to their own ID
// This function returns if the URL matches the restricted path, and if that request should be allowed according to the conditions above.
func AllowRequest(claims *jwtGocertClaims, regex, path string, SelfAuthorizedAllowed bool) (pathMatched, requestAllowed bool, err error) {
gruyaume marked this conversation as resolved.
Show resolved Hide resolved
gruyaume marked this conversation as resolved.
Show resolved Hide resolved
var idMatchedIfExistsInPath bool
regexChallenge, err := regexp.Compile(regex)
if err != nil {
return pathMatched, requestAllowed, fmt.Errorf("regex couldn't compile: %s", err)
}
matches := regexChallenge.FindStringSubmatch(path)
if len(matches) > 0 {
pathMatched = true
}
if len(matches) == 1 {
idMatchedIfExistsInPath = true
}
if len(matches) > 1 && SelfAuthorizedAllowed {
matchedID, err := strconv.Atoi(matches[1])
if err != nil {
return pathMatched, requestAllowed, fmt.Errorf("error converting url id to string: %s", err)
}
if matchedID == claims.ID {
idMatchedIfExistsInPath = true
}
}
if pathMatched && !SelfAuthorizedAllowed {
return pathMatched, false, nil
}
if pathMatched && SelfAuthorizedAllowed && !idMatchedIfExistsInPath {
return pathMatched, false, nil
}
return pathMatched, true, nil
}

func getClaimsFromJWT(bearerToken string, jwtSecret []byte) (*jwtGocertClaims, error) {
claims := jwtGocertClaims{}
token, err := jwt.ParseWithClaims(bearerToken, &claims, func(token *jwt.Token) (interface{}, error) {
Expand Down
Loading