Skip to content

Commit

Permalink
Use writeTimeout middleware to set write deadlines
Browse files Browse the repository at this point in the history
  • Loading branch information
sevein committed Jun 22, 2023
1 parent 926d49c commit b104f60
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 16 deletions.
21 changes: 5 additions & 16 deletions internal/api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ func HTTPServer(
collectionErrorHandler := errorHandler(logger, "Collection error.")
var collectionServer *collectionsvr.Server = collectionsvr.New(collectionEndpoints, mux, dec, enc, collectionErrorHandler, nil, websocketUpgrader, nil)
// Intercept request in Download endpoint so we can serve the file directly.
collectionServer.Download = colsvc.HTTPDownload(mux, dec)
collectionServer.Download = writeTimeout(colsvc.HTTPDownload(mux, dec), 0)

Check warning on line 71 in internal/api/api.go

View check run for this annotation

Codecov / codecov/patch

internal/api/api.go#L71

Added line #L71 was not covered by tests
collectionsvr.Mount(mux, collectionServer)

// Swagger service.
Expand All @@ -90,12 +90,10 @@ func HTTPServer(
}

return &http.Server{
Addr: config.Listen,
Handler: handler,
ReadTimeout: time.Second * 5,
// WriteTimeout is set to 0 because we have streaming endpoints.
// https://github.com/golang/go/issues/16100#issuecomment-285573480
WriteTimeout: 0,
Addr: config.Listen,
Handler: handler,
ReadTimeout: time.Second * 5,
WriteTimeout: time.Second * 5,

Check warning on line 96 in internal/api/api.go

View check run for this annotation

Codecov / codecov/patch

internal/api/api.go#L93-L96

Added lines #L93 - L96 were not covered by tests
IdleTimeout: time.Second * 120,
}
}
Expand Down Expand Up @@ -127,15 +125,6 @@ func errorHandler(logger logr.Logger, msg string) func(context.Context, http.Res
}
}

func versionHeaderMiddleware(version string) func(http.Handler) http.Handler {
return func(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("X-Enduro-Version", version)
h.ServeHTTP(w, r)
})
}
}

func sameOriginChecker(logger logr.Logger) func(r *http.Request) bool {
return func(r *http.Request) bool {
origin := r.Header["Origin"]
Expand Down
29 changes: 29 additions & 0 deletions internal/api/middleware.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
package api

import (
"net/http"
"time"
)

func versionHeaderMiddleware(version string) func(http.Handler) http.Handler {
return func(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("X-Enduro-Version", version)
h.ServeHTTP(w, r)
})

Check warning on line 13 in internal/api/middleware.go

View check run for this annotation

Codecov / codecov/patch

internal/api/middleware.go#L9-L13

Added lines #L9 - L13 were not covered by tests
}
}

// writeTimeout sets the write deadline for writing the response. A zero value
// means no timeout.
func writeTimeout(h http.Handler, timeout time.Duration) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
rc := http.NewResponseController(w)
var deadline time.Time
if timeout != 0 {
deadline = time.Now().Add(timeout)
}
_ = rc.SetWriteDeadline(deadline)
h.ServeHTTP(w, r)
})
}
40 changes: 40 additions & 0 deletions internal/api/middleware_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
package api

import (
"io"
"net/http"
"net/http/httptest"
"testing"
"time"

"gotest.tools/v3/assert"
)

func TestWriteTimeout(t *testing.T) {
t.Parallel()

h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
time.Sleep(time.Microsecond * 100)
w.Write([]byte("Hi there!"))
})

t.Run("Sets a write timeout", func(t *testing.T) {
ts := httptest.NewServer(writeTimeout(h, time.Microsecond))
defer ts.Close()

_, err := ts.Client().Get(ts.URL)
assert.ErrorIs(t, err, io.EOF)
})

t.Run("Sets an unlimited write timeout", func(t *testing.T) {
ts := httptest.NewServer(writeTimeout(h, 0))
defer ts.Close()

resp, err := ts.Client().Get(ts.URL)
assert.NilError(t, err)

blob, err := io.ReadAll(resp.Body)
assert.NilError(t, err)
assert.Equal(t, string(blob), "Hi there!")
})
}

0 comments on commit b104f60

Please sign in to comment.