Skip to content

Commit

Permalink
low level api
Browse files Browse the repository at this point in the history
  • Loading branch information
nikandfor committed Dec 20, 2023
1 parent c130a1f commit 084250d
Show file tree
Hide file tree
Showing 3 changed files with 227 additions and 54 deletions.
33 changes: 31 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ This is all without timeouts, additional goroutines, allocations, and channels.

## Usage

```go
```
var tx int
bc := batch.Controller{
Expand All @@ -42,6 +42,10 @@ for j := 0; j < N; j++ {
b := bc.Enter()
defer b.Exit()
if b.Index() == 0 { // we are first in batch, reset it
tx = 0
}
tx++ // add work to the batch
res, err := b.Commit(ctx)
Expand All @@ -55,6 +59,31 @@ for j := 0; j < N; j++ {
}
```

Batch is error and panic proof which means error the user code can return error or panic in any place,
Lower level API allows queue up in advance, before actually entering batch.
This can be used instead of waiting for timeout for other workers to come.
Instead workers declare itself and now they may be a bit late.

```
b := bc.Batch()
defer b.Exit() // should be called with defer to outlive panics
b.QueueUp() // now we will be waited for.
// prepare data
x := 3
b.Enter() // enter syncronized section
// add data to a common batch
tx += x
_, _ = b.Commit(ctx)
// we are still in syncronized section until Exit is called
```

`Controller.Enter` is a shortcut for `Controller.Batch, Batch.QueueUp, Batch.Enter`.

Batch is error and panic proof which means the user code can return error or panic in any place,
but as soon as all workers left the batch its state is restored.
But not the external state, it's callers responsibility to keep it consistent.
168 changes: 116 additions & 52 deletions batch3.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,11 @@ type (
}

Batch struct {
noCopy noCopy

c *Controller
i int

triggered bool
noCopy noCopy
state int
}

PanicError struct {
Expand All @@ -39,12 +38,37 @@ type (
noCopy struct{}
)

const (
stateNew = iota
stateQueued
stateEntered
stateCommitted
stateExited
)

var (
ErrRollback = errors.New("rollback")
)

// Enter enters a batch.
func (c *Controller) Enter() Batch {
atomic.AddInt32(&c.queue, 1)

return Batch{
c: c,
i: c.enter(),
state: stateEntered,
}
}

func (c *Controller) Batch() Batch {
return Batch{
c: c,
i: -1,
}
}

func (c *Controller) enter() int {
c.mu.Lock()

if c.cond.L == nil {
Expand All @@ -58,29 +82,94 @@ func (c *Controller) Enter() Batch {
atomic.AddInt32(&c.queue, -1)
c.cnt++

return Batch{
c: c,
i: int(c.cnt) - 1,
return int(c.cnt) - 1
}

func (c *Controller) exit() {
c.cnt++

if c.cnt == 0 {
c.res, c.err = nil, nil
}

c.mu.Unlock()
c.cond.Broadcast()
}

func (c *Controller) commit(ctx context.Context, err error) (interface{}, error) {
if err != nil || atomic.LoadInt32(&c.queue) == 0 {
c.cnt = -c.cnt

if ep, ok := err.(PanicError); ok {
c.err = err
panic(ep.Panic)
}

if err != nil {
c.err = err
} else {
func() {
defer func() {
if p := recover(); p != nil {
c.err = PanicError{Panic: p}
}
}()

c.res, c.err = c.Commit(ctx)
}()
}
} else {
c.cond.Wait()
}

res, err := c.res, c.err

return res, err
}

func (b *Batch) QueueUp() {
if b.state != stateQueued-1 {
panic("usage: QueueUp -> Enter -> defer Exit -> Commit/Rollback")
}

b.state = stateQueued

atomic.AddInt32(&b.c.queue, 1)
}

func (b *Batch) Enter() {
if b.state != stateEntered-1 {
panic("usage: QueueUp -> Enter -> defer Exit -> Commit/Rollback")
}

b.state = stateEntered

b.i = b.c.enter()
}

func (b *Batch) Index() int {
return b.i
}

func (b *Batch) Exit() {
defer func() {
b.c.cnt++
switch b.state {
case stateNew:
return
case stateQueued:
atomic.AddInt32(&b.c.queue, -1)
return
case stateEntered, stateCommitted:
case stateExited:
panic("usage: QueueUp -> Enter -> defer Exit -> Commit/Rollback")
}

if b.c.cnt == 0 {
b.c.res, b.c.err = nil, nil
}
defer func() {
b.c.exit()

b.c.mu.Unlock()
b.c.cond.Broadcast()
b.state = stateExited
}()

if b.triggered {
if b.state == stateCommitted {
return
}

Expand All @@ -91,56 +180,31 @@ func (b *Batch) Exit() {
err = PanicError{Panic: p}
}

_, _ = b.commit(nil, err)
_, _ = b.c.commit(context.Background(), err)
}

func (b *Batch) Commit(ctx context.Context) (interface{}, error) {
return b.commit(ctx, nil)
}

func (b *Batch) Rollback(ctx context.Context, err error) (interface{}, error) {
if err == nil {
err = ErrRollback
if b.state != stateCommitted-1 {
panic("usage: QueueUp -> Enter -> defer Exit -> Commit/Rollback")
}

return b.commit(ctx, err)
b.state = stateCommitted

return b.c.commit(ctx, nil)
}

func (b *Batch) commit(ctx context.Context, err error) (interface{}, error) {
if b.triggered {
panic("usage: Enter -> defer Exit -> optional Commit with err or nil")
func (b *Batch) Rollback(ctx context.Context, err error) (interface{}, error) {
if b.state != stateCommitted-1 {
panic("usage: QueueUp -> Enter -> defer Exit -> Commit/Rollback")
}

b.triggered = true

if err != nil || atomic.LoadInt32(&b.c.queue) == 0 {
b.c.cnt = -b.c.cnt

if ep, ok := err.(PanicError); ok {
b.c.err = err
panic(ep.Panic)
}

if err != nil {
b.c.err = err
} else {
func() {
defer func() {
if p := recover(); p != nil {
b.c.err = PanicError{Panic: p}
}
}()
b.state = stateCommitted

b.c.res, b.c.err = b.c.Commit(ctx)
}()
}
} else {
b.c.cond.Wait()
if err == nil {
err = ErrRollback
}

res, err := b.c.res, b.c.err

return res, err
return b.c.commit(ctx, err)
}

func (e PanicError) Error() string {
Expand Down
80 changes: 80 additions & 0 deletions batch3_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -159,4 +159,84 @@ func TestBatch(tb *testing.T) {

reached2 = true
})

tb.Run("LowerAPI", func(tb *testing.T) {
b.Commit = func(ctx context.Context) (interface{}, error) {
return nil, nil
}

b := b.Batch()

b.QueueUp()

defer b.Exit()
b.Enter()

_, err := b.Commit(context.Background())
if err != nil {
tb.Errorf("commit: %v", err)
}
})

tb.Run("LowerAPIMisuse", func(tb *testing.T) {
b.Commit = func(ctx context.Context) (interface{}, error) {
return nil, nil
}

type testCase struct {
Name string
SkipQueue bool
DoubleQueue bool
NoEnter bool
CommitRollback bool
DoubleExit bool
}

for _, tc := range []testCase{
{Name: "SkipQueueUp", SkipQueue: true},
{Name: "DoubleQueue", DoubleQueue: true},
{Name: "NoEnter", NoEnter: true},
{Name: "CommitRollback", CommitRollback: true},
{Name: "DoubleExit", DoubleExit: true},
} {
tc := tc

tb.Run(tc.Name, func(tb *testing.T) {
defer func() {
p := recover()
if p == nil {
tb.Errorf("expected panic")
}
}()

b := b.Batch()
defer b.Exit()

if !tc.SkipQueue {
b.QueueUp()
}
if tc.DoubleQueue {
b.QueueUp()
}

if !tc.NoEnter {
b.Enter()
}

_, err := b.Commit(context.Background())
if err != nil {
tb.Errorf("commit: %v", err)
}

if tc.CommitRollback {
_, err = b.Rollback(context.Background(), nil)
_ = err
}

if tc.DoubleExit {
b.Exit()
}
})
}
})
}

0 comments on commit 084250d

Please sign in to comment.