Skip to content

Commit

Permalink
Merge pull request #23 from k1LoW/collect-errors
Browse files Browse the repository at this point in the history
Add `donegroup.Go`
  • Loading branch information
k1LoW authored Apr 4, 2024
2 parents b2b37d2 + 1834ccf commit 972ceca
Show file tree
Hide file tree
Showing 3 changed files with 156 additions and 9 deletions.
24 changes: 17 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

`donegroup` is a package that provides a graceful cleanup transaction to context.Context when the context is canceled ( **done** ).

> errgroup.Group after <-ctx.Done() = donegroup
> errgroup.Group + <-ctx.Done() = donegroup
## Usage

Expand Down Expand Up @@ -130,7 +130,7 @@ fmt.Println("main finish")

### [donegroup.Awaiter](https://pkg.go.dev/github.com/k1LoW/donegroup#Awaiter)

In addition to using [donegroup.Cleanup](https://pkg.go.dev/github.com/k1LoW/donegroup#Cleanup) to register a cleanup function after context cancellation, it is possible to use [donegroup.Awaiter](https://pkg.go.dev/github.com/k1LoW/donegroup#Awaiter) to make the execution of an arbitrary process wait after the context has been canceled.
In addition to using [donegroup.Cleanup](https://pkg.go.dev/github.com/k1LoW/donegroup#Cleanup) to register a cleanup function after context cancellation, it is possible to use [donegroup.Awaiter](https://pkg.go.dev/github.com/k1LoW/donegroup#Awaiter) to make the execution of an arbitrary process wait.

``` go
ctx, cancel := donegroup.WithCancel(context.Background())
Expand All @@ -141,8 +141,7 @@ go func() {
log.Fatal(err)
return
}
<-ctx.Done()
time.Sleep(100 * time.Millisecond)
time.Sleep(1000 * time.Millisecond)
fmt.Println("do something")
completed()
}()
Expand All @@ -169,11 +168,21 @@ It is also possible to guarantee the execution of a function block using `defer
``` go
go func() {
defer donegroup.Awaitable(ctx)()
<-ctx.Done()
time.Sleep(100 * time.Millisecond)
time.Sleep(1000 * time.Millisecond)
fmt.Println("do something")
completed()
}
}()
```

### [donegroup.Go](https://pkg.go.dev/github.com/k1LoW/donegroup#Go)

[donegroup.Go](https://pkg.go.dev/github.com/k1LoW/donegroup#Go) can also make an arbitrary process wait, similar to [donegroup.Awaiter](https://pkg.go.dev/github.com/k1LoW/donegroup#Awaiter).

``` go
donegroup.Go(func() error {
time.Sleep(1000 * time.Millisecond)
fmt.Println("do something")
return nil
}()
```

Expand Down Expand Up @@ -203,3 +212,4 @@ defer func() {
```

are equivalent.

53 changes: 51 additions & 2 deletions donegroup.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"errors"
"sync"
"sync/atomic"
"time"

"golang.org/x/sync/errgroup"
Expand Down Expand Up @@ -176,7 +177,7 @@ func Awaiter(ctx context.Context) (completed func(), err error) {
// AwaiterWithKey returns a function that guarantees execution of the process until it is called.
// Note that if the timeout of WaitWithTimeout has passed (or the context of WaitWithContext has canceled), it will not wait.
func AwaiterWithKey(ctx context.Context, key any) (completed func(), err error) {
ctxx, completed := context.WithCancel(context.Background())
ctxx, completed := context.WithCancel(context.Background()) //nolint:govet
if err := CleanupWithKey(ctx, key, func(ctxw context.Context) error {
for {
select {
Expand All @@ -187,7 +188,7 @@ func AwaiterWithKey(ctx context.Context, key any) (completed func(), err error)
}
}
}); err != nil {
return nil, err
return nil, err //nolint:govet
}
return completed, nil
}
Expand All @@ -207,3 +208,51 @@ func AwaitableWithKey(ctx context.Context, key any) (completed func()) {
}
return completed
}

// Go calls the function now asynchronously.
// If an error occurs, it is stored in the doneGroup.
// Note that if the timeout of WaitWithTimeout has passed (or the context of WaitWithContext has canceled), it will not wait.
func Go(ctx context.Context, f func() error) {
GoWithKey(ctx, doneGroupKey, f)
}

// GoWithKey calls the function now asynchronously.
// If an error occurs, it is stored in the doneGroup.
// Note that if the timeout of WaitWithTimeout has passed (or the context of WaitWithContext has canceled), it will not wait.
func GoWithKey(ctx context.Context, key any, f func() error) {
dg, ok := ctx.Value(key).(*doneGroup)
if !ok {
panic(ErrNotContainDoneGroup)
}

first := dg.cleanupGroups[0]
first.Go(func() error {
var finished int32
go func() {
if err := f(); err != nil {
dg.mu.Lock()
dg.errors = errors.Join(dg.errors, err)
dg.mu.Unlock()
}
atomic.AddInt32(&finished, 1)
}()
<-ctx.Done()
for {
if atomic.LoadInt32(&finished) > 0 {
break
}
dg.mu.Lock()
if dg.ctxw == nil {
dg.mu.Unlock()
continue
}
dg.mu.Unlock()
select {
case <-dg.ctxw.Done():
return dg.ctxw.Err()
default:
}
}
return nil
})
}
88 changes: 88 additions & 0 deletions donegroup_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"errors"
"sync"
"sync/atomic"
"testing"
"time"
)
Expand Down Expand Up @@ -355,20 +356,26 @@ func TestAwaiter(t *testing.T) {
t.Parallel()
ctx, cancel := WithCancel(context.Background())

var finished int32

go func() {
completed, err := Awaiter(ctx)
if err != nil {
t.Error(err)
}
<-ctx.Done()
time.Sleep(20 * time.Millisecond)
atomic.AddInt32(&finished, 1)
completed()
}()

defer func() {
cancel()
time.Sleep(10 * time.Millisecond)
err := WaitWithTimeout(ctx, tt.timeout)
if tt.finished != (atomic.LoadInt32(&finished) > 0) {
t.Errorf("expected finished: %v, got: %v", tt.finished, finished)
}
if tt.finished {
if err != nil {
t.Error(err)
Expand Down Expand Up @@ -407,16 +414,22 @@ func TestAwaitable(t *testing.T) {
t.Parallel()
ctx, cancel := WithCancel(context.Background())

var finished int32

go func() {
defer Awaitable(ctx)()
<-ctx.Done()
time.Sleep(20 * time.Millisecond)
atomic.AddInt32(&finished, 1)
}()

defer func() {
cancel()
time.Sleep(10 * time.Millisecond)
err := WaitWithTimeout(ctx, tt.timeout)
if tt.finished != (atomic.LoadInt32(&finished) > 0) {
t.Errorf("expected finished: %v, got: %v", tt.finished, finished)
}
if tt.finished {
if err != nil {
t.Error(err)
Expand Down Expand Up @@ -507,3 +520,78 @@ func TestCancelWithContext(t *testing.T) {
}
}()
}

func TestGo(t *testing.T) {
t.Parallel()
tests := []struct {
name string
timeout time.Duration
finished bool
}{
{
name: "finished",
timeout: 200 * time.Millisecond,
finished: true,
},
{
name: "not finished",
timeout: 5 * time.Millisecond,
finished: false,
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
ctx, cancel := WithCancel(context.Background())

finished := false

Go(ctx, func() error {
<-ctx.Done()
time.Sleep(100 * time.Millisecond)
finished = true
return nil
})

defer func() {
cancel()
time.Sleep(10 * time.Millisecond)
err := WaitWithTimeout(ctx, tt.timeout)
if tt.finished != finished {
t.Errorf("expected finished: %v, got: %v", tt.finished, finished)
}
if tt.finished {
if err != nil {
t.Error(err)
}
return
}
if !errors.Is(err, context.DeadlineExceeded) {
t.Errorf("expected timeout error: %v", err)
}
}()
})
}
}

func TestGoWithError(t *testing.T) {
t.Parallel()
ctx, cancel := WithCancel(context.Background())

var errTest = errors.New("test error")

Go(ctx, func() error {
time.Sleep(10 * time.Millisecond)
return errTest
})

defer func() {
cancel()

err := Wait(ctx)
if !errors.Is(err, errTest) {
t.Errorf("got %v, want %v", err, errTest)
}
}()
}

0 comments on commit 972ceca

Please sign in to comment.