Skip to content

Commit

Permalink
Use sync/atomic
Browse files Browse the repository at this point in the history
  • Loading branch information
k1LoW committed Jun 8, 2024
1 parent 1499301 commit e7e00b6
Showing 1 changed file with 49 additions and 98 deletions.
147 changes: 49 additions & 98 deletions donegroup_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package donegroup
import (
"context"
"errors"
"sync"
"sync/atomic"
"testing"
"time"
Expand All @@ -13,11 +12,11 @@ func TestDoneGroup(t *testing.T) {
t.Parallel()
ctx, cancel := WithCancel(context.Background())

cleanup := false
cleanup := atomic.Bool{}

if err := Cleanup(ctx, func() error {
time.Sleep(10 * time.Millisecond)
cleanup = true
cleanup.Store(true)
return nil
}); err != nil {
t.Error(err)
Expand All @@ -30,12 +29,12 @@ func TestDoneGroup(t *testing.T) {
t.Error(err)
}

if !cleanup {
if !cleanup.Load() {
t.Error("cleanup function not called")
}
}()

cleanup = false
cleanup.Store(false)
}

func TestCleanup(t *testing.T) {
Expand Down Expand Up @@ -111,38 +110,31 @@ func TestWait(t *testing.T) {

func TestNoWait(t *testing.T) {
t.Parallel()
mu := sync.Mutex{}
ctx, cancel := WithCancel(context.Background())

cleanup := false
cleanup := atomic.Bool{}

if err := Cleanup(ctx, func() error {
time.Sleep(10 * time.Millisecond)
mu.Lock()
defer mu.Unlock()
cleanup = true
cleanup.Store(true)
return nil
}); err != nil {
t.Error(err)
}

defer func() {
cancel()
mu.Lock()
if cleanup {
if cleanup.Load() {
t.Error("cleanup function called")
}
mu.Unlock()

time.Sleep(20 * time.Millisecond)
mu.Lock()
if !cleanup {
if !cleanup.Load() {
t.Error("cleanup function not called")
}
mu.Unlock()
}()

cleanup = false
cleanup.Store(false)
}

func TestNoCleanup(t *testing.T) {
Expand All @@ -162,15 +154,12 @@ func TestMultiCleanup(t *testing.T) {
t.Parallel()
ctx, cancel := WithCancel(context.Background())

mu := sync.Mutex{}
cleanup := 0
cleanup := atomic.Int64{}

for i := 0; i < 10; i++ {
if err := Cleanup(ctx, func() error {
time.Sleep(10 * time.Millisecond)
mu.Lock()
defer mu.Unlock()
cleanup += 1
cleanup.Add(1)
return nil
}); err != nil {
t.Error(err)
Expand All @@ -184,7 +173,7 @@ func TestMultiCleanup(t *testing.T) {
t.Error(err)
}

if cleanup != 10 {
if cleanup.Load() != 10 {
t.Error("cleanup function not called")
}
}()
Expand All @@ -196,17 +185,14 @@ func TestNestedWithCancel(t *testing.T) {
secondCtx, secondCancel := WithCancel(firstCtx)
thirdCtx, thirdCancel := context.WithCancel(secondCtx) // context.WithCancel

mu := sync.Mutex{}
firstCleanup := 0
secondCleanup := 0
thirdCleanup := 0
firstCleanup := atomic.Int64{}
secondCleanup := atomic.Int64{}
thirdCleanup := atomic.Int64{}

for i := 0; i < 10; i++ {
if err := Cleanup(firstCtx, func() error {
time.Sleep(10 * time.Millisecond)
mu.Lock()
defer mu.Unlock()
firstCleanup += 1
firstCleanup.Add(1)
return nil
}); err != nil {
t.Error(err)
Expand All @@ -216,9 +202,7 @@ func TestNestedWithCancel(t *testing.T) {
for i := 0; i < 5; i++ {
if err := Cleanup(secondCtx, func() error {
time.Sleep(10 * time.Millisecond)
mu.Lock()
defer mu.Unlock()
secondCleanup += 1
secondCleanup.Add(1)
return nil
}); err != nil {
t.Error(err)
Expand All @@ -228,9 +212,7 @@ func TestNestedWithCancel(t *testing.T) {
for i := 0; i < 3; i++ {
if err := Cleanup(thirdCtx, func() error {
time.Sleep(10 * time.Millisecond)
mu.Lock()
defer mu.Unlock()
thirdCleanup += 1
thirdCleanup.Add(1)
return nil
}); err != nil {
t.Error(err)
Expand Down Expand Up @@ -263,17 +245,15 @@ func TestNestedWithCancel(t *testing.T) {
thirdCancel()
<-thirdCtx.Done()

mu.Lock()
if firstCleanup != 0 {
if firstCleanup.Load() != 0 {
t.Error("cleanup function for first called")
}
if secondCleanup != 0 {
if secondCleanup.Load() != 0 {
t.Error("cleanup function for second called")
}
if thirdCleanup != 0 {
if thirdCleanup.Load() != 0 {
t.Error("cleanup function for third called")
}
mu.Unlock()

secondCancel()
<-secondCtx.Done()
Expand All @@ -282,17 +262,15 @@ func TestNestedWithCancel(t *testing.T) {
t.Error(err)
}

mu.Lock()
if thirdCleanup != 3 {
if thirdCleanup.Load() != 3 {
t.Error("cleanup function for third not called")
}
if secondCleanup != 5 {
if secondCleanup.Load() != 5 {
t.Error("cleanup function for second not called")
}
if firstCleanup != 0 {
if firstCleanup.Load() != 0 {
t.Error("cleanup function for first called")
}
mu.Unlock()

firstCancel()
<-firstCtx.Done()
Expand All @@ -301,17 +279,15 @@ func TestNestedWithCancel(t *testing.T) {
t.Error(err)
}

mu.Lock()
if thirdCleanup != 3 {
if thirdCleanup.Load() != 3 {
t.Error("cleanup function for third not called")
}
if secondCleanup != 5 {
if secondCleanup.Load() != 5 {
t.Error("cleanup function for second not called")
}
if firstCleanup != 10 {
if firstCleanup.Load() != 10 {
t.Error("cleanup function for first not called")
}
mu.Unlock()
}()
}

Expand All @@ -320,16 +296,13 @@ func TestRootWaitAll(t *testing.T) {
rootCtx, rootCancel := WithCancel(context.Background())
leafCtx, _ := WithCancel(rootCtx)

mu := sync.Mutex{}
rootCleanup := 0
leafCleanup := 0
rootCleanup := atomic.Int64{}
leafCleanup := atomic.Int64{}

for i := 0; i < 10; i++ {
if err := Cleanup(rootCtx, func() error {
time.Sleep(10 * time.Millisecond)
mu.Lock()
defer mu.Unlock()
rootCleanup += 1
rootCleanup.Add(1)
return nil
}); err != nil {
t.Error(err)
Expand All @@ -339,17 +312,15 @@ func TestRootWaitAll(t *testing.T) {
for i := 0; i < 5; i++ {
if err := Cleanup(leafCtx, func() error {
time.Sleep(10 * time.Millisecond)
mu.Lock()
defer mu.Unlock()
leafCleanup += 1
leafCleanup.Add(1)
return nil
}); err != nil {
t.Error(err)
}
}

defer func() {
if rootCleanup != 0 {
if rootCleanup.Load() != 0 {
t.Error("cleanup function for root called")
}

Expand All @@ -359,11 +330,11 @@ func TestRootWaitAll(t *testing.T) {
t.Error(err)
}

if leafCleanup != 5 {
if leafCleanup.Load() != 5 {
t.Error("cleanup function for leaf not called")
}

if rootCleanup != 10 {
if rootCleanup.Load() != 10 {
t.Error("cleanup function for root not called")
}
}()
Expand Down Expand Up @@ -667,28 +638,23 @@ func TestWithDeadline(t *testing.T) {
t.Parallel()
ctx, _ := WithDeadline(context.Background(), time.Now().Add(5*time.Millisecond))

mu := sync.Mutex{}
cleanup := false
cleanup := atomic.Bool{}

if err := Cleanup(ctx, func() error {
time.Sleep(10 * time.Millisecond)
mu.Lock()
defer mu.Unlock()
cleanup = true
cleanup.Store(true)
return nil
}); err != nil {
t.Error(err)
}

mu.Lock()
cleanup = false
mu.Unlock()
cleanup.Store(false)

if err := Wait(ctx); err != nil {
t.Error(err)
}

if !cleanup {
if !cleanup.Load() {
t.Error("cleanup function not called")
}

Expand All @@ -701,28 +667,23 @@ func TestWithTimeout(t *testing.T) {
t.Parallel()
ctx, _ := WithTimeout(context.Background(), 5*time.Millisecond)

mu := sync.Mutex{}
cleanup := false
cleanup := atomic.Bool{}

if err := Cleanup(ctx, func() error {
time.Sleep(10 * time.Millisecond)
mu.Lock()
defer mu.Unlock()
cleanup = true
cleanup.Store(true)
return nil
}); err != nil {
t.Error(err)
}

mu.Lock()
cleanup = false
mu.Unlock()
cleanup.Store(false)

if err := Wait(ctx); err != nil {
t.Error(err)
}

if !cleanup {
if !cleanup.Load() {
t.Error("cleanup function not called")
}

Expand All @@ -736,28 +697,23 @@ func TestWithTimeoutCause(t *testing.T) {
var errTest = errors.New("test error")
ctx, _ := WithTimeoutCause(context.Background(), 5*time.Millisecond, errTest)

mu := sync.Mutex{}
cleanup := false
cleanup := atomic.Bool{}

if err := Cleanup(ctx, func() error {
time.Sleep(10 * time.Millisecond)
mu.Lock()
defer mu.Unlock()
cleanup = true
cleanup.Store(true)
return nil
}); err != nil {
t.Error(err)
}

mu.Lock()
cleanup = false
mu.Unlock()
cleanup.Store(false)

if err := Wait(ctx); err != nil {
t.Error(err)
}

if !cleanup {
if !cleanup.Load() {
t.Error("cleanup function not called")
}

Expand Down Expand Up @@ -806,13 +762,10 @@ func TestCancelWithCause(t *testing.T) {
func TestWithoutCancel(t *testing.T) {
t.Parallel()
ctx, cancel := WithCancel(context.Background())
mu := sync.Mutex{}
cleanup := false
cleanup := atomic.Bool{}

if err := Cleanup(ctx, func() error {
mu.Lock()
defer mu.Unlock()
cleanup = true
cleanup.Store(true)
return nil
}); err != nil {
t.Error(err)
Expand All @@ -828,9 +781,7 @@ func TestWithoutCancel(t *testing.T) {

time.Sleep(5 * time.Millisecond)

mu.Lock()
defer mu.Unlock()
if !cleanup {
if !cleanup.Load() {
t.Error("cleanup function not called")
}
}

0 comments on commit e7e00b6

Please sign in to comment.