From ae6efcde5015992df5af39e7be68ea8a1c2b58aa Mon Sep 17 00:00:00 2001 From: Ryan Leung Date: Tue, 23 Jul 2024 11:15:00 +0800 Subject: [PATCH] pass context to task Signed-off-by: Ryan Leung --- pkg/cluster/cluster.go | 16 ++++++++++++++-- pkg/core/region.go | 5 +++-- pkg/mcs/scheduling/server/cluster.go | 18 ++++++++---------- pkg/ratelimit/runner.go | 12 ++++++------ pkg/ratelimit/runner_test.go | 6 +++--- server/cluster/cluster.go | 20 +++++++++----------- 6 files changed, 43 insertions(+), 34 deletions(-) diff --git a/pkg/cluster/cluster.go b/pkg/cluster/cluster.go index 2cf5787646a..ddba8f89fb6 100644 --- a/pkg/cluster/cluster.go +++ b/pkg/cluster/cluster.go @@ -15,6 +15,8 @@ package cluster import ( + "context" + "github.com/tikv/pd/pkg/core" "github.com/tikv/pd/pkg/schedule" "github.com/tikv/pd/pkg/schedule/placement" @@ -56,8 +58,13 @@ func HandleStatsAsync(c Cluster, region *core.RegionInfo) { } // HandleOverlaps handles the overlap regions. -func HandleOverlaps(c Cluster, overlaps []*core.RegionInfo) { +func HandleOverlaps(ctx context.Context, c Cluster, overlaps []*core.RegionInfo) { for _, item := range overlaps { + select { + case <-ctx.Done(): + return + default: + } if c.GetRegionStats() != nil { c.GetRegionStats().ClearDefunctRegion(item.GetID()) } @@ -67,7 +74,7 @@ func HandleOverlaps(c Cluster, overlaps []*core.RegionInfo) { } // Collect collects the cluster information. -func Collect(c Cluster, region *core.RegionInfo, hasRegionStats bool) { +func Collect(ctx context.Context, c Cluster, region *core.RegionInfo, hasRegionStats bool) { if hasRegionStats { // get region again from root tree. make sure the observed region is the latest. bc := c.GetBasicCluster() @@ -78,6 +85,11 @@ func Collect(c Cluster, region *core.RegionInfo, hasRegionStats bool) { if region == nil { return } + select { + case <-ctx.Done(): + return + default: + } c.GetRegionStats().Observe(region, c.GetBasicCluster().GetRegionStores(region)) } } diff --git a/pkg/core/region.go b/pkg/core/region.go index eb8b89aecff..4f7af8cc333 100644 --- a/pkg/core/region.go +++ b/pkg/core/region.go @@ -16,6 +16,7 @@ package core import ( "bytes" + "context" "encoding/hex" "fmt" "math" @@ -750,7 +751,7 @@ func GenerateRegionGuideFunc(enableLog bool) RegionGuideFunc { logRunner.RunTask( regionID, "DebugLog", - func() { + func(context.Context) { d(msg, fields...) }, ) @@ -759,7 +760,7 @@ func GenerateRegionGuideFunc(enableLog bool) RegionGuideFunc { logRunner.RunTask( regionID, "InfoLog", - func() { + func(context.Context) { i(msg, fields...) }, ) diff --git a/pkg/mcs/scheduling/server/cluster.go b/pkg/mcs/scheduling/server/cluster.go index 24a75012331..c86c739f724 100644 --- a/pkg/mcs/scheduling/server/cluster.go +++ b/pkg/mcs/scheduling/server/cluster.go @@ -627,10 +627,8 @@ func (c *Cluster) processRegionHeartbeat(ctx *core.MetaProcessContext, region *c ctx.TaskRunner.RunTask( regionID, ratelimit.ObserveRegionStatsAsync, - func() { - if c.regionStats.RegionStatsNeedUpdate(region) { - cluster.Collect(c, region, hasRegionStats) - } + func(ctx context.Context) { + cluster.Collect(ctx, c, region, hasRegionStats) }, ) } @@ -639,7 +637,7 @@ func (c *Cluster) processRegionHeartbeat(ctx *core.MetaProcessContext, region *c ctx.TaskRunner.RunTask( regionID, ratelimit.UpdateSubTree, - func() { + func(context.Context) { c.CheckAndPutSubTree(region) }, ratelimit.WithRetained(true), @@ -663,7 +661,7 @@ func (c *Cluster) processRegionHeartbeat(ctx *core.MetaProcessContext, region *c ctx.TaskRunner.RunTask( regionID, ratelimit.UpdateSubTree, - func() { + func(context.Context) { c.CheckAndPutSubTree(region) }, ratelimit.WithRetained(retained), @@ -672,8 +670,8 @@ func (c *Cluster) processRegionHeartbeat(ctx *core.MetaProcessContext, region *c ctx.TaskRunner.RunTask( regionID, ratelimit.HandleOverlaps, - func() { - cluster.HandleOverlaps(c, overlaps) + func(ctx context.Context) { + cluster.HandleOverlaps(ctx, c, overlaps) }, ) } @@ -682,8 +680,8 @@ func (c *Cluster) processRegionHeartbeat(ctx *core.MetaProcessContext, region *c ctx.TaskRunner.RunTask( regionID, ratelimit.CollectRegionStatsAsync, - func() { - cluster.Collect(c, region, hasRegionStats) + func(ctx context.Context) { + cluster.Collect(ctx, c, region, hasRegionStats) }, ) tracer.OnCollectRegionStatsFinished() diff --git a/pkg/ratelimit/runner.go b/pkg/ratelimit/runner.go index 57a19e4e682..a230177ac73 100644 --- a/pkg/ratelimit/runner.go +++ b/pkg/ratelimit/runner.go @@ -42,7 +42,7 @@ const ( // Runner is the interface for running tasks. type Runner interface { - RunTask(id uint64, name string, f func(), opts ...TaskOption) error + RunTask(id uint64, name string, f func(context.Context), opts ...TaskOption) error Start(ctx context.Context) Stop() } @@ -51,7 +51,7 @@ type Runner interface { type Task struct { id uint64 submittedAt time.Time - f func() + f func(context.Context) name string // retained indicates whether the task should be dropped if the task queue exceeds maxPendingDuration. retained bool @@ -152,7 +152,7 @@ func (cr *ConcurrentRunner) run(ctx context.Context, task *Task, token *TaskToke return default: } - task.f() + task.f(ctx) if token != nil { cr.limiter.ReleaseToken(token) cr.processPendingTasks() @@ -184,7 +184,7 @@ func (cr *ConcurrentRunner) Stop() { } // RunTask runs the task asynchronously. -func (cr *ConcurrentRunner) RunTask(id uint64, name string, f func(), opts ...TaskOption) error { +func (cr *ConcurrentRunner) RunTask(id uint64, name string, f func(context.Context), opts ...TaskOption) error { task := &Task{ id: id, name: name, @@ -238,8 +238,8 @@ func NewSyncRunner() *SyncRunner { } // RunTask runs the task synchronously. -func (*SyncRunner) RunTask(_ uint64, _ string, f func(), _ ...TaskOption) error { - f() +func (*SyncRunner) RunTask(_ uint64, _ string, f func(context.Context), _ ...TaskOption) error { + f(context.Background()) return nil } diff --git a/pkg/ratelimit/runner_test.go b/pkg/ratelimit/runner_test.go index d4aa0825e83..a9090804a08 100644 --- a/pkg/ratelimit/runner_test.go +++ b/pkg/ratelimit/runner_test.go @@ -36,7 +36,7 @@ func TestConcurrentRunner(t *testing.T) { err := runner.RunTask( uint64(i), "test1", - func() { + func(context.Context) { defer wg.Done() time.Sleep(100 * time.Millisecond) }, @@ -56,7 +56,7 @@ func TestConcurrentRunner(t *testing.T) { err := runner.RunTask( uint64(i), "test2", - func() { + func(context.Context) { defer wg.Done() time.Sleep(100 * time.Millisecond) }, @@ -87,7 +87,7 @@ func TestConcurrentRunner(t *testing.T) { err := runner.RunTask( regionID, "test3", - func() { + func(context.Context) { time.Sleep(time.Second) }, ) diff --git a/server/cluster/cluster.go b/server/cluster/cluster.go index ed1080f617a..d1f89ca2128 100644 --- a/server/cluster/cluster.go +++ b/server/cluster/cluster.go @@ -1061,10 +1061,8 @@ func (c *RaftCluster) processRegionHeartbeat(ctx *core.MetaProcessContext, regio ctx.MiscRunner.RunTask( regionID, ratelimit.ObserveRegionStatsAsync, - func() { - if c.regionStats.RegionStatsNeedUpdate(region) { - cluster.Collect(c, region, hasRegionStats) - } + func(ctx context.Context) { + cluster.Collect(ctx, c, region, hasRegionStats) }, ) } @@ -1073,7 +1071,7 @@ func (c *RaftCluster) processRegionHeartbeat(ctx *core.MetaProcessContext, regio ctx.TaskRunner.RunTask( regionID, ratelimit.UpdateSubTree, - func() { + func(context.Context) { c.CheckAndPutSubTree(region) }, ratelimit.WithRetained(true), @@ -1101,7 +1099,7 @@ func (c *RaftCluster) processRegionHeartbeat(ctx *core.MetaProcessContext, regio ctx.TaskRunner.RunTask( regionID, ratelimit.UpdateSubTree, - func() { + func(context.Context) { c.CheckAndPutSubTree(region) }, ratelimit.WithRetained(retained), @@ -1112,8 +1110,8 @@ func (c *RaftCluster) processRegionHeartbeat(ctx *core.MetaProcessContext, regio ctx.MiscRunner.RunTask( regionID, ratelimit.HandleOverlaps, - func() { - cluster.HandleOverlaps(c, overlaps) + func(ctx context.Context) { + cluster.HandleOverlaps(ctx, c, overlaps) }, ) } @@ -1125,11 +1123,11 @@ func (c *RaftCluster) processRegionHeartbeat(ctx *core.MetaProcessContext, regio ctx.MiscRunner.RunTask( regionID, ratelimit.CollectRegionStatsAsync, - func() { + func(ctx context.Context) { // TODO: Due to the accuracy requirements of the API "/regions/check/xxx", // region stats needs to be collected in API mode. // We need to think of a better way to reduce this part of the cost in the future. - cluster.Collect(c, region, hasRegionStats) + cluster.Collect(ctx, c, region, hasRegionStats) }, ) @@ -1139,7 +1137,7 @@ func (c *RaftCluster) processRegionHeartbeat(ctx *core.MetaProcessContext, regio ctx.MiscRunner.RunTask( regionID, ratelimit.SaveRegionToKV, - func() { + func(context.Context) { // If there are concurrent heartbeats from the same region, the last write will win even if // writes to storage in the critical area. So don't use mutex to protect it. // Not successfully saved to storage is not fatal, it only leads to longer warm-up