Skip to content

Commit

Permalink
Merge pull request #271 from sjbarag/let-before-funcs-modify-response…
Browse files Browse the repository at this point in the history
…-status-code

fix: let ResponseWriter.Before() callbacks change status
  • Loading branch information
jszwedko authored Mar 11, 2023
2 parents 0d1d16b + 83566e9 commit dc8359a
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 5 deletions.
34 changes: 29 additions & 5 deletions response_writer.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,17 +36,28 @@ func NewResponseWriter(rw http.ResponseWriter) ResponseWriter {

type responseWriter struct {
http.ResponseWriter
status int
size int
beforeFuncs []beforeFunc
pendingStatus int
status int
size int
beforeFuncs []beforeFunc
callingBefores bool
}

func (rw *responseWriter) WriteHeader(s int) {
if rw.Written() {
return
}
rw.status = s

rw.pendingStatus = s
rw.callBefore()

// Any of the rw.beforeFuncs may have written a header,
// so check again to see if any work is necessary.
if rw.Written() {
return
}

rw.status = s
rw.ResponseWriter.WriteHeader(s)
}

Expand Down Expand Up @@ -74,7 +85,11 @@ func (rw *responseWriter) ReadFrom(r io.Reader) (n int64, err error) {
}

func (rw *responseWriter) Status() int {
return rw.status
if rw.Written() {
return rw.status
}

return rw.pendingStatus
}

func (rw *responseWriter) Size() int {
Expand All @@ -90,6 +105,15 @@ func (rw *responseWriter) Before(before func(ResponseWriter)) {
}

func (rw *responseWriter) callBefore() {
// Don't recursively call before() functions, to avoid infinite looping if
// one of them calls rw.WriteHeader again.
if rw.callingBefores {
return
}

rw.callingBefores = true
defer func() { rw.callingBefores = false }()

for i := len(rw.beforeFuncs) - 1; i >= 0; i-- {
rw.beforeFuncs[i](rw)
}
Expand Down
28 changes: 28 additions & 0 deletions response_writer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,34 @@ func TestResponseWriterBeforeFuncHasAccessToStatus(t *testing.T) {
expect(t, status, http.StatusCreated)
}

func TestResponseWriterBeforeFuncCanChangeStatus(t *testing.T) {
rec := httptest.NewRecorder()
rw := NewResponseWriter(rec)

// Always respond with 200.
rw.Before(func(w ResponseWriter) {
w.WriteHeader(http.StatusOK)
})

rw.WriteHeader(http.StatusBadRequest)
expect(t, rec.Code, http.StatusOK)
}

func TestResponseWriterBeforeFuncChangesStatusMultipleTimes(t *testing.T) {
rec := httptest.NewRecorder()
rw := NewResponseWriter(rec)

rw.Before(func(w ResponseWriter) {
w.WriteHeader(http.StatusInternalServerError)
})
rw.Before(func(w ResponseWriter) {
w.WriteHeader(http.StatusNotFound)
})

rw.WriteHeader(http.StatusOK)
expect(t, rec.Code, http.StatusNotFound)
}

func TestResponseWriterWritingString(t *testing.T) {
rec := httptest.NewRecorder()
rw := NewResponseWriter(rec)
Expand Down

0 comments on commit dc8359a

Please sign in to comment.