diff --git a/pkg/cluster/cluster.go b/pkg/cluster/cluster.go index 916200bfa3ec..8b4e67491397 100644 --- a/pkg/cluster/cluster.go +++ b/pkg/cluster/cluster.go @@ -28,6 +28,7 @@ type Cluster interface { GetLabelStats() *statistics.LabelStatistics GetCoordinator() *schedule.Coordinator GetRuleManager() *placement.RuleManager + GetBasicCluster() *core.BasicCluster } // HandleStatsAsync handles the flow asynchronously. @@ -43,6 +44,10 @@ func HandleStatsAsync(c Cluster, region *core.RegionInfo) { c.GetCoordinator().GetSchedulersController().CheckTransferWitnessLeader(region) } +type Item interface { + GetID() uint64 +} + // HandleOverlaps handles the overlap regions. func HandleOverlaps(c Cluster, overlaps []*core.RegionInfo) { for _, item := range overlaps { @@ -54,9 +59,23 @@ func HandleOverlaps(c Cluster, overlaps []*core.RegionInfo) { } } +// HandleOverlaps handles the overlap regions asynchronously. +func HandleOverlapsByCheck(c Cluster, region *core.RegionInfo) { + // check region again + _, overlaps, err := c.GetBasicCluster().PreCheckPutRegion(region) + if err != nil { + return + } + if len(overlaps) > 0 { + HandleOverlaps(c, overlaps) + } +} + // Collect collects the cluster information. -func Collect(c Cluster, region *core.RegionInfo, stores []*core.StoreInfo, hasRegionStats bool) { +func Collect(c Cluster, region *core.RegionInfo, hasRegionStats bool) { if hasRegionStats { - c.GetRegionStats().Observe(region, stores) + // get region again from root tree. + region = c.GetBasicCluster().GetRegion(region.GetID()) + c.GetRegionStats().Observe(region, c.GetBasicCluster().GetRegionStores(region)) } } diff --git a/pkg/core/region.go b/pkg/core/region.go index f7a4ef5f0fd5..ee4ecf5aae03 100644 --- a/pkg/core/region.go +++ b/pkg/core/region.go @@ -16,6 +16,7 @@ package core import ( "bytes" + "context" "encoding/hex" "fmt" "math" @@ -35,6 +36,7 @@ import ( "github.com/pingcap/kvproto/pkg/replication_modepb" "github.com/pingcap/log" "github.com/tikv/pd/pkg/errs" + "github.com/tikv/pd/pkg/ratelimit" "github.com/tikv/pd/pkg/utils/logutil" "github.com/tikv/pd/pkg/utils/syncutil" "github.com/tikv/pd/pkg/utils/typeutil" @@ -711,20 +713,51 @@ func (r *RegionInfo) isRegionRecreated() bool { // RegionGuideFunc is a function that determines which follow-up operations need to be performed based on the origin // and new region information. -type RegionGuideFunc func(region, origin *RegionInfo) (saveKV, saveCache, needSync bool) +type RegionGuideFunc func(ctx context.Context, region, origin *RegionInfo) (saveKV, saveCache, needSync bool) // GenerateRegionGuideFunc is used to generate a RegionGuideFunc. Control the log output by specifying the log function. // nil means do not print the log. func GenerateRegionGuideFunc(enableLog bool) RegionGuideFunc { noLog := func(msg string, fields ...zap.Field) {} - debug, info := noLog, noLog + d, i := noLog, noLog + debug, info := d, i if enableLog { - debug = log.Debug - info = log.Info + d = log.Debug + i = log.Info + debug, info = d, i } // Save to storage if meta is updated. // Save to cache if meta or leader is updated, or contains any down/pending peer. - return func(region, origin *RegionInfo) (saveKV, saveCache, needSync bool) { + return func(ctx context.Context, region, origin *RegionInfo) (saveKV, saveCache, needSync bool) { + asyncRunner, ok := ctx.Value("asyncRunner").(*ratelimit.Runner) + limiter, _ := ctx.Value("limiter").(*ratelimit.ConcurrencyLimiter) + // print log asynchronously + if ok { + debug = func(msg string, fields ...zap.Field) { + asyncRunner.RunAsyncTask( + ctx, + ratelimit.TaskOpts{ + TaskName: "Log", + Limit: limiter, + }, + func(ctx context.Context) { + d(msg, fields...) + }, + ) + } + info = func(msg string, fields ...zap.Field) { + asyncRunner.RunAsyncTask( + ctx, + ratelimit.TaskOpts{ + TaskName: "Log", + Limit: limiter, + }, + func(ctx context.Context) { + i(msg, fields...) + }, + ) + } + } if origin == nil { if log.GetLevel() <= zap.DebugLevel { debug("insert new region", @@ -789,7 +822,7 @@ func GenerateRegionGuideFunc(enableLog bool) RegionGuideFunc { } if !SortedPeersStatsEqual(region.GetDownPeers(), origin.GetDownPeers()) { if log.GetLevel() <= zap.DebugLevel { - debug("down-peers changed", zap.Uint64("region-id", region.GetID())) + debug("down peers changed", zap.Uint64("region-id", region.GetID()), zap.Reflect("before", origin.GetDownPeers()), zap.Reflect("after", region.GetDownPeers())) } saveCache, needSync = true, true return @@ -912,7 +945,7 @@ func (r *RegionsInfo) CheckAndPutRegion(region *RegionInfo) []*RegionInfo { if origin == nil || !bytes.Equal(origin.GetStartKey(), region.GetStartKey()) || !bytes.Equal(origin.GetEndKey(), region.GetEndKey()) { ols = r.tree.overlaps(®ionItem{RegionInfo: region}) } - err := check(region, origin, ols) + err := check(region, origin, convertItemsToRegions(ols)) if err != nil { log.Debug("region is stale", zap.Stringer("origin", origin.GetMeta()), errs.ZapError(err)) // return the state region to delete. @@ -933,48 +966,124 @@ func (r *RegionsInfo) PutRegion(region *RegionInfo) []*RegionInfo { } // PreCheckPutRegion checks if the region is valid to put. -func (r *RegionsInfo) PreCheckPutRegion(region *RegionInfo, trace RegionHeartbeatProcessTracer) (*RegionInfo, []*regionItem, error) { - origin, overlaps := r.GetRelevantRegions(region, trace) +func (r *RegionsInfo) PreCheckPutRegion(region *RegionInfo) (*RegionInfo, []*RegionInfo, error) { + origin, overlaps := r.GetRelevantRegions(region) err := check(region, origin, overlaps) - return origin, overlaps, err + ov := make([]*RegionInfo, 0, len(overlaps)) + return origin, ov, err +} + +func convertItemsToRegions(items []*regionItem) []*RegionInfo { + regions := make([]*RegionInfo, 0, len(items)) + for _, item := range items { + regions = append(regions, item.RegionInfo) + } + return regions } // AtomicCheckAndPutRegion checks if the region is valid to put, if valid then put. -func (r *RegionsInfo) AtomicCheckAndPutRegion(region *RegionInfo, trace RegionHeartbeatProcessTracer) ([]*RegionInfo, error) { +func (r *RegionsInfo) AtomicCheckAndPutRegion(ctx context.Context, region *RegionInfo) ([]*RegionInfo, error) { + tracer, ok := ctx.Value("tracer").(RegionHeartbeatProcessTracer) + if !ok { + tracer = NewNoopHeartbeatProcessTracer() + } r.t.Lock() var ols []*regionItem origin := r.getRegionLocked(region.GetID()) if origin == nil || !bytes.Equal(origin.GetStartKey(), region.GetStartKey()) || !bytes.Equal(origin.GetEndKey(), region.GetEndKey()) { ols = r.tree.overlaps(®ionItem{RegionInfo: region}) } - trace.OnCheckOverlapsFinished() - err := check(region, origin, ols) + tracer.OnCheckOverlapsFinished() + err := check(region, origin, convertItemsToRegions(ols)) if err != nil { r.t.Unlock() - trace.OnValidateRegionFinished() + tracer.OnValidateRegionFinished() return nil, err } - trace.OnValidateRegionFinished() + tracer.OnValidateRegionFinished() origin, overlaps, rangeChanged := r.setRegionLocked(region, true, ols...) r.t.Unlock() - trace.OnSetRegionFinished() + tracer.OnSetRegionFinished() r.UpdateSubTree(region, origin, overlaps, rangeChanged) - trace.OnUpdateSubTreeFinished() + tracer.OnUpdateSubTreeFinished() + return overlaps, nil +} + +// CheckAndPutRootTree checks if the region is valid to put to the root, if valid then return error. +// Usually used with AtomicCheckAndPutSubTree together. +func (r *RegionsInfo) CheckAndPutRootTree(ctx context.Context, region *RegionInfo) ([]*RegionInfo, error) { + tracer, ok := ctx.Value("tracer").(RegionHeartbeatProcessTracer) + if !ok { + tracer = NewNoopHeartbeatProcessTracer() + } + r.t.Lock() + var ols []*regionItem + origin := r.getRegionLocked(region.GetID()) + if origin == nil || !bytes.Equal(origin.GetStartKey(), region.GetStartKey()) || !bytes.Equal(origin.GetEndKey(), region.GetEndKey()) { + ols = r.tree.overlaps(®ionItem{RegionInfo: region}) + } + tracer.OnCheckOverlapsFinished() + err := check(region, origin, convertItemsToRegions(ols)) + if err != nil { + r.t.Unlock() + tracer.OnValidateRegionFinished() + return nil, err + } + tracer.OnValidateRegionFinished() + _, overlaps, _ := r.setRegionLocked(region, true, ols...) + r.t.Unlock() + tracer.OnSetRegionFinished() return overlaps, nil } +// CheckAndPutSubTree checks if the region is valid to put to the sub tree, if valid then return error. +// Usually used with AtomicCheckAndPutRootTree together. +func (r *RegionsInfo) CheckAndPutSubTree(ctx context.Context, region *RegionInfo) error { + var origin *RegionInfo + r.st.RLock() + // get origin from sub tree + originItem, ok := r.subRegions[region.GetID()] + if ok { + origin = originItem.RegionInfo + } + r.st.RUnlock() + // new region get from root tree again + newRegion, overlaps := r.GetRelevantRegions(origin) + rangeChanged := true + rangeChanged = !origin.rangeEqualsTo(region) + asyncRunner, ok := ctx.Value("asyncRunner").(*ratelimit.Runner) + limiter, _ := ctx.Value("limiter").(*ratelimit.ConcurrencyLimiter) + if ok { + asyncRunner.RunAsyncTask( + ctx, + ratelimit.TaskOpts{ + TaskName: "UpdateSubTree", + Limit: limiter, + }, + func(ctx context.Context) { + r.UpdateSubTree(newRegion, origin, overlaps, rangeChanged) + }, + ) + } else { + r.UpdateSubTree(region, origin, overlaps, rangeChanged) + } + return nil +} + // GetRelevantRegions returns the relevant regions for a given region. -func (r *RegionsInfo) GetRelevantRegions(region *RegionInfo, trace RegionHeartbeatProcessTracer) (origin *RegionInfo, overlaps []*regionItem) { +func (r *RegionsInfo) GetRelevantRegions(region *RegionInfo) (origin *RegionInfo, overlaps []*RegionInfo) { r.t.RLock() defer r.t.RUnlock() origin = r.getRegionLocked(region.GetID()) if origin == nil || !bytes.Equal(origin.GetStartKey(), region.GetStartKey()) || !bytes.Equal(origin.GetEndKey(), region.GetEndKey()) { - overlaps = r.tree.overlaps(®ionItem{RegionInfo: region}) + for _, item := range r.tree.overlaps(®ionItem{RegionInfo: region}) { + overlaps = append(overlaps, item.RegionInfo) + } } return } -func check(region, origin *RegionInfo, overlaps []*regionItem) error { +func check(region, origin *RegionInfo, overlaps []*RegionInfo) error { for _, item := range overlaps { // PD ignores stale regions' heartbeats, unless it is recreated recently by unsafe recover operation. if region.GetRegionEpoch().GetVersion() < item.GetRegionEpoch().GetVersion() && !region.isRegionRecreated() { @@ -1065,6 +1174,10 @@ func (r *RegionsInfo) UpdateSubTree(region, origin *RegionInfo, overlaps []*Regi r.st.Lock() defer r.st.Unlock() if origin != nil { + // TODO: Check if the pointer address is consistent? + // eg &origin == &subtree[origin.GetID]. + // compare the pointer of subtree with the pointer of origin to ensure no changes have occurred before update subtree. + if rangeChanged || !origin.peersEqualTo(region) { // If the range or peers have changed, the sub regionTree needs to be cleaned up. // TODO: Improve performance by deleting only the different peers. diff --git a/pkg/core/region_test.go b/pkg/core/region_test.go index 3c6536a6a773..e51c24142ca1 100644 --- a/pkg/core/region_test.go +++ b/pkg/core/region_test.go @@ -15,6 +15,7 @@ package core import ( + "context" "crypto/rand" "fmt" "math" @@ -459,9 +460,9 @@ func TestSetRegionConcurrence(t *testing.T) { regions := NewRegionsInfo() region := NewTestRegionInfo(1, 1, []byte("a"), []byte("b")) go func() { - regions.AtomicCheckAndPutRegion(region, NewNoopHeartbeatProcessTracer()) + regions.AtomicCheckAndPutRegion(context.TODO(), region) }() - regions.AtomicCheckAndPutRegion(region, NewNoopHeartbeatProcessTracer()) + regions.AtomicCheckAndPutRegion(context.TODO(), region) re.NoError(failpoint.Disable("github.com/tikv/pd/pkg/core/UpdateSubTree")) } diff --git a/pkg/core/region_tree.go b/pkg/core/region_tree.go index 333e1730ec8a..5a633a2639c1 100644 --- a/pkg/core/region_tree.go +++ b/pkg/core/region_tree.go @@ -35,6 +35,10 @@ func (r *regionItem) GetStartKey() []byte { return r.meta.StartKey } +func (r *regionItem) GetID() uint64 { + return r.meta.GetId() +} + // GetEndKey returns the end key of the region. func (r *regionItem) GetEndKey() []byte { return r.meta.EndKey diff --git a/pkg/mcs/scheduling/server/cluster.go b/pkg/mcs/scheduling/server/cluster.go index 1b915b6874d2..a6ef5e85aa21 100644 --- a/pkg/mcs/scheduling/server/cluster.go +++ b/pkg/mcs/scheduling/server/cluster.go @@ -2,6 +2,7 @@ package server import ( "context" + "runtime" "sync" "sync/atomic" "time" @@ -15,6 +16,7 @@ import ( "github.com/tikv/pd/pkg/core" "github.com/tikv/pd/pkg/errs" "github.com/tikv/pd/pkg/mcs/scheduling/server/config" + "github.com/tikv/pd/pkg/ratelimit" "github.com/tikv/pd/pkg/schedule" sc "github.com/tikv/pd/pkg/schedule/config" "github.com/tikv/pd/pkg/schedule/hbstream" @@ -51,6 +53,9 @@ type Cluster struct { apiServerLeader atomic.Value clusterID uint64 running atomic.Bool + + asyncRunner *ratelimit.Runner + hbConcurrencyLimiter *ratelimit.ConcurrencyLimiter } const ( @@ -81,6 +86,9 @@ func NewCluster(parentCtx context.Context, persistConfig *config.PersistConfig, storage: storage, clusterID: clusterID, checkMembershipCh: checkMembershipCh, + + asyncRunner: ratelimit.NewRunner("heartbeat-async-task-runner", 1000000), + hbConcurrencyLimiter: ratelimit.NewConcurrencyLimiter(uint64(runtime.NumCPU() * 2)), } c.coordinator = schedule.NewCoordinator(ctx, c, hbStreams) err = c.ruleManager.Initialize(persistConfig.GetMaxReplicas(), persistConfig.GetLocationLabels(), persistConfig.GetIsolationLevel()) @@ -543,7 +551,11 @@ func (c *Cluster) HandleRegionHeartbeat(region *core.RegionInfo) error { tracer = core.NewHeartbeatProcessTracer() } tracer.Begin() - if err := c.processRegionHeartbeat(region, tracer); err != nil { + ctx := context.WithValue(c.ctx, "tracer", tracer) + ctx = context.WithValue(ctx, "limiter", c.hbConcurrencyLimiter) + ctx = context.WithValue(ctx, "asyncRunner", c.asyncRunner) + + if err := c.processRegionHeartbeat(ctx, region); err != nil { tracer.OnAllStageFinished() return err } @@ -553,26 +565,47 @@ func (c *Cluster) HandleRegionHeartbeat(region *core.RegionInfo) error { } // processRegionHeartbeat updates the region information. -func (c *Cluster) processRegionHeartbeat(region *core.RegionInfo, tracer core.RegionHeartbeatProcessTracer) error { - origin, _, err := c.PreCheckPutRegion(region, tracer) +func (c *Cluster) processRegionHeartbeat(ctx context.Context, region *core.RegionInfo) error { + tracer, _ := ctx.Value("tracer").(core.RegionHeartbeatProcessTracer) + origin, _, err := c.PreCheckPutRegion(region) tracer.OnPreCheckFinished() if err != nil { return err } region.Inherit(origin, c.GetStoreConfig().IsEnableRegionBucket()) - cluster.HandleStatsAsync(c, region) + c.asyncRunner.RunAsyncTask( + ctx, + ratelimit.TaskOpts{ + TaskName: "HandleStatsAsync", + Limit: c.hbConcurrencyLimiter, + }, + func(ctx context.Context) { + cluster.HandleStatsAsync(c, region) + }, + ) tracer.OnAsyncHotStatsFinished() hasRegionStats := c.regionStats != nil // Save to storage if meta is updated, except for flashback. // Save to cache if meta or leader is updated, or contains any down/pending peer. - _, saveCache, _ := core.GenerateRegionGuideFunc(true)(region, origin) + _, saveCache, _ := core.GenerateRegionGuideFunc(true)(ctx, region, origin) if !saveCache { // Due to some config changes need to update the region stats as well, // so we do some extra checks here. if hasRegionStats && c.regionStats.RegionStatsNeedUpdate(region) { - c.regionStats.Observe(region, c.GetRegionStores(region)) + c.asyncRunner.RunAsyncTask( + ctx, + ratelimit.TaskOpts{ + TaskName: "ObserveRegionStatsAsync", + Limit: c.hbConcurrencyLimiter, + }, + func(ctx context.Context) { + if c.regionStats.RegionStatsNeedUpdate(region) { + c.regionStats.Observe(region, c.GetRegionStores(region)) + } + }, + ) } return nil } @@ -583,15 +616,26 @@ func (c *Cluster) processRegionHeartbeat(region *core.RegionInfo, tracer core.Re // check its validation again here. // // However, it can't solve the race condition of concurrent heartbeats from the same region. - if overlaps, err = c.AtomicCheckAndPutRegion(region, tracer); err != nil { + if overlaps, err = c.CheckAndPutRootTree(ctx, region); err != nil { tracer.OnSaveCacheFinished() return err } + c.asyncRunner.RunAsyncTask( + ctx, + ratelimit.TaskOpts{ + TaskName: "UpdateSubTree", + Limit: c.hbConcurrencyLimiter, + }, + func(ctx context.Context) { + c.CheckAndPutSubTree(ctx, region) + }, + ) + tracer.OnUpdateSubTreeFinished() cluster.HandleOverlaps(c, overlaps) } tracer.OnSaveCacheFinished() - cluster.Collect(c, region, c.GetRegionStores(region), hasRegionStats) + cluster.Collect(c, region, hasRegionStats) tracer.OnCollectRegionStatsFinished() return nil } diff --git a/pkg/ratelimit/concurrency_limiter.go b/pkg/ratelimit/concurrency_limiter.go index b1eef3c8101b..2f5b38992148 100644 --- a/pkg/ratelimit/concurrency_limiter.go +++ b/pkg/ratelimit/concurrency_limiter.go @@ -14,24 +14,31 @@ package ratelimit -import "github.com/tikv/pd/pkg/utils/syncutil" +import ( + "context" -type concurrencyLimiter struct { - mu syncutil.RWMutex + "github.com/tikv/pd/pkg/utils/syncutil" +) + +type ConcurrencyLimiter struct { + mu syncutil.Mutex current uint64 limit uint64 + waiting uint64 // statistic maxLimit uint64 + queue chan *TaskToken } -func newConcurrencyLimiter(limit uint64) *concurrencyLimiter { - return &concurrencyLimiter{limit: limit} +// NewConcurrencyLimiter creates a new ConcurrencyLimiter. +func NewConcurrencyLimiter(limit uint64) *ConcurrencyLimiter { + return &ConcurrencyLimiter{limit: limit, queue: make(chan *TaskToken, limit)} } const unlimit = uint64(0) -func (l *concurrencyLimiter) allow() bool { +func (l *ConcurrencyLimiter) allow() bool { l.mu.Lock() defer l.mu.Unlock() @@ -45,7 +52,7 @@ func (l *concurrencyLimiter) allow() bool { return false } -func (l *concurrencyLimiter) release() { +func (l *ConcurrencyLimiter) release() { l.mu.Lock() defer l.mu.Unlock() @@ -54,28 +61,28 @@ func (l *concurrencyLimiter) release() { } } -func (l *concurrencyLimiter) getLimit() uint64 { - l.mu.RLock() - defer l.mu.RUnlock() +func (l *ConcurrencyLimiter) getLimit() uint64 { + l.mu.Lock() + defer l.mu.Unlock() return l.limit } -func (l *concurrencyLimiter) setLimit(limit uint64) { +func (l *ConcurrencyLimiter) setLimit(limit uint64) { l.mu.Lock() defer l.mu.Unlock() l.limit = limit } -func (l *concurrencyLimiter) getCurrent() uint64 { - l.mu.RLock() - defer l.mu.RUnlock() +func (l *ConcurrencyLimiter) getCurrent() uint64 { + l.mu.Lock() + defer l.mu.Unlock() return l.current } -func (l *concurrencyLimiter) getMaxConcurrency() uint64 { +func (l *ConcurrencyLimiter) getMaxConcurrency() uint64 { l.mu.Lock() defer func() { l.maxLimit = l.current @@ -84,3 +91,65 @@ func (l *concurrencyLimiter) getMaxConcurrency() uint64 { return l.maxLimit } + +// GetRunningTasks returns the number of running tasks. +func (l *ConcurrencyLimiter) GetRunningTasksNum() uint64 { + return l.getCurrent() +} + +// GetWaitingTasks returns the number of waiting tasks. +func (l *ConcurrencyLimiter) GetWaitingTasksNum() uint64 { + l.mu.Lock() + defer l.mu.Unlock() + return l.waiting +} + +// Acquire acquires a token from the limiter. which will block until a token is available or ctx is done, like Timeout. +func (l *ConcurrencyLimiter) Acquire(ctx context.Context) (*TaskToken, error) { + l.mu.Lock() + if l.current >= l.limit { + l.waiting++ + l.mu.Unlock() + select { + case <-ctx.Done(): + l.mu.Lock() + l.waiting-- + l.mu.Unlock() + return nil, ctx.Err() + case token := <-l.queue: + l.mu.Lock() + token.released = false + l.current++ + l.waiting-- + l.mu.Unlock() + return token, nil + } + } + l.current++ + token := &TaskToken{limiter: l} + l.mu.Unlock() + return token, nil +} + +// TaskToken is a token that must be released after the task is done. +type TaskToken struct { + released bool + limiter *ConcurrencyLimiter +} + +// Release releases the token. +func (tt *TaskToken) Release() { + tt.limiter.mu.Lock() + defer tt.limiter.mu.Unlock() + if tt.released { + return + } + if tt.limiter.current == 0 { + panic("release token more than acquire") + } + tt.released = true + tt.limiter.current-- + if len(tt.limiter.queue) < int(tt.limiter.limit) { + tt.limiter.queue <- tt + } +} diff --git a/pkg/ratelimit/concurrency_limiter_test.go b/pkg/ratelimit/concurrency_limiter_test.go index 5fe03740394a..216da1ac8a02 100644 --- a/pkg/ratelimit/concurrency_limiter_test.go +++ b/pkg/ratelimit/concurrency_limiter_test.go @@ -15,7 +15,12 @@ package ratelimit import ( + "context" + "fmt" + "sync" + "sync/atomic" "testing" + "time" "github.com/stretchr/testify/require" ) @@ -23,7 +28,7 @@ import ( func TestConcurrencyLimiter(t *testing.T) { t.Parallel() re := require.New(t) - cl := newConcurrencyLimiter(10) + cl := NewConcurrencyLimiter(10) for i := 0; i < 10; i++ { re.True(cl.allow()) } @@ -52,3 +57,72 @@ func TestConcurrencyLimiter(t *testing.T) { re.Equal(uint64(5), cl.getMaxConcurrency()) re.Equal(uint64(0), cl.getMaxConcurrency()) } + +func TestConcurrencyLimiter2(t *testing.T) { + limit := uint64(2) + limiter := NewConcurrencyLimiter(limit) + + require.Equal(t, uint64(0), limiter.GetRunningTasksNum(), "Expected running tasks to be 0") + require.Equal(t, uint64(0), limiter.GetWaitingTasksNum(), "Expected waiting tasks to be 0") + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + // Acquire two tokens + token1, err := limiter.Acquire(ctx) + require.NoError(t, err, "Failed to acquire token") + + token2, err := limiter.Acquire(ctx) + require.NoError(t, err, "Failed to acquire token") + + require.Equal(t, limit, limiter.GetRunningTasksNum(), "Expected running tasks to be 2") + + // Try to acquire third token, it should not be able to acquire immediately due to limit + go func() { + _, err := limiter.Acquire(ctx) + require.NoError(t, err, "Failed to acquire token") + }() + + time.Sleep(100 * time.Millisecond) // Give some time for the goroutine to run + require.Equal(t, uint64(1), limiter.GetWaitingTasksNum(), "Expected waiting tasks to be 1") + + // Release a token + token1.Release() + time.Sleep(100 * time.Millisecond) // Give some time for the goroutine to run + require.Equal(t, uint64(2), limiter.GetRunningTasksNum(), "Expected running tasks to be 2") + require.Equal(t, uint64(0), limiter.GetWaitingTasksNum(), "Expected waiting tasks to be 0") + + // Release the second token + token2.Release() + time.Sleep(100 * time.Millisecond) // Give some time for the goroutine to run + require.Equal(t, uint64(1), limiter.GetRunningTasksNum(), "Expected running tasks to be 1") +} + +func TestConcurrencyLimiterAcquire(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + limiter := NewConcurrencyLimiter(20) + sum := int64(0) + start := time.Now() + wg := &sync.WaitGroup{} + wg.Add(100) + for i := 0; i < 100; i++ { + go func(i int) { + defer wg.Done() + token, err := limiter.Acquire(ctx) + if err != nil { + fmt.Printf("Task %d failed to acquire: %v\n", i, err) + return + } + defer token.Release() + // simulate takes some time + time.Sleep(10 * time.Millisecond) + atomic.AddInt64(&sum, 1) + }(i) + } + wg.Wait() + // We should have 20 tasks running concurrently, so it should take at least 50ms to complete + require.Greater(t, time.Since(start).Milliseconds(), int64(50)) + require.Equal(t, int64(100), sum) +} diff --git a/pkg/ratelimit/limiter.go b/pkg/ratelimit/limiter.go index dc744d9ac1b0..7b3eba10325a 100644 --- a/pkg/ratelimit/limiter.go +++ b/pkg/ratelimit/limiter.go @@ -36,18 +36,18 @@ type DimensionConfig struct { type limiter struct { mu syncutil.RWMutex - concurrency *concurrencyLimiter + concurrency *ConcurrencyLimiter rate *RateLimiter } func newLimiter() *limiter { lim := &limiter{ - concurrency: newConcurrencyLimiter(0), + concurrency: NewConcurrencyLimiter(0), } return lim } -func (l *limiter) getConcurrencyLimiter() *concurrencyLimiter { +func (l *limiter) getConcurrencyLimiter() *ConcurrencyLimiter { l.mu.RLock() defer l.mu.RUnlock() return l.concurrency @@ -101,7 +101,7 @@ func (l *limiter) updateConcurrencyConfig(limit uint64) UpdateStatus { } l.concurrency.setLimit(limit) } else { - l.concurrency = newConcurrencyLimiter(limit) + l.concurrency = NewConcurrencyLimiter(limit) } return ConcurrencyChanged } diff --git a/pkg/ratelimit/runner.go b/pkg/ratelimit/runner.go new file mode 100644 index 000000000000..510d020dee05 --- /dev/null +++ b/pkg/ratelimit/runner.go @@ -0,0 +1,96 @@ +// Copyright 2022 TiKV Project Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ratelimit + +import ( + "context" + "errors" + "sync/atomic" + + "github.com/pingcap/log" + "go.uber.org/zap" +) + +// ErrMaxWaitingTasksExceeded is returned when the number of waiting tasks exceeds the maximum. +var ErrMaxWaitingTasksExceeded = errors.New("max waiting tasks exceeded") + +// Runner is a simple task runner that limits the number of concurrent tasks. +type Runner struct { + numTasks int64 + maxPendingTasks uint64 + name string +} + +// NewRunner creates a new Runner. +func NewRunner(name string, maxPendingTasks uint64) *Runner { + return &Runner{name: name, maxPendingTasks: maxPendingTasks} +} + +// TaskOpts is the options for RunAsyncTask. +type TaskOpts struct { + // TaskName is a human-readable name for the operation. TODO: metrics by name. + TaskName string + Limit *ConcurrencyLimiter +} + +// RunAsyncTask except the callback f is run in a goroutine. +// The call doesn't block for the callback to finish execution. +func (s *Runner) RunAsyncTask(ctx context.Context, opt TaskOpts, f func(context.Context)) error { + if opt.Limit != nil && atomic.LoadInt64(&s.numTasks) >= int64(s.maxPendingTasks) { + return ErrMaxWaitingTasksExceeded + } + + s.addTask(1) + go func(ctx context.Context, opt TaskOpts) { + var token *TaskToken + if opt.Limit != nil { + // Wait for permission to run from the semaphore. + var err error + if opt.Limit != nil { + token, err = opt.Limit.Acquire(ctx) + } + if err != nil { + log.Error("failed to acquire semaphore", zap.String("task-name", opt.TaskName), zap.Error(err)) + return + } + + // Check for canceled context: it's possible to get the semaphore even + // if the context is canceled. + if ctx.Err() != nil { + log.Error("context is canceled", zap.String("task-name", opt.TaskName)) + return + } + } + defer s.addTask(-1) + defer s.recover() + if token != nil { + defer token.Release() + } + f(ctx) + }(ctx, opt) + + return nil +} + +func (s *Runner) recover() { + if r := recover(); r != nil { + log.Error("panic in runner", zap.Any("error", r)) + return + } +} + +func (s *Runner) addTask(delta int64) (updated int64) { + return atomic.AddInt64(&s.numTasks, delta) +} diff --git a/pkg/statistics/region_collection.go b/pkg/statistics/region_collection.go index 52655b093a21..171e3710c1cf 100644 --- a/pkg/statistics/region_collection.go +++ b/pkg/statistics/region_collection.go @@ -22,6 +22,7 @@ import ( sc "github.com/tikv/pd/pkg/schedule/config" "github.com/tikv/pd/pkg/schedule/placement" "github.com/tikv/pd/pkg/utils/syncutil" + "go.uber.org/zap" ) // RegionInfoProvider is an interface to provide the region information. @@ -234,6 +235,7 @@ func (r *RegionStatistics) Observe(region *core.RegionInfo, stores []*core.Store info = &RegionInfoWithTS{id: regionID} } if typ == DownPeer { + log.Info("region has down peer", zap.Uint64("id", region.GetID()), zap.Reflect("meta", region.GetMeta()), zap.Reflect("region", region), zap.Reflect("downPeers", region.GetDownPeers()), zap.Int("len-down", len(region.GetDownPeers()))) if info.startDownPeerTS != 0 { regionDownPeerDuration.Observe(float64(time.Now().Unix() - info.startDownPeerTS)) } else { @@ -254,6 +256,9 @@ func (r *RegionStatistics) Observe(region *core.RegionInfo, stores []*core.Store // Remove the info if any of the conditions are not met any more. if oldIndex, ok := r.index[regionID]; ok && oldIndex > emptyStatistic { deleteIndex := oldIndex &^ peerTypeIndex + if deleteIndex&DownPeer != 0 { + log.Info("region has no down peer, remove", zap.Uint64("id", region.GetID()), zap.Reflect("meta", region.GetMeta()), zap.Reflect("region", region)) + } r.deleteEntry(deleteIndex, regionID) } r.index[regionID] = peerTypeIndex diff --git a/pkg/syncer/client.go b/pkg/syncer/client.go index ffbd71d2f1ea..558423722ffe 100644 --- a/pkg/syncer/client.go +++ b/pkg/syncer/client.go @@ -200,13 +200,12 @@ func (s *RegionSyncer) StartSyncWithLeader(addr string) { region = core.NewRegionInfo(r, regionLeader, core.SetSource(core.Sync)) } - tracer := core.NewNoopHeartbeatProcessTracer() - origin, _, err := bc.PreCheckPutRegion(region, tracer) + origin, _, err := bc.PreCheckPutRegion(region) if err != nil { log.Debug("region is stale", zap.Stringer("origin", origin.GetMeta()), errs.ZapError(err)) continue } - saveKV, _, _ := regionGuide(region, origin) + saveKV, _, _ := regionGuide(ctx, region, origin) overlaps := bc.PutRegion(region) if hasBuckets { diff --git a/server/cluster/cluster.go b/server/cluster/cluster.go index 354e12020e3b..e6569cf0bddf 100644 --- a/server/cluster/cluster.go +++ b/server/cluster/cluster.go @@ -21,6 +21,7 @@ import ( "io" "math" "net/http" + "runtime" "strconv" "strings" "sync" @@ -44,6 +45,7 @@ import ( mcsutils "github.com/tikv/pd/pkg/mcs/utils" "github.com/tikv/pd/pkg/memory" "github.com/tikv/pd/pkg/progress" + "github.com/tikv/pd/pkg/ratelimit" "github.com/tikv/pd/pkg/replication" sc "github.com/tikv/pd/pkg/schedule/config" "github.com/tikv/pd/pkg/schedule/hbstream" @@ -166,6 +168,9 @@ type RaftCluster struct { keyspaceGroupManager *keyspace.GroupManager independentServices sync.Map hbstreams *hbstream.HeartbeatStreams + + asyncRunner *ratelimit.Runner + hbConcurrencyLimiter *ratelimit.ConcurrencyLimiter } // Status saves some state information. @@ -182,13 +187,15 @@ type Status struct { func NewRaftCluster(ctx context.Context, clusterID uint64, basicCluster *core.BasicCluster, storage storage.Storage, regionSyncer *syncer.RegionSyncer, etcdClient *clientv3.Client, httpClient *http.Client) *RaftCluster { return &RaftCluster{ - serverCtx: ctx, - clusterID: clusterID, - regionSyncer: regionSyncer, - httpClient: httpClient, - etcdClient: etcdClient, - core: basicCluster, - storage: storage, + serverCtx: ctx, + clusterID: clusterID, + regionSyncer: regionSyncer, + httpClient: httpClient, + etcdClient: etcdClient, + core: basicCluster, + storage: storage, + asyncRunner: ratelimit.NewRunner("heartbeat-async-task-runner", 1000000), + hbConcurrencyLimiter: ratelimit.NewConcurrencyLimiter(uint64(runtime.NumCPU() * 2)), } } @@ -990,8 +997,9 @@ func (c *RaftCluster) processReportBuckets(buckets *metapb.Buckets) error { var regionGuide = core.GenerateRegionGuideFunc(true) // processRegionHeartbeat updates the region information. -func (c *RaftCluster) processRegionHeartbeat(region *core.RegionInfo, tracer core.RegionHeartbeatProcessTracer) error { - origin, _, err := c.core.PreCheckPutRegion(region, tracer) +func (c *RaftCluster) processRegionHeartbeat(ctx context.Context, region *core.RegionInfo) error { + tracer, _ := ctx.Value("tracer").(core.RegionHeartbeatProcessTracer) + origin, _, err := c.core.PreCheckPutRegion(region) tracer.OnPreCheckFinished() if err != nil { return err @@ -1000,13 +1008,22 @@ func (c *RaftCluster) processRegionHeartbeat(region *core.RegionInfo, tracer cor region.Inherit(origin, c.GetStoreConfig().IsEnableRegionBucket()) if !c.IsServiceIndependent(mcsutils.SchedulingServiceName) { - cluster.HandleStatsAsync(c, region) + c.asyncRunner.RunAsyncTask( + ctx, + ratelimit.TaskOpts{ + TaskName: "HandleStatsAsync", + Limit: c.hbConcurrencyLimiter, + }, + func(ctx context.Context) { + cluster.HandleStatsAsync(c, region) + }, + ) } tracer.OnAsyncHotStatsFinished() hasRegionStats := c.regionStats != nil // Save to storage if meta is updated, except for flashback. // Save to cache if meta or leader is updated, or contains any down/pending peer. - saveKV, saveCache, needSync := regionGuide(region, origin) + saveKV, saveCache, needSync := regionGuide(ctx, region, origin) tracer.OnRegionGuideFinished() if !saveKV && !saveCache { // Due to some config changes need to update the region stats as well, @@ -1015,7 +1032,19 @@ func (c *RaftCluster) processRegionHeartbeat(region *core.RegionInfo, tracer cor // region stats needs to be collected in API mode. // We need to think of a better way to reduce this part of the cost in the future. if hasRegionStats && c.regionStats.RegionStatsNeedUpdate(region) { - c.regionStats.Observe(region, c.getRegionStoresLocked(region)) + c.asyncRunner.RunAsyncTask( + ctx, + ratelimit.TaskOpts{ + TaskName: "ObserveRegionStatsAsync", + Limit: c.hbConcurrencyLimiter, + }, + func(ctx context.Context) { + // get region agian to avoid the region is stale + // concurrency cannot make sure the order anyway + region = c.GetRegion(region.GetID()) + c.regionStats.Observe(region, c.GetRegionStores(region)) + }, + ) } return nil } @@ -1032,43 +1061,86 @@ func (c *RaftCluster) processRegionHeartbeat(region *core.RegionInfo, tracer cor // check its validation again here. // // However, it can't solve the race condition of concurrent heartbeats from the same region. - if overlaps, err = c.core.AtomicCheckAndPutRegion(region, tracer); err != nil { + if overlaps, err = c.core.CheckAndPutRootTree(ctx, region); err != nil { tracer.OnSaveCacheFinished() return err } + c.asyncRunner.RunAsyncTask( + ctx, + ratelimit.TaskOpts{ + TaskName: "UpdateSubTree", + Limit: c.hbConcurrencyLimiter, + }, + func(ctx context.Context) { + c.core.CheckAndPutSubTree(ctx, region) + }, + ) + tracer.OnUpdateSubTreeFinished() if !c.IsServiceIndependent(mcsutils.SchedulingServiceName) { - cluster.HandleOverlaps(c, overlaps) + // async handle overlaps + c.asyncRunner.RunAsyncTask( + ctx, + ratelimit.TaskOpts{ + TaskName: "HandleOverlaps", + Limit: c.hbConcurrencyLimiter, + }, + func(ctx context.Context) { + cluster.HandleOverlapsByCheck(c, region) + }, + ) } regionUpdateCacheEventCounter.Inc() } tracer.OnSaveCacheFinished() - // TODO: Due to the accuracy requirements of the API "/regions/check/xxx", - // region stats needs to be collected in API mode. - // We need to think of a better way to reduce this part of the cost in the future. - cluster.Collect(c, region, c.GetRegionStores(region), hasRegionStats) + // async handle region stats + + c.asyncRunner.RunAsyncTask( + ctx, + ratelimit.TaskOpts{ + TaskName: "CollectRegionStatsAsync", + Limit: c.hbConcurrencyLimiter, + }, + func(ctx context.Context) { + // TODO: Due to the accuracy requirements of the API "/regions/check/xxx", + // region stats needs to be collected in API mode. + // We need to think of a better way to reduce this part of the cost in the future. + cluster.Collect(c, region, hasRegionStats) + }, + ) + tracer.OnCollectRegionStatsFinished() if c.storage != nil { - // If there are concurrent heartbeats from the same region, the last write will win even if - // writes to storage in the critical area. So don't use mutex to protect it. - // Not successfully saved to storage is not fatal, it only leads to longer warm-up - // after restart. Here we only log the error then go on updating cache. - for _, item := range overlaps { - if err := c.storage.DeleteRegion(item.GetMeta()); err != nil { - log.Error("failed to delete region from storage", - zap.Uint64("region-id", item.GetID()), - logutil.ZapRedactStringer("region-meta", core.RegionToHexMeta(item.GetMeta())), - errs.ZapError(err)) - } - } if saveKV { - if err := c.storage.SaveRegion(region.GetMeta()); err != nil { - log.Error("failed to save region to storage", - zap.Uint64("region-id", region.GetID()), - logutil.ZapRedactStringer("region-meta", core.RegionToHexMeta(region.GetMeta())), - errs.ZapError(err)) - } - regionUpdateKVEventCounter.Inc() + // async save KV + c.asyncRunner.RunAsyncTask( + ctx, + ratelimit.TaskOpts{ + TaskName: "SaveRegionToKV", + Limit: c.hbConcurrencyLimiter, + }, + func(ctx context.Context) { + // If there are concurrent heartbeats from the same region, the last write will win even if + // writes to storage in the critical area. So don't use mutex to protect it. + // Not successfully saved to storage is not fatal, it only leads to longer warm-up + // after restart. Here we only log the error then go on updating cache. + for _, item := range overlaps { + if err := c.storage.DeleteRegion(item.GetMeta()); err != nil { + log.Error("failed to delete region from storage", + zap.Uint64("region-id", item.GetID()), + logutil.ZapRedactStringer("region-meta", core.RegionToHexMeta(item.GetMeta())), + errs.ZapError(err)) + } + } + if err := c.storage.SaveRegion(region.GetMeta()); err != nil { + log.Error("failed to save region to storage", + zap.Uint64("region-id", region.GetID()), + logutil.ZapRedactStringer("region-meta", core.RegionToHexMeta(region.GetMeta())), + errs.ZapError(err)) + } + regionUpdateKVEventCounter.Inc() + }, + ) } } @@ -2067,16 +2139,6 @@ func (c *RaftCluster) resetProgressIndicator() { storesETAGauge.Reset() } -func (c *RaftCluster) getRegionStoresLocked(region *core.RegionInfo) []*core.StoreInfo { - stores := make([]*core.StoreInfo, 0, len(region.GetPeers())) - for _, p := range region.GetPeers() { - if store := c.core.GetStore(p.StoreId); store != nil { - stores = append(stores, store) - } - } - return stores -} - // OnStoreVersionChange changes the version of the cluster when needed. func (c *RaftCluster) OnStoreVersionChange() { c.RLock() diff --git a/server/cluster/cluster_worker.go b/server/cluster/cluster_worker.go index 5ae8fdc0396f..8693abc5b67b 100644 --- a/server/cluster/cluster_worker.go +++ b/server/cluster/cluster_worker.go @@ -16,6 +16,7 @@ package cluster import ( "bytes" + "context" "github.com/pingcap/errors" "github.com/pingcap/kvproto/pkg/metapb" @@ -39,7 +40,11 @@ func (c *RaftCluster) HandleRegionHeartbeat(region *core.RegionInfo) error { tracer = core.NewHeartbeatProcessTracer() } tracer.Begin() - if err := c.processRegionHeartbeat(region, tracer); err != nil { + ctx := context.WithValue(c.ctx, "tracer", tracer) + ctx = context.WithValue(ctx, "limiter", c.hbConcurrencyLimiter) + ctx = context.WithValue(ctx, "asyncRunner", c.asyncRunner) + + if err := c.processRegionHeartbeat(ctx, region); err != nil { tracer.OnAllStageFinished() return err }