Skip to content

Commit

Permalink
Add SmartRedirectSlashes to http/middleware
Browse files Browse the repository at this point in the history
  • Loading branch information
tchssk committed Oct 15, 2023
1 parent 6f3aa71 commit 8e83930
Show file tree
Hide file tree
Showing 2 changed files with 117 additions and 0 deletions.
59 changes: 59 additions & 0 deletions http/middleware/chi.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
package middleware

import (
"fmt"
"net/http"

"github.com/go-chi/chi/v5"
)

// SmartRedirectSlashes is a middleware that matches the request path with
// patterns added to the router and redirects it.
//
// If a pattern is added to the router with a trailing slash, any matches on
// that pattern without a trailing slash will be redirected to the version with
// the slash. If a pattern does not have a trailing slash, matches on that
// pattern with a trailing slash will be redirected to the version without.
//
// This middleware depends on chi, so it needs to be mounted on chi's router.
// It make the router behavior similar to httptreemux.
func SmartRedirectSlashes(next http.Handler) http.Handler {
fn := func(w http.ResponseWriter, r *http.Request) {
rctx := chi.RouteContext(r.Context())
if rctx != nil {
var path string
if rctx.RoutePath != "" {
path = rctx.RoutePath
} else {
path = r.URL.Path
}
var method string
if rctx.RouteMethod != "" {
method = rctx.RouteMethod
} else {
method = r.Method
}
if len(path) > 1 {
if rctx.Routes != nil {
if !rctx.Routes.Match(chi.NewRouteContext(), method, path) {
if path[len(path)-1] == '/' {
path = path[:len(path)-1]
} else {
path += "/"
}
if rctx.Routes.Match(chi.NewRouteContext(), method, path) {
if r.URL.RawQuery != "" {
path = fmt.Sprintf("%s?%s", path, r.URL.RawQuery)
}
redirectURL := fmt.Sprintf("//%s%s", r.Host, path)
http.Redirect(w, r, redirectURL, http.StatusMovedPermanently)
return
}
}
}
}
}
next.ServeHTTP(w, r)
}
return http.HandlerFunc(fn)
}
58 changes: 58 additions & 0 deletions http/middleware/chi_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
package middleware

import (
"net/http"
"net/http/httptest"
"testing"

"github.com/stretchr/testify/assert"
goahttp "goa.design/goa/v3/http"
)

func TestSmartRedirectSlashes(t *testing.T) {
cases := []struct {
Pattern string
URL string
Status int
Location string
}{
{"/users", "/users", http.StatusOK, ""},
{"/users", "/users/", http.StatusMovedPermanently, "/users"},
{"/users/", "/users/", http.StatusOK, ""},
{"/users/", "/users", http.StatusMovedPermanently, "/users/"},
{"/users/{id}", "/users/123", http.StatusOK, ""},
{"/users/{id}", "/users/123/", http.StatusMovedPermanently, "/users/123"},
{"/users/{id}/", "/users/123/", http.StatusOK, ""},
{"/users/{id}/", "/users/123", http.StatusMovedPermanently, "/users/123/"},
{"/users/{id}/posts/{post_id}", "/users/123/posts/456", http.StatusOK, ""},
{"/users/{id}/posts/{post_id}", "/users/123/posts/456/", http.StatusMovedPermanently, "/users/123/posts/456"},
{"/users/{id}/posts/{post_id}/", "/users/123/posts/456/", http.StatusOK, ""},
{"/users/{id}/posts/{post_id}/", "/users/123/posts/456", http.StatusMovedPermanently, "/users/123/posts/456/"},
{"/users/{id}/posts/{*post_id}", "/users/123/posts/456/789", http.StatusOK, ""},
{"/users/{id}/posts/{*post_id}", "/users/123/posts/456/789/", http.StatusOK, ""},
{"/users", "/users?name=foo", http.StatusOK, ""},
{"/users", "/users/?name=foo", http.StatusMovedPermanently, "/users?name=foo"},
{"/users/", "/users/?name=foo", http.StatusOK, ""},
{"/users/", "/users?name=foo", http.StatusMovedPermanently, "/users/?name=foo"},
}

for _, c := range cases {
t.Run(c.Pattern, func(t *testing.T) {
var called bool
mux := goahttp.NewMuxer()
mux.Use(SmartRedirectSlashes)
handler := http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
called = true
})
mux.Handle("GET", c.Pattern, handler)
req, _ := http.NewRequest("GET", c.URL, nil)
w := httptest.NewRecorder()
mux.ServeHTTP(w, req)
assert.Equal(t, c.Status, w.Code)
assert.Equal(t, w.Code == http.StatusOK, called)
if w.Code == http.StatusMovedPermanently {
assert.Equal(t, c.Location, w.Header().Get("Location"))
}
})
}
}

0 comments on commit 8e83930

Please sign in to comment.