diff --git a/http/mux.go b/http/mux.go index abf536816b..924030cf8c 100644 --- a/http/mux.go +++ b/http/mux.go @@ -59,15 +59,24 @@ type ( Use(func(http.Handler) http.Handler) } + // ResolverMuxer is a MiddlewareMuxer that can resolve the route pattern used + // to register the handler for the given request. + ResolverMuxer interface { + MiddlewareMuxer + ResolvePattern(*http.Request) string + } + // mux is the default Muxer implementation. mux struct { chi.Router - wildcard string + // wildcards maps a method and a pattern to the name of the wildcard + // this is needed because chi does not expose the name of the wildcard + wildcards map[string]string } ) // NewMuxer returns a Muxer implementation based on a Chi router. -func NewMuxer() MiddlewareMuxer { +func NewMuxer() ResolverMuxer { r := chi.NewRouter() r.NotFound(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { ctx := context.WithValue(req.Context(), AcceptTypeKey, req.Header.Get("Accept")) @@ -75,7 +84,7 @@ func NewMuxer() MiddlewareMuxer { w.WriteHeader(http.StatusNotFound) enc.Encode(NewErrorResponse(ctx, fmt.Errorf("404 page not found"))) // nolint:errcheck })) - return &mux{Router: r} + return &mux{Router: r, wildcards: make(map[string]string)} } // wildPath matches a wildcard path segment. @@ -87,22 +96,24 @@ func (m *mux) Handle(method, pattern string, handler http.HandlerFunc) { if len(wildcards) > 2 { panic("too many wildcards") } - m.wildcard = wildcards[1] pattern = wildPath.ReplaceAllString(pattern, "/*") + m.wildcards[method+"::"+pattern] = wildcards[1] } m.Method(method, pattern, handler) } // Vars extracts the path variables from the request context. func (m *mux) Vars(r *http.Request) map[string]string { - params := chi.RouteContext(r.Context()).URLParams + rctx := chi.RouteContext(r.Context()) + params := rctx.URLParams if len(params.Keys) == 0 { return nil } vars := make(map[string]string, len(params.Keys)) for i, k := range params.Keys { if k == "*" { - vars[m.wildcard] = params.Values[i] + wildcard := m.wildcards[r.Method+"::"+rctx.RoutePattern()] + vars[wildcard] = params.Values[i] continue } vars[k] = params.Values[i] @@ -115,3 +126,25 @@ func (m *mux) Vars(r *http.Request) map[string]string { func (m *mux) Use(f func(http.Handler) http.Handler) { m.Router.Use(f) } + +// ResolvePattern returns the route pattern used to register the handler for the +// given method and path. +func (m *mux) ResolvePattern(r *http.Request) string { + ctx := chi.RouteContext(r.Context()) + if ctx.RoutePattern() != "" { + return m.resolveWildcard(r.Method, ctx.RoutePattern()) + } + if !m.Router.Match(ctx, r.Method, r.URL.Path) { + return "" + } + return m.resolveWildcard(r.Method, ctx.RoutePattern()) +} + +// resolveWildcard returns the route pattern with the wildcard replaced by the +// name of the wildcard. +func (m *mux) resolveWildcard(method, pattern string) string { + if wildcard, ok := m.wildcards[method+"::"+pattern]; ok { + return pattern[:len(pattern)-2] + "/{*" + wildcard + "}" + } + return pattern +} diff --git a/http/mux_test.go b/http/mux_test.go index b12870de31..d8b8c569c4 100644 --- a/http/mux_test.go +++ b/http/mux_test.go @@ -128,3 +128,88 @@ func TestVars(t *testing.T) { }) } } + +func TestResolvePattern(t *testing.T) { + cases := []struct { + Name string + Patterns []string + URL string + Expected string + }{ + { + Name: "simple", + Patterns: []string{"/users/{id}"}, + URL: "/users/123", + Expected: "/users/{id}", + }, + { + Name: "multiple", + Patterns: []string{"/users/{id}/posts/{post_id}"}, + URL: "/users/123/posts/456", + Expected: "/users/{id}/posts/{post_id}", + }, + { + Name: "two patterns", + Patterns: []string{"/users/{id}/posts/{post_id}", "/users/{id}/posts/{post_id}/comments/{comment_id}"}, + URL: "/users/123/posts/456", + Expected: "/users/{id}/posts/{post_id}", + }, + { + Name: "two patterns deep", + Patterns: []string{"/users/{id}/posts/{post_id}", "/users/{id}/posts/{post_id}/comments/{comment_id}"}, + URL: "/users/123/posts/456/comments/789", + Expected: "/users/{id}/posts/{post_id}/comments/{comment_id}", + }, + { + Name: "wildcard", + Patterns: []string{"/users/{id}/posts/{*post_id}"}, + URL: "/users/123/posts/456/789", + Expected: "/users/{id}/posts/{*post_id}", + }, + { + Name: "two wildcards", + Patterns: []string{"/users/{id}/posts/{*post_id}", "/users/{id}/posts/{post_id}/comments/{*comment_id}"}, + URL: "/users/123/posts/456/789", + Expected: "/users/{id}/posts/{*post_id}", + }, + { + Name: "two wildcards deep", + Patterns: []string{"/users/{id}/posts/{*post_id}", "/users/{id}/posts/{post_id}/comments/{*comment_id}"}, + URL: "/users/123/posts/456/comments/abc", + Expected: "/users/{id}/posts/{post_id}/comments/{*comment_id}", + }, + { + Name: "no var", + Patterns: []string{"/users"}, + URL: "/users", + Expected: "/users", + }, + } + + for _, c := range cases { + t.Run(c.Name, func(t *testing.T) { + var called bool + mux := NewMuxer() + handler := http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) { + pattern := mux.ResolvePattern(r) + assert.Equal(t, c.Expected, pattern) + called = true + }) + // Make sure resolver works with middlewares. + handler = func(next http.HandlerFunc) http.HandlerFunc { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + pattern := mux.ResolvePattern(r) + assert.Equal(t, c.Expected, pattern) + next.ServeHTTP(w, r) + }) + }(handler) + for _, p := range c.Patterns { + mux.Handle("GET", p, handler) + } + req, _ := http.NewRequest("GET", c.URL, nil) + w := httptest.NewRecorder() + mux.ServeHTTP(w, req) + assert.True(t, called) + }) + } +}