From ae2e9b17679962991b40a356dbec360aed5c4cb1 Mon Sep 17 00:00:00 2001 From: Ben Tranter Date: Sat, 30 Dec 2023 12:24:36 -0500 Subject: [PATCH] seatbelt: add option to skip csrf checks --- example/cmd/main.go | 10 ++++--- seatbelt.go | 15 +++++++++++ seatbelt_test.go | 66 +++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 88 insertions(+), 3 deletions(-) diff --git a/example/cmd/main.go b/example/cmd/main.go index 8befa6c..d11f208 100644 --- a/example/cmd/main.go +++ b/example/cmd/main.go @@ -9,9 +9,10 @@ import ( func main() { app := seatbelt.New(seatbelt.Option{ - TemplateDir: "templates", - Reload: true, - LocaleDir: "locales", + TemplateDir: "templates", + Reload: true, + LocaleDir: "locales", + SkipCSRFPaths: []string{"/api"}, }) app.Use(func(fn func(ctx *seatbelt.Context) error) func(*seatbelt.Context) error { @@ -26,6 +27,9 @@ func main() { app.Get("/", func(c *seatbelt.Context) error { return c.Render("index", nil) }) + app.Post("/api", func(c *seatbelt.Context) error { + return c.JSON(201, map[string]string{"message": "Hello, world!"}) + }) app.Get("/rendertostring", func(c *seatbelt.Context) error { page := c.RenderToBytes("index", nil) _, err := c.Response().Write([]byte(page)) diff --git a/seatbelt.go b/seatbelt.go index 2ad35e0..ee0d3e6 100644 --- a/seatbelt.go +++ b/seatbelt.go @@ -305,6 +305,10 @@ type Option struct { // SkipServeFiles does not automatically serve static files from the // project's /public directory when set to true. Default is false. SkipServeFiles bool + + // SkipCSRFPaths is used to skip the CSRF validation to POST, PUT, PATCH, + // DELETE, etc requests to paths that match one of the given paths. + SkipCSRFPaths []string } // setDefaults sets the default values for Seatbelt options. @@ -408,6 +412,17 @@ func New(opts ...Option) *App { // Initialize the underlying chi mux so that we can setup our default // middleware stack. mux := chi.NewRouter() + mux.Use(func(h http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + for _, skipPath := range opt.SkipCSRFPaths { + if strings.HasPrefix(r.URL.Path, skipPath) { + r = csrf.UnsafeSkipCheck(r) + + } + } + h.ServeHTTP(w, r) + }) + }) mux.Use(csrf.Protect(signingKey, csrf.Path("/"))) sess := session.New(signingKey, session.Options{ diff --git a/seatbelt_test.go b/seatbelt_test.go index 4dd6b47..a0f13e5 100644 --- a/seatbelt_test.go +++ b/seatbelt_test.go @@ -73,3 +73,69 @@ func TestSubRouter(t *testing.T) { } }) } + +func TestCSRFSkipPaths(t *testing.T) { + app := New(Option{ + SkipCSRFPaths: []string{"/api", "/skip-me"}, + }) + + app.Get("/", func(c *Context) error { + return c.JSON(200, map[string]string{"message": "ok"}) + }) + app.Post("/", func(c *Context) error { + return c.NoContent() + }) + app.Post("/api", func(c *Context) error { + return c.NoContent() + }) + app.Put("/skip-me/test", func(c *Context) error { + return c.NoContent() + }) + + srv := httptest.NewServer(app) + defer srv.Close() + + cases := []struct { + path string + method string + status int + }{ + { + path: "/", + method: http.MethodGet, + status: 200, + }, + { + path: "/", + method: http.MethodPost, + status: 403, + }, + { + path: "/api", + method: http.MethodPost, + status: 204, + }, + { + path: "/skip-me/test", + method: http.MethodPut, + status: 204, + }, + } + + for _, c := range cases { + t.Run(c.method+" "+c.path, func(t *testing.T) { + req, err := http.NewRequest(c.method, srv.URL+c.path, nil) + if err != nil { + t.Fatal(err) + } + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatal(err) + } + + if resp.StatusCode != c.status { + t.Fatalf("expected %d but got %d", c.status, resp.StatusCode) + } + }) + } +}