diff --git a/donegroup.go b/donegroup.go index 3437184..c6629a8 100644 --- a/donegroup.go +++ b/donegroup.go @@ -4,6 +4,7 @@ import ( "context" "errors" "sync" + "sync/atomic" "time" "golang.org/x/sync/errgroup" @@ -226,23 +227,26 @@ func GoWithKey(ctx context.Context, key any, f func() error) { first := dg.cleanupGroups[0] first.Go(func() error { - finished := false + var finished int32 go func() { if err := f(); err != nil { dg.mu.Lock() dg.errors = errors.Join(dg.errors, err) dg.mu.Unlock() } - finished = true + atomic.AddInt32(&finished, 1) }() <-ctx.Done() for { - if finished { + if atomic.LoadInt32(&finished) > 0 { break } + dg.mu.Lock() if dg.ctxw == nil { + dg.mu.Unlock() continue } + dg.mu.Unlock() select { case <-dg.ctxw.Done(): return dg.ctxw.Err() diff --git a/donegroup_test.go b/donegroup_test.go index 2bd9711..2404f41 100644 --- a/donegroup_test.go +++ b/donegroup_test.go @@ -4,6 +4,7 @@ import ( "context" "errors" "sync" + "sync/atomic" "testing" "time" ) @@ -355,7 +356,7 @@ func TestAwaiter(t *testing.T) { t.Parallel() ctx, cancel := WithCancel(context.Background()) - finished := false + var finished int32 go func() { completed, err := Awaiter(ctx) @@ -364,7 +365,7 @@ func TestAwaiter(t *testing.T) { } <-ctx.Done() time.Sleep(20 * time.Millisecond) - finished = true + atomic.AddInt32(&finished, 1) completed() }() @@ -372,7 +373,7 @@ func TestAwaiter(t *testing.T) { cancel() time.Sleep(10 * time.Millisecond) err := WaitWithTimeout(ctx, tt.timeout) - if tt.finished != finished { + if tt.finished != (atomic.LoadInt32(&finished) > 0) { t.Errorf("expected finished: %v, got: %v", tt.finished, finished) } if tt.finished { @@ -413,20 +414,20 @@ func TestAwaitable(t *testing.T) { t.Parallel() ctx, cancel := WithCancel(context.Background()) - finished := false + var finished int32 go func() { defer Awaitable(ctx)() <-ctx.Done() time.Sleep(20 * time.Millisecond) - finished = true + atomic.AddInt32(&finished, 1) }() defer func() { cancel() time.Sleep(10 * time.Millisecond) err := WaitWithTimeout(ctx, tt.timeout) - if tt.finished != finished { + if tt.finished != (atomic.LoadInt32(&finished) > 0) { t.Errorf("expected finished: %v, got: %v", tt.finished, finished) } if tt.finished {