diff --git a/.env.example b/.env.example index 1c4a42c..f662ab8 100644 --- a/.env.example +++ b/.env.example @@ -13,3 +13,6 @@ CAPTCHA_SECRET= HTTPS_ENABLED=false HTTPS_CRT= HTTPS_KEY= + +CSRF_ENABLED=true +REDIS_CONNECTION=localhost:6379 diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index ded091c..7338357 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -2,8 +2,6 @@ name: build on: workflow_dispatch: - release: - types: [ published ] env: REGISTRY: docker.io diff --git a/.github/workflows/deploy.yml b/.github/workflows/deploy.yml index 0dee917..6a32a02 100644 --- a/.github/workflows/deploy.yml +++ b/.github/workflows/deploy.yml @@ -18,6 +18,7 @@ env: jobs: deploy: runs-on: [self-hosted, Linux, x64] + environment: prod permissions: contents: read packages: write diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml new file mode 100644 index 0000000..9a511af --- /dev/null +++ b/.github/workflows/release.yml @@ -0,0 +1,139 @@ +name: release + +on: + workflow_dispatch: + release: + types: [ published ] + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: false + +env: + REGISTRY: docker.io + REPO: cashtrack/gateway + INFRA_REPO: cash-track/infra + INFRA_REPO_REF: main + CLUSTER: k8s-cash-track + NAMESPACE: cash-track + KUBECTL_BIN: ${{ vars.KUBECTL_BIN_URL }} + +jobs: + build: + runs-on: [self-hosted, Linux, x64] + permissions: + contents: read + packages: write + id-token: write + attestations: write + + steps: + - name: Checkout repository + if: github.event_name != 'pull_request' + uses: actions/checkout@v4 + + # Login against a Docker registry except on PR + # https://github.com/docker/login-action + - name: Login to Docker Hub + if: github.event_name != 'pull_request' + uses: docker/login-action@v3 + with: + username: ${{ secrets.DOCKER_HUB_USER }} + password: ${{ secrets.DOCKER_HUB_TOKEN }} + + # Extract metadata (tags, labels) for Docker + # https://github.com/docker/metadata-action + - name: Extract Docker metadata + if: github.event_name != 'pull_request' + id: meta + uses: docker/metadata-action@v5 + with: + images: ${{ env.REPO }} + tags: | + type=sha + type=semver,pattern={{version}} + + # Setup BuildX + # https://github.com/docker/setup-buildx-action + - name: Setup BuildX + uses: docker/setup-buildx-action@v3 + id: buildx + with: + install: true + + # Build and push Docker image with Build (don't push on PR) + # https://github.com/docker/build-push-action + - name: Build and push + uses: docker/build-push-action@v6 + id: push + with: + context: . + push: ${{ github.event_name != 'pull_request' }} + tags: ${{ steps.meta.outputs.tags }} + labels: ${{ steps.meta.outputs.labels }} + build-args: | + GIT_COMMIT=${{ github.sha }} + GIT_TAG=${{ github.ref_name }} + + - name: Attest + uses: actions/attest-build-provenance@v1 + id: attest + with: + subject-name: ${{ env.REGISTRY }}/${{ env.REPO }} + subject-digest: ${{ steps.push.outputs.digest }} + push-to-registry: true + + deploy: + runs-on: [self-hosted, Linux, x64] + environment: prod + permissions: + contents: read + packages: write + + steps: + - name: Checkout infra repository + uses: actions/checkout@v4 + with: + repository: ${{ env.INFRA_REPO }} + ref: ${{ env.INFRA_REPO_REF }} + path: deploy + + - name: Install doctl + uses: digitalocean/action-doctl@v2 + with: + token: ${{ secrets.DIGITALOCEAN_ACCESS_TOKEN }} + + - name: Install kubectl + run: | + curl -LO ${{ env.KUBECTL_BIN }} + chmod +x ./kubectl + sudo mv ./kubectl /usr/local/bin/kubectl + + - name: Configure kubectl + run: doctl kubernetes cluster kubeconfig save --expiry-seconds 600 ${{ env.CLUSTER }} + + # Extract metadata (tags, labels) for Docker + # https://github.com/docker/metadata-action + - name: Extract Docker metadata + if: github.event_name != 'pull_request' + id: meta + uses: docker/metadata-action@v5 + with: + images: ${{ env.REPO }} + tags: | + type=semver,pattern={{version}} + + - name: Update deployment + env: + IMAGE: ${{ env.REPO }}:${{ fromJSON(steps.meta.outputs.json).labels['org.opencontainers.image.version'] }} + run: sed -i 's|${{ env.REPO }}:latest|'${IMAGE}'|' $GITHUB_WORKSPACE/deploy/services/gateway/deployment.yml + + - name: Update definition + run: kubectl apply -f $GITHUB_WORKSPACE/deploy/services/gateway/ + + - name: Verify deployment + run: kubectl -n ${{ env.NAMESPACE }} rollout status deployment/gateway + + - name: Verify service ready + run: kubectl -n ${{ env.NAMESPACE }} wait pods -l app=gateway --for condition=Ready --timeout=120s + diff --git a/Makefile b/Makefile index 1c5c324..551ec76 100644 --- a/Makefile +++ b/Makefile @@ -50,4 +50,5 @@ mock-gen: mockgen -source=captcha/provider.go -package=mocks -destination=mocks/captcha_provider_mock.go -mock_names=Provider=CaptchaProviderMock mockgen -source=service/api/service.go -package=mocks -destination=mocks/api_service_mock.go -mock_names=Service=ApiServiceMock mockgen -source=router/api/handler.go -package=mocks -destination=mocks/api_handler_mock.go -mock_names=Handler=ApiHandlerMock + mockgen -source=router/csrf/handler.go -package=mocks -destination=mocks/csrf_handler_mock.go -mock_names=Handler=CsrfHandlerMock diff --git a/config/config.go b/config/config.go index ef16606..e014ab5 100644 --- a/config/config.go +++ b/config/config.go @@ -28,6 +28,9 @@ type Config struct { CorsAllowedOrigins map[string]bool DebugHttp bool + + CsrfEnabled bool + RedisConnection string } var Global Config @@ -57,6 +60,9 @@ func (c *Config) Load() { c.CookieSecure = getCookieSecure(c.GatewayUrl) c.CorsAllowedOrigins = getCorsAllowedOrigins(getEnv("CORS_ALLOWED_ORIGINS", "")) + + c.CsrfEnabled = getEnv("CSRF_ENABLED", "") == "true" + c.RedisConnection = getEnv("REDIS_CONNECTION", "localhost:6379") } func getEnv(key, def string) string { diff --git a/config/config_test.go b/config/config_test.go index 2f6dd5b..3ca1ca9 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -15,6 +15,8 @@ func TestConfigLoad(t *testing.T) { _ = os.Setenv("GATEWAY_URL", "https://gateway.dev.cash-track.app:8081") _ = os.Setenv("HTTPS_ENABLED", "true") _ = os.Setenv("CORS_ALLOWED_ORIGINS", "https://My.dev.cash-track.app:3001,https://Dev.cash-track.app:3000") + _ = os.Setenv("CSRF_ENABLED", "true") + _ = os.Setenv("REDIS_CONNECTION", "redis:1234") config := &Config{} config.Load() @@ -42,6 +44,9 @@ func TestConfigLoad(t *testing.T) { _, ok = config.CorsAllowedOrigins["https://dev.cash-track.app:3000"] assert.Equal(t, true, ok) + + assert.Equal(t, true, config.CsrfEnabled) + assert.Equal(t, "redis:1234", config.RedisConnection) } func TestConfigLoadUnexpectedApiUrl(t *testing.T) { diff --git a/go.mod b/go.mod index df0cb42..b09d752 100644 --- a/go.mod +++ b/go.mod @@ -15,6 +15,10 @@ require ( github.com/beorn7/perks v1.0.1 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/davecgh/go-spew v1.1.1 // indirect + github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect + github.com/go-redis/redismock/v9 v9.2.0 // indirect + github.com/golang-jwt/jwt/v4 v4.5.1 // indirect + github.com/google/uuid v1.6.0 // indirect github.com/klauspost/compress v1.17.9 // indirect github.com/kr/text v0.2.0 // indirect github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect @@ -23,6 +27,7 @@ require ( github.com/prometheus/client_model v0.6.1 // indirect github.com/prometheus/common v0.59.1 // indirect github.com/prometheus/procfs v0.15.1 // indirect + github.com/redis/go-redis/v9 v9.7.0 // indirect github.com/rogpeppe/go-internal v1.12.0 // indirect github.com/savsgio/gotils v0.0.0-20240704082632-aef3928b8a38 // indirect github.com/valyala/bytebufferpool v1.0.0 // indirect diff --git a/go.sum b/go.sum index b5d4d74..76f9b6f 100644 --- a/go.sum +++ b/go.sum @@ -7,12 +7,20 @@ github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XL github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= +github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= github.com/fasthttp/router v1.5.2 h1:ckJCCdV7hWkkrMeId3WfEhz+4Gyyf6QPwxi/RHIMZ6I= github.com/fasthttp/router v1.5.2/go.mod h1:C8EY53ozOwpONyevc/V7Gr8pqnEjwnkFFqPo1alAGs0= github.com/flf2ko/fasthttp-prometheus v0.1.0 h1:hj4K3TwJ2B7Fe2E7lWE/eb9mtb7gBvwURXr4+iEFoCI= github.com/flf2ko/fasthttp-prometheus v0.1.0/go.mod h1:5tGRWsJeP8ABLYovqPxa5c/zCgnsYUhhC1ivs/Kv/c4= +github.com/go-redis/redismock/v9 v9.2.0 h1:ZrMYQeKPECZPjOj5u9eyOjg8Nnb0BS9lkVIZ6IpsKLw= +github.com/go-redis/redismock/v9 v9.2.0/go.mod h1:18KHfGDK4Y6c2R0H38EUGWAdc7ZQS9gfYxc94k7rWT0= +github.com/golang-jwt/jwt/v4 v4.5.1 h1:JdqV9zKUdtaa9gdPlywC3aeoEsR681PlKC+4F5gQgeo= +github.com/golang-jwt/jwt/v4 v4.5.1/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0= github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/klauspost/compress v1.17.9 h1:6KIumPrER1LHsvBVuDa0r5xaG0Es51mhhB9BQB2qeMA= github.com/klauspost/compress v1.17.9/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= @@ -33,6 +41,8 @@ github.com/prometheus/common v0.59.1 h1:LXb1quJHWm1P6wq/U824uxYi4Sg0oGvNeUm1z5dJ github.com/prometheus/common v0.59.1/go.mod h1:GpWM7dewqmVYcd7SmRaiWVe9SSqjf0UrwnYnpEZNuT0= github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0learggepc= github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoGhij/e3PBqk= +github.com/redis/go-redis/v9 v9.7.0 h1:HhLSs+B6O021gwzl+locl0zEDnyNkxMtf/Z3NNBMa9E= +github.com/redis/go-redis/v9 v9.7.0/go.mod h1:f6zhXITC7JUJIlPEiBOTXxJgPLdZcA93GewI7inzyWw= github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8= github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4= github.com/savsgio/gotils v0.0.0-20240704082632-aef3928b8a38 h1:D0vL7YNisV2yqE55+q0lFuGse6U8lxlg7fYTctlT5Gc= diff --git a/headers/cookie/csrf.go b/headers/cookie/csrf.go new file mode 100644 index 0000000..e6cb640 --- /dev/null +++ b/headers/cookie/csrf.go @@ -0,0 +1,38 @@ +package cookie + +import ( + "time" + + "github.com/valyala/fasthttp" +) + +const ( + CsrfTokenCookieName = "cshtrkcsrf" + CsrfTokenTtl = time.Minute * 10 +) + +type CSRF struct { + Auth Auth + Token string +} + +func ReadCSRFCookie(ctx *fasthttp.RequestCtx) CSRF { + csrf := CSRF{ + Auth: ReadAuthCookie(ctx), + } + + if val := ctx.Request.Header.Cookie(CsrfTokenCookieName); val != nil { + csrf.Token = string(val) + } + + return csrf +} + +func (c CSRF) WriteCookie(ctx *fasthttp.RequestCtx) { + if !c.Auth.IsLogged() { + ctx.Response.Header.SetCookie(newCookie(CsrfTokenCookieName, "", fasthttp.CookieExpireDelete)) + return + } + + ctx.Response.Header.SetCookie(newCookie(CsrfTokenCookieName, c.Token, time.Now().Add(CsrfTokenTtl))) +} diff --git a/headers/cookie/csrf_test.go b/headers/cookie/csrf_test.go new file mode 100644 index 0000000..a6fea74 --- /dev/null +++ b/headers/cookie/csrf_test.go @@ -0,0 +1,51 @@ +package cookie + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/valyala/fasthttp" +) + +func TestReadCSRFCookie(t *testing.T) { + ctx := fasthttp.RequestCtx{} + ctx.Request.Header.SetCookie(AccessTokenCookieName, "access_token") + ctx.Request.Header.SetCookie(CsrfTokenCookieName, "csrf_token") + + csrf := ReadCSRFCookie(&ctx) + + assert.Equal(t, "csrf_token", csrf.Token) + assert.Equal(t, true, csrf.Auth.IsLogged()) +} + +func TestWriteCSRFCookie(t *testing.T) { + for name, test := range map[string]struct { + csrf CSRF + expectedToken string + }{ + "Logged": { + csrf: CSRF{ + Auth: Auth{ + AccessToken: "access_token", + }, + Token: "csrf_token", + }, + expectedToken: "csrf_token", + }, + "Guest": { + csrf: CSRF{ + Token: "csrf_token", + }, + expectedToken: "", + }, + } { + t.Run(name, func(t *testing.T) { + ctx := fasthttp.RequestCtx{} + + test.csrf.WriteCookie(&ctx) + + token := ctx.Response.Header.PeekCookie(CsrfTokenCookieName) + assert.Contains(t, string(token), test.expectedToken) + }) + } +} diff --git a/main.go b/main.go index 8ed3fb3..ab8057a 100644 --- a/main.go +++ b/main.go @@ -1,9 +1,12 @@ package main import ( + "context" "log" + "time" prom "github.com/flf2ko/fasthttp-prometheus" + "github.com/redis/go-redis/v9" "github.com/valyala/fasthttp" "github.com/cash-track/gateway/captcha" @@ -13,6 +16,7 @@ import ( "github.com/cash-track/gateway/logger" "github.com/cash-track/gateway/router" apiHandler "github.com/cash-track/gateway/router/api" + csrfHandler "github.com/cash-track/gateway/router/csrf" apiService "github.com/cash-track/gateway/service/api" ) @@ -24,15 +28,22 @@ const ( func main() { config.Global.Load() + redisClient := getRedisClient() + csrf := csrfHandler.NewRedisHandler(redisClient) + r := router.New( apiHandler.NewHttp( config.Global, apiService.NewHttp(retryhttp.NewFastHttpRetryClient(), config.Global), captcha.NewGoogleReCaptchaProvider(retryhttp.NewFastHttpRetryClient(), config.Global), ), + csrf, ) h := prom.NewPrometheus("http").WrapHandler(r.Router) h = headers.Handler(h) + if config.Global.CsrfEnabled { + h = csrf.Handler(h) + } h = headers.CorsHandler(h) h = logger.DebugHandler(h) @@ -68,3 +79,20 @@ func startTls(s *fasthttp.Server) { log.Fatalf("Error in HTTPS server: %v", err) } } + +func getRedisClient() *redis.Client { + client := redis.NewClient(&redis.Options{ + Addr: config.Global.RedisConnection, + }) + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + if err := client.Ping(ctx).Err(); err != nil { + log.Fatalf("Error connecting to redis: %v", err) + } + + log.Printf("Connected to Redis at %s\n", config.Global.RedisConnection) + + return client +} diff --git a/mocks/api_handler_mock.go b/mocks/api_handler_mock.go index 98234ac..c210ad8 100644 --- a/mocks/api_handler_mock.go +++ b/mocks/api_handler_mock.go @@ -20,6 +20,7 @@ import ( type ApiHandlerMock struct { ctrl *gomock.Controller recorder *ApiHandlerMockMockRecorder + isgomock struct{} } // ApiHandlerMockMockRecorder is the mock recorder for ApiHandlerMock. diff --git a/mocks/api_service_mock.go b/mocks/api_service_mock.go index dd88cc1..c76a09f 100644 --- a/mocks/api_service_mock.go +++ b/mocks/api_service_mock.go @@ -20,6 +20,7 @@ import ( type ApiServiceMock struct { ctrl *gomock.Controller recorder *ApiServiceMockMockRecorder + isgomock struct{} } // ApiServiceMockMockRecorder is the mock recorder for ApiServiceMock. diff --git a/mocks/captcha_provider_mock.go b/mocks/captcha_provider_mock.go index 7974d25..bccb55d 100644 --- a/mocks/captcha_provider_mock.go +++ b/mocks/captcha_provider_mock.go @@ -20,6 +20,7 @@ import ( type CaptchaProviderMock struct { ctrl *gomock.Controller recorder *CaptchaProviderMockMockRecorder + isgomock struct{} } // CaptchaProviderMockMockRecorder is the mock recorder for CaptchaProviderMock. diff --git a/mocks/csrf_handler_mock.go b/mocks/csrf_handler_mock.go new file mode 100644 index 0000000..cfeaa60 --- /dev/null +++ b/mocks/csrf_handler_mock.go @@ -0,0 +1,67 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: router/csrf/handler.go +// +// Generated by this command: +// +// mockgen -source=router/csrf/handler.go -package=mocks -destination=mocks/csrf_handler_mock.go -mock_names=Handler=CsrfHandlerMock +// + +// Package mocks is a generated GoMock package. +package mocks + +import ( + reflect "reflect" + + fasthttp "github.com/valyala/fasthttp" + gomock "go.uber.org/mock/gomock" +) + +// CsrfHandlerMock is a mock of Handler interface. +type CsrfHandlerMock struct { + ctrl *gomock.Controller + recorder *CsrfHandlerMockMockRecorder + isgomock struct{} +} + +// CsrfHandlerMockMockRecorder is the mock recorder for CsrfHandlerMock. +type CsrfHandlerMockMockRecorder struct { + mock *CsrfHandlerMock +} + +// NewCsrfHandlerMock creates a new mock instance. +func NewCsrfHandlerMock(ctrl *gomock.Controller) *CsrfHandlerMock { + mock := &CsrfHandlerMock{ctrl: ctrl} + mock.recorder = &CsrfHandlerMockMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *CsrfHandlerMock) EXPECT() *CsrfHandlerMockMockRecorder { + return m.recorder +} + +// Handler mocks base method. +func (m *CsrfHandlerMock) Handler(h fasthttp.RequestHandler) fasthttp.RequestHandler { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Handler", h) + ret0, _ := ret[0].(fasthttp.RequestHandler) + return ret0 +} + +// Handler indicates an expected call of Handler. +func (mr *CsrfHandlerMockMockRecorder) Handler(h any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Handler", reflect.TypeOf((*CsrfHandlerMock)(nil).Handler), h) +} + +// RotateTokenHandler mocks base method. +func (m *CsrfHandlerMock) RotateTokenHandler(ctx *fasthttp.RequestCtx) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "RotateTokenHandler", ctx) +} + +// RotateTokenHandler indicates an expected call of RotateTokenHandler. +func (mr *CsrfHandlerMockMockRecorder) RotateTokenHandler(ctx any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RotateTokenHandler", reflect.TypeOf((*CsrfHandlerMock)(nil).RotateTokenHandler), ctx) +} diff --git a/mocks/http/client_mock.go b/mocks/http/client_mock.go index 4da60a5..9cd7b3f 100644 --- a/mocks/http/client_mock.go +++ b/mocks/http/client_mock.go @@ -22,6 +22,7 @@ import ( type ClientMock struct { ctrl *gomock.Controller recorder *ClientMockMockRecorder + isgomock struct{} } // ClientMockMockRecorder is the mock recorder for ClientMock. diff --git a/mocks/http_retry_client_mock.go b/mocks/http_retry_client_mock.go index 9e1ef13..11cf5c3 100644 --- a/mocks/http_retry_client_mock.go +++ b/mocks/http_retry_client_mock.go @@ -23,6 +23,7 @@ import ( type HttpRetryClientMock struct { ctrl *gomock.Controller recorder *HttpRetryClientMockMockRecorder + isgomock struct{} } // HttpRetryClientMockMockRecorder is the mock recorder for HttpRetryClientMock. diff --git a/router/csrf/handler.go b/router/csrf/handler.go new file mode 100644 index 0000000..444faf5 --- /dev/null +++ b/router/csrf/handler.go @@ -0,0 +1,10 @@ +package csrf + +import ( + "github.com/valyala/fasthttp" +) + +type Handler interface { + Handler(h fasthttp.RequestHandler) fasthttp.RequestHandler + RotateTokenHandler(ctx *fasthttp.RequestCtx) +} diff --git a/router/csrf/redis_handler.go b/router/csrf/redis_handler.go new file mode 100644 index 0000000..669955b --- /dev/null +++ b/router/csrf/redis_handler.go @@ -0,0 +1,212 @@ +package csrf + +import ( + "context" + "fmt" + "log" + "strconv" + "strings" + "time" + + "github.com/golang-jwt/jwt/v4" + "github.com/google/uuid" + "github.com/redis/go-redis/v9" + "github.com/valyala/fasthttp" + + "github.com/cash-track/gateway/headers/cookie" + "github.com/cash-track/gateway/router/response" +) + +const ( + keyPrefix = "CT:csrf" + tokenTtl = time.Minute * 10 +) + +var ( + csrfRequiredForMethods = map[string]bool{ + fasthttp.MethodPost: true, + fasthttp.MethodPut: true, + fasthttp.MethodPatch: true, + fasthttp.MethodDelete: true, + } +) + +type userContext struct { + cookie cookie.CSRF + context string + isValid bool + err error +} + +func newUserContext(cookie cookie.CSRF) userContext { + ctx, err := getUserContextFromAccessToken(cookie.Auth.AccessToken) + userCtx := userContext{ + cookie: cookie, + context: ctx, + isValid: true, + } + + if err != nil { + userCtx.isValid = false + userCtx.err = err + } + + return userCtx +} + +type RedisHandler struct { + client *redis.Client +} + +func NewRedisHandler(client *redis.Client) *RedisHandler { + return &RedisHandler{ + client: client, + } +} + +// Handler will check each request of defined HTTP methods for CSRF token +// and rotate the new CSRF token as the response +func (r *RedisHandler) Handler(h fasthttp.RequestHandler) fasthttp.RequestHandler { + return func(ctx *fasthttp.RequestCtx) { + method := string(ctx.Request.Header.Method()) + + if method == fasthttp.MethodOptions { + h(ctx) + return + } + + userCtx := newUserContext(cookie.ReadCSRFCookie(ctx)) + + if err := r.validateCsrfRequest(userCtx, method); err != nil { + log.Printf("Error on validating CSRF token: %v", err) + response.ByErrorAndStatus(err, fasthttp.StatusExpectationFailed).Write(ctx) + return + } + + h(ctx) + + if userCtx.cookie.Auth.IsLogged() { + newToken, err := r.rotate(userCtx) + if err != nil { + log.Printf("Error on rotating CSRF token: %v", err) + ctx.SetStatusCode(fasthttp.StatusInternalServerError) + return + } + + userCtx.cookie.Token = newToken + userCtx.cookie.WriteCookie(ctx) + } + } +} + +func (r *RedisHandler) validateCsrfRequest(ctx userContext, method string) error { + if _, ok := csrfRequiredForMethods[method]; !ok { + return nil + } + + if !ctx.cookie.Auth.IsLogged() { + return nil + } + + if ctx.err != nil { + return fmt.Errorf("unable to verify with invalid user context: %w", ctx.err) + } + + return r.verify(ctx) +} + +// RotateTokenHandler configure CSRF cookie for next request validation +func (r *RedisHandler) RotateTokenHandler(ctx *fasthttp.RequestCtx) { + userCtx := newUserContext(cookie.ReadCSRFCookie(ctx)) + + if !userCtx.cookie.Auth.IsLogged() { + userCtx.cookie.WriteCookie(ctx) + ctx.SetStatusCode(fasthttp.StatusUnauthorized) + return + } + + newToken, err := r.rotate(userCtx) + if err != nil { + log.Printf("Error on rotating CSRF token: %v", err) + ctx.SetStatusCode(fasthttp.StatusInternalServerError) + return + } + + userCtx.cookie.Token = newToken + userCtx.cookie.WriteCookie(ctx) + ctx.SetStatusCode(fasthttp.StatusOK) +} + +func (r *RedisHandler) rotate(ctx userContext) (string, error) { + key := fmt.Sprintf("%s:%s", keyPrefix, ctx.context) + + token := generateNewToken() + + if err := r.client.SetEx(context.Background(), key, token, tokenTtl).Err(); err != nil { + return "", fmt.Errorf("error on writing new token: %w", err) + } + + return token, nil +} + +func (r *RedisHandler) verify(ctx userContext) error { + key := fmt.Sprintf("%s:%s", keyPrefix, ctx.context) + + if cmd := r.client.Get(context.Background(), key); cmd.Err() != nil { + return fmt.Errorf("error on reading token: %w", cmd.Err()) + } else if strings.Compare(ctx.cookie.Token, cmd.Val()) != 0 { + log.Printf("CSRF token is invalid: requested %s stored %s", ctx.cookie.Token, cmd.Val()) + return fmt.Errorf("invalid CSRF token") + } + + return nil +} + +func generateNewToken() string { + token, _ := uuid.NewV7() + return token.String() +} + +func getUserContextFromAccessToken(accessToken string) (string, error) { + defer func() { + if r := recover(); r != nil { + log.Printf("JWT decoding recovered from panic: %v", r) + } + }() + + if accessToken == "" { + return "", fmt.Errorf("access token is empty") + } + + token, _, err := jwt.NewParser().ParseUnverified(accessToken, jwt.MapClaims{}) + if err != nil || token == nil { + return "", fmt.Errorf("could not parse access token") + } + + var claims jwt.MapClaims + if c, ok := token.Claims.(jwt.MapClaims); ok { + claims = c + } + + var userId string + var issuedAt string + + if u, ok := claims["sub"]; ok { + userId = strconv.FormatFloat(u.(float64), 'f', 0, 64) + } else { + return "", fmt.Errorf("could not extract user id from claims") + } + + if i, ok := claims["iat"]; ok { + issuedAt = strconv.FormatFloat(i.(float64), 'f', 0, 64) + } else { + return "", fmt.Errorf("could not extract issued at from claims") + } + + if userId == "" || userId == "0" || issuedAt == "" || issuedAt == "0" { + return "", fmt.Errorf("could not extract user id or issued at from claims") + } + + // include iat claim to allow different clients having different CSRF tokens + return fmt.Sprintf("%s:%s", userId, issuedAt), nil +} diff --git a/router/csrf/redis_handler_test.go b/router/csrf/redis_handler_test.go new file mode 100644 index 0000000..e34a3b5 --- /dev/null +++ b/router/csrf/redis_handler_test.go @@ -0,0 +1,358 @@ +package csrf + +import ( + "errors" + "fmt" + "testing" + + "github.com/go-redis/redismock/v9" + "github.com/golang-jwt/jwt/v4" + "github.com/stretchr/testify/assert" + "github.com/valyala/fasthttp" + + "github.com/cash-track/gateway/headers/cookie" +) + +func TestHandler(t *testing.T) { + accessToken, _ := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{ + "sub": 123987, + "iat": 987654321, + }).SignedString([]byte("asd")) + + for name, test := range map[string]struct { + request *fasthttp.RequestCtx + setup func(mock redismock.ClientMock) + expectPass bool + expectStatus int + }{ + "TokenValidForPost": { + request: func() *fasthttp.RequestCtx { + ctx := fasthttp.RequestCtx{} + ctx.Request.Header.SetMethod(fasthttp.MethodPost) + ctx.Request.Header.SetCookie(cookie.CsrfTokenCookieName, "csrf_token") + ctx.Request.Header.SetCookie(cookie.AccessTokenCookieName, accessToken) + return &ctx + }(), + setup: func(mock redismock.ClientMock) { + key := fmt.Sprintf("%s:%d:%d", keyPrefix, 123987, 987654321) + mock.ExpectGet(key).SetVal("csrf_token") + mock.CustomMatch(func(expected, actual []interface{}) error { + assert.NotNil(t, actual) + if s, ok := actual[1].(string); ok { + assert.IsType(t, "", s) + } + return nil + }).ExpectSetEx(key, nil, 0).SetVal("token_1") + }, + expectPass: true, + }, + "TokenInvalidForPost": { + request: func() *fasthttp.RequestCtx { + ctx := fasthttp.RequestCtx{} + ctx.Request.Header.SetMethod(fasthttp.MethodPost) + ctx.Request.Header.SetCookie(cookie.CsrfTokenCookieName, "csrf_token_invalid") + ctx.Request.Header.SetCookie(cookie.AccessTokenCookieName, accessToken) + return &ctx + }(), + setup: func(mock redismock.ClientMock) { + key := fmt.Sprintf("%s:%d:%d", keyPrefix, 123987, 987654321) + mock.ExpectGet(key).SetVal("csrf_token") + }, + expectPass: false, + expectStatus: fasthttp.StatusExpectationFailed, + }, + "SkippedForOptions": { + request: func() *fasthttp.RequestCtx { + ctx := fasthttp.RequestCtx{} + ctx.Request.Header.SetMethod(fasthttp.MethodOptions) + return &ctx + }(), + setup: func(mock redismock.ClientMock) { + }, + expectPass: true, + }, + "ValidationSkippedForGet": { + request: func() *fasthttp.RequestCtx { + ctx := fasthttp.RequestCtx{} + ctx.Request.Header.SetMethod(fasthttp.MethodGet) + ctx.Request.Header.SetCookie(cookie.CsrfTokenCookieName, "csrf_token") + ctx.Request.Header.SetCookie(cookie.AccessTokenCookieName, accessToken) + return &ctx + }(), + setup: func(mock redismock.ClientMock) { + key := fmt.Sprintf("%s:%d:%d", keyPrefix, 123987, 987654321) + mock.CustomMatch(func(expected, actual []interface{}) error { + assert.NotNil(t, actual) + if s, ok := actual[1].(string); ok { + assert.IsType(t, "", s) + } + return nil + }).ExpectSetEx(key, nil, 0).SetVal("token_1") + }, + expectPass: true, + }, + "SkippedForGuest": { + request: func() *fasthttp.RequestCtx { + ctx := fasthttp.RequestCtx{} + ctx.Request.Header.SetMethod(fasthttp.MethodPost) + return &ctx + }(), + setup: func(mock redismock.ClientMock) { + }, + expectPass: true, + }, + "FailForInvalidAccessToken": { + request: func() *fasthttp.RequestCtx { + ctx := fasthttp.RequestCtx{} + ctx.Request.Header.SetMethod(fasthttp.MethodPost) + ctx.Request.Header.SetCookie(cookie.CsrfTokenCookieName, "csrf_token") + ctx.Request.Header.SetCookie(cookie.AccessTokenCookieName, "123") + return &ctx + }(), + setup: func(mock redismock.ClientMock) { + }, + expectPass: false, + }, + "VerifyError": { + request: func() *fasthttp.RequestCtx { + ctx := fasthttp.RequestCtx{} + ctx.Request.Header.SetMethod(fasthttp.MethodPost) + ctx.Request.Header.SetCookie(cookie.CsrfTokenCookieName, "csrf_token") + ctx.Request.Header.SetCookie(cookie.AccessTokenCookieName, accessToken) + return &ctx + }(), + setup: func(mock redismock.ClientMock) { + key := fmt.Sprintf("%s:%d:%d", keyPrefix, 123987, 987654321) + mock.ExpectGet(key).SetErr(errors.New("broken pipe")) + }, + expectPass: false, + expectStatus: fasthttp.StatusExpectationFailed, + }, + "RotateError": { + request: func() *fasthttp.RequestCtx { + ctx := fasthttp.RequestCtx{} + ctx.Request.Header.SetMethod(fasthttp.MethodPost) + ctx.Request.Header.SetCookie(cookie.CsrfTokenCookieName, "csrf_token") + ctx.Request.Header.SetCookie(cookie.AccessTokenCookieName, accessToken) + return &ctx + }(), + setup: func(mock redismock.ClientMock) { + key := fmt.Sprintf("%s:%d:%d", keyPrefix, 123987, 987654321) + mock.ExpectGet(key).SetVal("csrf_token") + mock.CustomMatch(func(expected, actual []interface{}) error { + assert.NotNil(t, actual) + if s, ok := actual[1].(string); ok { + assert.IsType(t, "", s) + } + return nil + }).ExpectSetEx(key, nil, 0).SetErr(errors.New("broken pipe")) + }, + expectPass: true, + expectStatus: fasthttp.StatusInternalServerError, + }, + } { + t.Run(name, func(t *testing.T) { + client, mock := redismock.NewClientMock() + + test.setup(mock) + + handlersExecuted := false + + handler := NewRedisHandler(client) + handler.Handler(func(ctx *fasthttp.RequestCtx) { + handlersExecuted = true + })(test.request) + + assert.Equal(t, test.expectPass, handlersExecuted) + assert.NotEqual(t, string(test.request.Response.Header.PeekCookie(cookie.CsrfTokenCookieName)), "csrf_token") + if test.expectStatus > 0 { + assert.Equal(t, test.expectStatus, test.request.Response.StatusCode()) + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Error(err) + } + }) + } +} + +func TestRotateTokenHandler(t *testing.T) { + accessToken, _ := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{ + "sub": 123987, + "iat": 987654321, + }).SignedString([]byte("asd")) + + for name, test := range map[string]struct { + request *fasthttp.RequestCtx + setup func(mock redismock.ClientMock) + expectRotate bool + expectStatus int + }{ + "Rotate": { + request: func() *fasthttp.RequestCtx { + ctx := fasthttp.RequestCtx{} + ctx.Request.Header.SetCookie(cookie.CsrfTokenCookieName, "csrf_token") + ctx.Request.Header.SetCookie(cookie.AccessTokenCookieName, accessToken) + return &ctx + }(), + setup: func(mock redismock.ClientMock) { + key := fmt.Sprintf("%s:%d:%d", keyPrefix, 123987, 987654321) + mock.CustomMatch(func(expected, actual []interface{}) error { + assert.NotNil(t, actual) + if s, ok := actual[1].(string); ok { + assert.IsType(t, "", s) + } + return nil + }).ExpectSetEx(key, nil, 0).SetVal("token_1") + }, + expectRotate: true, + expectStatus: fasthttp.StatusOK, + }, + "Guest": { + request: func() *fasthttp.RequestCtx { + ctx := fasthttp.RequestCtx{} + return &ctx + }(), + setup: func(mock redismock.ClientMock) { + }, + expectRotate: false, + expectStatus: fasthttp.StatusUnauthorized, + }, + "RotateError": { + request: func() *fasthttp.RequestCtx { + ctx := fasthttp.RequestCtx{} + ctx.Request.Header.SetCookie(cookie.CsrfTokenCookieName, "csrf_token") + ctx.Request.Header.SetCookie(cookie.AccessTokenCookieName, accessToken) + return &ctx + }(), + setup: func(mock redismock.ClientMock) { + key := fmt.Sprintf("%s:%d:%d", keyPrefix, 123987, 987654321) + mock.CustomMatch(func(expected, actual []interface{}) error { + assert.NotNil(t, actual) + if s, ok := actual[1].(string); ok { + assert.IsType(t, "", s) + } + return nil + }).ExpectSetEx(key, nil, 0).SetErr(errors.New("broken pipe")) + }, + expectRotate: false, + expectStatus: fasthttp.StatusInternalServerError, + }, + } { + t.Run(name, func(t *testing.T) { + client, mock := redismock.NewClientMock() + + test.setup(mock) + + handler := NewRedisHandler(client) + handler.RotateTokenHandler(test.request) + + if test.expectRotate { + assert.NotEqual(t, string(test.request.Response.Header.PeekCookie(cookie.CsrfTokenCookieName)), "csrf_token") + } + + assert.Equal(t, test.expectStatus, test.request.Response.StatusCode()) + + if err := mock.ExpectationsWereMet(); err != nil { + t.Error(err) + } + }) + } +} + +func TestGetUserContextFromAccessToken(t *testing.T) { + for name, test := range map[string]struct { + token string + expectContext string + expectError bool + expectPanic bool + }{ + "OK": { + token: func() string { + s, _ := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{ + "sub": 123987, + "iat": 987654321, + }).SignedString([]byte("asd")) + return s + }(), + expectContext: "123987:987654321", + }, + "Empty": { + token: "", + expectError: true, + }, + "Invalid": { + token: "not jwt token", + expectError: true, + }, + "NoClaims": { + token: func() string { + s, _ := jwt.New(jwt.SigningMethodHS256).SignedString([]byte("asd")) + return s + }(), + expectError: true, + }, + "NoUserId": { + token: func() string { + s, _ := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{ + "iat": 987654321, + }).SignedString([]byte("asd")) + return s + }(), + expectError: true, + }, + "NoIssuedTimestamp": { + token: func() string { + s, _ := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{ + "sub": 123987, + }).SignedString([]byte("asd")) + return s + }(), + expectError: true, + }, + "EmptyUserId": { + token: func() string { + s, _ := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{ + "sub": 0, + "iat": 987654321, + }).SignedString([]byte("asd")) + return s + }(), + expectError: true, + }, + "EmptyIssuedTimestamp": { + token: func() string { + s, _ := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{ + "sub": 123987, + "iat": 0, + }).SignedString([]byte("asd")) + return s + }(), + expectError: true, + }, + "ClaimsPanic": { + token: func() string { + s, _ := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{ + "sub": "", + "iat": "", + }).SignedString([]byte("asd")) + return s + }(), + expectPanic: true, + }, + } { + t.Run(name, func(t *testing.T) { + if test.expectPanic { + _, _ = getUserContextFromAccessToken(test.token) + return + } + ctx, err := getUserContextFromAccessToken(test.token) + if test.expectError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + assert.Equal(t, test.expectContext, ctx) + }) + } +} diff --git a/router/healthcheck_test.go b/router/healthcheck_test.go index 239e194..045b6b3 100644 --- a/router/healthcheck_test.go +++ b/router/healthcheck_test.go @@ -13,8 +13,9 @@ import ( func TestLiveHandler(t *testing.T) { ctrl := gomock.NewController(t) - h := mocks.NewApiHandlerMock(ctrl) - r := New(h) + a := mocks.NewApiHandlerMock(ctrl) + c := mocks.NewCsrfHandlerMock(ctrl) + r := New(a, c) ctx := fasthttp.RequestCtx{} @@ -26,9 +27,10 @@ func TestLiveHandler(t *testing.T) { func TestReadyHandler(t *testing.T) { ctrl := gomock.NewController(t) - h := mocks.NewApiHandlerMock(ctrl) - h.EXPECT().Healthcheck().Return(nil) - r := New(h) + a := mocks.NewApiHandlerMock(ctrl) + a.EXPECT().Healthcheck().Return(nil) + c := mocks.NewCsrfHandlerMock(ctrl) + r := New(a, c) ctx := fasthttp.RequestCtx{} @@ -40,9 +42,10 @@ func TestReadyHandler(t *testing.T) { func TestReadyHandlerFail(t *testing.T) { ctrl := gomock.NewController(t) - h := mocks.NewApiHandlerMock(ctrl) - h.EXPECT().Healthcheck().Return(fmt.Errorf("context cancelled")) - r := New(h) + a := mocks.NewApiHandlerMock(ctrl) + a.EXPECT().Healthcheck().Return(fmt.Errorf("context cancelled")) + c := mocks.NewCsrfHandlerMock(ctrl) + r := New(a, c) ctx := fasthttp.RequestCtx{} diff --git a/router/router.go b/router/router.go index b8d284f..9b7278c 100644 --- a/router/router.go +++ b/router/router.go @@ -4,17 +4,20 @@ import ( "github.com/fasthttp/router" "github.com/cash-track/gateway/router/api" + "github.com/cash-track/gateway/router/csrf" ) type Router struct { *router.Router - api api.Handler + api api.Handler + csrf csrf.Handler } -func New(api api.Handler) *Router { +func New(api api.Handler, csrf csrf.Handler) *Router { r := &Router{ Router: router.New(), api: api, + csrf: csrf, } r.register() return r @@ -23,6 +26,7 @@ func New(api api.Handler) *Router { func (r *Router) register() { r.ANY("/live", r.LiveHandler) r.ANY("/ready", r.ReadyHandler) + r.GET("/csrf", r.csrf.RotateTokenHandler) r.POST("/api/auth/login", r.api.AuthSetHandler) r.POST("/api/auth/login/passkey", r.api.AuthSetHandler) diff --git a/router/router_test.go b/router/router_test.go index 713396e..2202f40 100644 --- a/router/router_test.go +++ b/router/router_test.go @@ -11,12 +11,13 @@ import ( func TestNew(t *testing.T) { ctrl := gomock.NewController(t) - h := mocks.NewApiHandlerMock(ctrl) - r := New(h) + a := mocks.NewApiHandlerMock(ctrl) + c := mocks.NewCsrfHandlerMock(ctrl) + r := New(a, c) l := r.List() - assert.Len(t, l, 2) + assert.Len(t, l, 3) assert.NotNil(t, l["*"]) assert.Len(t, l["*"], 3) @@ -24,6 +25,10 @@ func TestNew(t *testing.T) { assert.Contains(t, l["*"], "/ready") assert.Contains(t, l["*"], "/api/{path:*}") + assert.NotNil(t, l["GET"]) + assert.Len(t, l["GET"], 1) + assert.Contains(t, l["GET"], "/csrf") + assert.NotNil(t, l["POST"]) assert.Len(t, l["POST"], 6) assert.Contains(t, l["POST"], "/api/auth/login")