Skip to content

Commit

Permalink
Merge pull request #11 from KyberNetwork/ft/batcher-flush
Browse files Browse the repository at this point in the history
feat: batcher can now be flushed manually
  • Loading branch information
NgoKimPhu authored Aug 21, 2024
2 parents 4260526 + 253651b commit 7f09211
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 22 deletions.
21 changes: 21 additions & 0 deletions batcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,8 @@ func (c *ChanTask[R]) Resolve(ret R, err error) {
type Batcher[T BatchableTask[R], R any] interface {
// Batch submits a BatchableTask to the batcher.
Batch(task T)
// Flush executes tasks currently waiting in queue immediately.
Flush()
// Close should stop Batch from being called and clean up any background resources.
Close()
}
Expand All @@ -108,6 +110,7 @@ type ChanBatcher[T BatchableTask[R], R any] struct {
batchCfg BatchCfg
batchFn BatchFn[T]
taskCh chan T
flushCh chan struct{}
closed atomic.Bool
}

Expand All @@ -117,6 +120,7 @@ func NewChanBatcher[T BatchableTask[R], R any](batchCfg BatchCfg, batchFn BatchF
batchCfg: batchCfg,
batchFn: batchFn,
taskCh: make(chan T, 16*batchCnt),
flushCh: make(chan struct{}, 1),
}
go chanBatcher.worker()
return chanBatcher
Expand All @@ -131,6 +135,16 @@ func (b *ChanBatcher[T, R]) Batch(task T) {
}
}

// Flush executes tasks currently waiting in queue immediately.
func (b *ChanBatcher[T, R]) Flush() {
if !b.closed.Load() {
select {
case b.flushCh <- struct{}{}:
default:
}
}
}

// Close closes this chanBatcher to prevents Batch-ing new BatchableTask's and tell the worker goroutine to finish up.
func (b *ChanBatcher[_, _]) Close() {
if !b.closed.Swap(true) {
Expand Down Expand Up @@ -185,6 +199,13 @@ func (b *ChanBatcher[T, R]) worker() {
klog.Debugf(tasks[0].Ctx(), "ChanBatcher.worker|timer|%d tasks", len(tasks))
go b.batchFnWithRecover(tasks)
tasks = tasks[:0:0]
case <-b.flushCh:
if len(tasks) == 0 {
break
}
klog.Debugf(tasks[0].Ctx(), "ChanBatcher.worker|flush|%d tasks", len(tasks))
go b.batchFnWithRecover(tasks)
tasks = tasks[:0:0]
case task, ok := <-b.taskCh:
if !ok {
ctx := context.Background()
Expand Down
70 changes: 48 additions & 22 deletions batcher_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ func TestChanBatcher(t *testing.T) {
return batchRate, 2
}, func(tasks []*ChanTask[time.Duration]) { batchFn(tasks) })
var cnt atomic.Uint32
start := time.Now()
var start time.Time
batchFn = func(tasks []*ChanTask[time.Duration]) {
cnt.Add(1)
for _, task := range tasks {
Expand All @@ -31,29 +31,54 @@ func TestChanBatcher(t *testing.T) {
task0 := NewChanTask[time.Duration](ctx)
task1 := NewChanTask[time.Duration](ctx)
task2 := NewChanTask[time.Duration](ctx)
task3 := NewChanTask[time.Duration](ctx)

t.Run("happy", func(t *testing.T) {
batcher.Batch(task0)
batcher.Batch(task1)
_, _ = task0.Result()
assert.EqualValues(t, 1, cnt.Load())
assert.NoError(t, task0.Err)
assert.Less(t, task0.Ret, batchRate)
ret, err := task1.Result()
assert.NoError(t, err)
assert.Less(t, ret, batchRate)
time.Sleep(batchRate * 11 / 10)
runtime.Gosched()
t.Run("trigger max", func(t *testing.T) {
start = time.Now()
batcher.Batch(task0)
batcher.Batch(task1)
_, _ = task0.Result()
assert.EqualValues(t, 1, cnt.Load())
assert.NoError(t, task0.Err)
assert.Less(t, task0.Ret, batchRate)
ret, err := task1.Result()
assert.NoError(t, err)
assert.Less(t, ret, batchRate)
time.Sleep(batchRate * 11 / 10)
runtime.Gosched()
})

batcher.Batch(task2)
assert.False(t, task2.IsDone())
ret, err = task2.Result()
assert.True(t, task2.IsDone())
assert.EqualValues(t, 2, cnt.Load())
assert.Equal(t, task2.Err, err)
assert.NoError(t, task2.Err)
assert.Equal(t, task2.Ret, ret)
assert.Greater(t, ret, batchRate)
t.Run("trigger timer after blocked by .Result()", func(t *testing.T) {
start = time.Now()
batcher.Batch(task2)
assert.False(t, task2.IsDone())
ret, err := task2.Result()
assert.True(t, task2.IsDone())
assert.EqualValues(t, 2, cnt.Load())
assert.Equal(t, task2.Err, err)
assert.NoError(t, task2.Err)
assert.Equal(t, task2.Ret, ret)
assert.Greater(t, ret, batchRate)
})

t.Run("trigger flush", func(t *testing.T) {
start = time.Now()
batcher.Batch(task3)
assert.False(t, task3.IsDone())
batcher.Flush()
batcher.Flush()
ret, err := task3.Result()
assert.True(t, task3.IsDone())
assert.EqualValues(t, 3, cnt.Load())
assert.Equal(t, task3.Err, err)
assert.NoError(t, task3.Err)
assert.Equal(t, task3.Ret, ret)
assert.Less(t, ret, batchRate)
batcher.Flush()
batcher.Flush()
assert.EqualValues(t, 3, cnt.Load())
})
})

t.Run("spam", func(t *testing.T) {
Expand Down Expand Up @@ -111,13 +136,14 @@ func TestChanBatcher(t *testing.T) {
assert.ErrorIs(t, task0.Err, panicErr)
assert.ErrorIs(t, task1.Err, panicErr)

start = time.Now()
batchFn = oldBatchFn
task2 = NewChanTask[time.Duration](nil) // nolint:staticcheck
batcher.Batch(task2)
batcher.Batch(task2)
ret, err := task2.Result()
assert.NoError(t, err)
assert.Greater(t, ret, batchRate)
assert.Less(t, ret, batchRate)
})

t.Run("cancelled task", func(t *testing.T) {
Expand Down

0 comments on commit 7f09211

Please sign in to comment.