From 084250d2dd78026c0544a2459feb542e8148f575 Mon Sep 17 00:00:00 2001 From: nikandfor Date: Thu, 21 Dec 2023 00:12:33 +0100 Subject: [PATCH] low level api --- README.md | 33 +++++++++- batch3.go | 168 ++++++++++++++++++++++++++++++++++--------------- batch3_test.go | 80 +++++++++++++++++++++++ 3 files changed, 227 insertions(+), 54 deletions(-) diff --git a/README.md b/README.md index b0c8a04..98da042 100644 --- a/README.md +++ b/README.md @@ -24,7 +24,7 @@ This is all without timeouts, additional goroutines, allocations, and channels. ## Usage -```go +``` var tx int bc := batch.Controller{ @@ -42,6 +42,10 @@ for j := 0; j < N; j++ { b := bc.Enter() defer b.Exit() + if b.Index() == 0 { // we are first in batch, reset it + tx = 0 + } + tx++ // add work to the batch res, err := b.Commit(ctx) @@ -55,6 +59,31 @@ for j := 0; j < N; j++ { } ``` -Batch is error and panic proof which means error the user code can return error or panic in any place, +Lower level API allows queue up in advance, before actually entering batch. +This can be used instead of waiting for timeout for other workers to come. +Instead workers declare itself and now they may be a bit late. + +``` +b := bc.Batch() +defer b.Exit() // should be called with defer to outlive panics + +b.QueueUp() // now we will be waited for. + +// prepare data +x := 3 + +b.Enter() // enter syncronized section + +// add data to a common batch +tx += x + +_, _ = b.Commit(ctx) + +// we are still in syncronized section until Exit is called +``` + +`Controller.Enter` is a shortcut for `Controller.Batch, Batch.QueueUp, Batch.Enter`. + +Batch is error and panic proof which means the user code can return error or panic in any place, but as soon as all workers left the batch its state is restored. But not the external state, it's callers responsibility to keep it consistent. diff --git a/batch3.go b/batch3.go index 9c42584..186c116 100644 --- a/batch3.go +++ b/batch3.go @@ -24,12 +24,11 @@ type ( } Batch struct { - noCopy noCopy - c *Controller i int - triggered bool + noCopy noCopy + state int } PanicError struct { @@ -39,12 +38,37 @@ type ( noCopy struct{} ) +const ( + stateNew = iota + stateQueued + stateEntered + stateCommitted + stateExited +) + var ( ErrRollback = errors.New("rollback") ) +// Enter enters a batch. func (c *Controller) Enter() Batch { atomic.AddInt32(&c.queue, 1) + + return Batch{ + c: c, + i: c.enter(), + state: stateEntered, + } +} + +func (c *Controller) Batch() Batch { + return Batch{ + c: c, + i: -1, + } +} + +func (c *Controller) enter() int { c.mu.Lock() if c.cond.L == nil { @@ -58,10 +82,69 @@ func (c *Controller) Enter() Batch { atomic.AddInt32(&c.queue, -1) c.cnt++ - return Batch{ - c: c, - i: int(c.cnt) - 1, + return int(c.cnt) - 1 +} + +func (c *Controller) exit() { + c.cnt++ + + if c.cnt == 0 { + c.res, c.err = nil, nil + } + + c.mu.Unlock() + c.cond.Broadcast() +} + +func (c *Controller) commit(ctx context.Context, err error) (interface{}, error) { + if err != nil || atomic.LoadInt32(&c.queue) == 0 { + c.cnt = -c.cnt + + if ep, ok := err.(PanicError); ok { + c.err = err + panic(ep.Panic) + } + + if err != nil { + c.err = err + } else { + func() { + defer func() { + if p := recover(); p != nil { + c.err = PanicError{Panic: p} + } + }() + + c.res, c.err = c.Commit(ctx) + }() + } + } else { + c.cond.Wait() + } + + res, err := c.res, c.err + + return res, err +} + +func (b *Batch) QueueUp() { + if b.state != stateQueued-1 { + panic("usage: QueueUp -> Enter -> defer Exit -> Commit/Rollback") } + + b.state = stateQueued + + atomic.AddInt32(&b.c.queue, 1) +} + +func (b *Batch) Enter() { + if b.state != stateEntered-1 { + panic("usage: QueueUp -> Enter -> defer Exit -> Commit/Rollback") + } + + b.state = stateEntered + + b.i = b.c.enter() } func (b *Batch) Index() int { @@ -69,18 +152,24 @@ func (b *Batch) Index() int { } func (b *Batch) Exit() { - defer func() { - b.c.cnt++ + switch b.state { + case stateNew: + return + case stateQueued: + atomic.AddInt32(&b.c.queue, -1) + return + case stateEntered, stateCommitted: + case stateExited: + panic("usage: QueueUp -> Enter -> defer Exit -> Commit/Rollback") + } - if b.c.cnt == 0 { - b.c.res, b.c.err = nil, nil - } + defer func() { + b.c.exit() - b.c.mu.Unlock() - b.c.cond.Broadcast() + b.state = stateExited }() - if b.triggered { + if b.state == stateCommitted { return } @@ -91,56 +180,31 @@ func (b *Batch) Exit() { err = PanicError{Panic: p} } - _, _ = b.commit(nil, err) + _, _ = b.c.commit(context.Background(), err) } func (b *Batch) Commit(ctx context.Context) (interface{}, error) { - return b.commit(ctx, nil) -} - -func (b *Batch) Rollback(ctx context.Context, err error) (interface{}, error) { - if err == nil { - err = ErrRollback + if b.state != stateCommitted-1 { + panic("usage: QueueUp -> Enter -> defer Exit -> Commit/Rollback") } - return b.commit(ctx, err) + b.state = stateCommitted + + return b.c.commit(ctx, nil) } -func (b *Batch) commit(ctx context.Context, err error) (interface{}, error) { - if b.triggered { - panic("usage: Enter -> defer Exit -> optional Commit with err or nil") +func (b *Batch) Rollback(ctx context.Context, err error) (interface{}, error) { + if b.state != stateCommitted-1 { + panic("usage: QueueUp -> Enter -> defer Exit -> Commit/Rollback") } - b.triggered = true - - if err != nil || atomic.LoadInt32(&b.c.queue) == 0 { - b.c.cnt = -b.c.cnt - - if ep, ok := err.(PanicError); ok { - b.c.err = err - panic(ep.Panic) - } - - if err != nil { - b.c.err = err - } else { - func() { - defer func() { - if p := recover(); p != nil { - b.c.err = PanicError{Panic: p} - } - }() + b.state = stateCommitted - b.c.res, b.c.err = b.c.Commit(ctx) - }() - } - } else { - b.c.cond.Wait() + if err == nil { + err = ErrRollback } - res, err := b.c.res, b.c.err - - return res, err + return b.c.commit(ctx, err) } func (e PanicError) Error() string { diff --git a/batch3_test.go b/batch3_test.go index 2fe158c..796783d 100644 --- a/batch3_test.go +++ b/batch3_test.go @@ -159,4 +159,84 @@ func TestBatch(tb *testing.T) { reached2 = true }) + + tb.Run("LowerAPI", func(tb *testing.T) { + b.Commit = func(ctx context.Context) (interface{}, error) { + return nil, nil + } + + b := b.Batch() + + b.QueueUp() + + defer b.Exit() + b.Enter() + + _, err := b.Commit(context.Background()) + if err != nil { + tb.Errorf("commit: %v", err) + } + }) + + tb.Run("LowerAPIMisuse", func(tb *testing.T) { + b.Commit = func(ctx context.Context) (interface{}, error) { + return nil, nil + } + + type testCase struct { + Name string + SkipQueue bool + DoubleQueue bool + NoEnter bool + CommitRollback bool + DoubleExit bool + } + + for _, tc := range []testCase{ + {Name: "SkipQueueUp", SkipQueue: true}, + {Name: "DoubleQueue", DoubleQueue: true}, + {Name: "NoEnter", NoEnter: true}, + {Name: "CommitRollback", CommitRollback: true}, + {Name: "DoubleExit", DoubleExit: true}, + } { + tc := tc + + tb.Run(tc.Name, func(tb *testing.T) { + defer func() { + p := recover() + if p == nil { + tb.Errorf("expected panic") + } + }() + + b := b.Batch() + defer b.Exit() + + if !tc.SkipQueue { + b.QueueUp() + } + if tc.DoubleQueue { + b.QueueUp() + } + + if !tc.NoEnter { + b.Enter() + } + + _, err := b.Commit(context.Background()) + if err != nil { + tb.Errorf("commit: %v", err) + } + + if tc.CommitRollback { + _, err = b.Rollback(context.Background(), nil) + _ = err + } + + if tc.DoubleExit { + b.Exit() + } + }) + } + }) }