From 90273ddef5dfb41304a7b9bc7a85bd65ad4ce838 Mon Sep 17 00:00:00 2001 From: Hu# Date: Tue, 9 Jul 2024 11:42:03 +0800 Subject: [PATCH] tools/simulator: resolve simulator race (#8376) ref tikv/pd#8135 Signed-off-by: husharp Co-authored-by: ti-chi-bot[bot] <108142056+ti-chi-bot[bot]@users.noreply.github.com> --- Makefile | 2 +- pkg/core/region.go | 4 +++- tools/pd-simulator/simulator/client.go | 29 +++++++++++++------------- tools/pd-simulator/simulator/drive.go | 23 ++++++++++---------- tools/pd-simulator/simulator/event.go | 5 +++-- tools/pd-simulator/simulator/node.go | 26 ++++++++++++++--------- tools/pd-simulator/simulator/raft.go | 17 ++++++++++++--- tools/pd-simulator/simulator/task.go | 11 ++++++++-- 8 files changed, 71 insertions(+), 46 deletions(-) diff --git a/Makefile b/Makefile index f7bf8364552..34c3be775be 100644 --- a/Makefile +++ b/Makefile @@ -121,7 +121,7 @@ pd-analysis: pd-heartbeat-bench: cd tools && CGO_ENABLED=0 go build -gcflags '$(GCFLAGS)' -ldflags '$(LDFLAGS)' -o $(BUILD_BIN_PATH)/pd-heartbeat-bench pd-heartbeat-bench/main.go simulator: - cd tools && CGO_ENABLED=0 go build -gcflags '$(GCFLAGS)' -ldflags '$(LDFLAGS)' -o $(BUILD_BIN_PATH)/pd-simulator pd-simulator/main.go + cd tools && GOEXPERIMENT=$(BUILD_GOEXPERIMENT) CGO_ENABLED=$(BUILD_CGO_ENABLED) go build $(BUILD_FLAGS) -gcflags '$(GCFLAGS)' -ldflags '$(LDFLAGS)' -o $(BUILD_BIN_PATH)/pd-simulator pd-simulator/main.go regions-dump: cd tools && CGO_ENABLED=0 go build -gcflags '$(GCFLAGS)' -ldflags '$(LDFLAGS)' -o $(BUILD_BIN_PATH)/regions-dump regions-dump/main.go stores-dump: diff --git a/pkg/core/region.go b/pkg/core/region.go index 73f2fdd62e7..f0c78f443bd 100644 --- a/pkg/core/region.go +++ b/pkg/core/region.go @@ -2211,8 +2211,10 @@ func NewTestRegionInfo(regionID, storeID uint64, start, end []byte, opts ...Regi } // TraverseRegions executes a function on all regions. -// ONLY for simulator now and function need to be self-locked. +// ONLY for simulator now and only for READ. func (r *RegionsInfo) TraverseRegions(lockedFunc func(*RegionInfo)) { + r.t.RLock() + defer r.t.RUnlock() for _, item := range r.regions { lockedFunc(item.RegionInfo) } diff --git a/tools/pd-simulator/simulator/client.go b/tools/pd-simulator/simulator/client.go index 77166f38674..8acf3ccd9ab 100644 --- a/tools/pd-simulator/simulator/client.go +++ b/tools/pd-simulator/simulator/client.go @@ -20,6 +20,7 @@ import ( "strconv" "strings" "sync" + "sync/atomic" "time" "github.com/pingcap/errors" @@ -50,7 +51,7 @@ type Client interface { } const ( - pdTimeout = time.Second + pdTimeout = 3 * time.Second maxInitClusterRetries = 100 // retry to get leader URL leaderChangedWaitTime = 100 * time.Millisecond @@ -62,13 +63,13 @@ var ( errFailInitClusterID = errors.New("[pd] failed to get cluster id") PDHTTPClient pdHttp.Client SD pd.ServiceDiscovery - ClusterID uint64 + ClusterID atomic.Uint64 ) // requestHeader returns a header for fixed ClusterID. func requestHeader() *pdpb.RequestHeader { return &pdpb.RequestHeader{ - ClusterId: ClusterID, + ClusterId: ClusterID.Load(), } } @@ -205,12 +206,11 @@ func (c *client) reportRegionHeartbeat(ctx context.Context, stream pdpb.PD_Regio defer wg.Done() for { select { - case r := <-c.reportRegionHeartbeatCh: - if r == nil { + case region := <-c.reportRegionHeartbeatCh: + if region == nil { simutil.Logger.Error("report nil regionHeartbeat error", zap.String("tag", c.tag), zap.Error(errors.New("nil region"))) } - region := r.Clone() request := &pdpb.RegionHeartbeatRequest{ Header: requestHeader(), Region: region.GetMeta(), @@ -281,9 +281,8 @@ func (c *client) PutStore(ctx context.Context, store *metapb.Store) error { return nil } -func (c *client) StoreHeartbeat(ctx context.Context, stats *pdpb.StoreStats) error { +func (c *client) StoreHeartbeat(ctx context.Context, newStats *pdpb.StoreStats) error { ctx, cancel := context.WithTimeout(ctx, pdTimeout) - newStats := typeutil.DeepClone(stats, core.StoreStatsFactory) resp, err := c.pdClient().StoreHeartbeat(ctx, &pdpb.StoreHeartbeatRequest{ Header: requestHeader(), Stats: newStats, @@ -382,8 +381,8 @@ func getLeaderURL(ctx context.Context, conn *grpc.ClientConn) (string, *grpc.Cli if members.GetHeader().GetError() != nil { return "", nil, errors.New(members.GetHeader().GetError().String()) } - ClusterID = members.GetHeader().GetClusterId() - if ClusterID == 0 { + ClusterID.Store(members.GetHeader().GetClusterId()) + if ClusterID.Load() == 0 { return "", nil, errors.New("cluster id is 0") } if members.GetLeader() == nil { @@ -413,9 +412,9 @@ func (rc *RetryClient) PutStore(ctx context.Context, store *metapb.Store) error return err } -func (rc *RetryClient) StoreHeartbeat(ctx context.Context, stats *pdpb.StoreStats) error { +func (rc *RetryClient) StoreHeartbeat(ctx context.Context, newStats *pdpb.StoreStats) error { _, err := rc.requestWithRetry(func() (any, error) { - err := rc.client.StoreHeartbeat(ctx, stats) + err := rc.client.StoreHeartbeat(ctx, newStats) return nil, err }) return err @@ -466,10 +465,10 @@ retry: break retry } } - if ClusterID == 0 { + if ClusterID.Load() == 0 { return "", nil, errors.WithStack(errFailInitClusterID) } - simutil.Logger.Info("get cluster id successfully", zap.Uint64("cluster-id", ClusterID)) + simutil.Logger.Info("get cluster id successfully", zap.Uint64("cluster-id", ClusterID.Load())) // Check if the cluster is already bootstrapped. ctx, cancel := context.WithTimeout(ctx, pdTimeout) @@ -543,7 +542,7 @@ func PutPDConfig(config *sc.PDConfig) error { } func ChooseToHaltPDSchedule(halt bool) { - HaltSchedule = halt + HaltSchedule.Store(halt) PDHTTPClient.SetConfig(context.Background(), map[string]any{ "schedule.halt-scheduling": strconv.FormatBool(halt), }) diff --git a/tools/pd-simulator/simulator/drive.go b/tools/pd-simulator/simulator/drive.go index 738c06533d2..8c511b5ac5c 100644 --- a/tools/pd-simulator/simulator/drive.go +++ b/tools/pd-simulator/simulator/drive.go @@ -22,6 +22,7 @@ import ( "strconv" "strings" "sync" + "sync/atomic" "time" "github.com/pingcap/errors" @@ -122,7 +123,7 @@ func (d *Driver) allocID() error { return err } ctx, cancel := context.WithTimeout(context.Background(), requestTimeout) - rootPath := path.Join("/pd", strconv.FormatUint(ClusterID, 10)) + rootPath := path.Join("/pd", strconv.FormatUint(ClusterID.Load(), 10)) allocIDPath := path.Join(rootPath, "alloc_id") _, err = etcdClient.Put(ctx, allocIDPath, string(typeutil.Uint64ToBytes(maxID+1000))) if err != nil { @@ -176,24 +177,20 @@ func (d *Driver) Tick() { d.wg.Wait() } -var HaltSchedule = false +var HaltSchedule atomic.Bool // Check checks if the simulation is completed. func (d *Driver) Check() bool { - if !HaltSchedule { + if !HaltSchedule.Load() { return false } - length := uint64(len(d.conn.Nodes) + 1) + var stats []info.StoreStats var stores []*metapb.Store - for index, s := range d.conn.Nodes { - if index >= length { - length = index + 1 - } + for _, s := range d.conn.Nodes { + s.statsMutex.RLock() stores = append(stores, s.Store) - } - stats := make([]info.StoreStats, length) - for index, node := range d.conn.Nodes { - stats[index] = *node.stats + stats = append(stats, *s.stats) + s.statsMutex.RUnlock() } return d.simCase.Checker(stores, d.raftEngine.regionsInfo, stats) } @@ -252,11 +249,13 @@ func (d *Driver) GetBootstrapInfo(r *RaftEngine) (*metapb.Store, *metapb.Region, func (d *Driver) updateNodeAvailable() { for storeID, n := range d.conn.Nodes { + n.statsMutex.Lock() if n.hasExtraUsedSpace { n.stats.StoreStats.Available = n.stats.StoreStats.Capacity - uint64(d.raftEngine.regionsInfo.GetStoreRegionSize(storeID)) - uint64(d.simConfig.RaftStore.ExtraUsedSpace) } else { n.stats.StoreStats.Available = n.stats.StoreStats.Capacity - uint64(d.raftEngine.regionsInfo.GetStoreRegionSize(storeID)) } + n.statsMutex.Unlock() } } diff --git a/tools/pd-simulator/simulator/event.go b/tools/pd-simulator/simulator/event.go index 408da5c2e62..86da86ed20d 100644 --- a/tools/pd-simulator/simulator/event.go +++ b/tools/pd-simulator/simulator/event.go @@ -240,7 +240,8 @@ func (e *DownNode) Run(raft *RaftEngine, _ int64) bool { } node.Stop() - raft.TraverseRegions(func(region *core.RegionInfo) { + regions := raft.GetRegions() + for _, region := range regions { storeIDs := region.GetStoreIDs() if _, ok := storeIDs[node.Id]; ok { downPeer := &pdpb.PeerStats{ @@ -250,6 +251,6 @@ func (e *DownNode) Run(raft *RaftEngine, _ int64) bool { region = region.Clone(core.WithDownPeers(append(region.GetDownPeers(), downPeer))) raft.SetRegion(region) } - }) + } return true } diff --git a/tools/pd-simulator/simulator/node.go b/tools/pd-simulator/simulator/node.go index c055d345425..2059107227e 100644 --- a/tools/pd-simulator/simulator/node.go +++ b/tools/pd-simulator/simulator/node.go @@ -27,6 +27,7 @@ import ( "github.com/tikv/pd/pkg/core" "github.com/tikv/pd/pkg/ratelimit" "github.com/tikv/pd/pkg/utils/syncutil" + "github.com/tikv/pd/pkg/utils/typeutil" "github.com/tikv/pd/tools/pd-simulator/simulator/cases" sc "github.com/tikv/pd/tools/pd-simulator/simulator/config" "github.com/tikv/pd/tools/pd-simulator/simulator/info" @@ -51,7 +52,7 @@ type Node struct { cancel context.CancelFunc raftEngine *RaftEngine limiter *ratelimit.RateLimiter - sizeMutex syncutil.Mutex + statsMutex syncutil.RWMutex hasExtraUsedSpace bool snapStats []*pdpb.SnapshotStat // PD client @@ -179,12 +180,15 @@ func (n *Node) storeHeartBeat() { if n.GetNodeState() != metapb.NodeState_Preparing && n.GetNodeState() != metapb.NodeState_Serving { return } - ctx, cancel := context.WithTimeout(n.ctx, pdTimeout) + n.statsMutex.Lock() stats := make([]*pdpb.SnapshotStat, len(n.snapStats)) copy(stats, n.snapStats) n.snapStats = n.snapStats[:0] n.stats.SnapshotStats = stats - err := n.client.StoreHeartbeat(ctx, &n.stats.StoreStats) + newStats := typeutil.DeepClone(&n.stats.StoreStats, core.StoreStatsFactory) + n.statsMutex.Unlock() + ctx, cancel := context.WithTimeout(n.ctx, pdTimeout) + err := n.client.StoreHeartbeat(ctx, newStats) if err != nil { simutil.Logger.Info("report store heartbeat error", zap.Uint64("node-id", n.GetId()), @@ -194,8 +198,8 @@ func (n *Node) storeHeartBeat() { } func (n *Node) compaction() { - n.sizeMutex.Lock() - defer n.sizeMutex.Unlock() + n.statsMutex.Lock() + defer n.statsMutex.Unlock() n.stats.Available += n.stats.ToCompactionSize n.stats.UsedSize -= n.stats.ToCompactionSize n.stats.ToCompactionSize = 0 @@ -211,7 +215,7 @@ func (n *Node) regionHeartBeat() { if region == nil { simutil.Logger.Fatal("region not found") } - err := n.client.RegionHeartbeat(ctx, region) + err := n.client.RegionHeartbeat(ctx, region.Clone()) if err != nil { simutil.Logger.Info("report region heartbeat error", zap.Uint64("node-id", n.Id), @@ -267,19 +271,21 @@ func (n *Node) Stop() { } func (n *Node) incUsedSize(size uint64) { - n.sizeMutex.Lock() - defer n.sizeMutex.Unlock() + n.statsMutex.Lock() + defer n.statsMutex.Unlock() n.stats.Available -= size n.stats.UsedSize += size } func (n *Node) decUsedSize(size uint64) { - n.sizeMutex.Lock() - defer n.sizeMutex.Unlock() + n.statsMutex.Lock() + defer n.statsMutex.Unlock() n.stats.ToCompactionSize += size } func (n *Node) registerSnapStats(generate, send, total uint64) { + n.statsMutex.Lock() + defer n.statsMutex.Unlock() stat := pdpb.SnapshotStat{ GenerateDurationSec: generate, SendDurationSec: send, diff --git a/tools/pd-simulator/simulator/raft.go b/tools/pd-simulator/simulator/raft.go index 45afc4d5216..727cc6ab805 100644 --- a/tools/pd-simulator/simulator/raft.go +++ b/tools/pd-simulator/simulator/raft.go @@ -82,10 +82,11 @@ func NewRaftEngine(conf *cases.Case, conn *Connection, storeConfig *config.SimCo } func (r *RaftEngine) stepRegions() { - r.TraverseRegions(func(region *core.RegionInfo) { + regions := r.GetRegions() + for _, region := range regions { r.stepLeader(region) r.stepSplit(region) - }) + } } func (r *RaftEngine) stepLeader(region *core.RegionInfo) { @@ -228,7 +229,10 @@ func (r *RaftEngine) electNewLeader(region *core.RegionInfo) *metapb.Peer { func (r *RaftEngine) GetRegion(regionID uint64) *core.RegionInfo { r.RLock() defer r.RUnlock() - return r.regionsInfo.GetRegion(regionID) + if region := r.regionsInfo.GetRegion(regionID); region != nil { + return region.Clone() + } + return nil } // GetRegionChange returns a list of RegionID for a given store. @@ -256,6 +260,13 @@ func (r *RaftEngine) TraverseRegions(lockedFunc func(*core.RegionInfo)) { r.regionsInfo.TraverseRegions(lockedFunc) } +// GetRegions gets all RegionInfo from regionMap +func (r *RaftEngine) GetRegions() []*core.RegionInfo { + r.RLock() + defer r.RUnlock() + return r.regionsInfo.GetRegions() +} + // SetRegion sets the RegionInfo with regionID func (r *RaftEngine) SetRegion(region *core.RegionInfo) []*core.RegionInfo { r.Lock() diff --git a/tools/pd-simulator/simulator/task.go b/tools/pd-simulator/simulator/task.go index c0bfa1e691b..0921838c70b 100644 --- a/tools/pd-simulator/simulator/task.go +++ b/tools/pd-simulator/simulator/task.go @@ -517,20 +517,25 @@ func processSnapshot(n *Node, stat *snapshotStat, speed uint64) bool { return true } if stat.status == pending { - if stat.action == generate && n.stats.SendingSnapCount > maxSnapGeneratorPoolSize { + n.statsMutex.RLock() + sendSnapshot, receiveSnapshot := n.stats.SendingSnapCount, n.stats.ReceivingSnapCount + n.statsMutex.RUnlock() + if stat.action == generate && sendSnapshot > maxSnapGeneratorPoolSize { return false } - if stat.action == receive && n.stats.ReceivingSnapCount > maxSnapReceivePoolSize { + if stat.action == receive && receiveSnapshot > maxSnapReceivePoolSize { return false } stat.status = running stat.generateStart = time.Now() + n.statsMutex.Lock() // If the statement is true, it will start to send or Receive the snapshot. if stat.action == generate { n.stats.SendingSnapCount++ } else { n.stats.ReceivingSnapCount++ } + n.statsMutex.Unlock() } // store should Generate/Receive snapshot by chunk size. @@ -548,11 +553,13 @@ func processSnapshot(n *Node, stat *snapshotStat, speed uint64) bool { totalSec := uint64(time.Since(stat.start).Seconds()) * speed generateSec := uint64(time.Since(stat.generateStart).Seconds()) * speed n.registerSnapStats(generateSec, 0, totalSec) + n.statsMutex.Lock() if stat.action == generate { n.stats.SendingSnapCount-- } else { n.stats.ReceivingSnapCount-- } + n.statsMutex.Unlock() } return true }