Skip to content

Commit

Permalink
feat: auth middleware (#17)
Browse files Browse the repository at this point in the history
  • Loading branch information
ImSoZRious authored Dec 27, 2023
1 parent 3411069 commit e366ab4
Show file tree
Hide file tree
Showing 7 changed files with 166 additions and 65 deletions.
9 changes: 2 additions & 7 deletions cmd/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package main
import (
"fmt"

"github.com/gin-gonic/gin"
"github.com/isd-sgcu/oph66-backend/di"
"github.com/isd-sgcu/oph66-backend/docs"
swaggerFiles "github.com/swaggo/files"
Expand All @@ -27,12 +26,8 @@ func main() {

docs.SwaggerInfo.Host = container.Config.AppConfig.Host

if !container.Config.AppConfig.IsDevelopment() {
gin.SetMode(gin.ReleaseMode)
}
r := gin.Default()
r := container.Router

r.Use(gin.HandlerFunc(container.CorsHandler))
r.GET("/_hc", container.HcHandler.HealthCheck)
r.GET("/live", container.FeatureflagHandler.GetLivestreamInfo)
r.GET("/events", container.EventHandler.GetAllEvents)
Expand All @@ -42,7 +37,7 @@ func main() {
r.GET("/auth/login", container.AuthHandler.GoogleLogin)
r.GET("/auth/callback", container.AuthHandler.GoogleCallback)

if container.Config.AppConfig.Env == "development" {
if container.Config.AppConfig.IsDevelopment() {
r.GET("/swagger/*any", ginSwagger.WrapHandler(swaggerFiles.Handler))
}
if err := r.Run(fmt.Sprintf(":%v", container.Config.AppConfig.Port)); err != nil {
Expand Down
17 changes: 16 additions & 1 deletion di/wire.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ import (
event "github.com/isd-sgcu/oph66-backend/internal/event"
featureflag "github.com/isd-sgcu/oph66-backend/internal/feature_flag"
healthcheck "github.com/isd-sgcu/oph66-backend/internal/health_check"
"github.com/isd-sgcu/oph66-backend/internal/middleware"
"github.com/isd-sgcu/oph66-backend/internal/router"
"github.com/isd-sgcu/oph66-backend/logger"
"go.uber.org/zap"
)
Expand All @@ -24,9 +26,19 @@ type Container struct {
Config *cfgldr.Config
Logger *zap.Logger
CorsHandler cfgldr.CorsHandler
Router *router.Router
}

func newContainer(eventHandler event.Handler, hcHandler healthcheck.Handler, featureflagHandler featureflag.Handler, authHandler auth.Handler, config *cfgldr.Config, logger *zap.Logger, corsHandler cfgldr.CorsHandler) Container {
func newContainer(
eventHandler event.Handler,
hcHandler healthcheck.Handler,
featureflagHandler featureflag.Handler,
authHandler auth.Handler,
config *cfgldr.Config,
logger *zap.Logger,
corsHandler cfgldr.CorsHandler,
router *router.Router,
) Container {
return Container{
eventHandler,
hcHandler,
Expand All @@ -35,6 +47,7 @@ func newContainer(eventHandler event.Handler, hcHandler healthcheck.Handler, fea
config,
logger,
corsHandler,
router,
}
}

Expand All @@ -58,6 +71,8 @@ func Init() (Container, error) {
auth.NewService,
auth.NewRepository,
logger.InitLogger,
router.NewRouter,
middleware.NewAuthMiddleware,
)

return Container{}, nil
Expand Down
18 changes: 15 additions & 3 deletions di/wire_gen.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

32 changes: 17 additions & 15 deletions internal/auth/auth.handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package auth

import (
"net/http"
"strings"

"github.com/gin-gonic/gin"
"github.com/isd-sgcu/oph66-backend/apperror"
Expand Down Expand Up @@ -84,23 +83,25 @@ func (h *handlerImpl) GoogleCallback(c *gin.Context) {
func (h *handlerImpl) Register(c *gin.Context) {
var data RegisterRequestDTO
var user model.User
authHeader := c.GetHeader("Authorization")
if authHeader == "" {
emailRaw, exist := c.Get("email")
if !exist {
utils.ReturnError(c, apperror.Unauthorized)
return
}
if !strings.HasPrefix(authHeader, "Bearer ") {
utils.ReturnError(c, apperror.InvalidToken)

email, ok := emailRaw.(string)
if !ok {
h.logger.Error("email string assertion failed", zap.Any("emailRaw", emailRaw))
utils.ReturnError(c, apperror.InternalError)
return
}
token := strings.Replace(authHeader, "Bearer ", "", 1)

if err := c.ShouldBindJSON(&data); err != nil {
utils.ReturnError(c, apperror.BadRequest)
return
}

apperr := h.svc.Register(c, &data, token, &user)
apperr := h.svc.Register(email, &data, &user)
if apperr != nil {
utils.ReturnError(c, apperr)
return
Expand All @@ -125,20 +126,21 @@ func (h *handlerImpl) Register(c *gin.Context) {
// @Failure 401 {object} auth.GetProfileUnauthorized
// @Failure 404 {object} auth.GetProfileUserNotFound
func (h *handlerImpl) GetProfile(c *gin.Context) {
var user model.User
authHeader := c.GetHeader("Authorization")
if authHeader == "" {
emailRaw, exist := c.Get("email")
if !exist {
utils.ReturnError(c, apperror.Unauthorized)
return
}
if !strings.HasPrefix(authHeader, "Bearer ") {
utils.ReturnError(c, apperror.InvalidToken)

email, ok := emailRaw.(string)
if !ok {
h.logger.Error("email string assertion failed", zap.Any("emailRaw", emailRaw))
utils.ReturnError(c, apperror.InternalError)
return
}

token := strings.Replace(authHeader, "Bearer ", "", 1)

apperr := h.svc.GetUserFromJWTToken(c, token, &user)
var user model.User
apperr := h.svc.GetUserFromJWTToken(email, &user)
if apperr != nil {
utils.ReturnError(c, apperr)
return
Expand Down
52 changes: 13 additions & 39 deletions internal/auth/auth.service.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,22 @@ package auth

import (
"context"
"errors"

"github.com/isd-sgcu/oph66-backend/apperror"
"github.com/isd-sgcu/oph66-backend/cfgldr"
"github.com/isd-sgcu/oph66-backend/internal/model"
"go.uber.org/zap"
"golang.org/x/oauth2"
"golang.org/x/oauth2/google"
"google.golang.org/api/idtoken"
"gorm.io/gorm"
)

type Service interface {
GoogleLogin() (url string)
GoogleCallback(ctx context.Context, code string) (idToken string, appErr *apperror.AppError)
Register(ctx context.Context, data *RegisterRequestDTO, tokenString string, user *model.User) *apperror.AppError
GetUserFromJWTToken(ctx context.Context, tokenString string, user *model.User) *apperror.AppError
Register(email string, data *RegisterRequestDTO, user *model.User) *apperror.AppError
GetUserFromJWTToken(email string, user *model.User) *apperror.AppError
}

func NewService(repo Repository, logger *zap.Logger, cfg *cfgldr.Config) Service {
Expand Down Expand Up @@ -63,51 +64,24 @@ func (s *serviceImpl) GoogleCallback(ctx context.Context, code string) (idToken
return rawIdToken.(string), nil
}

func (s *serviceImpl) Register(ctx context.Context, data *RegisterRequestDTO, token string, user *model.User) *apperror.AppError {
email, apperr := getEmailFromToken(ctx, token, s.cfg.OAuth2Config.ClientID)
if apperr != nil {
return apperr
}

err := s.repo.GetUserByEmail(user, email)
if err != nil {
user = ConvertRegisterRequestDTOToUser(data, email)
err = s.repo.CreateUser(user)
if err != nil {
s.logger.Error("Failed to create user", zap.Error(err))
return apperror.InternalError
}
} else {
func (s *serviceImpl) Register(email string, data *RegisterRequestDTO, user *model.User) *apperror.AppError {
user = ConvertRegisterRequestDTOToUser(data, email)
err := s.repo.CreateUser(user)
if errors.Is(err, gorm.ErrDuplicatedKey) {
return apperror.DuplicateEmail
} else if err != nil {
s.logger.Error("Failed to create user", zap.Error(err))
return apperror.InternalError
}

return nil
}

func (s *serviceImpl) GetUserFromJWTToken(ctx context.Context, token string, user *model.User) *apperror.AppError {
email, apperr := getEmailFromToken(ctx, token, s.cfg.OAuth2Config.ClientID)
if apperr != nil {
return apperr
}

err := s.repo.GetUserByEmail(user, email)
func (s *serviceImpl) GetUserFromJWTToken(email string, result *model.User) *apperror.AppError {
err := s.repo.GetUserByEmail(result, email)
if err != nil {
return apperror.UserNotFound
}

return nil
}

func getEmailFromToken(ctx context.Context, tokenString string, clientID string) (email string, appErr *apperror.AppError) {
token, err := idtoken.Validate(ctx, tokenString, clientID)
if err != nil {
return "", apperror.InvalidToken
}

email, ok := token.Claims["email"].(string)
if !ok || email == "" {
return "", apperror.InvalidToken
}

return email, nil
}
74 changes: 74 additions & 0 deletions internal/middleware/auth.middleware.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
package middleware

import (
"strings"

"github.com/gin-gonic/gin"
"github.com/isd-sgcu/oph66-backend/apperror"
"github.com/isd-sgcu/oph66-backend/cfgldr"
"github.com/isd-sgcu/oph66-backend/internal/auth"
"github.com/isd-sgcu/oph66-backend/utils"
"google.golang.org/api/idtoken"
)

type AuthMiddleware gin.HandlerFunc

func NewAuthMiddleware(userRepo auth.Repository, cfg *cfgldr.Config) AuthMiddleware {

return func(c *gin.Context) {

authHeader := c.GetHeader("Authorization")

if authHeader == "" {

c.Next()

return

}

if !strings.HasPrefix(authHeader, "Bearer ") {

utils.ReturnError(c, apperror.InvalidToken)

c.Abort()

return

}

tokenString := strings.Replace(authHeader, "Bearer ", "", 1)

token, err := idtoken.Validate(c, tokenString, cfg.OAuth2Config.ClientID)

if err != nil {

utils.ReturnError(c, apperror.InvalidToken)

c.Abort()

return

}

if email, ok := token.Claims["email"].(string); ok {

c.Set("email", email)

c.Next()

return

} else {

utils.ReturnError(c, apperror.ServiceUnavailable)

c.Abort()

return

}

}

}
29 changes: 29 additions & 0 deletions internal/router/router.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
package router

import (
"github.com/gin-gonic/gin"
"github.com/isd-sgcu/oph66-backend/cfgldr"
"github.com/isd-sgcu/oph66-backend/internal/middleware"
)

type Router struct {
*gin.Engine
}

func NewRouter(config *cfgldr.Config, corsHandler cfgldr.CorsHandler, authMiddleware middleware.AuthMiddleware) *Router {

if !config.AppConfig.IsDevelopment() {

gin.SetMode(gin.ReleaseMode)

}

r := gin.Default()

r.Use(gin.HandlerFunc(corsHandler))

r.Use(gin.HandlerFunc(authMiddleware))

return &Router{r}

}

0 comments on commit e366ab4

Please sign in to comment.