Skip to content

Commit

Permalink
Merge branch 'v3' into dependabot/go_modules/github.com/gorilla/webso…
Browse files Browse the repository at this point in the history
…cket-1.5.1
  • Loading branch information
raphael authored Dec 2, 2023
2 parents fc2b4dc + 1dc06f4 commit 54b0456
Showing 1 changed file with 30 additions and 8 deletions.
38 changes: 30 additions & 8 deletions http/mux.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"net/http"
"net/url"
"regexp"
"sync"

chi "github.com/go-chi/chi/v5"
)
Expand Down Expand Up @@ -70,6 +71,10 @@ type (
// mux is the default Muxer implementation.
mux struct {
chi.Router
// protect access to middlewares and handlers
mu sync.Mutex
// middlewares to be registered before handlers
middlewares []func(http.Handler) http.Handler
// wildcards maps a method and a pattern to the name of the wildcard
// this is needed because chi does not expose the name of the wildcard
wildcards map[string]string
Expand All @@ -78,21 +83,32 @@ type (

// NewMuxer returns a Muxer implementation based on a Chi router.
func NewMuxer() ResolverMuxer {
r := chi.NewRouter()
r.NotFound(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
ctx := context.WithValue(req.Context(), AcceptTypeKey, req.Header.Get("Accept"))
enc := ResponseEncoder(ctx, w)
w.WriteHeader(http.StatusNotFound)
enc.Encode(NewErrorResponse(ctx, fmt.Errorf("404 page not found"))) // nolint:errcheck
}))
return &mux{Router: r, wildcards: make(map[string]string)}
return &mux{
Router: chi.NewRouter(),
wildcards: make(map[string]string),
middlewares: []func(http.Handler) http.Handler{},
}
}

// wildPath matches a wildcard path segment.
var wildPath = regexp.MustCompile(`/{\*([a-zA-Z0-9_]+)}`)

// Handle registers the handler function for the given method and pattern.
func (m *mux) Handle(method, pattern string, handler http.HandlerFunc) {
m.mu.Lock()
defer m.mu.Unlock()
if m.middlewares != nil {
for _, middleware := range m.middlewares {
m.Router.Use(middleware)
}
m.NotFound(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
ctx := context.WithValue(req.Context(), AcceptTypeKey, req.Header.Get("Accept"))
enc := ResponseEncoder(ctx, w)
w.WriteHeader(http.StatusNotFound)
enc.Encode(NewErrorResponse(ctx, fmt.Errorf("404 page not found"))) // nolint:errcheck
}))
m.middlewares = nil
}
if wildcards := wildPath.FindStringSubmatch(pattern); len(wildcards) > 0 {
if len(wildcards) > 2 {
panic("too many wildcards")
Expand Down Expand Up @@ -136,6 +152,12 @@ func unescape(s string) string {
// Use appends a middleware to the list of middlewares to be applied
// downstream the Muxer.
func (m *mux) Use(f func(http.Handler) http.Handler) {
m.mu.Lock()
defer m.mu.Unlock()
if m.middlewares != nil {
m.middlewares = append(m.middlewares, f)
return
}
m.Router.Use(f)
}

Expand Down

0 comments on commit 54b0456

Please sign in to comment.