From 85ba1d601a5d90ad21058a4fdba3a759169bfa15 Mon Sep 17 00:00:00 2001 From: Ryan Leung Date: Sat, 11 May 2024 16:44:43 +0800 Subject: [PATCH] fix test Signed-off-by: Ryan Leung --- pkg/core/region.go | 2 +- pkg/ratelimit/runner.go | 43 +++++++++------------------- pkg/statistics/region_collection.go | 10 +++---- tests/server/cluster/cluster_test.go | 18 ++++++------ 4 files changed, 29 insertions(+), 44 deletions(-) diff --git a/pkg/core/region.go b/pkg/core/region.go index b1a5c58f2036..e51f7312a690 100644 --- a/pkg/core/region.go +++ b/pkg/core/region.go @@ -262,7 +262,7 @@ func RegionFromHeartbeat(heartbeat RegionHeartbeatRequest, opts ...RegionCreateO sort.Sort(peerSlice(region.pendingPeers)) classifyVoterAndLearner(®ion) - return region + return region // nolint:govet } // Inherit inherits the buckets and region size from the parent region if bucket enabled. diff --git a/pkg/ratelimit/runner.go b/pkg/ratelimit/runner.go index e4ac4e90b9de..98b56321e154 100644 --- a/pkg/ratelimit/runner.go +++ b/pkg/ratelimit/runner.go @@ -44,13 +44,6 @@ type Runner interface { Stop() } -// TaskPool is a pool for tasks. -var TaskPool = &sync.Pool{ - New: func() any { - return &Task{} - }, -} - // Task is a task to run. type Task struct { Ctx context.Context @@ -59,11 +52,13 @@ type Task struct { submittedAt time.Time } -func NewTask(ctx context.Context, f func(context.Context), opts ...TaskOption) *Task { - task := TaskPool.Get().(*Task) - task.Ctx = ctx - task.f = f - task.Opts = TaskOpts{} +func NewTask(ctx context.Context, f func(context.Context), opts ...TaskOption) Task { + task := Task{ + Ctx: ctx, + f: f, + Opts: TaskOpts{}, + submittedAt: time.Now(), + } task.submittedAt = time.Now() for _, opt := range opts { opt(task.Opts) @@ -71,15 +66,6 @@ func NewTask(ctx context.Context, f func(context.Context), opts ...TaskOption) * return task } -// ReleaseTask releases the task. -func ReleaseTask(task *Task) { - task.Ctx = nil - task.Opts = TaskOpts{} - task.f = nil - task.submittedAt = time.Time{} - TaskPool.Put(task) -} - // ErrMaxWaitingTasksExceeded is returned when the number of waiting tasks exceeds the maximum. var ErrMaxWaitingTasksExceeded = errors.New("max waiting tasks exceeded") @@ -88,8 +74,8 @@ type ConcurrentRunner struct { name string limiter *ConcurrencyLimiter maxPendingDuration time.Duration - taskChan chan *Task - pendingTasks []*Task + taskChan chan Task + pendingTasks []Task pendingMu sync.Mutex stopChan chan struct{} wg sync.WaitGroup @@ -104,8 +90,8 @@ func NewConcurrentRunner(name string, limiter *ConcurrencyLimiter, maxPendingDur name: name, limiter: limiter, maxPendingDuration: maxPendingDuration, - taskChan: make(chan *Task), - pendingTasks: make([]*Task, 0, initialCapacity), + taskChan: make(chan Task), + pendingTasks: make([]Task, 0, initialCapacity), failedTaskCount: RunnerTaskFailedTasks.WithLabelValues(name), pendingTaskCount: make(map[string]int64), maxWaitingDuration: RunnerTaskMaxWaitingDuration.WithLabelValues(name), @@ -148,7 +134,7 @@ func (cr *ConcurrentRunner) Start() { } case <-cr.stopChan: cr.pendingMu.Lock() - cr.pendingTasks = make([]*Task, 0, initialCapacity) + cr.pendingTasks = make([]Task, 0, initialCapacity) cr.pendingMu.Unlock() log.Info("stopping async task runner", zap.String("name", cr.name)) return @@ -168,19 +154,18 @@ func (cr *ConcurrentRunner) Start() { }() } -func (cr *ConcurrentRunner) run(task *Task, token *TaskToken) { +func (cr *ConcurrentRunner) run(task Task, token *TaskToken) { task.f(task.Ctx) if token != nil { token.Release() cr.processPendingTasks() } - ReleaseTask(task) } func (cr *ConcurrentRunner) processPendingTasks() { cr.pendingMu.Lock() defer cr.pendingMu.Unlock() - for len(cr.pendingTasks) > 0 { + if len(cr.pendingTasks) > 0 { task := cr.pendingTasks[0] select { case cr.taskChan <- task: diff --git a/pkg/statistics/region_collection.go b/pkg/statistics/region_collection.go index ee30f33389c7..3dbffffa70e4 100644 --- a/pkg/statistics/region_collection.go +++ b/pkg/statistics/region_collection.go @@ -91,7 +91,7 @@ type RegionStatistics struct { syncutil.RWMutex rip RegionInfoProvider conf sc.CheckerConfigProvider - stats map[RegionStatisticType]map[uint64]*RegionInfoWithTS + stats map[RegionStatisticType]map[uint64]RegionInfoWithTS index map[uint64]RegionStatisticType ruleManager *placement.RuleManager } @@ -106,11 +106,11 @@ func NewRegionStatistics( rip: rip, conf: conf, ruleManager: ruleManager, - stats: make(map[RegionStatisticType]map[uint64]*RegionInfoWithTS), + stats: make(map[RegionStatisticType]map[uint64]RegionInfoWithTS), index: make(map[uint64]RegionStatisticType), } for _, typ := range regionStatisticTypes { - r.stats[typ] = make(map[uint64]*RegionInfoWithTS) + r.stats[typ] = make(map[uint64]RegionInfoWithTS) } return r } @@ -250,8 +250,8 @@ func (r *RegionStatistics) Observe(region *core.RegionInfo, stores []*core.Store for typ, c := range conditions { if c { info := r.stats[typ][regionID] - if info == nil { - info = &RegionInfoWithTS{id: regionID} + if info == (RegionInfoWithTS{}) { + info = RegionInfoWithTS{id: regionID} } if typ == DownPeer { if info.startDownPeerTS != 0 { diff --git a/tests/server/cluster/cluster_test.go b/tests/server/cluster/cluster_test.go index aea5ff739683..08899a5c9018 100644 --- a/tests/server/cluster/cluster_test.go +++ b/tests/server/cluster/cluster_test.go @@ -1336,37 +1336,37 @@ func TestStaleTermHeartbeat(t *testing.T) { } region := core.RegionFromHeartbeat(regionReq) - err = rc.HandleRegionHeartbeat(region) + err = rc.HandleRegionHeartbeat(®ion) re.NoError(err) // Transfer leader regionReq.Term = 6 regionReq.Leader = peers[1] - region = core.RegionFromHeartbeat(regionReq) - err = rc.HandleRegionHeartbeat(region) + region1 := core.RegionFromHeartbeat(regionReq) + err = rc.HandleRegionHeartbeat(®ion1) re.NoError(err) // issue #3379 regionReq.KeysWritten = uint64(18446744073709551615) // -1 regionReq.BytesWritten = uint64(18446744073709550602) // -1024 - region = core.RegionFromHeartbeat(regionReq) + region2 := core.RegionFromHeartbeat(regionReq) re.Equal(uint64(0), region.GetKeysWritten()) re.Equal(uint64(0), region.GetBytesWritten()) - err = rc.HandleRegionHeartbeat(region) + err = rc.HandleRegionHeartbeat(®ion2) re.NoError(err) // Stale heartbeat, update check should fail regionReq.Term = 5 regionReq.Leader = peers[0] - region = core.RegionFromHeartbeat(regionReq) - err = rc.HandleRegionHeartbeat(region) + region3 := core.RegionFromHeartbeat(regionReq) + err = rc.HandleRegionHeartbeat(®ion3) re.Error(err) // Allow regions that are created by unsafe recover to send a heartbeat, even though they // are considered "stale" because their conf ver and version are both equal to 1. regionReq.Region.RegionEpoch.ConfVer = 1 - region = core.RegionFromHeartbeat(regionReq) - err = rc.HandleRegionHeartbeat(region) + region4 := core.RegionFromHeartbeat(regionReq) + err = rc.HandleRegionHeartbeat(®ion4) re.NoError(err) }