diff --git a/pkg/mcs/scheduling/server/apis/v1/api.go b/pkg/mcs/scheduling/server/apis/v1/api.go index 39be00ef9a0..ed998b9b62d 100644 --- a/pkg/mcs/scheduling/server/apis/v1/api.go +++ b/pkg/mcs/scheduling/server/apis/v1/api.go @@ -17,6 +17,7 @@ package apis import ( "fmt" "net/http" + "net/url" "strconv" "sync" @@ -26,6 +27,7 @@ import ( "github.com/gin-gonic/gin" "github.com/joho/godotenv" "github.com/pingcap/log" + "github.com/tikv/pd/pkg/errs" scheserver "github.com/tikv/pd/pkg/mcs/scheduling/server" mcsutils "github.com/tikv/pd/pkg/mcs/utils" sche "github.com/tikv/pd/pkg/schedule/core" @@ -114,6 +116,8 @@ func NewService(srv *scheserver.Service) *Service { s.RegisterSchedulersRouter() s.RegisterCheckersRouter() s.RegisterHotspotRouter() + s.RegisterRegionsRouter() + s.RegisterRegionLabelRouter() return s } @@ -160,6 +164,21 @@ func (s *Service) RegisterOperatorsRouter() { router.GET("/records", getOperatorRecords) } +// RegisterRegionsRouter registers the router of the regions handler. +func (s *Service) RegisterRegionsRouter() { + router := s.root.Group("regions") + router.GET("/:id/label/:key", getRegionLabelByKey) + router.GET("/:id/labels", getRegionLabels) +} + +// RegisterRegionLabelRouter registers the router of the region label handler. +func (s *Service) RegisterRegionLabelRouter() { + router := s.root.Group("config/region-label") + router.GET("rules", getAllRegionLabelRules) + router.GET("rules/ids", getRegionLabelRulesByIDs) + router.GET("rule/:id", getRegionLabelRuleByID) +} + func changeLogLevel(c *gin.Context) { svr := c.MustGet(multiservicesapi.ServiceContextKey).(*scheserver.Server) var level string @@ -548,3 +567,144 @@ func getHistoryHotRegions(c *gin.Context) { var res storage.HistoryHotRegions c.IndentedJSON(http.StatusOK, res) } + +// @Tags region_label +// @Summary Get label of a region. +// @Param id path integer true "Region Id" +// @Param key path string true "Label key" +// @Produce json +// @Success 200 {string} string +// @Failure 400 {string} string "The input is invalid." +// @Failure 404 {string} string "The region does not exist." +// @Router /regions/{id}/label/{key} [get] +func getRegionLabelByKey(c *gin.Context) { + handler := c.MustGet(handlerKey).(*handler.Handler) + + idStr := c.Param("id") + labelKey := c.Param("key") // TODO: test https://github.com/tikv/pd/pull/4004 + + id, err := strconv.ParseUint(idStr, 10, 64) + if err != nil { + c.String(http.StatusBadRequest, err.Error()) + return + } + + region := handler.GetRegion(id) + if region == nil { + c.String(http.StatusNotFound, errs.ErrRegionNotFound.FastGenByArgs().Error()) + return + } + + l, err := handler.GetRegionLabeler() + if err != nil { + c.String(http.StatusInternalServerError, err.Error()) + return + } + labelValue := l.GetRegionLabel(region, labelKey) + c.IndentedJSON(http.StatusOK, labelValue) +} + +// @Tags region_label +// @Summary Get labels of a region. +// @Param id path integer true "Region Id" +// @Produce json +// @Success 200 {string} string +// @Failure 400 {string} string "The input is invalid." +// @Failure 404 {string} string "The region does not exist." +// @Router /regions/{id}/labels [get] +func getRegionLabels(c *gin.Context) { + handler := c.MustGet(handlerKey).(*handler.Handler) + + idStr := c.Param("id") + id, err := strconv.ParseUint(idStr, 10, 64) + if err != nil { + c.String(http.StatusBadRequest, err.Error()) + return + } + + region := handler.GetRegion(id) + if region == nil { + c.String(http.StatusNotFound, errs.ErrRegionNotFound.FastGenByArgs().Error()) + return + } + l, err := handler.GetRegionLabeler() + if err != nil { + c.String(http.StatusInternalServerError, err.Error()) + return + } + labels := l.GetRegionLabels(region) + c.IndentedJSON(http.StatusOK, labels) +} + +// @Tags region_label +// @Summary List all label rules of cluster. +// @Produce json +// @Success 200 {array} labeler.LabelRule +// @Router /config/region-label/rules [get] +func getAllRegionLabelRules(c *gin.Context) { + handler := c.MustGet(handlerKey).(*handler.Handler) + l, err := handler.GetRegionLabeler() + if err != nil { + c.String(http.StatusInternalServerError, err.Error()) + return + } + rules := l.GetAllLabelRules() + c.IndentedJSON(http.StatusOK, rules) +} + +// @Tags region_label +// @Summary Get label rules of cluster by ids. +// @Param body body []string true "IDs of query rules" +// @Produce json +// @Success 200 {array} labeler.LabelRule +// @Failure 400 {string} string "The input is invalid." +// @Failure 500 {string} string "PD server failed to proceed the request." +// @Router /config/region-label/rules/ids [get] +func getRegionLabelRulesByIDs(c *gin.Context) { + handler := c.MustGet(handlerKey).(*handler.Handler) + l, err := handler.GetRegionLabeler() + if err != nil { + c.String(http.StatusInternalServerError, err.Error()) + return + } + var ids []string + if err := c.BindJSON(&ids); err != nil { + c.String(http.StatusBadRequest, err.Error()) + return + } + rules, err := l.GetLabelRules(ids) + if err != nil { + c.String(http.StatusInternalServerError, err.Error()) + return + } + c.IndentedJSON(http.StatusOK, rules) +} + +// @Tags region_label +// @Summary Get label rule of cluster by id. +// @Param id path string true "Rule Id" +// @Produce json +// @Success 200 {object} labeler.LabelRule +// @Failure 404 {string} string "The rule does not exist." +// @Router /config/region-label/rule/{id} [get] +func getRegionLabelRuleByID(c *gin.Context) { + handler := c.MustGet(handlerKey).(*handler.Handler) + + id, err := url.PathUnescape(c.Param("id")) + if err != nil { + c.String(http.StatusBadRequest, err.Error()) + return + } + + l, err := handler.GetRegionLabeler() + if err != nil { + c.String(http.StatusInternalServerError, err.Error()) + return + } + rule := l.GetLabelRule(id) + if rule == nil { + c.String(http.StatusNotFound, errs.ErrRegionRuleNotFound.FastGenByArgs().Error()) + return + } + c.IndentedJSON(http.StatusOK, rule) +} diff --git a/pkg/schedule/handler/handler.go b/pkg/schedule/handler/handler.go index fca43f3eeeb..c0cee81d27e 100644 --- a/pkg/schedule/handler/handler.go +++ b/pkg/schedule/handler/handler.go @@ -30,6 +30,7 @@ import ( "github.com/tikv/pd/pkg/schedule" sche "github.com/tikv/pd/pkg/schedule/core" "github.com/tikv/pd/pkg/schedule/filter" + "github.com/tikv/pd/pkg/schedule/labeler" "github.com/tikv/pd/pkg/schedule/operator" "github.com/tikv/pd/pkg/schedule/placement" "github.com/tikv/pd/pkg/schedule/scatter" @@ -1040,3 +1041,21 @@ func (h *Handler) GetHotBuckets(regionIDs ...uint64) (HotBucketsResponse, error) } return ret, nil } + +// GetRegion returns the region labeler. +func (h *Handler) GetRegion(id uint64) *core.RegionInfo { + c := h.GetCluster() + if c == nil { + return nil + } + return c.GetRegion(id) +} + +// GetRegionLabeler returns the region labeler. +func (h *Handler) GetRegionLabeler() (*labeler.RegionLabeler, error) { + c := h.GetCluster() + if c == nil || c.GetRegionLabeler() == nil { + return nil, errs.ErrNotBootstrapped + } + return c.GetRegionLabeler(), nil +} diff --git a/pkg/utils/apiutil/serverapi/middleware.go b/pkg/utils/apiutil/serverapi/middleware.go index 19438ad0f91..061c65329c9 100644 --- a/pkg/utils/apiutil/serverapi/middleware.go +++ b/pkg/utils/apiutil/serverapi/middleware.go @@ -79,6 +79,7 @@ type microserviceRedirectRule struct { targetPath string targetServiceName string matchMethods []string + filter func(*http.Request) bool } // NewRedirector redirects request to the leader if needs to be handled in the leader. @@ -94,14 +95,19 @@ func NewRedirector(s *server.Server, opts ...RedirectorOption) negroni.Handler { type RedirectorOption func(*redirector) // MicroserviceRedirectRule new a microservice redirect rule option -func MicroserviceRedirectRule(matchPath, targetPath, targetServiceName string, methods []string) RedirectorOption { +func MicroserviceRedirectRule(matchPath, targetPath, targetServiceName string, + methods []string, filters ...func(*http.Request) bool) RedirectorOption { return func(s *redirector) { - s.microserviceRedirectRules = append(s.microserviceRedirectRules, µserviceRedirectRule{ - matchPath, - targetPath, - targetServiceName, - methods, - }) + rule := µserviceRedirectRule{ + matchPath: matchPath, + targetPath: targetPath, + targetServiceName: targetServiceName, + matchMethods: methods, + } + if len(filters) > 0 { + rule.filter = filters[0] + } + s.microserviceRedirectRules = append(s.microserviceRedirectRules, rule) } } @@ -117,6 +123,9 @@ func (h *redirector) matchMicroServiceRedirectRules(r *http.Request) (bool, stri r.URL.Path = strings.TrimRight(r.URL.Path, "/") for _, rule := range h.microserviceRedirectRules { if strings.HasPrefix(r.URL.Path, rule.matchPath) && slice.Contains(rule.matchMethods, r.Method) { + if rule.filter != nil && !rule.filter(r) { + continue + } addr, ok := h.s.GetServicePrimaryAddr(r.Context(), rule.targetServiceName) if !ok || addr == "" { log.Warn("failed to get the service primary addr when trying to match redirect rules", diff --git a/pkg/utils/testutil/api_check.go b/pkg/utils/testutil/api_check.go index 84af97f828d..786530b1567 100644 --- a/pkg/utils/testutil/api_check.go +++ b/pkg/utils/testutil/api_check.go @@ -88,7 +88,7 @@ func ReadGetJSON(re *require.Assertions, client *http.Client, url string, data i } // ReadGetJSONWithBody is used to do get request with input and check whether given data can be extracted successfully. -func ReadGetJSONWithBody(re *require.Assertions, client *http.Client, url string, input []byte, data interface{}) error { +func ReadGetJSONWithBody(re *require.Assertions, client *http.Client, url string, input []byte, data interface{}, checkOpts ...func([]byte, int, http.Header)) error { resp, err := apiutil.GetJSON(client, url, input) if err != nil { return err diff --git a/server/api/region_label.go b/server/api/region_label.go index 003dfb1132f..7958bacd371 100644 --- a/server/api/region_label.go +++ b/server/api/region_label.go @@ -83,7 +83,7 @@ func (h *regionLabelHandler) PatchRegionLabelRules(w http.ResponseWriter, r *htt // @Success 200 {array} labeler.LabelRule // @Failure 400 {string} string "The input is invalid." // @Failure 500 {string} string "PD server failed to proceed the request." -// @Router /config/region-label/rule/ids [get] +// @Router /config/region-label/rules/ids [get] func (h *regionLabelHandler) GetRegionLabelRulesByIDs(w http.ResponseWriter, r *http.Request) { cluster := getCluster(r) var ids []string diff --git a/server/api/server.go b/server/api/server.go index ee301ea54c8..992cd42d796 100644 --- a/server/api/server.go +++ b/server/api/server.go @@ -17,6 +17,7 @@ package api import ( "context" "net/http" + "strings" "github.com/gorilla/mux" scheapi "github.com/tikv/pd/pkg/mcs/scheduling/server/apis/v1" @@ -78,6 +79,21 @@ func NewHandler(_ context.Context, svr *server.Server) (http.Handler, apiutil.AP scheapi.APIPathPrefix+"/checkers", mcs.SchedulingServiceName, []string{http.MethodPost, http.MethodGet}), + serverapi.MicroserviceRedirectRule( + prefix+"/region/id", + scheapi.APIPathPrefix+"/regions", + mcs.SchedulingServiceName, + []string{http.MethodGet}, + func(r *http.Request) bool { + // The original code uses the path "/region/id" to get the region id. + // However, the path "/region/id" is used to get the region by id, which is not what we want. + return strings.Contains(r.URL.Path, "label") + }), + serverapi.MicroserviceRedirectRule( + prefix+"/config/region-label", + scheapi.APIPathPrefix+"/config/region-label", + mcs.SchedulingServiceName, + []string{http.MethodGet}), serverapi.MicroserviceRedirectRule( prefix+"/hotspot", scheapi.APIPathPrefix+"/hotspot", diff --git a/tests/integrations/mcs/scheduling/api_test.go b/tests/integrations/mcs/scheduling/api_test.go index 5284913813c..3873ca0f96d 100644 --- a/tests/integrations/mcs/scheduling/api_test.go +++ b/tests/integrations/mcs/scheduling/api_test.go @@ -12,6 +12,7 @@ import ( "github.com/stretchr/testify/suite" _ "github.com/tikv/pd/pkg/mcs/scheduling/server/apis/v1" "github.com/tikv/pd/pkg/schedule/handler" + "github.com/tikv/pd/pkg/schedule/labeler" "github.com/tikv/pd/pkg/statistics" "github.com/tikv/pd/pkg/storage" "github.com/tikv/pd/pkg/utils/apiutil" @@ -217,4 +218,26 @@ func (suite *apiTestSuite) TestAPIForward() { err = testutil.ReadGetJSON(re, testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "hotspot/regions/history"), &history, testutil.WithHeader(re, apiutil.ForwardToMicroServiceHeader, "true")) re.NoError(err) + + // Test region label + var rules []*labeler.LabelRule + err = testutil.ReadGetJSON(re, testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "config/region-label/rules"), &rules, + testutil.WithHeader(re, apiutil.ForwardToMicroServiceHeader, "true")) + re.NoError(err) + err = testutil.ReadGetJSONWithBody(re, testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "config/region-label/rules/ids"), []byte(`["rule1", "rule3"]`), + &rules, testutil.WithHeader(re, apiutil.ForwardToMicroServiceHeader, "true")) + re.NoError(err) + err = testutil.CheckGetJSON(testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "config/region-label/rule/rule1"), nil, + testutil.StatusNotOK(re), testutil.WithHeader(re, apiutil.ForwardToMicroServiceHeader, "true")) + re.NoError(err) + + err = testutil.CheckGetJSON(testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "region/id/1"), nil, + testutil.WithoutHeader(re, apiutil.ForwardToMicroServiceHeader)) + re.NoError(err) + err = testutil.CheckGetJSON(testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "region/id/1/label/key"), nil, + testutil.WithHeader(re, apiutil.ForwardToMicroServiceHeader,"true")) + re.NoError(err) + err = testutil.CheckGetJSON(testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "region/id/1/labels"), nil, + testutil.WithHeader(re, apiutil.ForwardToMicroServiceHeader,"true")) + re.NoError(err) }