diff --git a/singleflight/singleflight.go b/singleflight/singleflight.go index ff2c2ee4..319797e2 100644 --- a/singleflight/singleflight.go +++ b/singleflight/singleflight.go @@ -25,6 +25,8 @@ type call struct { wg sync.WaitGroup val interface{} err error + // true if call has completed; guarded by (Once)Group.mu + complete bool } // Group represents a class of work and forms a namespace in which @@ -62,3 +64,37 @@ func (g *Group) Do(key string, fn func() (interface{}, error)) (interface{}, err return c.val, c.err } + +// OnceGroup is like Group, but caches the results of calls. +type OnceGroup struct { + mu sync.Mutex // protects m + m map[string]*call // lazily initialized +} + +func (g *OnceGroup) Do(key string, fn func() (interface{}, error)) (interface{}, error) { + g.mu.Lock() + if g.m == nil { + g.m = make(map[string]*call) + } + if c, ok := g.m[key]; ok { + if c.complete { + g.mu.Unlock() + return c.val, c.err + } + g.mu.Unlock() + c.wg.Wait() + return c.val, c.err + } + c := new(call) + c.wg.Add(1) + g.m[key] = c + g.mu.Unlock() + + c.val, c.err = fn() + g.mu.Lock() + c.complete = true + g.mu.Unlock() + c.wg.Done() + + return c.val, c.err +} diff --git a/singleflight/singleflight_test.go b/singleflight/singleflight_test.go index 47b4d3dc..95f0002b 100644 --- a/singleflight/singleflight_test.go +++ b/singleflight/singleflight_test.go @@ -83,3 +83,114 @@ func TestDoDupSuppress(t *testing.T) { t.Errorf("number of calls = %d; want 1", got) } } + +func TestDoCalledTwice(t *testing.T) { + var g Group + c := make(chan string) + var calls int32 + fn := func() (interface{}, error) { + atomic.AddInt32(&calls, 1) + return <-c, nil + } + + const n = 10 + var wg sync.WaitGroup + for i := 0; i < n; i++ { + wg.Add(1) + go func() { + v, err := g.Do("key", fn) + if err != nil { + t.Errorf("Do error: %v", err) + } + if v.(string) != "bar" { + t.Errorf("got %q; want %q", v, "bar") + } + wg.Done() + }() + } + time.Sleep(100 * time.Millisecond) // let goroutines above block + c <- "bar" + wg.Wait() + go func() { + // call one more time; fn() should get called a second time + v, err := g.Do("key", fn) + if err != nil { + t.Errorf("Do error: %v", err) + } + if v.(string) != "bar" { + t.Errorf("got %q; want %q", v, "bar") + } + }() + c <- "bar" + if got := atomic.LoadInt32(&calls); got != 2 { + t.Errorf("number of calls = %d; want 2", got) + } +} + +func TestOnceDo(t *testing.T) { + var g OnceGroup + v, err := g.Do("key", func() (interface{}, error) { + return "bar", nil + }) + if got, want := fmt.Sprintf("%v (%T)", v, v), "bar (string)"; got != want { + t.Errorf("Do = %v; want %v", got, want) + } + if err != nil { + t.Errorf("Do error = %v", err) + } +} + +func TestOnceDoErr(t *testing.T) { + var g OnceGroup + someErr := errors.New("Some error") + v, err := g.Do("key", func() (interface{}, error) { + return nil, someErr + }) + if err != someErr { + t.Errorf("Do error = %v; want someErr", err) + } + if v != nil { + t.Errorf("unexpected non-nil value %#v", v) + } +} + +func TestOnceDoDupSuppress(t *testing.T) { + var g OnceGroup + c := make(chan string) + var calls int32 + fn := func() (interface{}, error) { + atomic.AddInt32(&calls, 1) + return <-c, nil + } + + const n = 10 + var wg sync.WaitGroup + for i := 0; i < n; i++ { + wg.Add(1) + go func() { + v, err := g.Do("key", fn) + if err != nil { + t.Errorf("Do error: %v", err) + } + if v.(string) != "bar" { + t.Errorf("got %q; want %q", v, "bar") + } + wg.Done() + }() + } + time.Sleep(100 * time.Millisecond) // let goroutines above block + c <- "bar" + wg.Wait() + // one more time after every goroutine has completed - should return the + // same result instantly. + v, err := g.Do("key", fn) + if err != nil { + t.Errorf("Do error: %v", err) + } + if v.(string) != "bar" { + t.Errorf("got %q; want %q", v, "bar") + } + if got := atomic.LoadInt32(&calls); got != 1 { + t.Errorf("number of calls = %d; want 1", got) + } +}