diff --git a/README.md b/README.md index 06a7601..0173762 100644 --- a/README.md +++ b/README.md @@ -39,7 +39,7 @@ import ( func main() { r := gin.Default() - r.Use(gzip.Gzip(gzip.DefaultCompression)) + r.Use(gzip.Gzip(gzip.DefaultCompression, 0)) r.GET("/ping", func(c *gin.Context) { c.String(200, "pong "+fmt.Sprint(time.Now().Unix())) }) diff --git a/example/example.go b/example/example.go index 24042cf..cd9ee5a 100644 --- a/example/example.go +++ b/example/example.go @@ -10,7 +10,7 @@ import ( func main() { r := gin.Default() - r.Use(gzip.Gzip(gzip.DefaultCompression)) + r.Use(gzip.Gzip(gzip.DefaultCompression, 0)) r.GET("/ping", func(c *gin.Context) { c.String(200, "pong "+fmt.Sprint(time.Now().Unix())) }) diff --git a/gzip.go b/gzip.go index df22715..367f019 100644 --- a/gzip.go +++ b/gzip.go @@ -1,6 +1,7 @@ package gzip import ( + "bytes" "compress/gzip" "fmt" "io/ioutil" @@ -19,7 +20,16 @@ const ( NoCompression = gzip.NoCompression ) -func Gzip(level int) gin.HandlerFunc { +// Gzip returns a Gin handler that implements transparent gzip +// compression of the response. +// +// Gzip expects two parameter: +// - level: the compression level. One of BestCompression, BestSpeed, +// DefaultCompression, or NoCompression +// - minLength: the minimal response length in bytes that is required +// in order to actually compress the response. Disable it by setting it +// to 0. +func Gzip(level, minLength int) gin.HandlerFunc { var gzPool sync.Pool gzPool.New = func() interface{} { gz, err := gzip.NewWriterLevel(ioutil.Discard, level) @@ -37,28 +47,85 @@ func Gzip(level int) gin.HandlerFunc { defer gzPool.Put(gz) gz.Reset(c.Writer) - c.Header("Content-Encoding", "gzip") - c.Header("Vary", "Accept-Encoding") - c.Writer = &gzipWriter{c.Writer, gz} - defer func() { - gz.Close() - c.Header("Content-Length", fmt.Sprint(c.Writer.Size())) - }() + gzWriter := &gzipWriter{ + ResponseWriter: c.Writer, + writer: gz, + minLength: minLength, + } + + // Replace the context writer with a gzip writer + c.Writer = gzWriter + c.Next() + + if gzWriter.compress { + // Just close and flush the gz writer + gz.Close() + } else { + // Discard the gz writer + gz.Reset(ioutil.Discard) + + // Write the buffered data into the original writer + gzWriter.ResponseWriter.Write(gzWriter.buffer.Bytes()) + } + + // Set the content length if it's still possible + c.Header("Content-Length", fmt.Sprint(c.Writer.Size())) } } type gzipWriter struct { gin.ResponseWriter writer *gzip.Writer + + // Buffer to store partial response + buffer bytes.Buffer + + // Minimal length of buffer before content will be compressed + minLength int + + // Whether the response should be compressed + compress bool } func (g *gzipWriter) WriteString(s string) (int, error) { - return g.writer.Write([]byte(s)) + return g.Write([]byte(s)) } -func (g *gzipWriter) Write(data []byte) (int, error) { - return g.writer.Write(data) +func (g *gzipWriter) Write(data []byte) (w int, err error) { + // If the first chunk of data is already bigger than the minimum size, + // set the headers and write directly to the gz writer + if g.compress == false && len(data) >= g.minLength { + g.ResponseWriter.Header().Set("Content-Encoding", "gzip") + g.ResponseWriter.Header().Set("Vary", "Accept-Encoding") + + g.compress = true + } + + if !g.compress { + // Write the data into a buffer + w, err = g.buffer.Write(data) + if err != nil { + return + } + + // If the buffer is bigger than the minimum size, set the headers and write + // the buffered data into the gz writer + if g.buffer.Len() >= g.minLength { + g.ResponseWriter.Header().Set("Content-Encoding", "gzip") + g.ResponseWriter.Header().Set("Vary", "Accept-Encoding") + + _, err = g.writer.Write(g.buffer.Bytes()) + g.compress = true + } + + return + } + + // Write the data into the gz writer + w, err = g.writer.Write(data) + + return } // Fix: https://github.com/mholt/caddy/issues/38 diff --git a/gzip_test.go b/gzip_test.go index 88fd8c6..cba6911 100644 --- a/gzip_test.go +++ b/gzip_test.go @@ -50,18 +50,22 @@ func (c *closeNotifyingRecorder) CloseNotify() <-chan bool { return c.closed } -func newServer() *gin.Engine { +func newServer(minLength int) *gin.Engine { // init reverse proxy server rServer := httptest.NewServer(new(rServer)) target, _ := url.Parse(rServer.URL) rp := httputil.NewSingleHostReverseProxy(target) router := gin.New() - router.Use(Gzip(DefaultCompression)) + router.Use(Gzip(DefaultCompression, minLength)) router.GET("/", func(c *gin.Context) { c.Header("Content-Length", strconv.Itoa(len(testResponse))) c.String(200, testResponse) }) + router.GET("/minlengthbuffer", func(c *gin.Context) { + c.String(200, testResponse) + c.String(200, testResponse) + }) router.Any("/reverse", func(c *gin.Context) { rp.ServeHTTP(c.Writer, c.Request) }) @@ -73,14 +77,14 @@ func TestGzip(t *testing.T) { req.Header.Add("Accept-Encoding", "gzip") w := httptest.NewRecorder() - r := newServer() + r := newServer(0) r.ServeHTTP(w, req) assert.Equal(t, w.Code, 200) assert.Equal(t, w.Header().Get("Content-Encoding"), "gzip") assert.Equal(t, w.Header().Get("Vary"), "Accept-Encoding") assert.NotEqual(t, w.Header().Get("Content-Length"), "0") - assert.NotEqual(t, w.Body.Len(), 19) + assert.NotEqual(t, w.Body.Len(), len(testResponse)) assert.Equal(t, fmt.Sprint(w.Body.Len()), w.Header().Get("Content-Length")) gr, err := gzip.NewReader(w.Body) @@ -91,12 +95,54 @@ func TestGzip(t *testing.T) { assert.Equal(t, string(body), testResponse) } +func TestGzipMinLengthNoBuffer(t *testing.T) { + req, _ := http.NewRequest("GET", "/", nil) + req.Header.Add("Accept-Encoding", "gzip") + + w := httptest.NewRecorder() + r := newServer(100) + r.ServeHTTP(w, req) + + assert.Equal(t, w.Code, 200) + assert.NotEqual(t, w.Header().Get("Content-Encoding"), "gzip") + assert.NotEqual(t, w.Header().Get("Vary"), "Accept-Encoding") + assert.NotEqual(t, w.Header().Get("Content-Length"), "0") + assert.Equal(t, w.Body.Len(), len(testResponse)) + assert.Equal(t, fmt.Sprint(w.Body.Len()), w.Header().Get("Content-Length")) + + body, _ := ioutil.ReadAll(w.Body) + assert.Equal(t, string(body), testResponse) +} + +func TestGzipMinLengthBuffer(t *testing.T) { + req, _ := http.NewRequest("GET", "/minlengthbuffer", nil) + req.Header.Add("Accept-Encoding", "gzip") + + w := httptest.NewRecorder() + r := newServer(len(testResponse) + 1) + r.ServeHTTP(w, req) + + assert.Equal(t, w.Code, 200) + assert.Equal(t, w.Header().Get("Content-Encoding"), "gzip") + assert.Equal(t, w.Header().Get("Vary"), "Accept-Encoding") + assert.NotEqual(t, w.Header().Get("Content-Length"), "0") + assert.NotEqual(t, w.Body.Len(), 2*len(testResponse)) + assert.Equal(t, fmt.Sprint(w.Body.Len()), w.Header().Get("Content-Length")) + + gr, err := gzip.NewReader(w.Body) + assert.NoError(t, err) + defer gr.Close() + + body, _ := ioutil.ReadAll(gr) + assert.Equal(t, string(body), testResponse + testResponse) +} + func TestGzipPNG(t *testing.T) { req, _ := http.NewRequest("GET", "/image.png", nil) req.Header.Add("Accept-Encoding", "gzip") router := gin.New() - router.Use(Gzip(DefaultCompression)) + router.Use(Gzip(DefaultCompression, 0)) router.GET("/image.png", func(c *gin.Context) { c.String(200, "this is a PNG!") }) @@ -114,7 +160,7 @@ func TestNoGzip(t *testing.T) { req, _ := http.NewRequest("GET", "/", nil) w := httptest.NewRecorder() - r := newServer() + r := newServer(0) r.ServeHTTP(w, req) assert.Equal(t, w.Code, 200) @@ -128,14 +174,14 @@ func TestGzipWithReverseProxy(t *testing.T) { req.Header.Add("Accept-Encoding", "gzip") w := newCloseNotifyingRecorder() - r := newServer() + r := newServer(0) r.ServeHTTP(w, req) assert.Equal(t, w.Code, 200) assert.Equal(t, w.Header().Get("Content-Encoding"), "gzip") assert.Equal(t, w.Header().Get("Vary"), "Accept-Encoding") assert.NotEqual(t, w.Header().Get("Content-Length"), "0") - assert.NotEqual(t, w.Body.Len(), 19) + assert.NotEqual(t, w.Body.Len(), len(testReverseResponse)) assert.Equal(t, fmt.Sprint(w.Body.Len()), w.Header().Get("Content-Length")) gr, err := gzip.NewReader(w.Body) @@ -145,3 +191,23 @@ func TestGzipWithReverseProxy(t *testing.T) { body, _ := ioutil.ReadAll(gr) assert.Equal(t, string(body), testReverseResponse) } + +func TestGzipMinLengthWithReverseProxy(t *testing.T) { + req, _ := http.NewRequest("GET", "/reverse", nil) + req.Header.Add("Accept-Encoding", "gzip") + + w := newCloseNotifyingRecorder() + r := newServer(100) + r.ServeHTTP(w, req) + + assert.Equal(t, w.Code, 200) + assert.NotEqual(t, w.Header().Get("Content-Encoding"), "gzip") + assert.NotEqual(t, w.Header().Get("Vary"), "Accept-Encoding") + assert.NotEqual(t, w.Header().Get("Content-Length"), "0") + assert.Equal(t, w.Body.Len(), len(testReverseResponse)) + assert.Equal(t, fmt.Sprint(w.Body.Len()), w.Header().Get("Content-Length")) + + body, _ := ioutil.ReadAll(w.Body) + assert.Equal(t, string(body), testReverseResponse) +} +