Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace RedirectSlashes with smartRedirectSlashes #3385

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 49 additions & 4 deletions http/mux.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import (
"regexp"

chi "github.com/go-chi/chi/v5"
"github.com/go-chi/chi/v5/middleware"
)

type (
Expand Down Expand Up @@ -80,9 +79,8 @@ type (
// NewMuxer returns a Muxer implementation based on a Chi router.
func NewMuxer() ResolverMuxer {
r := chi.NewRouter()
// RedirectSlashes must be mounted at the top level of the router.
// See. https://github.com/go-chi/chi/blob/1129e362d6cce6e3805e3bc8dfbaeb34b5129789/middleware/strip_test.go#L105-L107
r.Use(middleware.RedirectSlashes)
// smartRedirectSlashes must be mounted at the top level of the router.
r.Use(smartRedirectSlashes)
r.NotFound(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
ctx := context.WithValue(req.Context(), AcceptTypeKey, req.Header.Get("Accept"))
enc := ResponseEncoder(ctx, w)
Expand All @@ -92,6 +90,53 @@ func NewMuxer() ResolverMuxer {
return &mux{Router: r, wildcards: make(map[string]string)}
}

// smartRedirectSlashes is a middleware that matches the request path with
// patterns added to the router and redirects it.
// If a pattern is added to the router with a trailing slash, any matches on
// that pattern without a trailing slash will be redirected to the version with
// the slash. If a pattern does not have a trailing slash, matches on that
// pattern with a trailing slash will be redirected to the version without.
func smartRedirectSlashes(next http.Handler) http.Handler {
fn := func(w http.ResponseWriter, r *http.Request) {
rctx := chi.RouteContext(r.Context())
if rctx != nil {
var path string
if rctx.RoutePath != "" {
path = rctx.RoutePath
} else {
path = r.URL.Path
}
var method string
if rctx.RouteMethod != "" {
method = rctx.RouteMethod
} else {
method = r.Method
}
if len(path) > 1 {
if rctx.Routes != nil {
if !rctx.Routes.Match(chi.NewRouteContext(), method, path) {
if path[len(path)-1] == '/' {
path = path[:len(path)-1]
} else {
path += "/"
}
if rctx.Routes.Match(chi.NewRouteContext(), method, path) {
if r.URL.RawQuery != "" {
path = fmt.Sprintf("%s?%s", path, r.URL.RawQuery)
}
redirectURL := fmt.Sprintf("//%s%s", r.Host, path)
http.Redirect(w, r, redirectURL, http.StatusMovedPermanently)
return
}
}
}
}
}
next.ServeHTTP(w, r)
}
return http.HandlerFunc(fn)
}

// wildPath matches a wildcard path segment.
var wildPath = regexp.MustCompile(`/{\*([a-zA-Z0-9_]+)}`)

Expand Down
27 changes: 21 additions & 6 deletions http/mux_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ import (
"fmt"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"

"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -201,6 +203,12 @@ func TestResolvePattern(t *testing.T) {
URL: "/users",
Expected: "/users",
},
{
Name: "query",
Patterns: []string{"/users"},
URL: "/users?name=foo",
Expected: "/users",
},
}

for _, c := range cases {
Expand All @@ -227,15 +235,22 @@ func TestResolvePattern(t *testing.T) {
w := httptest.NewRecorder()
mux.ServeHTTP(w, req)
assert.True(t, called)
// Make sure the URL with a trailing slash is redirected.
// Test the URL with a trailing slash.
called = false // Reset.
req, _ = http.NewRequest("GET", c.URL+"/", nil)
w = httptest.NewRecorder()
mux.ServeHTTP(w, req)
assert.Equal(t, http.StatusMovedPermanently, w.Code)
req, _ = http.NewRequest("GET", w.Header().Get("Location"), nil)
u, err := url.Parse(c.URL)
assert.NoError(t, err)
u.Path += "/"
req, _ = http.NewRequest("GET", u.String(), nil)
w = httptest.NewRecorder()
mux.ServeHTTP(w, req)
if strings.Contains(c.Name, "wildcard") {
assert.Equal(t, http.StatusOK, w.Code)
} else {
assert.Equal(t, http.StatusMovedPermanently, w.Code)
req, _ = http.NewRequest("GET", w.Header().Get("Location"), nil)
w = httptest.NewRecorder()
mux.ServeHTTP(w, req)
}
assert.True(t, called)
})
}
Expand Down