From 129c9e4ad4cd6e9791fbcf69e07d44106b572ae2 Mon Sep 17 00:00:00 2001 From: Peter Bourgon Date: Thu, 11 Dec 2014 20:52:17 +0100 Subject: [PATCH] Ensure ReadStream with direct=true isn't racy Partially addresses #17. Big thanks to @jonboulle for spotting that head-slapper. --- diskv.go | 14 +++++++++----- issues_test.go | 47 +++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 56 insertions(+), 5 deletions(-) diff --git a/diskv.go b/diskv.go index 5a83cae..a0c91ba 100644 --- a/diskv.go +++ b/diskv.go @@ -175,7 +175,7 @@ func (d *Diskv) Read(key string) ([]byte, error) { // ReadStream will use the cached value. Otherwise, it will return a handle to // the file on disk, and cache the data on read. // -// If direct is true, ReadStream will always delete any cached value for the +// If direct is true, ReadStream will lazily delete any cached value for the // key, and return a direct handle to the file on disk. // // If compression is enabled, ReadStream taps into the io.Reader stream prior @@ -185,16 +185,20 @@ func (d *Diskv) ReadStream(key string, direct bool) (io.ReadCloser, error) { defer d.RUnlock() if val, ok := d.cache[key]; ok { - if direct { - d.cacheSize -= uint64(len(val)) - delete(d.cache, key) - } else { + if !direct { buf := bytes.NewBuffer(val) if d.Compression != nil { return d.Compression.Reader(buf) } return ioutil.NopCloser(buf), nil } + + go func() { + d.Lock() + defer d.Unlock() + d.cacheSize -= uint64(len(val)) + delete(d.cache, key) + }() } return d.read(key) diff --git a/issues_test.go b/issues_test.go index a1ba9e6..0b0b109 100644 --- a/issues_test.go +++ b/issues_test.go @@ -3,6 +3,7 @@ package diskv import ( "bytes" "io/ioutil" + "sync" "testing" "time" ) @@ -72,3 +73,49 @@ func TestIssue2B(t *testing.T) { } t.Logf("ReadStream('abc') returned error: %v", err) } + +// Ensure ReadStream with direct=true isn't racy. +func TestIssue17(t *testing.T) { + var ( + basePath = "test-data" + ) + + dWrite := New(Options{ + BasePath: basePath, + CacheSizeMax: 0, + }) + defer dWrite.EraseAll() + + dRead := New(Options{ + BasePath: basePath, + CacheSizeMax: 50, + }) + + cases := map[string]string{ + "a": `1234567890`, + "b": `2345678901`, + "c": `3456789012`, + "d": `4567890123`, + "e": `5678901234`, + } + + for k, v := range cases { + if err := dWrite.Write(k, []byte(v)); err != nil { + t.Fatalf("during write: %s", err) + } + dRead.Read(k) // ensure it's added to cache + } + + var wg sync.WaitGroup + start := make(chan struct{}) + for k, v := range cases { + wg.Add(1) + go func(k, v string) { + <-start + dRead.ReadStream(k, true) + wg.Done() + }(k, v) + } + close(start) + wg.Wait() +}