Skip to content

Commit

Permalink
fix(response): refactor timeout and writer functions and tests (#64)
Browse files Browse the repository at this point in the history
- Modify the `New` function in `timeout.go` to include a new line of code
- Delete the `WriteHeader` function in `writer.go`
- Add a new line of code to the `WriteHeader` function in `writer.go`
- Add a new test case to the `TestWriter_Status` function in `writer_test.go`
- Add a new function `testNew` to `writer_test.go`
- Add a new function `timeoutHandler` to `writer_test.go`
- Add a new test case to the `TestHTTPStatusCode` function in `writer_test.go`

fixed by @jeff-lyu

ref: #52
fixed #31
  • Loading branch information
appleboy authored Nov 25, 2023
1 parent 7452411 commit f338d36
Show file tree
Hide file tree
Showing 3 changed files with 151 additions and 2 deletions.
2 changes: 1 addition & 1 deletion timeout.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ func New(opts ...Option) gin.HandlerFunc {
for k, vv := range tw.Header() {
dst[k] = vv
}
tw.ResponseWriter.WriteHeader(tw.code)

if _, err := tw.ResponseWriter.Write(buffer.Bytes()); err != nil {
panic(err)
}
Expand Down
5 changes: 4 additions & 1 deletion writer.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,9 @@ func (w *Writer) Write(data []byte) (int, error) {
return w.body.Write(data)
}

// WriteHeader will write http status code
// WriteHeader sends an HTTP response header with the provided status code.
// If the response writer has already written headers or if a timeout has occurred,
// this method does nothing.
func (w *Writer) WriteHeader(code int) {
checkWriteHeaderCode(code)
if w.timeout || w.wroteHeaders {
Expand All @@ -48,6 +50,7 @@ func (w *Writer) WriteHeader(code int) {
defer w.mu.Unlock()

w.writeHeader(code)
w.ResponseWriter.WriteHeader(code)
}

func (w *Writer) writeHeader(code int) {
Expand Down
146 changes: 146 additions & 0 deletions writer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package timeout

import (
"fmt"
"log"
"net/http"
"net/http/httptest"
"strconv"
Expand Down Expand Up @@ -57,3 +58,148 @@ func TestWriter_Status(t *testing.T) {
assert.Equal(t, http.StatusInternalServerError, w.Code)
assert.Equal(t, strconv.Itoa(http.StatusInternalServerError), req.Header.Get("X-Status-Code-MW-Set"))
}

// testNew is a copy of New() with a small change to the timeoutHandler() function.
// ref: https://github.com/gin-contrib/timeout/issues/31
func testNew(duration time.Duration) gin.HandlerFunc {
return New(
WithTimeout(duration),
WithHandler(func(c *gin.Context) { c.Next() }),
WithResponse(timeoutHandler()),
)
}

// timeoutHandler returns a handler that returns a 504 Gateway Timeout error.
func timeoutHandler() gin.HandlerFunc {
gatewayTimeoutErr := struct {
Error string `json:"error"`
}{
Error: "Timed out.",
}

return func(c *gin.Context) {
log.Printf("request timed out: [method=%s,path=%s]",
c.Request.Method, c.Request.URL.Path)
c.JSON(http.StatusGatewayTimeout, gatewayTimeoutErr)
}
}

// TestHTTPStatusCode tests the HTTP status code of the response.
func TestHTTPStatusCode(t *testing.T) {
gin.SetMode(gin.ReleaseMode)

type testCase struct {
Name string
Method string
Path string
ExpStatusCode int
Handler gin.HandlerFunc
}

var (
cases = []testCase{
{
Name: "Plain text (200)",
Method: http.MethodGet,
Path: "/me",
ExpStatusCode: http.StatusOK,
Handler: func(ctx *gin.Context) {
ctx.String(http.StatusOK, "I'm text!")
},
},
{
Name: "Plain text (201)",
Method: http.MethodGet,
Path: "/me",
ExpStatusCode: http.StatusCreated,
Handler: func(ctx *gin.Context) {
ctx.String(http.StatusCreated, "I'm created!")
},
},
{
Name: "Plain text (204)",
Method: http.MethodGet,
Path: "/me",
ExpStatusCode: http.StatusNoContent,
Handler: func(ctx *gin.Context) {
ctx.String(http.StatusNoContent, "")
},
},
{
Name: "Plain text (400)",
Method: http.MethodGet,
Path: "/me",
ExpStatusCode: http.StatusBadRequest,
Handler: func(ctx *gin.Context) {
ctx.String(http.StatusBadRequest, "")
},
},
{
Name: "JSON (200)",
Method: http.MethodGet,
Path: "/me",
ExpStatusCode: http.StatusOK,
Handler: func(ctx *gin.Context) {
ctx.JSON(http.StatusOK, gin.H{"field": "value"})
},
},
{
Name: "JSON (201)",
Method: http.MethodGet,
Path: "/me",
ExpStatusCode: http.StatusCreated,
Handler: func(ctx *gin.Context) {
ctx.JSON(http.StatusCreated, gin.H{"field": "value"})
},
},
{
Name: "JSON (204)",
Method: http.MethodGet,
Path: "/me",
ExpStatusCode: http.StatusNoContent,
Handler: func(ctx *gin.Context) {
ctx.JSON(http.StatusNoContent, nil)
},
},
{
Name: "JSON (400)",
Method: http.MethodGet,
Path: "/me",
ExpStatusCode: http.StatusBadRequest,
Handler: func(ctx *gin.Context) {
ctx.JSON(http.StatusBadRequest, nil)
},
},
{
Name: "No reply",
Method: http.MethodGet,
Path: "/me",
ExpStatusCode: http.StatusOK,
Handler: func(ctx *gin.Context) {},
},
}

initCase = func(c testCase) (*http.Request, *httptest.ResponseRecorder) {
return httptest.NewRequest(c.Method, c.Path, nil), httptest.NewRecorder()
}
)

for i := range cases {
t.Run(cases[i].Name, func(tt *testing.T) {
tt.Logf("Test case [%s]", cases[i].Name)

router := gin.Default()

router.Use(testNew(1 * time.Second))
router.GET("/*root", cases[i].Handler)

req, resp := initCase(cases[i])
router.ServeHTTP(resp, req)

if resp.Code != cases[i].ExpStatusCode {
tt.Errorf("response is different from expected:\nexp: >>>%d<<<\ngot: >>>%d<<<",
cases[i].ExpStatusCode, resp.Code)
}
})
}
}

0 comments on commit f338d36

Please sign in to comment.