Skip to content

Commit

Permalink
clean up auth middleware and redirect /me
Browse files Browse the repository at this point in the history
  • Loading branch information
kayra1 committed Jul 12, 2024
1 parent d16dd50 commit e9a8e5b
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 28 deletions.
32 changes: 28 additions & 4 deletions internal/api/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,9 @@ func NewGoCertRouter(env *Environment) http.Handler {
apiV1Router.HandleFunc("POST /certificate_requests/{id}/certificate/reject", RejectCertificate(env))
apiV1Router.HandleFunc("DELETE /certificate_requests/{id}/certificate", DeleteCertificate(env))

apiV1Router.HandleFunc("GET /accounts/{id}", GetUserAccount(env))
apiV1Router.HandleFunc("GET /accounts", GetUserAccounts(env))
apiV1Router.HandleFunc("POST /accounts", PostUserAccount(env))
apiV1Router.HandleFunc("GET /accounts/{id}", GetUserAccount(env))
apiV1Router.HandleFunc("DELETE /accounts/{id}", DeleteUserAccount(env))
apiV1Router.HandleFunc("POST /accounts/{id}/change_password", ChangeUserAccountPassword(env))

Expand Down Expand Up @@ -300,7 +300,17 @@ func GetUserAccounts(env *Environment) http.HandlerFunc {
func GetUserAccount(env *Environment) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
id := r.PathValue("id")
userAccount, err := env.DB.RetrieveUser(id)
var userAccount certdb.User
var err error
if id == "me" {
claims, headerErr := getClaimsFromAuthorizationHeader(r.Header.Get("Authorization"), env.JWTSecret)
if headerErr != nil {
logErrorAndWriteResponse(headerErr.Error(), http.StatusUnauthorized, w)
}
userAccount, err = env.DB.RetrieveUserByUsername(claims.Username)
} else {
userAccount, err = env.DB.RetrieveUser(id)
}
if err != nil {
if errors.Is(err, certdb.ErrIdNotFound) {
logErrorAndWriteResponse(err.Error(), http.StatusNotFound, w)
Expand Down Expand Up @@ -409,6 +419,17 @@ func DeleteUserAccount(env *Environment) http.HandlerFunc {
func ChangeUserAccountPassword(env *Environment) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
id := r.PathValue("id")
if id == "me" {
claims, err := getClaimsFromAuthorizationHeader(r.Header.Get("Authorization"), env.JWTSecret)
if err != nil {
logErrorAndWriteResponse(err.Error(), http.StatusUnauthorized, w)
}
userAccount, err := env.DB.RetrieveUserByUsername(claims.Username)
if err != nil {
logErrorAndWriteResponse(err.Error(), http.StatusUnauthorized, w)
}
id = strconv.Itoa(userAccount.ID)
}
var user certdb.User
if err := json.NewDecoder(r.Body).Decode(&user); err != nil {
logErrorAndWriteResponse("Invalid JSON format", http.StatusBadRequest, w)
Expand Down Expand Up @@ -471,7 +492,7 @@ func Login(env *Environment) http.HandlerFunc {
logErrorAndWriteResponse("The username or password is incorrect. Try again.", http.StatusUnauthorized, w)
return
}
jwt, err := generateJWT(userRequest.Username, env.JWTSecret, userAccount.Permissions)
jwt, err := generateJWT(userAccount.ID, userAccount.Username, env.JWTSecret, userAccount.Permissions)
if err != nil {
logErrorAndWriteResponse(err.Error(), http.StatusInternalServerError, w)
return
Expand All @@ -489,6 +510,7 @@ func logErrorAndWriteResponse(msg string, status int, w http.ResponseWriter) {
log.Println(errMsg)
w.WriteHeader(status)
if _, err := w.Write([]byte(errMsg)); err != nil {
// TODO: how did this get merged?
logErrorAndWriteResponse(err.Error(), http.StatusInternalServerError, w)
}
}
Expand Down Expand Up @@ -554,8 +576,9 @@ func validatePassword(password string) bool {
}

// Helper function to generate a JWT
func generateJWT(username string, jwtSecret []byte, permissions int) (string, error) {
func generateJWT(id int, username string, jwtSecret []byte, permissions int) (string, error) {
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwtGocertClaims{
ID: id,
Username: username,
Permissions: permissions,
StandardClaims: jwt.StandardClaims{
Expand All @@ -571,6 +594,7 @@ func generateJWT(username string, jwtSecret []byte, permissions int) (string, er
}

type jwtGocertClaims struct {
ID int `json:"id"`
Username string `json:"username"`
Permissions int `json:"permissions"`
jwt.StandardClaims
Expand Down
18 changes: 18 additions & 0 deletions internal/api/handlers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -737,6 +737,24 @@ func TestAuthorization(t *testing.T) {
response: "",
status: http.StatusForbidden,
},
{
desc: "user can change self password with /me",
method: "POST",
path: "/api/v1/accounts/me/change_password",
data: `{"password":"BetterPW1!"}`,
auth: nonAdminToken,
response: "",
status: http.StatusOK,
},
{
desc: "user can login with new password",
method: "POST",
path: "/login",
data: `{"username":"testuser","password":"BetterPW1!"}`,
auth: nonAdminToken,
response: "",
status: http.StatusOK,
},
{
desc: "admin can't delete itself",
method: "DELETE",
Expand Down
83 changes: 59 additions & 24 deletions internal/api/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,14 +94,16 @@ func loggingMiddleware(ctx *middlewareContext) middleware {
// authMiddleware intercepts requests that need authorization to check if the user's token exists and is
// permitted to use the endpoint
func authMiddleware(ctx *middlewareContext) middleware {
AdminOnlyPaths := []struct{ method, path string }{
{"POST", `accounts`},
{"GET", `accounts`},
{"GET", `accounts\/\d+$`},
{"DELETE", `accounts\/\d+$`},
{"POST", `accounts\/\d+\/change_password$`},
RestrictedPaths := []struct {
method, pathRegex string
SelfAuthorizedAllowed bool
}{
{"POST", `accounts$`, false},
{"GET", `accounts$`, false},
{"DELETE", `accounts\/(\d+)$`, false},
{"GET", `accounts\/(\d+)$`, true},
{"POST", `accounts\/(\d+)\/change_password$`, true},
}

return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if !strings.HasPrefix(r.URL.Path, "/api/v1/") {
Expand All @@ -115,35 +117,29 @@ func authMiddleware(ctx *middlewareContext) middleware {
}
return
}
authHeader := r.Header.Get("Authorization")
if authHeader == "" {
logErrorAndWriteResponse("authorization header not found", http.StatusUnauthorized, w)
return
}
bearerToken := strings.Split(authHeader, " ")
if len(bearerToken) != 2 || bearerToken[0] != "Bearer" {
logErrorAndWriteResponse("authorization header couldn't be processed. The expected format is 'Bearer <token>'", http.StatusUnauthorized, w)
return
}
claims, err := getClaimsFromJWT(bearerToken[1], ctx.jwtSecret)
claims, err := getClaimsFromAuthorizationHeader(r.Header.Get("Authorization"), ctx.jwtSecret)
if err != nil {
logErrorAndWriteResponse(fmt.Sprintf("token is not valid: %s", err.Error()), http.StatusUnauthorized, w)
logErrorAndWriteResponse(fmt.Sprintf("auth failed: %s", err.Error()), http.StatusUnauthorized, w)
return
}
if claims.Permissions == 0 {
for _, v := range AdminOnlyPaths {
matched, err := regexp.Match(v.path, []byte(r.URL.Path))
for _, v := range RestrictedPaths {
pathMatched, idMatchedIfExistsInPath, err := permitter(claims, v.pathRegex, r.URL.Path, v.SelfAuthorizedAllowed)
if err != nil {
logErrorAndWriteResponse(fmt.Sprintf("ran into issue parsing path: %s", err.Error()), http.StatusInternalServerError, w)
logErrorAndWriteResponse(fmt.Sprintf("error processing path: %s", err.Error()), http.StatusInternalServerError, w)
return
}
if r.Method == v.method && matched {
if pathMatched && !v.SelfAuthorizedAllowed {
logErrorAndWriteResponse("forbidden", http.StatusForbidden, w)
return
}
if pathMatched && v.SelfAuthorizedAllowed && !idMatchedIfExistsInPath {
logErrorAndWriteResponse("forbidden", http.StatusForbidden, w)
return
}
}
}
if claims.Permissions == 1 && r.Method == "DELETE" && strings.HasSuffix(r.URL.Path, "accounts/1") {
if r.Method == "DELETE" && strings.HasSuffix(r.URL.Path, "accounts/1") {
logErrorAndWriteResponse("can't delete admin account", http.StatusConflict, w)
return
}
Expand All @@ -152,6 +148,45 @@ func authMiddleware(ctx *middlewareContext) middleware {
}
}

func getClaimsFromAuthorizationHeader(header string, jwtSecret []byte) (*jwtGocertClaims, error) {
if header == "" {
return nil, fmt.Errorf("authorization header not found")
}
bearerToken := strings.Split(header, " ")
if len(bearerToken) != 2 || bearerToken[0] != "Bearer" {
return nil, fmt.Errorf("authorization header couldn't be processed. The expected format is 'Bearer <token>'")
}
claims, err := getClaimsFromJWT(bearerToken[1], jwtSecret)
if err != nil {
return nil, fmt.Errorf("token is not valid: %s", err)
}
return claims, nil
}

func permitter(claims *jwtGocertClaims, regex, path string, SelfAuthorizedAllowed bool) (pathMatched, idMatchedIfExistsInPath bool, err error) {
regexChallenge, err := regexp.Compile(regex)
if err != nil {
return pathMatched, idMatchedIfExistsInPath, fmt.Errorf("regex couldn't compile: %s", err)
}
matches := regexChallenge.FindStringSubmatch(path)
if len(matches) > 0 {
pathMatched = true
}
if len(matches) == 1 {
idMatchedIfExistsInPath = true
}
if len(matches) > 1 && SelfAuthorizedAllowed {
matchedID, err := strconv.Atoi(matches[1])
if err != nil {
return pathMatched, idMatchedIfExistsInPath, fmt.Errorf("error converting url id to string: %s", err)
}
if matchedID == claims.ID {
idMatchedIfExistsInPath = true
}
}
return pathMatched, idMatchedIfExistsInPath, nil
}

func getClaimsFromJWT(bearerToken string, jwtSecret []byte) (*jwtGocertClaims, error) {
claims := jwtGocertClaims{}
token, err := jwt.ParseWithClaims(bearerToken, &claims, func(token *jwt.Token) (interface{}, error) {
Expand Down

0 comments on commit e9a8e5b

Please sign in to comment.