diff --git a/api/account.go b/api/account.go index 6011a5e..f2c8d53 100644 --- a/api/account.go +++ b/api/account.go @@ -3,6 +3,7 @@ package api import ( "database/sql" "errors" + "github.com/cukhoaimon/SimpleBank/token" "github.com/lib/pq" "net/http" @@ -23,8 +24,9 @@ func (server *Server) createAccount(ctx *gin.Context) { return } + authPayload := ctx.MustGet(authorizationPayloadKey).(*token.Payload) arg := db.CreateAccountParams{ - Owner: req.Owner, + Owner: authPayload.Username, Currency: req.Currency, Balance: 0, } @@ -70,12 +72,20 @@ func (server *Server) getAccount(ctx *gin.Context) { return } + authPayload := ctx.MustGet(authorizationPayloadKey).(*token.Payload) + if account.Owner != authPayload.Username { + err := errors.New("this account does not belong to the authenticated user") + ctx.JSON(http.StatusUnauthorized, errorResponse(err)) + return + } + ctx.JSON(http.StatusOK, account) } type listAccountRequest struct { - PageID int32 `form:"page_id" binding:"required,min=1"` - PageSize int32 `form:"page_size" binding:"required,min=5,max=10"` + Owner string `from:"owner"` + PageID int32 `form:"page_id" binding:"required,min=1"` + PageSize int32 `form:"page_size" binding:"required,min=5,max=10"` } func (server *Server) listAccount(ctx *gin.Context) { @@ -86,7 +96,9 @@ func (server *Server) listAccount(ctx *gin.Context) { return } + authPayload := ctx.MustGet(authorizationPayloadKey).(*token.Payload) arg := db.ListAccountsParams{ + Owner: authPayload.Username, Limit: req.PageSize, Offset: (req.PageID - 1) * req.PageSize, } diff --git a/api/account_test.go b/api/account_test.go index 170f897..5a7adab 100644 --- a/api/account_test.go +++ b/api/account_test.go @@ -9,6 +9,7 @@ import ( "net/http" "net/http/httptest" "testing" + "time" mockdb "github.com/cukhoaimon/SimpleBank/db/mock" db "github.com/cukhoaimon/SimpleBank/db/sqlc" @@ -19,7 +20,8 @@ import ( ) func TestServer_getAccount(t *testing.T) { - account := randomAccount() + user, _ := randomUser(t) + account := randomAccountWithUser(user) testCases := []struct { name string @@ -102,6 +104,7 @@ func TestServer_getAccount(t *testing.T) { request, err := http.NewRequest(http.MethodGet, url, nil) require.Nil(t, err) + addAuthorization(t, request, server.tokenMaker, authorizationTypeBearer, user.Username, time.Minute) server.router.ServeHTTP(recorder, request) // check response tc.checkResponse(t, recorder) @@ -205,6 +208,8 @@ func TestServer_createAccount(t *testing.T) { request, err := http.NewRequest(http.MethodPost, url, bytes.NewBuffer(data)) require.Nil(t, err) + addAuthorization(t, request, server.tokenMaker, authorizationTypeBearer, arg.Owner, time.Minute) + server.router.ServeHTTP(recorder, request) // check response tc.checkResponse(t, recorder) @@ -212,15 +217,19 @@ func TestServer_createAccount(t *testing.T) { } } +// TODO: add authorization to listAccount +// TODO: rewrite listAccount for more robust func TestServer_listAccount(t *testing.T) { + user, _ := randomUser(t) n := 10 accounts := make([]db.Account, n) for i := 0; i < n; i++ { - accounts = append(accounts, randomAccount()) + accounts = append(accounts, randomAccountWithUser(user)) } arg := listAccountRequest{ + Owner: user.Username, PageID: 1, PageSize: 5, } @@ -239,8 +248,9 @@ func TestServer_listAccount(t *testing.T) { store.EXPECT(). ListAccounts(gomock.Any(), gomock.Eq(db.ListAccountsParams{ - Offset: offset, + Owner: user.Username, Limit: arg.PageSize, + Offset: offset, })). Times(1). Return(accounts[offset:offset+arg.PageSize], nil) @@ -286,6 +296,7 @@ func TestServer_listAccount(t *testing.T) { store.EXPECT(). ListAccounts(gomock.Any(), gomock.Eq(db.ListAccountsParams{ + Owner: user.Username, Offset: offset, Limit: arg.PageSize, })). @@ -317,6 +328,7 @@ func TestServer_listAccount(t *testing.T) { // send request request, err := http.NewRequest(http.MethodGet, url, nil) require.Nil(t, err) + addAuthorization(t, request, server.tokenMaker, authorizationTypeBearer, user.Username, time.Minute) server.router.ServeHTTP(recorder, request) // check response @@ -334,6 +346,15 @@ func randomAccount() db.Account { } } +func randomAccountWithUser(user db.User) db.Account { + return db.Account{ + ID: utils.RandomInt(1, 1000), + Owner: user.Username, + Balance: utils.RandomMoney(), + Currency: utils.RandomCurrency(), + } +} + func requireBodyMatchAccount(t *testing.T, body *bytes.Buffer, account db.Account) { data, err := io.ReadAll(body) require.Nil(t, err) diff --git a/api/main_test.go b/api/main_test.go index ae88cbc..ae8f656 100644 --- a/api/main_test.go +++ b/api/main_test.go @@ -2,6 +2,7 @@ package api import ( db "github.com/cukhoaimon/SimpleBank/db/sqlc" + "github.com/cukhoaimon/SimpleBank/token" "github.com/cukhoaimon/SimpleBank/utils" "github.com/stretchr/testify/require" "os" @@ -17,9 +18,13 @@ func newTestServer(t *testing.T, store db.Store) *Server { TokenSymmetricKey: utils.RandomString(32), } + pasetoMaker, err := token.NewPasetoMaker(config.TokenSymmetricKey) + require.Nil(t, err) + server, err := NewServer(store, config) require.Nil(t, err) + server.tokenMaker = pasetoMaker return server } diff --git a/api/middleware.go b/api/middleware.go new file mode 100644 index 0000000..bd9fe10 --- /dev/null +++ b/api/middleware.go @@ -0,0 +1,54 @@ +package api + +import ( + "errors" + "fmt" + "github.com/cukhoaimon/SimpleBank/token" + "github.com/gin-gonic/gin" + "net/http" + "strings" +) + +const ( + authorizationHeaderKey = "authorization" + authorizationTypeBearer = "bearer" + authorizationPayloadKey = "authorization_payload" +) + +var ( + errAuthHeaderNotProvided = errors.New("authorization header is not provided") + errAuthHeaderInvalidFormat = errors.New("invalid authorization header format") +) + +func authMiddleware(tokenMaker token.Maker) gin.HandlerFunc { + return func(ctx *gin.Context) { + authorizationHeader := ctx.GetHeader(authorizationHeaderKey) + if len(authorizationHeader) == 0 { + ctx.AbortWithStatusJSON(http.StatusUnauthorized, errorResponse(errAuthHeaderNotProvided)) + return + } + + fields := strings.Fields(authorizationHeader) + if len(fields) < 2 { + ctx.AbortWithStatusJSON(http.StatusUnauthorized, errorResponse(errAuthHeaderInvalidFormat)) + return + } + + authorizationType := strings.ToLower(fields[0]) + if authorizationType != authorizationTypeBearer { + err := fmt.Errorf("authorization type %s is not supported", authorizationType) + ctx.AbortWithStatusJSON(http.StatusUnauthorized, errorResponse(err)) + return + } + + accessToken := fields[1] + payload, err := tokenMaker.VerifyToken(accessToken) + if err != nil { + ctx.AbortWithStatusJSON(http.StatusUnauthorized, errorResponse(err)) + return + } + + ctx.Set(authorizationPayloadKey, payload) + ctx.Next() + } +} diff --git a/api/middleware_test.go b/api/middleware_test.go new file mode 100644 index 0000000..1238f8d --- /dev/null +++ b/api/middleware_test.go @@ -0,0 +1,106 @@ +package api + +import ( + "fmt" + "github.com/cukhoaimon/SimpleBank/token" + "github.com/cukhoaimon/SimpleBank/utils" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" + "net/http" + "net/http/httptest" + "testing" + "time" +) + +func addAuthorization( + t *testing.T, + request *http.Request, + tokenMaker token.Maker, + authorizationType string, + username string, + duration time.Duration, +) { + accessToken, err := tokenMaker.CreateToken(username, duration) + require.Nil(t, err) + require.NotEmpty(t, accessToken) + + authorizationHeader := fmt.Sprintf("%s %s", authorizationType, accessToken) + request.Header.Set(authorizationHeaderKey, authorizationHeader) +} + +func Test_authMiddleware(t *testing.T) { + tests := []struct { + name string + setupAuth func(*testing.T, *http.Request, token.Maker) + checkResponse func(*testing.T, *httptest.ResponseRecorder) + }{ + { + name: "200 OK", + setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) { + addAuthorization(t, request, tokenMaker, authorizationTypeBearer, "user", time.Minute) + }, + checkResponse: func(t *testing.T, recorder *httptest.ResponseRecorder) { + require.Equal(t, http.StatusOK, recorder.Code) + }, + }, + { + name: "401 - Authorization not provide", + setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) { + }, + checkResponse: func(t *testing.T, recorder *httptest.ResponseRecorder) { + require.Equal(t, http.StatusUnauthorized, recorder.Code) + }, + }, + { + name: "401 invalid authorization header format", + setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) { + request.Header.Set("hehe", "hehe") + }, + checkResponse: func(t *testing.T, recorder *httptest.ResponseRecorder) { + require.Equal(t, http.StatusUnauthorized, recorder.Code) + }, + }, + { + name: "401 unsupported authorization type ", + setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) { + addAuthorization(t, request, tokenMaker, "sieu cap vo dich", "user", time.Minute) + }, + checkResponse: func(t *testing.T, recorder *httptest.ResponseRecorder) { + require.Equal(t, http.StatusUnauthorized, recorder.Code) + }, + }, + { + name: "401 invalid token", + setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) { + // provide JWT token, but server is using Paseto token => invalid token + jwtMaker, err := token.NewJWTMaker(utils.RandomString(32)) + require.Nil(t, err) + addAuthorization(t, request, jwtMaker, authorizationTypeBearer, "user", time.Minute) + }, + checkResponse: func(t *testing.T, recorder *httptest.ResponseRecorder) { + require.Equal(t, http.StatusUnauthorized, recorder.Code) + }, + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + server := newTestServer(t, nil) + authPath := "/auth" + server.router.GET( + authPath, + authMiddleware(server.tokenMaker), + func(ctx *gin.Context) { + ctx.JSON(http.StatusOK, gin.H{}) + }, + ) + + recorder := httptest.NewRecorder() + request, err := http.NewRequest(http.MethodGet, authPath, nil) + require.Nil(t, err) + + tc.setupAuth(t, request, server.tokenMaker) + server.router.ServeHTTP(recorder, request) + tc.checkResponse(t, recorder) + }) + } +} diff --git a/api/server.go b/api/server.go index 53ec590..be03cfa 100644 --- a/api/server.go +++ b/api/server.go @@ -46,11 +46,13 @@ func (server *Server) setupRouter() { router.POST("/api/v1/user", server.createUser) router.POST("/api/v1/user/login", server.loginUser) - router.GET("/api/v1/account", server.listAccount) - router.GET("/api/v1/account/:id", server.getAccount) - router.POST("/api/v1/account", server.createAccount) + authRoutes := router.Group("/").Use(authMiddleware(server.tokenMaker)) - router.POST("/api/v1/transfer", server.createTransfer) + authRoutes.GET("/api/v1/account", server.listAccount) + authRoutes.GET("/api/v1/account/:id", server.getAccount) + authRoutes.POST("/api/v1/account", server.createAccount) + + authRoutes.POST("/api/v1/transfer", server.createTransfer) server.router = router } diff --git a/api/transfer.go b/api/transfer.go index 20616da..983e2ad 100644 --- a/api/transfer.go +++ b/api/transfer.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" db "github.com/cukhoaimon/SimpleBank/db/sqlc" + "github.com/cukhoaimon/SimpleBank/token" "github.com/gin-gonic/gin" "net/http" ) @@ -24,6 +25,10 @@ func (server *Server) createTransfer(ctx *gin.Context) { return } + if !server.validOwner(ctx, req.FromAccountID) { + return + } + if !server.validAccount(ctx, req.FromAccountID, req.Currency) { return } @@ -47,6 +52,28 @@ func (server *Server) createTransfer(ctx *gin.Context) { ctx.JSON(http.StatusCreated, account) } +func (server *Server) validOwner(ctx *gin.Context, accountID int64) bool { + account, err := server.store.GetAccount(ctx, accountID) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + ctx.JSON(http.StatusNotFound, errorResponse(err)) + return false + } + + ctx.JSON(http.StatusInternalServerError, errorResponse(err)) + return false + } + + authPayload := ctx.MustGet(authorizationPayloadKey).(*token.Payload) + if account.Owner != authPayload.Username { + err = errors.New("from_account is not belong to the authorized user") + ctx.JSON(http.StatusUnauthorized, errorResponse(err)) + return false + } + + return true +} + func (server *Server) validAccount(ctx *gin.Context, accountID int64, currency string) bool { account, err := server.store.GetAccount(ctx, accountID) if err != nil { diff --git a/api/transfer_test.go b/api/transfer_test.go index 21a2ddf..bdc32b0 100644 --- a/api/transfer_test.go +++ b/api/transfer_test.go @@ -14,6 +14,7 @@ import ( "net/http" "net/http/httptest" "testing" + "time" ) func TestServer_createTransfer(t *testing.T) { @@ -80,7 +81,7 @@ func TestServer_createTransfer(t *testing.T) { buildStubs: func(store *mockdb.MockStore) { store.EXPECT(). GetAccount(gomock.Any(), gomock.Eq(fromAccount.ID)). - Times(1). + Times(2). Return(fromAccount, nil) store.EXPECT(). @@ -158,7 +159,7 @@ func TestServer_createTransfer(t *testing.T) { buildStubs: func(store *mockdb.MockStore) { store.EXPECT(). GetAccount(gomock.Any(), gomock.Eq(fromAccount.ID)). - Times(1). + Times(2). Return(fromAccount, nil) store.EXPECT(). @@ -180,7 +181,7 @@ func TestServer_createTransfer(t *testing.T) { buildStubs: func(store *mockdb.MockStore) { store.EXPECT(). GetAccount(gomock.Any(), gomock.Eq(fromAccount.ID)). - Times(1). + Times(2). Return(fromAccount, nil) store.EXPECT(). @@ -236,7 +237,7 @@ func TestServer_createTransfer(t *testing.T) { buildStubs: func(store *mockdb.MockStore) { store.EXPECT(). GetAccount(gomock.Any(), gomock.Eq(fromAccount.ID)). - Times(1). + Times(2). Return(fromAccount, nil) store.EXPECT(). @@ -280,6 +281,7 @@ func TestServer_createTransfer(t *testing.T) { request, err := http.NewRequest(http.MethodPost, url, bytes.NewBuffer(data)) require.Nil(t, err) + addAuthorization(t, request, server.tokenMaker, authorizationTypeBearer, fromAccount.Owner, time.Minute) server.router.ServeHTTP(recorder, request) // check response diff --git a/db/query/account.sql b/db/query/account.sql index 6921123..868c23c 100644 --- a/db/query/account.sql +++ b/db/query/account.sql @@ -24,9 +24,10 @@ FOR NO KEY UPDATE; -- name: ListAccounts :many SELECT * FROM accounts +WHERE owner = $1 ORDER BY id -LIMIT $1 -OFFSET $2; +LIMIT $2 +OFFSET $3; -- name: UpdateAccount :one UPDATE accounts diff --git a/db/sqlc/account.sql.go b/db/sqlc/account.sql.go index 25853af..4f8cb64 100644 --- a/db/sqlc/account.sql.go +++ b/db/sqlc/account.sql.go @@ -117,18 +117,20 @@ func (q *Queries) GetAccountForUpdate(ctx context.Context, id int64) (Account, e const listAccounts = `-- name: ListAccounts :many SELECT id, owner, balance, currency, created_at FROM accounts +WHERE owner = $1 ORDER BY id -LIMIT $1 -OFFSET $2 +LIMIT $2 +OFFSET $3 ` type ListAccountsParams struct { - Limit int32 `json:"limit"` - Offset int32 `json:"offset"` + Owner string `json:"owner"` + Limit int32 `json:"limit"` + Offset int32 `json:"offset"` } func (q *Queries) ListAccounts(ctx context.Context, arg ListAccountsParams) ([]Account, error) { - rows, err := q.db.QueryContext(ctx, listAccounts, arg.Limit, arg.Offset) + rows, err := q.db.QueryContext(ctx, listAccounts, arg.Owner, arg.Limit, arg.Offset) if err != nil { return nil, err } diff --git a/db/sqlc/account.sql_test.go b/db/sqlc/account.sql_test.go index f51b954..92b6122 100644 --- a/db/sqlc/account.sql_test.go +++ b/db/sqlc/account.sql_test.go @@ -86,20 +86,23 @@ func TestQueries_DeleteAccount(t *testing.T) { require.Empty(t, have) } +// TODO: Fix List Account in the ListAccountsParams func TestQueries_ListAccount(t *testing.T) { + var lastAccount Account for i := 0; i < 10; i++ { - createRandomAccount(t) + lastAccount = createRandomAccount(t) } arg := ListAccountsParams{ + Owner: lastAccount.Owner, Limit: 5, - Offset: 5, + Offset: 0, } have, err := testQuery.ListAccounts(context.Background(), arg) require.Nil(t, err) - require.Len(t, have, int(arg.Limit)) + require.Len(t, have, 1) for _, account := range have { require.NotEmpty(t, account)