diff --git a/pkg/cluster/cluster.go b/pkg/cluster/cluster.go index 916200bfa3ec..cf11c866a42e 100644 --- a/pkg/cluster/cluster.go +++ b/pkg/cluster/cluster.go @@ -28,6 +28,7 @@ type Cluster interface { GetLabelStats() *statistics.LabelStatistics GetCoordinator() *schedule.Coordinator GetRuleManager() *placement.RuleManager + GetBasicCluster() *core.BasicCluster } // HandleStatsAsync handles the flow asynchronously. @@ -55,8 +56,10 @@ func HandleOverlaps(c Cluster, overlaps []*core.RegionInfo) { } // Collect collects the cluster information. -func Collect(c Cluster, region *core.RegionInfo, stores []*core.StoreInfo, hasRegionStats bool) { +func Collect(c Cluster, region *core.RegionInfo, hasRegionStats bool) { if hasRegionStats { - c.GetRegionStats().Observe(region, stores) + // get region again from root tree. make sure the observed region is the latest. + region = c.GetBasicCluster().GetRegion(region.GetID()) + c.GetRegionStats().Observe(region, c.GetBasicCluster().GetRegionStores(region)) } } diff --git a/pkg/core/region.go b/pkg/core/region.go index f7a4ef5f0fd5..2da138cd8262 100644 --- a/pkg/core/region.go +++ b/pkg/core/region.go @@ -16,6 +16,7 @@ package core import ( "bytes" + "context" "encoding/hex" "fmt" "math" @@ -35,6 +36,8 @@ import ( "github.com/pingcap/kvproto/pkg/replication_modepb" "github.com/pingcap/log" "github.com/tikv/pd/pkg/errs" + "github.com/tikv/pd/pkg/ratelimit" + "github.com/tikv/pd/pkg/utils/ctxutil" "github.com/tikv/pd/pkg/utils/logutil" "github.com/tikv/pd/pkg/utils/syncutil" "github.com/tikv/pd/pkg/utils/typeutil" @@ -711,20 +714,51 @@ func (r *RegionInfo) isRegionRecreated() bool { // RegionGuideFunc is a function that determines which follow-up operations need to be performed based on the origin // and new region information. -type RegionGuideFunc func(region, origin *RegionInfo) (saveKV, saveCache, needSync bool) +type RegionGuideFunc func(ctx context.Context, region, origin *RegionInfo) (saveKV, saveCache, needSync bool) // GenerateRegionGuideFunc is used to generate a RegionGuideFunc. Control the log output by specifying the log function. // nil means do not print the log. func GenerateRegionGuideFunc(enableLog bool) RegionGuideFunc { noLog := func(msg string, fields ...zap.Field) {} - debug, info := noLog, noLog + d, i := noLog, noLog + debug, info := d, i if enableLog { - debug = log.Debug - info = log.Info + d = log.Debug + i = log.Info + debug, info = d, i } // Save to storage if meta is updated. // Save to cache if meta or leader is updated, or contains any down/pending peer. - return func(region, origin *RegionInfo) (saveKV, saveCache, needSync bool) { + return func(ctx context.Context, region, origin *RegionInfo) (saveKV, saveCache, needSync bool) { + taskRunner, ok := ctx.Value("taskRunner").(ratelimit.Runner) + limiter, _ := ctx.Value("limiter").(*ratelimit.ConcurrencyLimiter) + // print log asynchronously + if ok { + debug = func(msg string, fields ...zap.Field) { + taskRunner.RunTask( + ctx, + ratelimit.TaskOpts{ + TaskName: "Log", + Limit: limiter, + }, + func(ctx context.Context) { + d(msg, fields...) + }, + ) + } + info = func(msg string, fields ...zap.Field) { + taskRunner.RunTask( + ctx, + ratelimit.TaskOpts{ + TaskName: "Log", + Limit: limiter, + }, + func(ctx context.Context) { + i(msg, fields...) + }, + ) + } + } if origin == nil { if log.GetLevel() <= zap.DebugLevel { debug("insert new region", @@ -789,7 +823,7 @@ func GenerateRegionGuideFunc(enableLog bool) RegionGuideFunc { } if !SortedPeersStatsEqual(region.GetDownPeers(), origin.GetDownPeers()) { if log.GetLevel() <= zap.DebugLevel { - debug("down-peers changed", zap.Uint64("region-id", region.GetID())) + debug("down-peers changed", zap.Uint64("region-id", region.GetID()), zap.Reflect("before", origin.GetDownPeers()), zap.Reflect("after", region.GetDownPeers())) } saveCache, needSync = true, true return @@ -912,7 +946,7 @@ func (r *RegionsInfo) CheckAndPutRegion(region *RegionInfo) []*RegionInfo { if origin == nil || !bytes.Equal(origin.GetStartKey(), region.GetStartKey()) || !bytes.Equal(origin.GetEndKey(), region.GetEndKey()) { ols = r.tree.overlaps(®ionItem{RegionInfo: region}) } - err := check(region, origin, ols) + err := check(region, origin, convertItemsToRegions(ols)) if err != nil { log.Debug("region is stale", zap.Stringer("origin", origin.GetMeta()), errs.ZapError(err)) // return the state region to delete. @@ -933,48 +967,116 @@ func (r *RegionsInfo) PutRegion(region *RegionInfo) []*RegionInfo { } // PreCheckPutRegion checks if the region is valid to put. -func (r *RegionsInfo) PreCheckPutRegion(region *RegionInfo, trace RegionHeartbeatProcessTracer) (*RegionInfo, []*regionItem, error) { - origin, overlaps := r.GetRelevantRegions(region, trace) +func (r *RegionsInfo) PreCheckPutRegion(region *RegionInfo) (*RegionInfo, []*RegionInfo, error) { + origin, overlaps := r.GetRelevantRegions(region) err := check(region, origin, overlaps) return origin, overlaps, err } +func convertItemsToRegions(items []*regionItem) []*RegionInfo { + regions := make([]*RegionInfo, 0, len(items)) + for _, item := range items { + regions = append(regions, item.RegionInfo) + } + return regions +} + // AtomicCheckAndPutRegion checks if the region is valid to put, if valid then put. -func (r *RegionsInfo) AtomicCheckAndPutRegion(region *RegionInfo, trace RegionHeartbeatProcessTracer) ([]*RegionInfo, error) { +func (r *RegionsInfo) AtomicCheckAndPutRegion(ctx context.Context, region *RegionInfo) ([]*RegionInfo, error) { + tracer, ok := ctx.Value("tracer").(RegionHeartbeatProcessTracer) + if !ok { + tracer = NewNoopHeartbeatProcessTracer() + } r.t.Lock() var ols []*regionItem origin := r.getRegionLocked(region.GetID()) if origin == nil || !bytes.Equal(origin.GetStartKey(), region.GetStartKey()) || !bytes.Equal(origin.GetEndKey(), region.GetEndKey()) { ols = r.tree.overlaps(®ionItem{RegionInfo: region}) } - trace.OnCheckOverlapsFinished() - err := check(region, origin, ols) + tracer.OnCheckOverlapsFinished() + err := check(region, origin, convertItemsToRegions(ols)) if err != nil { r.t.Unlock() - trace.OnValidateRegionFinished() + tracer.OnValidateRegionFinished() return nil, err } - trace.OnValidateRegionFinished() + tracer.OnValidateRegionFinished() origin, overlaps, rangeChanged := r.setRegionLocked(region, true, ols...) r.t.Unlock() - trace.OnSetRegionFinished() + tracer.OnSetRegionFinished() r.UpdateSubTree(region, origin, overlaps, rangeChanged) - trace.OnUpdateSubTreeFinished() + tracer.OnUpdateSubTreeFinished() + return overlaps, nil +} + +// CheckAndPutRootTree checks if the region is valid to put to the root, if valid then return error. +// Usually used with AtomicCheckAndPutSubTree together. +func (r *RegionsInfo) CheckAndPutRootTree(ctx context.Context, region *RegionInfo) ([]*RegionInfo, error) { + tracer, ok := ctx.Value(ctxutil.HeartbeatTracerKey).(RegionHeartbeatProcessTracer) + if !ok { + tracer = NewNoopHeartbeatProcessTracer() + } + r.t.Lock() + var ols []*regionItem + origin := r.getRegionLocked(region.GetID()) + if origin == nil || !bytes.Equal(origin.GetStartKey(), region.GetStartKey()) || !bytes.Equal(origin.GetEndKey(), region.GetEndKey()) { + ols = r.tree.overlaps(®ionItem{RegionInfo: region}) + } + tracer.OnCheckOverlapsFinished() + err := check(region, origin, convertItemsToRegions(ols)) + if err != nil { + r.t.Unlock() + tracer.OnValidateRegionFinished() + return nil, err + } + tracer.OnValidateRegionFinished() + _, overlaps, _ := r.setRegionLocked(region, true, ols...) + r.t.Unlock() + tracer.OnSetRegionFinished() return overlaps, nil } +// CheckAndPutSubTree checks if the region is valid to put to the sub tree, if valid then return error. +// Usually used with AtomicCheckAndPutRootTree together. +func (r *RegionsInfo) CheckAndPutSubTree(ctx context.Context, region *RegionInfo) error { + var origin *RegionInfo + r.st.RLock() + // get origin from sub tree. + originItem, ok := r.subRegions[region.GetID()] + if ok { + origin = originItem.RegionInfo + } + r.st.RUnlock() + // new region get from root tree again + var ( + overlaps []*RegionInfo + newRegion *RegionInfo + ) + rangeChanged := true + if origin != nil { + newRegion, overlaps = r.GetRelevantRegions(origin) + rangeChanged = !origin.rangeEqualsTo(region) + } else { + newRegion = r.GetRegion(region.GetID()) + } + r.UpdateSubTree(newRegion, origin, overlaps, rangeChanged) + return nil +} + // GetRelevantRegions returns the relevant regions for a given region. -func (r *RegionsInfo) GetRelevantRegions(region *RegionInfo, trace RegionHeartbeatProcessTracer) (origin *RegionInfo, overlaps []*regionItem) { +func (r *RegionsInfo) GetRelevantRegions(region *RegionInfo) (origin *RegionInfo, overlaps []*RegionInfo) { r.t.RLock() defer r.t.RUnlock() origin = r.getRegionLocked(region.GetID()) if origin == nil || !bytes.Equal(origin.GetStartKey(), region.GetStartKey()) || !bytes.Equal(origin.GetEndKey(), region.GetEndKey()) { - overlaps = r.tree.overlaps(®ionItem{RegionInfo: region}) + for _, item := range r.tree.overlaps(®ionItem{RegionInfo: region}) { + overlaps = append(overlaps, item.RegionInfo) + } } return } -func check(region, origin *RegionInfo, overlaps []*regionItem) error { +func check(region, origin *RegionInfo, overlaps []*RegionInfo) error { for _, item := range overlaps { // PD ignores stale regions' heartbeats, unless it is recreated recently by unsafe recover operation. if region.GetRegionEpoch().GetVersion() < item.GetRegionEpoch().GetVersion() && !region.isRegionRecreated() { @@ -1065,6 +1167,9 @@ func (r *RegionsInfo) UpdateSubTree(region, origin *RegionInfo, overlaps []*Regi r.st.Lock() defer r.st.Unlock() if origin != nil { + // TODO: Check if the pointer address is consistent? + // eg &origin == &subtree[origin.GetID]. + // compare the pointer of subtree with the pointer of origin to ensure no changes have occurred before update subtree. if rangeChanged || !origin.peersEqualTo(region) { // If the range or peers have changed, the sub regionTree needs to be cleaned up. // TODO: Improve performance by deleting only the different peers. diff --git a/pkg/core/region_test.go b/pkg/core/region_test.go index 3c6536a6a773..5bb09eb52b07 100644 --- a/pkg/core/region_test.go +++ b/pkg/core/region_test.go @@ -15,6 +15,7 @@ package core import ( + "context" "crypto/rand" "fmt" "math" @@ -363,7 +364,7 @@ func TestNeedSync(t *testing.T) { for _, testCase := range testCases { regionA := region.Clone(testCase.optionsA...) regionB := region.Clone(testCase.optionsB...) - _, _, needSync := RegionGuide(regionA, regionB) + _, _, needSync := RegionGuide(context.TODO(), regionA, regionB) re.Equal(testCase.needSync, needSync) } } @@ -459,9 +460,9 @@ func TestSetRegionConcurrence(t *testing.T) { regions := NewRegionsInfo() region := NewTestRegionInfo(1, 1, []byte("a"), []byte("b")) go func() { - regions.AtomicCheckAndPutRegion(region, NewNoopHeartbeatProcessTracer()) + regions.AtomicCheckAndPutRegion(context.TODO(), region) }() - regions.AtomicCheckAndPutRegion(region, NewNoopHeartbeatProcessTracer()) + regions.AtomicCheckAndPutRegion(context.TODO(), region) re.NoError(failpoint.Disable("github.com/tikv/pd/pkg/core/UpdateSubTree")) } diff --git a/pkg/core/region_tree.go b/pkg/core/region_tree.go index 333e1730ec8a..5a633a2639c1 100644 --- a/pkg/core/region_tree.go +++ b/pkg/core/region_tree.go @@ -35,6 +35,10 @@ func (r *regionItem) GetStartKey() []byte { return r.meta.StartKey } +func (r *regionItem) GetID() uint64 { + return r.meta.GetId() +} + // GetEndKey returns the end key of the region. func (r *regionItem) GetEndKey() []byte { return r.meta.EndKey diff --git a/pkg/mcs/scheduling/server/cluster.go b/pkg/mcs/scheduling/server/cluster.go index 1b915b6874d2..d25ad0da6cbc 100644 --- a/pkg/mcs/scheduling/server/cluster.go +++ b/pkg/mcs/scheduling/server/cluster.go @@ -2,6 +2,7 @@ package server import ( "context" + "runtime" "sync" "sync/atomic" "time" @@ -15,6 +16,7 @@ import ( "github.com/tikv/pd/pkg/core" "github.com/tikv/pd/pkg/errs" "github.com/tikv/pd/pkg/mcs/scheduling/server/config" + "github.com/tikv/pd/pkg/ratelimit" "github.com/tikv/pd/pkg/schedule" sc "github.com/tikv/pd/pkg/schedule/config" "github.com/tikv/pd/pkg/schedule/hbstream" @@ -29,6 +31,7 @@ import ( "github.com/tikv/pd/pkg/statistics/buckets" "github.com/tikv/pd/pkg/statistics/utils" "github.com/tikv/pd/pkg/storage" + "github.com/tikv/pd/pkg/utils/ctxutil" "github.com/tikv/pd/pkg/utils/logutil" "go.uber.org/zap" ) @@ -51,6 +54,9 @@ type Cluster struct { apiServerLeader atomic.Value clusterID uint64 running atomic.Bool + + taskRunner ratelimit.Runner + hbConcurrencyLimiter *ratelimit.ConcurrencyLimiter } const ( @@ -81,6 +87,9 @@ func NewCluster(parentCtx context.Context, persistConfig *config.PersistConfig, storage: storage, clusterID: clusterID, checkMembershipCh: checkMembershipCh, + + taskRunner: ratelimit.NewAsyncRunner("heartbeat-async-task-runner", 1000000), + hbConcurrencyLimiter: ratelimit.NewConcurrencyLimiter(uint64(runtime.NumCPU() * 2)), } c.coordinator = schedule.NewCoordinator(ctx, c, hbStreams) err = c.ruleManager.Initialize(persistConfig.GetMaxReplicas(), persistConfig.GetLocationLabels(), persistConfig.GetIsolationLevel()) @@ -536,6 +545,8 @@ func (c *Cluster) IsBackgroundJobsRunning() bool { return c.running.Load() } +var syncRunner = ratelimit.NewSyncRunner() + // HandleRegionHeartbeat processes RegionInfo reports from client. func (c *Cluster) HandleRegionHeartbeat(region *core.RegionInfo) error { tracer := core.NewNoopHeartbeatProcessTracer() @@ -543,7 +554,13 @@ func (c *Cluster) HandleRegionHeartbeat(region *core.RegionInfo) error { tracer = core.NewHeartbeatProcessTracer() } tracer.Begin() - if err := c.processRegionHeartbeat(region, tracer); err != nil { + ctx := context.WithValue(c.ctx, ctxutil.HeartbeatTracerKey, tracer) + ctx = context.WithValue(ctx, ctxutil.LimiterKey, c.hbConcurrencyLimiter) + if c.persistConfig.GetScheduleConfig().EnableHeartbeatAsyncRunner { + ctx = context.WithValue(ctx, ctxutil.TaskRunnerKey, c.taskRunner) + } + + if err := c.processRegionHeartbeat(ctx, region); err != nil { tracer.OnAllStageFinished() return err } @@ -553,26 +570,55 @@ func (c *Cluster) HandleRegionHeartbeat(region *core.RegionInfo) error { } // processRegionHeartbeat updates the region information. -func (c *Cluster) processRegionHeartbeat(region *core.RegionInfo, tracer core.RegionHeartbeatProcessTracer) error { - origin, _, err := c.PreCheckPutRegion(region, tracer) +func (c *Cluster) processRegionHeartbeat(ctx context.Context, region *core.RegionInfo) error { + tracer, ok := ctx.Value(ctxutil.HeartbeatTracerKey).(core.RegionHeartbeatProcessTracer) + if !ok { + tracer = core.NewNoopHeartbeatProcessTracer() + } + runner, ok := ctx.Value(ctxutil.TaskRunnerKey).(ratelimit.Runner) + if !ok { + runner = syncRunner + } + limiter, _ := ctx.Value(ctxutil.LimiterKey).(*ratelimit.ConcurrencyLimiter) + origin, _, err := c.PreCheckPutRegion(region) tracer.OnPreCheckFinished() if err != nil { return err } region.Inherit(origin, c.GetStoreConfig().IsEnableRegionBucket()) - cluster.HandleStatsAsync(c, region) + runner.RunTask( + ctx, + ratelimit.TaskOpts{ + TaskName: "HandleStatsAsync", + Limit: limiter, + }, + func(ctx context.Context) { + cluster.HandleStatsAsync(c, region) + }, + ) tracer.OnAsyncHotStatsFinished() hasRegionStats := c.regionStats != nil // Save to storage if meta is updated, except for flashback. // Save to cache if meta or leader is updated, or contains any down/pending peer. - _, saveCache, _ := core.GenerateRegionGuideFunc(true)(region, origin) + _, saveCache, _ := core.GenerateRegionGuideFunc(true)(ctx, region, origin) if !saveCache { // Due to some config changes need to update the region stats as well, // so we do some extra checks here. if hasRegionStats && c.regionStats.RegionStatsNeedUpdate(region) { - c.regionStats.Observe(region, c.GetRegionStores(region)) + runner.RunTask( + ctx, + ratelimit.TaskOpts{ + TaskName: "ObserveRegionStatsAsync", + Limit: limiter, + }, + func(ctx context.Context) { + if c.regionStats.RegionStatsNeedUpdate(region) { + c.regionStats.Observe(region, c.GetRegionStores(region)) + } + }, + ) } return nil } @@ -583,15 +629,25 @@ func (c *Cluster) processRegionHeartbeat(region *core.RegionInfo, tracer core.Re // check its validation again here. // // However, it can't solve the race condition of concurrent heartbeats from the same region. - if overlaps, err = c.AtomicCheckAndPutRegion(region, tracer); err != nil { + if overlaps, err = c.CheckAndPutRootTree(ctx, region); err != nil { tracer.OnSaveCacheFinished() return err } - + runner.RunTask( + ctx, + ratelimit.TaskOpts{ + TaskName: "UpdateSubTree", + Limit: limiter, + }, + func(ctx context.Context) { + c.CheckAndPutSubTree(ctx, region) + }, + ) + tracer.OnUpdateSubTreeFinished() cluster.HandleOverlaps(c, overlaps) } tracer.OnSaveCacheFinished() - cluster.Collect(c, region, c.GetRegionStores(region), hasRegionStats) + cluster.Collect(c, region, hasRegionStats) tracer.OnCollectRegionStatsFinished() return nil } diff --git a/pkg/ratelimit/concurrency_limiter.go b/pkg/ratelimit/concurrency_limiter.go index b1eef3c8101b..2f5b38992148 100644 --- a/pkg/ratelimit/concurrency_limiter.go +++ b/pkg/ratelimit/concurrency_limiter.go @@ -14,24 +14,31 @@ package ratelimit -import "github.com/tikv/pd/pkg/utils/syncutil" +import ( + "context" -type concurrencyLimiter struct { - mu syncutil.RWMutex + "github.com/tikv/pd/pkg/utils/syncutil" +) + +type ConcurrencyLimiter struct { + mu syncutil.Mutex current uint64 limit uint64 + waiting uint64 // statistic maxLimit uint64 + queue chan *TaskToken } -func newConcurrencyLimiter(limit uint64) *concurrencyLimiter { - return &concurrencyLimiter{limit: limit} +// NewConcurrencyLimiter creates a new ConcurrencyLimiter. +func NewConcurrencyLimiter(limit uint64) *ConcurrencyLimiter { + return &ConcurrencyLimiter{limit: limit, queue: make(chan *TaskToken, limit)} } const unlimit = uint64(0) -func (l *concurrencyLimiter) allow() bool { +func (l *ConcurrencyLimiter) allow() bool { l.mu.Lock() defer l.mu.Unlock() @@ -45,7 +52,7 @@ func (l *concurrencyLimiter) allow() bool { return false } -func (l *concurrencyLimiter) release() { +func (l *ConcurrencyLimiter) release() { l.mu.Lock() defer l.mu.Unlock() @@ -54,28 +61,28 @@ func (l *concurrencyLimiter) release() { } } -func (l *concurrencyLimiter) getLimit() uint64 { - l.mu.RLock() - defer l.mu.RUnlock() +func (l *ConcurrencyLimiter) getLimit() uint64 { + l.mu.Lock() + defer l.mu.Unlock() return l.limit } -func (l *concurrencyLimiter) setLimit(limit uint64) { +func (l *ConcurrencyLimiter) setLimit(limit uint64) { l.mu.Lock() defer l.mu.Unlock() l.limit = limit } -func (l *concurrencyLimiter) getCurrent() uint64 { - l.mu.RLock() - defer l.mu.RUnlock() +func (l *ConcurrencyLimiter) getCurrent() uint64 { + l.mu.Lock() + defer l.mu.Unlock() return l.current } -func (l *concurrencyLimiter) getMaxConcurrency() uint64 { +func (l *ConcurrencyLimiter) getMaxConcurrency() uint64 { l.mu.Lock() defer func() { l.maxLimit = l.current @@ -84,3 +91,65 @@ func (l *concurrencyLimiter) getMaxConcurrency() uint64 { return l.maxLimit } + +// GetRunningTasks returns the number of running tasks. +func (l *ConcurrencyLimiter) GetRunningTasksNum() uint64 { + return l.getCurrent() +} + +// GetWaitingTasks returns the number of waiting tasks. +func (l *ConcurrencyLimiter) GetWaitingTasksNum() uint64 { + l.mu.Lock() + defer l.mu.Unlock() + return l.waiting +} + +// Acquire acquires a token from the limiter. which will block until a token is available or ctx is done, like Timeout. +func (l *ConcurrencyLimiter) Acquire(ctx context.Context) (*TaskToken, error) { + l.mu.Lock() + if l.current >= l.limit { + l.waiting++ + l.mu.Unlock() + select { + case <-ctx.Done(): + l.mu.Lock() + l.waiting-- + l.mu.Unlock() + return nil, ctx.Err() + case token := <-l.queue: + l.mu.Lock() + token.released = false + l.current++ + l.waiting-- + l.mu.Unlock() + return token, nil + } + } + l.current++ + token := &TaskToken{limiter: l} + l.mu.Unlock() + return token, nil +} + +// TaskToken is a token that must be released after the task is done. +type TaskToken struct { + released bool + limiter *ConcurrencyLimiter +} + +// Release releases the token. +func (tt *TaskToken) Release() { + tt.limiter.mu.Lock() + defer tt.limiter.mu.Unlock() + if tt.released { + return + } + if tt.limiter.current == 0 { + panic("release token more than acquire") + } + tt.released = true + tt.limiter.current-- + if len(tt.limiter.queue) < int(tt.limiter.limit) { + tt.limiter.queue <- tt + } +} diff --git a/pkg/ratelimit/concurrency_limiter_test.go b/pkg/ratelimit/concurrency_limiter_test.go index 5fe03740394a..216da1ac8a02 100644 --- a/pkg/ratelimit/concurrency_limiter_test.go +++ b/pkg/ratelimit/concurrency_limiter_test.go @@ -15,7 +15,12 @@ package ratelimit import ( + "context" + "fmt" + "sync" + "sync/atomic" "testing" + "time" "github.com/stretchr/testify/require" ) @@ -23,7 +28,7 @@ import ( func TestConcurrencyLimiter(t *testing.T) { t.Parallel() re := require.New(t) - cl := newConcurrencyLimiter(10) + cl := NewConcurrencyLimiter(10) for i := 0; i < 10; i++ { re.True(cl.allow()) } @@ -52,3 +57,72 @@ func TestConcurrencyLimiter(t *testing.T) { re.Equal(uint64(5), cl.getMaxConcurrency()) re.Equal(uint64(0), cl.getMaxConcurrency()) } + +func TestConcurrencyLimiter2(t *testing.T) { + limit := uint64(2) + limiter := NewConcurrencyLimiter(limit) + + require.Equal(t, uint64(0), limiter.GetRunningTasksNum(), "Expected running tasks to be 0") + require.Equal(t, uint64(0), limiter.GetWaitingTasksNum(), "Expected waiting tasks to be 0") + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + // Acquire two tokens + token1, err := limiter.Acquire(ctx) + require.NoError(t, err, "Failed to acquire token") + + token2, err := limiter.Acquire(ctx) + require.NoError(t, err, "Failed to acquire token") + + require.Equal(t, limit, limiter.GetRunningTasksNum(), "Expected running tasks to be 2") + + // Try to acquire third token, it should not be able to acquire immediately due to limit + go func() { + _, err := limiter.Acquire(ctx) + require.NoError(t, err, "Failed to acquire token") + }() + + time.Sleep(100 * time.Millisecond) // Give some time for the goroutine to run + require.Equal(t, uint64(1), limiter.GetWaitingTasksNum(), "Expected waiting tasks to be 1") + + // Release a token + token1.Release() + time.Sleep(100 * time.Millisecond) // Give some time for the goroutine to run + require.Equal(t, uint64(2), limiter.GetRunningTasksNum(), "Expected running tasks to be 2") + require.Equal(t, uint64(0), limiter.GetWaitingTasksNum(), "Expected waiting tasks to be 0") + + // Release the second token + token2.Release() + time.Sleep(100 * time.Millisecond) // Give some time for the goroutine to run + require.Equal(t, uint64(1), limiter.GetRunningTasksNum(), "Expected running tasks to be 1") +} + +func TestConcurrencyLimiterAcquire(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + limiter := NewConcurrencyLimiter(20) + sum := int64(0) + start := time.Now() + wg := &sync.WaitGroup{} + wg.Add(100) + for i := 0; i < 100; i++ { + go func(i int) { + defer wg.Done() + token, err := limiter.Acquire(ctx) + if err != nil { + fmt.Printf("Task %d failed to acquire: %v\n", i, err) + return + } + defer token.Release() + // simulate takes some time + time.Sleep(10 * time.Millisecond) + atomic.AddInt64(&sum, 1) + }(i) + } + wg.Wait() + // We should have 20 tasks running concurrently, so it should take at least 50ms to complete + require.Greater(t, time.Since(start).Milliseconds(), int64(50)) + require.Equal(t, int64(100), sum) +} diff --git a/pkg/ratelimit/limiter.go b/pkg/ratelimit/limiter.go index dc744d9ac1b0..7b3eba10325a 100644 --- a/pkg/ratelimit/limiter.go +++ b/pkg/ratelimit/limiter.go @@ -36,18 +36,18 @@ type DimensionConfig struct { type limiter struct { mu syncutil.RWMutex - concurrency *concurrencyLimiter + concurrency *ConcurrencyLimiter rate *RateLimiter } func newLimiter() *limiter { lim := &limiter{ - concurrency: newConcurrencyLimiter(0), + concurrency: NewConcurrencyLimiter(0), } return lim } -func (l *limiter) getConcurrencyLimiter() *concurrencyLimiter { +func (l *limiter) getConcurrencyLimiter() *ConcurrencyLimiter { l.mu.RLock() defer l.mu.RUnlock() return l.concurrency @@ -101,7 +101,7 @@ func (l *limiter) updateConcurrencyConfig(limit uint64) UpdateStatus { } l.concurrency.setLimit(limit) } else { - l.concurrency = newConcurrencyLimiter(limit) + l.concurrency = NewConcurrencyLimiter(limit) } return ConcurrencyChanged } diff --git a/pkg/ratelimit/runner.go b/pkg/ratelimit/runner.go new file mode 100644 index 000000000000..a32a50960f0d --- /dev/null +++ b/pkg/ratelimit/runner.go @@ -0,0 +1,114 @@ +// Copyright 2022 TiKV Project Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ratelimit + +import ( + "context" + "errors" + "sync/atomic" + + "github.com/pingcap/log" + "go.uber.org/zap" +) + +// Runner is the interface for running tasks. +type Runner interface { + RunTask(ctx context.Context, opt TaskOpts, f func(context.Context)) error +} + +// ErrMaxWaitingTasksExceeded is returned when the number of waiting tasks exceeds the maximum. +var ErrMaxWaitingTasksExceeded = errors.New("max waiting tasks exceeded") + +// Runner is a simple task runner that limits the number of concurrent tasks. +type AsyncRunner struct { + numTasks int64 + maxPendingTasks uint64 + name string +} + +// NewAsyncRunner creates a new AsyncRunner. +func NewAsyncRunner(name string, maxPendingTasks uint64) *AsyncRunner { + return &AsyncRunner{name: name, maxPendingTasks: maxPendingTasks} +} + +// TaskOpts is the options for RunTask. +type TaskOpts struct { + // TaskName is a human-readable name for the operation. TODO: metrics by name. + TaskName string + Limit *ConcurrencyLimiter +} + +// RunTask except the callback f is run in a goroutine. +// The call doesn't block for the callback to finish execution. +func (s *AsyncRunner) RunTask(ctx context.Context, opt TaskOpts, f func(context.Context)) error { + if opt.Limit != nil && atomic.LoadInt64(&s.numTasks) >= int64(s.maxPendingTasks) { + return ErrMaxWaitingTasksExceeded + } + s.addTask(1) + go func(ctx context.Context, opt TaskOpts) { + var token *TaskToken + if opt.Limit != nil { + // Wait for permission to run from the semaphore. + var err error + if opt.Limit != nil { + token, err = opt.Limit.Acquire(ctx) + } + if err != nil { + log.Error("failed to acquire semaphore", zap.String("task-name", opt.TaskName), zap.Error(err)) + return + } + + // Check for canceled context: it's possible to get the semaphore even + // if the context is canceled. + if ctx.Err() != nil { + log.Error("context is canceled", zap.String("task-name", opt.TaskName)) + return + } + } + defer s.addTask(-1) + defer s.recover() + if token != nil { + defer token.Release() + } + f(ctx) + }(ctx, opt) + + return nil +} + +func (s *AsyncRunner) recover() { + if r := recover(); r != nil { + log.Error("panic in runner", zap.Any("error", r)) + return + } +} + +func (s *AsyncRunner) addTask(delta int64) (updated int64) { + return atomic.AddInt64(&s.numTasks, delta) +} + +// SyncRunner is a simple task runner that limits the number of concurrent tasks. +type SyncRunner struct{} + +// NewSyncRunner creates a new SyncRunner. +func NewSyncRunner() *SyncRunner { + return &SyncRunner{} +} + +// RunTask runs the task synchronously. +func (s *SyncRunner) RunTask(ctx context.Context, opt TaskOpts, f func(context.Context)) error { + f(ctx) + return nil +} diff --git a/pkg/schedule/config/config.go b/pkg/schedule/config/config.go index 56038ddcb098..2a1b63bfe8f5 100644 --- a/pkg/schedule/config/config.go +++ b/pkg/schedule/config/config.go @@ -52,6 +52,7 @@ const ( defaultEnableJointConsensus = true defaultEnableTiKVSplitRegion = true defaultEnableHeartbeatBreakdownMetrics = true + defaultEnableHeartbeatAsyncRunner = true defaultEnableCrossTableMerge = true defaultEnableDiagnostic = true defaultStrictlyMatchLabel = false @@ -267,6 +268,9 @@ type ScheduleConfig struct { // EnableHeartbeatBreakdownMetrics is the option to enable heartbeat stats metrics. EnableHeartbeatBreakdownMetrics bool `toml:"enable-heartbeat-breakdown-metrics" json:"enable-heartbeat-breakdown-metrics,string"` + // EnableHeartbeatAsyncRunner is the option to enable heartbeat async runner. + EnableHeartbeatAsyncRunner bool `toml:"enable-heartbeat-async-runner" json:"enable-heartbeat-async-runner,string"` + // Schedulers support for loading customized schedulers Schedulers SchedulerConfigs `toml:"schedulers" json:"schedulers-v2"` // json v2 is for the sake of compatible upgrade @@ -382,6 +386,10 @@ func (c *ScheduleConfig) Adjust(meta *configutil.ConfigMetaData, reloading bool) c.EnableHeartbeatBreakdownMetrics = defaultEnableHeartbeatBreakdownMetrics } + if !meta.IsDefined("enable-heartbeat-async-runner") { + c.EnableHeartbeatAsyncRunner = defaultEnableHeartbeatAsyncRunner + } + if !meta.IsDefined("enable-cross-table-merge") { c.EnableCrossTableMerge = defaultEnableCrossTableMerge } diff --git a/pkg/syncer/client.go b/pkg/syncer/client.go index ffbd71d2f1ea..558423722ffe 100644 --- a/pkg/syncer/client.go +++ b/pkg/syncer/client.go @@ -200,13 +200,12 @@ func (s *RegionSyncer) StartSyncWithLeader(addr string) { region = core.NewRegionInfo(r, regionLeader, core.SetSource(core.Sync)) } - tracer := core.NewNoopHeartbeatProcessTracer() - origin, _, err := bc.PreCheckPutRegion(region, tracer) + origin, _, err := bc.PreCheckPutRegion(region) if err != nil { log.Debug("region is stale", zap.Stringer("origin", origin.GetMeta()), errs.ZapError(err)) continue } - saveKV, _, _ := regionGuide(region, origin) + saveKV, _, _ := regionGuide(ctx, region, origin) overlaps := bc.PutRegion(region) if hasBuckets { diff --git a/pkg/utils/ctxutil/context.go b/pkg/utils/ctxutil/context.go new file mode 100644 index 000000000000..8a4c73218287 --- /dev/null +++ b/pkg/utils/ctxutil/context.go @@ -0,0 +1,27 @@ +// Copyright 2024 TiKV Project Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ctxutil + +// CtxKey is a custom type used as a key for values stored in Context. +type CtxKey string + +const ( + // HeartbeatTracerKey is the key for the heartbeat tracer in the context. + HeartbeatTracerKey CtxKey = "h_tracer" + // TaskRunnerKey is the key for the task runner in the context. + TaskRunnerKey CtxKey = "task_runner" + // LimiterKey is the key for the concurrency limiter in the context. + LimiterKey CtxKey = "limiter" +) diff --git a/server/cluster/cluster.go b/server/cluster/cluster.go index 354e12020e3b..5f4d9d5b2431 100644 --- a/server/cluster/cluster.go +++ b/server/cluster/cluster.go @@ -21,6 +21,7 @@ import ( "io" "math" "net/http" + "runtime" "strconv" "strings" "sync" @@ -44,6 +45,7 @@ import ( mcsutils "github.com/tikv/pd/pkg/mcs/utils" "github.com/tikv/pd/pkg/memory" "github.com/tikv/pd/pkg/progress" + "github.com/tikv/pd/pkg/ratelimit" "github.com/tikv/pd/pkg/replication" sc "github.com/tikv/pd/pkg/schedule/config" "github.com/tikv/pd/pkg/schedule/hbstream" @@ -56,6 +58,7 @@ import ( "github.com/tikv/pd/pkg/storage/endpoint" "github.com/tikv/pd/pkg/syncer" "github.com/tikv/pd/pkg/unsaferecovery" + "github.com/tikv/pd/pkg/utils/ctxutil" "github.com/tikv/pd/pkg/utils/etcdutil" "github.com/tikv/pd/pkg/utils/logutil" "github.com/tikv/pd/pkg/utils/netutil" @@ -166,6 +169,9 @@ type RaftCluster struct { keyspaceGroupManager *keyspace.GroupManager independentServices sync.Map hbstreams *hbstream.HeartbeatStreams + + taskRunner ratelimit.Runner + hbConcurrencyLimiter *ratelimit.ConcurrencyLimiter } // Status saves some state information. @@ -182,13 +188,15 @@ type Status struct { func NewRaftCluster(ctx context.Context, clusterID uint64, basicCluster *core.BasicCluster, storage storage.Storage, regionSyncer *syncer.RegionSyncer, etcdClient *clientv3.Client, httpClient *http.Client) *RaftCluster { return &RaftCluster{ - serverCtx: ctx, - clusterID: clusterID, - regionSyncer: regionSyncer, - httpClient: httpClient, - etcdClient: etcdClient, - core: basicCluster, - storage: storage, + serverCtx: ctx, + clusterID: clusterID, + regionSyncer: regionSyncer, + httpClient: httpClient, + etcdClient: etcdClient, + core: basicCluster, + storage: storage, + taskRunner: ratelimit.NewAsyncRunner("heartbeat-async-task-runner", 1000000), + hbConcurrencyLimiter: ratelimit.NewConcurrencyLimiter(uint64(runtime.NumCPU() * 2)), } } @@ -988,10 +996,21 @@ func (c *RaftCluster) processReportBuckets(buckets *metapb.Buckets) error { } var regionGuide = core.GenerateRegionGuideFunc(true) +var syncRunner = ratelimit.NewSyncRunner() // processRegionHeartbeat updates the region information. -func (c *RaftCluster) processRegionHeartbeat(region *core.RegionInfo, tracer core.RegionHeartbeatProcessTracer) error { - origin, _, err := c.core.PreCheckPutRegion(region, tracer) +func (c *RaftCluster) processRegionHeartbeat(ctx context.Context, region *core.RegionInfo) error { + tracer, ok := ctx.Value(ctxutil.HeartbeatTracerKey).(core.RegionHeartbeatProcessTracer) + if !ok { + tracer = core.NewNoopHeartbeatProcessTracer() + } + runner, ok := ctx.Value(ctxutil.TaskRunnerKey).(ratelimit.Runner) + if !ok { + runner = syncRunner + } + limiter, _ := ctx.Value(ctxutil.LimiterKey).(*ratelimit.ConcurrencyLimiter) + + origin, _, err := c.core.PreCheckPutRegion(region) tracer.OnPreCheckFinished() if err != nil { return err @@ -1000,13 +1019,22 @@ func (c *RaftCluster) processRegionHeartbeat(region *core.RegionInfo, tracer cor region.Inherit(origin, c.GetStoreConfig().IsEnableRegionBucket()) if !c.IsServiceIndependent(mcsutils.SchedulingServiceName) { - cluster.HandleStatsAsync(c, region) + runner.RunTask( + ctx, + ratelimit.TaskOpts{ + TaskName: "HandleStatsAsync", + Limit: limiter, + }, + func(ctx context.Context) { + cluster.HandleStatsAsync(c, region) + }, + ) } tracer.OnAsyncHotStatsFinished() hasRegionStats := c.regionStats != nil // Save to storage if meta is updated, except for flashback. // Save to cache if meta or leader is updated, or contains any down/pending peer. - saveKV, saveCache, needSync := regionGuide(region, origin) + saveKV, saveCache, needSync := regionGuide(ctx, region, origin) tracer.OnRegionGuideFinished() if !saveKV && !saveCache { // Due to some config changes need to update the region stats as well, @@ -1015,7 +1043,19 @@ func (c *RaftCluster) processRegionHeartbeat(region *core.RegionInfo, tracer cor // region stats needs to be collected in API mode. // We need to think of a better way to reduce this part of the cost in the future. if hasRegionStats && c.regionStats.RegionStatsNeedUpdate(region) { - c.regionStats.Observe(region, c.getRegionStoresLocked(region)) + runner.RunTask( + ctx, + ratelimit.TaskOpts{ + TaskName: "ObserveRegionStatsAsync", + Limit: limiter, + }, + func(ctx context.Context) { + // get region agian to avoid the region is stale + // concurrency cannot make sure the order anyway + region = c.GetRegion(region.GetID()) + c.regionStats.Observe(region, c.GetRegionStores(region)) + }, + ) } return nil } @@ -1032,43 +1072,83 @@ func (c *RaftCluster) processRegionHeartbeat(region *core.RegionInfo, tracer cor // check its validation again here. // // However, it can't solve the race condition of concurrent heartbeats from the same region. - if overlaps, err = c.core.AtomicCheckAndPutRegion(region, tracer); err != nil { + if overlaps, err = c.core.CheckAndPutRootTree(ctx, region); err != nil { tracer.OnSaveCacheFinished() return err } + runner.RunTask( + ctx, + ratelimit.TaskOpts{ + TaskName: "UpdateSubTree", + Limit: limiter, + }, + func(ctx context.Context) { + c.core.CheckAndPutSubTree(ctx, region) + }, + ) + tracer.OnUpdateSubTreeFinished() if !c.IsServiceIndependent(mcsutils.SchedulingServiceName) { - cluster.HandleOverlaps(c, overlaps) + runner.RunTask( + ctx, + ratelimit.TaskOpts{ + TaskName: "HandleOverlaps", + Limit: limiter, + }, + func(ctx context.Context) { + cluster.HandleOverlaps(c, overlaps) + }, + ) } regionUpdateCacheEventCounter.Inc() } tracer.OnSaveCacheFinished() - // TODO: Due to the accuracy requirements of the API "/regions/check/xxx", - // region stats needs to be collected in API mode. - // We need to think of a better way to reduce this part of the cost in the future. - cluster.Collect(c, region, c.GetRegionStores(region), hasRegionStats) + // handle region stats + runner.RunTask( + ctx, + ratelimit.TaskOpts{ + TaskName: "CollectRegionStatsAsync", + Limit: c.hbConcurrencyLimiter, + }, + func(ctx context.Context) { + // TODO: Due to the accuracy requirements of the API "/regions/check/xxx", + // region stats needs to be collected in API mode. + // We need to think of a better way to reduce this part of the cost in the future. + cluster.Collect(c, region, hasRegionStats) + }, + ) + tracer.OnCollectRegionStatsFinished() if c.storage != nil { - // If there are concurrent heartbeats from the same region, the last write will win even if - // writes to storage in the critical area. So don't use mutex to protect it. - // Not successfully saved to storage is not fatal, it only leads to longer warm-up - // after restart. Here we only log the error then go on updating cache. - for _, item := range overlaps { - if err := c.storage.DeleteRegion(item.GetMeta()); err != nil { - log.Error("failed to delete region from storage", - zap.Uint64("region-id", item.GetID()), - logutil.ZapRedactStringer("region-meta", core.RegionToHexMeta(item.GetMeta())), - errs.ZapError(err)) - } - } if saveKV { - if err := c.storage.SaveRegion(region.GetMeta()); err != nil { - log.Error("failed to save region to storage", - zap.Uint64("region-id", region.GetID()), - logutil.ZapRedactStringer("region-meta", core.RegionToHexMeta(region.GetMeta())), - errs.ZapError(err)) - } - regionUpdateKVEventCounter.Inc() + runner.RunTask( + ctx, + ratelimit.TaskOpts{ + TaskName: "SaveRegionToKV", + Limit: c.hbConcurrencyLimiter, + }, + func(ctx context.Context) { + // If there are concurrent heartbeats from the same region, the last write will win even if + // writes to storage in the critical area. So don't use mutex to protect it. + // Not successfully saved to storage is not fatal, it only leads to longer warm-up + // after restart. Here we only log the error then go on updating cache. + for _, item := range overlaps { + if err := c.storage.DeleteRegion(item.GetMeta()); err != nil { + log.Error("failed to delete region from storage", + zap.Uint64("region-id", item.GetID()), + logutil.ZapRedactStringer("region-meta", core.RegionToHexMeta(item.GetMeta())), + errs.ZapError(err)) + } + } + if err := c.storage.SaveRegion(region.GetMeta()); err != nil { + log.Error("failed to save region to storage", + zap.Uint64("region-id", region.GetID()), + logutil.ZapRedactStringer("region-meta", core.RegionToHexMeta(region.GetMeta())), + errs.ZapError(err)) + } + regionUpdateKVEventCounter.Inc() + }, + ) } } @@ -2067,16 +2147,6 @@ func (c *RaftCluster) resetProgressIndicator() { storesETAGauge.Reset() } -func (c *RaftCluster) getRegionStoresLocked(region *core.RegionInfo) []*core.StoreInfo { - stores := make([]*core.StoreInfo, 0, len(region.GetPeers())) - for _, p := range region.GetPeers() { - if store := c.core.GetStore(p.StoreId); store != nil { - stores = append(stores, store) - } - } - return stores -} - // OnStoreVersionChange changes the version of the cluster when needed. func (c *RaftCluster) OnStoreVersionChange() { c.RLock() diff --git a/server/cluster/cluster_test.go b/server/cluster/cluster_test.go index dc0f79667614..a3a9e5fbb73e 100644 --- a/server/cluster/cluster_test.go +++ b/server/cluster/cluster_test.go @@ -54,6 +54,7 @@ import ( "github.com/tikv/pd/pkg/statistics" "github.com/tikv/pd/pkg/statistics/utils" "github.com/tikv/pd/pkg/storage" + "github.com/tikv/pd/pkg/utils/ctxutil" "github.com/tikv/pd/pkg/utils/operatorutil" "github.com/tikv/pd/pkg/utils/testutil" "github.com/tikv/pd/pkg/utils/typeutil" @@ -631,7 +632,7 @@ func TestRegionHeartbeatHotStat(t *testing.T) { region := core.NewRegionInfo(regionMeta, leader, core.WithInterval(&pdpb.TimeInterval{StartTimestamp: 0, EndTimestamp: utils.RegionHeartBeatReportInterval}), core.SetWrittenBytes(30000*10), core.SetWrittenKeys(300000*10)) - err = cluster.processRegionHeartbeat(region, core.NewNoopHeartbeatProcessTracer()) + err = cluster.processRegionHeartbeat(context.TODO(), region) re.NoError(err) // wait HotStat to update items time.Sleep(time.Second) @@ -644,7 +645,7 @@ func TestRegionHeartbeatHotStat(t *testing.T) { StoreId: 4, } region = region.Clone(core.WithRemoveStorePeer(2), core.WithAddPeer(newPeer)) - err = cluster.processRegionHeartbeat(region, core.NewNoopHeartbeatProcessTracer()) + err = cluster.processRegionHeartbeat(context.TODO(), region) re.NoError(err) // wait HotStat to update items time.Sleep(time.Second) @@ -681,8 +682,8 @@ func TestBucketHeartbeat(t *testing.T) { re.NoError(cluster.putStoreLocked(store)) } - re.NoError(cluster.processRegionHeartbeat(regions[0], core.NewNoopHeartbeatProcessTracer())) - re.NoError(cluster.processRegionHeartbeat(regions[1], core.NewNoopHeartbeatProcessTracer())) + re.NoError(cluster.processRegionHeartbeat(context.TODO(), regions[0])) + re.NoError(cluster.processRegionHeartbeat(context.TODO(), regions[1])) re.Nil(cluster.GetRegion(uint64(1)).GetBuckets()) re.NoError(cluster.processReportBuckets(buckets)) re.Equal(buckets, cluster.GetRegion(uint64(1)).GetBuckets()) @@ -701,13 +702,13 @@ func TestBucketHeartbeat(t *testing.T) { // case5: region update should inherit buckets. newRegion := regions[1].Clone(core.WithIncConfVer(), core.SetBuckets(nil)) opt.SetRegionBucketEnabled(true) - re.NoError(cluster.processRegionHeartbeat(newRegion, core.NewNoopHeartbeatProcessTracer())) + re.NoError(cluster.processRegionHeartbeat(context.TODO(), newRegion)) re.Len(cluster.GetRegion(uint64(1)).GetBuckets().GetKeys(), 2) // case6: disable region bucket in opt.SetRegionBucketEnabled(false) newRegion2 := regions[1].Clone(core.WithIncConfVer(), core.SetBuckets(nil)) - re.NoError(cluster.processRegionHeartbeat(newRegion2, core.NewNoopHeartbeatProcessTracer())) + re.NoError(cluster.processRegionHeartbeat(context.TODO(), newRegion2)) re.Nil(cluster.GetRegion(uint64(1)).GetBuckets()) re.Empty(cluster.GetRegion(uint64(1)).GetBuckets().GetKeys()) } @@ -733,25 +734,25 @@ func TestRegionHeartbeat(t *testing.T) { for i, region := range regions { // region does not exist. - re.NoError(cluster.processRegionHeartbeat(region, core.NewNoopHeartbeatProcessTracer())) + re.NoError(cluster.processRegionHeartbeat(context.TODO(), region)) checkRegions(re, cluster.core, regions[:i+1]) checkRegionsKV(re, cluster.storage, regions[:i+1]) // region is the same, not updated. - re.NoError(cluster.processRegionHeartbeat(region, core.NewNoopHeartbeatProcessTracer())) + re.NoError(cluster.processRegionHeartbeat(context.TODO(), region)) checkRegions(re, cluster.core, regions[:i+1]) checkRegionsKV(re, cluster.storage, regions[:i+1]) origin := region // region is updated. region = origin.Clone(core.WithIncVersion()) regions[i] = region - re.NoError(cluster.processRegionHeartbeat(region, core.NewNoopHeartbeatProcessTracer())) + re.NoError(cluster.processRegionHeartbeat(context.TODO(), region)) checkRegions(re, cluster.core, regions[:i+1]) checkRegionsKV(re, cluster.storage, regions[:i+1]) // region is stale (Version). stale := origin.Clone(core.WithIncConfVer()) - re.Error(cluster.processRegionHeartbeat(stale, core.NewNoopHeartbeatProcessTracer())) + re.Error(cluster.processRegionHeartbeat(context.TODO(), stale)) checkRegions(re, cluster.core, regions[:i+1]) checkRegionsKV(re, cluster.storage, regions[:i+1]) @@ -761,13 +762,13 @@ func TestRegionHeartbeat(t *testing.T) { core.WithIncConfVer(), ) regions[i] = region - re.NoError(cluster.processRegionHeartbeat(region, core.NewNoopHeartbeatProcessTracer())) + re.NoError(cluster.processRegionHeartbeat(context.TODO(), region)) checkRegions(re, cluster.core, regions[:i+1]) checkRegionsKV(re, cluster.storage, regions[:i+1]) // region is stale (ConfVer). stale = origin.Clone(core.WithIncConfVer()) - re.Error(cluster.processRegionHeartbeat(stale, core.NewNoopHeartbeatProcessTracer())) + re.Error(cluster.processRegionHeartbeat(context.TODO(), stale)) checkRegions(re, cluster.core, regions[:i+1]) checkRegionsKV(re, cluster.storage, regions[:i+1]) @@ -779,38 +780,38 @@ func TestRegionHeartbeat(t *testing.T) { }, })) regions[i] = region - re.NoError(cluster.processRegionHeartbeat(region, core.NewNoopHeartbeatProcessTracer())) + re.NoError(cluster.processRegionHeartbeat(context.TODO(), region)) checkRegions(re, cluster.core, regions[:i+1]) // Add a pending peer. region = region.Clone(core.WithPendingPeers([]*metapb.Peer{region.GetPeers()[rand.Intn(len(region.GetPeers()))]})) regions[i] = region - re.NoError(cluster.processRegionHeartbeat(region, core.NewNoopHeartbeatProcessTracer())) + re.NoError(cluster.processRegionHeartbeat(context.TODO(), region)) checkRegions(re, cluster.core, regions[:i+1]) // Clear down peers. region = region.Clone(core.WithDownPeers(nil)) regions[i] = region - re.NoError(cluster.processRegionHeartbeat(region, core.NewNoopHeartbeatProcessTracer())) + re.NoError(cluster.processRegionHeartbeat(context.TODO(), region)) checkRegions(re, cluster.core, regions[:i+1]) // Clear pending peers. region = region.Clone(core.WithPendingPeers(nil)) regions[i] = region - re.NoError(cluster.processRegionHeartbeat(region, core.NewNoopHeartbeatProcessTracer())) + re.NoError(cluster.processRegionHeartbeat(context.TODO(), region)) checkRegions(re, cluster.core, regions[:i+1]) // Remove peers. origin = region region = origin.Clone(core.SetPeers(region.GetPeers()[:1])) regions[i] = region - re.NoError(cluster.processRegionHeartbeat(region, core.NewNoopHeartbeatProcessTracer())) + re.NoError(cluster.processRegionHeartbeat(context.TODO(), region)) checkRegions(re, cluster.core, regions[:i+1]) checkRegionsKV(re, cluster.storage, regions[:i+1]) // Add peers. region = origin regions[i] = region - re.NoError(cluster.processRegionHeartbeat(region, core.NewNoopHeartbeatProcessTracer())) + re.NoError(cluster.processRegionHeartbeat(context.TODO(), region)) checkRegions(re, cluster.core, regions[:i+1]) checkRegionsKV(re, cluster.storage, regions[:i+1]) @@ -820,47 +821,47 @@ func TestRegionHeartbeat(t *testing.T) { core.WithIncConfVer(), ) regions[i] = region - re.NoError(cluster.processRegionHeartbeat(region, core.NewNoopHeartbeatProcessTracer())) + re.NoError(cluster.processRegionHeartbeat(context.TODO(), region)) checkRegions(re, cluster.core, regions[:i+1]) // Change leader. region = region.Clone(core.WithLeader(region.GetPeers()[1])) regions[i] = region - re.NoError(cluster.processRegionHeartbeat(region, core.NewNoopHeartbeatProcessTracer())) + re.NoError(cluster.processRegionHeartbeat(context.TODO(), region)) checkRegions(re, cluster.core, regions[:i+1]) // Change ApproximateSize. region = region.Clone(core.SetApproximateSize(144)) regions[i] = region - re.NoError(cluster.processRegionHeartbeat(region, core.NewNoopHeartbeatProcessTracer())) + re.NoError(cluster.processRegionHeartbeat(context.TODO(), region)) checkRegions(re, cluster.core, regions[:i+1]) // Change ApproximateKeys. region = region.Clone(core.SetApproximateKeys(144000)) regions[i] = region - re.NoError(cluster.processRegionHeartbeat(region, core.NewNoopHeartbeatProcessTracer())) + re.NoError(cluster.processRegionHeartbeat(context.TODO(), region)) checkRegions(re, cluster.core, regions[:i+1]) // Change bytes written. region = region.Clone(core.SetWrittenBytes(24000)) regions[i] = region - re.NoError(cluster.processRegionHeartbeat(region, core.NewNoopHeartbeatProcessTracer())) + re.NoError(cluster.processRegionHeartbeat(context.TODO(), region)) checkRegions(re, cluster.core, regions[:i+1]) // Change bytes read. region = region.Clone(core.SetReadBytes(1080000)) regions[i] = region - re.NoError(cluster.processRegionHeartbeat(region, core.NewNoopHeartbeatProcessTracer())) + re.NoError(cluster.processRegionHeartbeat(context.TODO(), region)) checkRegions(re, cluster.core, regions[:i+1]) // Flashback region = region.Clone(core.WithFlashback(true, 1)) regions[i] = region - re.NoError(cluster.processRegionHeartbeat(region, core.NewNoopHeartbeatProcessTracer())) + re.NoError(cluster.processRegionHeartbeat(context.TODO(), region)) checkRegions(re, cluster.core, regions[:i+1]) region = region.Clone(core.WithFlashback(false, 0)) regions[i] = region - re.NoError(cluster.processRegionHeartbeat(region, core.NewNoopHeartbeatProcessTracer())) + re.NoError(cluster.processRegionHeartbeat(context.TODO(), region)) checkRegions(re, cluster.core, regions[:i+1]) } @@ -916,8 +917,7 @@ func TestRegionHeartbeat(t *testing.T) { core.WithNewRegionID(10000), core.WithDecVersion(), ) - tracer := core.NewHeartbeatProcessTracer() - re.Error(cluster.processRegionHeartbeat(overlapRegion, tracer)) + re.Error(cluster.processRegionHeartbeat(context.TODO(), overlapRegion)) region := &metapb.Region{} ok, err := storage.LoadRegion(regions[n-1].GetID(), region) re.True(ok) @@ -941,9 +941,11 @@ func TestRegionHeartbeat(t *testing.T) { core.WithStartKey(regions[n-2].GetStartKey()), core.WithNewRegionID(regions[n-1].GetID()+1), ) - tracer = core.NewHeartbeatProcessTracer() + tracer := core.NewHeartbeatProcessTracer() tracer.Begin() - re.NoError(cluster.processRegionHeartbeat(overlapRegion, tracer)) + ctx := context.TODO() + ctx = context.WithValue(ctx, ctxutil.HeartbeatTracerKey, tracer) + re.NoError(cluster.processRegionHeartbeat(ctx, overlapRegion)) tracer.OnAllStageFinished() re.Condition(func() bool { fileds := tracer.LogFields() @@ -977,7 +979,7 @@ func TestRegionFlowChanged(t *testing.T) { regions := []*core.RegionInfo{core.NewTestRegionInfo(1, 1, []byte{}, []byte{})} processRegions := func(regions []*core.RegionInfo) { for _, r := range regions { - cluster.processRegionHeartbeat(r, core.NewNoopHeartbeatProcessTracer()) + cluster.processRegionHeartbeat(ctx, r) } } regions = core.SplitRegions(regions) @@ -1013,7 +1015,7 @@ func TestRegionSizeChanged(t *testing.T) { core.SetApproximateKeys(curMaxMergeKeys-1), core.SetSource(core.Heartbeat), ) - cluster.processRegionHeartbeat(region, core.NewNoopHeartbeatProcessTracer()) + cluster.processRegionHeartbeat(context.TODO(), region) regionID := region.GetID() re.True(cluster.regionStats.IsRegionStatsType(regionID, statistics.UndersizedRegion)) // Test ApproximateSize and ApproximateKeys change. @@ -1023,16 +1025,16 @@ func TestRegionSizeChanged(t *testing.T) { core.SetApproximateKeys(curMaxMergeKeys+1), core.SetSource(core.Heartbeat), ) - cluster.processRegionHeartbeat(region, core.NewNoopHeartbeatProcessTracer()) + cluster.processRegionHeartbeat(context.TODO(), region) re.False(cluster.regionStats.IsRegionStatsType(regionID, statistics.UndersizedRegion)) // Test MaxMergeRegionSize and MaxMergeRegionKeys change. cluster.opt.SetMaxMergeRegionSize(uint64(curMaxMergeSize + 2)) cluster.opt.SetMaxMergeRegionKeys(uint64(curMaxMergeKeys + 2)) - cluster.processRegionHeartbeat(region, core.NewNoopHeartbeatProcessTracer()) + cluster.processRegionHeartbeat(context.TODO(), region) re.True(cluster.regionStats.IsRegionStatsType(regionID, statistics.UndersizedRegion)) cluster.opt.SetMaxMergeRegionSize(uint64(curMaxMergeSize)) cluster.opt.SetMaxMergeRegionKeys(uint64(curMaxMergeKeys)) - cluster.processRegionHeartbeat(region, core.NewNoopHeartbeatProcessTracer()) + cluster.processRegionHeartbeat(context.TODO(), region) re.False(cluster.regionStats.IsRegionStatsType(regionID, statistics.UndersizedRegion)) } @@ -1095,11 +1097,11 @@ func TestConcurrentRegionHeartbeat(t *testing.T) { re.NoError(failpoint.Enable("github.com/tikv/pd/server/cluster/concurrentRegionHeartbeat", "return(true)")) go func() { defer wg.Done() - cluster.processRegionHeartbeat(source, core.NewNoopHeartbeatProcessTracer()) + cluster.processRegionHeartbeat(context.TODO(), source) }() time.Sleep(100 * time.Millisecond) re.NoError(failpoint.Disable("github.com/tikv/pd/server/cluster/concurrentRegionHeartbeat")) - re.NoError(cluster.processRegionHeartbeat(target, core.NewNoopHeartbeatProcessTracer())) + re.NoError(cluster.processRegionHeartbeat(context.TODO(), target)) wg.Wait() checkRegion(re, cluster.GetRegionByKey([]byte{}), target) } @@ -1161,7 +1163,7 @@ func TestRegionLabelIsolationLevel(t *testing.T) { func heartbeatRegions(re *require.Assertions, cluster *RaftCluster, regions []*core.RegionInfo) { // Heartbeat and check region one by one. for _, r := range regions { - re.NoError(cluster.processRegionHeartbeat(r, core.NewNoopHeartbeatProcessTracer())) + re.NoError(cluster.processRegionHeartbeat(context.TODO(), r)) checkRegion(re, cluster.GetRegion(r.GetID()), r) checkRegion(re, cluster.GetRegionByKey(r.GetStartKey()), r) @@ -1198,7 +1200,7 @@ func TestHeartbeatSplit(t *testing.T) { // 1: [nil, nil) region1 := core.NewRegionInfo(&metapb.Region{Id: 1, RegionEpoch: &metapb.RegionEpoch{Version: 1, ConfVer: 1}}, nil) - re.NoError(cluster.processRegionHeartbeat(region1, core.NewNoopHeartbeatProcessTracer())) + re.NoError(cluster.processRegionHeartbeat(context.TODO(), region1)) checkRegion(re, cluster.GetRegionByKey([]byte("foo")), region1) // split 1 to 2: [nil, m) 1: [m, nil), sync 2 first. @@ -1207,12 +1209,12 @@ func TestHeartbeatSplit(t *testing.T) { core.WithIncVersion(), ) region2 := core.NewRegionInfo(&metapb.Region{Id: 2, EndKey: []byte("m"), RegionEpoch: &metapb.RegionEpoch{Version: 1, ConfVer: 1}}, nil) - re.NoError(cluster.processRegionHeartbeat(region2, core.NewNoopHeartbeatProcessTracer())) + re.NoError(cluster.processRegionHeartbeat(context.TODO(), region2)) checkRegion(re, cluster.GetRegionByKey([]byte("a")), region2) // [m, nil) is missing before r1's heartbeat. re.Nil(cluster.GetRegionByKey([]byte("z"))) - re.NoError(cluster.processRegionHeartbeat(region1, core.NewNoopHeartbeatProcessTracer())) + re.NoError(cluster.processRegionHeartbeat(context.TODO(), region1)) checkRegion(re, cluster.GetRegionByKey([]byte("z")), region1) // split 1 to 3: [m, q) 1: [q, nil), sync 1 first. @@ -1221,12 +1223,12 @@ func TestHeartbeatSplit(t *testing.T) { core.WithIncVersion(), ) region3 := core.NewRegionInfo(&metapb.Region{Id: 3, StartKey: []byte("m"), EndKey: []byte("q"), RegionEpoch: &metapb.RegionEpoch{Version: 1, ConfVer: 1}}, nil) - re.NoError(cluster.processRegionHeartbeat(region1, core.NewNoopHeartbeatProcessTracer())) + re.NoError(cluster.processRegionHeartbeat(context.TODO(), region1)) checkRegion(re, cluster.GetRegionByKey([]byte("z")), region1) checkRegion(re, cluster.GetRegionByKey([]byte("a")), region2) // [m, q) is missing before r3's heartbeat. re.Nil(cluster.GetRegionByKey([]byte("n"))) - re.NoError(cluster.processRegionHeartbeat(region3, core.NewNoopHeartbeatProcessTracer())) + re.NoError(cluster.processRegionHeartbeat(context.TODO(), region3)) checkRegion(re, cluster.GetRegionByKey([]byte("n")), region3) } @@ -1522,11 +1524,11 @@ func TestUpdateStorePendingPeerCount(t *testing.T) { }, } origin := core.NewRegionInfo(&metapb.Region{Id: 1, Peers: peers[:3]}, peers[0], core.WithPendingPeers(peers[1:3])) - re.NoError(tc.processRegionHeartbeat(origin, core.NewNoopHeartbeatProcessTracer())) + re.NoError(tc.processRegionHeartbeat(context.TODO(), origin)) time.Sleep(50 * time.Millisecond) checkPendingPeerCount([]int{0, 1, 1, 0}, tc.RaftCluster, re) newRegion := core.NewRegionInfo(&metapb.Region{Id: 1, Peers: peers[1:]}, peers[1], core.WithPendingPeers(peers[3:4])) - re.NoError(tc.processRegionHeartbeat(newRegion, core.NewNoopHeartbeatProcessTracer())) + re.NoError(tc.processRegionHeartbeat(context.TODO(), newRegion)) time.Sleep(50 * time.Millisecond) checkPendingPeerCount([]int{0, 0, 0, 1}, tc.RaftCluster, re) } @@ -2137,6 +2139,7 @@ func newTestRaftCluster( opt *config.PersistOptions, s storage.Storage, ) *RaftCluster { + opt.GetScheduleConfig().EnableHeartbeatAsyncRunner = false rc := &RaftCluster{serverCtx: ctx, core: core.NewBasicCluster(), storage: s} rc.InitCluster(id, opt, nil, nil) rc.ruleManager = placement.NewRuleManager(ctx, storage.NewStorageWithMemoryBackend(), rc, opt) @@ -2959,12 +2962,12 @@ func TestShouldRun(t *testing.T) { for _, testCase := range testCases { r := tc.GetRegion(testCase.regionID) nr := r.Clone(core.WithLeader(r.GetPeers()[0]), core.SetSource(core.Heartbeat)) - re.NoError(tc.processRegionHeartbeat(nr, core.NewNoopHeartbeatProcessTracer())) + re.NoError(tc.processRegionHeartbeat(context.TODO(), nr)) re.Equal(testCase.ShouldRun, co.ShouldRun()) } nr := &metapb.Region{Id: 6, Peers: []*metapb.Peer{}} newRegion := core.NewRegionInfo(nr, nil, core.SetSource(core.Heartbeat)) - re.Error(tc.processRegionHeartbeat(newRegion, core.NewNoopHeartbeatProcessTracer())) + re.Error(tc.processRegionHeartbeat(context.TODO(), newRegion)) re.Equal(7, tc.core.GetClusterNotFromStorageRegionsCnt()) } @@ -3002,12 +3005,12 @@ func TestShouldRunWithNonLeaderRegions(t *testing.T) { for _, testCase := range testCases { r := tc.GetRegion(testCase.regionID) nr := r.Clone(core.WithLeader(r.GetPeers()[0]), core.SetSource(core.Heartbeat)) - re.NoError(tc.processRegionHeartbeat(nr, core.NewNoopHeartbeatProcessTracer())) + re.NoError(tc.processRegionHeartbeat(context.TODO(), nr)) re.Equal(testCase.ShouldRun, co.ShouldRun()) } nr := &metapb.Region{Id: 9, Peers: []*metapb.Peer{}} newRegion := core.NewRegionInfo(nr, nil, core.SetSource(core.Heartbeat)) - re.Error(tc.processRegionHeartbeat(newRegion, core.NewNoopHeartbeatProcessTracer())) + re.Error(tc.processRegionHeartbeat(context.TODO(), newRegion)) re.Equal(9, tc.core.GetClusterNotFromStorageRegionsCnt()) // Now, after server is prepared, there exist some regions with no leader. diff --git a/server/cluster/cluster_worker.go b/server/cluster/cluster_worker.go index 5ae8fdc0396f..5b8ea0a6ec1a 100644 --- a/server/cluster/cluster_worker.go +++ b/server/cluster/cluster_worker.go @@ -16,6 +16,7 @@ package cluster import ( "bytes" + "context" "github.com/pingcap/errors" "github.com/pingcap/kvproto/pkg/metapb" @@ -26,6 +27,7 @@ import ( mcsutils "github.com/tikv/pd/pkg/mcs/utils" "github.com/tikv/pd/pkg/schedule/operator" "github.com/tikv/pd/pkg/statistics/buckets" + "github.com/tikv/pd/pkg/utils/ctxutil" "github.com/tikv/pd/pkg/utils/logutil" "github.com/tikv/pd/pkg/utils/typeutil" "github.com/tikv/pd/pkg/versioninfo" @@ -39,7 +41,13 @@ func (c *RaftCluster) HandleRegionHeartbeat(region *core.RegionInfo) error { tracer = core.NewHeartbeatProcessTracer() } tracer.Begin() - if err := c.processRegionHeartbeat(region, tracer); err != nil { + ctx := context.WithValue(c.ctx, ctxutil.HeartbeatTracerKey, tracer) + ctx = context.WithValue(ctx, ctxutil.LimiterKey, c.hbConcurrencyLimiter) + if c.GetScheduleConfig().EnableHeartbeatAsyncRunner { + ctx = context.WithValue(ctx, ctxutil.TaskRunnerKey, c.taskRunner) + } + + if err := c.processRegionHeartbeat(ctx, region); err != nil { tracer.OnAllStageFinished() return err }