Skip to content

Commit

Permalink
Replace RedirectSlashes with smartRedirectSlashes
Browse files Browse the repository at this point in the history
  • Loading branch information
tchssk committed Oct 6, 2023
1 parent 48d61c3 commit 2e0b381
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 10 deletions.
53 changes: 49 additions & 4 deletions http/mux.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import (
"regexp"

chi "github.com/go-chi/chi/v5"
"github.com/go-chi/chi/v5/middleware"
)

type (
Expand Down Expand Up @@ -80,9 +79,8 @@ type (
// NewMuxer returns a Muxer implementation based on a Chi router.
func NewMuxer() ResolverMuxer {
r := chi.NewRouter()
// RedirectSlashes must be mounted at the top level of the router.
// See. https://github.com/go-chi/chi/blob/1129e362d6cce6e3805e3bc8dfbaeb34b5129789/middleware/strip_test.go#L105-L107
r.Use(middleware.RedirectSlashes)
// smartRedirectSlashes must be mounted at the top level of the router.
r.Use(smartRedirectSlashes)
r.NotFound(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
ctx := context.WithValue(req.Context(), AcceptTypeKey, req.Header.Get("Accept"))
enc := ResponseEncoder(ctx, w)
Expand All @@ -92,6 +90,53 @@ func NewMuxer() ResolverMuxer {
return &mux{Router: r, wildcards: make(map[string]string)}
}

// smartRedirectSlashes is a middleware that matches the request path with
// patterns added to the router and redirects it.
// If a pattern is added to the router with a trailing slash, any matches on
// that pattern without a trailing slash will be redirected to the version with
// the slash. If a pattern does not have a trailing slash, matches on that
// pattern with a trailing slash will be redirected to the version without.
func smartRedirectSlashes(next http.Handler) http.Handler {
fn := func(w http.ResponseWriter, r *http.Request) {
rctx := chi.RouteContext(r.Context())
if rctx != nil {
var path string
if rctx.RoutePath != "" {
path = rctx.RoutePath
} else {
path = r.URL.Path
}
var method string
if rctx.RouteMethod != "" {
method = rctx.RouteMethod
} else {
method = r.Method
}
if len(path) > 1 {
if rctx.Routes != nil {
if !rctx.Routes.Match(chi.NewRouteContext(), method, path) {
if path[len(path)-1] == '/' {
path = path[:len(path)-1]
} else {
path += "/"
}
if rctx.Routes.Match(chi.NewRouteContext(), method, path) {
if r.URL.RawQuery != "" {
path = fmt.Sprintf("%s?%s", path, r.URL.RawQuery)
}
redirectURL := fmt.Sprintf("//%s%s", r.Host, path)
http.Redirect(w, r, redirectURL, http.StatusMovedPermanently)
return
}
}
}
}
}
next.ServeHTTP(w, r)
}
return http.HandlerFunc(fn)
}

// wildPath matches a wildcard path segment.
var wildPath = regexp.MustCompile(`/{\*([a-zA-Z0-9_]+)}`)

Expand Down
27 changes: 21 additions & 6 deletions http/mux_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ import (
"fmt"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"

"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -201,6 +203,12 @@ func TestResolvePattern(t *testing.T) {
URL: "/users",
Expected: "/users",
},
{
Name: "query",
Patterns: []string{"/users"},
URL: "/users?name=foo",
Expected: "/users",
},
}

for _, c := range cases {
Expand All @@ -227,15 +235,22 @@ func TestResolvePattern(t *testing.T) {
w := httptest.NewRecorder()
mux.ServeHTTP(w, req)
assert.True(t, called)
// Make sure the URL with a trailing slash is redirected.
// Test the URL with a trailing slash.
called = false // Reset.
req, _ = http.NewRequest("GET", c.URL+"/", nil)
w = httptest.NewRecorder()
mux.ServeHTTP(w, req)
assert.Equal(t, http.StatusMovedPermanently, w.Code)
req, _ = http.NewRequest("GET", w.Header().Get("Location"), nil)
u, err := url.Parse(c.URL)
assert.NoError(t, err)
u.Path += "/"
req, _ = http.NewRequest("GET", u.String(), nil)
w = httptest.NewRecorder()
mux.ServeHTTP(w, req)
if strings.Contains(c.Name, "wildcard") {
assert.Equal(t, http.StatusOK, w.Code)
} else {
assert.Equal(t, http.StatusMovedPermanently, w.Code)
req, _ = http.NewRequest("GET", w.Header().Get("Location"), nil)
w = httptest.NewRecorder()
mux.ServeHTTP(w, req)
}
assert.True(t, called)
})
}
Expand Down

0 comments on commit 2e0b381

Please sign in to comment.