From d03f485c952463bf54f5cb137766884f2f706219 Mon Sep 17 00:00:00 2001 From: Hu# Date: Mon, 11 Sep 2023 15:42:13 +0800 Subject: [PATCH] *: check raftcluster nil (#7054) close tikv/pd#7053 Signed-off-by: husharp Co-authored-by: ti-chi-bot[bot] <108142056+ti-chi-bot[bot]@users.noreply.github.com> --- server/api/min_resolved_ts.go | 4 +-- server/grpc_service.go | 4 ++- server/handler.go | 14 +++++--- server/server.go | 60 +++++++++++++++++++++++------------ tests/server/api/api_test.go | 40 +++++++++++++++++++++++ 5 files changed, 95 insertions(+), 27 deletions(-) diff --git a/server/api/min_resolved_ts.go b/server/api/min_resolved_ts.go index ef05e91b9f7..1edf924370f 100644 --- a/server/api/min_resolved_ts.go +++ b/server/api/min_resolved_ts.go @@ -53,7 +53,7 @@ type minResolvedTS struct { // @Failure 500 {string} string "PD server failed to proceed the request." // @Router /min-resolved-ts/{store_id} [get] func (h *minResolvedTSHandler) GetStoreMinResolvedTS(w http.ResponseWriter, r *http.Request) { - c := h.svr.GetRaftCluster() + c := getCluster(r) idStr := mux.Vars(r)["store_id"] storeID, err := strconv.ParseUint(idStr, 10, 64) if err != nil { @@ -84,7 +84,7 @@ func (h *minResolvedTSHandler) GetStoreMinResolvedTS(w http.ResponseWriter, r *h // @Failure 500 {string} string "PD server failed to proceed the request." // @Router /min-resolved-ts [get] func (h *minResolvedTSHandler) GetMinResolvedTS(w http.ResponseWriter, r *http.Request) { - c := h.svr.GetRaftCluster() + c := getCluster(r) scopeMinResolvedTS := c.GetMinResolvedTS() persistInterval := c.GetPDServerConfig().MinResolvedTSPersistenceInterval diff --git a/server/grpc_service.go b/server/grpc_service.go index 0563371cdc3..973c45a622f 100644 --- a/server/grpc_service.go +++ b/server/grpc_service.go @@ -1600,7 +1600,6 @@ func (s *GrpcServer) ReportBatchSplit(ctx context.Context, request *pdpb.ReportB if rc == nil { return &pdpb.ReportBatchSplitResponse{Header: s.notBootstrappedHeader()}, nil } - _, err := rc.HandleBatchReportSplit(request) if err != nil { return &pdpb.ReportBatchSplitResponse{ @@ -2089,6 +2088,9 @@ func (s *GrpcServer) SplitAndScatterRegions(ctx context.Context, request *pdpb.S return rsp.(*pdpb.SplitAndScatterRegionsResponse), err } rc := s.GetRaftCluster() + if rc == nil { + return &pdpb.SplitAndScatterRegionsResponse{Header: s.notBootstrappedHeader()}, nil + } splitFinishedPercentage, newRegionIDs := rc.GetRegionSplitter().SplitRegions(ctx, request.GetSplitKeys(), int(request.GetRetryLimit())) scatterFinishedPercentage, err := scatterRegions(rc, newRegionIDs, request.GetGroup(), int(request.GetRetryLimit()), false) if err != nil { diff --git a/server/handler.go b/server/handler.go index 2e4b88b20e2..adc1e8ecd31 100644 --- a/server/handler.go +++ b/server/handler.go @@ -961,14 +961,20 @@ func (h *Handler) ResetTS(ts uint64, ignoreSmaller, skipUpperBoundCheck bool, _ // SetStoreLimitScene sets the limit values for different scenes func (h *Handler) SetStoreLimitScene(scene *storelimit.Scene, limitType storelimit.Type) { - cluster := h.s.GetRaftCluster() - cluster.GetStoreLimiter().ReplaceStoreLimitScene(scene, limitType) + rc := h.s.GetRaftCluster() + if rc == nil { + return + } + rc.GetStoreLimiter().ReplaceStoreLimitScene(scene, limitType) } // GetStoreLimitScene returns the limit values for different scenes func (h *Handler) GetStoreLimitScene(limitType storelimit.Type) *storelimit.Scene { - cluster := h.s.GetRaftCluster() - return cluster.GetStoreLimiter().StoreLimitScene(limitType) + rc := h.s.GetRaftCluster() + if rc == nil { + return nil + } + return rc.GetStoreLimiter().StoreLimitScene(limitType) } // GetProgressByID returns the progress details for a given store ID. diff --git a/server/server.go b/server/server.go index 34a7883cccd..b74ca5b57b3 100644 --- a/server/server.go +++ b/server/server.go @@ -1023,18 +1023,18 @@ func (s *Server) SetReplicationConfig(cfg sc.ReplicationConfig) error { } old := s.persistOptions.GetReplicationConfig() if cfg.EnablePlacementRules != old.EnablePlacementRules { - raftCluster := s.GetRaftCluster() - if raftCluster == nil { + rc := s.GetRaftCluster() + if rc == nil { return errs.ErrNotBootstrapped.GenWithStackByArgs() } if cfg.EnablePlacementRules { // initialize rule manager. - if err := raftCluster.GetRuleManager().Initialize(int(cfg.MaxReplicas), cfg.LocationLabels); err != nil { + if err := rc.GetRuleManager().Initialize(int(cfg.MaxReplicas), cfg.LocationLabels); err != nil { return err } } else { // NOTE: can be removed after placement rules feature is enabled by default. - for _, s := range raftCluster.GetStores() { + for _, s := range rc.GetStores() { if !s.IsRemoved() && s.IsTiFlash() { return errors.New("cannot disable placement rules with TiFlash nodes") } @@ -1044,8 +1044,12 @@ func (s *Server) SetReplicationConfig(cfg sc.ReplicationConfig) error { var rule *placement.Rule if cfg.EnablePlacementRules { + rc := s.GetRaftCluster() + if rc == nil { + return errs.ErrNotBootstrapped.GenWithStackByArgs() + } // replication.MaxReplicas won't work when placement rule is enabled and not only have one default rule. - defaultRule := s.GetRaftCluster().GetRuleManager().GetRule("pd", "default") + defaultRule := rc.GetRuleManager().GetRule("pd", "default") CheckInDefaultRule := func() error { // replication config won't work when placement rule is enabled and exceeds one default rule @@ -1071,7 +1075,11 @@ func (s *Server) SetReplicationConfig(cfg sc.ReplicationConfig) error { if rule != nil { rule.Count = int(cfg.MaxReplicas) rule.LocationLabels = cfg.LocationLabels - if err := s.GetRaftCluster().GetRuleManager().SetRule(rule); err != nil { + rc := s.GetRaftCluster() + if rc == nil { + return errs.ErrNotBootstrapped.GenWithStackByArgs() + } + if err := rc.GetRuleManager().SetRule(rule); err != nil { log.Error("failed to update rule count", errs.ZapError(err)) return err @@ -1083,7 +1091,11 @@ func (s *Server) SetReplicationConfig(cfg sc.ReplicationConfig) error { s.persistOptions.SetReplicationConfig(old) if rule != nil { rule.Count = int(old.MaxReplicas) - if e := s.GetRaftCluster().GetRuleManager().SetRule(rule); e != nil { + rc := s.GetRaftCluster() + if rc == nil { + return errs.ErrNotBootstrapped.GenWithStackByArgs() + } + if e := rc.GetRuleManager().SetRule(rule); e != nil { log.Error("failed to roll back count of rule when update replication config", errs.ZapError(e)) } } @@ -1371,18 +1383,18 @@ func (s *Server) GetServerOption() *config.PersistOptions { // GetMetaRegions gets meta regions from cluster. func (s *Server) GetMetaRegions() []*metapb.Region { - cluster := s.GetRaftCluster() - if cluster != nil { - return cluster.GetMetaRegions() + rc := s.GetRaftCluster() + if rc != nil { + return rc.GetMetaRegions() } return nil } // GetRegions gets regions from cluster. func (s *Server) GetRegions() []*core.RegionInfo { - cluster := s.GetRaftCluster() - if cluster != nil { - return cluster.GetRegions() + rc := s.GetRaftCluster() + if rc != nil { + return rc.GetRegions() } return nil } @@ -1519,9 +1531,9 @@ func (s *Server) SetReplicationModeConfig(cfg config.ReplicationModeConfig) erro } log.Info("replication mode config is updated", zap.Reflect("new", cfg), zap.Reflect("old", old)) - cluster := s.GetRaftCluster() - if cluster != nil { - err := cluster.GetReplicationMode().UpdateConfig(cfg) + rc := s.GetRaftCluster() + if rc != nil { + err := rc.GetReplicationMode().UpdateConfig(cfg) if err != nil { log.Warn("failed to update replication mode", errs.ZapError(err)) // revert to old config @@ -1992,7 +2004,11 @@ func (s *Server) RecoverAllocID(ctx context.Context, id uint64) error { // GetExternalTS returns external timestamp. func (s *Server) GetExternalTS() uint64 { - return s.GetRaftCluster().GetExternalTS() + rc := s.GetRaftCluster() + if rc == nil { + return 0 + } + return rc.GetExternalTS() } // SetExternalTS returns external timestamp. @@ -2002,14 +2018,18 @@ func (s *Server) SetExternalTS(externalTS, globalTS uint64) error { log.Error(desc, zap.Uint64("request timestamp", externalTS), zap.Uint64("global ts", globalTS)) return errors.New(desc) } - currentExternalTS := s.GetRaftCluster().GetExternalTS() + c := s.GetRaftCluster() + if c == nil { + return errs.ErrNotBootstrapped.FastGenByArgs() + } + currentExternalTS := c.GetExternalTS() if tsoutil.CompareTimestampUint64(externalTS, currentExternalTS) != 1 { desc := "the external timestamp should be larger than current external timestamp" log.Error(desc, zap.Uint64("request", externalTS), zap.Uint64("current", currentExternalTS)) return errors.New(desc) } - s.GetRaftCluster().SetExternalTS(externalTS) - return nil + + return c.SetExternalTS(externalTS) } // IsLocalTSOEnabled returns if the local TSO is enabled. diff --git a/tests/server/api/api_test.go b/tests/server/api/api_test.go index 623af9b0a82..61d47d7790c 100644 --- a/tests/server/api/api_test.go +++ b/tests/server/api/api_test.go @@ -807,6 +807,46 @@ func TestRemovingProgress(t *testing.T) { re.NoError(failpoint.Disable("github.com/tikv/pd/server/cluster/highFrequencyClusterJobs")) } +func TestSendApiWhenRestartRaftCluster(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + cluster, err := tests.NewTestCluster(ctx, 3, func(conf *config.Config, serverName string) { + conf.Replication.MaxReplicas = 1 + }) + re.NoError(err) + defer cluster.Destroy() + + err = cluster.RunInitialServers() + re.NoError(err) + leader := cluster.GetServer(cluster.WaitLeader()) + + grpcPDClient := testutil.MustNewGrpcClient(re, leader.GetAddr()) + clusterID := leader.GetClusterID() + req := &pdpb.BootstrapRequest{ + Header: testutil.NewRequestHeader(clusterID), + Store: &metapb.Store{Id: 1, Address: "127.0.0.1:0"}, + Region: &metapb.Region{Id: 2, Peers: []*metapb.Peer{{Id: 3, StoreId: 1, Role: metapb.PeerRole_Voter}}}, + } + resp, err := grpcPDClient.Bootstrap(context.Background(), req) + re.NoError(err) + re.Nil(resp.GetHeader().GetError()) + + // Mock restart raft cluster + rc := leader.GetRaftCluster() + re.NotNil(rc) + rc.Stop() + + // Mock client-go will still send request + output := sendRequest(re, leader.GetAddr()+"/pd/api/v1/min-resolved-ts", http.MethodGet, http.StatusInternalServerError) + re.Contains(string(output), "TiKV cluster not bootstrapped, please start TiKV first") + + err = rc.Start(leader.GetServer()) + re.NoError(err) + rc = leader.GetRaftCluster() + re.NotNil(rc) +} + func TestPreparingProgress(t *testing.T) { re := require.New(t) re.NoError(failpoint.Enable("github.com/tikv/pd/server/cluster/highFrequencyClusterJobs", `return(true)`))