From 2e0b381ac76b136d1aed6bf820fbf52d95961bf4 Mon Sep 17 00:00:00 2001 From: Taichi Sasaki Date: Fri, 6 Oct 2023 12:11:41 +0900 Subject: [PATCH] Replace RedirectSlashes with smartRedirectSlashes --- http/mux.go | 53 ++++++++++++++++++++++++++++++++++++++++++++---- http/mux_test.go | 27 ++++++++++++++++++------ 2 files changed, 70 insertions(+), 10 deletions(-) diff --git a/http/mux.go b/http/mux.go index 146027172d..f0e3114cbe 100644 --- a/http/mux.go +++ b/http/mux.go @@ -8,7 +8,6 @@ import ( "regexp" chi "github.com/go-chi/chi/v5" - "github.com/go-chi/chi/v5/middleware" ) type ( @@ -80,9 +79,8 @@ type ( // NewMuxer returns a Muxer implementation based on a Chi router. func NewMuxer() ResolverMuxer { r := chi.NewRouter() - // RedirectSlashes must be mounted at the top level of the router. - // See. https://github.com/go-chi/chi/blob/1129e362d6cce6e3805e3bc8dfbaeb34b5129789/middleware/strip_test.go#L105-L107 - r.Use(middleware.RedirectSlashes) + // smartRedirectSlashes must be mounted at the top level of the router. + r.Use(smartRedirectSlashes) r.NotFound(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { ctx := context.WithValue(req.Context(), AcceptTypeKey, req.Header.Get("Accept")) enc := ResponseEncoder(ctx, w) @@ -92,6 +90,53 @@ func NewMuxer() ResolverMuxer { return &mux{Router: r, wildcards: make(map[string]string)} } +// smartRedirectSlashes is a middleware that matches the request path with +// patterns added to the router and redirects it. +// If a pattern is added to the router with a trailing slash, any matches on +// that pattern without a trailing slash will be redirected to the version with +// the slash. If a pattern does not have a trailing slash, matches on that +// pattern with a trailing slash will be redirected to the version without. +func smartRedirectSlashes(next http.Handler) http.Handler { + fn := func(w http.ResponseWriter, r *http.Request) { + rctx := chi.RouteContext(r.Context()) + if rctx != nil { + var path string + if rctx.RoutePath != "" { + path = rctx.RoutePath + } else { + path = r.URL.Path + } + var method string + if rctx.RouteMethod != "" { + method = rctx.RouteMethod + } else { + method = r.Method + } + if len(path) > 1 { + if rctx.Routes != nil { + if !rctx.Routes.Match(chi.NewRouteContext(), method, path) { + if path[len(path)-1] == '/' { + path = path[:len(path)-1] + } else { + path += "/" + } + if rctx.Routes.Match(chi.NewRouteContext(), method, path) { + if r.URL.RawQuery != "" { + path = fmt.Sprintf("%s?%s", path, r.URL.RawQuery) + } + redirectURL := fmt.Sprintf("//%s%s", r.Host, path) + http.Redirect(w, r, redirectURL, http.StatusMovedPermanently) + return + } + } + } + } + } + next.ServeHTTP(w, r) + } + return http.HandlerFunc(fn) +} + // wildPath matches a wildcard path segment. var wildPath = regexp.MustCompile(`/{\*([a-zA-Z0-9_]+)}`) diff --git a/http/mux_test.go b/http/mux_test.go index 94f59781db..3877c5ab38 100644 --- a/http/mux_test.go +++ b/http/mux_test.go @@ -4,6 +4,8 @@ import ( "fmt" "net/http" "net/http/httptest" + "net/url" + "strings" "testing" "github.com/stretchr/testify/assert" @@ -201,6 +203,12 @@ func TestResolvePattern(t *testing.T) { URL: "/users", Expected: "/users", }, + { + Name: "query", + Patterns: []string{"/users"}, + URL: "/users?name=foo", + Expected: "/users", + }, } for _, c := range cases { @@ -227,15 +235,22 @@ func TestResolvePattern(t *testing.T) { w := httptest.NewRecorder() mux.ServeHTTP(w, req) assert.True(t, called) - // Make sure the URL with a trailing slash is redirected. + // Test the URL with a trailing slash. called = false // Reset. - req, _ = http.NewRequest("GET", c.URL+"/", nil) - w = httptest.NewRecorder() - mux.ServeHTTP(w, req) - assert.Equal(t, http.StatusMovedPermanently, w.Code) - req, _ = http.NewRequest("GET", w.Header().Get("Location"), nil) + u, err := url.Parse(c.URL) + assert.NoError(t, err) + u.Path += "/" + req, _ = http.NewRequest("GET", u.String(), nil) w = httptest.NewRecorder() mux.ServeHTTP(w, req) + if strings.Contains(c.Name, "wildcard") { + assert.Equal(t, http.StatusOK, w.Code) + } else { + assert.Equal(t, http.StatusMovedPermanently, w.Code) + req, _ = http.NewRequest("GET", w.Header().Get("Location"), nil) + w = httptest.NewRecorder() + mux.ServeHTTP(w, req) + } assert.True(t, called) }) }