Skip to content

Commit

Permalink
*: check raftcluster nil (#7054)
Browse files Browse the repository at this point in the history
close #7053

Signed-off-by: husharp <[email protected]>

Co-authored-by: ti-chi-bot[bot] <108142056+ti-chi-bot[bot]@users.noreply.github.com>
  • Loading branch information
HuSharp and ti-chi-bot[bot] committed Sep 11, 2023
1 parent 83369eb commit d03f485
Show file tree
Hide file tree
Showing 5 changed files with 95 additions and 27 deletions.
4 changes: 2 additions & 2 deletions server/api/min_resolved_ts.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ type minResolvedTS struct {
// @Failure 500 {string} string "PD server failed to proceed the request."
// @Router /min-resolved-ts/{store_id} [get]
func (h *minResolvedTSHandler) GetStoreMinResolvedTS(w http.ResponseWriter, r *http.Request) {
c := h.svr.GetRaftCluster()
c := getCluster(r)
idStr := mux.Vars(r)["store_id"]
storeID, err := strconv.ParseUint(idStr, 10, 64)
if err != nil {
Expand Down Expand Up @@ -84,7 +84,7 @@ func (h *minResolvedTSHandler) GetStoreMinResolvedTS(w http.ResponseWriter, r *h
// @Failure 500 {string} string "PD server failed to proceed the request."
// @Router /min-resolved-ts [get]
func (h *minResolvedTSHandler) GetMinResolvedTS(w http.ResponseWriter, r *http.Request) {
c := h.svr.GetRaftCluster()
c := getCluster(r)
scopeMinResolvedTS := c.GetMinResolvedTS()
persistInterval := c.GetPDServerConfig().MinResolvedTSPersistenceInterval

Expand Down
4 changes: 3 additions & 1 deletion server/grpc_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -1600,7 +1600,6 @@ func (s *GrpcServer) ReportBatchSplit(ctx context.Context, request *pdpb.ReportB
if rc == nil {
return &pdpb.ReportBatchSplitResponse{Header: s.notBootstrappedHeader()}, nil
}

_, err := rc.HandleBatchReportSplit(request)
if err != nil {
return &pdpb.ReportBatchSplitResponse{
Expand Down Expand Up @@ -2089,6 +2088,9 @@ func (s *GrpcServer) SplitAndScatterRegions(ctx context.Context, request *pdpb.S
return rsp.(*pdpb.SplitAndScatterRegionsResponse), err
}
rc := s.GetRaftCluster()
if rc == nil {
return &pdpb.SplitAndScatterRegionsResponse{Header: s.notBootstrappedHeader()}, nil
}
splitFinishedPercentage, newRegionIDs := rc.GetRegionSplitter().SplitRegions(ctx, request.GetSplitKeys(), int(request.GetRetryLimit()))
scatterFinishedPercentage, err := scatterRegions(rc, newRegionIDs, request.GetGroup(), int(request.GetRetryLimit()), false)
if err != nil {
Expand Down
14 changes: 10 additions & 4 deletions server/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -961,14 +961,20 @@ func (h *Handler) ResetTS(ts uint64, ignoreSmaller, skipUpperBoundCheck bool, _

// SetStoreLimitScene sets the limit values for different scenes
func (h *Handler) SetStoreLimitScene(scene *storelimit.Scene, limitType storelimit.Type) {
cluster := h.s.GetRaftCluster()
cluster.GetStoreLimiter().ReplaceStoreLimitScene(scene, limitType)
rc := h.s.GetRaftCluster()
if rc == nil {
return
}
rc.GetStoreLimiter().ReplaceStoreLimitScene(scene, limitType)
}

// GetStoreLimitScene returns the limit values for different scenes
func (h *Handler) GetStoreLimitScene(limitType storelimit.Type) *storelimit.Scene {
cluster := h.s.GetRaftCluster()
return cluster.GetStoreLimiter().StoreLimitScene(limitType)
rc := h.s.GetRaftCluster()
if rc == nil {
return nil
}
return rc.GetStoreLimiter().StoreLimitScene(limitType)
}

// GetProgressByID returns the progress details for a given store ID.
Expand Down
60 changes: 40 additions & 20 deletions server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -1023,18 +1023,18 @@ func (s *Server) SetReplicationConfig(cfg sc.ReplicationConfig) error {
}
old := s.persistOptions.GetReplicationConfig()
if cfg.EnablePlacementRules != old.EnablePlacementRules {
raftCluster := s.GetRaftCluster()
if raftCluster == nil {
rc := s.GetRaftCluster()
if rc == nil {
return errs.ErrNotBootstrapped.GenWithStackByArgs()
}
if cfg.EnablePlacementRules {
// initialize rule manager.
if err := raftCluster.GetRuleManager().Initialize(int(cfg.MaxReplicas), cfg.LocationLabels); err != nil {
if err := rc.GetRuleManager().Initialize(int(cfg.MaxReplicas), cfg.LocationLabels); err != nil {
return err
}
} else {
// NOTE: can be removed after placement rules feature is enabled by default.
for _, s := range raftCluster.GetStores() {
for _, s := range rc.GetStores() {
if !s.IsRemoved() && s.IsTiFlash() {
return errors.New("cannot disable placement rules with TiFlash nodes")
}
Expand All @@ -1044,8 +1044,12 @@ func (s *Server) SetReplicationConfig(cfg sc.ReplicationConfig) error {

var rule *placement.Rule
if cfg.EnablePlacementRules {
rc := s.GetRaftCluster()
if rc == nil {
return errs.ErrNotBootstrapped.GenWithStackByArgs()
}
// replication.MaxReplicas won't work when placement rule is enabled and not only have one default rule.
defaultRule := s.GetRaftCluster().GetRuleManager().GetRule("pd", "default")
defaultRule := rc.GetRuleManager().GetRule("pd", "default")

CheckInDefaultRule := func() error {
// replication config won't work when placement rule is enabled and exceeds one default rule
Expand All @@ -1071,7 +1075,11 @@ func (s *Server) SetReplicationConfig(cfg sc.ReplicationConfig) error {
if rule != nil {
rule.Count = int(cfg.MaxReplicas)
rule.LocationLabels = cfg.LocationLabels
if err := s.GetRaftCluster().GetRuleManager().SetRule(rule); err != nil {
rc := s.GetRaftCluster()
if rc == nil {
return errs.ErrNotBootstrapped.GenWithStackByArgs()
}
if err := rc.GetRuleManager().SetRule(rule); err != nil {
log.Error("failed to update rule count",
errs.ZapError(err))
return err
Expand All @@ -1083,7 +1091,11 @@ func (s *Server) SetReplicationConfig(cfg sc.ReplicationConfig) error {
s.persistOptions.SetReplicationConfig(old)
if rule != nil {
rule.Count = int(old.MaxReplicas)
if e := s.GetRaftCluster().GetRuleManager().SetRule(rule); e != nil {
rc := s.GetRaftCluster()
if rc == nil {
return errs.ErrNotBootstrapped.GenWithStackByArgs()
}
if e := rc.GetRuleManager().SetRule(rule); e != nil {
log.Error("failed to roll back count of rule when update replication config", errs.ZapError(e))
}
}
Expand Down Expand Up @@ -1371,18 +1383,18 @@ func (s *Server) GetServerOption() *config.PersistOptions {

// GetMetaRegions gets meta regions from cluster.
func (s *Server) GetMetaRegions() []*metapb.Region {
cluster := s.GetRaftCluster()
if cluster != nil {
return cluster.GetMetaRegions()
rc := s.GetRaftCluster()
if rc != nil {
return rc.GetMetaRegions()
}
return nil
}

// GetRegions gets regions from cluster.
func (s *Server) GetRegions() []*core.RegionInfo {
cluster := s.GetRaftCluster()
if cluster != nil {
return cluster.GetRegions()
rc := s.GetRaftCluster()
if rc != nil {
return rc.GetRegions()
}
return nil
}
Expand Down Expand Up @@ -1519,9 +1531,9 @@ func (s *Server) SetReplicationModeConfig(cfg config.ReplicationModeConfig) erro
}
log.Info("replication mode config is updated", zap.Reflect("new", cfg), zap.Reflect("old", old))

cluster := s.GetRaftCluster()
if cluster != nil {
err := cluster.GetReplicationMode().UpdateConfig(cfg)
rc := s.GetRaftCluster()
if rc != nil {
err := rc.GetReplicationMode().UpdateConfig(cfg)
if err != nil {
log.Warn("failed to update replication mode", errs.ZapError(err))
// revert to old config
Expand Down Expand Up @@ -1992,7 +2004,11 @@ func (s *Server) RecoverAllocID(ctx context.Context, id uint64) error {

// GetExternalTS returns external timestamp.
func (s *Server) GetExternalTS() uint64 {
return s.GetRaftCluster().GetExternalTS()
rc := s.GetRaftCluster()
if rc == nil {
return 0
}
return rc.GetExternalTS()
}

// SetExternalTS returns external timestamp.
Expand All @@ -2002,14 +2018,18 @@ func (s *Server) SetExternalTS(externalTS, globalTS uint64) error {
log.Error(desc, zap.Uint64("request timestamp", externalTS), zap.Uint64("global ts", globalTS))
return errors.New(desc)
}
currentExternalTS := s.GetRaftCluster().GetExternalTS()
c := s.GetRaftCluster()
if c == nil {
return errs.ErrNotBootstrapped.FastGenByArgs()
}
currentExternalTS := c.GetExternalTS()
if tsoutil.CompareTimestampUint64(externalTS, currentExternalTS) != 1 {
desc := "the external timestamp should be larger than current external timestamp"
log.Error(desc, zap.Uint64("request", externalTS), zap.Uint64("current", currentExternalTS))
return errors.New(desc)
}
s.GetRaftCluster().SetExternalTS(externalTS)
return nil

return c.SetExternalTS(externalTS)
}

// IsLocalTSOEnabled returns if the local TSO is enabled.
Expand Down
40 changes: 40 additions & 0 deletions tests/server/api/api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -807,6 +807,46 @@ func TestRemovingProgress(t *testing.T) {
re.NoError(failpoint.Disable("github.com/tikv/pd/server/cluster/highFrequencyClusterJobs"))
}

func TestSendApiWhenRestartRaftCluster(t *testing.T) {
re := require.New(t)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
cluster, err := tests.NewTestCluster(ctx, 3, func(conf *config.Config, serverName string) {
conf.Replication.MaxReplicas = 1
})
re.NoError(err)
defer cluster.Destroy()

err = cluster.RunInitialServers()
re.NoError(err)
leader := cluster.GetServer(cluster.WaitLeader())

grpcPDClient := testutil.MustNewGrpcClient(re, leader.GetAddr())
clusterID := leader.GetClusterID()
req := &pdpb.BootstrapRequest{
Header: testutil.NewRequestHeader(clusterID),
Store: &metapb.Store{Id: 1, Address: "127.0.0.1:0"},
Region: &metapb.Region{Id: 2, Peers: []*metapb.Peer{{Id: 3, StoreId: 1, Role: metapb.PeerRole_Voter}}},
}
resp, err := grpcPDClient.Bootstrap(context.Background(), req)
re.NoError(err)
re.Nil(resp.GetHeader().GetError())

// Mock restart raft cluster
rc := leader.GetRaftCluster()
re.NotNil(rc)
rc.Stop()

// Mock client-go will still send request
output := sendRequest(re, leader.GetAddr()+"/pd/api/v1/min-resolved-ts", http.MethodGet, http.StatusInternalServerError)
re.Contains(string(output), "TiKV cluster not bootstrapped, please start TiKV first")

err = rc.Start(leader.GetServer())
re.NoError(err)
rc = leader.GetRaftCluster()
re.NotNil(rc)
}

func TestPreparingProgress(t *testing.T) {
re := require.New(t)
re.NoError(failpoint.Enable("github.com/tikv/pd/server/cluster/highFrequencyClusterJobs", `return(true)`))
Expand Down

0 comments on commit d03f485

Please sign in to comment.