diff --git a/gzip.go b/gzip.go index 957fc92..82d1483 100644 --- a/gzip.go +++ b/gzip.go @@ -4,13 +4,14 @@ import ( "bufio" "compress/gzip" "fmt" + "github.com/NYTimes/gziphandler/writer" + "github.com/NYTimes/gziphandler/writer/stdlib" "io" "mime" "net" "net/http" "strconv" "strings" - "sync" ) const ( @@ -36,48 +37,15 @@ const ( DefaultMinSize = 1400 ) -// gzipWriterPools stores a sync.Pool for each compression level for reuse of -// gzip.Writers. Use poolIndex to covert a compression level to an index into -// gzipWriterPools. -var gzipWriterPools [gzip.BestCompression - gzip.BestSpeed + 2]*sync.Pool - -func init() { - for i := gzip.BestSpeed; i <= gzip.BestCompression; i++ { - addLevelPool(i) - } - addLevelPool(gzip.DefaultCompression) -} - -// poolIndex maps a compression level to its index into gzipWriterPools. It -// assumes that level is a valid gzip compression level. -func poolIndex(level int) int { - // gzip.DefaultCompression == -1, so we need to treat it special. - if level == gzip.DefaultCompression { - return gzip.BestCompression - gzip.BestSpeed + 1 - } - return level - gzip.BestSpeed -} - -func addLevelPool(level int) { - gzipWriterPools[poolIndex(level)] = &sync.Pool{ - New: func() interface{} { - // NewWriterLevel only returns error on a bad level, we are guaranteeing - // that this will be a valid level so it is okay to ignore the returned - // error. - w, _ := gzip.NewWriterLevel(nil, level) - return w - }, - } -} - // GzipResponseWriter provides an http.ResponseWriter interface, which gzips // bytes before writing them to the underlying response. This doesn't close the // writers, so don't forget to do that. // It can be configured to skip response smaller than minSize. type GzipResponseWriter struct { http.ResponseWriter - index int // Index for gzipWriterPools. - gw *gzip.Writer + level int + gwFactory writer.GzipWriterFactory + gw writer.GzipWriter code int // Saves the WriteHeader value. @@ -217,9 +185,7 @@ func (w *GzipResponseWriter) WriteHeader(code int) { func (w *GzipResponseWriter) init() { // Bytes written during ServeHTTP are redirected to this gzip writer // before being written to the underlying response. - gzw := gzipWriterPools[w.index].Get().(*gzip.Writer) - gzw.Reset(w.ResponseWriter) - w.gw = gzw + w.gw = w.gwFactory(w.ResponseWriter, w.level) } // Close will close the gzip.Writer and will put it back in the gzipWriterPool. @@ -239,7 +205,6 @@ func (w *GzipResponseWriter) Close() error { } err := w.gw.Close() - gzipWriterPools[w.index].Put(w.gw) w.gw = nil return err } @@ -305,8 +270,9 @@ func NewGzipLevelAndMinSize(level, minSize int) (func(http.Handler) http.Handler func GzipHandlerWithOpts(opts ...option) (func(http.Handler) http.Handler, error) { c := &config{ - level: gzip.DefaultCompression, - minSize: DefaultMinSize, + level: gzip.DefaultCompression, + minSize: DefaultMinSize, + newWriter: stdlib.NewWriter, } for _, o := range opts { @@ -318,14 +284,13 @@ func GzipHandlerWithOpts(opts ...option) (func(http.Handler) http.Handler, error } return func(h http.Handler) http.Handler { - index := poolIndex(c.level) - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Add(vary, acceptEncoding) if acceptsGzip(r) { gw := &GzipResponseWriter{ ResponseWriter: w, - index: index, + gwFactory: c.newWriter, + level: c.level, minSize: c.minSize, contentTypes: c.contentTypes, } @@ -378,6 +343,7 @@ func (pct parsedContentType) equals(mediaType string, params map[string]string) type config struct { minSize int level int + newWriter writer.GzipWriterFactory contentTypes []parsedContentType } @@ -407,6 +373,16 @@ func CompressionLevel(level int) option { } } +// Implementation changes the implementation of GzipWriter +// +// The default implementation is writer/stdlib/NewWriter +// which is backed by standard library's compress/zlib +func Implementation(writer writer.GzipWriterFactory) option { + return func(c *config) { + c.newWriter = writer + } +} + // ContentTypes specifies a list of content types to compare // the Content-Type header to before compressing. If none // match, the response will be returned as-is. diff --git a/gzip_test.go b/gzip_test.go index bed7f52..69d1ae3 100644 --- a/gzip_test.go +++ b/gzip_test.go @@ -321,31 +321,6 @@ func TestGzipHandlerMinSize(t *testing.T) { } } -func TestGzipDoubleClose(t *testing.T) { - // reset the pool for the default compression so we can make sure duplicates - // aren't added back by double close - addLevelPool(gzip.DefaultCompression) - - handler := GzipHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // call close here and it'll get called again interally by - // NewGzipLevelHandler's handler defer - w.Write([]byte("test")) - w.(io.Closer).Close() - })) - - r := httptest.NewRequest("GET", "/", nil) - r.Header.Set("Accept-Encoding", "gzip") - w := httptest.NewRecorder() - handler.ServeHTTP(w, r) - - // the second close shouldn't have added the same writer - // so we pull out 2 writers from the pool and make sure they're different - w1 := gzipWriterPools[poolIndex(gzip.DefaultCompression)].Get() - w2 := gzipWriterPools[poolIndex(gzip.DefaultCompression)].Get() - // assert.NotEqual looks at the value and not the address, so we use regular == - assert.False(t, w1 == w2) -} - type panicOnSecondWriteHeaderWriter struct { http.ResponseWriter headerWritten bool diff --git a/writer/interface.go b/writer/interface.go new file mode 100644 index 0000000..de4f041 --- /dev/null +++ b/writer/interface.go @@ -0,0 +1,11 @@ +package writer + +import "io" + +type GzipWriter interface { + Close() error + Flush() error + Write(p []byte) (int, error) +} + +type GzipWriterFactory = func(writer io.Writer, level int) GzipWriter diff --git a/writer/stdlib/stdlib.go b/writer/stdlib/stdlib.go new file mode 100644 index 0000000..65db479 --- /dev/null +++ b/writer/stdlib/stdlib.go @@ -0,0 +1,68 @@ +package stdlib + +import ( + "compress/gzip" + "github.com/NYTimes/gziphandler/writer" + "io" + "sync" +) + +// gzipWriterPools stores a sync.Pool for each compression level for reuse of +// gzip.Writers. Use poolIndex to covert a compression level to an index into +// gzipWriterPools. +var gzipWriterPools [gzip.BestCompression - gzip.BestSpeed + 2]*sync.Pool + +func init() { + for i := gzip.BestSpeed; i <= gzip.BestCompression; i++ { + addLevelPool(i) + } + addLevelPool(gzip.DefaultCompression) +} + +// poolIndex maps a compression level to its index into gzipWriterPools. It +// assumes that level is a valid gzip compression level. +func poolIndex(level int) int { + // gzip.DefaultCompression == -1, so we need to treat it special. + if level == gzip.DefaultCompression { + return gzip.BestCompression - gzip.BestSpeed + 1 + } + return level - gzip.BestSpeed +} + +func addLevelPool(level int) { + gzipWriterPools[poolIndex(level)] = &sync.Pool{ + New: func() interface{} { + // NewWriterLevel only returns error on a bad level, we are guaranteeing + // that this will be a valid level so it is okay to ignore the returned + // error. + w, _ := gzip.NewWriterLevel(nil, level) + return w + }, + } +} + +type pooledWriter struct { + *gzip.Writer + index int +} + +func (pw *pooledWriter) Close() error { + err := pw.Writer.Close() + gzipWriterPools[pw.index].Put(pw.Writer) + pw.Writer = nil + return err +} + +func NewWriter(w io.Writer, level int) writer.GzipWriter { + index := poolIndex(level) + gzw := gzipWriterPools[index].Get().(*gzip.Writer) + gzw.Reset(w) + return &pooledWriter{ + Writer: gzw, + index: index, + } +} + +func ImplementationInfo() string { + return "compress/zlib" +} diff --git a/writer/stdlib/stdlib_test.go b/writer/stdlib/stdlib_test.go new file mode 100644 index 0000000..ebfd9d2 --- /dev/null +++ b/writer/stdlib/stdlib_test.go @@ -0,0 +1,25 @@ +package stdlib + +import ( + "bytes" + "compress/gzip" + "github.com/stretchr/testify/assert" + "testing" +) + +func TestGzipDoubleClose(t *testing.T) { + // reset the pool for the default compression so we can make sure duplicates + // aren't added back by double close + addLevelPool(gzip.DefaultCompression) + + w := bytes.NewBufferString("") + writer := NewWriter(w, gzip.DefaultCompression) + writer.Close() + + // the second close shouldn't have added the same writer + // so we pull out 2 writers from the pool and make sure they're different + w1 := gzipWriterPools[poolIndex(gzip.DefaultCompression)].Get() + w2 := gzipWriterPools[poolIndex(gzip.DefaultCompression)].Get() + // assert.NotEqual looks at the value and not the address, so we use regular == + assert.False(t, w1 == w2) +}