diff --git a/go.mod b/go.mod index 29ee237..d9a61c4 100644 --- a/go.mod +++ b/go.mod @@ -34,6 +34,7 @@ require ( github.com/sirupsen/logrus v1.4.2 // indirect github.com/smartystreets/goconvey v0.0.0-20190222223459-a17d461953aa // indirect github.com/soheilhy/cmux v0.1.4 // indirect + github.com/stretchr/testify v1.2.2 github.com/tmc/grpc-websocket-proxy v0.0.0-20170815181823-89b8d40f7ca8 // indirect github.com/xiang90/probing v0.0.0-20190116061207-43a291ad63a2 // indirect go.etcd.io/bbolt v1.3.2 // indirect diff --git a/util/time_wheel.go b/util/time_wheel.go index 4baabbc..24aea7b 100644 --- a/util/time_wheel.go +++ b/util/time_wheel.go @@ -45,8 +45,11 @@ type TimeWheel struct { // NewTimeWheel create new time wheel func NewTimeWheel(tick time.Duration, bucketsNum int) (*TimeWheel, error) { - if tick <= 0 || bucketsNum <= 0 { - return nil, errors.New("invalid params") + if bucketsNum <= 0 { + return nil, errors.New("bucket number must be greater than 0") + } + if int(tick.Seconds()) < 1 { + return nil, errors.New("tick cannot be less than 1s") } tw := &TimeWheel{ @@ -125,6 +128,9 @@ func (tw *TimeWheel) add(task *Task) { round := tw.calculateRound(task.delay) index := tw.calculateIndex(task.delay) task.round = round + if originIndex, ok := tw.bucketIndexes[task.key]; ok { + delete(tw.buckets[originIndex], task.key) + } tw.bucketIndexes[task.key] = index tw.buckets[index][task.key] = task } diff --git a/util/time_wheel_test.go b/util/time_wheel_test.go index 7b8ef13..b59a73d 100644 --- a/util/time_wheel_test.go +++ b/util/time_wheel_test.go @@ -15,19 +15,26 @@ package util import ( - "fmt" "strconv" + "sync/atomic" "testing" "time" + + "github.com/stretchr/testify/assert" ) type A struct { - a int - b string + a int + b string + isCallbacked int32 +} + +func (a *A) callback() { + atomic.StoreInt32(&a.isCallbacked, 1) } -func callback() { - fmt.Println("timeout") +func (a *A) getCallbackValue() int32 { + return atomic.LoadInt32(&a.isCallbacked) } func newTimeWheel() *TimeWheel { @@ -39,33 +46,75 @@ func newTimeWheel() *TimeWheel { return tw } +func TestNewTimeWheel(t *testing.T) { + tests := []struct { + name string + tick time.Duration + bucketNum int + hasErr bool + }{ + {tick: time.Second, bucketNum: 0, hasErr: true}, + {tick: time.Millisecond, bucketNum: 1, hasErr: true}, + {tick: time.Second, bucketNum: 1, hasErr: false}, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + _, err := NewTimeWheel(test.tick, test.bucketNum) + assert.Equal(t, test.hasErr, err != nil) + }) + } +} + func TestAdd(t *testing.T) { tw := newTimeWheel() - err := tw.Add(time.Second*1, "test", callback) - if err != nil { - t.Fatalf("test add failed, %v", err) + a := &A{} + err := tw.Add(time.Second*1, "test", a.callback) + assert.NoError(t, err) + + time.Sleep(time.Millisecond * 500) + assert.Equal(t, int32(0), a.getCallbackValue()) + time.Sleep(time.Second * 2) + assert.Equal(t, int32(1), a.getCallbackValue()) + tw.Stop() +} + +func TestAddMultipleTimes(t *testing.T) { + a := &A{} + tw := newTimeWheel() + for i := 0; i < 4; i++ { + err := tw.Add(time.Second, "test", a.callback) + assert.NoError(t, err) + time.Sleep(time.Millisecond * 500) + t.Logf("current: %d", i) + assert.Equal(t, int32(0), a.getCallbackValue()) } - time.Sleep(time.Second * 5) + + time.Sleep(time.Second * 2) + assert.Equal(t, int32(1), a.getCallbackValue()) tw.Stop() } func TestRemove(t *testing.T) { a := &A{a: 10, b: "test"} tw := newTimeWheel() - err := tw.Add(time.Second*1, a, callback) - if err != nil { - t.Fatalf("test add failed, %v", err) - } - tw.Remove(a) - time.Sleep(time.Second * 5) + err := tw.Add(time.Second*1, a, a.callback) + assert.NoError(t, err) + + time.Sleep(time.Millisecond * 500) + assert.Equal(t, int32(0), a.getCallbackValue()) + err = tw.Remove(a) + assert.NoError(t, err) + time.Sleep(time.Second * 2) + assert.Equal(t, int32(0), a.getCallbackValue()) tw.Stop() } func BenchmarkAdd(b *testing.B) { + a := &A{} tw := newTimeWheel() for i := 0; i < b.N; i++ { key := "test" + strconv.Itoa(i) - err := tw.Add(time.Second, key, callback) + err := tw.Add(time.Second, key, a.callback) if err != nil { b.Fatalf("benchmark Add failed, %v", err) }