Skip to content

Commit

Permalink
fix test
Browse files Browse the repository at this point in the history
Signed-off-by: Ryan Leung <[email protected]>
  • Loading branch information
rleungx committed May 11, 2024
1 parent b9b2971 commit 85ba1d6
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 44 deletions.
2 changes: 1 addition & 1 deletion pkg/core/region.go
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ func RegionFromHeartbeat(heartbeat RegionHeartbeatRequest, opts ...RegionCreateO
sort.Sort(peerSlice(region.pendingPeers))

classifyVoterAndLearner(&region)
return region
return region // nolint:govet
}

// Inherit inherits the buckets and region size from the parent region if bucket enabled.
Expand Down
43 changes: 14 additions & 29 deletions pkg/ratelimit/runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,6 @@ type Runner interface {
Stop()
}

// TaskPool is a pool for tasks.
var TaskPool = &sync.Pool{
New: func() any {
return &Task{}
},
}

// Task is a task to run.
type Task struct {
Ctx context.Context
Expand All @@ -59,27 +52,20 @@ type Task struct {
submittedAt time.Time
}

func NewTask(ctx context.Context, f func(context.Context), opts ...TaskOption) *Task {
task := TaskPool.Get().(*Task)
task.Ctx = ctx
task.f = f
task.Opts = TaskOpts{}
func NewTask(ctx context.Context, f func(context.Context), opts ...TaskOption) Task {
task := Task{
Ctx: ctx,
f: f,
Opts: TaskOpts{},
submittedAt: time.Now(),
}
task.submittedAt = time.Now()
for _, opt := range opts {
opt(task.Opts)
}
return task
}

// ReleaseTask releases the task.
func ReleaseTask(task *Task) {
task.Ctx = nil
task.Opts = TaskOpts{}
task.f = nil
task.submittedAt = time.Time{}
TaskPool.Put(task)
}

// ErrMaxWaitingTasksExceeded is returned when the number of waiting tasks exceeds the maximum.
var ErrMaxWaitingTasksExceeded = errors.New("max waiting tasks exceeded")

Expand All @@ -88,8 +74,8 @@ type ConcurrentRunner struct {
name string
limiter *ConcurrencyLimiter
maxPendingDuration time.Duration
taskChan chan *Task
pendingTasks []*Task
taskChan chan Task
pendingTasks []Task
pendingMu sync.Mutex
stopChan chan struct{}
wg sync.WaitGroup
Expand All @@ -104,8 +90,8 @@ func NewConcurrentRunner(name string, limiter *ConcurrencyLimiter, maxPendingDur
name: name,
limiter: limiter,
maxPendingDuration: maxPendingDuration,
taskChan: make(chan *Task),
pendingTasks: make([]*Task, 0, initialCapacity),
taskChan: make(chan Task),
pendingTasks: make([]Task, 0, initialCapacity),
failedTaskCount: RunnerTaskFailedTasks.WithLabelValues(name),
pendingTaskCount: make(map[string]int64),
maxWaitingDuration: RunnerTaskMaxWaitingDuration.WithLabelValues(name),
Expand Down Expand Up @@ -148,7 +134,7 @@ func (cr *ConcurrentRunner) Start() {
}
case <-cr.stopChan:
cr.pendingMu.Lock()
cr.pendingTasks = make([]*Task, 0, initialCapacity)
cr.pendingTasks = make([]Task, 0, initialCapacity)
cr.pendingMu.Unlock()
log.Info("stopping async task runner", zap.String("name", cr.name))
return
Expand All @@ -168,19 +154,18 @@ func (cr *ConcurrentRunner) Start() {
}()
}

func (cr *ConcurrentRunner) run(task *Task, token *TaskToken) {
func (cr *ConcurrentRunner) run(task Task, token *TaskToken) {
task.f(task.Ctx)
if token != nil {
token.Release()
cr.processPendingTasks()
}
ReleaseTask(task)
}

func (cr *ConcurrentRunner) processPendingTasks() {
cr.pendingMu.Lock()
defer cr.pendingMu.Unlock()
for len(cr.pendingTasks) > 0 {
if len(cr.pendingTasks) > 0 {
task := cr.pendingTasks[0]
select {
case cr.taskChan <- task:
Expand Down
10 changes: 5 additions & 5 deletions pkg/statistics/region_collection.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ type RegionStatistics struct {
syncutil.RWMutex
rip RegionInfoProvider
conf sc.CheckerConfigProvider
stats map[RegionStatisticType]map[uint64]*RegionInfoWithTS
stats map[RegionStatisticType]map[uint64]RegionInfoWithTS
index map[uint64]RegionStatisticType
ruleManager *placement.RuleManager
}
Expand All @@ -106,11 +106,11 @@ func NewRegionStatistics(
rip: rip,
conf: conf,
ruleManager: ruleManager,
stats: make(map[RegionStatisticType]map[uint64]*RegionInfoWithTS),
stats: make(map[RegionStatisticType]map[uint64]RegionInfoWithTS),
index: make(map[uint64]RegionStatisticType),
}
for _, typ := range regionStatisticTypes {
r.stats[typ] = make(map[uint64]*RegionInfoWithTS)
r.stats[typ] = make(map[uint64]RegionInfoWithTS)
}
return r
}
Expand Down Expand Up @@ -250,8 +250,8 @@ func (r *RegionStatistics) Observe(region *core.RegionInfo, stores []*core.Store
for typ, c := range conditions {
if c {
info := r.stats[typ][regionID]
if info == nil {
info = &RegionInfoWithTS{id: regionID}
if info == (RegionInfoWithTS{}) {
info = RegionInfoWithTS{id: regionID}
}
if typ == DownPeer {
if info.startDownPeerTS != 0 {
Expand Down
18 changes: 9 additions & 9 deletions tests/server/cluster/cluster_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1336,37 +1336,37 @@ func TestStaleTermHeartbeat(t *testing.T) {
}

region := core.RegionFromHeartbeat(regionReq)
err = rc.HandleRegionHeartbeat(region)
err = rc.HandleRegionHeartbeat(&region)
re.NoError(err)

// Transfer leader
regionReq.Term = 6
regionReq.Leader = peers[1]
region = core.RegionFromHeartbeat(regionReq)
err = rc.HandleRegionHeartbeat(region)
region1 := core.RegionFromHeartbeat(regionReq)
err = rc.HandleRegionHeartbeat(&region1)
re.NoError(err)

// issue #3379
regionReq.KeysWritten = uint64(18446744073709551615) // -1
regionReq.BytesWritten = uint64(18446744073709550602) // -1024
region = core.RegionFromHeartbeat(regionReq)
region2 := core.RegionFromHeartbeat(regionReq)
re.Equal(uint64(0), region.GetKeysWritten())
re.Equal(uint64(0), region.GetBytesWritten())
err = rc.HandleRegionHeartbeat(region)
err = rc.HandleRegionHeartbeat(&region2)
re.NoError(err)

// Stale heartbeat, update check should fail
regionReq.Term = 5
regionReq.Leader = peers[0]
region = core.RegionFromHeartbeat(regionReq)
err = rc.HandleRegionHeartbeat(region)
region3 := core.RegionFromHeartbeat(regionReq)
err = rc.HandleRegionHeartbeat(&region3)
re.Error(err)

// Allow regions that are created by unsafe recover to send a heartbeat, even though they
// are considered "stale" because their conf ver and version are both equal to 1.
regionReq.Region.RegionEpoch.ConfVer = 1
region = core.RegionFromHeartbeat(regionReq)
err = rc.HandleRegionHeartbeat(region)
region4 := core.RegionFromHeartbeat(regionReq)
err = rc.HandleRegionHeartbeat(&region4)
re.NoError(err)
}

Expand Down

0 comments on commit 85ba1d6

Please sign in to comment.