From b104f60e20b4e12f86b374f4f21af555d6c558c1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jes=C3=BAs=20Garc=C3=ADa=20Crespo?= Date: Thu, 22 Jun 2023 18:35:45 +0000 Subject: [PATCH] Use writeTimeout middleware to set write deadlines --- internal/api/api.go | 21 +++++------------ internal/api/middleware.go | 29 ++++++++++++++++++++++++ internal/api/middleware_test.go | 40 +++++++++++++++++++++++++++++++++ 3 files changed, 74 insertions(+), 16 deletions(-) create mode 100644 internal/api/middleware.go create mode 100644 internal/api/middleware_test.go diff --git a/internal/api/api.go b/internal/api/api.go index 83f03474..7cdc1fa4 100644 --- a/internal/api/api.go +++ b/internal/api/api.go @@ -68,7 +68,7 @@ func HTTPServer( collectionErrorHandler := errorHandler(logger, "Collection error.") var collectionServer *collectionsvr.Server = collectionsvr.New(collectionEndpoints, mux, dec, enc, collectionErrorHandler, nil, websocketUpgrader, nil) // Intercept request in Download endpoint so we can serve the file directly. - collectionServer.Download = colsvc.HTTPDownload(mux, dec) + collectionServer.Download = writeTimeout(colsvc.HTTPDownload(mux, dec), 0) collectionsvr.Mount(mux, collectionServer) // Swagger service. @@ -90,12 +90,10 @@ func HTTPServer( } return &http.Server{ - Addr: config.Listen, - Handler: handler, - ReadTimeout: time.Second * 5, - // WriteTimeout is set to 0 because we have streaming endpoints. - // https://github.com/golang/go/issues/16100#issuecomment-285573480 - WriteTimeout: 0, + Addr: config.Listen, + Handler: handler, + ReadTimeout: time.Second * 5, + WriteTimeout: time.Second * 5, IdleTimeout: time.Second * 120, } } @@ -127,15 +125,6 @@ func errorHandler(logger logr.Logger, msg string) func(context.Context, http.Res } } -func versionHeaderMiddleware(version string) func(http.Handler) http.Handler { - return func(h http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("X-Enduro-Version", version) - h.ServeHTTP(w, r) - }) - } -} - func sameOriginChecker(logger logr.Logger) func(r *http.Request) bool { return func(r *http.Request) bool { origin := r.Header["Origin"] diff --git a/internal/api/middleware.go b/internal/api/middleware.go new file mode 100644 index 00000000..f54a6d31 --- /dev/null +++ b/internal/api/middleware.go @@ -0,0 +1,29 @@ +package api + +import ( + "net/http" + "time" +) + +func versionHeaderMiddleware(version string) func(http.Handler) http.Handler { + return func(h http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("X-Enduro-Version", version) + h.ServeHTTP(w, r) + }) + } +} + +// writeTimeout sets the write deadline for writing the response. A zero value +// means no timeout. +func writeTimeout(h http.Handler, timeout time.Duration) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + rc := http.NewResponseController(w) + var deadline time.Time + if timeout != 0 { + deadline = time.Now().Add(timeout) + } + _ = rc.SetWriteDeadline(deadline) + h.ServeHTTP(w, r) + }) +} diff --git a/internal/api/middleware_test.go b/internal/api/middleware_test.go new file mode 100644 index 00000000..d47fc18f --- /dev/null +++ b/internal/api/middleware_test.go @@ -0,0 +1,40 @@ +package api + +import ( + "io" + "net/http" + "net/http/httptest" + "testing" + "time" + + "gotest.tools/v3/assert" +) + +func TestWriteTimeout(t *testing.T) { + t.Parallel() + + h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + time.Sleep(time.Microsecond * 100) + w.Write([]byte("Hi there!")) + }) + + t.Run("Sets a write timeout", func(t *testing.T) { + ts := httptest.NewServer(writeTimeout(h, time.Microsecond)) + defer ts.Close() + + _, err := ts.Client().Get(ts.URL) + assert.ErrorIs(t, err, io.EOF) + }) + + t.Run("Sets an unlimited write timeout", func(t *testing.T) { + ts := httptest.NewServer(writeTimeout(h, 0)) + defer ts.Close() + + resp, err := ts.Client().Get(ts.URL) + assert.NilError(t, err) + + blob, err := io.ReadAll(resp.Body) + assert.NilError(t, err) + assert.Equal(t, string(blob), "Hi there!") + }) +}