diff --git a/middleware/csrf/csrf.go b/middleware/csrf/csrf.go index d417730416..7186345c8a 100644 --- a/middleware/csrf/csrf.go +++ b/middleware/csrf/csrf.go @@ -12,15 +12,17 @@ import ( ) var ( - ErrTokenNotFound = errors.New("csrf token not found") - ErrTokenInvalid = errors.New("csrf token invalid") - ErrRefererNotFound = errors.New("referer not supplied") - ErrRefererInvalid = errors.New("referer invalid") - ErrRefererNoMatch = errors.New("referer does not match host and is not a trusted origin") - ErrOriginInvalid = errors.New("origin invalid") - ErrOriginNoMatch = errors.New("origin does not match host and is not a trusted origin") - errOriginNotFound = errors.New("origin not supplied or is null") // internal error, will not be returned to the user - dummyValue = []byte{'+'} + ErrTokenNotFound = errors.New("csrf token not found") + ErrTokenInvalid = errors.New("csrf token invalid") + ErrRefererNotFound = errors.New("referer not supplied") + ErrRefererInvalid = errors.New("referer invalid") + ErrRefererNoMatch = errors.New("referer does not match host and is not a trusted origin") + ErrOriginInvalid = errors.New("origin invalid") + ErrOriginNoMatch = errors.New("origin does not match host and is not a trusted origin") + ErrStorageRetrievalFailed = errors.New("unable to retrieve data from CSRF storage") + + errOriginNotFound = errors.New("origin not supplied or is null") // internal error, will not be returned to the user + dummyValue = []byte{'+'} ) // Handler for CSRF middleware @@ -103,10 +105,12 @@ func New(config ...Config) fiber.Handler { switch c.Method() { case fiber.MethodGet, fiber.MethodHead, fiber.MethodOptions, fiber.MethodTrace: cookieToken := c.Cookies(cfg.CookieName) - if cookieToken != "" { - raw := getRawFromStorage(c, cookieToken, cfg, sessionManager, storageManager) - + raw, err := getRawFromStorage(c, cookieToken, cfg, sessionManager, storageManager) + if err != nil { + println("hereee+" + err.Error()) + return cfg.ErrorHandler(c, err) + } if raw != nil { token = cookieToken // Token is valid, safe to set it } @@ -149,14 +153,18 @@ func New(config ...Config) fiber.Handler { return cfg.ErrorHandler(c, ErrTokenInvalid) } - raw := getRawFromStorage(c, extractedToken, cfg, sessionManager, storageManager) + raw, err := getRawFromStorage(c, extractedToken, cfg, sessionManager, storageManager) + if err != nil { + + return cfg.ErrorHandler(c, err) + } else if raw == nil { - if raw == nil { // If token is not in storage, expire the cookie expireCSRFCookie(c, cfg) // and return an error - return cfg.ErrorHandler(c, ErrTokenNotFound) + return cfg.ErrorHandler(c, ErrTokenInvalid) } + if cfg.SingleUseToken { // If token is single use, delete it from storage deleteTokenFromStorage(c, extractedToken, cfg, sessionManager, storageManager) @@ -210,7 +218,7 @@ func HandlerFromContext(c fiber.Ctx) *Handler { // getRawFromStorage returns the raw value from the storage for the given token // returns nil if the token does not exist, is expired or is invalid -func getRawFromStorage(c fiber.Ctx, token string, cfg Config, sessionManager *sessionManager, storageManager *storageManager) []byte { +func getRawFromStorage(c fiber.Ctx, token string, cfg Config, sessionManager *sessionManager, storageManager *storageManager) ([]byte, error) { if cfg.Session != nil { return sessionManager.getRaw(c, token, dummyValue) } diff --git a/middleware/csrf/csrf_test.go b/middleware/csrf/csrf_test.go index 82252549bd..34e6065a54 100644 --- a/middleware/csrf/csrf_test.go +++ b/middleware/csrf/csrf_test.go @@ -1,6 +1,7 @@ package csrf import ( + "fmt" "net/http/httptest" "strings" "testing" @@ -1263,7 +1264,6 @@ func Test_CSRF_Cookie_Injection_Exploit(t *testing.T) { ctx.Request.SetRequestURI("/") h(ctx) token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie)) - token = strings.Split(strings.Split(token, ";")[0], "=")[1] // Exploit CSRF token we just injected ctx.Request.Reset() @@ -1509,3 +1509,67 @@ func Test_CSRF_FromContextMethods_Invalid(t *testing.T) { require.NoError(t, err) require.Equal(t, fiber.StatusOK, resp.StatusCode) } + +type mockStorage struct{} + +func (m *mockStorage) Get(key string) ([]byte, error) { + return nil, fmt.Errorf("not found") +} + +func (m *mockStorage) Set(key string, val []byte, exp time.Duration) error { + return nil +} + +func (m *mockStorage) Delete(key string) error { + return nil +} + +func (m *mockStorage) Reset() error { + return nil +} + +func (m *mockStorage) Close() error { + return nil +} + +func Test_NotGetTokenInSessionStorage(t *testing.T) { + t.Parallel() + + errHandler := func(c fiber.Ctx, err error) error { + require.Equal(t, ErrStorageRetrievalFailed.Error(), err.Error()) + return c.Status(419).Send([]byte(err.Error())) + } + + // &session.Store{}.Storage.Set(ConfigDefault.CookieName, "fiber", 300) + + app := fiber.New() + app.Use(New(Config{ + ErrorHandler: errHandler, + Session: &session.Store{ + Config: session.Config{ + Storage: &mockStorage{}, + KeyGenerator: ConfigDefault.KeyGenerator, + KeyLookup: ConfigDefault.KeyLookup, + Expiration: ConfigDefault.Expiration, + CookieSameSite: "Lax", + }, + }, + })) + + app.Post("/", func(c fiber.Ctx) error { + return c.SendStatus(fiber.StatusOK) + }) + + h := app.Handler() + ctx := &fasthttp.RequestCtx{} + + ctx.Request.Reset() + ctx.Response.Reset() + ctx.Request.Header.SetMethod(fiber.MethodGet) + ctx.Request.Header.SetCookie(ConfigDefault.CookieName, "fiber") + h(ctx) + + require.Equal(t, 419, ctx.Response.StatusCode()) + require.Equal(t, "invalid CSRF token", string(ctx.Response.Body())) + +} diff --git a/middleware/csrf/session_manager.go b/middleware/csrf/session_manager.go index 3bbf173a26..d927215af4 100644 --- a/middleware/csrf/session_manager.go +++ b/middleware/csrf/session_manager.go @@ -1,6 +1,7 @@ package csrf import ( + "fmt" "time" "github.com/gofiber/fiber/v3" @@ -26,20 +27,25 @@ func newSessionManager(s *session.Store, k string) *sessionManager { } // get token from session -func (m *sessionManager) getRaw(c fiber.Ctx, key string, raw []byte) []byte { +func (m *sessionManager) getRaw(c fiber.Ctx, key string, raw []byte) ([]byte, error) { sess, err := m.session.Get(c) if err != nil { - return nil + return nil, ErrStorageRetrievalFailed } + + fmt.Println("key: ", sess) + token, ok := sess.Get(m.key).(Token) - if ok { - if token.Expiration.Before(time.Now()) || key != token.Key || !compareTokens(raw, token.Raw) { - return nil - } - return token.Raw + fmt.Println("key: ", token, ok) + if !ok { + return nil, ErrTokenInvalid + } + + if token.Expiration.Before(time.Now()) || key != token.Key || !compareTokens(raw, token.Raw) { + return nil, ErrTokenInvalid } - return nil + return token.Raw, nil } // set token in session diff --git a/middleware/csrf/storage_manager.go b/middleware/csrf/storage_manager.go index 4d3c26420a..76ad38cb84 100644 --- a/middleware/csrf/storage_manager.go +++ b/middleware/csrf/storage_manager.go @@ -1,11 +1,13 @@ package csrf import ( + "fmt" "sync" "time" "github.com/gofiber/fiber/v3" "github.com/gofiber/fiber/v3/internal/memory" + "github.com/gofiber/fiber/v3/log" "github.com/gofiber/utils/v2" ) @@ -41,20 +43,35 @@ func newStorageManager(storage fiber.Storage) *storageManager { } // get raw data from storage or memory -func (m *storageManager) getRaw(key string) []byte { - var raw []byte +func (m *storageManager) getRaw(key string) ([]byte, error) { + var ( + raw []byte + err error + ) if m.storage != nil { - raw, _ = m.storage.Get(key) //nolint:errcheck // TODO: Do not ignore error + raw, err = m.storage.Get(key) + if err != nil { + return nil, fmt.Errorf("%w: %s", ErrStorageRetrievalFailed, err.Error()) + } } else { - raw, _ = m.memory.Get(key).([]byte) //nolint:errcheck // TODO: Do not ignore error + var ok bool + raw, ok = m.memory.Get(key).([]byte) + if !ok { + return nil, ErrStorageRetrievalFailed + } } - return raw + + return raw, nil } // set data to storage or memory func (m *storageManager) setRaw(key string, raw []byte, exp time.Duration) { if m.storage != nil { - _ = m.storage.Set(key, raw, exp) //nolint:errcheck // TODO: Do not ignore error + err := m.storage.Set(key, raw, exp) + if err != nil { + log.Warnf("csrf: failed to save session in storage: %s", err.Error()) + return + } } else { // the key is crucial in crsf and sometimes a reference to another value which can be reused later(pool/unsafe values concept), so a copy is made here m.memory.Set(utils.CopyString(key), raw, exp) @@ -64,7 +81,11 @@ func (m *storageManager) setRaw(key string, raw []byte, exp time.Duration) { // delete data from storage or memory func (m *storageManager) delRaw(key string) { if m.storage != nil { - _ = m.storage.Delete(key) //nolint:errcheck // TODO: Do not ignore error + err := m.storage.Delete(key) + if err != nil { + log.Warnf("csrf: failed to delete session in storage: %s", err.Error()) + return + } } else { m.memory.Delete(key) }