diff --git a/pkg/schedule/checker/replica_strategy.go b/pkg/schedule/checker/replica_strategy.go index 27e6301c3dc..797d8bb2853 100644 --- a/pkg/schedule/checker/replica_strategy.go +++ b/pkg/schedule/checker/replica_strategy.go @@ -93,6 +93,9 @@ func (s *ReplicaStrategy) SelectStoreToAdd(coLocationStores []*core.StoreInfo, e // SelectStoreToFix returns a store to replace down/offline old peer. The location // placement after scheduling is allowed to be worse than original. func (s *ReplicaStrategy) SelectStoreToFix(coLocationStores []*core.StoreInfo, old uint64) (uint64, bool) { + if len(coLocationStores) == 0 { + return 0, false + } // trick to avoid creating a slice with `old` removed. s.swapStoreToFirst(coLocationStores, old) return s.SelectStoreToAdd(coLocationStores[1:]) @@ -101,6 +104,9 @@ func (s *ReplicaStrategy) SelectStoreToFix(coLocationStores []*core.StoreInfo, o // SelectStoreToImprove returns a store to replace oldStore. The location // placement after scheduling should be better than original. func (s *ReplicaStrategy) SelectStoreToImprove(coLocationStores []*core.StoreInfo, old uint64) (uint64, bool) { + if len(coLocationStores) == 0 { + return 0, false + } // trick to avoid creating a slice with `old` removed. s.swapStoreToFirst(coLocationStores, old) oldStore := s.cluster.GetStore(old) diff --git a/pkg/schedule/checker/rule_checker.go b/pkg/schedule/checker/rule_checker.go index aad87ec54b3..d405941300e 100644 --- a/pkg/schedule/checker/rule_checker.go +++ b/pkg/schedule/checker/rule_checker.go @@ -391,19 +391,26 @@ func (c *RuleChecker) allowLeader(fit *placement.RegionFit, peer *metapb.Peer) b } func (c *RuleChecker) fixBetterLocation(region *core.RegionInfo, rf *placement.RuleFit) (*operator.Operator, error) { - if len(rf.Rule.LocationLabels) == 0 || rf.Rule.Count <= 1 { + if len(rf.Rule.LocationLabels) == 0 { return nil, nil } isWitness := rf.Rule.IsWitness && c.isWitnessEnabled() // If the peer to be moved is a witness, since no snapshot is needed, we also reuse the fast failover logic. strategy := c.strategy(region, rf.Rule, isWitness) - ruleStores := c.getRuleFitStores(rf) - oldStore := strategy.SelectStoreToRemove(ruleStores) + regionStores := c.cluster.GetRegionStores(region) + oldStore := strategy.SelectStoreToRemove(regionStores) if oldStore == 0 { return nil, nil } - newStore, filterByTempState := strategy.SelectStoreToImprove(ruleStores, oldStore) + var coLocationStores []*core.StoreInfo + for _, s := range regionStores { + if placement.MatchLabelConstraints(s, rf.Rule.LabelConstraints) { + coLocationStores = append(coLocationStores, s) + } + } + + newStore, filterByTempState := strategy.SelectStoreToImprove(coLocationStores, oldStore) if newStore == 0 { log.Debug("no replacement store", zap.Uint64("region-id", region.GetID())) c.handleFilterState(region, filterByTempState) diff --git a/pkg/schedule/checker/rule_checker_test.go b/pkg/schedule/checker/rule_checker_test.go index c204345faa3..ebe1d1aadaf 100644 --- a/pkg/schedule/checker/rule_checker_test.go +++ b/pkg/schedule/checker/rule_checker_test.go @@ -1332,3 +1332,62 @@ func (suite *ruleCheckerTestSuite) TestPendingList() { _, exist = suite.rc.pendingList.Get(1) suite.False(exist) } + +func (suite *ruleCheckerTestSuite) TestLocationLabels() { + suite.cluster.AddLabelsStore(1, 1, map[string]string{"zone": "z1", "rack": "r1", "host": "h1"}) + suite.cluster.AddLabelsStore(2, 1, map[string]string{"zone": "z1", "rack": "r1", "host": "h1"}) + suite.cluster.AddLabelsStore(3, 1, map[string]string{"zone": "z1", "rack": "r2", "host": "h1"}) + suite.cluster.AddLabelsStore(4, 1, map[string]string{"zone": "z1", "rack": "r2", "host": "h1"}) + suite.cluster.AddLabelsStore(5, 1, map[string]string{"zone": "z2", "rack": "r3", "host": "h2"}) + suite.cluster.AddLabelsStore(6, 1, map[string]string{"zone": "z2", "rack": "r3", "host": "h2"}) + suite.cluster.AddLeaderRegionWithRange(1, "", "", 1, 2, 5) + rule1 := &placement.Rule{ + GroupID: "pd", + ID: "test1", + Role: placement.Leader, + Count: 1, + LabelConstraints: []placement.LabelConstraint{ + { + Key: "zone", + Op: placement.In, + Values: []string{"z1"}, + }, + }, + LocationLabels: []string{"rack"}, + } + rule2 := &placement.Rule{ + GroupID: "pd", + ID: "test2", + Role: placement.Voter, + Count: 1, + LabelConstraints: []placement.LabelConstraint{ + { + Key: "zone", + Op: placement.In, + Values: []string{"z1"}, + }, + }, + LocationLabels: []string{"rack"}, + } + rule3 := &placement.Rule{ + GroupID: "pd", + ID: "test3", + Role: placement.Voter, + Count: 1, + LabelConstraints: []placement.LabelConstraint{ + { + Key: "zone", + Op: placement.In, + Values: []string{"z2"}, + }, + }, + LocationLabels: []string{"rack"}, + } + suite.ruleManager.SetRule(rule1) + suite.ruleManager.SetRule(rule2) + suite.ruleManager.SetRule(rule3) + suite.ruleManager.DeleteRule("pd", "default") + op := suite.rc.Check(suite.cluster.GetRegion(1)) + suite.NotNil(op) + suite.Equal("move-to-better-location", op.Desc()) +} diff --git a/pkg/utils/apiutil/serverapi/middleware.go b/pkg/utils/apiutil/serverapi/middleware.go index fb5d20f930e..653ede75e7a 100644 --- a/pkg/utils/apiutil/serverapi/middleware.go +++ b/pkg/utils/apiutil/serverapi/middleware.go @@ -30,6 +30,7 @@ import ( const ( PDRedirectorHeader = "PD-Redirector" PDAllowFollowerHandle = "PD-Allow-follower-handle" + ForwardedForHeader = "X-Forwarded-For" ) type runtimeServiceValidator struct { @@ -144,6 +145,7 @@ func (h *redirector) ServeHTTP(w http.ResponseWriter, r *http.Request, next http } r.Header.Set(PDRedirectorHeader, h.s.Name()) + r.Header.Add(ForwardedForHeader, r.RemoteAddr) var clientUrls []string if matchedFlag { diff --git a/server/api/router.go b/server/api/router.go index 87e59941dee..222e2b6f5cb 100644 --- a/server/api/router.go +++ b/server/api/router.go @@ -240,7 +240,7 @@ func createRouter(prefix string, svr *server.Server) *mux.Router { srd := createStreamingRender() regionsAllHandler := newRegionsHandler(svr, srd) - registerFunc(clusterRouter, "/regions", regionsAllHandler.GetRegions, setMethods(http.MethodGet), setAuditBackend(prometheus)) + registerFunc(clusterRouter, "/regions", regionsAllHandler.GetRegions, setMethods(http.MethodGet), setAuditBackend(localLog, prometheus)) regionsHandler := newRegionsHandler(svr, rd) registerFunc(clusterRouter, "/regions/key", regionsHandler.ScanRegions, setMethods(http.MethodGet), setAuditBackend(prometheus)) @@ -289,7 +289,7 @@ func createRouter(prefix string, svr *server.Server) *mux.Router { registerFunc(apiRouter, "/leader/transfer/{next_leader}", leaderHandler.TransferLeader, setMethods(http.MethodPost), setAuditBackend(localLog, prometheus)) statsHandler := newStatsHandler(svr, rd) - registerFunc(clusterRouter, "/stats/region", statsHandler.GetRegionStatus, setMethods(http.MethodGet), setAuditBackend(prometheus)) + registerFunc(clusterRouter, "/stats/region", statsHandler.GetRegionStatus, setMethods(http.MethodGet), setAuditBackend(localLog, prometheus)) trendHandler := newTrendHandler(svr, rd) registerFunc(apiRouter, "/trend", trendHandler.GetTrend, setMethods(http.MethodGet), setAuditBackend(prometheus)) diff --git a/tests/server/api/api_test.go b/tests/server/api/api_test.go index 7fb7e7d1236..1e2a5531ba6 100644 --- a/tests/server/api/api_test.go +++ b/tests/server/api/api_test.go @@ -445,6 +445,7 @@ func (suite *middlewareTestSuite) TestAuditPrometheusBackend() { func (suite *middlewareTestSuite) TestAuditLocalLogBackend() { tempStdoutFile, _ := os.CreateTemp("/tmp", "pd_tests") + defer os.Remove(tempStdoutFile.Name()) cfg := &log.Config{} cfg.File.Filename = tempStdoutFile.Name() cfg.Level = "info" @@ -471,8 +472,6 @@ func (suite *middlewareTestSuite) TestAuditLocalLogBackend() { suite.Contains(string(b), "audit log") suite.NoError(err) suite.Equal(http.StatusOK, resp.StatusCode) - - os.Remove(tempStdoutFile.Name()) } func BenchmarkDoRequestWithLocalLogAudit(b *testing.B) { @@ -656,6 +655,32 @@ func (suite *redirectorTestSuite) TestNotLeader() { suite.NoError(err) } +func (suite *redirectorTestSuite) TestXForwardedFor() { + leader := suite.cluster.GetServer(suite.cluster.GetLeader()) + suite.NoError(leader.BootstrapCluster()) + tempStdoutFile, _ := os.CreateTemp("/tmp", "pd_tests") + defer os.Remove(tempStdoutFile.Name()) + cfg := &log.Config{} + cfg.File.Filename = tempStdoutFile.Name() + cfg.Level = "info" + lg, p, _ := log.InitLogger(cfg) + log.ReplaceGlobals(lg, p) + + follower := suite.cluster.GetServer(suite.cluster.GetFollower()) + addr := follower.GetAddr() + "/pd/api/v1/regions" + request, err := http.NewRequest(http.MethodGet, addr, nil) + suite.NoError(err) + resp, err := dialClient.Do(request) + suite.NoError(err) + defer resp.Body.Close() + suite.Equal(http.StatusOK, resp.StatusCode) + time.Sleep(1 * time.Second) + b, _ := os.ReadFile(tempStdoutFile.Name()) + l := string(b) + suite.Contains(l, "/pd/api/v1/regions") + suite.NotContains(l, suite.cluster.GetConfig().GetClientURLs()) +} + func mustRequestSuccess(re *require.Assertions, s *server.Server) http.Header { resp, err := dialClient.Get(s.GetAddr() + "/pd/api/v1/version") re.NoError(err) diff --git a/tools/pd-simulator/main.go b/tools/pd-simulator/main.go index 6b638edb365..60d8874d083 100644 --- a/tools/pd-simulator/main.go +++ b/tools/pd-simulator/main.go @@ -43,7 +43,7 @@ import ( ) var ( - pdAddr = flag.String("pd", "", "pd address") + pdAddr = flag.String("pd-endpoints", "", "pd address") configFile = flag.String("config", "conf/simconfig.toml", "config file") caseName = flag.String("case", "", "case name") serverLogLevel = flag.String("serverLog", "info", "pd server log level") diff --git a/tools/pd-simulator/simulator/cases/balance_region.go b/tools/pd-simulator/simulator/cases/balance_region.go index 39f3ef29379..0a013cf3876 100644 --- a/tools/pd-simulator/simulator/cases/balance_region.go +++ b/tools/pd-simulator/simulator/cases/balance_region.go @@ -17,7 +17,6 @@ package cases import ( "time" - "github.com/docker/go-units" "github.com/pingcap/kvproto/pkg/metapb" "github.com/tikv/pd/pkg/core" "github.com/tikv/pd/tools/pd-simulator/simulator/info" @@ -55,8 +54,6 @@ func newRedundantBalanceRegion() *Case { ID: IDAllocator.nextID(), Peers: peers, Leader: peers[0], - Size: 96 * units.MiB, - Keys: 960000, }) } diff --git a/tools/pd-simulator/simulator/config.go b/tools/pd-simulator/simulator/config.go index 0ea26528837..df3fb714462 100644 --- a/tools/pd-simulator/simulator/config.go +++ b/tools/pd-simulator/simulator/config.go @@ -117,6 +117,9 @@ func (sc *SimConfig) Adjust(meta *toml.MetaData) error { return sc.ServerConfig.Adjust(meta, false) } +func (sc *SimConfig) speed() uint64 { + return uint64(time.Second / sc.SimTickInterval.Duration) +} // PDConfig saves some config which may be changed in PD. type PDConfig struct { diff --git a/tools/pd-simulator/simulator/node.go b/tools/pd-simulator/simulator/node.go index cd76d80b3c4..b8fb422d6dd 100644 --- a/tools/pd-simulator/simulator/node.go +++ b/tools/pd-simulator/simulator/node.go @@ -52,6 +52,7 @@ type Node struct { limiter *ratelimit.RateLimiter sizeMutex sync.Mutex hasExtraUsedSpace bool + snapStats []*pdpb.SnapshotStat } // NewNode returns a Node. @@ -91,8 +92,8 @@ func NewNode(s *cases.Store, pdAddr string, config *SimConfig) (*Node, error) { cancel() return nil, err } - ratio := int64(time.Second) / config.SimTickInterval.Milliseconds() - speed := config.StoreIOMBPerSecond * units.MiB * ratio + ratio := config.speed() + speed := config.StoreIOMBPerSecond * units.MiB * int64(ratio) return &Node{ Store: store, stats: stats, @@ -104,6 +105,7 @@ func NewNode(s *cases.Store, pdAddr string, config *SimConfig) (*Node, error) { limiter: ratelimit.NewRateLimiter(float64(speed), int(speed)), tick: uint64(rand.Intn(storeHeartBeatPeriod)), hasExtraUsedSpace: s.HasExtraUsedSpace, + snapStats: make([]*pdpb.SnapshotStat, 0), }, nil } @@ -191,6 +193,10 @@ func (n *Node) storeHeartBeat() { return } ctx, cancel := context.WithTimeout(n.ctx, pdTimeout) + stats := make([]*pdpb.SnapshotStat, len(n.snapStats)) + copy(stats, n.snapStats) + n.snapStats = n.snapStats[:0] + n.stats.SnapshotStats = stats err := n.client.StoreHeartbeat(ctx, &n.stats.StoreStats) if err != nil { simutil.Logger.Info("report heartbeat error", @@ -279,3 +285,12 @@ func (n *Node) decUsedSize(size uint64) { defer n.sizeMutex.Unlock() n.stats.ToCompactionSize += size } + +func (n *Node) registerSnapStats(generate, send, total uint64) { + stat := pdpb.SnapshotStat{ + GenerateDurationSec: generate, + SendDurationSec: send, + TotalDurationSec: total, + } + n.snapStats = append(n.snapStats, &stat) +} diff --git a/tools/pd-simulator/simulator/task.go b/tools/pd-simulator/simulator/task.go index 8e55902615f..b1c609b503d 100644 --- a/tools/pd-simulator/simulator/task.go +++ b/tools/pd-simulator/simulator/task.go @@ -415,13 +415,14 @@ func (a *addPeer) tick(engine *RaftEngine, region *core.RegionInfo) (newRegion * pendingPeers := append(region.GetPendingPeers(), a.peer) return region.Clone(core.WithAddPeer(a.peer), core.WithIncConfVer(), core.WithPendingPeers(pendingPeers)), false } + speed := engine.storeConfig.speed() // Step 2: Process Snapshot - if !processSnapshot(sendNode, a.sendingStat) { + if !processSnapshot(sendNode, a.sendingStat, speed) { return nil, false } sendStoreID := fmt.Sprintf("store-%d", sendNode.Id) snapshotCounter.WithLabelValues(sendStoreID, "send").Inc() - if !processSnapshot(recvNode, a.receivingStat) { + if !processSnapshot(recvNode, a.receivingStat, speed) { return nil, false } recvStoreID := fmt.Sprintf("store-%d", recvNode.Id) @@ -492,10 +493,11 @@ func removeDownPeers(region *core.RegionInfo, removePeer *metapb.Peer) core.Regi } type snapshotStat struct { - action snapAction - remainSize int64 - status snapStatus - start time.Time + action snapAction + remainSize int64 + status snapStatus + start time.Time + generateStart time.Time } func newSnapshotState(size int64, action snapAction) *snapshotStat { @@ -510,7 +512,7 @@ func newSnapshotState(size int64, action snapAction) *snapshotStat { } } -func processSnapshot(n *Node, stat *snapshotStat) bool { +func processSnapshot(n *Node, stat *snapshotStat, speed uint64) bool { if stat.status == finished { return true } @@ -522,6 +524,7 @@ func processSnapshot(n *Node, stat *snapshotStat) bool { return false } stat.status = running + stat.generateStart = time.Now() // If the statement is true, it will start to send or Receive the snapshot. if stat.action == generate { n.stats.SendingSnapCount++ @@ -542,6 +545,9 @@ func processSnapshot(n *Node, stat *snapshotStat) bool { } if stat.status == running { stat.status = finished + totalSec := uint64(time.Since(stat.start).Seconds()) * speed + generateSec := uint64(time.Since(stat.generateStart).Seconds()) * speed + n.registerSnapStats(generateSec, 0, totalSec) if stat.action == generate { n.stats.SendingSnapCount-- } else {