diff --git a/batch3.go b/batch3.go index 742cc06..ce7fd3a 100644 --- a/batch3.go +++ b/batch3.go @@ -133,13 +133,14 @@ func (c *Controller[Res]) commit(ctx context.Context, err error) (Res, error) { again: if err != nil || atomic.LoadInt32(&c.queue) == 0 { c.cnt = -c.cnt - c.ready = true if ep, ok := err.(PanicError); ok { c.err = err + c.ready = true panic(ep.Panic) } else if err != nil { c.err = err + c.ready = true } else { func() { var res Res @@ -147,6 +148,7 @@ again: defer func() { c.res, c.err = res, err + c.ready = true if p := recover(); p != nil { c.err = PanicError{Panic: p} @@ -160,11 +162,15 @@ again: }() } } else { + wait: c.cond.Wait() - if !c.ready { + if c.cnt > 0 { goto again } + if !c.ready { + goto wait + } } res, err := c.res, c.err @@ -260,8 +266,15 @@ func (b *Batch[Res]) Rollback(ctx context.Context, err error) (Res, error) { return b.c.commit(ctx, err) } +func AsPanicError(err error) (PanicError, bool) { + var pe PanicError + + return pe, errors.As(err, &pe) +} + func (e PanicError) Error() string { return fmt.Sprintf("panic: %v", e.Panic) } -func (noCopy) Lock() {} +func (*noCopy) Lock() {} +func (*noCopy) Unlock() {} diff --git a/batch3_test.go b/batch3_test.go index 402bd98..0bb4621 100644 --- a/batch3_test.go +++ b/batch3_test.go @@ -266,6 +266,48 @@ func TestBatch(tb *testing.T) { }) } }) + + tb.Run("okAfterAll", func(tb *testing.T) { + const N = 100 + + ctx := context.Background() + + var sum int + + b.Commit = func(ctx context.Context, _ int) (int, error) { + runtime.Gosched() + return sum, nil + } + + var wg sync.WaitGroup + + wg.Add(*jobs) + + for j := 0; j < *jobs; j++ { + go func() { + defer wg.Done() + + for i := 0; i < N; i++ { + func() { + b, idx := b.Enter(true) + defer b.Exit() + + if idx == 0 { + sum = 0 + } + + runtime.Gosched() + sum++ + runtime.Gosched() + + _, _ = b.Commit(ctx) + }() + } + }() + } + + wg.Wait() + }) } func TestBatchNonBlocking(tb *testing.T) { diff --git a/go.mod b/go.mod index 3750a31..cfbd2ff 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,5 @@ module nikand.dev/go/batch go 1.18 + +retract v0.4.0 // has a bug diff --git a/multi.go b/multi.go index f42d577..844a89c 100644 --- a/multi.go +++ b/multi.go @@ -40,7 +40,10 @@ func (c *Multi[Res]) Enter(blocking bool) (b Batch[Res], coach, index int) { for coach := range c.cs { b, idx := c.cs[coach].Enter(false) if idx >= 0 { - return b, coach, idx + return Batch[Res]{ + c: b.c, + state: b.state, + }, coach, idx } }