diff --git a/timeout.go b/timeout.go index ec59143..75165a9 100644 --- a/timeout.go +++ b/timeout.go @@ -70,7 +70,7 @@ func New(opts ...Option) gin.HandlerFunc { for k, vv := range tw.Header() { dst[k] = vv } - tw.ResponseWriter.WriteHeader(tw.code) + if _, err := tw.ResponseWriter.Write(buffer.Bytes()); err != nil { panic(err) } diff --git a/writer.go b/writer.go index eba95ea..4c64fec 100644 --- a/writer.go +++ b/writer.go @@ -37,7 +37,9 @@ func (w *Writer) Write(data []byte) (int, error) { return w.body.Write(data) } -// WriteHeader will write http status code +// WriteHeader sends an HTTP response header with the provided status code. +// If the response writer has already written headers or if a timeout has occurred, +// this method does nothing. func (w *Writer) WriteHeader(code int) { checkWriteHeaderCode(code) if w.timeout || w.wroteHeaders { @@ -48,6 +50,7 @@ func (w *Writer) WriteHeader(code int) { defer w.mu.Unlock() w.writeHeader(code) + w.ResponseWriter.WriteHeader(code) } func (w *Writer) writeHeader(code int) { diff --git a/writer_test.go b/writer_test.go index 614c151..5213597 100644 --- a/writer_test.go +++ b/writer_test.go @@ -2,6 +2,7 @@ package timeout import ( "fmt" + "log" "net/http" "net/http/httptest" "strconv" @@ -57,3 +58,148 @@ func TestWriter_Status(t *testing.T) { assert.Equal(t, http.StatusInternalServerError, w.Code) assert.Equal(t, strconv.Itoa(http.StatusInternalServerError), req.Header.Get("X-Status-Code-MW-Set")) } + +// testNew is a copy of New() with a small change to the timeoutHandler() function. +// ref: https://github.com/gin-contrib/timeout/issues/31 +func testNew(duration time.Duration) gin.HandlerFunc { + return New( + WithTimeout(duration), + WithHandler(func(c *gin.Context) { c.Next() }), + WithResponse(timeoutHandler()), + ) +} + +// timeoutHandler returns a handler that returns a 504 Gateway Timeout error. +func timeoutHandler() gin.HandlerFunc { + gatewayTimeoutErr := struct { + Error string `json:"error"` + }{ + Error: "Timed out.", + } + + return func(c *gin.Context) { + log.Printf("request timed out: [method=%s,path=%s]", + c.Request.Method, c.Request.URL.Path) + c.JSON(http.StatusGatewayTimeout, gatewayTimeoutErr) + } +} + +// TestHTTPStatusCode tests the HTTP status code of the response. +func TestHTTPStatusCode(t *testing.T) { + gin.SetMode(gin.ReleaseMode) + + type testCase struct { + Name string + Method string + Path string + ExpStatusCode int + Handler gin.HandlerFunc + } + + var ( + cases = []testCase{ + { + Name: "Plain text (200)", + Method: http.MethodGet, + Path: "/me", + ExpStatusCode: http.StatusOK, + Handler: func(ctx *gin.Context) { + ctx.String(http.StatusOK, "I'm text!") + }, + }, + { + Name: "Plain text (201)", + Method: http.MethodGet, + Path: "/me", + ExpStatusCode: http.StatusCreated, + Handler: func(ctx *gin.Context) { + ctx.String(http.StatusCreated, "I'm created!") + }, + }, + { + Name: "Plain text (204)", + Method: http.MethodGet, + Path: "/me", + ExpStatusCode: http.StatusNoContent, + Handler: func(ctx *gin.Context) { + ctx.String(http.StatusNoContent, "") + }, + }, + { + Name: "Plain text (400)", + Method: http.MethodGet, + Path: "/me", + ExpStatusCode: http.StatusBadRequest, + Handler: func(ctx *gin.Context) { + ctx.String(http.StatusBadRequest, "") + }, + }, + { + Name: "JSON (200)", + Method: http.MethodGet, + Path: "/me", + ExpStatusCode: http.StatusOK, + Handler: func(ctx *gin.Context) { + ctx.JSON(http.StatusOK, gin.H{"field": "value"}) + }, + }, + { + Name: "JSON (201)", + Method: http.MethodGet, + Path: "/me", + ExpStatusCode: http.StatusCreated, + Handler: func(ctx *gin.Context) { + ctx.JSON(http.StatusCreated, gin.H{"field": "value"}) + }, + }, + { + Name: "JSON (204)", + Method: http.MethodGet, + Path: "/me", + ExpStatusCode: http.StatusNoContent, + Handler: func(ctx *gin.Context) { + ctx.JSON(http.StatusNoContent, nil) + }, + }, + { + Name: "JSON (400)", + Method: http.MethodGet, + Path: "/me", + ExpStatusCode: http.StatusBadRequest, + Handler: func(ctx *gin.Context) { + ctx.JSON(http.StatusBadRequest, nil) + }, + }, + { + Name: "No reply", + Method: http.MethodGet, + Path: "/me", + ExpStatusCode: http.StatusOK, + Handler: func(ctx *gin.Context) {}, + }, + } + + initCase = func(c testCase) (*http.Request, *httptest.ResponseRecorder) { + return httptest.NewRequest(c.Method, c.Path, nil), httptest.NewRecorder() + } + ) + + for i := range cases { + t.Run(cases[i].Name, func(tt *testing.T) { + tt.Logf("Test case [%s]", cases[i].Name) + + router := gin.Default() + + router.Use(testNew(1 * time.Second)) + router.GET("/*root", cases[i].Handler) + + req, resp := initCase(cases[i]) + router.ServeHTTP(resp, req) + + if resp.Code != cases[i].ExpStatusCode { + tt.Errorf("response is different from expected:\nexp: >>>%d<<<\ngot: >>>%d<<<", + cases[i].ExpStatusCode, resp.Code) + } + }) + } +}