diff --git a/concurrency/worker.go b/concurrency/worker.go index 179791efe..72acc7dd0 100644 --- a/concurrency/worker.go +++ b/concurrency/worker.go @@ -22,15 +22,22 @@ func NewReusableGoroutinesPool(size int) *ReusableGoroutinesPool { } type ReusableGoroutinesPool struct { - jobsMu sync.Mutex + jobsMu sync.RWMutex + closed bool jobs chan func() } // Go will run the given function in a worker of the pool. // If all workers are busy, Go() will spawn a new goroutine to run the workload. func (p *ReusableGoroutinesPool) Go(f func()) { - p.jobsMu.Lock() - defer p.jobsMu.Unlock() + p.jobsMu.RLock() + defer p.jobsMu.RUnlock() + + // If the pool is closed, run the function in a new goroutine. + if p.closed { + go f() + return + } select { case p.jobs <- f: @@ -46,5 +53,6 @@ func (p *ReusableGoroutinesPool) Go(f func()) { func (p *ReusableGoroutinesPool) Close() { p.jobsMu.Lock() defer p.jobsMu.Unlock() + p.closed = true close(p.jobs) } diff --git a/concurrency/worker_test.go b/concurrency/worker_test.go index 1dab55928..c8ceef904 100644 --- a/concurrency/worker_test.go +++ b/concurrency/worker_test.go @@ -62,10 +62,10 @@ func TestReusableGoroutinesPool(t *testing.T) { t.Fatalf("expected %d goroutines after closing, got %d", 0, countGoroutines()) } +// TestReusableGoroutinesPool_Race tests that Close() and Go() can be called concurrently. func TestReusableGoroutinesPool_Race(t *testing.T) { w := NewReusableGoroutinesPool(2) - var panicked bool var runCountAtomic atomic.Int32 const maxMsgCount = 10 @@ -73,11 +73,6 @@ func TestReusableGoroutinesPool_Race(t *testing.T) { testWG.Add(1) go func() { defer testWG.Done() - defer func() { - if r := recover(); r != nil { - panicked = true - } - }() for i := 0; i < maxMsgCount; i++ { w.Go(func() { runCountAtomic.Add(1) @@ -90,7 +85,5 @@ func TestReusableGoroutinesPool_Race(t *testing.T) { testWG.Wait() // wait for the test to finish runCt := int(runCountAtomic.Load()) - require.NotZero(t, runCt, "expected at least one run") - require.Less(t, runCt, 10, "expected less than 10 runs") - require.True(t, panicked, "expected panic") + require.Equal(t, runCt, 10, "expected all functions to run") }