diff --git a/pkg/mcs/resourcemanager/server/apis/v1/api.go b/pkg/mcs/resourcemanager/server/apis/v1/api.go index 411933e55c34..7c5e3e010dc2 100644 --- a/pkg/mcs/resourcemanager/server/apis/v1/api.go +++ b/pkg/mcs/resourcemanager/server/apis/v1/api.go @@ -73,7 +73,7 @@ func NewService(srv *rmserver.Service) *Service { manager := srv.GetManager() apiHandlerEngine.Use(func(c *gin.Context) { // manager implements the interface of basicserver.Service. - c.Set("service", manager.GetBasicServer()) + c.Set(multiservicesapi.ServiceContextKey, manager.GetBasicServer()) c.Next() }) apiHandlerEngine.Use(multiservicesapi.ServiceRedirector()) diff --git a/pkg/mcs/scheduling/server/apis/v1/api.go b/pkg/mcs/scheduling/server/apis/v1/api.go index e8c4faa5d559..48d473b7b1b1 100644 --- a/pkg/mcs/scheduling/server/apis/v1/api.go +++ b/pkg/mcs/scheduling/server/apis/v1/api.go @@ -34,7 +34,7 @@ import ( ) // APIPathPrefix is the prefix of the API path. -const APIPathPrefix = "/scheduling/api/v1/" +const APIPathPrefix = "/scheduling/api/v1" var ( once sync.Once @@ -101,19 +101,19 @@ func NewService(srv *scheserver.Service) *Service { // RegisterSchedulersRouter registers the router of the schedulers handler. func (s *Service) RegisterSchedulersRouter() { - router := s.root.Group("schedulers") + router := s.root.Group("schedulers", gzip.Gzip(gzip.DefaultCompression)) router.GET("", getSchedulers) } // RegisterCheckersRouter registers the router of the checkers handler. func (s *Service) RegisterCheckersRouter() { - router := s.root.Group("checkers") + router := s.root.Group("checkers", gzip.Gzip(gzip.DefaultCompression)) router.GET("/:name", getCheckerByName) } // RegisterOperatorsRouter registers the router of the operators handler. func (s *Service) RegisterOperatorsRouter() { - router := s.root.Group("operators") + router := s.root.Group("operators", gzip.Gzip(gzip.DefaultCompression)) router.GET("", getOperators) router.GET("/:id", getOperatorByID) } diff --git a/pkg/mcs/tso/server/apis/v1/api.go b/pkg/mcs/tso/server/apis/v1/api.go index c2cbca005d71..f3f39e470631 100644 --- a/pkg/mcs/tso/server/apis/v1/api.go +++ b/pkg/mcs/tso/server/apis/v1/api.go @@ -105,7 +105,7 @@ func NewService(srv *tsoserver.Service) *Service { // RegisterAdminRouter registers the router of the TSO admin handler. func (s *Service) RegisterAdminRouter() { - router := s.root.Group("admin") + router := s.root.Group("admin", gzip.Gzip(gzip.DefaultCompression)) router.POST("/reset-ts", ResetTS) } diff --git a/pkg/utils/apiutil/serverapi/middleware.go b/pkg/utils/apiutil/serverapi/middleware.go index 7d403ecef13d..d7850029e2ed 100644 --- a/pkg/utils/apiutil/serverapi/middleware.go +++ b/pkg/utils/apiutil/serverapi/middleware.go @@ -17,6 +17,7 @@ package serverapi import ( "net/http" "net/url" + "strings" "github.com/pingcap/log" "github.com/tikv/pd/pkg/errs" @@ -108,13 +109,19 @@ func (h *redirector) matchMicroServiceRedirectRules(r *http.Request) (bool, stri return false, "" } for _, rule := range h.microserviceRedirectRules { - if rule.matchPath == r.URL.Path { + if strings.HasPrefix(r.URL.Path, rule.matchPath) { addr, ok := h.s.GetServicePrimaryAddr(r.Context(), rule.targetServiceName) if !ok || addr == "" { log.Warn("failed to get the service primary addr when try match redirect rules", zap.String("path", r.URL.Path)) } - r.URL.Path = rule.targetPath + // Extract parameters from the URL path + pathParams := strings.TrimPrefix(r.URL.Path, rule.matchPath) + if len(pathParams) > 0 && pathParams[0] == '/' { + pathParams = pathParams[1:] // Remove leading '/' + } + r.URL.Path = rule.targetPath + "/" + pathParams + r.URL.Path = strings.TrimRight(r.URL.Path, "/") return true, addr } } @@ -122,10 +129,10 @@ func (h *redirector) matchMicroServiceRedirectRules(r *http.Request) (bool, stri } func (h *redirector) ServeHTTP(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) { - matchedFlag, targetAddr := h.matchMicroServiceRedirectRules(r) + needRedirectToMicroService, targetAddr := h.matchMicroServiceRedirectRules(r) allowFollowerHandle := len(r.Header.Get(apiutil.PDAllowFollowerHandleHeader)) > 0 isLeader := h.s.GetMember().IsLeader() - if !h.s.IsClosed() && (allowFollowerHandle || isLeader) && !matchedFlag { + if !h.s.IsClosed() && (allowFollowerHandle || isLeader) && !needRedirectToMicroService { next(w, r) return } @@ -150,7 +157,7 @@ func (h *redirector) ServeHTTP(w http.ResponseWriter, r *http.Request, next http } var clientUrls []string - if matchedFlag { + if needRedirectToMicroService { if len(targetAddr) == 0 { http.Error(w, apiutil.ErrRedirectFailed, http.StatusInternalServerError) return diff --git a/server/api/server.go b/server/api/server.go index 1d881022c042..5df5d22f2d03 100644 --- a/server/api/server.go +++ b/server/api/server.go @@ -19,6 +19,7 @@ import ( "net/http" "github.com/gorilla/mux" + scheapi "github.com/tikv/pd/pkg/mcs/scheduling/server/apis/v1" tsoapi "github.com/tikv/pd/pkg/mcs/tso/server/apis/v1" mcs "github.com/tikv/pd/pkg/mcs/utils" "github.com/tikv/pd/pkg/utils/apiutil" @@ -35,14 +36,29 @@ func NewHandler(_ context.Context, svr *server.Server) (http.Handler, apiutil.AP Name: "core", IsCore: true, } - router := mux.NewRouter() + prefix := apiPrefix + "/api/v1" r := createRouter(apiPrefix, svr) + router := mux.NewRouter() router.PathPrefix(apiPrefix).Handler(negroni.New( serverapi.NewRuntimeServiceValidator(svr, group), - serverapi.NewRedirector(svr, serverapi.MicroserviceRedirectRule( - apiPrefix+"/api/v1"+"/admin/reset-ts", - tsoapi.APIPathPrefix+"/admin/reset-ts", - mcs.TSOServiceName)), + serverapi.NewRedirector(svr, + serverapi.MicroserviceRedirectRule( + prefix+"/admin/reset-ts", + tsoapi.APIPathPrefix+"/admin/reset-ts", + mcs.TSOServiceName), + serverapi.MicroserviceRedirectRule( + prefix+"/operators", + scheapi.APIPathPrefix+"/operators", + mcs.SchedulingServiceName), + serverapi.MicroserviceRedirectRule( + prefix+"/checker", // Note: this is a typo in the original code + scheapi.APIPathPrefix+"/checkers", + mcs.SchedulingServiceName), + serverapi.MicroserviceRedirectRule( + prefix+"/schedulers", + scheapi.APIPathPrefix+"/schedulers", + mcs.SchedulingServiceName), + ), negroni.Wrap(r)), ) diff --git a/tests/integrations/mcs/scheduling/server_test.go b/tests/integrations/mcs/scheduling/server_test.go index 54994bbc34b0..a28814a75a12 100644 --- a/tests/integrations/mcs/scheduling/server_test.go +++ b/tests/integrations/mcs/scheduling/server_test.go @@ -16,10 +16,15 @@ package scheduling import ( "context" + "encoding/json" + "fmt" + "io" + "net/http" "testing" "time" "github.com/pingcap/failpoint" + "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" mcs "github.com/tikv/pd/pkg/mcs/utils" "github.com/tikv/pd/pkg/utils/testutil" @@ -126,3 +131,55 @@ func (suite *serverTestSuite) TestPrimaryChange() { return ok && newPrimaryAddr == watchedAddr }) } + +func (suite *serverTestSuite) TestAPIForward() { + re := suite.Require() + tc, err := tests.NewTestSchedulingCluster(suite.ctx, 2, suite.backendEndpoints) + re.NoError(err) + defer tc.Destroy() + tc.WaitForPrimaryServing(re) + + urlPrefix := fmt.Sprintf("%s/pd/api/v1", suite.backendEndpoints) + cli := &http.Client{} + defer cli.CloseIdleConnections() + var list []string + var res map[string]interface{} + + // opeartor + resp := checkAPIForward(re, cli, http.MethodGet, urlPrefix+"/operators") + defer resp.Body.Close() + data, _ := io.ReadAll(resp.Body) + re.NoError(json.Unmarshal(data, &list)) + re.Len(list, 0) + + resp = checkAPIForward(re, cli, http.MethodGet, urlPrefix+"/operators/2") + defer resp.Body.Close() + data, _ = io.ReadAll(resp.Body) + re.NoError(json.Unmarshal(data, &res)) + re.Nil(res) + + // checker + resp = checkAPIForward(re, cli, http.MethodGet, urlPrefix+"/checker/merge") + defer resp.Body.Close() + data, _ = io.ReadAll(resp.Body) + re.NoError(json.Unmarshal(data, &res)) + re.False(res["paused"].(bool)) + + // scheduler + resp = checkAPIForward(re, cli, http.MethodGet, urlPrefix+"/schedulers") + defer resp.Body.Close() + data, _ = io.ReadAll(resp.Body) + re.NoError(json.Unmarshal(data, &list)) + re.Contains(list, "balance-leader-scheduler") +} + +func checkAPIForward(re *require.Assertions, cli *http.Client, method string, urlPrefix string) (resp *http.Response) { + var err error + switch method { + case http.MethodGet: + resp, err = cli.Get(urlPrefix) + } + re.NoError(err) + re.Equal(http.StatusOK, resp.StatusCode) + return resp +} diff --git a/tests/integrations/mcs/tso/api_test.go b/tests/integrations/mcs/tso/api_test.go index fde6bcb8da05..ade933ad38cf 100644 --- a/tests/integrations/mcs/tso/api_test.go +++ b/tests/integrations/mcs/tso/api_test.go @@ -20,6 +20,7 @@ import ( "encoding/json" "io" "net/http" + "strings" "testing" "time" @@ -47,10 +48,11 @@ var dialClient = &http.Client{ type tsoAPITestSuite struct { suite.Suite - ctx context.Context - cancel context.CancelFunc - pdCluster *tests.TestCluster - tsoCluster *tests.TestTSOCluster + ctx context.Context + cancel context.CancelFunc + pdCluster *tests.TestCluster + tsoCluster *tests.TestTSOCluster + backendEndpoints string } func TestTSOAPI(t *testing.T) { @@ -69,7 +71,8 @@ func (suite *tsoAPITestSuite) SetupTest() { leaderName := suite.pdCluster.WaitLeader() pdLeaderServer := suite.pdCluster.GetServer(leaderName) re.NoError(pdLeaderServer.BootstrapCluster()) - suite.tsoCluster, err = tests.NewTestTSOCluster(suite.ctx, 1, pdLeaderServer.GetAddr()) + suite.backendEndpoints = pdLeaderServer.GetAddr() + suite.tsoCluster, err = tests.NewTestTSOCluster(suite.ctx, 1, suite.backendEndpoints) re.NoError(err) } @@ -95,6 +98,31 @@ func (suite *tsoAPITestSuite) TestGetKeyspaceGroupMembers() { re.Equal(primaryMember.GetLeaderID(), defaultGroupMember.PrimaryID) } +func (suite *tsoAPITestSuite) TestResetTS() { + re := suite.Require() + primary := suite.tsoCluster.WaitForDefaultPrimaryServing(re) + re.NotNil(primary) + url := suite.backendEndpoints + "/pd/api/v1/admin/reset-ts" + + // reset ts + resetJSON := `{"tso":"121312", "force-use-larger":true}` + resp, err := http.Post(url, "application/json", strings.NewReader(resetJSON)) + re.NoError(err) + defer resp.Body.Close() + re.Equal(http.StatusOK, resp.StatusCode) + data, _ := io.ReadAll(resp.Body) + re.Equal("Reset ts successfully.", string(data)) + + // reset ts with invalid tso + resetJSON = `{}` + resp, err = http.Post(url, "application/json", strings.NewReader(resetJSON)) + re.NoError(err) + defer resp.Body.Close() + re.Equal(http.StatusBadRequest, resp.StatusCode) + data, _ = io.ReadAll(resp.Body) + re.Equal("invalid tso value", string(data)) +} + func mustGetKeyspaceGroupMembers(re *require.Assertions, server *tso.Server) map[uint32]*apis.KeyspaceGroupMember { httpReq, err := http.NewRequest(http.MethodGet, server.GetAddr()+tsoKeyspaceGroupsPrefix+"/members", nil) re.NoError(err) diff --git a/tests/integrations/mcs/tso/server_test.go b/tests/integrations/mcs/tso/server_test.go index d87f15421799..cdf28c0ec6cf 100644 --- a/tests/integrations/mcs/tso/server_test.go +++ b/tests/integrations/mcs/tso/server_test.go @@ -109,7 +109,6 @@ func (suite *tsoServerTestSuite) TestTSOServerStartAndStopNormally() { url := s.GetAddr() + tsoapi.APIPathPrefix { resetJSON := `{"tso":"121312", "force-use-larger":true}` - re.NoError(err) resp, err := http.Post(url+"/admin/reset-ts", "application/json", strings.NewReader(resetJSON)) re.NoError(err) defer resp.Body.Close() @@ -117,7 +116,6 @@ func (suite *tsoServerTestSuite) TestTSOServerStartAndStopNormally() { } { resetJSON := `{}` - re.NoError(err) resp, err := http.Post(url+"/admin/reset-ts", "application/json", strings.NewReader(resetJSON)) re.NoError(err) defer resp.Body.Close()