Skip to content

Commit

Permalink
Ensure ReadStream with direct=true isn't racy
Browse files Browse the repository at this point in the history
Partially addresses #17. Big thanks to @jonboulle for spotting that
head-slapper.
  • Loading branch information
peterbourgon committed Dec 11, 2014
1 parent fd6bc2e commit 129c9e4
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 5 deletions.
14 changes: 9 additions & 5 deletions diskv.go
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ func (d *Diskv) Read(key string) ([]byte, error) {
// ReadStream will use the cached value. Otherwise, it will return a handle to
// the file on disk, and cache the data on read.
//
// If direct is true, ReadStream will always delete any cached value for the
// If direct is true, ReadStream will lazily delete any cached value for the
// key, and return a direct handle to the file on disk.
//
// If compression is enabled, ReadStream taps into the io.Reader stream prior
Expand All @@ -185,16 +185,20 @@ func (d *Diskv) ReadStream(key string, direct bool) (io.ReadCloser, error) {
defer d.RUnlock()

if val, ok := d.cache[key]; ok {
if direct {
d.cacheSize -= uint64(len(val))
delete(d.cache, key)
} else {
if !direct {
buf := bytes.NewBuffer(val)
if d.Compression != nil {
return d.Compression.Reader(buf)
}
return ioutil.NopCloser(buf), nil
}

go func() {
d.Lock()
defer d.Unlock()
d.cacheSize -= uint64(len(val))
delete(d.cache, key)
}()
}

return d.read(key)
Expand Down
47 changes: 47 additions & 0 deletions issues_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package diskv
import (
"bytes"
"io/ioutil"
"sync"
"testing"
"time"
)
Expand Down Expand Up @@ -72,3 +73,49 @@ func TestIssue2B(t *testing.T) {
}
t.Logf("ReadStream('abc') returned error: %v", err)
}

// Ensure ReadStream with direct=true isn't racy.
func TestIssue17(t *testing.T) {
var (
basePath = "test-data"
)

dWrite := New(Options{
BasePath: basePath,
CacheSizeMax: 0,
})
defer dWrite.EraseAll()

dRead := New(Options{
BasePath: basePath,
CacheSizeMax: 50,
})

cases := map[string]string{
"a": `1234567890`,
"b": `2345678901`,
"c": `3456789012`,
"d": `4567890123`,
"e": `5678901234`,
}

for k, v := range cases {
if err := dWrite.Write(k, []byte(v)); err != nil {
t.Fatalf("during write: %s", err)
}
dRead.Read(k) // ensure it's added to cache
}

var wg sync.WaitGroup
start := make(chan struct{})
for k, v := range cases {
wg.Add(1)
go func(k, v string) {
<-start
dRead.ReadStream(k, true)
wg.Done()
}(k, v)
}
close(start)
wg.Wait()
}

0 comments on commit 129c9e4

Please sign in to comment.