diff --git a/errors.toml b/errors.toml index 1d10d40d294..b6123058310 100644 --- a/errors.toml +++ b/errors.toml @@ -551,6 +551,11 @@ error = ''' build rule list failed, %s ''' +["PD:placement:ErrKeyFormat"] +error = ''' +key should be in hex format, %s +''' + ["PD:placement:ErrLoadRule"] error = ''' load rule failed @@ -561,11 +566,21 @@ error = ''' load rule group failed ''' +["PD:placement:ErrPlacementDisabled"] +error = ''' +placement rules feature is disabled +''' + ["PD:placement:ErrRuleContent"] error = ''' invalid rule content, %s ''' +["PD:placement:ErrRuleNotFound"] +error = ''' +rule not found +''' + ["PD:plugin:ErrLoadPlugin"] error = ''' failed to load plugin @@ -616,6 +631,11 @@ error = ''' region %v has abnormal peer ''' +["PD:region:ErrRegionInvalidID"] +error = ''' +invalid region id +''' + ["PD:region:ErrRegionNotAdjacent"] error = ''' two regions are not adjacent diff --git a/pkg/election/leadership.go b/pkg/election/leadership.go index d5d73e90b58..8cfdcf423ac 100644 --- a/pkg/election/leadership.go +++ b/pkg/election/leadership.go @@ -260,6 +260,7 @@ func (ls *Leadership) Watch(serverCtx context.Context, revision int64) { continue } } + lastReceivedResponseTime = time.Now() log.Info("watch channel is created", zap.Int64("revision", revision), zap.String("leader-key", ls.leaderKey), zap.String("purpose", ls.purpose)) watchChanLoop: diff --git a/pkg/errs/errno.go b/pkg/errs/errno.go index e5bac8519be..b8a882cd187 100644 --- a/pkg/errs/errno.go +++ b/pkg/errs/errno.go @@ -102,6 +102,8 @@ var ( // region errors var ( + // ErrRegionInvalidID is error info for region id invalid. + ErrRegionInvalidID = errors.Normalize("invalid region id", errors.RFCCodeText("PD:region:ErrRegionInvalidID")) // ErrRegionNotAdjacent is error info for region not adjacent. ErrRegionNotAdjacent = errors.Normalize("two regions are not adjacent", errors.RFCCodeText("PD:region:ErrRegionNotAdjacent")) // ErrRegionNotFound is error info for region not found. @@ -153,10 +155,13 @@ var ( // placement errors var ( - ErrRuleContent = errors.Normalize("invalid rule content, %s", errors.RFCCodeText("PD:placement:ErrRuleContent")) - ErrLoadRule = errors.Normalize("load rule failed", errors.RFCCodeText("PD:placement:ErrLoadRule")) - ErrLoadRuleGroup = errors.Normalize("load rule group failed", errors.RFCCodeText("PD:placement:ErrLoadRuleGroup")) - ErrBuildRuleList = errors.Normalize("build rule list failed, %s", errors.RFCCodeText("PD:placement:ErrBuildRuleList")) + ErrRuleContent = errors.Normalize("invalid rule content, %s", errors.RFCCodeText("PD:placement:ErrRuleContent")) + ErrLoadRule = errors.Normalize("load rule failed", errors.RFCCodeText("PD:placement:ErrLoadRule")) + ErrLoadRuleGroup = errors.Normalize("load rule group failed", errors.RFCCodeText("PD:placement:ErrLoadRuleGroup")) + ErrBuildRuleList = errors.Normalize("build rule list failed, %s", errors.RFCCodeText("PD:placement:ErrBuildRuleList")) + ErrPlacementDisabled = errors.Normalize("placement rules feature is disabled", errors.RFCCodeText("PD:placement:ErrPlacementDisabled")) + ErrKeyFormat = errors.Normalize("key should be in hex format, %s", errors.RFCCodeText("PD:placement:ErrKeyFormat")) + ErrRuleNotFound = errors.Normalize("rule not found", errors.RFCCodeText("PD:placement:ErrRuleNotFound")) ) // region label errors diff --git a/pkg/mcs/scheduling/server/apis/v1/api.go b/pkg/mcs/scheduling/server/apis/v1/api.go index 47fdb95543f..172515d8620 100644 --- a/pkg/mcs/scheduling/server/apis/v1/api.go +++ b/pkg/mcs/scheduling/server/apis/v1/api.go @@ -15,6 +15,7 @@ package apis import ( + "encoding/hex" "net/http" "strconv" "sync" @@ -127,12 +128,6 @@ func (s *Service) RegisterAdminRouter() { router.DELETE("cache/regions/:id", deleteRegionCacheByID) } -// RegisterConfigRouter registers the router of the config handler. -func (s *Service) RegisterConfigRouter() { - router := s.root.Group("config") - router.GET("", getConfig) -} - // RegisterSchedulersRouter registers the router of the schedulers handler. func (s *Service) RegisterSchedulersRouter() { router := s.root.Group("schedulers") @@ -172,6 +167,32 @@ func (s *Service) RegisterOperatorsRouter() { router.GET("/records", getOperatorRecords) } +// RegisterConfigRouter registers the router of the config handler. +func (s *Service) RegisterConfigRouter() { + router := s.root.Group("config") + router.GET("", getConfig) + + rules := router.Group("rules") + rules.GET("", getAllRules) + rules.GET("/group/:group", getRuleByGroup) + rules.GET("/region/:region", getRulesByRegion) + rules.GET("/region/:region/detail", checkRegionPlacementRule) + rules.GET("/key/:key", getRulesByKey) + + // We cannot merge `/rule` and `/rules`, because we allow `group_id` to be "group", + // which is the same as the prefix of `/rules/group/:group`. + rule := router.Group("rule") + rule.GET("/:group/:id", getRuleByGroupAndID) + + groups := router.Group("rule_groups") + groups.GET("", getAllGroupConfigs) + groups.GET("/:id", getRuleGroupConfig) + + placementRule := router.Group("placement-rule") + placementRule.GET("", getPlacementRules) + placementRule.GET("/:group", getPlacementRuleByGroup) +} + // @Tags admin // @Summary Change the log level. // @Produce json @@ -671,3 +692,266 @@ func getHistoryHotRegions(c *gin.Context) { var res storage.HistoryHotRegions c.IndentedJSON(http.StatusOK, res) } + +// @Tags rule +// @Summary List all rules of cluster. +// @Produce json +// @Success 200 {array} placement.Rule +// @Failure 412 {string} string "Placement rules feature is disabled." +// @Failure 500 {string} string "PD server failed to proceed the request." +// @Router /config/rules [get] +func getAllRules(c *gin.Context) { + handler := c.MustGet(handlerKey).(*handler.Handler) + manager, err := handler.GetRuleManager() + if err == errs.ErrPlacementDisabled { + c.String(http.StatusPreconditionFailed, err.Error()) + return + } + if err != nil { + c.String(http.StatusInternalServerError, err.Error()) + return + } + rules := manager.GetAllRules() + c.IndentedJSON(http.StatusOK, rules) +} + +// @Tags rule +// @Summary List all rules of cluster by group. +// @Param group path string true "The name of group" +// @Produce json +// @Success 200 {array} placement.Rule +// @Failure 412 {string} string "Placement rules feature is disabled." +// @Failure 500 {string} string "PD server failed to proceed the request." +// @Router /config/rules/group/{group} [get] +func getRuleByGroup(c *gin.Context) { + handler := c.MustGet(handlerKey).(*handler.Handler) + manager, err := handler.GetRuleManager() + if err == errs.ErrPlacementDisabled { + c.String(http.StatusPreconditionFailed, err.Error()) + return + } + if err != nil { + c.String(http.StatusInternalServerError, err.Error()) + return + } + group := c.Param("group") + rules := manager.GetRulesByGroup(group) + c.IndentedJSON(http.StatusOK, rules) +} + +// @Tags rule +// @Summary List all rules of cluster by region. +// @Param id path integer true "Region Id" +// @Produce json +// @Success 200 {array} placement.Rule +// @Failure 400 {string} string "The input is invalid." +// @Failure 404 {string} string "The region does not exist." +// @Failure 412 {string} string "Placement rules feature is disabled." +// @Failure 500 {string} string "PD server failed to proceed the request." +// @Router /config/rules/region/{region} [get] +func getRulesByRegion(c *gin.Context) { + handler := c.MustGet(handlerKey).(*handler.Handler) + manager, err := handler.GetRuleManager() + if err == errs.ErrPlacementDisabled { + c.String(http.StatusPreconditionFailed, err.Error()) + return + } + if err != nil { + c.String(http.StatusInternalServerError, err.Error()) + return + } + regionStr := c.Param("region") + region, code, err := handler.PreCheckForRegion(regionStr) + if err != nil { + c.String(code, err.Error()) + return + } + rules := manager.GetRulesForApplyRegion(region) + c.IndentedJSON(http.StatusOK, rules) +} + +// @Tags rule +// @Summary List rules and matched peers related to the given region. +// @Param id path integer true "Region Id" +// @Produce json +// @Success 200 {object} placement.RegionFit +// @Failure 400 {string} string "The input is invalid." +// @Failure 404 {string} string "The region does not exist." +// @Failure 412 {string} string "Placement rules feature is disabled." +// @Failure 500 {string} string "PD server failed to proceed the request." +// @Router /config/rules/region/{region}/detail [get] +func checkRegionPlacementRule(c *gin.Context) { + handler := c.MustGet(handlerKey).(*handler.Handler) + regionStr := c.Param("region") + region, code, err := handler.PreCheckForRegion(regionStr) + if err != nil { + c.String(code, err.Error()) + return + } + regionFit, err := handler.CheckRegionPlacementRule(region) + if err == errs.ErrPlacementDisabled { + c.String(http.StatusPreconditionFailed, err.Error()) + return + } + if err != nil { + c.String(http.StatusInternalServerError, err.Error()) + return + } + c.IndentedJSON(http.StatusOK, regionFit) +} + +// @Tags rule +// @Summary List all rules of cluster by key. +// @Param key path string true "The name of key" +// @Produce json +// @Success 200 {array} placement.Rule +// @Failure 400 {string} string "The input is invalid." +// @Failure 412 {string} string "Placement rules feature is disabled." +// @Failure 500 {string} string "PD server failed to proceed the request." +// @Router /config/rules/key/{key} [get] +func getRulesByKey(c *gin.Context) { + handler := c.MustGet(handlerKey).(*handler.Handler) + manager, err := handler.GetRuleManager() + if err == errs.ErrPlacementDisabled { + c.String(http.StatusPreconditionFailed, err.Error()) + return + } + if err != nil { + c.String(http.StatusInternalServerError, err.Error()) + return + } + keyHex := c.Param("key") + key, err := hex.DecodeString(keyHex) + if err != nil { + c.String(http.StatusBadRequest, errs.ErrKeyFormat.Error()) + return + } + rules := manager.GetRulesByKey(key) + c.IndentedJSON(http.StatusOK, rules) +} + +// @Tags rule +// @Summary Get rule of cluster by group and id. +// @Param group path string true "The name of group" +// @Param id path string true "Rule Id" +// @Produce json +// @Success 200 {object} placement.Rule +// @Failure 404 {string} string "The rule does not exist." +// @Failure 412 {string} string "Placement rules feature is disabled." +// @Router /config/rule/{group}/{id} [get] +func getRuleByGroupAndID(c *gin.Context) { + handler := c.MustGet(handlerKey).(*handler.Handler) + manager, err := handler.GetRuleManager() + if err == errs.ErrPlacementDisabled { + c.String(http.StatusPreconditionFailed, err.Error()) + return + } + if err != nil { + c.String(http.StatusInternalServerError, err.Error()) + return + } + group, id := c.Param("group"), c.Param("id") + rule := manager.GetRule(group, id) + if rule == nil { + c.String(http.StatusNotFound, errs.ErrRuleNotFound.Error()) + return + } + c.IndentedJSON(http.StatusOK, rule) +} + +// @Tags rule +// @Summary List all rule group configs. +// @Produce json +// @Success 200 {array} placement.RuleGroup +// @Failure 412 {string} string "Placement rules feature is disabled." +// @Failure 500 {string} string "PD server failed to proceed the request." +// @Router /config/rule_groups [get] +func getAllGroupConfigs(c *gin.Context) { + handler := c.MustGet(handlerKey).(*handler.Handler) + manager, err := handler.GetRuleManager() + if err == errs.ErrPlacementDisabled { + c.String(http.StatusPreconditionFailed, err.Error()) + return + } + if err != nil { + c.String(http.StatusInternalServerError, err.Error()) + return + } + ruleGroups := manager.GetRuleGroups() + c.IndentedJSON(http.StatusOK, ruleGroups) +} + +// @Tags rule +// @Summary Get rule group config by group id. +// @Param id path string true "Group Id" +// @Produce json +// @Success 200 {object} placement.RuleGroup +// @Failure 404 {string} string "The RuleGroup does not exist." +// @Failure 412 {string} string "Placement rules feature is disabled." +// @Failure 500 {string} string "PD server failed to proceed the request." +// @Router /config/rule_groups/{id} [get] +func getRuleGroupConfig(c *gin.Context) { + handler := c.MustGet(handlerKey).(*handler.Handler) + manager, err := handler.GetRuleManager() + if err == errs.ErrPlacementDisabled { + c.String(http.StatusPreconditionFailed, err.Error()) + return + } + if err != nil { + c.String(http.StatusInternalServerError, err.Error()) + return + } + id := c.Param("id") + group := manager.GetRuleGroup(id) + if group == nil { + c.String(http.StatusNotFound, errs.ErrRuleNotFound.Error()) + return + } + c.IndentedJSON(http.StatusOK, group) +} + +// @Tags rule +// @Summary List all rules and groups configuration. +// @Produce json +// @Success 200 {array} placement.GroupBundle +// @Failure 412 {string} string "Placement rules feature is disabled." +// @Failure 500 {string} string "PD server failed to proceed the request." +// @Router /config/placement-rules [get] +func getPlacementRules(c *gin.Context) { + handler := c.MustGet(handlerKey).(*handler.Handler) + manager, err := handler.GetRuleManager() + if err == errs.ErrPlacementDisabled { + c.String(http.StatusPreconditionFailed, err.Error()) + return + } + if err != nil { + c.String(http.StatusInternalServerError, err.Error()) + return + } + bundles := manager.GetAllGroupBundles() + c.IndentedJSON(http.StatusOK, bundles) +} + +// @Tags rule +// @Summary Get group config and all rules belong to the group. +// @Param group path string true "The name of group" +// @Produce json +// @Success 200 {object} placement.GroupBundle +// @Failure 412 {string} string "Placement rules feature is disabled." +// @Failure 500 {string} string "PD server failed to proceed the request." +// @Router /config/placement-rules/{group} [get] +func getPlacementRuleByGroup(c *gin.Context) { + handler := c.MustGet(handlerKey).(*handler.Handler) + manager, err := handler.GetRuleManager() + if err == errs.ErrPlacementDisabled { + c.String(http.StatusPreconditionFailed, err.Error()) + return + } + if err != nil { + c.String(http.StatusInternalServerError, err.Error()) + return + } + g := c.Param("group") + group := manager.GetGroupBundle(g) + c.IndentedJSON(http.StatusOK, group) +} diff --git a/pkg/schedule/handler/handler.go b/pkg/schedule/handler/handler.go index 45b0eaf502f..3f9f4f96622 100644 --- a/pkg/schedule/handler/handler.go +++ b/pkg/schedule/handler/handler.go @@ -18,6 +18,7 @@ import ( "bytes" "encoding/hex" "net/http" + "strconv" "strings" "time" @@ -1061,3 +1062,45 @@ func (h *Handler) GetHotBuckets(regionIDs ...uint64) (HotBucketsResponse, error) } return ret, nil } + +// GetRuleManager returns the rule manager. +func (h *Handler) GetRuleManager() (*placement.RuleManager, error) { + c := h.GetCluster() + if c == nil { + return nil, errs.ErrNotBootstrapped + } + if !c.GetSharedConfig().IsPlacementRulesEnabled() { + return nil, errs.ErrPlacementDisabled + } + return c.GetRuleManager(), nil +} + +// PreCheckForRegion checks if the region is valid. +func (h *Handler) PreCheckForRegion(regionStr string) (*core.RegionInfo, int, error) { + c := h.GetCluster() + if c == nil { + return nil, http.StatusInternalServerError, errs.ErrNotBootstrapped.GenWithStackByArgs() + } + regionID, err := strconv.ParseUint(regionStr, 10, 64) + if err != nil { + return nil, http.StatusBadRequest, errs.ErrRegionInvalidID.FastGenByArgs() + } + region := c.GetRegion(regionID) + if region == nil { + return nil, http.StatusNotFound, errs.ErrRegionNotFound.FastGenByArgs(regionID) + } + return region, http.StatusOK, nil +} + +// CheckRegionPlacementRule checks if the region matches the placement rules. +func (h *Handler) CheckRegionPlacementRule(region *core.RegionInfo) (*placement.RegionFit, error) { + c := h.GetCluster() + if c == nil { + return nil, errs.ErrNotBootstrapped.GenWithStackByArgs() + } + manager, err := h.GetRuleManager() + if err != nil { + return nil, err + } + return manager.FitRegion(c, region), nil +} diff --git a/pkg/utils/apiutil/serverapi/middleware.go b/pkg/utils/apiutil/serverapi/middleware.go index 19438ad0f91..2bb742ccbba 100644 --- a/pkg/utils/apiutil/serverapi/middleware.go +++ b/pkg/utils/apiutil/serverapi/middleware.go @@ -117,6 +117,7 @@ func (h *redirector) matchMicroServiceRedirectRules(r *http.Request) (bool, stri r.URL.Path = strings.TrimRight(r.URL.Path, "/") for _, rule := range h.microserviceRedirectRules { if strings.HasPrefix(r.URL.Path, rule.matchPath) && slice.Contains(rule.matchMethods, r.Method) { + origin := r.URL.Path addr, ok := h.s.GetServicePrimaryAddr(r.Context(), rule.targetServiceName) if !ok || addr == "" { log.Warn("failed to get the service primary addr when trying to match redirect rules", @@ -134,8 +135,8 @@ func (h *redirector) matchMicroServiceRedirectRules(r *http.Request) (bool, stri } else { r.URL.Path = rule.targetPath } - log.Debug("redirect to micro service", zap.String("path", r.URL.Path), zap.String("target", addr), - zap.String("method", r.Method)) + log.Debug("redirect to micro service", zap.String("path", r.URL.Path), zap.String("origin-path", origin), + zap.String("target", addr), zap.String("method", r.Method)) return true, addr } } diff --git a/pkg/utils/etcdutil/etcdutil.go b/pkg/utils/etcdutil/etcdutil.go index 1432b6e37c3..e004247c6d0 100644 --- a/pkg/utils/etcdutil/etcdutil.go +++ b/pkg/utils/etcdutil/etcdutil.go @@ -707,7 +707,6 @@ func (lw *LoopWatcher) watch(ctx context.Context, revision int64) (nextRevision }() ticker := time.NewTicker(RequestProgressInterval) defer ticker.Stop() - lastReceivedResponseTime := time.Now() for { if watcherCancel != nil { @@ -736,8 +735,10 @@ func (lw *LoopWatcher) watch(ctx context.Context, revision int64) (nextRevision continue } } + lastReceivedResponseTime := time.Now() log.Info("watch channel is created in watch loop", zap.Int64("revision", revision), zap.String("name", lw.name), zap.String("key", lw.key)) + watchChanLoop: select { case <-ctx.Done(): diff --git a/server/api/region_test.go b/server/api/region_test.go index a39a1e5c5fd..379fcf7d463 100644 --- a/server/api/region_test.go +++ b/server/api/region_test.go @@ -241,14 +241,14 @@ func (suite *regionTestSuite) TestRegions() { mustRegionHeartbeat(re, suite.svr, r) } url := fmt.Sprintf("%s/regions", suite.urlPrefix) - RegionsInfo := &RegionsInfo{} - err := tu.ReadGetJSON(re, testDialClient, url, RegionsInfo) + regionsInfo := &RegionsInfo{} + err := tu.ReadGetJSON(re, testDialClient, url, regionsInfo) suite.NoError(err) - suite.Len(regions, RegionsInfo.Count) - sort.Slice(RegionsInfo.Regions, func(i, j int) bool { - return RegionsInfo.Regions[i].ID < RegionsInfo.Regions[j].ID + suite.Len(regions, regionsInfo.Count) + sort.Slice(regionsInfo.Regions, func(i, j int) bool { + return regionsInfo.Regions[i].ID < regionsInfo.Regions[j].ID }) - for i, r := range RegionsInfo.Regions { + for i, r := range regionsInfo.Regions { suite.Equal(regions[i].ID, r.ID) suite.Equal(regions[i].ApproximateSize, r.ApproximateSize) suite.Equal(regions[i].ApproximateKeys, r.ApproximateKeys) diff --git a/server/api/rule.go b/server/api/rule.go index b3a720ece41..77aad42eb42 100644 --- a/server/api/rule.go +++ b/server/api/rule.go @@ -19,30 +19,26 @@ import ( "fmt" "net/http" "net/url" - "strconv" "github.com/gorilla/mux" - "github.com/pingcap/errors" - "github.com/tikv/pd/pkg/core" "github.com/tikv/pd/pkg/errs" "github.com/tikv/pd/pkg/schedule/placement" "github.com/tikv/pd/pkg/utils/apiutil" "github.com/tikv/pd/server" - "github.com/tikv/pd/server/cluster" "github.com/unrolled/render" ) -var errPlacementDisabled = errors.New("placement rules feature is disabled") - type ruleHandler struct { + *server.Handler svr *server.Server rd *render.Render } func newRulesHandler(svr *server.Server, rd *render.Render) *ruleHandler { return &ruleHandler{ - svr: svr, - rd: rd, + Handler: svr.GetHandler(), + svr: svr, + rd: rd, } } @@ -51,14 +47,19 @@ func newRulesHandler(svr *server.Server, rd *render.Render) *ruleHandler { // @Produce json // @Success 200 {array} placement.Rule // @Failure 412 {string} string "Placement rules feature is disabled." +// @Failure 500 {string} string "PD server failed to proceed the request." // @Router /config/rules [get] func (h *ruleHandler) GetAllRules(w http.ResponseWriter, r *http.Request) { - cluster := getCluster(r) - if !cluster.GetOpts().IsPlacementRulesEnabled() { - h.rd.JSON(w, http.StatusPreconditionFailed, errPlacementDisabled.Error()) + manager, err := h.Handler.GetRuleManager() + if err == errs.ErrPlacementDisabled { + h.rd.JSON(w, http.StatusPreconditionFailed, err.Error()) + return + } + if err != nil { + h.rd.JSON(w, http.StatusInternalServerError, err.Error()) return } - rules := cluster.GetRuleManager().GetAllRules() + rules := manager.GetAllRules() h.rd.JSON(w, http.StatusOK, rules) } @@ -72,9 +73,13 @@ func (h *ruleHandler) GetAllRules(w http.ResponseWriter, r *http.Request) { // @Failure 500 {string} string "PD server failed to proceed the request." // @Router /config/rules [post] func (h *ruleHandler) SetAllRules(w http.ResponseWriter, r *http.Request) { - cluster := getCluster(r) - if !cluster.GetOpts().IsPlacementRulesEnabled() { - h.rd.JSON(w, http.StatusPreconditionFailed, errPlacementDisabled.Error()) + manager, err := h.Handler.GetRuleManager() + if err == errs.ErrPlacementDisabled { + h.rd.JSON(w, http.StatusPreconditionFailed, err.Error()) + return + } + if err != nil { + h.rd.JSON(w, http.StatusInternalServerError, err.Error()) return } var rules []*placement.Rule @@ -87,7 +92,7 @@ func (h *ruleHandler) SetAllRules(w http.ResponseWriter, r *http.Request) { return } } - if err := cluster.GetRuleManager().SetKeyType(h.svr.GetConfig().PDServerCfg.KeyType). + if err := manager.SetKeyType(h.svr.GetConfig().PDServerCfg.KeyType). SetRules(rules); err != nil { if errs.ErrRuleContent.Equal(err) || errs.ErrHexDecodingString.Equal(err) { h.rd.JSON(w, http.StatusBadRequest, err.Error()) @@ -105,15 +110,20 @@ func (h *ruleHandler) SetAllRules(w http.ResponseWriter, r *http.Request) { // @Produce json // @Success 200 {array} placement.Rule // @Failure 412 {string} string "Placement rules feature is disabled." +// @Failure 500 {string} string "PD server failed to proceed the request." // @Router /config/rules/group/{group} [get] func (h *ruleHandler) GetRuleByGroup(w http.ResponseWriter, r *http.Request) { - cluster := getCluster(r) - if !cluster.GetOpts().IsPlacementRulesEnabled() { - h.rd.JSON(w, http.StatusPreconditionFailed, errPlacementDisabled.Error()) + manager, err := h.Handler.GetRuleManager() + if err == errs.ErrPlacementDisabled { + h.rd.JSON(w, http.StatusPreconditionFailed, err.Error()) + return + } + if err != nil { + h.rd.JSON(w, http.StatusInternalServerError, err.Error()) return } group := mux.Vars(r)["group"] - rules := cluster.GetRuleManager().GetRulesByGroup(group) + rules := manager.GetRulesByGroup(group) h.rd.JSON(w, http.StatusOK, rules) } @@ -125,13 +135,25 @@ func (h *ruleHandler) GetRuleByGroup(w http.ResponseWriter, r *http.Request) { // @Failure 400 {string} string "The input is invalid." // @Failure 404 {string} string "The region does not exist." // @Failure 412 {string} string "Placement rules feature is disabled." +// @Failure 500 {string} string "PD server failed to proceed the request." // @Router /config/rules/region/{region} [get] func (h *ruleHandler) GetRulesByRegion(w http.ResponseWriter, r *http.Request) { - cluster, region := h.preCheckForRegionAndRule(w, r) - if cluster == nil || region == nil { + manager, err := h.Handler.GetRuleManager() + if err == errs.ErrPlacementDisabled { + h.rd.JSON(w, http.StatusPreconditionFailed, err.Error()) + return + } + if err != nil { + h.rd.JSON(w, http.StatusInternalServerError, err.Error()) return } - rules := cluster.GetRuleManager().GetRulesForApplyRegion(region) + regionStr := mux.Vars(r)["region"] + region, code, err := h.PreCheckForRegion(regionStr) + if err != nil { + h.rd.JSON(w, code, err.Error()) + return + } + rules := manager.GetRulesForApplyRegion(region) h.rd.JSON(w, http.StatusOK, rules) } @@ -143,34 +165,25 @@ func (h *ruleHandler) GetRulesByRegion(w http.ResponseWriter, r *http.Request) { // @Failure 400 {string} string "The input is invalid." // @Failure 404 {string} string "The region does not exist." // @Failure 412 {string} string "Placement rules feature is disabled." +// @Failure 500 {string} string "PD server failed to proceed the request." // @Router /config/rules/region/{region}/detail [get] func (h *ruleHandler) CheckRegionPlacementRule(w http.ResponseWriter, r *http.Request) { - cluster, region := h.preCheckForRegionAndRule(w, r) - if cluster == nil || region == nil { + regionStr := mux.Vars(r)["region"] + region, code, err := h.PreCheckForRegion(regionStr) + if err != nil { + h.rd.JSON(w, code, err.Error()) return } - regionFit := cluster.GetRuleManager().FitRegion(cluster, region) - h.rd.JSON(w, http.StatusOK, regionFit) -} - -func (h *ruleHandler) preCheckForRegionAndRule(w http.ResponseWriter, r *http.Request) (*cluster.RaftCluster, *core.RegionInfo) { - cluster := getCluster(r) - if !cluster.GetOpts().IsPlacementRulesEnabled() { - h.rd.JSON(w, http.StatusPreconditionFailed, errPlacementDisabled.Error()) - return cluster, nil + regionFit, err := h.Handler.CheckRegionPlacementRule(region) + if err == errs.ErrPlacementDisabled { + h.rd.JSON(w, http.StatusPreconditionFailed, err.Error()) + return } - regionStr := mux.Vars(r)["region"] - regionID, err := strconv.ParseUint(regionStr, 10, 64) if err != nil { - h.rd.JSON(w, http.StatusBadRequest, "invalid region id") - return cluster, nil - } - region := cluster.GetRegion(regionID) - if region == nil { - h.rd.JSON(w, http.StatusNotFound, errs.ErrRegionNotFound.FastGenByArgs(regionID).Error()) - return cluster, nil + h.rd.JSON(w, http.StatusInternalServerError, err.Error()) + return } - return cluster, region + h.rd.JSON(w, http.StatusOK, regionFit) } // @Tags rule @@ -180,20 +193,25 @@ func (h *ruleHandler) preCheckForRegionAndRule(w http.ResponseWriter, r *http.Re // @Success 200 {array} placement.Rule // @Failure 400 {string} string "The input is invalid." // @Failure 412 {string} string "Placement rules feature is disabled." +// @Failure 500 {string} string "PD server failed to proceed the request." // @Router /config/rules/key/{key} [get] func (h *ruleHandler) GetRulesByKey(w http.ResponseWriter, r *http.Request) { - cluster := getCluster(r) - if !cluster.GetOpts().IsPlacementRulesEnabled() { - h.rd.JSON(w, http.StatusPreconditionFailed, errPlacementDisabled.Error()) + manager, err := h.Handler.GetRuleManager() + if err == errs.ErrPlacementDisabled { + h.rd.JSON(w, http.StatusPreconditionFailed, err.Error()) + return + } + if err != nil { + h.rd.JSON(w, http.StatusInternalServerError, err.Error()) return } keyHex := mux.Vars(r)["key"] key, err := hex.DecodeString(keyHex) if err != nil { - h.rd.JSON(w, http.StatusBadRequest, "key should be in hex format") + h.rd.JSON(w, http.StatusBadRequest, errs.ErrKeyFormat.FastGenByArgs(err).Error()) return } - rules := cluster.GetRuleManager().GetRulesByKey(key) + rules := manager.GetRulesByKey(key) h.rd.JSON(w, http.StatusOK, rules) } @@ -207,15 +225,19 @@ func (h *ruleHandler) GetRulesByKey(w http.ResponseWriter, r *http.Request) { // @Failure 412 {string} string "Placement rules feature is disabled." // @Router /config/rule/{group}/{id} [get] func (h *ruleHandler) GetRuleByGroupAndID(w http.ResponseWriter, r *http.Request) { - cluster := getCluster(r) - if !cluster.GetOpts().IsPlacementRulesEnabled() { - h.rd.JSON(w, http.StatusPreconditionFailed, errPlacementDisabled.Error()) + manager, err := h.Handler.GetRuleManager() + if err == errs.ErrPlacementDisabled { + h.rd.JSON(w, http.StatusPreconditionFailed, err.Error()) + return + } + if err != nil { + h.rd.JSON(w, http.StatusInternalServerError, err.Error()) return } group, id := mux.Vars(r)["group"], mux.Vars(r)["id"] - rule := cluster.GetRuleManager().GetRule(group, id) + rule := manager.GetRule(group, id) if rule == nil { - h.rd.JSON(w, http.StatusNotFound, nil) + h.rd.JSON(w, http.StatusNotFound, errs.ErrRuleNotFound.Error()) return } h.rd.JSON(w, http.StatusOK, rule) @@ -232,21 +254,25 @@ func (h *ruleHandler) GetRuleByGroupAndID(w http.ResponseWriter, r *http.Request // @Failure 500 {string} string "PD server failed to proceed the request." // @Router /config/rule [post] func (h *ruleHandler) SetRule(w http.ResponseWriter, r *http.Request) { - cluster := getCluster(r) - if !cluster.GetOpts().IsPlacementRulesEnabled() { - h.rd.JSON(w, http.StatusPreconditionFailed, errPlacementDisabled.Error()) + manager, err := h.Handler.GetRuleManager() + if err == errs.ErrPlacementDisabled { + h.rd.JSON(w, http.StatusPreconditionFailed, err.Error()) + return + } + if err != nil { + h.rd.JSON(w, http.StatusInternalServerError, err.Error()) return } var rule placement.Rule if err := apiutil.ReadJSONRespondError(h.rd, w, r.Body, &rule); err != nil { return } - oldRule := cluster.GetRuleManager().GetRule(rule.GroupID, rule.ID) + oldRule := manager.GetRule(rule.GroupID, rule.ID) if err := h.syncReplicateConfigWithDefaultRule(&rule); err != nil { h.rd.JSON(w, http.StatusBadRequest, err.Error()) return } - if err := cluster.GetRuleManager().SetKeyType(h.svr.GetConfig().PDServerCfg.KeyType). + if err := manager.SetKeyType(h.svr.GetConfig().PDServerCfg.KeyType). SetRule(&rule); err != nil { if errs.ErrRuleContent.Equal(err) || errs.ErrHexDecodingString.Equal(err) { h.rd.JSON(w, http.StatusBadRequest, err.Error()) @@ -255,6 +281,7 @@ func (h *ruleHandler) SetRule(w http.ResponseWriter, r *http.Request) { } return } + cluster := getCluster(r) cluster.AddSuspectKeyRange(rule.StartKey, rule.EndKey) if oldRule != nil { cluster.AddSuspectKeyRange(oldRule.StartKey, oldRule.EndKey) @@ -285,18 +312,23 @@ func (h *ruleHandler) syncReplicateConfigWithDefaultRule(rule *placement.Rule) e // @Failure 500 {string} string "PD server failed to proceed the request." // @Router /config/rule/{group}/{id} [delete] func (h *ruleHandler) DeleteRuleByGroup(w http.ResponseWriter, r *http.Request) { - cluster := getCluster(r) - if !cluster.GetOpts().IsPlacementRulesEnabled() { - h.rd.JSON(w, http.StatusPreconditionFailed, errPlacementDisabled.Error()) + manager, err := h.Handler.GetRuleManager() + if err == errs.ErrPlacementDisabled { + h.rd.JSON(w, http.StatusPreconditionFailed, err.Error()) + return + } + if err != nil { + h.rd.JSON(w, http.StatusInternalServerError, err.Error()) return } group, id := mux.Vars(r)["group"], mux.Vars(r)["id"] - rule := cluster.GetRuleManager().GetRule(group, id) - if err := cluster.GetRuleManager().DeleteRule(group, id); err != nil { + rule := manager.GetRule(group, id) + if err := manager.DeleteRule(group, id); err != nil { h.rd.JSON(w, http.StatusInternalServerError, err.Error()) return } if rule != nil { + cluster := getCluster(r) cluster.AddSuspectKeyRange(rule.StartKey, rule.EndKey) } @@ -313,16 +345,20 @@ func (h *ruleHandler) DeleteRuleByGroup(w http.ResponseWriter, r *http.Request) // @Failure 500 {string} string "PD server failed to proceed the request." // @Router /config/rules/batch [post] func (h *ruleHandler) BatchRules(w http.ResponseWriter, r *http.Request) { - cluster := getCluster(r) - if !cluster.GetOpts().IsPlacementRulesEnabled() { - h.rd.JSON(w, http.StatusPreconditionFailed, errPlacementDisabled.Error()) + manager, err := h.Handler.GetRuleManager() + if err == errs.ErrPlacementDisabled { + h.rd.JSON(w, http.StatusPreconditionFailed, err.Error()) + return + } + if err != nil { + h.rd.JSON(w, http.StatusInternalServerError, err.Error()) return } var opts []placement.RuleOp if err := apiutil.ReadJSONRespondError(h.rd, w, r.Body, &opts); err != nil { return } - if err := cluster.GetRuleManager().SetKeyType(h.svr.GetConfig().PDServerCfg.KeyType). + if err := manager.SetKeyType(h.svr.GetConfig().PDServerCfg.KeyType). Batch(opts); err != nil { if errs.ErrRuleContent.Equal(err) || errs.ErrHexDecodingString.Equal(err) { h.rd.JSON(w, http.StatusBadRequest, err.Error()) @@ -341,15 +377,20 @@ func (h *ruleHandler) BatchRules(w http.ResponseWriter, r *http.Request) { // @Success 200 {object} placement.RuleGroup // @Failure 404 {string} string "The RuleGroup does not exist." // @Failure 412 {string} string "Placement rules feature is disabled." +// @Failure 500 {string} string "PD server failed to proceed the request." // @Router /config/rule_group/{id} [get] func (h *ruleHandler) GetGroupConfig(w http.ResponseWriter, r *http.Request) { - cluster := getCluster(r) - if !cluster.GetOpts().IsPlacementRulesEnabled() { - h.rd.JSON(w, http.StatusPreconditionFailed, errPlacementDisabled.Error()) + manager, err := h.Handler.GetRuleManager() + if err == errs.ErrPlacementDisabled { + h.rd.JSON(w, http.StatusPreconditionFailed, err.Error()) + return + } + if err != nil { + h.rd.JSON(w, http.StatusInternalServerError, err.Error()) return } id := mux.Vars(r)["id"] - group := cluster.GetRuleManager().GetRuleGroup(id) + group := manager.GetRuleGroup(id) if group == nil { h.rd.JSON(w, http.StatusNotFound, nil) return @@ -368,21 +409,26 @@ func (h *ruleHandler) GetGroupConfig(w http.ResponseWriter, r *http.Request) { // @Failure 500 {string} string "PD server failed to proceed the request." // @Router /config/rule_group [post] func (h *ruleHandler) SetGroupConfig(w http.ResponseWriter, r *http.Request) { - cluster := getCluster(r) - if !cluster.GetOpts().IsPlacementRulesEnabled() { - h.rd.JSON(w, http.StatusPreconditionFailed, errPlacementDisabled.Error()) + manager, err := h.Handler.GetRuleManager() + if err == errs.ErrPlacementDisabled { + h.rd.JSON(w, http.StatusPreconditionFailed, err.Error()) + return + } + if err != nil { + h.rd.JSON(w, http.StatusInternalServerError, err.Error()) return } var ruleGroup placement.RuleGroup if err := apiutil.ReadJSONRespondError(h.rd, w, r.Body, &ruleGroup); err != nil { return } - if err := cluster.GetRuleManager().SetRuleGroup(&ruleGroup); err != nil { + if err := manager.SetRuleGroup(&ruleGroup); err != nil { h.rd.JSON(w, http.StatusInternalServerError, err.Error()) return } - for _, r := range cluster.GetRuleManager().GetRulesByGroup(ruleGroup.ID) { - cluster.AddSuspectKeyRange(r.StartKey, r.EndKey) + cluster := getCluster(r) + for _, rule := range manager.GetRulesByGroup(ruleGroup.ID) { + cluster.AddSuspectKeyRange(rule.StartKey, rule.EndKey) } h.rd.JSON(w, http.StatusOK, "Update rule group successfully.") } @@ -396,18 +442,23 @@ func (h *ruleHandler) SetGroupConfig(w http.ResponseWriter, r *http.Request) { // @Failure 500 {string} string "PD server failed to proceed the request." // @Router /config/rule_group/{id} [delete] func (h *ruleHandler) DeleteGroupConfig(w http.ResponseWriter, r *http.Request) { - cluster := getCluster(r) - if !cluster.GetOpts().IsPlacementRulesEnabled() { - h.rd.JSON(w, http.StatusPreconditionFailed, errPlacementDisabled.Error()) + manager, err := h.Handler.GetRuleManager() + if err == errs.ErrPlacementDisabled { + h.rd.JSON(w, http.StatusPreconditionFailed, err.Error()) + return + } + if err != nil { + h.rd.JSON(w, http.StatusInternalServerError, err.Error()) return } id := mux.Vars(r)["id"] - err := cluster.GetRuleManager().DeleteRuleGroup(id) + err = manager.DeleteRuleGroup(id) if err != nil { h.rd.JSON(w, http.StatusInternalServerError, err.Error()) return } - for _, r := range cluster.GetRuleManager().GetRulesByGroup(id) { + cluster := getCluster(r) + for _, r := range manager.GetRulesByGroup(id) { cluster.AddSuspectKeyRange(r.StartKey, r.EndKey) } h.rd.JSON(w, http.StatusOK, "Delete rule group successfully.") @@ -418,14 +469,19 @@ func (h *ruleHandler) DeleteGroupConfig(w http.ResponseWriter, r *http.Request) // @Produce json // @Success 200 {array} placement.RuleGroup // @Failure 412 {string} string "Placement rules feature is disabled." +// @Failure 500 {string} string "PD server failed to proceed the request." // @Router /config/rule_groups [get] func (h *ruleHandler) GetAllGroupConfigs(w http.ResponseWriter, r *http.Request) { - cluster := getCluster(r) - if !cluster.GetOpts().IsPlacementRulesEnabled() { - h.rd.JSON(w, http.StatusPreconditionFailed, errPlacementDisabled.Error()) + manager, err := h.Handler.GetRuleManager() + if err == errs.ErrPlacementDisabled { + h.rd.JSON(w, http.StatusPreconditionFailed, err.Error()) return } - ruleGroups := cluster.GetRuleManager().GetRuleGroups() + if err != nil { + h.rd.JSON(w, http.StatusInternalServerError, err.Error()) + return + } + ruleGroups := manager.GetRuleGroups() h.rd.JSON(w, http.StatusOK, ruleGroups) } @@ -434,14 +490,19 @@ func (h *ruleHandler) GetAllGroupConfigs(w http.ResponseWriter, r *http.Request) // @Produce json // @Success 200 {array} placement.GroupBundle // @Failure 412 {string} string "Placement rules feature is disabled." +// @Failure 500 {string} string "PD server failed to proceed the request." // @Router /config/placement-rule [get] func (h *ruleHandler) GetPlacementRules(w http.ResponseWriter, r *http.Request) { - cluster := getCluster(r) - if !cluster.GetOpts().IsPlacementRulesEnabled() { - h.rd.JSON(w, http.StatusPreconditionFailed, errPlacementDisabled.Error()) + manager, err := h.Handler.GetRuleManager() + if err == errs.ErrPlacementDisabled { + h.rd.JSON(w, http.StatusPreconditionFailed, err.Error()) + return + } + if err != nil { + h.rd.JSON(w, http.StatusInternalServerError, err.Error()) return } - bundles := cluster.GetRuleManager().GetAllGroupBundles() + bundles := manager.GetAllGroupBundles() h.rd.JSON(w, http.StatusOK, bundles) } @@ -455,9 +516,13 @@ func (h *ruleHandler) GetPlacementRules(w http.ResponseWriter, r *http.Request) // @Failure 500 {string} string "PD server failed to proceed the request." // @Router /config/placement-rule [post] func (h *ruleHandler) SetPlacementRules(w http.ResponseWriter, r *http.Request) { - cluster := getCluster(r) - if !cluster.GetOpts().IsPlacementRulesEnabled() { - h.rd.JSON(w, http.StatusPreconditionFailed, errPlacementDisabled.Error()) + manager, err := h.Handler.GetRuleManager() + if err == errs.ErrPlacementDisabled { + h.rd.JSON(w, http.StatusPreconditionFailed, err.Error()) + return + } + if err != nil { + h.rd.JSON(w, http.StatusInternalServerError, err.Error()) return } var groups []placement.GroupBundle @@ -465,7 +530,7 @@ func (h *ruleHandler) SetPlacementRules(w http.ResponseWriter, r *http.Request) return } _, partial := r.URL.Query()["partial"] - if err := cluster.GetRuleManager().SetKeyType(h.svr.GetConfig().PDServerCfg.KeyType). + if err := manager.SetKeyType(h.svr.GetConfig().PDServerCfg.KeyType). SetAllGroupBundles(groups, !partial); err != nil { if errs.ErrRuleContent.Equal(err) || errs.ErrHexDecodingString.Equal(err) { h.rd.JSON(w, http.StatusBadRequest, err.Error()) @@ -483,14 +548,20 @@ func (h *ruleHandler) SetPlacementRules(w http.ResponseWriter, r *http.Request) // @Produce json // @Success 200 {object} placement.GroupBundle // @Failure 412 {string} string "Placement rules feature is disabled." +// @Failure 500 {string} string "PD server failed to proceed the request." // @Router /config/placement-rule/{group} [get] func (h *ruleHandler) GetPlacementRuleByGroup(w http.ResponseWriter, r *http.Request) { - cluster := getCluster(r) - if !cluster.GetOpts().IsPlacementRulesEnabled() { - h.rd.JSON(w, http.StatusPreconditionFailed, errPlacementDisabled.Error()) + manager, err := h.Handler.GetRuleManager() + if err == errs.ErrPlacementDisabled { + h.rd.JSON(w, http.StatusPreconditionFailed, err.Error()) return } - group := cluster.GetRuleManager().GetGroupBundle(mux.Vars(r)["group"]) + if err != nil { + h.rd.JSON(w, http.StatusInternalServerError, err.Error()) + return + } + g := mux.Vars(r)["group"] + group := manager.GetGroupBundle(g) h.rd.JSON(w, http.StatusOK, group) } @@ -502,21 +573,26 @@ func (h *ruleHandler) GetPlacementRuleByGroup(w http.ResponseWriter, r *http.Req // @Success 200 {string} string "Delete group and rules successfully." // @Failure 400 {string} string "Bad request." // @Failure 412 {string} string "Placement rules feature is disabled." +// @Failure 500 {string} string "PD server failed to proceed the request." // @Router /config/placement-rule [delete] func (h *ruleHandler) DeletePlacementRuleByGroup(w http.ResponseWriter, r *http.Request) { - cluster := getCluster(r) - if !cluster.GetOpts().IsPlacementRulesEnabled() { - h.rd.JSON(w, http.StatusPreconditionFailed, errPlacementDisabled.Error()) + manager, err := h.Handler.GetRuleManager() + if err == errs.ErrPlacementDisabled { + h.rd.JSON(w, http.StatusPreconditionFailed, err.Error()) + return + } + if err != nil { + h.rd.JSON(w, http.StatusInternalServerError, err.Error()) return } group := mux.Vars(r)["group"] - group, err := url.PathUnescape(group) + group, err = url.PathUnescape(group) if err != nil { h.rd.JSON(w, http.StatusBadRequest, err.Error()) return } _, regex := r.URL.Query()["regexp"] - if err := cluster.GetRuleManager().DeleteGroupBundle(group, regex); err != nil { + if err := manager.DeleteGroupBundle(group, regex); err != nil { h.rd.JSON(w, http.StatusInternalServerError, err.Error()) return } @@ -532,9 +608,13 @@ func (h *ruleHandler) DeletePlacementRuleByGroup(w http.ResponseWriter, r *http. // @Failure 500 {string} string "PD server failed to proceed the request." // @Router /config/placement-rule/{group} [post] func (h *ruleHandler) SetPlacementRuleByGroup(w http.ResponseWriter, r *http.Request) { - cluster := getCluster(r) - if !cluster.GetOpts().IsPlacementRulesEnabled() { - h.rd.JSON(w, http.StatusPreconditionFailed, errPlacementDisabled.Error()) + manager, err := h.Handler.GetRuleManager() + if err == errs.ErrPlacementDisabled { + h.rd.JSON(w, http.StatusPreconditionFailed, err.Error()) + return + } + if err != nil { + h.rd.JSON(w, http.StatusInternalServerError, err.Error()) return } groupID := mux.Vars(r)["group"] @@ -549,7 +629,7 @@ func (h *ruleHandler) SetPlacementRuleByGroup(w http.ResponseWriter, r *http.Req h.rd.JSON(w, http.StatusBadRequest, fmt.Sprintf("group id %s does not match request URI %s", group.ID, groupID)) return } - if err := cluster.GetRuleManager().SetKeyType(h.svr.GetConfig().PDServerCfg.KeyType). + if err := manager.SetKeyType(h.svr.GetConfig().PDServerCfg.KeyType). SetGroupBundle(group); err != nil { if errs.ErrRuleContent.Equal(err) || errs.ErrHexDecodingString.Equal(err) { h.rd.JSON(w, http.StatusBadRequest, err.Error()) diff --git a/server/api/server.go b/server/api/server.go index ae877b8407c..77a51eb04e5 100644 --- a/server/api/server.go +++ b/server/api/server.go @@ -84,6 +84,31 @@ func NewHandler(_ context.Context, svr *server.Server) (http.Handler, apiutil.AP scheapi.APIPathPrefix+"/hotspot", mcs.SchedulingServiceName, []string{http.MethodGet}), + serverapi.MicroserviceRedirectRule( + prefix+"/config/rules", + scheapi.APIPathPrefix+"/config/rules", + mcs.SchedulingServiceName, + []string{http.MethodGet}), + serverapi.MicroserviceRedirectRule( + prefix+"/config/rule/", + scheapi.APIPathPrefix+"/config/rule", + mcs.SchedulingServiceName, + []string{http.MethodGet}), + serverapi.MicroserviceRedirectRule( + prefix+"/config/rule_group/", + scheapi.APIPathPrefix+"/config/rule_groups", // Note: this is a typo in the original code + mcs.SchedulingServiceName, + []string{http.MethodGet}), + serverapi.MicroserviceRedirectRule( + prefix+"/config/rule_groups", + scheapi.APIPathPrefix+"/config/rule_groups", + mcs.SchedulingServiceName, + []string{http.MethodGet}), + serverapi.MicroserviceRedirectRule( + prefix+"/config/placement-rule", + scheapi.APIPathPrefix+"/config/placement-rule", + mcs.SchedulingServiceName, + []string{http.MethodGet}), // because the writing of all the meta information of the scheduling service is in the API server, // we should not post and delete the scheduler directly in the scheduling service. serverapi.MicroserviceRedirectRule( diff --git a/tests/integrations/mcs/scheduling/api_test.go b/tests/integrations/mcs/scheduling/api_test.go index 15c66ce5829..cfeaa4db033 100644 --- a/tests/integrations/mcs/scheduling/api_test.go +++ b/tests/integrations/mcs/scheduling/api_test.go @@ -15,10 +15,10 @@ import ( _ "github.com/tikv/pd/pkg/mcs/scheduling/server/apis/v1" "github.com/tikv/pd/pkg/mcs/scheduling/server/config" "github.com/tikv/pd/pkg/schedule/handler" + "github.com/tikv/pd/pkg/schedule/placement" "github.com/tikv/pd/pkg/statistics" "github.com/tikv/pd/pkg/storage" "github.com/tikv/pd/pkg/utils/apiutil" - "github.com/tikv/pd/pkg/utils/tempurl" "github.com/tikv/pd/pkg/utils/testutil" "github.com/tikv/pd/tests" ) @@ -43,7 +43,7 @@ func TestAPI(t *testing.T) { suite.Run(t, &apiTestSuite{}) } -func (suite *apiTestSuite) SetupSuite() { +func (suite *apiTestSuite) SetupTest() { ctx, cancel := context.WithCancel(context.Background()) suite.ctx = ctx cluster, err := tests.NewTestAPICluster(suite.ctx, 1) @@ -62,14 +62,19 @@ func (suite *apiTestSuite) SetupSuite() { suite.cleanupFunc = func() { cancel() } + tc, err := tests.NewTestSchedulingCluster(suite.ctx, 2, suite.backendEndpoints) + suite.NoError(err) + suite.cluster.SetSchedulingCluster(tc) + tc.WaitForPrimaryServing(suite.Require()) } -func (suite *apiTestSuite) TearDownSuite() { +func (suite *apiTestSuite) TearDownTest() { suite.cluster.Destroy() suite.cleanupFunc() } func (suite *apiTestSuite) TestGetCheckerByName() { + re := suite.Require() testCases := []struct { name string }{ @@ -81,14 +86,8 @@ func (suite *apiTestSuite) TestGetCheckerByName() { {name: "joint-state"}, } - re := suite.Require() - s, cleanup := tests.StartSingleSchedulingTestServer(suite.ctx, re, suite.backendEndpoints, tempurl.Alloc()) - defer cleanup() - testutil.Eventually(re, func() bool { - return s.IsServing() - }, testutil.WithWaitFor(5*time.Second), testutil.WithTickInterval(50*time.Millisecond)) - addr := s.GetAddr() - urlPrefix := fmt.Sprintf("%s/scheduling/api/v1/checkers", addr) + s := suite.cluster.GetSchedulingPrimaryServer() + urlPrefix := fmt.Sprintf("%s/scheduling/api/v1/checkers", s.GetAddr()) co := s.GetCoordinator() for _, testCase := range testCases { @@ -123,17 +122,12 @@ func (suite *apiTestSuite) TestAPIForward() { re.NoError(failpoint.Disable("github.com/tikv/pd/pkg/utils/apiutil/serverapi/checkHeader")) }() - tc, err := tests.NewTestSchedulingCluster(suite.ctx, 2, suite.backendEndpoints) - re.NoError(err) - defer tc.Destroy() - tc.WaitForPrimaryServing(re) - urlPrefix := fmt.Sprintf("%s/pd/api/v1", suite.backendEndpoints) var slice []string var resp map[string]interface{} // Test opeartor - err = testutil.ReadGetJSON(re, testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "operators"), &slice, + err := testutil.ReadGetJSON(re, testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "operators"), &slice, testutil.WithHeader(re, apiutil.ForwardToMicroServiceHeader, "true")) re.NoError(err) re.Len(slice, 0) @@ -241,6 +235,80 @@ func (suite *apiTestSuite) TestAPIForward() { err = testutil.ReadGetJSON(re, testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "hotspot/regions/history"), &history, testutil.WithHeader(re, apiutil.ForwardToMicroServiceHeader, "true")) re.NoError(err) + + // Test rules: only forward `GET` request + var rules []*placement.Rule + tests.MustPutRegion(re, suite.cluster, 2, 1, []byte("a"), []byte("b"), core.SetApproximateSize(60)) + rules = []*placement.Rule{ + { + GroupID: "pd", + ID: "default", + Role: "voter", + Count: 3, + LocationLabels: []string{}, + }, + } + rulesArgs, err := json.Marshal(rules) + suite.NoError(err) + + err = testutil.ReadGetJSON(re, testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "/config/rules"), &rules, + testutil.WithHeader(re, apiutil.ForwardToMicroServiceHeader, "true")) + re.NoError(err) + err = testutil.CheckPostJSON(testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "config/rules"), rulesArgs, + testutil.WithoutHeader(re, apiutil.ForwardToMicroServiceHeader)) + re.NoError(err) + err = testutil.CheckPostJSON(testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "config/rules/batch"), rulesArgs, + testutil.WithoutHeader(re, apiutil.ForwardToMicroServiceHeader)) + re.NoError(err) + err = testutil.ReadGetJSON(re, testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "config/rules/group/pd"), &rules, + testutil.WithHeader(re, apiutil.ForwardToMicroServiceHeader, "true")) + re.NoError(err) + err = testutil.ReadGetJSON(re, testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "config/rules/region/2"), &rules, + testutil.WithHeader(re, apiutil.ForwardToMicroServiceHeader, "true")) + re.NoError(err) + var fit placement.RegionFit + err = testutil.ReadGetJSON(re, testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "config/rules/region/2/detail"), &fit, + testutil.WithHeader(re, apiutil.ForwardToMicroServiceHeader, "true")) + re.NoError(err) + err = testutil.ReadGetJSON(re, testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "config/rules/key/0000000000000001"), &rules, + testutil.WithHeader(re, apiutil.ForwardToMicroServiceHeader, "true")) + re.NoError(err) + err = testutil.CheckGetJSON(testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "config/rule/pd/2"), nil, + testutil.WithHeader(re, apiutil.ForwardToMicroServiceHeader, "true")) + re.NoError(err) + err = testutil.CheckDelete(testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "config/rule/pd/2"), + testutil.WithoutHeader(re, apiutil.ForwardToMicroServiceHeader)) + re.NoError(err) + err = testutil.CheckPostJSON(testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "config/rule"), rulesArgs, + testutil.WithoutHeader(re, apiutil.ForwardToMicroServiceHeader)) + re.NoError(err) + err = testutil.CheckGetJSON(testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "config/rule_group/pd"), nil, + testutil.WithHeader(re, apiutil.ForwardToMicroServiceHeader, "true")) + re.NoError(err) + err = testutil.CheckDelete(testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "config/rule_group/pd"), + testutil.WithoutHeader(re, apiutil.ForwardToMicroServiceHeader)) + re.NoError(err) + err = testutil.CheckPostJSON(testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "config/rule_group"), rulesArgs, + testutil.WithoutHeader(re, apiutil.ForwardToMicroServiceHeader)) + re.NoError(err) + err = testutil.CheckGetJSON(testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "config/rule_groups"), nil, + testutil.WithHeader(re, apiutil.ForwardToMicroServiceHeader, "true")) + re.NoError(err) + err = testutil.CheckGetJSON(testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "config/placement-rule"), nil, + testutil.WithHeader(re, apiutil.ForwardToMicroServiceHeader, "true")) + re.NoError(err) + err = testutil.CheckPostJSON(testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "config/placement-rule"), rulesArgs, + testutil.WithoutHeader(re, apiutil.ForwardToMicroServiceHeader)) + re.NoError(err) + err = testutil.CheckGetJSON(testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "config/placement-rule/pd"), nil, + testutil.WithHeader(re, apiutil.ForwardToMicroServiceHeader, "true")) + re.NoError(err) + err = testutil.CheckDelete(testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "config/placement-rule/pd"), + testutil.WithoutHeader(re, apiutil.ForwardToMicroServiceHeader)) + re.NoError(err) + err = testutil.CheckPostJSON(testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "config/placement-rule/pd"), rulesArgs, + testutil.WithoutHeader(re, apiutil.ForwardToMicroServiceHeader)) + re.NoError(err) } func (suite *apiTestSuite) TestConfig() { diff --git a/tests/pdctl/config/config_test.go b/tests/pdctl/config/config_test.go index 26d70bb955f..2cc8427911a 100644 --- a/tests/pdctl/config/config_test.go +++ b/tests/pdctl/config/config_test.go @@ -19,6 +19,7 @@ import ( "encoding/json" "os" "reflect" + "strings" "testing" "time" @@ -409,9 +410,12 @@ func (suite *configTestSuite) checkPlacementRuleGroups(cluster *tests.TestCluste // test show var group placement.RuleGroup - output, err = pdctl.ExecuteCommand(cmd, "-u", pdAddr, "config", "placement-rules", "rule-group", "show", "pd") - re.NoError(err) - re.NoError(json.Unmarshal(output, &group)) + testutil.Eventually(re, func() bool { // wait for the config to be synced to the scheduling server + output, err = pdctl.ExecuteCommand(cmd, "-u", pdAddr, "config", "placement-rules", "rule-group", "show", "pd") + re.NoError(err) + return !strings.Contains(string(output), "404") + }) + re.NoError(json.Unmarshal(output, &group), string(output)) re.Equal(placement.RuleGroup{ID: "pd"}, group) // test set diff --git a/server/api/rule_test.go b/tests/server/api/rule_test.go similarity index 67% rename from server/api/rule_test.go rename to tests/server/api/rule_test.go index d2dc50f1119..3ee3357e031 100644 --- a/server/api/rule_test.go +++ b/tests/server/api/rule_test.go @@ -25,57 +25,37 @@ import ( "github.com/pingcap/kvproto/pkg/metapb" "github.com/stretchr/testify/suite" "github.com/tikv/pd/pkg/core" + "github.com/tikv/pd/pkg/errs" "github.com/tikv/pd/pkg/schedule/placement" tu "github.com/tikv/pd/pkg/utils/testutil" - "github.com/tikv/pd/server" "github.com/tikv/pd/server/config" + "github.com/tikv/pd/tests" ) type ruleTestSuite struct { suite.Suite - svr *server.Server - cleanup tu.CleanupFunc - urlPrefix string } func TestRuleTestSuite(t *testing.T) { suite.Run(t, new(ruleTestSuite)) } -func (suite *ruleTestSuite) SetupSuite() { - re := suite.Require() - suite.svr, suite.cleanup = mustNewServer(re) - server.MustWaitLeader(re, []*server.Server{suite.svr}) - - addr := suite.svr.GetAddr() - suite.urlPrefix = fmt.Sprintf("%s%s/api/v1/config", addr, apiPrefix) - - mustBootstrapCluster(re, suite.svr) - PDServerCfg := suite.svr.GetConfig().PDServerCfg - PDServerCfg.KeyType = "raw" - err := suite.svr.SetPDServerConfig(PDServerCfg) - suite.NoError(err) - suite.NoError(tu.CheckPostJSON(testDialClient, suite.urlPrefix, []byte(`{"enable-placement-rules":"true"}`), tu.StatusOK(re))) -} - -func (suite *ruleTestSuite) TearDownSuite() { - suite.cleanup() -} - -func (suite *ruleTestSuite) TearDownTest() { - def := placement.GroupBundle{ - ID: "pd", - Rules: []*placement.Rule{ - {GroupID: "pd", ID: "default", Role: "voter", Count: 3}, +func (suite *ruleTestSuite) TestSet() { + opts := []tests.ConfigOption{ + func(conf *config.Config, serverName string) { + conf.PDServerCfg.KeyType = "raw" + conf.Replication.EnablePlacementRules = true }, } - data, err := json.Marshal([]placement.GroupBundle{def}) - suite.NoError(err) - err = tu.CheckPostJSON(testDialClient, suite.urlPrefix+"/placement-rule", data, tu.StatusOK(suite.Require())) - suite.NoError(err) + env := tests.NewSchedulingTestEnvironment(suite.T(), opts...) + env.RunTestInTwoModes(suite.checkSet) } -func (suite *ruleTestSuite) TestSet() { +func (suite *ruleTestSuite) checkSet(cluster *tests.TestCluster) { + leaderServer := cluster.GetLeaderServer() + pdAddr := leaderServer.GetAddr() + urlPrefix := fmt.Sprintf("%s%s/api/v1/config", pdAddr, apiPrefix) + rule := placement.Rule{GroupID: "a", ID: "10", StartKeyHex: "1111", EndKeyHex: "3333", Role: "voter", Count: 1} successData, err := json.Marshal(rule) suite.NoError(err) @@ -159,12 +139,12 @@ func (suite *ruleTestSuite) TestSet() { for _, testCase := range testCases { suite.T().Log(testCase.name) // clear suspect keyRanges to prevent test case from others - suite.svr.GetRaftCluster().ClearSuspectKeyRanges() + leaderServer.GetRaftCluster().ClearSuspectKeyRanges() if testCase.success { - err = tu.CheckPostJSON(testDialClient, suite.urlPrefix+"/rule", testCase.rawData, tu.StatusOK(re)) + err = tu.CheckPostJSON(testDialClient, urlPrefix+"/rule", testCase.rawData, tu.StatusOK(re)) popKeyRangeMap := map[string]struct{}{} for i := 0; i < len(testCase.popKeyRange)/2; i++ { - v, got := suite.svr.GetRaftCluster().PopOneSuspectKeyRange() + v, got := leaderServer.GetRaftCluster().PopOneSuspectKeyRange() suite.True(got) popKeyRangeMap[hex.EncodeToString(v[0])] = struct{}{} popKeyRangeMap[hex.EncodeToString(v[1])] = struct{}{} @@ -175,7 +155,7 @@ func (suite *ruleTestSuite) TestSet() { suite.True(ok) } } else { - err = tu.CheckPostJSON(testDialClient, suite.urlPrefix+"/rule", testCase.rawData, + err = tu.CheckPostJSON(testDialClient, urlPrefix+"/rule", testCase.rawData, tu.StatusNotOK(re), tu.StringEqual(re, testCase.response)) } @@ -184,11 +164,26 @@ func (suite *ruleTestSuite) TestSet() { } func (suite *ruleTestSuite) TestGet() { + opts := []tests.ConfigOption{ + func(conf *config.Config, serverName string) { + conf.PDServerCfg.KeyType = "raw" + conf.Replication.EnablePlacementRules = true + }, + } + env := tests.NewSchedulingTestEnvironment(suite.T(), opts...) + env.RunTestInTwoModes(suite.checkGet) +} + +func (suite *ruleTestSuite) checkGet(cluster *tests.TestCluster) { + leaderServer := cluster.GetLeaderServer() + pdAddr := leaderServer.GetAddr() + urlPrefix := fmt.Sprintf("%s%s/api/v1/config", pdAddr, apiPrefix) + rule := placement.Rule{GroupID: "a", ID: "20", StartKeyHex: "1111", EndKeyHex: "3333", Role: "voter", Count: 1} data, err := json.Marshal(rule) suite.NoError(err) re := suite.Require() - err = tu.CheckPostJSON(testDialClient, suite.urlPrefix+"/rule", data, tu.StatusOK(re)) + err = tu.CheckPostJSON(testDialClient, urlPrefix+"/rule", data, tu.StatusOK(re)) suite.NoError(err) testCases := []struct { @@ -213,7 +208,7 @@ func (suite *ruleTestSuite) TestGet() { for _, testCase := range testCases { suite.T().Log(testCase.name) var resp placement.Rule - url := fmt.Sprintf("%s/rule/%s/%s", suite.urlPrefix, testCase.rule.GroupID, testCase.rule.ID) + url := fmt.Sprintf("%s/rule/%s/%s", urlPrefix, testCase.rule.GroupID, testCase.rule.ID) if testCase.found { err = tu.ReadGetJSON(re, testDialClient, url, &resp) suite.compareRule(&resp, &testCase.rule) @@ -225,20 +220,50 @@ func (suite *ruleTestSuite) TestGet() { } func (suite *ruleTestSuite) TestGetAll() { + opts := []tests.ConfigOption{ + func(conf *config.Config, serverName string) { + conf.PDServerCfg.KeyType = "raw" + conf.Replication.EnablePlacementRules = true + }, + } + env := tests.NewSchedulingTestEnvironment(suite.T(), opts...) + env.RunTestInTwoModes(suite.checkGetAll) +} + +func (suite *ruleTestSuite) checkGetAll(cluster *tests.TestCluster) { + leaderServer := cluster.GetLeaderServer() + pdAddr := leaderServer.GetAddr() + urlPrefix := fmt.Sprintf("%s%s/api/v1/config", pdAddr, apiPrefix) + rule := placement.Rule{GroupID: "b", ID: "20", StartKeyHex: "1111", EndKeyHex: "3333", Role: "voter", Count: 1} data, err := json.Marshal(rule) suite.NoError(err) re := suite.Require() - err = tu.CheckPostJSON(testDialClient, suite.urlPrefix+"/rule", data, tu.StatusOK(re)) + err = tu.CheckPostJSON(testDialClient, urlPrefix+"/rule", data, tu.StatusOK(re)) suite.NoError(err) var resp2 []*placement.Rule - err = tu.ReadGetJSON(re, testDialClient, suite.urlPrefix+"/rules", &resp2) + err = tu.ReadGetJSON(re, testDialClient, urlPrefix+"/rules", &resp2) suite.NoError(err) suite.GreaterOrEqual(len(resp2), 1) } func (suite *ruleTestSuite) TestSetAll() { + opts := []tests.ConfigOption{ + func(conf *config.Config, serverName string) { + conf.PDServerCfg.KeyType = "raw" + conf.Replication.EnablePlacementRules = true + }, + } + env := tests.NewSchedulingTestEnvironment(suite.T(), opts...) + env.RunTestInTwoModes(suite.checkSetAll) +} + +func (suite *ruleTestSuite) checkSetAll(cluster *tests.TestCluster) { + leaderServer := cluster.GetLeaderServer() + pdAddr := leaderServer.GetAddr() + urlPrefix := fmt.Sprintf("%s%s/api/v1/config", pdAddr, apiPrefix) + rule1 := placement.Rule{GroupID: "a", ID: "12", StartKeyHex: "1111", EndKeyHex: "3333", Role: "voter", Count: 1} rule2 := placement.Rule{GroupID: "b", ID: "12", StartKeyHex: "1111", EndKeyHex: "3333", Role: "voter", Count: 1} rule3 := placement.Rule{GroupID: "a", ID: "12", StartKeyHex: "XXXX", EndKeyHex: "3333", Role: "voter", Count: 1} @@ -247,10 +272,10 @@ func (suite *ruleTestSuite) TestSetAll() { LocationLabels: []string{"host"}} rule6 := placement.Rule{GroupID: "pd", ID: "default", StartKeyHex: "", EndKeyHex: "", Role: "voter", Count: 3} - suite.svr.GetPersistOptions().GetReplicationConfig().LocationLabels = []string{"host"} - defaultRule := suite.svr.GetRaftCluster().GetRuleManager().GetRule("pd", "default") + leaderServer.GetPersistOptions().GetReplicationConfig().LocationLabels = []string{"host"} + defaultRule := leaderServer.GetRaftCluster().GetRuleManager().GetRule("pd", "default") defaultRule.LocationLabels = []string{"host"} - suite.svr.GetRaftCluster().GetRuleManager().SetRule(defaultRule) + leaderServer.GetRaftCluster().GetRuleManager().SetRule(defaultRule) successData, err := json.Marshal([]*placement.Rule{&rule1, &rule2}) suite.NoError(err) @@ -333,13 +358,13 @@ func (suite *ruleTestSuite) TestSetAll() { for _, testCase := range testCases { suite.T().Log(testCase.name) if testCase.success { - err := tu.CheckPostJSON(testDialClient, suite.urlPrefix+"/rules", testCase.rawData, tu.StatusOK(re)) + err := tu.CheckPostJSON(testDialClient, urlPrefix+"/rules", testCase.rawData, tu.StatusOK(re)) suite.NoError(err) if testCase.isDefaultRule { - suite.Equal(int(suite.svr.GetPersistOptions().GetReplicationConfig().MaxReplicas), testCase.count) + suite.Equal(int(leaderServer.GetPersistOptions().GetReplicationConfig().MaxReplicas), testCase.count) } } else { - err := tu.CheckPostJSON(testDialClient, suite.urlPrefix+"/rules", testCase.rawData, + err := tu.CheckPostJSON(testDialClient, urlPrefix+"/rules", testCase.rawData, tu.StringEqual(re, testCase.response)) suite.NoError(err) } @@ -347,17 +372,32 @@ func (suite *ruleTestSuite) TestSetAll() { } func (suite *ruleTestSuite) TestGetAllByGroup() { + opts := []tests.ConfigOption{ + func(conf *config.Config, serverName string) { + conf.PDServerCfg.KeyType = "raw" + conf.Replication.EnablePlacementRules = true + }, + } + env := tests.NewSchedulingTestEnvironment(suite.T(), opts...) + env.RunTestInTwoModes(suite.checkGetAllByGroup) +} + +func (suite *ruleTestSuite) checkGetAllByGroup(cluster *tests.TestCluster) { + leaderServer := cluster.GetLeaderServer() + pdAddr := leaderServer.GetAddr() + urlPrefix := fmt.Sprintf("%s%s/api/v1/config", pdAddr, apiPrefix) + re := suite.Require() rule := placement.Rule{GroupID: "c", ID: "20", StartKeyHex: "1111", EndKeyHex: "3333", Role: "voter", Count: 1} data, err := json.Marshal(rule) suite.NoError(err) - err = tu.CheckPostJSON(testDialClient, suite.urlPrefix+"/rule", data, tu.StatusOK(re)) + err = tu.CheckPostJSON(testDialClient, urlPrefix+"/rule", data, tu.StatusOK(re)) suite.NoError(err) rule1 := placement.Rule{GroupID: "c", ID: "30", StartKeyHex: "1111", EndKeyHex: "3333", Role: "voter", Count: 1} data, err = json.Marshal(rule1) suite.NoError(err) - err = tu.CheckPostJSON(testDialClient, suite.urlPrefix+"/rule", data, tu.StatusOK(re)) + err = tu.CheckPostJSON(testDialClient, urlPrefix+"/rule", data, tu.StatusOK(re)) suite.NoError(err) testCases := []struct { @@ -380,7 +420,7 @@ func (suite *ruleTestSuite) TestGetAllByGroup() { for _, testCase := range testCases { suite.T().Log(testCase.name) var resp []*placement.Rule - url := fmt.Sprintf("%s/rules/group/%s", suite.urlPrefix, testCase.groupID) + url := fmt.Sprintf("%s/rules/group/%s", urlPrefix, testCase.groupID) err = tu.ReadGetJSON(re, testDialClient, url, &resp) suite.NoError(err) suite.Len(resp, testCase.count) @@ -392,15 +432,30 @@ func (suite *ruleTestSuite) TestGetAllByGroup() { } func (suite *ruleTestSuite) TestGetAllByRegion() { + opts := []tests.ConfigOption{ + func(conf *config.Config, serverName string) { + conf.PDServerCfg.KeyType = "raw" + conf.Replication.EnablePlacementRules = true + }, + } + env := tests.NewSchedulingTestEnvironment(suite.T(), opts...) + env.RunTestInTwoModes(suite.checkGetAllByRegion) +} + +func (suite *ruleTestSuite) checkGetAllByRegion(cluster *tests.TestCluster) { + leaderServer := cluster.GetLeaderServer() + pdAddr := leaderServer.GetAddr() + urlPrefix := fmt.Sprintf("%s%s/api/v1/config", pdAddr, apiPrefix) + rule := placement.Rule{GroupID: "e", ID: "20", StartKeyHex: "1111", EndKeyHex: "3333", Role: "voter", Count: 1} data, err := json.Marshal(rule) suite.NoError(err) re := suite.Require() - err = tu.CheckPostJSON(testDialClient, suite.urlPrefix+"/rule", data, tu.StatusOK(re)) + err = tu.CheckPostJSON(testDialClient, urlPrefix+"/rule", data, tu.StatusOK(re)) suite.NoError(err) r := core.NewTestRegionInfo(4, 1, []byte{0x22, 0x22}, []byte{0x33, 0x33}) - mustRegionHeartbeat(re, suite.svr, r) + tests.MustPutRegionInfo(re, cluster, r) testCases := []struct { name string @@ -429,7 +484,7 @@ func (suite *ruleTestSuite) TestGetAllByRegion() { for _, testCase := range testCases { suite.T().Log(testCase.name) var resp []*placement.Rule - url := fmt.Sprintf("%s/rules/region/%s", suite.urlPrefix, testCase.regionID) + url := fmt.Sprintf("%s/rules/region/%s", urlPrefix, testCase.regionID) if testCase.success { err = tu.ReadGetJSON(re, testDialClient, url, &resp) @@ -446,11 +501,26 @@ func (suite *ruleTestSuite) TestGetAllByRegion() { } func (suite *ruleTestSuite) TestGetAllByKey() { + opts := []tests.ConfigOption{ + func(conf *config.Config, serverName string) { + conf.PDServerCfg.KeyType = "raw" + conf.Replication.EnablePlacementRules = true + }, + } + env := tests.NewSchedulingTestEnvironment(suite.T(), opts...) + env.RunTestInTwoModes(suite.checkGetAllByKey) +} + +func (suite *ruleTestSuite) checkGetAllByKey(cluster *tests.TestCluster) { + leaderServer := cluster.GetLeaderServer() + pdAddr := leaderServer.GetAddr() + urlPrefix := fmt.Sprintf("%s%s/api/v1/config", pdAddr, apiPrefix) + rule := placement.Rule{GroupID: "f", ID: "40", StartKeyHex: "8888", EndKeyHex: "9111", Role: "voter", Count: 1} data, err := json.Marshal(rule) suite.NoError(err) re := suite.Require() - err = tu.CheckPostJSON(testDialClient, suite.urlPrefix+"/rule", data, tu.StatusOK(re)) + err = tu.CheckPostJSON(testDialClient, urlPrefix+"/rule", data, tu.StatusOK(re)) suite.NoError(err) testCases := []struct { @@ -483,7 +553,7 @@ func (suite *ruleTestSuite) TestGetAllByKey() { for _, testCase := range testCases { suite.T().Log(testCase.name) var resp []*placement.Rule - url := fmt.Sprintf("%s/rules/key/%s", suite.urlPrefix, testCase.key) + url := fmt.Sprintf("%s/rules/key/%s", urlPrefix, testCase.key) if testCase.success { err = tu.ReadGetJSON(re, testDialClient, url, &resp) suite.Len(resp, testCase.respSize) @@ -495,10 +565,25 @@ func (suite *ruleTestSuite) TestGetAllByKey() { } func (suite *ruleTestSuite) TestDelete() { + opts := []tests.ConfigOption{ + func(conf *config.Config, serverName string) { + conf.PDServerCfg.KeyType = "raw" + conf.Replication.EnablePlacementRules = true + }, + } + env := tests.NewSchedulingTestEnvironment(suite.T(), opts...) + env.RunTestInTwoModes(suite.checkDelete) +} + +func (suite *ruleTestSuite) checkDelete(cluster *tests.TestCluster) { + leaderServer := cluster.GetLeaderServer() + pdAddr := leaderServer.GetAddr() + urlPrefix := fmt.Sprintf("%s%s/api/v1/config", pdAddr, apiPrefix) + rule := placement.Rule{GroupID: "g", ID: "10", StartKeyHex: "8888", EndKeyHex: "9111", Role: "voter", Count: 1} data, err := json.Marshal(rule) suite.NoError(err) - err = tu.CheckPostJSON(testDialClient, suite.urlPrefix+"/rule", data, tu.StatusOK(suite.Require())) + err = tu.CheckPostJSON(testDialClient, urlPrefix+"/rule", data, tu.StatusOK(suite.Require())) suite.NoError(err) oldStartKey, err := hex.DecodeString(rule.StartKeyHex) suite.NoError(err) @@ -529,15 +614,15 @@ func (suite *ruleTestSuite) TestDelete() { } for _, testCase := range testCases { suite.T().Log(testCase.name) - url := fmt.Sprintf("%s/rule/%s/%s", suite.urlPrefix, testCase.groupID, testCase.id) + url := fmt.Sprintf("%s/rule/%s/%s", urlPrefix, testCase.groupID, testCase.id) // clear suspect keyRanges to prevent test case from others - suite.svr.GetRaftCluster().ClearSuspectKeyRanges() + leaderServer.GetRaftCluster().ClearSuspectKeyRanges() err = tu.CheckDelete(testDialClient, url, tu.StatusOK(suite.Require())) suite.NoError(err) if len(testCase.popKeyRange) > 0 { popKeyRangeMap := map[string]struct{}{} for i := 0; i < len(testCase.popKeyRange)/2; i++ { - v, got := suite.svr.GetRaftCluster().PopOneSuspectKeyRange() + v, got := leaderServer.GetRaftCluster().PopOneSuspectKeyRange() suite.True(got) popKeyRangeMap[hex.EncodeToString(v[0])] = struct{}{} popKeyRangeMap[hex.EncodeToString(v[1])] = struct{}{} @@ -551,16 +636,22 @@ func (suite *ruleTestSuite) TestDelete() { } } -func (suite *ruleTestSuite) compareRule(r1 *placement.Rule, r2 *placement.Rule) { - suite.Equal(r2.GroupID, r1.GroupID) - suite.Equal(r2.ID, r1.ID) - suite.Equal(r2.StartKeyHex, r1.StartKeyHex) - suite.Equal(r2.EndKeyHex, r1.EndKeyHex) - suite.Equal(r2.Role, r1.Role) - suite.Equal(r2.Count, r1.Count) +func (suite *ruleTestSuite) TestBatch() { + opts := []tests.ConfigOption{ + func(conf *config.Config, serverName string) { + conf.PDServerCfg.KeyType = "raw" + conf.Replication.EnablePlacementRules = true + }, + } + env := tests.NewSchedulingTestEnvironment(suite.T(), opts...) + env.RunTestInTwoModes(suite.checkBatch) } -func (suite *ruleTestSuite) TestBatch() { +func (suite *ruleTestSuite) checkBatch(cluster *tests.TestCluster) { + leaderServer := cluster.GetLeaderServer() + pdAddr := leaderServer.GetAddr() + urlPrefix := fmt.Sprintf("%s%s/api/v1/config", pdAddr, apiPrefix) + opt1 := placement.RuleOp{ Action: placement.RuleOpAdd, Rule: &placement.Rule{GroupID: "a", ID: "13", StartKeyHex: "1111", EndKeyHex: "3333", Role: "voter", Count: 1}, @@ -670,10 +761,10 @@ func (suite *ruleTestSuite) TestBatch() { for _, testCase := range testCases { suite.T().Log(testCase.name) if testCase.success { - err := tu.CheckPostJSON(testDialClient, suite.urlPrefix+"/rules/batch", testCase.rawData, tu.StatusOK(re)) + err := tu.CheckPostJSON(testDialClient, urlPrefix+"/rules/batch", testCase.rawData, tu.StatusOK(re)) suite.NoError(err) } else { - err := tu.CheckPostJSON(testDialClient, suite.urlPrefix+"/rules/batch", testCase.rawData, + err := tu.CheckPostJSON(testDialClient, urlPrefix+"/rules/batch", testCase.rawData, tu.StatusNotOK(re), tu.StringEqual(re, testCase.response)) suite.NoError(err) @@ -682,6 +773,21 @@ func (suite *ruleTestSuite) TestBatch() { } func (suite *ruleTestSuite) TestBundle() { + opts := []tests.ConfigOption{ + func(conf *config.Config, serverName string) { + conf.PDServerCfg.KeyType = "raw" + conf.Replication.EnablePlacementRules = true + }, + } + env := tests.NewSchedulingTestEnvironment(suite.T(), opts...) + env.RunTestInTwoModes(suite.checkBundle) +} + +func (suite *ruleTestSuite) checkBundle(cluster *tests.TestCluster) { + leaderServer := cluster.GetLeaderServer() + pdAddr := leaderServer.GetAddr() + urlPrefix := fmt.Sprintf("%s%s/api/v1/config", pdAddr, apiPrefix) + re := suite.Require() // GetAll b1 := placement.GroupBundle{ @@ -691,7 +797,7 @@ func (suite *ruleTestSuite) TestBundle() { }, } var bundles []placement.GroupBundle - err := tu.ReadGetJSON(re, testDialClient, suite.urlPrefix+"/placement-rule", &bundles) + err := tu.ReadGetJSON(re, testDialClient, urlPrefix+"/placement-rule", &bundles) suite.NoError(err) suite.Len(bundles, 1) suite.compareBundle(bundles[0], b1) @@ -707,28 +813,28 @@ func (suite *ruleTestSuite) TestBundle() { } data, err := json.Marshal(b2) suite.NoError(err) - err = tu.CheckPostJSON(testDialClient, suite.urlPrefix+"/placement-rule/foo", data, tu.StatusOK(re)) + err = tu.CheckPostJSON(testDialClient, urlPrefix+"/placement-rule/foo", data, tu.StatusOK(re)) suite.NoError(err) // Get var bundle placement.GroupBundle - err = tu.ReadGetJSON(re, testDialClient, suite.urlPrefix+"/placement-rule/foo", &bundle) + err = tu.ReadGetJSON(re, testDialClient, urlPrefix+"/placement-rule/foo", &bundle) suite.NoError(err) suite.compareBundle(bundle, b2) // GetAll again - err = tu.ReadGetJSON(re, testDialClient, suite.urlPrefix+"/placement-rule", &bundles) + err = tu.ReadGetJSON(re, testDialClient, urlPrefix+"/placement-rule", &bundles) suite.NoError(err) suite.Len(bundles, 2) suite.compareBundle(bundles[0], b1) suite.compareBundle(bundles[1], b2) // Delete - err = tu.CheckDelete(testDialClient, suite.urlPrefix+"/placement-rule/pd", tu.StatusOK(suite.Require())) + err = tu.CheckDelete(testDialClient, urlPrefix+"/placement-rule/pd", tu.StatusOK(suite.Require())) suite.NoError(err) // GetAll again - err = tu.ReadGetJSON(re, testDialClient, suite.urlPrefix+"/placement-rule", &bundles) + err = tu.ReadGetJSON(re, testDialClient, urlPrefix+"/placement-rule", &bundles) suite.NoError(err) suite.Len(bundles, 1) suite.compareBundle(bundles[0], b2) @@ -739,11 +845,11 @@ func (suite *ruleTestSuite) TestBundle() { b3 := placement.GroupBundle{ID: "foobar", Index: 100} data, err = json.Marshal([]placement.GroupBundle{b1, b2, b3}) suite.NoError(err) - err = tu.CheckPostJSON(testDialClient, suite.urlPrefix+"/placement-rule", data, tu.StatusOK(re)) + err = tu.CheckPostJSON(testDialClient, urlPrefix+"/placement-rule", data, tu.StatusOK(re)) suite.NoError(err) // GetAll again - err = tu.ReadGetJSON(re, testDialClient, suite.urlPrefix+"/placement-rule", &bundles) + err = tu.ReadGetJSON(re, testDialClient, urlPrefix+"/placement-rule", &bundles) suite.NoError(err) suite.Len(bundles, 3) suite.compareBundle(bundles[0], b2) @@ -751,11 +857,11 @@ func (suite *ruleTestSuite) TestBundle() { suite.compareBundle(bundles[2], b3) // Delete using regexp - err = tu.CheckDelete(testDialClient, suite.urlPrefix+"/placement-rule/"+url.PathEscape("foo.*")+"?regexp", tu.StatusOK(suite.Require())) + err = tu.CheckDelete(testDialClient, urlPrefix+"/placement-rule/"+url.PathEscape("foo.*")+"?regexp", tu.StatusOK(suite.Require())) suite.NoError(err) // GetAll again - err = tu.ReadGetJSON(re, testDialClient, suite.urlPrefix+"/placement-rule", &bundles) + err = tu.ReadGetJSON(re, testDialClient, urlPrefix+"/placement-rule", &bundles) suite.NoError(err) suite.Len(bundles, 1) suite.compareBundle(bundles[0], b1) @@ -770,19 +876,19 @@ func (suite *ruleTestSuite) TestBundle() { } data, err = json.Marshal(b4) suite.NoError(err) - err = tu.CheckPostJSON(testDialClient, suite.urlPrefix+"/placement-rule/"+id, data, tu.StatusOK(re)) + err = tu.CheckPostJSON(testDialClient, urlPrefix+"/placement-rule/"+id, data, tu.StatusOK(re)) suite.NoError(err) b4.ID = id b4.Rules[0].GroupID = b4.ID // Get - err = tu.ReadGetJSON(re, testDialClient, suite.urlPrefix+"/placement-rule/"+id, &bundle) + err = tu.ReadGetJSON(re, testDialClient, urlPrefix+"/placement-rule/"+id, &bundle) suite.NoError(err) suite.compareBundle(bundle, b4) // GetAll again - err = tu.ReadGetJSON(re, testDialClient, suite.urlPrefix+"/placement-rule", &bundles) + err = tu.ReadGetJSON(re, testDialClient, urlPrefix+"/placement-rule", &bundles) suite.NoError(err) suite.Len(bundles, 2) suite.compareBundle(bundles[0], b1) @@ -798,13 +904,13 @@ func (suite *ruleTestSuite) TestBundle() { } data, err = json.Marshal([]placement.GroupBundle{b1, b4, b5}) suite.NoError(err) - err = tu.CheckPostJSON(testDialClient, suite.urlPrefix+"/placement-rule", data, tu.StatusOK(re)) + err = tu.CheckPostJSON(testDialClient, urlPrefix+"/placement-rule", data, tu.StatusOK(re)) suite.NoError(err) b5.Rules[0].GroupID = b5.ID // GetAll again - err = tu.ReadGetJSON(re, testDialClient, suite.urlPrefix+"/placement-rule", &bundles) + err = tu.ReadGetJSON(re, testDialClient, urlPrefix+"/placement-rule", &bundles) suite.NoError(err) suite.Len(bundles, 3) suite.compareBundle(bundles[0], b1) @@ -813,6 +919,21 @@ func (suite *ruleTestSuite) TestBundle() { } func (suite *ruleTestSuite) TestBundleBadRequest() { + opts := []tests.ConfigOption{ + func(conf *config.Config, serverName string) { + conf.PDServerCfg.KeyType = "raw" + conf.Replication.EnablePlacementRules = true + }, + } + env := tests.NewSchedulingTestEnvironment(suite.T(), opts...) + env.RunTestInTwoModes(suite.checkBundleBadRequest) +} + +func (suite *ruleTestSuite) checkBundleBadRequest(cluster *tests.TestCluster) { + leaderServer := cluster.GetLeaderServer() + pdAddr := leaderServer.GetAddr() + urlPrefix := fmt.Sprintf("%s%s/api/v1/config", pdAddr, apiPrefix) + testCases := []struct { uri string data string @@ -826,7 +947,7 @@ func (suite *ruleTestSuite) TestBundleBadRequest() { {"/placement-rule", `[{"group_id":"foo", "rules": [{"group_id":"bar", "id":"baz", "role":"voter", "count":1}]}]`, false}, } for _, testCase := range testCases { - err := tu.CheckPostJSON(testDialClient, suite.urlPrefix+testCase.uri, []byte(testCase.data), + err := tu.CheckPostJSON(testDialClient, urlPrefix+testCase.uri, []byte(testCase.data), func(_ []byte, code int, _ http.Header) { suite.Equal(testCase.ok, code == http.StatusOK) }) @@ -844,22 +965,42 @@ func (suite *ruleTestSuite) compareBundle(b1, b2 placement.GroupBundle) { } } +func (suite *ruleTestSuite) compareRule(r1 *placement.Rule, r2 *placement.Rule) { + suite.Equal(r2.GroupID, r1.GroupID) + suite.Equal(r2.ID, r1.ID) + suite.Equal(r2.StartKeyHex, r1.StartKeyHex) + suite.Equal(r2.EndKeyHex, r1.EndKeyHex) + suite.Equal(r2.Role, r1.Role) + suite.Equal(r2.Count, r1.Count) +} + type regionRuleTestSuite struct { suite.Suite - svr *server.Server - grpcSvr *server.GrpcServer - cleanup tu.CleanupFunc - urlPrefix string - stores []*metapb.Store - regions []*core.RegionInfo } func TestRegionRuleTestSuite(t *testing.T) { suite.Run(t, new(regionRuleTestSuite)) } -func (suite *regionRuleTestSuite) SetupSuite() { - suite.stores = []*metapb.Store{ +func (suite *regionRuleTestSuite) TestRegionPlacementRule() { + opts := []tests.ConfigOption{ + func(conf *config.Config, serverName string) { + conf.Replication.EnablePlacementRules = true + conf.Replication.MaxReplicas = 1 + }, + } + env := tests.NewSchedulingTestEnvironment(suite.T(), opts...) + // FIXME: enable this test in two modes after we support region label forward. + env.RunTestInPDMode(suite.checkRegionPlacementRule) +} + +func (suite *regionRuleTestSuite) checkRegionPlacementRule(cluster *tests.TestCluster) { + re := suite.Require() + leaderServer := cluster.GetLeaderServer() + pdAddr := leaderServer.GetAddr() + urlPrefix := fmt.Sprintf("%s%s/api/v1", pdAddr, apiPrefix) + + stores := []*metapb.Store{ { Id: 1, Address: "tikv1", @@ -875,49 +1016,30 @@ func (suite *regionRuleTestSuite) SetupSuite() { Version: "2.0.0", }, } - re := suite.Require() - suite.svr, suite.cleanup = mustNewServer(re, func(cfg *config.Config) { - cfg.Replication.EnablePlacementRules = true - cfg.Replication.MaxReplicas = 1 - }) - server.MustWaitLeader(re, []*server.Server{suite.svr}) - - addr := suite.svr.GetAddr() - suite.grpcSvr = &server.GrpcServer{Server: suite.svr} - suite.urlPrefix = fmt.Sprintf("%s%s/api/v1", addr, apiPrefix) - - mustBootstrapCluster(re, suite.svr) - - for _, store := range suite.stores { - mustPutStore(re, suite.svr, store.Id, store.State, store.NodeState, nil) + for _, store := range stores { + tests.MustPutStore(re, cluster, store) } - suite.regions = make([]*core.RegionInfo, 0) + regions := make([]*core.RegionInfo, 0) peers1 := []*metapb.Peer{ {Id: 102, StoreId: 1, Role: metapb.PeerRole_Voter}, {Id: 103, StoreId: 2, Role: metapb.PeerRole_Voter}} - suite.regions = append(suite.regions, core.NewRegionInfo(&metapb.Region{Id: 1, Peers: peers1, RegionEpoch: &metapb.RegionEpoch{ConfVer: 1, Version: 1}}, peers1[0], + regions = append(regions, core.NewRegionInfo(&metapb.Region{Id: 1, Peers: peers1, RegionEpoch: &metapb.RegionEpoch{ConfVer: 1, Version: 1}}, peers1[0], core.WithStartKey([]byte("abc")), core.WithEndKey([]byte("def")))) peers2 := []*metapb.Peer{ {Id: 104, StoreId: 1, Role: metapb.PeerRole_Voter}, {Id: 105, StoreId: 2, Role: metapb.PeerRole_Learner}} - suite.regions = append(suite.regions, core.NewRegionInfo(&metapb.Region{Id: 2, Peers: peers2, RegionEpoch: &metapb.RegionEpoch{ConfVer: 2, Version: 2}}, peers2[0], + regions = append(regions, core.NewRegionInfo(&metapb.Region{Id: 2, Peers: peers2, RegionEpoch: &metapb.RegionEpoch{ConfVer: 2, Version: 2}}, peers2[0], core.WithStartKey([]byte("ghi")), core.WithEndKey([]byte("jkl")))) peers3 := []*metapb.Peer{ {Id: 106, StoreId: 1, Role: metapb.PeerRole_Voter}, {Id: 107, StoreId: 2, Role: metapb.PeerRole_Learner}} - suite.regions = append(suite.regions, core.NewRegionInfo(&metapb.Region{Id: 3, Peers: peers3, RegionEpoch: &metapb.RegionEpoch{ConfVer: 3, Version: 3}}, peers3[0], + regions = append(regions, core.NewRegionInfo(&metapb.Region{Id: 3, Peers: peers3, RegionEpoch: &metapb.RegionEpoch{ConfVer: 3, Version: 3}}, peers3[0], core.WithStartKey([]byte("mno")), core.WithEndKey([]byte("pqr")))) - for _, rg := range suite.regions { - suite.svr.GetBasicCluster().PutRegion(rg) + for _, rg := range regions { + tests.MustPutRegionInfo(re, cluster, rg) } -} - -func (suite *regionRuleTestSuite) TearDownSuite() { - suite.cleanup() -} -func (suite *regionRuleTestSuite) TestRegionPlacementRule() { - ruleManager := suite.svr.GetRaftCluster().GetRuleManager() + ruleManager := leaderServer.GetRaftCluster().GetRuleManager() ruleManager.SetRule(&placement.Rule{ GroupID: "test", ID: "test2", @@ -934,38 +1056,38 @@ func (suite *regionRuleTestSuite) TestRegionPlacementRule() { Role: placement.Learner, Count: 1, }) - re := suite.Require() - url := fmt.Sprintf("%s/config/rules/region/%d/detail", suite.urlPrefix, 1) fit := &placement.RegionFit{} + + url := fmt.Sprintf("%s/config/rules/region/%d/detail", urlPrefix, 1) err := tu.ReadGetJSON(re, testDialClient, url, fit) + suite.NoError(err) suite.Equal(len(fit.RuleFits), 1) suite.Equal(len(fit.OrphanPeers), 1) - suite.NoError(err) - url = fmt.Sprintf("%s/config/rules/region/%d/detail", suite.urlPrefix, 2) + url = fmt.Sprintf("%s/config/rules/region/%d/detail", urlPrefix, 2) fit = &placement.RegionFit{} err = tu.ReadGetJSON(re, testDialClient, url, fit) + suite.NoError(err) suite.Equal(len(fit.RuleFits), 2) suite.Equal(len(fit.OrphanPeers), 0) - suite.NoError(err) - url = fmt.Sprintf("%s/config/rules/region/%d/detail", suite.urlPrefix, 3) + url = fmt.Sprintf("%s/config/rules/region/%d/detail", urlPrefix, 3) fit = &placement.RegionFit{} err = tu.ReadGetJSON(re, testDialClient, url, fit) + suite.NoError(err) suite.Equal(len(fit.RuleFits), 0) suite.Equal(len(fit.OrphanPeers), 2) - suite.NoError(err) - url = fmt.Sprintf("%s/config/rules/region/%d/detail", suite.urlPrefix, 4) + url = fmt.Sprintf("%s/config/rules/region/%d/detail", urlPrefix, 4) err = tu.CheckGetJSON(testDialClient, url, nil, tu.Status(re, http.StatusNotFound), tu.StringContain( re, "region 4 not found")) suite.NoError(err) - url = fmt.Sprintf("%s/config/rules/region/%s/detail", suite.urlPrefix, "id") + url = fmt.Sprintf("%s/config/rules/region/%s/detail", urlPrefix, "id") err = tu.CheckGetJSON(testDialClient, url, nil, tu.Status(re, http.StatusBadRequest), tu.StringContain( - re, "invalid region id")) + re, errs.ErrRegionInvalidID.Error())) suite.NoError(err) - suite.svr.GetRaftCluster().GetReplicationConfig().EnablePlacementRules = false - url = fmt.Sprintf("%s/config/rules/region/%d/detail", suite.urlPrefix, 1) + leaderServer.GetRaftCluster().GetReplicationConfig().EnablePlacementRules = false + url = fmt.Sprintf("%s/config/rules/region/%d/detail", urlPrefix, 1) err = tu.CheckGetJSON(testDialClient, url, nil, tu.Status(re, http.StatusPreconditionFailed), tu.StringContain( re, "placement rules feature is disabled")) suite.NoError(err)