From 6f054330d5892f4b6971f5aebd28754d9e12643c Mon Sep 17 00:00:00 2001 From: k1LoW Date: Thu, 4 Apr 2024 15:27:08 +0900 Subject: [PATCH 1/2] Add `donegroup.Go` --- README.md | 24 +++++++++---- donegroup.go | 49 ++++++++++++++++++++++++-- donegroup_test.go | 87 +++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 151 insertions(+), 9 deletions(-) diff --git a/README.md b/README.md index c2dd786..452c3f3 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ `donegroup` is a package that provides a graceful cleanup transaction to context.Context when the context is canceled ( **done** ). -> errgroup.Group after <-ctx.Done() = donegroup +> errgroup.Group + <-ctx.Done() = donegroup ## Usage @@ -130,7 +130,7 @@ fmt.Println("main finish") ### [donegroup.Awaiter](https://pkg.go.dev/github.com/k1LoW/donegroup#Awaiter) -In addition to using [donegroup.Cleanup](https://pkg.go.dev/github.com/k1LoW/donegroup#Cleanup) to register a cleanup function after context cancellation, it is possible to use [donegroup.Awaiter](https://pkg.go.dev/github.com/k1LoW/donegroup#Awaiter) to make the execution of an arbitrary process wait after the context has been canceled. +In addition to using [donegroup.Cleanup](https://pkg.go.dev/github.com/k1LoW/donegroup#Cleanup) to register a cleanup function after context cancellation, it is possible to use [donegroup.Awaiter](https://pkg.go.dev/github.com/k1LoW/donegroup#Awaiter) to make the execution of an arbitrary process wait. ``` go ctx, cancel := donegroup.WithCancel(context.Background()) @@ -141,8 +141,7 @@ go func() { log.Fatal(err) return } - <-ctx.Done() - time.Sleep(100 * time.Millisecond) + time.Sleep(1000 * time.Millisecond) fmt.Println("do something") completed() }() @@ -169,11 +168,21 @@ It is also possible to guarantee the execution of a function block using `defer ``` go go func() { defer donegroup.Awaitable(ctx)() - <-ctx.Done() - time.Sleep(100 * time.Millisecond) + time.Sleep(1000 * time.Millisecond) fmt.Println("do something") completed() - } +}() +``` + +### [donegroup.Go](https://pkg.go.dev/github.com/k1LoW/donegroup#Go) + +[donegroup.Go](https://pkg.go.dev/github.com/k1LoW/donegroup#Go) can also make an arbitrary process wait, similar to [donegroup.Awaiter](https://pkg.go.dev/github.com/k1LoW/donegroup#Awaiter). + +``` go +donegroup.Go(func() error { + time.Sleep(1000 * time.Millisecond) + fmt.Println("do something") + return nil }() ``` @@ -203,3 +212,4 @@ defer func() { ``` are equivalent. + diff --git a/donegroup.go b/donegroup.go index 1ea25ae..3437184 100644 --- a/donegroup.go +++ b/donegroup.go @@ -176,7 +176,7 @@ func Awaiter(ctx context.Context) (completed func(), err error) { // AwaiterWithKey returns a function that guarantees execution of the process until it is called. // Note that if the timeout of WaitWithTimeout has passed (or the context of WaitWithContext has canceled), it will not wait. func AwaiterWithKey(ctx context.Context, key any) (completed func(), err error) { - ctxx, completed := context.WithCancel(context.Background()) + ctxx, completed := context.WithCancel(context.Background()) //nolint:govet if err := CleanupWithKey(ctx, key, func(ctxw context.Context) error { for { select { @@ -187,7 +187,7 @@ func AwaiterWithKey(ctx context.Context, key any) (completed func(), err error) } } }); err != nil { - return nil, err + return nil, err //nolint:govet } return completed, nil } @@ -207,3 +207,48 @@ func AwaitableWithKey(ctx context.Context, key any) (completed func()) { } return completed } + +// Go calls the function now asynchronously. +// If an error occurs, it is stored in the doneGroup. +// Note that if the timeout of WaitWithTimeout has passed (or the context of WaitWithContext has canceled), it will not wait. +func Go(ctx context.Context, f func() error) { + GoWithKey(ctx, doneGroupKey, f) +} + +// GoWithKey calls the function now asynchronously. +// If an error occurs, it is stored in the doneGroup. +// Note that if the timeout of WaitWithTimeout has passed (or the context of WaitWithContext has canceled), it will not wait. +func GoWithKey(ctx context.Context, key any, f func() error) { + dg, ok := ctx.Value(key).(*doneGroup) + if !ok { + panic(ErrNotContainDoneGroup) + } + + first := dg.cleanupGroups[0] + first.Go(func() error { + finished := false + go func() { + if err := f(); err != nil { + dg.mu.Lock() + dg.errors = errors.Join(dg.errors, err) + dg.mu.Unlock() + } + finished = true + }() + <-ctx.Done() + for { + if finished { + break + } + if dg.ctxw == nil { + continue + } + select { + case <-dg.ctxw.Done(): + return dg.ctxw.Err() + default: + } + } + return nil + }) +} diff --git a/donegroup_test.go b/donegroup_test.go index d0a7aa6..2bd9711 100644 --- a/donegroup_test.go +++ b/donegroup_test.go @@ -355,6 +355,8 @@ func TestAwaiter(t *testing.T) { t.Parallel() ctx, cancel := WithCancel(context.Background()) + finished := false + go func() { completed, err := Awaiter(ctx) if err != nil { @@ -362,6 +364,7 @@ func TestAwaiter(t *testing.T) { } <-ctx.Done() time.Sleep(20 * time.Millisecond) + finished = true completed() }() @@ -369,6 +372,9 @@ func TestAwaiter(t *testing.T) { cancel() time.Sleep(10 * time.Millisecond) err := WaitWithTimeout(ctx, tt.timeout) + if tt.finished != finished { + t.Errorf("expected finished: %v, got: %v", tt.finished, finished) + } if tt.finished { if err != nil { t.Error(err) @@ -407,16 +413,22 @@ func TestAwaitable(t *testing.T) { t.Parallel() ctx, cancel := WithCancel(context.Background()) + finished := false + go func() { defer Awaitable(ctx)() <-ctx.Done() time.Sleep(20 * time.Millisecond) + finished = true }() defer func() { cancel() time.Sleep(10 * time.Millisecond) err := WaitWithTimeout(ctx, tt.timeout) + if tt.finished != finished { + t.Errorf("expected finished: %v, got: %v", tt.finished, finished) + } if tt.finished { if err != nil { t.Error(err) @@ -507,3 +519,78 @@ func TestCancelWithContext(t *testing.T) { } }() } + +func TestGo(t *testing.T) { + t.Parallel() + tests := []struct { + name string + timeout time.Duration + finished bool + }{ + { + name: "finished", + timeout: 200 * time.Millisecond, + finished: true, + }, + { + name: "not finished", + timeout: 5 * time.Millisecond, + finished: false, + }, + } + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + ctx, cancel := WithCancel(context.Background()) + + finished := false + + Go(ctx, func() error { + <-ctx.Done() + time.Sleep(100 * time.Millisecond) + finished = true + return nil + }) + + defer func() { + cancel() + time.Sleep(10 * time.Millisecond) + err := WaitWithTimeout(ctx, tt.timeout) + if tt.finished != finished { + t.Errorf("expected finished: %v, got: %v", tt.finished, finished) + } + if tt.finished { + if err != nil { + t.Error(err) + } + return + } + if !errors.Is(err, context.DeadlineExceeded) { + t.Errorf("expected timeout error: %v", err) + } + }() + }) + } +} + +func TestGoWithError(t *testing.T) { + t.Parallel() + ctx, cancel := WithCancel(context.Background()) + + var errTest = errors.New("test error") + + Go(ctx, func() error { + time.Sleep(10 * time.Millisecond) + return errTest + }) + + defer func() { + cancel() + + err := Wait(ctx) + if !errors.Is(err, errTest) { + t.Errorf("got %v, want %v", err, errTest) + } + }() +} From 1834ccfd2cc0e9f4b9ddf21451250fbcc17da642 Mon Sep 17 00:00:00 2001 From: k1LoW Date: Thu, 4 Apr 2024 15:38:14 +0900 Subject: [PATCH 2/2] Fix race condition --- donegroup.go | 10 +++++++--- donegroup_test.go | 13 +++++++------ 2 files changed, 14 insertions(+), 9 deletions(-) diff --git a/donegroup.go b/donegroup.go index 3437184..c6629a8 100644 --- a/donegroup.go +++ b/donegroup.go @@ -4,6 +4,7 @@ import ( "context" "errors" "sync" + "sync/atomic" "time" "golang.org/x/sync/errgroup" @@ -226,23 +227,26 @@ func GoWithKey(ctx context.Context, key any, f func() error) { first := dg.cleanupGroups[0] first.Go(func() error { - finished := false + var finished int32 go func() { if err := f(); err != nil { dg.mu.Lock() dg.errors = errors.Join(dg.errors, err) dg.mu.Unlock() } - finished = true + atomic.AddInt32(&finished, 1) }() <-ctx.Done() for { - if finished { + if atomic.LoadInt32(&finished) > 0 { break } + dg.mu.Lock() if dg.ctxw == nil { + dg.mu.Unlock() continue } + dg.mu.Unlock() select { case <-dg.ctxw.Done(): return dg.ctxw.Err() diff --git a/donegroup_test.go b/donegroup_test.go index 2bd9711..2404f41 100644 --- a/donegroup_test.go +++ b/donegroup_test.go @@ -4,6 +4,7 @@ import ( "context" "errors" "sync" + "sync/atomic" "testing" "time" ) @@ -355,7 +356,7 @@ func TestAwaiter(t *testing.T) { t.Parallel() ctx, cancel := WithCancel(context.Background()) - finished := false + var finished int32 go func() { completed, err := Awaiter(ctx) @@ -364,7 +365,7 @@ func TestAwaiter(t *testing.T) { } <-ctx.Done() time.Sleep(20 * time.Millisecond) - finished = true + atomic.AddInt32(&finished, 1) completed() }() @@ -372,7 +373,7 @@ func TestAwaiter(t *testing.T) { cancel() time.Sleep(10 * time.Millisecond) err := WaitWithTimeout(ctx, tt.timeout) - if tt.finished != finished { + if tt.finished != (atomic.LoadInt32(&finished) > 0) { t.Errorf("expected finished: %v, got: %v", tt.finished, finished) } if tt.finished { @@ -413,20 +414,20 @@ func TestAwaitable(t *testing.T) { t.Parallel() ctx, cancel := WithCancel(context.Background()) - finished := false + var finished int32 go func() { defer Awaitable(ctx)() <-ctx.Done() time.Sleep(20 * time.Millisecond) - finished = true + atomic.AddInt32(&finished, 1) }() defer func() { cancel() time.Sleep(10 * time.Millisecond) err := WaitWithTimeout(ctx, tt.timeout) - if tt.finished != finished { + if tt.finished != (atomic.LoadInt32(&finished) > 0) { t.Errorf("expected finished: %v, got: %v", tt.finished, finished) } if tt.finished {