diff --git a/http3/response_writer.go b/http3/response_writer.go index d158b33ad46..8638ec57779 100644 --- a/http3/response_writer.go +++ b/http3/response_writer.go @@ -112,23 +112,24 @@ func (w *responseWriter) WriteHeader(status int) { } } +func (w *responseWriter) sniffContentType(p []byte) { + // If no content type, apply sniffing algorithm to body. + // We can't use `w.header.Get` here since if the Content-Type was set to nil, we shouldn't do sniffing. + _, haveType := w.header["Content-Type"] + + // If the Transfer-Encoding or Content-Encoding was set and is non-blank, + // we shouldn't sniff the body. + hasTE := w.header.Get("Transfer-Encoding") != "" + hasCE := w.header.Get("Content-Encoding") != "" + if !hasCE && !haveType && !hasTE && len(p) > 0 { + w.header.Set("Content-Type", http.DetectContentType(p)) + } +} + func (w *responseWriter) Write(p []byte) (int, error) { bodyAllowed := bodyAllowedForStatus(w.status) if !w.headerComplete { - // If body is not allowed, we don't need to (and we can't) sniff the content type. - if bodyAllowed { - // If no content type, apply sniffing algorithm to body. - // We can't use `w.header.Get` here since if the Content-Type was set to nil, we shoundn't do sniffing. - _, haveType := w.header["Content-Type"] - - // If the Transfer-Encoding or Content-Encoding was set and is non-blank, - // we shouldn't sniff the body. - hasTE := w.header.Get("Transfer-Encoding") != "" - hasCE := w.header.Get("Content-Encoding") != "" - if !hasCE && !haveType && !hasTE && len(p) > 0 { - w.header.Set("Content-Type", http.DetectContentType(p)) - } - } + w.sniffContentType(p) w.WriteHeader(http.StatusOK) bodyAllowed = true } @@ -158,6 +159,7 @@ func (w *responseWriter) Write(p []byte) (int, error) { func (w *responseWriter) doWrite(p []byte) (int, error) { if !w.headerWritten { + w.sniffContentType(w.smallResponseBuf) if err := w.writeHeader(w.status); err != nil { return 0, maybeReplaceError(err) } diff --git a/http3/response_writer_test.go b/http3/response_writer_test.go index f49e07308cf..f68e5aceed4 100644 --- a/http3/response_writer_test.go +++ b/http3/response_writer_test.go @@ -137,9 +137,10 @@ var _ = Describe("Response Writer", func() { // According to the spec, headers sent in the informational response must also be included in the final response fields = decodeHeader(strBuf) - Expect(fields).To(HaveLen(3)) + Expect(fields).To(HaveLen(4)) Expect(fields).To(HaveKeyWithValue(":status", []string{"200"})) Expect(fields).To(HaveKey("date")) + Expect(fields).To(HaveKey("content-type")) Expect(fields).To(HaveKeyWithValue("link", []string{"; rel=preload; as=style", "; rel=preload; as=script"})) Expect(getData(strBuf)).To(Equal([]byte("foobar"))) diff --git a/http3/server_test.go b/http3/server_test.go index efe6d1a129a..1a30e9f74a6 100644 --- a/http3/server_test.go +++ b/http3/server_test.go @@ -207,6 +207,26 @@ var _ = Describe("Server", func() { Expect(hfs).To(HaveLen(4)) }) + It("sets Content-Type when WriteHeader is called but response is not flushed", func() { + s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNotFound) + w.Write([]byte("")) + }) + + responseBuf := &bytes.Buffer{} + setRequest(encodeRequest(exampleGetRequest)) + str.EXPECT().Context().Return(reqContext) + str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes() + str.EXPECT().CancelRead(gomock.Any()) + str.EXPECT().Close() + + s.handleRequest(conn, str, nil, qpackDecoder) + hfs := decodeHeader(responseBuf) + Expect(hfs).To(HaveKeyWithValue(":status", []string{"404"})) + Expect(hfs).To(HaveKeyWithValue("content-length", []string{"13"})) + Expect(hfs).To(HaveKeyWithValue("content-type", []string{"text/html; charset=utf-8"})) + }) + It("not sets Content-Length when the handler flushes to the client", func() { s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Write([]byte("foobar"))