From 044ea2a78300e6f9bcb073bcd21b6d34b47d7fff Mon Sep 17 00:00:00 2001 From: garrettladley Date: Sun, 4 Feb 2024 18:35:45 -0500 Subject: [PATCH] tests passing | auth adjustments --- backend/src/controllers/auth.go | 11 +- backend/src/controllers/user.go | 6 +- backend/src/errors/auth.go | 18 +++ backend/src/errors/common.go | 4 - backend/src/middleware/auth.go | 25 ++++- backend/src/middleware/club.go | 2 +- backend/src/middleware/user.go | 4 +- backend/src/models/user.go | 1 - backend/src/server/routes/auth.go | 20 ++++ backend/src/server/routes/category.go | 21 ++++ backend/src/server/routes/category_tag.go | 16 +++ backend/src/server/routes/club.go | 26 +++++ backend/src/server/routes/tag.go | 18 +++ backend/src/server/routes/user.go | 28 +++++ backend/src/server/routes/user_tag.go | 16 +++ backend/src/server/routes/utility.go | 13 +++ backend/src/server/server.go | 127 +++------------------- backend/src/services/user.go | 8 +- backend/src/types/custom_claims.go | 20 +++- backend/tests/api/club_test.go | 18 +-- backend/tests/api/user_tag_test.go | 6 +- backend/tests/api/user_test.go | 21 +--- 22 files changed, 256 insertions(+), 173 deletions(-) create mode 100644 backend/src/errors/auth.go create mode 100644 backend/src/server/routes/auth.go create mode 100644 backend/src/server/routes/category.go create mode 100644 backend/src/server/routes/category_tag.go create mode 100644 backend/src/server/routes/club.go create mode 100644 backend/src/server/routes/tag.go create mode 100644 backend/src/server/routes/user.go create mode 100644 backend/src/server/routes/user_tag.go create mode 100644 backend/src/server/routes/utility.go diff --git a/backend/src/controllers/auth.go b/backend/src/controllers/auth.go index b6af16f29..f666419e1 100644 --- a/backend/src/controllers/auth.go +++ b/backend/src/controllers/auth.go @@ -8,6 +8,7 @@ import ( "github.com/GenerateNU/sac/backend/src/errors" "github.com/GenerateNU/sac/backend/src/models" "github.com/GenerateNU/sac/backend/src/services" + "github.com/GenerateNU/sac/backend/src/types" "github.com/GenerateNU/sac/backend/src/utilities" "github.com/gofiber/fiber/v2" ) @@ -33,15 +34,7 @@ func NewAuthController(authService services.AuthServiceInterface, authSettings c // @Failure 401 {string} string "failed to get current user" // @Router /api/v1/auth/me [get] func (a *AuthController) Me(c *fiber.Ctx) error { - // Extract token values from cookies - accessTokenValue := c.Cookies("access_token") - - claims, err := auth.ExtractAccessClaims(accessTokenValue, a.AuthSettings.AccessToken) - if err != nil { - return err.FiberError(c) - } - - user, err := a.authService.Me(claims.Issuer) + user, err := a.authService.Me(types.From(c).Issuer) if err != nil { return err.FiberError(c) } diff --git a/backend/src/controllers/user.go b/backend/src/controllers/user.go index 1b509bf3d..60746d3dd 100644 --- a/backend/src/controllers/user.go +++ b/backend/src/controllers/user.go @@ -81,7 +81,7 @@ func (u *UserController) GetUsers(c *fiber.Ctx) error { // @Failure 500 {string} string "failed to get user" // @Router /api/v1/users/:id [get] func (u *UserController) GetUser(c *fiber.Ctx) error { - user, err := u.userService.GetUser(c.Params("id")) + user, err := u.userService.GetUser(c.Params("userID")) if err != nil { return err.FiberError(c) } @@ -110,7 +110,7 @@ func (u *UserController) UpdateUser(c *fiber.Ctx) error { return errors.FailedToParseRequestBody.FiberError(c) } - updatedUser, err := u.userService.UpdateUser(c.Params("id"), user) + updatedUser, err := u.userService.UpdateUser(c.Params("userID"), user) if err != nil { return err.FiberError(c) } @@ -130,7 +130,7 @@ func (u *UserController) UpdateUser(c *fiber.Ctx) error { // @Failure 500 {string} string "failed to get all users" // @Router /api/v1/users/:id [delete] func (u *UserController) DeleteUser(c *fiber.Ctx) error { - err := u.userService.DeleteUser(c.Params("id")) + err := u.userService.DeleteUser(c.Params("userID")) if err != nil { return err.FiberError(c) } diff --git a/backend/src/errors/auth.go b/backend/src/errors/auth.go new file mode 100644 index 000000000..37d9b1a8b --- /dev/null +++ b/backend/src/errors/auth.go @@ -0,0 +1,18 @@ +package errors + +import "github.com/gofiber/fiber/v2" + +var ( + PassedAuthenticateMiddlewareButNilClaims = Error{ + StatusCode: fiber.StatusInternalServerError, + Message: "passed authenticate middleware but claims is nil", + } + FailedToCastToCustomClaims = Error{ + StatusCode: fiber.StatusInternalServerError, + Message: "failed to cast to custom claims", + } + ExpectedClaimsButGotNil = Error{ + StatusCode: fiber.StatusInternalServerError, + Message: "expected claims but got nil", + } +) diff --git a/backend/src/errors/common.go b/backend/src/errors/common.go index a42f9b4fd..343594d5b 100644 --- a/backend/src/errors/common.go +++ b/backend/src/errors/common.go @@ -63,8 +63,4 @@ var ( StatusCode: fiber.StatusUnauthorized, Message: "failed to validate access token", } - FailedToParseUUID = Error{ - StatusCode: fiber.StatusBadRequest, - Message: "failed to parse uuid", - } ) diff --git a/backend/src/middleware/auth.go b/backend/src/middleware/auth.go index 68bf5bf7b..fb05f6179 100644 --- a/backend/src/middleware/auth.go +++ b/backend/src/middleware/auth.go @@ -9,6 +9,7 @@ import ( "github.com/GenerateNU/sac/backend/src/types" "github.com/gofiber/fiber/v2" + "github.com/gofiber/fiber/v2/middleware/skip" ) var paths = []string{ @@ -18,6 +19,17 @@ var paths = []string{ "/api/v1/auth/logout", } +func SuperSkipper(h fiber.Handler) fiber.Handler { + return skip.New(h, func(c *fiber.Ctx) bool { + claims, err := types.From(c) + if err != nil { + err.FiberError(c) + return false + } + return claims.Role == string(models.Super) + }) +} + func (m *MiddlewareService) Authenticate(c *fiber.Ctx) error { if slices.Contains(paths, c.Path()) { return c.Next() @@ -28,7 +40,7 @@ func (m *MiddlewareService) Authenticate(c *fiber.Ctx) error { return errors.FailedToParseAccessToken.FiberError(c) } - _, ok := token.Claims.(*types.CustomClaims) + claims, ok := token.Claims.(*types.CustomClaims) if !ok || !token.Valid { return errors.FailedToValidateAccessToken.FiberError(c) } @@ -37,11 +49,22 @@ func (m *MiddlewareService) Authenticate(c *fiber.Ctx) error { return errors.Unauthorized.FiberError(c) } + c.Locals("claims", claims) + return c.Next() } func (m *MiddlewareService) Authorize(requiredPermissions ...types.Permission) func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error { + claims, fromErr := types.From(c) + if fromErr != nil { + return fromErr.FiberError(c) + } + + if claims != nil && claims.Role == string(models.Super) { + return c.Next() + } + role, err := auth.GetRoleFromToken(c.Cookies("access_token"), m.AuthSettings.AccessToken) if err != nil { return errors.FailedToParseAccessToken.FiberError(c) diff --git a/backend/src/middleware/club.go b/backend/src/middleware/club.go index 725b0f793..7d99b76b2 100644 --- a/backend/src/middleware/club.go +++ b/backend/src/middleware/club.go @@ -14,7 +14,7 @@ import ( func (m *MiddlewareService) ClubAuthorizeById(c *fiber.Ctx) error { clubUUID, err := utilities.ValidateID(c.Params("id")) if err != nil { - return errors.FailedToParseUUID.FiberError(c) + return errors.FailedToValidateID.FiberError(c) } token, tokenErr := auth.ParseAccessToken(c.Cookies("access_token"), m.AuthSettings.AccessToken) diff --git a/backend/src/middleware/user.go b/backend/src/middleware/user.go index 12514ff71..d431907cd 100644 --- a/backend/src/middleware/user.go +++ b/backend/src/middleware/user.go @@ -11,7 +11,7 @@ import ( func (m *MiddlewareService) UserAuthorizeById(c *fiber.Ctx) error { idAsUUID, err := utilities.ValidateID(c.Params("id")) if err != nil { - return errors.FailedToParseUUID.FiberError(c) + return errors.FailedToValidateID.FiberError(c) } token, tokenErr := auth.ParseAccessToken(c.Cookies("access_token"), m.AuthSettings.AccessToken) @@ -26,7 +26,7 @@ func (m *MiddlewareService) UserAuthorizeById(c *fiber.Ctx) error { issuerIDAsUUID, err := utilities.ValidateID(claims.Issuer) if err != nil { - return errors.FailedToParseUUID.FiberError(c) + return errors.FailedToValidateID.FiberError(c) } if issuerIDAsUUID.String() == idAsUUID.String() { diff --git a/backend/src/models/user.go b/backend/src/models/user.go index ad27132c2..fefd911b6 100644 --- a/backend/src/models/user.go +++ b/backend/src/models/user.go @@ -72,7 +72,6 @@ type UpdateUserRequestBody struct { FirstName string `json:"first_name" validate:"omitempty,max=255"` LastName string `json:"last_name" validate:"omitempty,max=255"` Email string `json:"email" validate:"omitempty,email,neu_email,max=255"` - Password string `json:"password" validate:"omitempty,password"` College College `json:"college" validate:"omitempty,oneof=CAMD DMSB KCCS CE BCHS SL CPS CS CSSH"` Year Year `json:"year" validate:"omitempty,min=1,max=6"` } diff --git a/backend/src/server/routes/auth.go b/backend/src/server/routes/auth.go new file mode 100644 index 000000000..7efe8d127 --- /dev/null +++ b/backend/src/server/routes/auth.go @@ -0,0 +1,20 @@ +package routes + +import ( + "github.com/GenerateNU/sac/backend/src/config" + "github.com/GenerateNU/sac/backend/src/controllers" + "github.com/GenerateNU/sac/backend/src/services" + "github.com/gofiber/fiber/v2" +) + +func Auth(router fiber.Router, authService services.AuthServiceInterface, authSettings config.AuthSettings) { + authController := controllers.NewAuthController(authService, authSettings) + + // api/v1/auth/* + auth := router.Group("/auth") + + auth.Post("/login", authController.Login) + auth.Get("/logout", authController.Logout) + auth.Get("/refresh", authController.Refresh) + auth.Get("/me", authController.Me) +} diff --git a/backend/src/server/routes/category.go b/backend/src/server/routes/category.go new file mode 100644 index 000000000..dccdfbb57 --- /dev/null +++ b/backend/src/server/routes/category.go @@ -0,0 +1,21 @@ +package routes + +import ( + "github.com/GenerateNU/sac/backend/src/controllers" + "github.com/GenerateNU/sac/backend/src/services" + "github.com/gofiber/fiber/v2" +) + +func Category(router fiber.Router, categoryService services.CategoryServiceInterface) fiber.Router { + categoryController := controllers.NewCategoryController(categoryService) + + categories := router.Group("/categories") + + categories.Post("/", categoryController.CreateCategory) + categories.Get("/", categoryController.GetCategories) + categories.Get("/:id", categoryController.GetCategory) + categories.Delete("/:id", categoryController.DeleteCategory) + categories.Patch("/:id", categoryController.UpdateCategory) + + return categories +} diff --git a/backend/src/server/routes/category_tag.go b/backend/src/server/routes/category_tag.go new file mode 100644 index 000000000..7720e7bb5 --- /dev/null +++ b/backend/src/server/routes/category_tag.go @@ -0,0 +1,16 @@ +package routes + +import ( + "github.com/GenerateNU/sac/backend/src/controllers" + "github.com/GenerateNU/sac/backend/src/services" + "github.com/gofiber/fiber/v2" +) + +func CategoryTag(router fiber.Router, categoryTagService services.CategoryTagServiceInterface) { + categoryTagController := controllers.NewCategoryTagController(categoryTagService) + + categoryTags := router.Group("/:categoryID/tags") + + categoryTags.Get("/", categoryTagController.GetTagsByCategory) + categoryTags.Get("/:tagID", categoryTagController.GetTagByCategory) +} diff --git a/backend/src/server/routes/club.go b/backend/src/server/routes/club.go new file mode 100644 index 000000000..89b42191f --- /dev/null +++ b/backend/src/server/routes/club.go @@ -0,0 +1,26 @@ +package routes + +import ( + "github.com/GenerateNU/sac/backend/src/controllers" + "github.com/GenerateNU/sac/backend/src/middleware" + "github.com/GenerateNU/sac/backend/src/services" + "github.com/GenerateNU/sac/backend/src/types" + "github.com/gofiber/fiber/v2" +) + +func Club(router fiber.Router, clubService services.ClubServiceInterface, middlewareService middleware.MiddlewareInterface) { + clubController := controllers.NewClubController(clubService) + + clubs := router.Group("/clubs") + + clubs.Get("/", middlewareService.Authorize(types.ClubReadAll), clubController.GetAllClubs) + clubs.Post("/", clubController.CreateClub) + + // api/v1/clubs/:id/* + clubsID := clubs.Group("/:id") + clubsID.Use(middleware.SuperSkipper(middlewareService.UserAuthorizeById)) + + clubsID.Get("/", clubController.GetClub) + clubsID.Patch("/", middlewareService.Authorize(types.ClubWrite), clubController.UpdateClub) + clubsID.Delete("/", middleware.SuperSkipper(middlewareService.Authorize(types.ClubDelete)), clubController.DeleteClub) +} diff --git a/backend/src/server/routes/tag.go b/backend/src/server/routes/tag.go new file mode 100644 index 000000000..6bd9bf8b8 --- /dev/null +++ b/backend/src/server/routes/tag.go @@ -0,0 +1,18 @@ +package routes + +import ( + "github.com/GenerateNU/sac/backend/src/controllers" + "github.com/GenerateNU/sac/backend/src/services" + "github.com/gofiber/fiber/v2" +) + +func Tag(router fiber.Router, tagService services.TagServiceInterface) { + tagController := controllers.NewTagController(tagService) + + tags := router.Group("/tags") + + tags.Get("/:tagID", tagController.GetTag) + tags.Post("/", tagController.CreateTag) + tags.Patch("/:tagID", tagController.UpdateTag) + tags.Delete("/:tagID", tagController.DeleteTag) +} diff --git a/backend/src/server/routes/user.go b/backend/src/server/routes/user.go new file mode 100644 index 000000000..2926e875c --- /dev/null +++ b/backend/src/server/routes/user.go @@ -0,0 +1,28 @@ +package routes + +import ( + "github.com/GenerateNU/sac/backend/src/controllers" + "github.com/GenerateNU/sac/backend/src/middleware" + "github.com/GenerateNU/sac/backend/src/services" + "github.com/GenerateNU/sac/backend/src/types" + "github.com/gofiber/fiber/v2" +) + +func User(router fiber.Router, userService services.UserServiceInterface, middlewareService middleware.MiddlewareInterface) fiber.Router { + userController := controllers.NewUserController(userService) + + // api/v1/users/* + users := router.Group("/users") + users.Post("/", userController.CreateUser) + users.Get("/", middleware.SuperSkipper(middlewareService.Authorize(types.UserReadAll)), userController.GetUsers) + + // api/v1/users/:userID/* + usersID := users.Group("/:userID") + usersID.Use(middleware.SuperSkipper(middlewareService.UserAuthorizeById)) + + usersID.Get("/", userController.GetUser) + usersID.Patch("/", userController.UpdateUser) + usersID.Delete("/", userController.DeleteUser) + + return users +} diff --git a/backend/src/server/routes/user_tag.go b/backend/src/server/routes/user_tag.go new file mode 100644 index 000000000..1b777ff47 --- /dev/null +++ b/backend/src/server/routes/user_tag.go @@ -0,0 +1,16 @@ +package routes + +import ( + "github.com/GenerateNU/sac/backend/src/controllers" + "github.com/GenerateNU/sac/backend/src/services" + "github.com/gofiber/fiber/v2" +) + +func UserTag(router fiber.Router, userTagService services.UserTagServiceInterface) { + userTagController := controllers.NewUserTagController(userTagService) + + userTags := router.Group("/:userID/tags") + + userTags.Post("/", userTagController.CreateUserTags) + userTags.Get("/", userTagController.GetUserTags) +} diff --git a/backend/src/server/routes/utility.go b/backend/src/server/routes/utility.go new file mode 100644 index 000000000..5730a021d --- /dev/null +++ b/backend/src/server/routes/utility.go @@ -0,0 +1,13 @@ +package routes + +import ( + "github.com/gofiber/fiber/v2" + "github.com/gofiber/swagger" +) + +func Utility(router fiber.Router) { + router.Get("/swagger/*", swagger.HandlerDefault) + router.Get("/health", func(c *fiber.Ctx) error { + return c.SendStatus(200) + }) +} diff --git a/backend/src/server/server.go b/backend/src/server/server.go index 268624f08..f65b70c23 100644 --- a/backend/src/server/server.go +++ b/backend/src/server/server.go @@ -2,10 +2,9 @@ package server import ( "github.com/GenerateNU/sac/backend/src/config" - "github.com/GenerateNU/sac/backend/src/controllers" "github.com/GenerateNU/sac/backend/src/middleware" + "github.com/GenerateNU/sac/backend/src/server/routes" "github.com/GenerateNU/sac/backend/src/services" - "github.com/GenerateNU/sac/backend/src/types" "github.com/GenerateNU/sac/backend/src/utilities" "github.com/goccy/go-json" @@ -13,7 +12,6 @@ import ( "github.com/gofiber/fiber/v2/middleware/cors" "github.com/gofiber/fiber/v2/middleware/logger" "github.com/gofiber/fiber/v2/middleware/requestid" - "github.com/gofiber/swagger" "gorm.io/gorm" ) @@ -34,14 +32,19 @@ func Init(db *gorm.DB, settings config.Settings) *fiber.App { apiv1 := app.Group("/api/v1") apiv1.Use(middlewareService.Authenticate) - utilityRoutes(app) - authRoutes(apiv1, services.NewAuthService(db, validate), settings.Auth) - userRouter := userRoutes(apiv1, services.NewUserService(db, validate), middlewareService) - userTagRouter(userRouter, services.NewUserTagService(db, validate)) - clubRoutes(apiv1, services.NewClubService(db, validate), middlewareService) - categoryRouter := categoryRoutes(apiv1, services.NewCategoryService(db, validate)) - tagRoutes(apiv1, services.NewTagService(db, validate)) - categoryTagRoutes(categoryRouter, services.NewCategoryTagService(db, validate)) + routes.Utility(app) + + routes.Auth(apiv1, services.NewAuthService(db, validate), settings.Auth) + + userRouter := routes.User(apiv1, services.NewUserService(db, validate), middlewareService) + routes.UserTag(userRouter, services.NewUserTagService(db, validate)) + + routes.Club(apiv1, services.NewClubService(db, validate), middlewareService) + + routes.Tag(apiv1, services.NewTagService(db, validate)) + + categoryRouter := routes.Category(apiv1, services.NewCategoryService(db, validate)) + routes.CategoryTag(categoryRouter, services.NewCategoryTagService(db, validate)) return app } @@ -63,105 +66,3 @@ func newFiberApp() *fiber.App { return app } - -func utilityRoutes(router fiber.Router) { - router.Get("/swagger/*", swagger.HandlerDefault) - router.Get("/health", func(c *fiber.Ctx) error { - return c.SendStatus(200) - }) -} - -func userRoutes(router fiber.Router, userService services.UserServiceInterface, middlewareService middleware.MiddlewareInterface) fiber.Router { - userController := controllers.NewUserController(userService) - - // api/v1/users/* - users := router.Group("/users") - users.Post("/", userController.CreateUser) - users.Get("/", middlewareService.Authorize(types.UserReadAll), userController.GetUsers) - - // api/v1/users/:id/* - usersID := users.Group("/:id") - usersID.Use(middlewareService.UserAuthorizeById) - - usersID.Get("/", middlewareService.Authorize(types.UserRead), userController.GetUser) - usersID.Patch("/", middlewareService.Authorize(types.UserWrite), userController.UpdateUser) - usersID.Delete("/", middlewareService.Authorize(types.UserDelete), userController.DeleteUser) - - users.Get("/", userController.GetUsers) - users.Get("/:id", userController.GetUser) - users.Patch("/:id", userController.UpdateUser) - users.Delete("/:id", userController.DeleteUser) - - return users -} - -func userTagRouter(router fiber.Router, userTagService services.UserTagServiceInterface) { - userTagController := controllers.NewUserTagController(userTagService) - - userTags := router.Group("/:userID/tags") - - userTags.Post("/", userTagController.CreateUserTags) - userTags.Get("/", userTagController.GetUserTags) -} - -func clubRoutes(router fiber.Router, clubService services.ClubServiceInterface, middlewareService middleware.MiddlewareInterface) { - clubController := controllers.NewClubController(clubService) - - clubs := router.Group("/clubs") - - clubs.Get("/", middlewareService.Authorize(types.ClubReadAll), clubController.GetAllClubs) - clubs.Post("/", clubController.CreateClub) - - // api/v1/clubs/:id/* - clubsID := clubs.Group("/:id") - clubsID.Use(middlewareService.ClubAuthorizeById) - - clubsID.Get("/", clubController.GetClub) - clubsID.Patch("/", middlewareService.Authorize(types.ClubWrite), clubController.UpdateClub) - clubsID.Delete("/", middlewareService.Authorize(types.ClubDelete), clubController.DeleteClub) -} - -func authRoutes(router fiber.Router, authService services.AuthServiceInterface, authSettings config.AuthSettings) { - authController := controllers.NewAuthController(authService, authSettings) - - // api/v1/auth/* - auth := router.Group("/auth") - auth.Post("/login", authController.Login) - auth.Get("/logout", authController.Logout) - auth.Get("/refresh", authController.Refresh) - auth.Get("/me", authController.Me) -} - -func categoryRoutes(router fiber.Router, categoryService services.CategoryServiceInterface) fiber.Router { - categoryController := controllers.NewCategoryController(categoryService) - - categories := router.Group("/categories") - - categories.Post("/", categoryController.CreateCategory) - categories.Get("/", categoryController.GetCategories) - categories.Get("/:id", categoryController.GetCategory) - categories.Delete("/:id", categoryController.DeleteCategory) - categories.Patch("/:id", categoryController.UpdateCategory) - - return categories -} - -func tagRoutes(router fiber.Router, tagService services.TagServiceInterface) { - tagController := controllers.NewTagController(tagService) - - tags := router.Group("/tags") - - tags.Get("/:tagID", tagController.GetTag) - tags.Post("/", tagController.CreateTag) - tags.Patch("/:tagID", tagController.UpdateTag) - tags.Delete("/:tagID", tagController.DeleteTag) -} - -func categoryTagRoutes(router fiber.Router, categoryTagService services.CategoryTagServiceInterface) { - categoryTagController := controllers.NewCategoryTagController(categoryTagService) - - categoryTags := router.Group("/:categoryID/tags") - - categoryTags.Get("/", categoryTagController.GetTagsByCategory) - categoryTags.Get("/:tagID", categoryTagController.GetTagByCategory) -} diff --git a/backend/src/services/user.go b/backend/src/services/user.go index 646976e09..82f4c431b 100644 --- a/backend/src/services/user.go +++ b/backend/src/services/user.go @@ -1,6 +1,7 @@ package services import ( + "fmt" "strings" "github.com/GenerateNU/sac/backend/src/auth" @@ -86,17 +87,12 @@ func (u *UserService) UpdateUser(id string, userBody models.UpdateUserRequestBod return nil, &errors.FailedToValidateUser } - passwordHash, err := auth.ComputePasswordHash(userBody.Password) - if err != nil { - return nil, &errors.FailedToComputePasswordHash - } - user, err := utilities.MapRequestToModel(userBody, &models.User{}) if err != nil { return nil, &errors.FailedToMapRequestToModel } - user.PasswordHash = *passwordHash + fmt.Println(user) return transactions.UpdateUser(u.DB, *idAsUUID, *user) } diff --git a/backend/src/types/custom_claims.go b/backend/src/types/custom_claims.go index b53da552e..b475d96d7 100644 --- a/backend/src/types/custom_claims.go +++ b/backend/src/types/custom_claims.go @@ -1,8 +1,26 @@ package types -import "github.com/golang-jwt/jwt" +import ( + "github.com/GenerateNU/sac/backend/src/errors" + "github.com/gofiber/fiber/v2" + "github.com/golang-jwt/jwt" +) type CustomClaims struct { jwt.StandardClaims Role string `json:"role"` } + +func From(c *fiber.Ctx) (*CustomClaims, *errors.Error) { + rawClaims := c.Locals("claims") + if rawClaims == nil { + return nil, &errors.ExpectedClaimsButGotNil + } + + claims, ok := rawClaims.(*CustomClaims) + if !ok { + return nil, &errors.FailedToCastToCustomClaims + } + + return claims, nil +} diff --git a/backend/tests/api/club_test.go b/backend/tests/api/club_test.go index 778b28fb8..bdce0892e 100644 --- a/backend/tests/api/club_test.go +++ b/backend/tests/api/club_test.go @@ -70,7 +70,6 @@ func AssertClubBodyRespDB(app h.TestApp, assert *assert.A, resp *http.Response, assert.Equal((*body)["name"].(string), dbClub.Name) assert.Equal((*body)["preview"].(string), dbClub.Preview) assert.Equal((*body)["description"].(string), dbClub.Description) - assert.Equal((*body)["num_members"].(int), dbClub.NumMembers) assert.Equal((*body)["is_recruiting"].(bool), dbClub.IsRecruiting) assert.Equal(models.RecruitmentCycle((*body)["recruitment_cycle"].(string)), dbClub.RecruitmentCycle) assert.Equal(models.RecruitmentType((*body)["recruitment_type"].(string)), dbClub.RecruitmentType) @@ -309,7 +308,6 @@ func TestCreateClubFailsOnInvalidLogo(t *testing.T) { ) } -// TODO: need to be able to join the club func TestUpdateClubWorks(t *testing.T) { appAssert, studentUUID, clubUUID := CreateSampleClub(h.InitTest(t)) @@ -333,7 +331,6 @@ func TestUpdateClubWorks(t *testing.T) { ).Close() } -// TODO: need to be able to join the club to try to update func TestUpdateClubFailsOnInvalidBody(t *testing.T) { appAssert, studentUUID, clubUUID := CreateSampleClub(h.InitTest(t)) @@ -374,11 +371,10 @@ func TestUpdateClubFailsOnInvalidBody(t *testing.T) { assert.Equal(1, len(dbAdmins)) - assert.Equal((*body)["user_id"].(uuid.UUID), dbAdmins[0].ID) + assert.Equal(*(*body)["user_id"].(*uuid.UUID), dbAdmins[0].ID) assert.Equal((*body)["name"].(string), dbClub.Name) assert.Equal((*body)["preview"].(string), dbClub.Preview) assert.Equal((*body)["description"].(string), dbClub.Description) - assert.Equal((*body)["num_members"].(int), dbClub.NumMembers) assert.Equal((*body)["is_recruiting"].(bool), dbClub.IsRecruiting) assert.Equal(models.RecruitmentCycle((*body)["recruitment_cycle"].(string)), dbClub.RecruitmentCycle) assert.Equal(models.RecruitmentType((*body)["recruitment_type"].(string)), dbClub.RecruitmentType) @@ -412,14 +408,13 @@ func TestUpdateClubFailsBadRequest(t *testing.T) { Body: h.SampleStudentJSONFactory(sampleStudent, rawPassword), Role: &models.Super, }, - errors.FailedToParseUUID, + errors.FailedToValidateID, ) } appAssert.Close() } -// TODO: should this be unauthorized or not found? func TestUpdateClubFailsOnClubIdNotExist(t *testing.T) { uuid := uuid.New() @@ -427,7 +422,7 @@ func TestUpdateClubFailsOnClubIdNotExist(t *testing.T) { Method: fiber.MethodPatch, Path: fmt.Sprintf("/api/v1/clubs/%s", uuid), Body: SampleClubFactory(nil), - Role: &models.Student, + Role: &models.Super, TestUserIDReplaces: h.StringToPointer("user_id"), }, h.ErrorWithTester{ @@ -443,7 +438,6 @@ func TestUpdateClubFailsOnClubIdNotExist(t *testing.T) { ).Close() } -// TODO: need to be able to join the club func TestDeleteClubWorks(t *testing.T) { appAssert, _, clubUUID := CreateSampleClub(h.InitTest(t)) @@ -460,7 +454,6 @@ func TestDeleteClubWorks(t *testing.T) { ).Close() } -// TODO: should this be unauthorized or not found? func TestDeleteClubNotExist(t *testing.T) { uuid := uuid.New() h.InitTest(t).TestOnErrorAndDB( @@ -470,7 +463,7 @@ func TestDeleteClubNotExist(t *testing.T) { Role: &models.Super, }, h.ErrorWithTester{ - Error: errors.Unauthorized, + Error: errors.ClubNotFound, Tester: func(app h.TestApp, assert *assert.A, resp *http.Response) { var club models.Club @@ -501,7 +494,8 @@ func TestDeleteClubBadRequest(t *testing.T) { Method: fiber.MethodDelete, Path: fmt.Sprintf("/api/v1/clubs/%s", badRequest), Role: &models.Super, - }, errors.FailedToParseUUID, + }, + errors.FailedToValidateID, ) } diff --git a/backend/tests/api/user_tag_test.go b/backend/tests/api/user_tag_test.go index e35319f08..013b5ddb9 100644 --- a/backend/tests/api/user_tag_test.go +++ b/backend/tests/api/user_tag_test.go @@ -192,7 +192,7 @@ func TestCreateUserTagsFailsOnInvalidUserID(t *testing.T) { Body: SampleTagIDsFactory(nil), Role: &models.Student, }, - errors.FailedToParseUUID, + errors.FailedToValidateID, ).Close() } } @@ -225,7 +225,6 @@ func TestCreateUserTagsFailsOnInvalidKey(t *testing.T) { } } -// TODO: should this be unauthorized or not found? func TestCreateUserTagsFailsOnNonExistentUser(t *testing.T) { uuid := uuid.New() @@ -242,7 +241,7 @@ func TestCreateUserTagsFailsOnNonExistentUser(t *testing.T) { var dbUser models.User err := app.Conn.First(&dbUser, uuid).Error - assert.Error(err) + assert.Assert(err != nil) }, }, ).Close() @@ -296,7 +295,6 @@ func TestCreateUserTagsNoneAddedIfInvalid(t *testing.T) { ).Close() } -// TODO: should this be unauthorized or not found? func TestGetUserTagsFailsOnNonExistentUser(t *testing.T) { h.InitTest(t).TestOnError( h.TestRequest{ diff --git a/backend/tests/api/user_test.go b/backend/tests/api/user_test.go index 57cd7db3e..5ba723114 100644 --- a/backend/tests/api/user_test.go +++ b/backend/tests/api/user_test.go @@ -128,14 +128,13 @@ func TestGetUserFailsBadRequest(t *testing.T) { Path: fmt.Sprintf("/api/v1/users/%s", badRequest), Role: &models.Super, }, - errors.FailedToParseUUID, + errors.FailedToValidateID, ) } appAssert.Close() } -// TODO: should this be not found or unauthorized? func TestGetUserFailsNotExist(t *testing.T) { uuid := uuid.New() @@ -158,7 +157,6 @@ func TestGetUserFailsNotExist(t *testing.T) { ).Close() } -// TODO: should this be unathorized or be allowed? func TestUpdateUserWorks(t *testing.T) { newFirstName := "Michael" newLastName := "Brennan" @@ -211,18 +209,14 @@ func TestUpdateUserWorks(t *testing.T) { ).Close() } -// TODO: should this be unauthorized or fail on processing request func TestUpdateUserFailsOnInvalidBody(t *testing.T) { - appAssert := h.InitTest(t) - for _, invalidData := range []map[string]interface{}{ {"email": "not.northeastern@gmail.com"}, {"nuid": "1800-123-4567"}, - {"password": "1234"}, {"year": 1963}, {"college": "UT-Austin"}, } { - appAssert.TestOnErrorAndDB( + h.InitTest(t).TestOnErrorAndDB( h.TestRequest{ Method: fiber.MethodPatch, Path: "/api/v1/users/:userID", @@ -234,10 +228,8 @@ func TestUpdateUserFailsOnInvalidBody(t *testing.T) { Error: errors.FailedToValidateUser, Tester: TestNumUsersRemainsAt2, }, - ) + ).Close() } - - appAssert.Close() } func TestUpdateUserFailsBadRequest(t *testing.T) { @@ -260,12 +252,11 @@ func TestUpdateUserFailsBadRequest(t *testing.T) { Body: slightlyDifferentSampleStudentJSON, Role: &models.Student, }, - errors.FailedToParseUUID, + errors.FailedToValidateID, ).Close() } } -// TODO: should this be unauthorized or not found? func TestUpdateUserFailsOnIdNotExist(t *testing.T) { uuid := uuid.New() @@ -291,7 +282,6 @@ func TestUpdateUserFailsOnIdNotExist(t *testing.T) { ).Close() } -// TODO: should this be unauthorized? func TestDeleteUserWorks(t *testing.T) { h.InitTest(t).TestOnStatusAndDB( h.TestRequest{ @@ -307,7 +297,6 @@ func TestDeleteUserWorks(t *testing.T) { ).Close() } -// TODO: how should this work now? func TestDeleteUserNotExist(t *testing.T) { uuid := uuid.New() @@ -350,7 +339,7 @@ func TestDeleteUserBadRequest(t *testing.T) { Role: &models.Super, }, h.ErrorWithTester{ - Error: errors.FailedToParseUUID, + Error: errors.FailedToValidateID, Tester: TestNumUsersRemainsAt1, }, )