Skip to content

Commit

Permalink
Fix race condition
Browse files Browse the repository at this point in the history
  • Loading branch information
k1LoW committed Apr 4, 2024
1 parent 6f05433 commit 1834ccf
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 9 deletions.
10 changes: 7 additions & 3 deletions donegroup.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"errors"
"sync"
"sync/atomic"
"time"

"golang.org/x/sync/errgroup"
Expand Down Expand Up @@ -226,23 +227,26 @@ func GoWithKey(ctx context.Context, key any, f func() error) {

first := dg.cleanupGroups[0]
first.Go(func() error {
finished := false
var finished int32
go func() {
if err := f(); err != nil {
dg.mu.Lock()
dg.errors = errors.Join(dg.errors, err)
dg.mu.Unlock()
}
finished = true
atomic.AddInt32(&finished, 1)
}()
<-ctx.Done()
for {
if finished {
if atomic.LoadInt32(&finished) > 0 {
break
}
dg.mu.Lock()
if dg.ctxw == nil {
dg.mu.Unlock()
continue
}
dg.mu.Unlock()
select {
case <-dg.ctxw.Done():
return dg.ctxw.Err()
Expand Down
13 changes: 7 additions & 6 deletions donegroup_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"errors"
"sync"
"sync/atomic"
"testing"
"time"
)
Expand Down Expand Up @@ -355,7 +356,7 @@ func TestAwaiter(t *testing.T) {
t.Parallel()
ctx, cancel := WithCancel(context.Background())

finished := false
var finished int32

go func() {
completed, err := Awaiter(ctx)
Expand All @@ -364,15 +365,15 @@ func TestAwaiter(t *testing.T) {
}
<-ctx.Done()
time.Sleep(20 * time.Millisecond)
finished = true
atomic.AddInt32(&finished, 1)
completed()
}()

defer func() {
cancel()
time.Sleep(10 * time.Millisecond)
err := WaitWithTimeout(ctx, tt.timeout)
if tt.finished != finished {
if tt.finished != (atomic.LoadInt32(&finished) > 0) {
t.Errorf("expected finished: %v, got: %v", tt.finished, finished)
}
if tt.finished {
Expand Down Expand Up @@ -413,20 +414,20 @@ func TestAwaitable(t *testing.T) {
t.Parallel()
ctx, cancel := WithCancel(context.Background())

finished := false
var finished int32

go func() {
defer Awaitable(ctx)()
<-ctx.Done()
time.Sleep(20 * time.Millisecond)
finished = true
atomic.AddInt32(&finished, 1)
}()

defer func() {
cancel()
time.Sleep(10 * time.Millisecond)
err := WaitWithTimeout(ctx, tt.timeout)
if tt.finished != finished {
if tt.finished != (atomic.LoadInt32(&finished) > 0) {
t.Errorf("expected finished: %v, got: %v", tt.finished, finished)
}
if tt.finished {
Expand Down

0 comments on commit 1834ccf

Please sign in to comment.