diff --git a/server/cluster/cluster.go b/server/cluster/cluster.go index a8ced3c0781..67177a0f8e7 100644 --- a/server/cluster/cluster.go +++ b/server/cluster/cluster.go @@ -881,7 +881,7 @@ func (c *RaftCluster) processRegionHeartbeat(region *core.RegionInfo) error { // To prevent a concurrent heartbeat of another region from overriding the up-to-date region info by a stale one, // check its validation again here. // - // However it can't solve the race condition of concurrent heartbeats from the same region. + // However, it can't solve the race condition of concurrent heartbeats from the same region. if overlaps, err = c.core.AtomicCheckAndPutRegion(region); err != nil { return err } @@ -899,10 +899,6 @@ func (c *RaftCluster) processRegionHeartbeat(region *core.RegionInfo) error { c.regionStats.Observe(region, c.getRegionStoresLocked(region)) } - if !c.IsPrepared() && isNew { - c.coordinator.prepareChecker.collect(region) - } - if c.storage != nil { // If there are concurrent heartbeats from the same region, the last write will win even if // writes to storage in the critical area. So don't use mutex to protect it. diff --git a/server/cluster/coordinator_test.go b/server/cluster/coordinator_test.go index 93e02b4324c..f15799498b3 100644 --- a/server/cluster/coordinator_test.go +++ b/server/cluster/coordinator_test.go @@ -147,7 +147,7 @@ func (c *testCluster) LoadRegion(regionID uint64, followerStoreIDs ...uint64) er peer, _ := c.AllocPeer(id) region.Peers = append(region.Peers, peer) } - return c.putRegion(core.NewRegionInfo(region, nil)) + return c.putRegion(core.NewRegionInfo(region, nil, core.SetSource(core.Storage))) } func TestBasic(t *testing.T) { @@ -231,7 +231,7 @@ func TestDispatch(t *testing.T) { func dispatchHeartbeat(co *coordinator, region *core.RegionInfo, stream hbstream.HeartbeatStream) error { co.hbStreams.BindStream(region.GetLeader().GetStoreId(), stream) - if err := co.cluster.putRegion(region.Clone()); err != nil { + if err := co.cluster.putRegion(region.Clone(core.SetSource(core.Heartbeat))); err != nil { return err } co.opController.Dispatch(region, schedule.DispatchFromHeartBeat) @@ -658,14 +658,14 @@ func TestShouldRun(t *testing.T) { for _, testCase := range testCases { r := tc.GetRegion(testCase.regionID) - nr := r.Clone(core.WithLeader(r.GetPeers()[0])) + nr := r.Clone(core.WithLeader(r.GetPeers()[0]), core.SetSource(core.Heartbeat)) re.NoError(tc.processRegionHeartbeat(nr)) re.Equal(testCase.shouldRun, co.shouldRun()) } nr := &metapb.Region{Id: 6, Peers: []*metapb.Peer{}} - newRegion := core.NewRegionInfo(nr, nil) + newRegion := core.NewRegionInfo(nr, nil, core.SetSource(core.Heartbeat)) re.Error(tc.processRegionHeartbeat(newRegion)) - re.Equal(7, co.prepareChecker.sum) + re.Equal(7, tc.core.GetClusterNotFromStorageRegionsCnt()) } func TestShouldRunWithNonLeaderRegions(t *testing.T) { @@ -701,14 +701,14 @@ func TestShouldRunWithNonLeaderRegions(t *testing.T) { for _, testCase := range testCases { r := tc.GetRegion(testCase.regionID) - nr := r.Clone(core.WithLeader(r.GetPeers()[0])) + nr := r.Clone(core.WithLeader(r.GetPeers()[0]), core.SetSource(core.Heartbeat)) re.NoError(tc.processRegionHeartbeat(nr)) re.Equal(testCase.shouldRun, co.shouldRun()) } nr := &metapb.Region{Id: 9, Peers: []*metapb.Peer{}} - newRegion := core.NewRegionInfo(nr, nil) + newRegion := core.NewRegionInfo(nr, nil, core.SetSource(core.Heartbeat)) re.Error(tc.processRegionHeartbeat(newRegion)) - re.Equal(9, co.prepareChecker.sum) + re.Equal(9, tc.core.GetClusterNotFromStorageRegionsCnt()) // Now, after server is prepared, there exist some regions with no leader. re.Equal(uint64(0), tc.GetRegion(10).GetLeader().GetStoreId()) @@ -1003,7 +1003,6 @@ func TestRestart(t *testing.T) { re.NoError(tc.addRegionStore(3, 3)) re.NoError(tc.addLeaderRegion(1, 1)) region := tc.GetRegion(1) - co.prepareChecker.collect(region) // Add 1 replica on store 2. stream := mockhbstream.NewHeartbeatStream() @@ -1016,7 +1015,6 @@ func TestRestart(t *testing.T) { // Recreate coordinator then add another replica on store 3. co = newCoordinator(ctx, tc.RaftCluster, hbStreams) - co.prepareChecker.collect(region) co.run() re.NoError(dispatchHeartbeat(co, region, stream)) region = waitAddLearner(re, stream, region, 3) diff --git a/server/cluster/prepare_checker.go b/server/cluster/prepare_checker.go index 6d20503ef55..c330a58c94a 100644 --- a/server/cluster/prepare_checker.go +++ b/server/cluster/prepare_checker.go @@ -25,16 +25,13 @@ import ( type prepareChecker struct { syncutil.RWMutex - reactiveRegions map[uint64]int - start time.Time - sum int - prepared bool + start time.Time + prepared bool } func newPrepareChecker() *prepareChecker { return &prepareChecker{ - start: time.Now(), - reactiveRegions: make(map[uint64]int), + start: time.Now(), } } @@ -51,14 +48,8 @@ func (checker *prepareChecker) check(c *core.BasicCluster) bool { } notLoadedFromRegionsCnt := c.GetClusterNotFromStorageRegionsCnt() totalRegionsCnt := c.GetRegionCount() - if float64(notLoadedFromRegionsCnt) > float64(totalRegionsCnt)*collectFactor { - log.Info("meta not loaded from region number is satisfied, finish prepare checker", - zap.Int("not-from-storage-region", notLoadedFromRegionsCnt), zap.Int("total-region", totalRegionsCnt)) - checker.prepared = true - return true - } // The number of active regions should be more than total region of all stores * collectFactor - if float64(totalRegionsCnt)*collectFactor > float64(checker.sum) { + if float64(totalRegionsCnt)*collectFactor > float64(notLoadedFromRegionsCnt) { return false } for _, store := range c.GetStores() { @@ -67,23 +58,15 @@ func (checker *prepareChecker) check(c *core.BasicCluster) bool { } storeID := store.GetID() // For each store, the number of active regions should be more than total region of the store * collectFactor - if float64(c.GetStoreRegionCount(storeID))*collectFactor > float64(checker.reactiveRegions[storeID]) { + if float64(c.GetStoreRegionCount(storeID))*collectFactor > float64(c.GetNotFromStorageRegionsCntByStore(storeID)) { return false } } + log.Info("not loaded from storage region number is satisfied, finish prepare checker", zap.Int("not-from-storage-region", notLoadedFromRegionsCnt), zap.Int("total-region", totalRegionsCnt)) checker.prepared = true return true } -func (checker *prepareChecker) collect(region *core.RegionInfo) { - checker.Lock() - defer checker.Unlock() - for _, p := range region.GetPeers() { - checker.reactiveRegions[p.GetStoreId()]++ - } - checker.sum++ -} - func (checker *prepareChecker) isPrepared() bool { checker.RLock() defer checker.RUnlock() diff --git a/server/core/region.go b/server/core/region.go index 391b65879d1..a198c3d9133 100644 --- a/server/core/region.go +++ b/server/core/region.go @@ -1187,11 +1187,23 @@ func (r *RegionsInfo) GetStoreWriteRate(storeID uint64) (bytesRate, keysRate flo return } -// GetClusterNotFromStorageRegionsCnt gets the total count of regions that not loaded from storage anymore +// GetClusterNotFromStorageRegionsCnt gets the `NotFromStorageRegionsCnt` count of regions that not loaded from storage anymore. func (r *RegionsInfo) GetClusterNotFromStorageRegionsCnt() int { r.t.RLock() defer r.t.RUnlock() - return r.tree.notFromStorageRegionsCnt + return r.tree.notFromStorageRegionsCount() +} + +// GetNotFromStorageRegionsCntByStore gets the `NotFromStorageRegionsCnt` count of a store's leader, follower and learner by storeID. +func (r *RegionsInfo) GetNotFromStorageRegionsCntByStore(storeID uint64) int { + r.st.RLock() + defer r.st.RUnlock() + return r.getNotFromStorageRegionsCntByStoreLocked(storeID) +} + +// getNotFromStorageRegionsCntByStoreLocked gets the `NotFromStorageRegionsCnt` count of a store's leader, follower and learner by storeID. +func (r *RegionsInfo) getNotFromStorageRegionsCntByStoreLocked(storeID uint64) int { + return r.leaders[storeID].notFromStorageRegionsCount() + r.followers[storeID].notFromStorageRegionsCount() + r.learners[storeID].notFromStorageRegionsCount() } // GetMetaRegions gets a set of metapb.Region from regionMap @@ -1227,7 +1239,7 @@ func (r *RegionsInfo) GetStoreRegionCount(storeID uint64) int { return r.getStoreRegionCountLocked(storeID) } -// GetStoreRegionCount gets the total count of a store's leader, follower and learner RegionInfo by storeID +// getStoreRegionCountLocked gets the total count of a store's leader, follower and learner RegionInfo by storeID func (r *RegionsInfo) getStoreRegionCountLocked(storeID uint64) int { return r.leaders[storeID].length() + r.followers[storeID].length() + r.learners[storeID].length() } diff --git a/server/core/region_tree.go b/server/core/region_tree.go index 5bf590740e0..cf2da1362ee 100644 --- a/server/core/region_tree.go +++ b/server/core/region_tree.go @@ -83,6 +83,13 @@ func (t *regionTree) length() int { return t.tree.Len() } +func (t *regionTree) notFromStorageRegionsCount() int { + if t == nil { + return 0 + } + return t.notFromStorageRegionsCnt +} + // getOverlaps gets the regions which are overlapped with the specified region range. func (t *regionTree) getOverlaps(region *RegionInfo) []*RegionInfo { item := ®ionItem{RegionInfo: region} diff --git a/tests/pdctl/helper.go b/tests/pdctl/helper.go index 5340357528f..56262393514 100644 --- a/tests/pdctl/helper.go +++ b/tests/pdctl/helper.go @@ -116,6 +116,7 @@ func MustPutRegion(re *require.Assertions, cluster *tests.TestCluster, regionID, Peers: []*metapb.Peer{leader}, RegionEpoch: &metapb.RegionEpoch{ConfVer: 1, Version: 1}, } + opts = append(opts, core.SetSource(core.Heartbeat)) r := core.NewRegionInfo(metaRegion, leader, opts...) err := cluster.HandleRegionHeartbeat(r) re.NoError(err) diff --git a/tests/server/cluster/cluster_test.go b/tests/server/cluster/cluster_test.go index 73bf2fb32d3..43640e7a103 100644 --- a/tests/server/cluster/cluster_test.go +++ b/tests/server/cluster/cluster_test.go @@ -1414,7 +1414,7 @@ func putRegionWithLeader(re *require.Assertions, rc *cluster.RaftCluster, id id. StartKey: []byte{byte(i)}, EndKey: []byte{byte(i + 1)}, } - rc.HandleRegionHeartbeat(core.NewRegionInfo(region, region.Peers[0])) + rc.HandleRegionHeartbeat(core.NewRegionInfo(region, region.Peers[0], core.SetSource(core.Heartbeat))) } time.Sleep(50 * time.Millisecond)