Skip to content

Commit

Permalink
feat: add authorization logic for all api
Browse files Browse the repository at this point in the history
  • Loading branch information
cukhoaimon committed Feb 4, 2024
1 parent e3e94a0 commit b3bf75a
Show file tree
Hide file tree
Showing 11 changed files with 259 additions and 24 deletions.
18 changes: 15 additions & 3 deletions api/account.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package api
import (
"database/sql"
"errors"
"github.com/cukhoaimon/SimpleBank/token"
"github.com/lib/pq"
"net/http"

Expand All @@ -23,8 +24,9 @@ func (server *Server) createAccount(ctx *gin.Context) {
return
}

authPayload := ctx.MustGet(authorizationPayloadKey).(*token.Payload)
arg := db.CreateAccountParams{
Owner: req.Owner,
Owner: authPayload.Username,
Currency: req.Currency,
Balance: 0,
}
Expand Down Expand Up @@ -70,12 +72,20 @@ func (server *Server) getAccount(ctx *gin.Context) {
return
}

authPayload := ctx.MustGet(authorizationPayloadKey).(*token.Payload)
if account.Owner != authPayload.Username {
err := errors.New("this account does not belong to the authenticated user")
ctx.JSON(http.StatusUnauthorized, errorResponse(err))
return
}

ctx.JSON(http.StatusOK, account)
}

type listAccountRequest struct {
PageID int32 `form:"page_id" binding:"required,min=1"`
PageSize int32 `form:"page_size" binding:"required,min=5,max=10"`
Owner string `from:"owner"`
PageID int32 `form:"page_id" binding:"required,min=1"`
PageSize int32 `form:"page_size" binding:"required,min=5,max=10"`
}

func (server *Server) listAccount(ctx *gin.Context) {
Expand All @@ -86,7 +96,9 @@ func (server *Server) listAccount(ctx *gin.Context) {
return
}

authPayload := ctx.MustGet(authorizationPayloadKey).(*token.Payload)
arg := db.ListAccountsParams{
Owner: authPayload.Username,
Limit: req.PageSize,
Offset: (req.PageID - 1) * req.PageSize,
}
Expand Down
27 changes: 24 additions & 3 deletions api/account_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"net/http"
"net/http/httptest"
"testing"
"time"

mockdb "github.com/cukhoaimon/SimpleBank/db/mock"
db "github.com/cukhoaimon/SimpleBank/db/sqlc"
Expand All @@ -19,7 +20,8 @@ import (
)

func TestServer_getAccount(t *testing.T) {
account := randomAccount()
user, _ := randomUser(t)
account := randomAccountWithUser(user)

testCases := []struct {
name string
Expand Down Expand Up @@ -102,6 +104,7 @@ func TestServer_getAccount(t *testing.T) {
request, err := http.NewRequest(http.MethodGet, url, nil)
require.Nil(t, err)

addAuthorization(t, request, server.tokenMaker, authorizationTypeBearer, user.Username, time.Minute)
server.router.ServeHTTP(recorder, request)
// check response
tc.checkResponse(t, recorder)
Expand Down Expand Up @@ -205,22 +208,28 @@ func TestServer_createAccount(t *testing.T) {
request, err := http.NewRequest(http.MethodPost, url, bytes.NewBuffer(data))
require.Nil(t, err)

addAuthorization(t, request, server.tokenMaker, authorizationTypeBearer, arg.Owner, time.Minute)

server.router.ServeHTTP(recorder, request)
// check response
tc.checkResponse(t, recorder)
})
}
}

// TODO: add authorization to listAccount
// TODO: rewrite listAccount for more robust
func TestServer_listAccount(t *testing.T) {
user, _ := randomUser(t)
n := 10
accounts := make([]db.Account, n)

for i := 0; i < n; i++ {
accounts = append(accounts, randomAccount())
accounts = append(accounts, randomAccountWithUser(user))
}

arg := listAccountRequest{
Owner: user.Username,
PageID: 1,
PageSize: 5,
}
Expand All @@ -239,8 +248,9 @@ func TestServer_listAccount(t *testing.T) {

store.EXPECT().
ListAccounts(gomock.Any(), gomock.Eq(db.ListAccountsParams{
Offset: offset,
Owner: user.Username,
Limit: arg.PageSize,
Offset: offset,
})).
Times(1).
Return(accounts[offset:offset+arg.PageSize], nil)
Expand Down Expand Up @@ -286,6 +296,7 @@ func TestServer_listAccount(t *testing.T) {

store.EXPECT().
ListAccounts(gomock.Any(), gomock.Eq(db.ListAccountsParams{
Owner: user.Username,
Offset: offset,
Limit: arg.PageSize,
})).
Expand Down Expand Up @@ -317,6 +328,7 @@ func TestServer_listAccount(t *testing.T) {
// send request
request, err := http.NewRequest(http.MethodGet, url, nil)
require.Nil(t, err)
addAuthorization(t, request, server.tokenMaker, authorizationTypeBearer, user.Username, time.Minute)

server.router.ServeHTTP(recorder, request)
// check response
Expand All @@ -334,6 +346,15 @@ func randomAccount() db.Account {
}
}

func randomAccountWithUser(user db.User) db.Account {
return db.Account{
ID: utils.RandomInt(1, 1000),
Owner: user.Username,
Balance: utils.RandomMoney(),
Currency: utils.RandomCurrency(),
}
}

func requireBodyMatchAccount(t *testing.T, body *bytes.Buffer, account db.Account) {
data, err := io.ReadAll(body)
require.Nil(t, err)
Expand Down
5 changes: 5 additions & 0 deletions api/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package api

import (
db "github.com/cukhoaimon/SimpleBank/db/sqlc"
"github.com/cukhoaimon/SimpleBank/token"
"github.com/cukhoaimon/SimpleBank/utils"
"github.com/stretchr/testify/require"
"os"
Expand All @@ -17,9 +18,13 @@ func newTestServer(t *testing.T, store db.Store) *Server {
TokenSymmetricKey: utils.RandomString(32),
}

pasetoMaker, err := token.NewPasetoMaker(config.TokenSymmetricKey)
require.Nil(t, err)

server, err := NewServer(store, config)
require.Nil(t, err)

server.tokenMaker = pasetoMaker
return server
}

Expand Down
54 changes: 54 additions & 0 deletions api/middleware.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
package api

import (
"errors"
"fmt"
"github.com/cukhoaimon/SimpleBank/token"
"github.com/gin-gonic/gin"
"net/http"
"strings"
)

const (
authorizationHeaderKey = "authorization"
authorizationTypeBearer = "bearer"
authorizationPayloadKey = "authorization_payload"
)

var (
errAuthHeaderNotProvided = errors.New("authorization header is not provided")
errAuthHeaderInvalidFormat = errors.New("invalid authorization header format")
)

func authMiddleware(tokenMaker token.Maker) gin.HandlerFunc {
return func(ctx *gin.Context) {
authorizationHeader := ctx.GetHeader(authorizationHeaderKey)
if len(authorizationHeader) == 0 {
ctx.AbortWithStatusJSON(http.StatusUnauthorized, errorResponse(errAuthHeaderNotProvided))
return
}

fields := strings.Fields(authorizationHeader)
if len(fields) < 2 {
ctx.AbortWithStatusJSON(http.StatusUnauthorized, errorResponse(errAuthHeaderInvalidFormat))
return
}

authorizationType := strings.ToLower(fields[0])
if authorizationType != authorizationTypeBearer {
err := fmt.Errorf("authorization type %s is not supported", authorizationType)
ctx.AbortWithStatusJSON(http.StatusUnauthorized, errorResponse(err))
return
}

accessToken := fields[1]
payload, err := tokenMaker.VerifyToken(accessToken)
if err != nil {
ctx.AbortWithStatusJSON(http.StatusUnauthorized, errorResponse(err))
return
}

ctx.Set(authorizationPayloadKey, payload)
ctx.Next()
}
}
106 changes: 106 additions & 0 deletions api/middleware_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
package api

import (
"fmt"
"github.com/cukhoaimon/SimpleBank/token"
"github.com/cukhoaimon/SimpleBank/utils"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
"net/http"
"net/http/httptest"
"testing"
"time"
)

func addAuthorization(
t *testing.T,
request *http.Request,
tokenMaker token.Maker,
authorizationType string,
username string,
duration time.Duration,
) {
accessToken, err := tokenMaker.CreateToken(username, duration)
require.Nil(t, err)
require.NotEmpty(t, accessToken)

authorizationHeader := fmt.Sprintf("%s %s", authorizationType, accessToken)
request.Header.Set(authorizationHeaderKey, authorizationHeader)
}

func Test_authMiddleware(t *testing.T) {
tests := []struct {
name string
setupAuth func(*testing.T, *http.Request, token.Maker)
checkResponse func(*testing.T, *httptest.ResponseRecorder)
}{
{
name: "200 OK",
setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) {
addAuthorization(t, request, tokenMaker, authorizationTypeBearer, "user", time.Minute)
},
checkResponse: func(t *testing.T, recorder *httptest.ResponseRecorder) {
require.Equal(t, http.StatusOK, recorder.Code)
},
},
{
name: "401 - Authorization not provide",
setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) {
},
checkResponse: func(t *testing.T, recorder *httptest.ResponseRecorder) {
require.Equal(t, http.StatusUnauthorized, recorder.Code)
},
},
{
name: "401 invalid authorization header format",
setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) {
request.Header.Set("hehe", "hehe")
},
checkResponse: func(t *testing.T, recorder *httptest.ResponseRecorder) {
require.Equal(t, http.StatusUnauthorized, recorder.Code)
},
},
{
name: "401 unsupported authorization type ",
setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) {
addAuthorization(t, request, tokenMaker, "sieu cap vo dich", "user", time.Minute)
},
checkResponse: func(t *testing.T, recorder *httptest.ResponseRecorder) {
require.Equal(t, http.StatusUnauthorized, recorder.Code)
},
},
{
name: "401 invalid token",
setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) {
// provide JWT token, but server is using Paseto token => invalid token
jwtMaker, err := token.NewJWTMaker(utils.RandomString(32))
require.Nil(t, err)
addAuthorization(t, request, jwtMaker, authorizationTypeBearer, "user", time.Minute)
},
checkResponse: func(t *testing.T, recorder *httptest.ResponseRecorder) {
require.Equal(t, http.StatusUnauthorized, recorder.Code)
},
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
server := newTestServer(t, nil)
authPath := "/auth"
server.router.GET(
authPath,
authMiddleware(server.tokenMaker),
func(ctx *gin.Context) {
ctx.JSON(http.StatusOK, gin.H{})
},
)

recorder := httptest.NewRecorder()
request, err := http.NewRequest(http.MethodGet, authPath, nil)
require.Nil(t, err)

tc.setupAuth(t, request, server.tokenMaker)
server.router.ServeHTTP(recorder, request)
tc.checkResponse(t, recorder)
})
}
}
10 changes: 6 additions & 4 deletions api/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,13 @@ func (server *Server) setupRouter() {
router.POST("/api/v1/user", server.createUser)
router.POST("/api/v1/user/login", server.loginUser)

router.GET("/api/v1/account", server.listAccount)
router.GET("/api/v1/account/:id", server.getAccount)
router.POST("/api/v1/account", server.createAccount)
authRoutes := router.Group("/").Use(authMiddleware(server.tokenMaker))

router.POST("/api/v1/transfer", server.createTransfer)
authRoutes.GET("/api/v1/account", server.listAccount)
authRoutes.GET("/api/v1/account/:id", server.getAccount)
authRoutes.POST("/api/v1/account", server.createAccount)

authRoutes.POST("/api/v1/transfer", server.createTransfer)

server.router = router
}
Expand Down
Loading

0 comments on commit b3bf75a

Please sign in to comment.