From 949d602b9cc334374fe979ac593b9a9553fc864e Mon Sep 17 00:00:00 2001 From: ShuNing Date: Mon, 8 Apr 2024 17:28:50 +0800 Subject: [PATCH] pkg/schedule: optimize the lock usage of operator controller (#8032) ref tikv/pd#7897 pkg/schedule: optimize the lock usage of the operator controller - use sync.Map for operators, which is friendly for check operators on the heartbeat hot path - reduce the lock hold time by splitting the locks. Signed-off-by: nolouch Co-authored-by: ti-chi-bot[bot] <108142056+ti-chi-bot[bot]@users.noreply.github.com> --- pkg/schedule/operator/operator_controller.go | 255 +++++++++--------- .../operator/operator_controller_test.go | 35 ++- pkg/schedule/operator/operator_queue.go | 33 +++ pkg/schedule/operator/waiting_operator.go | 33 ++- 4 files changed, 213 insertions(+), 143 deletions(-) diff --git a/pkg/schedule/operator/operator_controller.go b/pkg/schedule/operator/operator_controller.go index f5e86f812c9..f05c232904f 100644 --- a/pkg/schedule/operator/operator_controller.go +++ b/pkg/schedule/operator/operator_controller.go @@ -15,10 +15,10 @@ package operator import ( - "container/heap" "context" "fmt" "strconv" + "sync" "time" "github.com/pingcap/failpoint" @@ -52,20 +52,51 @@ var ( FastOperatorFinishTime = 10 * time.Second ) +type opCounter struct { + syncutil.RWMutex + count map[OpKind]uint64 +} + +func (c *opCounter) inc(kind OpKind) { + c.Lock() + defer c.Unlock() + c.count[kind]++ +} + +func (c *opCounter) dec(kind OpKind) { + c.Lock() + defer c.Unlock() + if c.count[kind] > 0 { + c.count[kind]-- + } +} + +func (c *opCounter) getCountByKind(kind OpKind) uint64 { + c.RLock() + defer c.RUnlock() + return c.count[kind] +} + // Controller is used to limit the speed of scheduling. type Controller struct { - syncutil.RWMutex - ctx context.Context - config config.SharedConfigProvider - cluster *core.BasicCluster - operators map[uint64]*Operator - hbStreams *hbstream.HeartbeatStreams - fastOperators *cache.TTLUint64 - counts map[OpKind]uint64 - records *records - wop WaitingOperator - wopStatus *waitingOperatorStatus - opNotifierQueue operatorQueue + operators sync.Map + ctx context.Context + config config.SharedConfigProvider + cluster *core.BasicCluster + hbStreams *hbstream.HeartbeatStreams + + // fast path, TTLUint64 is safe for concurrent. + fastOperators *cache.TTLUint64 + + // opNotifierQueue is a priority queue to notify the operator to be checked. + // safe for concurrent. + opNotifierQueue *concurrentHeapOpQueue + + // states + records *records // safe for concurrent + wop WaitingOperator + wopStatus *waitingOperatorStatus + counts *opCounter } // NewController creates a Controller. @@ -74,14 +105,14 @@ func NewController(ctx context.Context, cluster *core.BasicCluster, config confi ctx: ctx, cluster: cluster, config: config, - operators: make(map[uint64]*Operator), hbStreams: hbStreams, fastOperators: cache.NewIDTTL(ctx, time.Minute, FastOperatorFinishTime), - counts: make(map[OpKind]uint64), - records: newRecords(ctx), - wop: newRandBuckets(), - wopStatus: newWaitingOperatorStatus(), - opNotifierQueue: make(operatorQueue, 0), + opNotifierQueue: newConcurrentHeapOpQueue(), + // states + records: newRecords(ctx), + wop: newRandBuckets(), + wopStatus: newWaitingOperatorStatus(), + counts: &opCounter{count: make(map[OpKind]uint64)}, } } @@ -93,8 +124,6 @@ func (oc *Controller) Ctx() context.Context { // GetCluster exports basic cluster to evict-scheduler for check store status. func (oc *Controller) GetCluster() *core.BasicCluster { - oc.RLock() - defer oc.RUnlock() return oc.cluster } @@ -206,22 +235,21 @@ func (oc *Controller) getNextPushOperatorTime(step OpStep, now time.Time) time.T // "next" is true to indicate that it may exist in next attempt, // and false is the end for the poll. func (oc *Controller) pollNeedDispatchRegion() (r *core.RegionInfo, next bool) { - oc.Lock() - defer oc.Unlock() if oc.opNotifierQueue.Len() == 0 { return nil, false } - item := heap.Pop(&oc.opNotifierQueue).(*operatorWithTime) + item, _ := oc.opNotifierQueue.Pop() regionID := item.op.RegionID() - op, ok := oc.operators[regionID] - if !ok || op == nil { + opi, ok := oc.operators.Load(regionID) + if !ok || opi.(*Operator) == nil { return nil, true } + op := opi.(*Operator) // Check the operator lightly. It cant't dispatch the op for some scenario. var reason CancelReasonType r, reason = oc.checkOperatorLightly(op) if len(reason) != 0 { - _ = oc.removeOperatorLocked(op) + _ = oc.removeOperatorInner(op) if op.Cancel(reason) { log.Warn("remove operator because region disappeared", zap.Uint64("region-id", op.RegionID()), @@ -237,13 +265,13 @@ func (oc *Controller) pollNeedDispatchRegion() (r *core.RegionInfo, next bool) { } now := time.Now() if now.Before(item.time) { - heap.Push(&oc.opNotifierQueue, item) + oc.opNotifierQueue.Push(item) return nil, false } // pushes with new notify time. item.time = oc.getNextPushOperatorTime(step, now) - heap.Push(&oc.opNotifierQueue, item) + oc.opNotifierQueue.Push(item) return r, true } @@ -264,7 +292,6 @@ func (oc *Controller) PushOperators(recordOpStepWithTTL func(regionID uint64)) { // AddWaitingOperator adds operators to waiting operators. func (oc *Controller) AddWaitingOperator(ops ...*Operator) int { - oc.Lock() added := 0 needPromoted := 0 @@ -276,13 +303,11 @@ func (oc *Controller) AddWaitingOperator(ops ...*Operator) int { if i+1 >= len(ops) { // should not be here forever log.Error("orphan merge operators found", zap.String("desc", desc), errs.ZapError(errs.ErrMergeOperator.FastGenByArgs("orphan operator found"))) - oc.Unlock() return added } if ops[i+1].Kind()&OpMerge == 0 { log.Error("merge operator should be paired", zap.String("desc", ops[i+1].Desc()), errs.ZapError(errs.ErrMergeOperator.FastGenByArgs("operator should be paired"))) - oc.Unlock() return added } isMerge = true @@ -309,12 +334,10 @@ func (oc *Controller) AddWaitingOperator(ops ...*Operator) int { oc.wop.PutOperator(ops[i]) } operatorCounter.WithLabelValues(desc, "put").Inc() - oc.wopStatus.ops[desc]++ + oc.wopStatus.incCount(desc) added++ needPromoted++ } - - oc.Unlock() operatorCounter.WithLabelValues(ops[0].Desc(), "promote-add").Add(float64(needPromoted)) for i := 0; i < needPromoted; i++ { oc.PromoteWaitingOperator() @@ -324,13 +347,10 @@ func (oc *Controller) AddWaitingOperator(ops ...*Operator) int { // AddOperator adds operators to the running operators. func (oc *Controller) AddOperator(ops ...*Operator) bool { - oc.Lock() - defer oc.Unlock() - // note: checkAddOperator uses false param for `isPromoting`. // This is used to keep check logic before fixing issue #4946, // but maybe user want to add operator when waiting queue is busy - if oc.exceedStoreLimitLocked(ops...) { + if oc.ExceedStoreLimit(ops...) { for _, op := range ops { operatorCounter.WithLabelValues(op.Desc(), "exceed-limit").Inc() _ = op.Cancel(ExceedStoreLimit) @@ -346,7 +366,7 @@ func (oc *Controller) AddOperator(ops ...*Operator) bool { return false } for _, op := range ops { - if !oc.addOperatorLocked(op) { + if !oc.addOperatorInner(op) { return false } } @@ -355,23 +375,22 @@ func (oc *Controller) AddOperator(ops ...*Operator) bool { // PromoteWaitingOperator promotes operators from waiting operators. func (oc *Controller) PromoteWaitingOperator() { - oc.Lock() - defer oc.Unlock() var ops []*Operator for { // GetOperator returns one operator or two merge operators + // need write lock ops = oc.wop.GetOperator() if ops == nil { return } operatorCounter.WithLabelValues(ops[0].Desc(), "get").Inc() - if oc.exceedStoreLimitLocked(ops...) { + if oc.ExceedStoreLimit(ops...) { for _, op := range ops { operatorCounter.WithLabelValues(op.Desc(), "exceed-limit").Inc() _ = op.Cancel(ExceedStoreLimit) oc.buryOperator(op) } - oc.wopStatus.ops[ops[0].Desc()]-- + oc.wopStatus.decCount(ops[0].Desc()) continue } @@ -381,15 +400,15 @@ func (oc *Controller) PromoteWaitingOperator() { _ = op.Cancel(reason) oc.buryOperator(op) } - oc.wopStatus.ops[ops[0].Desc()]-- + oc.wopStatus.decCount(ops[0].Desc()) continue } - oc.wopStatus.ops[ops[0].Desc()]-- + oc.wopStatus.decCount(ops[0].Desc()) break } for _, op := range ops { - if !oc.addOperatorLocked(op) { + if !oc.addOperatorInner(op) { break } } @@ -420,7 +439,8 @@ func (oc *Controller) checkAddOperator(isPromoting bool, ops ...*Operator) (bool operatorCounter.WithLabelValues(op.Desc(), "epoch-not-match").Inc() return false, EpochNotMatch } - if old := oc.operators[op.RegionID()]; old != nil && !isHigherPriorityOperator(op, old) { + if oldi, ok := oc.operators.Load(op.RegionID()); ok && oldi.(*Operator) != nil && !isHigherPriorityOperator(op, oldi.(*Operator)) { + old := oldi.(*Operator) log.Debug("already have operator, cancel add operator", zap.Uint64("region-id", op.RegionID()), zap.Reflect("old", old)) @@ -438,7 +458,7 @@ func (oc *Controller) checkAddOperator(isPromoting bool, ops ...*Operator) (bool operatorCounter.WithLabelValues(op.Desc(), "unexpected-status").Inc() return false, NotInCreateStatus } - if !isPromoting && oc.wopStatus.ops[op.Desc()] >= oc.config.GetSchedulerMaxWaitingOperator() { + if !isPromoting && oc.wopStatus.getCount(op.Desc()) >= oc.config.GetSchedulerMaxWaitingOperator() { log.Debug("exceed max return false", zap.Uint64("waiting", oc.wopStatus.ops[op.Desc()]), zap.String("desc", op.Desc()), zap.Uint64("max", oc.config.GetSchedulerMaxWaitingOperator())) operatorCounter.WithLabelValues(op.Desc(), "exceed-max-waiting").Inc() return false, ExceedWaitLimit @@ -483,7 +503,7 @@ func isHigherPriorityOperator(new, old *Operator) bool { return new.GetPriorityLevel() > old.GetPriorityLevel() } -func (oc *Controller) addOperatorLocked(op *Operator) bool { +func (oc *Controller) addOperatorInner(op *Operator) bool { regionID := op.RegionID() log.Info("add operator", zap.Uint64("region-id", regionID), @@ -492,8 +512,9 @@ func (oc *Controller) addOperatorLocked(op *Operator) bool { // If there is an old operator, replace it. The priority should be checked // already. - if old, ok := oc.operators[regionID]; ok { - _ = oc.removeOperatorLocked(old) + if oldi, ok := oc.operators.Load(regionID); ok { + old := oldi.(*Operator) + _ = oc.removeOperatorInner(old) _ = old.Replace() oc.buryOperator(old) } @@ -509,8 +530,8 @@ func (oc *Controller) addOperatorLocked(op *Operator) bool { operatorCounter.WithLabelValues(op.Desc(), "unexpected").Inc() return false } - oc.operators[regionID] = op - oc.counts[op.SchedulerKind()]++ + oc.operators.Store(regionID, op) + oc.counts.inc(op.SchedulerKind()) operatorCounter.WithLabelValues(op.Desc(), "start").Inc() operatorSizeHist.WithLabelValues(op.Desc()).Observe(float64(op.ApproximateSize)) opInfluence := NewTotalOpInfluence([]*Operator{op}, oc.cluster) @@ -538,7 +559,7 @@ func (oc *Controller) addOperatorLocked(op *Operator) bool { } } - heap.Push(&oc.opNotifierQueue, &operatorWithTime{op: op, time: oc.getNextPushOperatorTime(step, time.Now())}) + oc.opNotifierQueue.Push(&operatorWithTime{op: op, time: oc.getNextPushOperatorTime(step, time.Now())}) operatorCounter.WithLabelValues(op.Desc(), "create").Inc() for _, counter := range op.Counters { counter.Inc() @@ -562,9 +583,7 @@ func (oc *Controller) ack(op *Operator) { // RemoveOperators removes all operators from the running operators. func (oc *Controller) RemoveOperators(reasons ...CancelReasonType) { - oc.Lock() - removed := oc.removeOperatorsLocked() - oc.Unlock() + removed := oc.removeOperatorsInner() var cancelReason CancelReasonType if len(reasons) > 0 { cancelReason = reasons[0] @@ -580,26 +599,26 @@ func (oc *Controller) RemoveOperators(reasons ...CancelReasonType) { } } -func (oc *Controller) removeOperatorsLocked() []*Operator { +func (oc *Controller) removeOperatorsInner() []*Operator { var removed []*Operator - for regionID, op := range oc.operators { - delete(oc.operators, regionID) - oc.counts[op.SchedulerKind()]-- + oc.operators.Range(func(regionID, value any) bool { + op := value.(*Operator) + oc.operators.Delete(regionID) + oc.counts.dec(op.SchedulerKind()) operatorCounter.WithLabelValues(op.Desc(), "remove").Inc() oc.ack(op) if op.Kind()&OpMerge != 0 { oc.removeRelatedMergeOperator(op) } removed = append(removed, op) - } + return true + }) return removed } // RemoveOperator removes an operator from the running operators. func (oc *Controller) RemoveOperator(op *Operator, reasons ...CancelReasonType) bool { - oc.Lock() - removed := oc.removeOperatorLocked(op) - oc.Unlock() + removed := oc.removeOperatorInner(op) var cancelReason CancelReasonType if len(reasons) > 0 { cancelReason = reasons[0] @@ -617,16 +636,14 @@ func (oc *Controller) RemoveOperator(op *Operator, reasons ...CancelReasonType) } func (oc *Controller) removeOperatorWithoutBury(op *Operator) bool { - oc.Lock() - defer oc.Unlock() - return oc.removeOperatorLocked(op) + return oc.removeOperatorInner(op) } -func (oc *Controller) removeOperatorLocked(op *Operator) bool { +func (oc *Controller) removeOperatorInner(op *Operator) bool { regionID := op.RegionID() - if cur := oc.operators[regionID]; cur == op { - delete(oc.operators, regionID) - oc.counts[op.SchedulerKind()]-- + if cur, ok := oc.operators.Load(regionID); ok && cur.(*Operator) == op { + oc.operators.Delete(regionID) + oc.counts.dec(op.SchedulerKind()) operatorCounter.WithLabelValues(op.Desc(), "remove").Inc() oc.ack(op) if op.Kind()&OpMerge != 0 { @@ -639,12 +656,17 @@ func (oc *Controller) removeOperatorLocked(op *Operator) bool { func (oc *Controller) removeRelatedMergeOperator(op *Operator) { relatedID, _ := strconv.ParseUint(op.AdditionalInfos[string(RelatedMergeRegion)], 10, 64) - if relatedOp := oc.operators[relatedID]; relatedOp != nil && relatedOp.Status() != CANCELED { + relatedOpi, ok := oc.operators.Load(relatedID) + if !ok { + return + } + relatedOp := relatedOpi.(*Operator) + if relatedOp != nil && relatedOp.Status() != CANCELED { log.Info("operator canceled related merge region", zap.Uint64("region-id", relatedOp.RegionID()), zap.String("additional-info", relatedOp.GetAdditionalInfo()), zap.Duration("takes", relatedOp.RunningTime())) - oc.removeOperatorLocked(relatedOp) + oc.removeOperatorInner(relatedOp) relatedOp.Cancel(RelatedMergeRegion) oc.buryOperator(relatedOp) } @@ -712,9 +734,8 @@ func (oc *Controller) buryOperator(op *Operator) { // GetOperatorStatus gets the operator and its status with the specify id. func (oc *Controller) GetOperatorStatus(id uint64) *OpWithStatus { - oc.Lock() - defer oc.Unlock() - if op, ok := oc.operators[id]; ok { + if opi, ok := oc.operators.Load(id); ok && opi.(*Operator) != nil { + op := opi.(*Operator) return NewOpWithStatus(op) } return oc.records.Get(id) @@ -722,43 +743,39 @@ func (oc *Controller) GetOperatorStatus(id uint64) *OpWithStatus { // GetOperator gets an operator from the given region. func (oc *Controller) GetOperator(regionID uint64) *Operator { - oc.RLock() - defer oc.RUnlock() - return oc.operators[regionID] + if v, ok := oc.operators.Load(regionID); ok { + return v.(*Operator) + } + return nil } // GetOperators gets operators from the running operators. func (oc *Controller) GetOperators() []*Operator { - oc.RLock() - defer oc.RUnlock() - - operators := make([]*Operator, 0, len(oc.operators)) - for _, op := range oc.operators { - operators = append(operators, op) - } - + operators := make([]*Operator, 0, oc.opNotifierQueue.Len()) + oc.operators.Range( + func(_, value any) bool { + operators = append(operators, value.(*Operator)) + return true + }) return operators } // GetWaitingOperators gets operators from the waiting operators. func (oc *Controller) GetWaitingOperators() []*Operator { - oc.RLock() - defer oc.RUnlock() return oc.wop.ListOperator() } // GetOperatorsOfKind returns the running operators of the kind. func (oc *Controller) GetOperatorsOfKind(mask OpKind) []*Operator { - oc.RLock() - defer oc.RUnlock() - - operators := make([]*Operator, 0, len(oc.operators)) - for _, op := range oc.operators { - if op.Kind()&mask != 0 { - operators = append(operators, op) - } - } - + operators := make([]*Operator, 0, oc.opNotifierQueue.Len()) + oc.operators.Range( + func(_, value any) bool { + op := value.(*Operator) + if op.Kind()&mask != 0 { + operators = append(operators, value.(*Operator)) + } + return true + }) return operators } @@ -810,9 +827,7 @@ func (oc *Controller) GetHistory(start time.Time) []OpHistory { // OperatorCount gets the count of operators filtered by kind. // kind only has one OpKind. func (oc *Controller) OperatorCount(kind OpKind) uint64 { - oc.RLock() - defer oc.RUnlock() - return oc.counts[kind] + return oc.counts.getCountByKind(kind) } // GetOpInfluence gets OpInfluence. @@ -820,16 +835,17 @@ func (oc *Controller) GetOpInfluence(cluster *core.BasicCluster) OpInfluence { influence := OpInfluence{ StoresInfluence: make(map[uint64]*StoreInfluence), } - oc.RLock() - defer oc.RUnlock() - for _, op := range oc.operators { - if !op.CheckTimeout() && !op.CheckSuccess() { - region := cluster.GetRegion(op.RegionID()) - if region != nil { - op.UnfinishedInfluence(influence, region) + oc.operators.Range( + func(_, value any) bool { + op := value.(*Operator) + if !op.CheckTimeout() && !op.CheckSuccess() { + region := cluster.GetRegion(op.RegionID()) + if region != nil { + op.UnfinishedInfluence(influence, region) + } } - } - } + return true + }) return influence } @@ -873,10 +889,8 @@ func NewTotalOpInfluence(operators []*Operator, cluster *core.BasicCluster) OpIn // SetOperator is only used for test. func (oc *Controller) SetOperator(op *Operator) { - oc.Lock() - defer oc.Unlock() - oc.operators[op.RegionID()] = op - oc.counts[op.SchedulerKind()]++ + oc.operators.Store(op.RegionID(), op) + oc.counts.inc(op.SchedulerKind()) } // OpWithStatus records the operator and its status. @@ -932,13 +946,6 @@ func (o *records) Put(op *Operator) { // ExceedStoreLimit returns true if the store exceeds the cost limit after adding the Otherwise, returns false. func (oc *Controller) ExceedStoreLimit(ops ...*Operator) bool { - oc.Lock() - defer oc.Unlock() - return oc.exceedStoreLimitLocked(ops...) -} - -// exceedStoreLimitLocked returns true if the store exceeds the cost limit after adding the Otherwise, returns false. -func (oc *Controller) exceedStoreLimitLocked(ops ...*Operator) bool { // The operator with Urgent priority, like admin operators, should ignore the store limit check. var desc string if len(ops) != 0 { diff --git a/pkg/schedule/operator/operator_controller_test.go b/pkg/schedule/operator/operator_controller_test.go index 643dbda9d73..f2f2b7305ce 100644 --- a/pkg/schedule/operator/operator_controller_test.go +++ b/pkg/schedule/operator/operator_controller_test.go @@ -15,7 +15,6 @@ package operator import ( - "container/heap" "context" "encoding/hex" "fmt" @@ -365,10 +364,10 @@ func (suite *operatorControllerTestSuite) TestPollDispatchRegion() { oc.SetOperator(op4) re.True(op2.Start()) oc.SetOperator(op2) - heap.Push(&oc.opNotifierQueue, &operatorWithTime{op: op1, time: time.Now().Add(100 * time.Millisecond)}) - heap.Push(&oc.opNotifierQueue, &operatorWithTime{op: op3, time: time.Now().Add(300 * time.Millisecond)}) - heap.Push(&oc.opNotifierQueue, &operatorWithTime{op: op4, time: time.Now().Add(499 * time.Millisecond)}) - heap.Push(&oc.opNotifierQueue, &operatorWithTime{op: op2, time: time.Now().Add(500 * time.Millisecond)}) + oc.opNotifierQueue.Push(&operatorWithTime{op: op1, time: time.Now().Add(100 * time.Millisecond)}) + oc.opNotifierQueue.Push(&operatorWithTime{op: op3, time: time.Now().Add(300 * time.Millisecond)}) + oc.opNotifierQueue.Push(&operatorWithTime{op: op4, time: time.Now().Add(499 * time.Millisecond)}) + oc.opNotifierQueue.Push(&operatorWithTime{op: op2, time: time.Now().Add(500 * time.Millisecond)}) } // first poll got nil r, next := oc.pollNeedDispatchRegion() @@ -430,7 +429,7 @@ func (suite *operatorControllerTestSuite) TestPollDispatchRegionForMergeRegion() re.Len(ops, 2) re.Equal(2, controller.AddWaitingOperator(ops...)) // Change next push time to now, it's used to make test case faster. - controller.opNotifierQueue[0].time = time.Now() + controller.opNotifierQueue.heap[0].time = time.Now() // first poll gets source region op. r, next := controller.pollNeedDispatchRegion() @@ -438,7 +437,7 @@ func (suite *operatorControllerTestSuite) TestPollDispatchRegionForMergeRegion() re.Equal(r, source) // second poll gets target region op. - controller.opNotifierQueue[0].time = time.Now() + controller.opNotifierQueue.heap[0].time = time.Now() r, next = controller.pollNeedDispatchRegion() re.True(next) re.Equal(r, target) @@ -448,18 +447,18 @@ func (suite *operatorControllerTestSuite) TestPollDispatchRegionForMergeRegion() r, next = controller.pollNeedDispatchRegion() re.True(next) re.Nil(r) - re.Len(controller.opNotifierQueue, 1) - re.Empty(controller.operators) + re.Equal(1, controller.opNotifierQueue.Len()) + re.Empty(controller.GetOperators()) re.Empty(controller.wop.ListOperator()) re.NotNil(controller.records.Get(101)) re.NotNil(controller.records.Get(102)) // fourth poll removes target region op from opNotifierQueue - controller.opNotifierQueue[0].time = time.Now() + controller.opNotifierQueue.heap[0].time = time.Now() r, next = controller.pollNeedDispatchRegion() re.True(next) re.Nil(r) - re.Empty(controller.opNotifierQueue) + re.Equal(0, controller.opNotifierQueue.Len()) // Add the two ops to waiting operators again. source.GetMeta().RegionEpoch = &metapb.RegionEpoch{ConfVer: 0, Version: 0} @@ -471,7 +470,7 @@ func (suite *operatorControllerTestSuite) TestPollDispatchRegionForMergeRegion() // change the target RegionEpoch // first poll gets source region from opNotifierQueue target.GetMeta().RegionEpoch = &metapb.RegionEpoch{ConfVer: 0, Version: 1} - controller.opNotifierQueue[0].time = time.Now() + controller.opNotifierQueue.heap[0].time = time.Now() r, next = controller.pollNeedDispatchRegion() re.True(next) re.Equal(r, source) @@ -479,17 +478,17 @@ func (suite *operatorControllerTestSuite) TestPollDispatchRegionForMergeRegion() r, next = controller.pollNeedDispatchRegion() re.True(next) re.Nil(r) - re.Len(controller.opNotifierQueue, 1) - re.Empty(controller.operators) + re.Equal(1, controller.opNotifierQueue.Len()) + re.Empty(controller.GetOperators()) re.Empty(controller.wop.ListOperator()) re.NotNil(controller.records.Get(101)) re.NotNil(controller.records.Get(102)) - controller.opNotifierQueue[0].time = time.Now() + controller.opNotifierQueue.heap[0].time = time.Now() r, next = controller.pollNeedDispatchRegion() re.True(next) re.Nil(r) - re.Empty(controller.opNotifierQueue) + re.Equal(0, controller.opNotifierQueue.Len()) } func (suite *operatorControllerTestSuite) TestCheckOperatorLightly() { @@ -911,7 +910,7 @@ func (suite *operatorControllerTestSuite) TestAddWaitingOperator() { batch = append(batch, addPeerOp(100)) added = controller.AddWaitingOperator(batch...) re.Equal(1, added) - re.NotNil(controller.operators[uint64(100)]) + re.NotNil(controller.GetOperator(uint64(100))) source := newRegionInfo(101, "1a", "1b", 1, 1, []uint64{101, 1}, []uint64{101, 1}) cluster.PutRegion(source) @@ -952,7 +951,7 @@ func (suite *operatorControllerTestSuite) TestInvalidStoreId() { RemovePeer{FromStore: 3, PeerID: 3, IsDownStore: false}, } op := NewTestOperator(1, &metapb.RegionEpoch{}, OpRegion, steps...) - re.True(oc.addOperatorLocked(op)) + re.True(oc.AddOperator(op)) // Although store 3 does not exist in PD, PD can also send op to TiKV. re.Equal(pdpb.OperatorStatus_RUNNING, oc.GetOperatorStatus(1).Status) } diff --git a/pkg/schedule/operator/operator_queue.go b/pkg/schedule/operator/operator_queue.go index 0e7f34ecc51..2233845724e 100644 --- a/pkg/schedule/operator/operator_queue.go +++ b/pkg/schedule/operator/operator_queue.go @@ -15,6 +15,8 @@ package operator import ( + "container/heap" + "sync" "time" ) @@ -50,3 +52,34 @@ func (opn *operatorQueue) Pop() any { *opn = old[0 : n-1] return item } + +type concurrentHeapOpQueue struct { + sync.Mutex + heap operatorQueue +} + +func newConcurrentHeapOpQueue() *concurrentHeapOpQueue { + return &concurrentHeapOpQueue{heap: make(operatorQueue, 0)} +} + +func (ch *concurrentHeapOpQueue) Len() int { + ch.Lock() + defer ch.Unlock() + return len(ch.heap) +} + +func (ch *concurrentHeapOpQueue) Push(x *operatorWithTime) { + ch.Lock() + defer ch.Unlock() + heap.Push(&ch.heap, x) +} + +func (ch *concurrentHeapOpQueue) Pop() (*operatorWithTime, bool) { + ch.Lock() + defer ch.Unlock() + if len(ch.heap) == 0 { + return nil, false + } + x := heap.Pop(&ch.heap).(*operatorWithTime) + return x, true +} diff --git a/pkg/schedule/operator/waiting_operator.go b/pkg/schedule/operator/waiting_operator.go index 8f5c72b053b..b3b1b885663 100644 --- a/pkg/schedule/operator/waiting_operator.go +++ b/pkg/schedule/operator/waiting_operator.go @@ -16,6 +16,8 @@ package operator import ( "math/rand" + + "github.com/tikv/pd/pkg/utils/syncutil" ) // priorityWeight is used to represent the weight of different priorities of operators. @@ -36,6 +38,7 @@ type bucket struct { // randBuckets is an implementation of waiting operators type randBuckets struct { + mu syncutil.Mutex totalWeight float64 buckets []*bucket } @@ -53,6 +56,8 @@ func newRandBuckets() *randBuckets { // PutOperator puts an operator into the random buckets. func (b *randBuckets) PutOperator(op *Operator) { + b.mu.Lock() + defer b.mu.Unlock() priority := op.GetPriorityLevel() bucket := b.buckets[priority] if len(bucket.ops) == 0 { @@ -63,6 +68,8 @@ func (b *randBuckets) PutOperator(op *Operator) { // ListOperator lists all operator in the random buckets. func (b *randBuckets) ListOperator() []*Operator { + b.mu.Lock() + defer b.mu.Unlock() var ops []*Operator for i := range b.buckets { bucket := b.buckets[i] @@ -73,6 +80,8 @@ func (b *randBuckets) ListOperator() []*Operator { // GetOperator gets an operator from the random buckets. func (b *randBuckets) GetOperator() []*Operator { + b.mu.Lock() + defer b.mu.Unlock() if b.totalWeight == 0 { return nil } @@ -106,12 +115,34 @@ func (b *randBuckets) GetOperator() []*Operator { // waitingOperatorStatus is used to limit the count of each kind of operators. type waitingOperatorStatus struct { + mu syncutil.Mutex ops map[string]uint64 } // newWaitingOperatorStatus creates a new waitingOperatorStatus. func newWaitingOperatorStatus() *waitingOperatorStatus { return &waitingOperatorStatus{ - make(map[string]uint64), + ops: make(map[string]uint64), } } + +// incCount increments the count of the given operator kind. +func (s *waitingOperatorStatus) incCount(kind string) { + s.mu.Lock() + defer s.mu.Unlock() + s.ops[kind]++ +} + +// decCount decrements the count of the given operator kind. +func (s *waitingOperatorStatus) decCount(kind string) { + s.mu.Lock() + defer s.mu.Unlock() + s.ops[kind]-- +} + +// getCount returns the count of the given operator kind. +func (s *waitingOperatorStatus) getCount(kind string) uint64 { + s.mu.Lock() + defer s.mu.Unlock() + return s.ops[kind] +}