From 2ad13d6e2f602fa362de89f899afed25ee1b1e6a Mon Sep 17 00:00:00 2001 From: nikandfor Date: Mon, 18 Dec 2023 00:51:47 +0100 Subject: [PATCH] change API --- README.md | 25 ++++++++++++++----------- batch.go | 43 +++++++++++++++++-------------------------- batch_test.go | 47 +++++++++++++++++++++++------------------------ example_test.go | 24 +++++++++++++++++++++++- 4 files changed, 77 insertions(+), 62 deletions(-) diff --git a/README.md b/README.md index 7f37679..e34ee32 100644 --- a/README.md +++ b/README.md @@ -27,17 +27,20 @@ This is all without timeouts, additional goroutines, allocations, and channels. ```go var tx int -b := batch.New(func(ctx context.Context) (interface{}, error) { - // commit tx - return res, err -}) - -// Optional hooks -b.Prepare = func(ctx context.Context) error { tx = 0; return nil } // called in the beginning on a new batch -b.Rollback = func(ctx context.Context, err error) error { return err } // if any worker returned error -b.Panic = func(ctx context.Context, p interface{}) error { // any worker panicked - return batch.PanicError{Panic: p} // returned to other workes - // panicked worker gets the panic back +b := batch.Batch{ + // Required + Commit: func(ctx context.Context) (interface{}, error) { + // commit tx + return res, err + }, + + // Optional hooks + Prepare: func(ctx context.Context) error { tx = 0; return nil }, // called in the beginning on a new batch + Rollback: func(ctx context.Context, err error) error { return err }, // if any worker returned error + Panic: func(ctx context.Context, p interface{}) error { // any worker panicked + return batch.PanicError{Panic: p} // returned to other workes + // panicked worker gets the panic back + }, } // only one of Panic, Rollback, and Commit is called (in respective priority order; panic wins, then error, commit is last) diff --git a/batch.go b/batch.go index 62f98cd..eeac02b 100644 --- a/batch.go +++ b/batch.go @@ -9,17 +9,17 @@ import ( type ( Batch struct { - queue int32 - Prepare func(ctx context.Context) error - Commit func(ctx context.Context) (interface{}, error) + Commit func(ctx context.Context) (interface{}, error) // required Rollback func(ctx context.Context, err error) error Panic func(ctx context.Context, p interface{}) error Limit *Semaphore - sync.Mutex - sync.Cond + queue int32 + + mu sync.Mutex + cond sync.Cond cnt int @@ -33,31 +33,22 @@ type ( } ) -func New(commit func(ctx context.Context) (interface{}, error)) *Batch { - b := &Batch{} - - b.Init(commit) - - return b -} - -func (b *Batch) Init(commit func(ctx context.Context) (interface{}, error)) { - b.Cond.L = &b.Mutex - b.Commit = commit -} - -func (b *Batch) Do(ctx context.Context, f func(ctx context.Context) error) (res interface{}, err error) { +func (b *Batch) Do(ctx context.Context, f func(ctx context.Context) error) (interface{}, error) { defer b.Limit.Exit() b.Limit.Enter() atomic.AddInt32(&b.queue, 1) - defer b.Unlock() - b.Lock() + defer b.mu.Unlock() + b.mu.Lock() + + if b.cond.L == nil { + b.cond.L = &b.mu + } // wait for all goroutines from the previous batch to exit for b.cnt < 0 { - b.Cond.Wait() + b.cond.Wait() } var p, p2 interface{} @@ -84,7 +75,7 @@ func (b *Batch) Do(ctx context.Context, f func(ctx context.Context) error) (res b.cnt++ // count entered if x != 0 { // we are not the last exiting the batch, wait for others - b.Cond.Wait() // so wait for the last one to finish the job + b.cond.Wait() // so wait for the last one to finish the job } else { b.cnt = -b.cnt // set committing mode, no new goroutines allowed to enter @@ -103,9 +94,9 @@ func (b *Batch) Do(ctx context.Context, f func(ctx context.Context) error) (res } b.cnt++ // reset committing mode when everybody left - b.Cond.Broadcast() + b.cond.Broadcast() - res, err = b.res, b.err // return the same result to all the entered + res, err := b.res, b.err // return the same result to all the entered if b.cnt == 0 { // the last turns the lights off b.res, b.err, b.panic = nil, nil, nil @@ -119,7 +110,7 @@ func (b *Batch) Do(ctx context.Context, f func(ctx context.Context) error) (res panic(p) } - return + return res, err } func (b *Batch) catchPanic(f func()) (p interface{}) { diff --git a/batch_test.go b/batch_test.go index b136232..9ba8bc9 100644 --- a/batch_test.go +++ b/batch_test.go @@ -19,28 +19,27 @@ func TestBatch(tb *testing.T) { var commits, rollbacks, panics int var bucket string - b := batch.New(func(ctx context.Context) (interface{}, error) { - commits++ - - return bucket, nil - }) - - b.Prepare = func(ctx context.Context) error { - bucket = "" - - return nil - } - - b.Rollback = func(ctx context.Context, err error) error { - rollbacks++ - - return err - } - - b.Panic = func(ctx context.Context, p interface{}) error { - panics++ - - return batch.PanicError{Panic: p} + b := &batch.Batch{ + Commit: func(ctx context.Context) (interface{}, error) { + commits++ + + return bucket, nil + }, + Prepare: func(ctx context.Context) error { + bucket = "" + + return nil + }, + Rollback: func(ctx context.Context, err error) error { + rollbacks++ + + return err + }, + Panic: func(ctx context.Context, p interface{}) error { + panics++ + + return batch.PanicError{Panic: p} + }, } var fail func() error @@ -154,12 +153,12 @@ func BenchmarkBatch(tb *testing.B) { var commits, sum int - b := batch.New(func(ctx context.Context) (interface{}, error) { + b := batch.Batch{Commit: func(ctx context.Context) (interface{}, error) { commits++ sum = 0 return nil, nil - }) + }} ctx := context.Background() diff --git a/example_test.go b/example_test.go index fe24b02..be9a809 100644 --- a/example_test.go +++ b/example_test.go @@ -24,8 +24,8 @@ type ( func New() *DB { d := &DB{} - d.b.Init(d.commit) d.b.Prepare = d.prepare + d.b.Commit = d.commit return d } @@ -98,3 +98,25 @@ func (d *DB) SaveBatched(ctx context.Context, data int) error { func (tx *Tx) Reset() { tx.updates = tx.updates[:0] } + +func ExampleBatch() { + ctx := context.Background() + svc := New() + + const M = 3 + + var wg sync.WaitGroup + + wg.Add(M) + + for j := 0; j < M; j++ { + go func() { + defer wg.Done() + + err := svc.SaveBatched(ctx, 2) + _ = err + }() + } + + wg.Wait() +}