From a9661bf5939fc2a0485deed5ae62c2f5f8cfe15e Mon Sep 17 00:00:00 2001 From: Garrett Ladley <92384606+garrettladley@users.noreply.github.com> Date: Fri, 24 May 2024 21:29:38 -0400 Subject: [PATCH] =?UTF-8?q?=F0=9F=A7=B9=20chore:=20fiber=20locals=20reorg?= =?UTF-8?q?=20(#905)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/entities/oauth/base/controller.go | 9 ++++---- backend/entities/users/base/controller.go | 4 ++-- backend/locals/claims.go | 27 +++++++++++++++++++++++ backend/locals/type.go | 8 +++++++ backend/locals/user_id.go | 27 +++++++++++++++++++++++ backend/middleware/auth/auth.go | 9 ++++---- backend/middleware/auth/club.go | 4 ++-- backend/middleware/auth/event.go | 4 ++-- backend/middleware/auth/user.go | 4 ++-- 9 files changed, 80 insertions(+), 16 deletions(-) create mode 100644 backend/locals/claims.go create mode 100644 backend/locals/type.go create mode 100644 backend/locals/user_id.go diff --git a/backend/entities/oauth/base/controller.go b/backend/entities/oauth/base/controller.go index 3a3e3d0c..2259f2f4 100644 --- a/backend/entities/oauth/base/controller.go +++ b/backend/entities/oauth/base/controller.go @@ -4,8 +4,9 @@ import ( "errors" "net/http" - "github.com/GenerateNU/sac/backend/auth" "github.com/GenerateNU/sac/backend/entities/models" + "github.com/GenerateNU/sac/backend/locals" + "github.com/gofiber/fiber/v2" ) @@ -25,7 +26,7 @@ func (oc *OAuthController) Authorize(c *fiber.Ctx) error { } // Extract the user making the call: - userID, err := auth.UserIDFrom(c) + userID, err := locals.UserID(c) if err != nil { return err } @@ -53,7 +54,7 @@ func (oc *OAuthController) Token(c *fiber.Ctx) error { } // Extract the user making the call: - userID, err := auth.UserIDFrom(c) + userID, err := locals.UserID(c) if err != nil { return err } @@ -75,7 +76,7 @@ func (oc *OAuthController) Revoke(c *fiber.Ctx) error { } // Extract the user making the call: - userID, err := auth.UserIDFrom(c) + userID, err := locals.UserID(c) if err != nil { return err } diff --git a/backend/entities/users/base/controller.go b/backend/entities/users/base/controller.go index 47ca36ac..91c596ac 100644 --- a/backend/entities/users/base/controller.go +++ b/backend/entities/users/base/controller.go @@ -3,8 +3,8 @@ package base import ( "net/http" - "github.com/GenerateNU/sac/backend/auth" authEntities "github.com/GenerateNU/sac/backend/entities/auth" + "github.com/GenerateNU/sac/backend/locals" "github.com/GenerateNU/sac/backend/utilities" "github.com/garrettladley/fiberpaginate" @@ -63,7 +63,7 @@ func (u *UserController) GetUsers(c *fiber.Ctx) error { // @Failure 500 {object} error // @Router /auth/me [get] func (u *UserController) GetMe(c *fiber.Ctx) error { - userID, err := auth.UserIDFrom(c) + userID, err := locals.UserID(c) if err != nil { return err } diff --git a/backend/locals/claims.go b/backend/locals/claims.go new file mode 100644 index 00000000..67155ef5 --- /dev/null +++ b/backend/locals/claims.go @@ -0,0 +1,27 @@ +package locals + +import ( + "fmt" + + "github.com/GenerateNU/sac/backend/auth" + "github.com/GenerateNU/sac/backend/utilities" + "github.com/gofiber/fiber/v2" +) + +func CustomClaims(c *fiber.Ctx) (*auth.CustomClaims, error) { + rawClaims := c.Locals(claimsKey) + if rawClaims == nil { + return nil, utilities.Forbidden() + } + + claims, ok := rawClaims.(*auth.CustomClaims) + if !ok { + return nil, fmt.Errorf("claims are not of type auth.CustomClaims. got: %T", rawClaims) + } + + return claims, nil +} + +func SetCustomClaims(c *fiber.Ctx, claims *auth.CustomClaims) { + c.Locals(claimsKey, claims) +} diff --git a/backend/locals/type.go b/backend/locals/type.go new file mode 100644 index 00000000..3c7c8036 --- /dev/null +++ b/backend/locals/type.go @@ -0,0 +1,8 @@ +package locals + +type localsKey byte + +const ( + claimsKey localsKey = 0 + userIDKey localsKey = 1 +) diff --git a/backend/locals/user_id.go b/backend/locals/user_id.go new file mode 100644 index 00000000..f2962438 --- /dev/null +++ b/backend/locals/user_id.go @@ -0,0 +1,27 @@ +package locals + +import ( + "fmt" + + "github.com/GenerateNU/sac/backend/utilities" + "github.com/gofiber/fiber/v2" + "github.com/google/uuid" +) + +func UserID(c *fiber.Ctx) (*uuid.UUID, error) { + userID := c.Locals(userIDKey) + if userID == nil { + return nil, utilities.Forbidden() + } + + id, ok := userID.(*uuid.UUID) + if !ok { + return nil, fmt.Errorf("userID is not of type uuid.UUID. got: %T", userID) + } + + return id, nil +} + +func SetUserID(c *fiber.Ctx, id *uuid.UUID) { + c.Locals(userIDKey, id) +} diff --git a/backend/middleware/auth/auth.go b/backend/middleware/auth/auth.go index e5a51fdb..eb950b3b 100644 --- a/backend/middleware/auth/auth.go +++ b/backend/middleware/auth/auth.go @@ -6,6 +6,7 @@ import ( "strings" "github.com/GenerateNU/sac/backend/auth" + "github.com/GenerateNU/sac/backend/locals" "github.com/GenerateNU/sac/backend/utilities" "github.com/google/uuid" @@ -17,7 +18,7 @@ import ( ) func (m *AuthMiddlewareService) IsSuper(c *fiber.Ctx) bool { - claims, err := auth.CustomClaimsFrom(c) + claims, err := locals.CustomClaims(c) if err != nil { _ = err return false @@ -67,7 +68,7 @@ func (m *AuthMiddlewareService) Authenticate(c *fiber.Ctx) error { // return errors.Unauthorized.FiberError(c) // } - auth.SetClaims(c, claims) + locals.SetCustomClaims(c, claims) rawUserID := claims.Issuer userID, err := uuid.Parse(rawUserID) @@ -75,7 +76,7 @@ func (m *AuthMiddlewareService) Authenticate(c *fiber.Ctx) error { return fmt.Errorf("invalid user id: %s", rawUserID) } - auth.SetUserID(c, &userID) + locals.SetUserID(c, &userID) return nil }(c) @@ -88,7 +89,7 @@ func (m *AuthMiddlewareService) Authorize(requiredPermissions ...auth.Permission return utilities.Unauthorized() } - claims, err := auth.CustomClaimsFrom(c) + claims, err := locals.CustomClaims(c) if err != nil { return err } diff --git a/backend/middleware/auth/club.go b/backend/middleware/auth/club.go index d44d636f..8531c78a 100644 --- a/backend/middleware/auth/club.go +++ b/backend/middleware/auth/club.go @@ -3,8 +3,8 @@ package auth import ( "slices" - "github.com/GenerateNU/sac/backend/auth" "github.com/GenerateNU/sac/backend/entities/clubs" + "github.com/GenerateNU/sac/backend/locals" "github.com/GenerateNU/sac/backend/utilities" "github.com/gofiber/fiber/v2" ) @@ -25,7 +25,7 @@ func (m *AuthMiddlewareService) ClubAuthorizeById(c *fiber.Ctx, extractor Extrac return err } - userID, err := auth.UserIDFrom(c) + userID, err := locals.UserID(c) if err != nil { return err } diff --git a/backend/middleware/auth/event.go b/backend/middleware/auth/event.go index 23bf6db1..5d0c13c3 100644 --- a/backend/middleware/auth/event.go +++ b/backend/middleware/auth/event.go @@ -3,8 +3,8 @@ package auth import ( "slices" - "github.com/GenerateNU/sac/backend/auth" "github.com/GenerateNU/sac/backend/entities/events" + "github.com/GenerateNU/sac/backend/locals" "github.com/GenerateNU/sac/backend/utilities" "github.com/gofiber/fiber/v2" ) @@ -25,7 +25,7 @@ func (m *AuthMiddlewareService) EventAuthorizeById(c *fiber.Ctx, extractor Extra return err } - userID, err := auth.UserIDFrom(c) + userID, err := locals.UserID(c) if err != nil { return err } diff --git a/backend/middleware/auth/user.go b/backend/middleware/auth/user.go index fe72f7e4..64757064 100644 --- a/backend/middleware/auth/user.go +++ b/backend/middleware/auth/user.go @@ -1,7 +1,7 @@ package auth import ( - "github.com/GenerateNU/sac/backend/auth" + "github.com/GenerateNU/sac/backend/locals" "github.com/GenerateNU/sac/backend/utilities" "github.com/gofiber/fiber/v2" ) @@ -21,7 +21,7 @@ func (m *AuthMiddlewareService) UserAuthorizeById(c *fiber.Ctx) error { return err } - userID, err := auth.UserIDFrom(c) + userID, err := locals.UserID(c) if err != nil { return err }