From cf9b0bc9a06e65912851642ae6c57337a06b0cbe Mon Sep 17 00:00:00 2001 From: Julien Duchesne Date: Fri, 11 Oct 2024 17:14:21 -0400 Subject: [PATCH] `ReusableGoroutinesPool`: Fix datarace on `Close` (#607) --- concurrency/worker.go | 26 +++++++++++++++++++++++--- concurrency/worker_test.go | 28 ++++++++++++++++++++++++++++ 2 files changed, 51 insertions(+), 3 deletions(-) diff --git a/concurrency/worker.go b/concurrency/worker.go index f40f03348..72acc7dd0 100644 --- a/concurrency/worker.go +++ b/concurrency/worker.go @@ -1,5 +1,9 @@ package concurrency +import ( + "sync" +) + // NewReusableGoroutinesPool creates a new worker pool with the given size. // These workers will run the workloads passed through Go() calls. // If all workers are busy, Go() will spawn a new goroutine to run the workload. @@ -18,12 +22,23 @@ func NewReusableGoroutinesPool(size int) *ReusableGoroutinesPool { } type ReusableGoroutinesPool struct { - jobs chan func() + jobsMu sync.RWMutex + closed bool + jobs chan func() } // Go will run the given function in a worker of the pool. // If all workers are busy, Go() will spawn a new goroutine to run the workload. func (p *ReusableGoroutinesPool) Go(f func()) { + p.jobsMu.RLock() + defer p.jobsMu.RUnlock() + + // If the pool is closed, run the function in a new goroutine. + if p.closed { + go f() + return + } + select { case p.jobs <- f: default: @@ -32,7 +47,12 @@ func (p *ReusableGoroutinesPool) Go(f func()) { } // Close stops the workers of the pool. -// No new Do() calls should be performed after calling Close(). +// No new Go() calls should be performed after calling Close(). // Close does NOT wait for all jobs to finish, it is the caller's responsibility to ensure that in the provided workloads. // Close is intended to be used in tests to ensure that no goroutines are leaked. -func (p *ReusableGoroutinesPool) Close() { close(p.jobs) } +func (p *ReusableGoroutinesPool) Close() { + p.jobsMu.Lock() + defer p.jobsMu.Unlock() + p.closed = true + close(p.jobs) +} diff --git a/concurrency/worker_test.go b/concurrency/worker_test.go index 338062055..c8ceef904 100644 --- a/concurrency/worker_test.go +++ b/concurrency/worker_test.go @@ -4,10 +4,12 @@ import ( "regexp" "runtime" "strings" + "sync" "testing" "time" "github.com/stretchr/testify/require" + "go.uber.org/atomic" ) func TestReusableGoroutinesPool(t *testing.T) { @@ -59,3 +61,29 @@ func TestReusableGoroutinesPool(t *testing.T) { } t.Fatalf("expected %d goroutines after closing, got %d", 0, countGoroutines()) } + +// TestReusableGoroutinesPool_Race tests that Close() and Go() can be called concurrently. +func TestReusableGoroutinesPool_Race(t *testing.T) { + w := NewReusableGoroutinesPool(2) + + var runCountAtomic atomic.Int32 + const maxMsgCount = 10 + + var testWG sync.WaitGroup + testWG.Add(1) + go func() { + defer testWG.Done() + for i := 0; i < maxMsgCount; i++ { + w.Go(func() { + runCountAtomic.Add(1) + }) + time.Sleep(10 * time.Millisecond) + } + }() + time.Sleep(10 * time.Millisecond) + w.Close() // close the pool + testWG.Wait() // wait for the test to finish + + runCt := int(runCountAtomic.Load()) + require.Equal(t, runCt, 10, "expected all functions to run") +}