diff --git a/pkg/mcs/scheduling/server/cluster.go b/pkg/mcs/scheduling/server/cluster.go index b18db7c0798..24a75012331 100644 --- a/pkg/mcs/scheduling/server/cluster.go +++ b/pkg/mcs/scheduling/server/cluster.go @@ -549,9 +549,9 @@ func (c *Cluster) StartBackgroundJobs() { go c.runUpdateStoreStats() go c.runCoordinator() go c.runMetricsCollectionJob() - c.heartbeatRunner.Start() - c.miscRunner.Start() - c.logRunner.Start() + c.heartbeatRunner.Start(c.ctx) + c.miscRunner.Start(c.ctx) + c.logRunner.Start(c.ctx) c.running.Store(true) } diff --git a/pkg/ratelimit/runner.go b/pkg/ratelimit/runner.go index 2d88e36106e..57a19e4e682 100644 --- a/pkg/ratelimit/runner.go +++ b/pkg/ratelimit/runner.go @@ -43,7 +43,7 @@ const ( // Runner is the interface for running tasks. type Runner interface { RunTask(id uint64, name string, f func(), opts ...TaskOption) error - Start() + Start(ctx context.Context) Stop() } @@ -66,12 +66,13 @@ type taskID struct { } type ConcurrentRunner struct { + ctx context.Context + cancel context.CancelFunc name string limiter *ConcurrencyLimiter maxPendingDuration time.Duration taskChan chan *Task pendingMu sync.Mutex - stopChan chan struct{} wg sync.WaitGroup pendingTaskCount map[string]int pendingTasks []*Task @@ -103,8 +104,8 @@ func WithRetained(retained bool) TaskOption { } // Start starts the runner. -func (cr *ConcurrentRunner) Start() { - cr.stopChan = make(chan struct{}) +func (cr *ConcurrentRunner) Start(ctx context.Context) { + cr.ctx, cr.cancel = context.WithCancel(ctx) cr.wg.Add(1) ticker := time.NewTicker(5 * time.Second) defer ticker.Stop() @@ -118,11 +119,11 @@ func (cr *ConcurrentRunner) Start() { if err != nil { continue } - go cr.run(task, token) + go cr.run(cr.ctx, task, token) } else { - go cr.run(task, nil) + go cr.run(cr.ctx, task, nil) } - case <-cr.stopChan: + case <-cr.ctx.Done(): cr.pendingMu.Lock() cr.pendingTasks = make([]*Task, 0, initialCapacity) cr.pendingMu.Unlock() @@ -144,8 +145,13 @@ func (cr *ConcurrentRunner) Start() { }() } -func (cr *ConcurrentRunner) run(task *Task, token *TaskToken) { +func (cr *ConcurrentRunner) run(ctx context.Context, task *Task, token *TaskToken) { start := time.Now() + select { + case <-ctx.Done(): + return + default: + } task.f() if token != nil { cr.limiter.ReleaseToken(token) @@ -173,7 +179,7 @@ func (cr *ConcurrentRunner) processPendingTasks() { // Stop stops the runner. func (cr *ConcurrentRunner) Stop() { - close(cr.stopChan) + cr.cancel() cr.wg.Wait() } @@ -238,7 +244,7 @@ func (*SyncRunner) RunTask(_ uint64, _ string, f func(), _ ...TaskOption) error } // Start starts the runner. -func (*SyncRunner) Start() {} +func (*SyncRunner) Start(context.Context) {} // Stop stops the runner. func (*SyncRunner) Stop() {} diff --git a/pkg/ratelimit/runner_test.go b/pkg/ratelimit/runner_test.go index 0335a78bcbe..d4aa0825e83 100644 --- a/pkg/ratelimit/runner_test.go +++ b/pkg/ratelimit/runner_test.go @@ -15,6 +15,7 @@ package ratelimit import ( + "context" "sync" "testing" "time" @@ -25,7 +26,7 @@ import ( func TestConcurrentRunner(t *testing.T) { t.Run("RunTask", func(t *testing.T) { runner := NewConcurrentRunner("test", NewConcurrencyLimiter(1), time.Second) - runner.Start() + runner.Start(context.TODO()) defer runner.Stop() var wg sync.WaitGroup @@ -47,7 +48,7 @@ func TestConcurrentRunner(t *testing.T) { t.Run("MaxPendingDuration", func(t *testing.T) { runner := NewConcurrentRunner("test", NewConcurrencyLimiter(1), 2*time.Millisecond) - runner.Start() + runner.Start(context.TODO()) defer runner.Stop() var wg sync.WaitGroup for i := 0; i < 10; i++ { @@ -76,7 +77,7 @@ func TestConcurrentRunner(t *testing.T) { t.Run("DuplicatedTask", func(t *testing.T) { runner := NewConcurrentRunner("test", NewConcurrencyLimiter(1), time.Minute) - runner.Start() + runner.Start(context.TODO()) defer runner.Stop() for i := 1; i < 11; i++ { regionID := uint64(i) diff --git a/server/cluster/cluster.go b/server/cluster/cluster.go index 93be9d1c076..ed1080f617a 100644 --- a/server/cluster/cluster.go +++ b/server/cluster/cluster.go @@ -364,9 +364,9 @@ func (c *RaftCluster) Start(s Server) error { go c.startGCTuner() c.running = true - c.heartbeatRunner.Start() - c.miscRunner.Start() - c.logRunner.Start() + c.heartbeatRunner.Start(c.ctx) + c.miscRunner.Start(c.ctx) + c.logRunner.Start(c.ctx) return nil }