diff --git a/.github/workflows/linter.yml b/.github/workflows/linter.yml index bd2c0bce4c..beed212610 100644 --- a/.github/workflows/linter.yml +++ b/.github/workflows/linter.yml @@ -37,4 +37,4 @@ jobs: uses: golangci/golangci-lint-action@v6 with: # NOTE: Keep this in sync with the version from .golangci.yml - version: v1.62.0 + version: v1.62.2 diff --git a/.github/workflows/markdown.yml b/.github/workflows/markdown.yml index a015149c22..cf3575a947 100644 --- a/.github/workflows/markdown.yml +++ b/.github/workflows/markdown.yml @@ -15,7 +15,7 @@ jobs: uses: actions/checkout@v4 - name: Run markdownlint-cli2 - uses: DavidAnson/markdownlint-cli2-action@v18 + uses: DavidAnson/markdownlint-cli2-action@v19 with: globs: | **/*.md diff --git a/Makefile b/Makefile index 4b348cd574..669b3fbee4 100644 --- a/Makefile +++ b/Makefile @@ -35,7 +35,7 @@ markdown: ## lint: 🚨 Run lint checks .PHONY: lint lint: - go run github.com/golangci/golangci-lint/cmd/golangci-lint@v1.62.0 run ./... + go run github.com/golangci/golangci-lint/cmd/golangci-lint@v1.62.2 run ./... ## test: 🚦 Execute all tests .PHONY: test diff --git a/app.go b/app.go index 5e5475b5f1..3810a8ec0c 100644 --- a/app.go +++ b/app.go @@ -616,6 +616,10 @@ func (app *App) handleTrustedProxy(ipAddress string) { // Note: It doesn't allow adding new methods, only customizing exist methods. func (app *App) NewCtxFunc(function func(app *App) CustomCtx) { app.newCtxFunc = function + + if app.server != nil { + app.server.Handler = app.customRequestHandler + } } // RegisterCustomConstraint allows to register custom constraint. @@ -868,7 +872,11 @@ func (app *App) Config() Config { func (app *App) Handler() fasthttp.RequestHandler { //revive:disable-line:confusing-naming // Having both a Handler() (uppercase) and a handler() (lowercase) is fine. TODO: Use nolint:revive directive instead. See https://github.com/golangci/golangci-lint/issues/3476 // prepare the server for the start app.startupProcess() - return app.requestHandler + + if app.newCtxFunc != nil { + return app.customRequestHandler + } + return app.defaultRequestHandler } // Stack returns the raw router stack. @@ -1057,7 +1065,11 @@ func (app *App) init() *App { } // fasthttp server settings - app.server.Handler = app.requestHandler + if app.newCtxFunc != nil { + app.server.Handler = app.customRequestHandler + } else { + app.server.Handler = app.defaultRequestHandler + } app.server.Name = app.config.ServerHeader app.server.Concurrency = app.config.Concurrency app.server.NoDefaultDate = app.config.DisableDefaultDate diff --git a/app_test.go b/app_test.go index a99796a2c1..8455ded86e 100644 --- a/app_test.go +++ b/app_test.go @@ -581,32 +581,51 @@ func Test_App_Use_StrictRouting(t *testing.T) { func Test_App_Add_Method_Test(t *testing.T) { t.Parallel() - defer func() { - if err := recover(); err != nil { - require.Equal(t, "add: invalid http method JANE\n", fmt.Sprintf("%v", err)) - } - }() methods := append(DefaultMethods, "JOHN") //nolint:gocritic // We want a new slice here app := New(Config{ RequestMethods: methods, }) - app.Add([]string{"JOHN"}, "/doe", testEmptyHandler) + app.Add([]string{"JOHN"}, "/john", testEmptyHandler) - resp, err := app.Test(httptest.NewRequest("JOHN", "/doe", nil)) + resp, err := app.Test(httptest.NewRequest("JOHN", "/john", nil)) require.NoError(t, err, "app.Test(req)") require.Equal(t, StatusOK, resp.StatusCode, "Status code") - resp, err = app.Test(httptest.NewRequest(MethodGet, "/doe", nil)) + resp, err = app.Test(httptest.NewRequest(MethodGet, "/john", nil)) require.NoError(t, err, "app.Test(req)") require.Equal(t, StatusMethodNotAllowed, resp.StatusCode, "Status code") - resp, err = app.Test(httptest.NewRequest("UNKNOWN", "/doe", nil)) + resp, err = app.Test(httptest.NewRequest("UNKNOWN", "/john", nil)) require.NoError(t, err, "app.Test(req)") require.Equal(t, StatusNotImplemented, resp.StatusCode, "Status code") - app.Add([]string{"JANE"}, "/doe", testEmptyHandler) + // Add a new method + require.Panics(t, func() { + app.Add([]string{"JANE"}, "/jane", testEmptyHandler) + }) +} + +func Test_App_All_Method_Test(t *testing.T) { + t.Parallel() + + methods := append(DefaultMethods, "JOHN") //nolint:gocritic // We want a new slice here + app := New(Config{ + RequestMethods: methods, + }) + + // Add a new method with All + app.All("/doe", testEmptyHandler) + + resp, err := app.Test(httptest.NewRequest("JOHN", "/doe", nil)) + require.NoError(t, err, "app.Test(req)") + require.Equal(t, StatusOK, resp.StatusCode, "Status code") + + // Add a new method + require.Panics(t, func() { + app.Add([]string{"JANE"}, "/jane", testEmptyHandler) + }) } // go test -run Test_App_GETOnly diff --git a/binder/mapping.go b/binder/mapping.go index 29ba2a25b0..70cb9cbc2d 100644 --- a/binder/mapping.go +++ b/binder/mapping.go @@ -128,7 +128,6 @@ func parseToMap(ptr any, data map[string][]string) error { newMap[k] = "" continue } - newMap[k] = v[len(v)-1] } default: diff --git a/ctx_interface_gen.go b/ctx_interface_gen.go index d7f8bbc615..cc48576efb 100644 --- a/ctx_interface_gen.go +++ b/ctx_interface_gen.go @@ -350,5 +350,8 @@ type Ctx interface { setIndexRoute(route int) setMatched(matched bool) setRoute(route *Route) + // Drop closes the underlying connection without sending any response headers or body. + // This can be useful for silently terminating client connections, such as in DDoS mitigation + // or when blocking access to sensitive endpoints. Drop() error } diff --git a/ctx_test.go b/ctx_test.go index 88b617eb5b..eb81876e37 100644 --- a/ctx_test.go +++ b/ctx_test.go @@ -127,6 +127,35 @@ func Test_Ctx_CustomCtx(t *testing.T) { require.Equal(t, "prefix_v3", string(body)) } +// go test -run Test_Ctx_CustomCtx +func Test_Ctx_CustomCtx_and_Method(t *testing.T) { + t.Parallel() + + // Create app with custom request methods + methods := append(DefaultMethods, "JOHN") //nolint:gocritic // We want a new slice here + app := New(Config{ + RequestMethods: methods, + }) + + // Create custom context + app.NewCtxFunc(func(app *App) CustomCtx { + return &customCtx{ + DefaultCtx: *NewDefaultCtx(app), + } + }) + + // Add route with custom method + app.Add([]string{"JOHN"}, "/doe", testEmptyHandler) + resp, err := app.Test(httptest.NewRequest("JOHN", "/doe", nil)) + require.NoError(t, err, "app.Test(req)") + require.Equal(t, StatusOK, resp.StatusCode, "Status code") + + // Add a new method + require.Panics(t, func() { + app.Add([]string{"JANE"}, "/jane", testEmptyHandler) + }) +} + // go test -run Test_Ctx_Accepts_EmptyAccept func Test_Ctx_Accepts_EmptyAccept(t *testing.T) { t.Parallel() diff --git a/docs/middleware/session.md b/docs/middleware/session.md index ff73ff6094..b175ac9e3f 100644 --- a/docs/middleware/session.md +++ b/docs/middleware/session.md @@ -2,7 +2,7 @@ id: session --- -# Session Middleware for [Fiber](https://github.com/gofiber/fiber) +# Session The `session` middleware provides session management for Fiber applications, utilizing the [Storage](https://github.com/gofiber/storage) package for multi-database support via a unified interface. By default, session data is stored in memory, but custom storage options are easily configurable (see examples below). diff --git a/middleware/idempotency/locker.go b/middleware/idempotency/locker.go index 2c3348b8f3..f24db382a5 100644 --- a/middleware/idempotency/locker.go +++ b/middleware/idempotency/locker.go @@ -10,42 +10,58 @@ type Locker interface { Unlock(key string) error } +type countedLock struct { + mu sync.Mutex + locked int +} + type MemoryLock struct { - keys map[string]*sync.Mutex + keys map[string]*countedLock mu sync.Mutex } func (l *MemoryLock) Lock(key string) error { l.mu.Lock() - mu, ok := l.keys[key] + lock, ok := l.keys[key] if !ok { - mu = new(sync.Mutex) - l.keys[key] = mu + lock = new(countedLock) + l.keys[key] = lock } + lock.locked++ l.mu.Unlock() - mu.Lock() + lock.mu.Lock() return nil } func (l *MemoryLock) Unlock(key string) error { l.mu.Lock() - mu, ok := l.keys[key] - l.mu.Unlock() + lock, ok := l.keys[key] if !ok { // This happens if we try to unlock an unknown key + l.mu.Unlock() return nil } + l.mu.Unlock() - mu.Unlock() + lock.mu.Unlock() + + l.mu.Lock() + lock.locked-- + if lock.locked <= 0 { + // This happens if countedLock is used to Lock and Unlock the same number of times + // So, we can delete the key to prevent memory leak + delete(l.keys, key) + } + l.mu.Unlock() return nil } func NewMemoryLock() *MemoryLock { return &MemoryLock{ - keys: make(map[string]*sync.Mutex), + keys: make(map[string]*countedLock), } } diff --git a/middleware/idempotency/locker_test.go b/middleware/idempotency/locker_test.go index 3b4a3ca78a..81da15d3bf 100644 --- a/middleware/idempotency/locker_test.go +++ b/middleware/idempotency/locker_test.go @@ -1,6 +1,8 @@ package idempotency_test import ( + "strconv" + "sync/atomic" "testing" "time" @@ -59,3 +61,67 @@ func Test_MemoryLock(t *testing.T) { require.NoError(t, err) } } + +func Benchmark_MemoryLock(b *testing.B) { + keys := make([]string, b.N) + for i := range keys { + keys[i] = strconv.Itoa(i) + } + + lock := idempotency.NewMemoryLock() + + b.ResetTimer() + + for i := 0; i < b.N; i++ { + key := keys[i] + if err := lock.Lock(key); err != nil { + b.Fatal(err) + } + if err := lock.Unlock(key); err != nil { + b.Fatal(err) + } + } +} + +func Benchmark_MemoryLock_Parallel(b *testing.B) { + // In order to prevent using repeated keys I pre-allocate keys + keys := make([]string, 1_000_000) + for i := range keys { + keys[i] = strconv.Itoa(i) + } + + b.Run("UniqueKeys", func(b *testing.B) { + lock := idempotency.NewMemoryLock() + var keyI atomic.Int32 + b.RunParallel(func(p *testing.PB) { + for p.Next() { + i := int(keyI.Add(1)) % len(keys) + key := keys[i] + if err := lock.Lock(key); err != nil { + b.Fatal(err) + } + if err := lock.Unlock(key); err != nil { + b.Fatal(err) + } + } + }) + }) + + b.Run("RepeatedKeys", func(b *testing.B) { + lock := idempotency.NewMemoryLock() + var keyI atomic.Int32 + b.RunParallel(func(p *testing.PB) { + for p.Next() { + // Division by 3 ensures that index will be repreated exactly 3 times + i := int(keyI.Add(1)) / 3 % len(keys) + key := keys[i] + if err := lock.Lock(key); err != nil { + b.Fatal(err) + } + if err := lock.Unlock(key); err != nil { + b.Fatal(err) + } + } + }) + }) +} diff --git a/path.go b/path.go index 00105d5cc0..282073ec04 100644 --- a/path.go +++ b/path.go @@ -620,10 +620,16 @@ func GetTrimmedParam(param string) string { // RemoveEscapeChar remove escape characters func RemoveEscapeChar(word string) string { - if strings.IndexByte(word, escapeChar) != -1 { - return strings.ReplaceAll(word, string(escapeChar), "") + b := []byte(word) + dst := 0 + for src := 0; src < len(b); src++ { + if b[src] == '\\' { + continue + } + b[dst] = b[src] + dst++ } - return word + return string(b[:dst]) } func getParamConstraintType(constraintPart string) TypeConstraint { diff --git a/router.go b/router.go index 2091cfc6cb..9612da170b 100644 --- a/router.go +++ b/router.go @@ -5,11 +5,11 @@ package fiber import ( + "bytes" "errors" "fmt" "html" "sort" - "strings" "sync/atomic" "github.com/gofiber/utils/v2" @@ -65,10 +65,12 @@ type Route struct { func (r *Route) match(detectionPath, path string, params *[maxParams]string) bool { // root detectionPath check - if r.root && detectionPath == "/" { + if r.root && len(detectionPath) == 1 && detectionPath[0] == '/' { return true - // '*' wildcard matches any detectionPath - } else if r.star { + } + + // '*' wildcard matches any detectionPath + if r.star { if len(path) > 1 { params[0] = path[1:] } else { @@ -76,24 +78,32 @@ func (r *Route) match(detectionPath, path string, params *[maxParams]string) boo } return true } - // Does this route have parameters + + // Does this route have parameters? if len(r.Params) > 0 { - // Match params - if match := r.routeParser.getMatch(detectionPath, path, params, r.use); match { - // Get params from the path detectionPath - return match + // Match params using precomputed routeParser + if r.routeParser.getMatch(detectionPath, path, params, r.use) { + return true } } - // Is this route a Middleware? + + // Middleware route? if r.use { - // Single slash will match or detectionPath prefix - if r.root || strings.HasPrefix(detectionPath, r.path) { + // Single slash or prefix match + plen := len(r.path) + if r.root { + // If r.root is '/', it matches everything starting at '/' + if len(detectionPath) > 0 && detectionPath[0] == '/' { + return true + } + } else if len(detectionPath) >= plen && detectionPath[:plen] == r.path { return true } - // Check for a simple detectionPath match - } else if len(r.path) == len(detectionPath) && r.path == detectionPath { + } else if len(r.path) == len(detectionPath) && detectionPath == r.path { + // Check exact match return true } + // No match return false } @@ -201,44 +211,63 @@ func (app *App) next(c *DefaultCtx) (bool, error) { return false, err } -func (app *App) requestHandler(rctx *fasthttp.RequestCtx) { - // Handler for default ctxs - var c CustomCtx - var ok bool - if app.newCtxFunc != nil { - c, ok = app.AcquireCtx(rctx).(CustomCtx) - if !ok { - panic(errors.New("requestHandler: failed to type-assert to CustomCtx")) - } - } else { - c, ok = app.AcquireCtx(rctx).(*DefaultCtx) - if !ok { - panic(errors.New("requestHandler: failed to type-assert to *DefaultCtx")) - } +func (app *App) defaultRequestHandler(rctx *fasthttp.RequestCtx) { + // Acquire DefaultCtx from the pool + ctx, ok := app.AcquireCtx(rctx).(*DefaultCtx) + if !ok { + panic(errors.New("requestHandler: failed to type-assert to *DefaultCtx")) } - defer app.ReleaseCtx(c) - // handle invalid http method directly - if app.methodInt(c.Method()) == -1 { - _ = c.SendStatus(StatusNotImplemented) //nolint:errcheck // Always return nil + defer app.ReleaseCtx(ctx) + + // Check if the HTTP method is valid + if ctx.methodINT == -1 { + _ = ctx.SendStatus(StatusNotImplemented) //nolint:errcheck // Always return nil return } - // check flash messages - if strings.Contains(utils.UnsafeString(c.Request().Header.RawHeaders()), FlashCookieName) { - c.Redirect().parseAndClearFlashMessages() + // Optional: Check flash messages + rawHeaders := ctx.Request().Header.RawHeaders() + if len(rawHeaders) > 0 && bytes.Contains(rawHeaders, []byte(FlashCookieName)) { + ctx.Redirect().parseAndClearFlashMessages() } - // Find match in stack - var err error - if app.newCtxFunc != nil { - _, err = app.nextCustom(c) - } else { - _, err = app.next(c.(*DefaultCtx)) //nolint:errcheck // It is fine to ignore the error here + // Attempt to match a route and execute the chain + _, err := app.next(ctx) + if err != nil { + if catch := ctx.App().ErrorHandler(ctx, err); catch != nil { + _ = ctx.SendStatus(StatusInternalServerError) //nolint:errcheck // Always return nil + } + // TODO: Do we need to return here? } +} + +func (app *App) customRequestHandler(rctx *fasthttp.RequestCtx) { + // Acquire CustomCtx from the pool + ctx, ok := app.AcquireCtx(rctx).(CustomCtx) + if !ok { + panic(errors.New("requestHandler: failed to type-assert to CustomCtx")) + } + + defer app.ReleaseCtx(ctx) + + // Check if the HTTP method is valid + if app.methodInt(ctx.Method()) == -1 { + _ = ctx.SendStatus(StatusNotImplemented) //nolint:errcheck // Always return nil + return + } + + // Optional: Check flash messages + rawHeaders := ctx.Request().Header.RawHeaders() + if len(rawHeaders) > 0 && bytes.Contains(rawHeaders, []byte(FlashCookieName)) { + ctx.Redirect().parseAndClearFlashMessages() + } + + // Attempt to match a route and execute the chain + _, err := app.nextCustom(ctx) if err != nil { - if catch := c.App().ErrorHandler(c, err); catch != nil { - _ = c.SendStatus(StatusInternalServerError) //nolint:errcheck // It is fine to ignore the error here + if catch := ctx.App().ErrorHandler(ctx, err); catch != nil { + _ = ctx.SendStatus(StatusInternalServerError) //nolint:errcheck // Always return nil } // TODO: Do we need to return here? } @@ -295,68 +324,56 @@ func (app *App) register(methods []string, pathRaw string, group *Group, handler handlers = append(handlers, handler) } + // Precompute path normalization ONCE + if pathRaw == "" { + pathRaw = "/" + } + if pathRaw[0] != '/' { + pathRaw = "/" + pathRaw + } + pathPretty := pathRaw + if !app.config.CaseSensitive { + pathPretty = utils.ToLower(pathPretty) + } + if !app.config.StrictRouting && len(pathPretty) > 1 { + pathPretty = utils.TrimRight(pathPretty, '/') + } + pathClean := RemoveEscapeChar(pathPretty) + + parsedRaw := parseRoute(pathRaw, app.customConstraints...) + parsedPretty := parseRoute(pathPretty, app.customConstraints...) + for _, method := range methods { - // Uppercase HTTP methods method = utils.ToUpper(method) - // Check if the HTTP method is valid unless it's USE if method != methodUse && app.methodInt(method) == -1 { panic(fmt.Sprintf("add: invalid http method %s\n", method)) } - // is mounted app + isMount := group != nil && group.app != app - // A route requires atleast one ctx handler if len(handlers) == 0 && !isMount { panic(fmt.Sprintf("missing handler/middleware in route: %s\n", pathRaw)) } - // Cannot have an empty path - if pathRaw == "" { - pathRaw = "/" - } - // Path always start with a '/' - if pathRaw[0] != '/' { - pathRaw = "/" + pathRaw - } - // Create a stripped path in case-sensitive / trailing slashes - pathPretty := pathRaw - // Case-sensitive routing, all to lowercase - if !app.config.CaseSensitive { - pathPretty = utils.ToLower(pathPretty) - } - // Strict routing, remove trailing slashes - if !app.config.StrictRouting && len(pathPretty) > 1 { - pathPretty = utils.TrimRight(pathPretty, '/') - } - // Is layer a middleware? + isUse := method == methodUse - // Is path a direct wildcard? - isStar := pathPretty == "/*" - // Is path a root slash? - isRoot := pathPretty == "/" - // Parse path parameters - parsedRaw := parseRoute(pathRaw, app.customConstraints...) - parsedPretty := parseRoute(pathPretty, app.customConstraints...) - - // Create route metadata without pointer + isStar := pathClean == "/*" + isRoot := pathClean == "/" + route := Route{ - // Router booleans use: isUse, mount: isMount, star: isStar, root: isRoot, - // Path data - path: RemoveEscapeChar(pathPretty), + path: pathClean, routeParser: parsedPretty, Params: parsedRaw.params, + group: group, - // Group data - group: group, - - // Public data Path: pathRaw, Method: method, Handlers: handlers, } + // Increment global handler count atomic.AddUint32(&app.handlersCount, uint32(len(handlers))) //nolint:gosec // Not a concern diff --git a/router_test.go b/router_test.go index 5509039c66..fe5b3429e0 100644 --- a/router_test.go +++ b/router_test.go @@ -591,6 +591,29 @@ func Benchmark_Router_Next_Default(b *testing.B) { } } +// go test -benchmem -run=^$ -bench ^Benchmark_Router_Next_Default_Parallel$ github.com/gofiber/fiber/v3 -count=1 +func Benchmark_Router_Next_Default_Parallel(b *testing.B) { + app := New() + app.Get("/", func(_ Ctx) error { + return nil + }) + + h := app.Handler() + + b.ReportAllocs() + b.ResetTimer() + + b.RunParallel(func(pb *testing.PB) { + fctx := &fasthttp.RequestCtx{} + fctx.Request.Header.SetMethod(MethodGet) + fctx.Request.SetRequestURI("/") + + for pb.Next() { + h(fctx) + } + }) +} + // go test -v ./... -run=^$ -bench=Benchmark_Route_Match -benchmem -count=4 func Benchmark_Route_Match(b *testing.B) { var match bool