Skip to content

Commit

Permalink
Resolve route patterns with default mux
Browse files Browse the repository at this point in the history
Fix issue with `Vars` when multiple routes use wild cards.
  • Loading branch information
raphael committed Sep 22, 2023
1 parent b86a916 commit 92d3900
Show file tree
Hide file tree
Showing 2 changed files with 124 additions and 6 deletions.
45 changes: 39 additions & 6 deletions http/mux.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,23 +59,32 @@ type (
Use(func(http.Handler) http.Handler)
}

// ResolverMuxer is a MiddlewareMuxer that can resolve the route pattern used
// to register the handler for the given request.
ResolverMuxer interface {
MiddlewareMuxer
ResolvePattern(*http.Request) string
}

// mux is the default Muxer implementation.
mux struct {
chi.Router
wildcard string
// wildcards maps a method and a pattern to the name of the wildcard
// this is needed because chi does not expose the name of the wildcard
wildcards map[string]string
}
)

// NewMuxer returns a Muxer implementation based on a Chi router.
func NewMuxer() MiddlewareMuxer {
func NewMuxer() ResolverMuxer {
r := chi.NewRouter()
r.NotFound(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
ctx := context.WithValue(req.Context(), AcceptTypeKey, req.Header.Get("Accept"))
enc := ResponseEncoder(ctx, w)
w.WriteHeader(http.StatusNotFound)
enc.Encode(NewErrorResponse(ctx, fmt.Errorf("404 page not found"))) // nolint:errcheck
}))
return &mux{Router: r}
return &mux{Router: r, wildcards: make(map[string]string)}
}

// wildPath matches a wildcard path segment.
Expand All @@ -87,22 +96,24 @@ func (m *mux) Handle(method, pattern string, handler http.HandlerFunc) {
if len(wildcards) > 2 {
panic("too many wildcards")
}
m.wildcard = wildcards[1]
pattern = wildPath.ReplaceAllString(pattern, "/*")
m.wildcards[method+"::"+pattern] = wildcards[1]
}
m.Method(method, pattern, handler)
}

// Vars extracts the path variables from the request context.
func (m *mux) Vars(r *http.Request) map[string]string {
params := chi.RouteContext(r.Context()).URLParams
rctx := chi.RouteContext(r.Context())
params := rctx.URLParams
if len(params.Keys) == 0 {
return nil
}
vars := make(map[string]string, len(params.Keys))
for i, k := range params.Keys {
if k == "*" {
vars[m.wildcard] = params.Values[i]
wildcard := m.wildcards[r.Method+"::"+rctx.RoutePattern()]
vars[wildcard] = params.Values[i]
continue
}
vars[k] = params.Values[i]
Expand All @@ -115,3 +126,25 @@ func (m *mux) Vars(r *http.Request) map[string]string {
func (m *mux) Use(f func(http.Handler) http.Handler) {
m.Router.Use(f)
}

// ResolvePattern returns the route pattern used to register the handler for the
// given method and path.
func (m *mux) ResolvePattern(r *http.Request) string {
ctx := chi.RouteContext(r.Context())
if ctx.RoutePattern() != "" {
return m.resolveWildcard(r.Method, ctx.RoutePattern())
}
if !m.Router.Match(ctx, r.Method, r.URL.Path) {
return ""
}
return m.resolveWildcard(r.Method, ctx.RoutePattern())
}

// resolveWildcard returns the route pattern with the wildcard replaced by the
// name of the wildcard.
func (m *mux) resolveWildcard(method, pattern string) string {
if wildcard, ok := m.wildcards[method+"::"+pattern]; ok {
return pattern[:len(pattern)-2] + "/{*" + wildcard + "}"
}
return pattern
}
85 changes: 85 additions & 0 deletions http/mux_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,3 +128,88 @@ func TestVars(t *testing.T) {
})
}
}

func TestResolvePattern(t *testing.T) {
cases := []struct {
Name string
Patterns []string
URL string
Expected string
}{
{
Name: "simple",
Patterns: []string{"/users/{id}"},
URL: "/users/123",
Expected: "/users/{id}",
},
{
Name: "multiple",
Patterns: []string{"/users/{id}/posts/{post_id}"},
URL: "/users/123/posts/456",
Expected: "/users/{id}/posts/{post_id}",
},
{
Name: "two patterns",
Patterns: []string{"/users/{id}/posts/{post_id}", "/users/{id}/posts/{post_id}/comments/{comment_id}"},
URL: "/users/123/posts/456",
Expected: "/users/{id}/posts/{post_id}",
},
{
Name: "two patterns deep",
Patterns: []string{"/users/{id}/posts/{post_id}", "/users/{id}/posts/{post_id}/comments/{comment_id}"},
URL: "/users/123/posts/456/comments/789",
Expected: "/users/{id}/posts/{post_id}/comments/{comment_id}",
},
{
Name: "wildcard",
Patterns: []string{"/users/{id}/posts/{*post_id}"},
URL: "/users/123/posts/456/789",
Expected: "/users/{id}/posts/{*post_id}",
},
{
Name: "two wildcards",
Patterns: []string{"/users/{id}/posts/{*post_id}", "/users/{id}/posts/{post_id}/comments/{*comment_id}"},
URL: "/users/123/posts/456/789",
Expected: "/users/{id}/posts/{*post_id}",
},
{
Name: "two wildcards deep",
Patterns: []string{"/users/{id}/posts/{*post_id}", "/users/{id}/posts/{post_id}/comments/{*comment_id}"},
URL: "/users/123/posts/456/comments/abc",
Expected: "/users/{id}/posts/{post_id}/comments/{*comment_id}",
},
{
Name: "no var",
Patterns: []string{"/users"},
URL: "/users",
Expected: "/users",
},
}

for _, c := range cases {
t.Run(c.Name, func(t *testing.T) {
var called bool
mux := NewMuxer()
handler := http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
pattern := mux.ResolvePattern(r)
assert.Equal(t, c.Expected, pattern)
called = true
})
// Make sure resolver works with middlewares.
handler = func(next http.HandlerFunc) http.HandlerFunc {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
pattern := mux.ResolvePattern(r)
assert.Equal(t, c.Expected, pattern)
next.ServeHTTP(w, r)
})
}(handler)
for _, p := range c.Patterns {
mux.Handle("GET", p, handler)
}
req, _ := http.NewRequest("GET", c.URL, nil)
w := httptest.NewRecorder()
mux.ServeHTTP(w, req)
assert.True(t, called)
})
}
}

0 comments on commit 92d3900

Please sign in to comment.