From c6fbf6055621b7ae4b2f3b3ff524bd946a337a98 Mon Sep 17 00:00:00 2001 From: lhy1024 Date: Mon, 18 Sep 2023 16:49:41 +0800 Subject: [PATCH] mcs: forward current http request to mcs (#7078) ref tikv/pd#5839 Signed-off-by: lhy1024 Co-authored-by: ti-chi-bot[bot] <108142056+ti-chi-bot[bot]@users.noreply.github.com> --- pkg/mcs/resourcemanager/server/apis/v1/api.go | 2 +- pkg/mcs/scheduling/server/apis/v1/api.go | 2 +- pkg/utils/apiutil/apiutil.go | 17 +++ pkg/utils/apiutil/serverapi/middleware.go | 35 ++++-- pkg/utils/testutil/api_check.go | 75 +++++++----- server/api/hot_status_test.go | 5 +- server/api/region_test.go | 2 +- server/api/rule_test.go | 2 +- server/api/server.go | 34 +++++- tests/integrations/mcs/scheduling/api_test.go | 58 +++++++++ tests/integrations/mcs/tso/api_test.go | 32 ++++- tests/integrations/mcs/tso/server_test.go | 30 ++--- tests/pdctl/operator/operator_test.go | 30 +++++ tests/pdctl/scheduler/scheduler_test.go | 115 ++++++++++++++++++ 14 files changed, 369 insertions(+), 70 deletions(-) diff --git a/pkg/mcs/resourcemanager/server/apis/v1/api.go b/pkg/mcs/resourcemanager/server/apis/v1/api.go index 411933e55c3..7c5e3e010dc 100644 --- a/pkg/mcs/resourcemanager/server/apis/v1/api.go +++ b/pkg/mcs/resourcemanager/server/apis/v1/api.go @@ -73,7 +73,7 @@ func NewService(srv *rmserver.Service) *Service { manager := srv.GetManager() apiHandlerEngine.Use(func(c *gin.Context) { // manager implements the interface of basicserver.Service. - c.Set("service", manager.GetBasicServer()) + c.Set(multiservicesapi.ServiceContextKey, manager.GetBasicServer()) c.Next() }) apiHandlerEngine.Use(multiservicesapi.ServiceRedirector()) diff --git a/pkg/mcs/scheduling/server/apis/v1/api.go b/pkg/mcs/scheduling/server/apis/v1/api.go index e8c4faa5d55..3d1c3921470 100644 --- a/pkg/mcs/scheduling/server/apis/v1/api.go +++ b/pkg/mcs/scheduling/server/apis/v1/api.go @@ -34,7 +34,7 @@ import ( ) // APIPathPrefix is the prefix of the API path. -const APIPathPrefix = "/scheduling/api/v1/" +const APIPathPrefix = "/scheduling/api/v1" var ( once sync.Once diff --git a/pkg/utils/apiutil/apiutil.go b/pkg/utils/apiutil/apiutil.go index 269a256cff3..0467bac9b04 100644 --- a/pkg/utils/apiutil/apiutil.go +++ b/pkg/utils/apiutil/apiutil.go @@ -57,6 +57,8 @@ const ( XForwardedPortHeader = "X-Forwarded-Port" // XRealIPHeader is used to mark the real client IP. XRealIPHeader = "X-Real-Ip" + // ForwardToMicroServiceHeader is used to mark the request is forwarded to micro service. + ForwardToMicroServiceHeader = "Forward-To-Micro-Service" // ErrRedirectFailed is the error message for redirect failed. ErrRedirectFailed = "redirect failed" @@ -435,8 +437,17 @@ func (p *customReverseProxies) ServeHTTP(w http.ResponseWriter, r *http.Request) reader = resp.Body } + // We need to copy the response headers before we write the header. + // Otherwise, we cannot set the header after w.WriteHeader() is called. + // And we need to write the header before we copy the response body. + // Otherwise, we cannot set the status code after w.Write() is called. + // In other words, we must perform the following steps strictly in order: + // 1. Set the response headers. + // 2. Write the response header. + // 3. Write the response body. copyHeader(w.Header(), resp.Header) w.WriteHeader(resp.StatusCode) + for { if _, err = io.CopyN(w, reader, chunkSize); err != nil { if err == io.EOF { @@ -455,8 +466,14 @@ func (p *customReverseProxies) ServeHTTP(w http.ResponseWriter, r *http.Request) http.Error(w, ErrRedirectFailed, http.StatusInternalServerError) } +// copyHeader duplicates the HTTP headers from the source `src` to the destination `dst`. +// It skips the "Content-Encoding" and "Content-Length" headers because they should be set by `http.ResponseWriter`. +// These headers may be modified after a redirect when gzip compression is enabled. func copyHeader(dst, src http.Header) { for k, vv := range src { + if k == "Content-Encoding" || k == "Content-Length" { + continue + } values := dst[k] for _, v := range vv { if !slice.Contains(values, v) { diff --git a/pkg/utils/apiutil/serverapi/middleware.go b/pkg/utils/apiutil/serverapi/middleware.go index 7d403ecef13..063ad042dbb 100644 --- a/pkg/utils/apiutil/serverapi/middleware.go +++ b/pkg/utils/apiutil/serverapi/middleware.go @@ -17,9 +17,12 @@ package serverapi import ( "net/http" "net/url" + "strings" + "github.com/pingcap/failpoint" "github.com/pingcap/log" "github.com/tikv/pd/pkg/errs" + "github.com/tikv/pd/pkg/slice" "github.com/tikv/pd/pkg/utils/apiutil" "github.com/tikv/pd/server" "github.com/urfave/negroni" @@ -75,6 +78,7 @@ type microserviceRedirectRule struct { matchPath string targetPath string targetServiceName string + matchMethods []string } // NewRedirector redirects request to the leader if needs to be handled in the leader. @@ -90,12 +94,13 @@ func NewRedirector(s *server.Server, opts ...RedirectorOption) negroni.Handler { type RedirectorOption func(*redirector) // MicroserviceRedirectRule new a microservice redirect rule option -func MicroserviceRedirectRule(matchPath, targetPath, targetServiceName string) RedirectorOption { +func MicroserviceRedirectRule(matchPath, targetPath, targetServiceName string, methods []string) RedirectorOption { return func(s *redirector) { s.microserviceRedirectRules = append(s.microserviceRedirectRules, µserviceRedirectRule{ matchPath, targetPath, targetServiceName, + methods, }) } } @@ -108,13 +113,24 @@ func (h *redirector) matchMicroServiceRedirectRules(r *http.Request) (bool, stri return false, "" } for _, rule := range h.microserviceRedirectRules { - if rule.matchPath == r.URL.Path { + if strings.HasPrefix(r.URL.Path, rule.matchPath) && slice.Contains(rule.matchMethods, r.Method) { addr, ok := h.s.GetServicePrimaryAddr(r.Context(), rule.targetServiceName) if !ok || addr == "" { - log.Warn("failed to get the service primary addr when try match redirect rules", + log.Warn("failed to get the service primary addr when trying to match redirect rules", zap.String("path", r.URL.Path)) } - r.URL.Path = rule.targetPath + // Extract parameters from the URL path + // e.g. r.URL.Path = /pd/api/v1/operators/1 (before redirect) + // matchPath = /pd/api/v1/operators + // targetPath = /scheduling/api/v1/operators + // r.URL.Path = /scheduling/api/v1/operator/1 (after redirect) + pathParams := strings.TrimPrefix(r.URL.Path, rule.matchPath) + pathParams = strings.Trim(pathParams, "/") // Remove leading and trailing '/' + if len(pathParams) > 0 { + r.URL.Path = rule.targetPath + "/" + pathParams + } else { + r.URL.Path = rule.targetPath + } return true, addr } } @@ -122,10 +138,10 @@ func (h *redirector) matchMicroServiceRedirectRules(r *http.Request) (bool, stri } func (h *redirector) ServeHTTP(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) { - matchedFlag, targetAddr := h.matchMicroServiceRedirectRules(r) + redirectToMicroService, targetAddr := h.matchMicroServiceRedirectRules(r) allowFollowerHandle := len(r.Header.Get(apiutil.PDAllowFollowerHandleHeader)) > 0 isLeader := h.s.GetMember().IsLeader() - if !h.s.IsClosed() && (allowFollowerHandle || isLeader) && !matchedFlag { + if !h.s.IsClosed() && (allowFollowerHandle || isLeader) && !redirectToMicroService { next(w, r) return } @@ -150,12 +166,17 @@ func (h *redirector) ServeHTTP(w http.ResponseWriter, r *http.Request, next http } var clientUrls []string - if matchedFlag { + if redirectToMicroService { if len(targetAddr) == 0 { http.Error(w, apiutil.ErrRedirectFailed, http.StatusInternalServerError) return } clientUrls = append(clientUrls, targetAddr) + failpoint.Inject("checkHeader", func() { + // add a header to the response, this is not a failure injection + // it is used for testing, to check whether the request is forwarded to the micro service + w.Header().Set(apiutil.ForwardToMicroServiceHeader, "true") + }) } else { leader := h.s.GetMember().GetLeader() if leader == nil { diff --git a/pkg/utils/testutil/api_check.go b/pkg/utils/testutil/api_check.go index c17c6970ab7..d11d575967d 100644 --- a/pkg/utils/testutil/api_check.go +++ b/pkg/utils/testutil/api_check.go @@ -23,56 +23,71 @@ import ( "github.com/tikv/pd/pkg/utils/apiutil" ) -// Status is used to check whether http response code is equal given code -func Status(re *require.Assertions, code int) func([]byte, int) { - return func(_ []byte, i int) { - re.Equal(code, i) +// Status is used to check whether http response code is equal given code. +func Status(re *require.Assertions, code int) func([]byte, int, http.Header) { + return func(resp []byte, i int, _ http.Header) { + re.Equal(code, i, "resp: "+string(resp)) } } -// StatusOK is used to check whether http response code is equal http.StatusOK -func StatusOK(re *require.Assertions) func([]byte, int) { +// StatusOK is used to check whether http response code is equal http.StatusOK. +func StatusOK(re *require.Assertions) func([]byte, int, http.Header) { return Status(re, http.StatusOK) } -// StatusNotOK is used to check whether http response code is not equal http.StatusOK -func StatusNotOK(re *require.Assertions) func([]byte, int) { - return func(_ []byte, i int) { +// StatusNotOK is used to check whether http response code is not equal http.StatusOK. +func StatusNotOK(re *require.Assertions) func([]byte, int, http.Header) { + return func(_ []byte, i int, _ http.Header) { re.NotEqual(http.StatusOK, i) } } -// ExtractJSON is used to check whether given data can be extracted successfully -func ExtractJSON(re *require.Assertions, data interface{}) func([]byte, int) { - return func(res []byte, _ int) { +// ExtractJSON is used to check whether given data can be extracted successfully. +func ExtractJSON(re *require.Assertions, data interface{}) func([]byte, int, http.Header) { + return func(res []byte, _ int, _ http.Header) { re.NoError(json.Unmarshal(res, data)) } } -// StringContain is used to check whether response context contains given string -func StringContain(re *require.Assertions, sub string) func([]byte, int) { - return func(res []byte, _ int) { +// StringContain is used to check whether response context contains given string. +func StringContain(re *require.Assertions, sub string) func([]byte, int, http.Header) { + return func(res []byte, _ int, _ http.Header) { re.Contains(string(res), sub) } } -// StringEqual is used to check whether response context equal given string -func StringEqual(re *require.Assertions, str string) func([]byte, int) { - return func(res []byte, _ int) { +// StringEqual is used to check whether response context equal given string. +func StringEqual(re *require.Assertions, str string) func([]byte, int, http.Header) { + return func(res []byte, _ int, _ http.Header) { re.Contains(string(res), str) } } -// ReadGetJSON is used to do get request and check whether given data can be extracted successfully -func ReadGetJSON(re *require.Assertions, client *http.Client, url string, data interface{}) error { +// WithHeader is used to check whether response header contains given key and value. +func WithHeader(re *require.Assertions, key, value string) func([]byte, int, http.Header) { + return func(_ []byte, _ int, header http.Header) { + re.Equal(value, header.Get(key)) + } +} + +// WithoutHeader is used to check whether response header does not contain given key. +func WithoutHeader(re *require.Assertions, key string) func([]byte, int, http.Header) { + return func(_ []byte, _ int, header http.Header) { + re.Empty(header.Get(key)) + } +} + +// ReadGetJSON is used to do get request and check whether given data can be extracted successfully. +func ReadGetJSON(re *require.Assertions, client *http.Client, url string, data interface{}, checkOpts ...func([]byte, int, http.Header)) error { resp, err := apiutil.GetJSON(client, url, nil) if err != nil { return err } - return checkResp(resp, StatusOK(re), ExtractJSON(re, data)) + checkOpts = append(checkOpts, StatusOK(re), ExtractJSON(re, data)) + return checkResp(resp, checkOpts...) } -// ReadGetJSONWithBody is used to do get request with input and check whether given data can be extracted successfully +// ReadGetJSONWithBody is used to do get request with input and check whether given data can be extracted successfully. func ReadGetJSONWithBody(re *require.Assertions, client *http.Client, url string, input []byte, data interface{}) error { resp, err := apiutil.GetJSON(client, url, input) if err != nil { @@ -81,8 +96,8 @@ func ReadGetJSONWithBody(re *require.Assertions, client *http.Client, url string return checkResp(resp, StatusOK(re), ExtractJSON(re, data)) } -// CheckPostJSON is used to do post request and do check options -func CheckPostJSON(client *http.Client, url string, data []byte, checkOpts ...func([]byte, int)) error { +// CheckPostJSON is used to do post request and do check options. +func CheckPostJSON(client *http.Client, url string, data []byte, checkOpts ...func([]byte, int, http.Header)) error { resp, err := apiutil.PostJSON(client, url, data) if err != nil { return err @@ -90,8 +105,8 @@ func CheckPostJSON(client *http.Client, url string, data []byte, checkOpts ...fu return checkResp(resp, checkOpts...) } -// CheckGetJSON is used to do get request and do check options -func CheckGetJSON(client *http.Client, url string, data []byte, checkOpts ...func([]byte, int)) error { +// CheckGetJSON is used to do get request and do check options. +func CheckGetJSON(client *http.Client, url string, data []byte, checkOpts ...func([]byte, int, http.Header)) error { resp, err := apiutil.GetJSON(client, url, data) if err != nil { return err @@ -99,8 +114,8 @@ func CheckGetJSON(client *http.Client, url string, data []byte, checkOpts ...fun return checkResp(resp, checkOpts...) } -// CheckPatchJSON is used to do patch request and do check options -func CheckPatchJSON(client *http.Client, url string, data []byte, checkOpts ...func([]byte, int)) error { +// CheckPatchJSON is used to do patch request and do check options. +func CheckPatchJSON(client *http.Client, url string, data []byte, checkOpts ...func([]byte, int, http.Header)) error { resp, err := apiutil.PatchJSON(client, url, data) if err != nil { return err @@ -108,14 +123,14 @@ func CheckPatchJSON(client *http.Client, url string, data []byte, checkOpts ...f return checkResp(resp, checkOpts...) } -func checkResp(resp *http.Response, checkOpts ...func([]byte, int)) error { +func checkResp(resp *http.Response, checkOpts ...func([]byte, int, http.Header)) error { res, err := io.ReadAll(resp.Body) resp.Body.Close() if err != nil { return err } for _, opt := range checkOpts { - opt(res, resp.StatusCode) + opt(res, resp.StatusCode, resp.Header) } return nil } diff --git a/server/api/hot_status_test.go b/server/api/hot_status_test.go index a1d1bbc2617..d3d495f86fa 100644 --- a/server/api/hot_status_test.go +++ b/server/api/hot_status_test.go @@ -17,6 +17,7 @@ package api import ( "encoding/json" "fmt" + "net/http" "testing" "time" @@ -92,7 +93,7 @@ func (suite *hotStatusTestSuite) TestGetHistoryHotRegionsTimeRange() { StartTime: now.UnixNano() / int64(time.Millisecond), EndTime: now.Add(10*time.Second).UnixNano() / int64(time.Millisecond), } - check := func(res []byte, statusCode int) { + check := func(res []byte, statusCode int, _ http.Header) { suite.Equal(200, statusCode) historyHotRegions := &storage.HistoryHotRegions{} json.Unmarshal(res, historyHotRegions) @@ -177,7 +178,7 @@ func (suite *hotStatusTestSuite) TestGetHistoryHotRegionsIDAndTypes() { IsLearners: []bool{false}, EndTime: now.Add(10*time.Minute).UnixNano() / int64(time.Millisecond), } - check := func(res []byte, statusCode int) { + check := func(res []byte, statusCode int, _ http.Header) { suite.Equal(200, statusCode) historyHotRegions := &storage.HistoryHotRegions{} json.Unmarshal(res, historyHotRegions) diff --git a/server/api/region_test.go b/server/api/region_test.go index 3f20f5ca29f..08880c7fcad 100644 --- a/server/api/region_test.go +++ b/server/api/region_test.go @@ -405,7 +405,7 @@ func (suite *regionTestSuite) TestSplitRegions() { hex.EncodeToString([]byte("bbb")), hex.EncodeToString([]byte("ccc")), hex.EncodeToString([]byte("ddd"))) - checkOpt := func(res []byte, code int) { + checkOpt := func(res []byte, code int, _ http.Header) { s := &struct { ProcessedPercentage int `json:"processed-percentage"` NewRegionsID []uint64 `json:"regions-id"` diff --git a/server/api/rule_test.go b/server/api/rule_test.go index d2000eb9562..4cea1523401 100644 --- a/server/api/rule_test.go +++ b/server/api/rule_test.go @@ -829,7 +829,7 @@ func (suite *ruleTestSuite) TestBundleBadRequest() { } for _, testCase := range testCases { err := tu.CheckPostJSON(testDialClient, suite.urlPrefix+testCase.uri, []byte(testCase.data), - func(_ []byte, code int) { + func(_ []byte, code int, _ http.Header) { suite.Equal(testCase.ok, code == http.StatusOK) }) suite.NoError(err) diff --git a/server/api/server.go b/server/api/server.go index 1d881022c04..0094d8eb5dd 100644 --- a/server/api/server.go +++ b/server/api/server.go @@ -19,6 +19,7 @@ import ( "net/http" "github.com/gorilla/mux" + scheapi "github.com/tikv/pd/pkg/mcs/scheduling/server/apis/v1" tsoapi "github.com/tikv/pd/pkg/mcs/tso/server/apis/v1" mcs "github.com/tikv/pd/pkg/mcs/utils" "github.com/tikv/pd/pkg/utils/apiutil" @@ -35,14 +36,37 @@ func NewHandler(_ context.Context, svr *server.Server) (http.Handler, apiutil.AP Name: "core", IsCore: true, } - router := mux.NewRouter() + prefix := apiPrefix + "/api/v1" r := createRouter(apiPrefix, svr) + router := mux.NewRouter() router.PathPrefix(apiPrefix).Handler(negroni.New( serverapi.NewRuntimeServiceValidator(svr, group), - serverapi.NewRedirector(svr, serverapi.MicroserviceRedirectRule( - apiPrefix+"/api/v1"+"/admin/reset-ts", - tsoapi.APIPathPrefix+"/admin/reset-ts", - mcs.TSOServiceName)), + serverapi.NewRedirector(svr, + serverapi.MicroserviceRedirectRule( + prefix+"/admin/reset-ts", + tsoapi.APIPathPrefix+"/admin/reset-ts", + mcs.TSOServiceName, + []string{http.MethodPost}), + serverapi.MicroserviceRedirectRule( + prefix+"/operators", + scheapi.APIPathPrefix+"/operators", + mcs.SchedulingServiceName, + []string{http.MethodPost, http.MethodGet, http.MethodDelete}), + // because the writing of all the meta information of the scheduling service is in the API server, + // we only forward read-only requests about checkers and schedulers to the scheduling service. + serverapi.MicroserviceRedirectRule( + prefix+"/checker", // Note: this is a typo in the original code + scheapi.APIPathPrefix+"/checkers", + mcs.SchedulingServiceName, + []string{http.MethodGet}), + serverapi.MicroserviceRedirectRule( + prefix+"/schedulers", + scheapi.APIPathPrefix+"/schedulers", + mcs.SchedulingServiceName, + []string{http.MethodGet}), + // TODO: we need to consider the case that v1 api not support restful api. + // we might change the previous path parameters to query parameters. + ), negroni.Wrap(r)), ) diff --git a/tests/integrations/mcs/scheduling/api_test.go b/tests/integrations/mcs/scheduling/api_test.go index 48bdf1ab95c..04671d84798 100644 --- a/tests/integrations/mcs/scheduling/api_test.go +++ b/tests/integrations/mcs/scheduling/api_test.go @@ -2,13 +2,16 @@ package scheduling_test import ( "context" + "encoding/json" "fmt" "net/http" "testing" "time" + "github.com/pingcap/failpoint" "github.com/stretchr/testify/suite" _ "github.com/tikv/pd/pkg/mcs/scheduling/server/apis/v1" + "github.com/tikv/pd/pkg/utils/apiutil" "github.com/tikv/pd/pkg/utils/tempurl" "github.com/tikv/pd/pkg/utils/testutil" "github.com/tikv/pd/tests" @@ -106,3 +109,58 @@ func (suite *apiTestSuite) TestGetCheckerByName() { suite.False(resp["paused"].(bool)) } } + +func (suite *apiTestSuite) TestAPIForward() { + re := suite.Require() + tc, err := tests.NewTestSchedulingCluster(suite.ctx, 2, suite.backendEndpoints) + re.NoError(err) + defer tc.Destroy() + tc.WaitForPrimaryServing(re) + + failpoint.Enable("github.com/tikv/pd/pkg/utils/apiutil/serverapi/checkHeader", "return(true)") + defer func() { + failpoint.Disable("github.com/tikv/pd/pkg/utils/apiutil/serverapi/checkHeader") + }() + + urlPrefix := fmt.Sprintf("%s/pd/api/v1", suite.backendEndpoints) + var slice []string + var resp map[string]interface{} + + // Test opeartor + err = testutil.ReadGetJSON(re, testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "operators"), &slice, + testutil.WithHeader(re, apiutil.ForwardToMicroServiceHeader, "true")) + re.NoError(err) + re.Len(slice, 0) + + err = testutil.ReadGetJSON(re, testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "operators/2"), &resp, + testutil.WithHeader(re, apiutil.ForwardToMicroServiceHeader, "true")) + re.NoError(err) + re.Nil(resp) + + // Test checker: only read-only requests are forwarded + err = testutil.ReadGetJSON(re, testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "checker/merge"), &resp, + testutil.WithHeader(re, apiutil.ForwardToMicroServiceHeader, "true")) + re.NoError(err) + suite.False(resp["paused"].(bool)) + + input := make(map[string]interface{}) + input["delay"] = 10 + pauseArgs, err := json.Marshal(input) + suite.NoError(err) + err = testutil.CheckPostJSON(testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "checker/merge"), pauseArgs, + testutil.StatusOK(re), testutil.WithoutHeader(re, apiutil.PDRedirectorHeader)) + suite.NoError(err) + + // Test scheduler: only read-only requests are forwarded + err = testutil.ReadGetJSON(re, testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "schedulers"), &slice, + testutil.WithHeader(re, apiutil.ForwardToMicroServiceHeader, "true")) + re.NoError(err) + re.Contains(slice, "balance-leader-scheduler") + + input["delay"] = 30 + pauseArgs, err = json.Marshal(input) + suite.NoError(err) + err = testutil.CheckPostJSON(testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "schedulers/all"), pauseArgs, + testutil.StatusOK(re), testutil.WithoutHeader(re, apiutil.ForwardToMicroServiceHeader)) + suite.NoError(err) +} diff --git a/tests/integrations/mcs/tso/api_test.go b/tests/integrations/mcs/tso/api_test.go index fde6bcb8da0..7e870fbc198 100644 --- a/tests/integrations/mcs/tso/api_test.go +++ b/tests/integrations/mcs/tso/api_test.go @@ -30,6 +30,7 @@ import ( apis "github.com/tikv/pd/pkg/mcs/tso/server/apis/v1" mcsutils "github.com/tikv/pd/pkg/mcs/utils" "github.com/tikv/pd/pkg/storage/endpoint" + "github.com/tikv/pd/pkg/utils/testutil" "github.com/tikv/pd/server/config" "github.com/tikv/pd/tests" ) @@ -47,10 +48,11 @@ var dialClient = &http.Client{ type tsoAPITestSuite struct { suite.Suite - ctx context.Context - cancel context.CancelFunc - pdCluster *tests.TestCluster - tsoCluster *tests.TestTSOCluster + ctx context.Context + cancel context.CancelFunc + pdCluster *tests.TestCluster + tsoCluster *tests.TestTSOCluster + backendEndpoints string } func TestTSOAPI(t *testing.T) { @@ -69,7 +71,8 @@ func (suite *tsoAPITestSuite) SetupTest() { leaderName := suite.pdCluster.WaitLeader() pdLeaderServer := suite.pdCluster.GetServer(leaderName) re.NoError(pdLeaderServer.BootstrapCluster()) - suite.tsoCluster, err = tests.NewTestTSOCluster(suite.ctx, 1, pdLeaderServer.GetAddr()) + suite.backendEndpoints = pdLeaderServer.GetAddr() + suite.tsoCluster, err = tests.NewTestTSOCluster(suite.ctx, 1, suite.backendEndpoints) re.NoError(err) } @@ -95,6 +98,25 @@ func (suite *tsoAPITestSuite) TestGetKeyspaceGroupMembers() { re.Equal(primaryMember.GetLeaderID(), defaultGroupMember.PrimaryID) } +func (suite *tsoAPITestSuite) TestForwardResetTS() { + re := suite.Require() + primary := suite.tsoCluster.WaitForDefaultPrimaryServing(re) + re.NotNil(primary) + url := suite.backendEndpoints + "/pd/api/v1/admin/reset-ts" + + // Test reset ts + input := []byte(`{"tso":"121312", "force-use-larger":true}`) + err := testutil.CheckPostJSON(dialClient, url, input, + testutil.StatusOK(re), testutil.StringContain(re, "Reset ts successfully")) + suite.NoError(err) + + // Test reset ts with invalid tso + input = []byte(`{}`) + err = testutil.CheckPostJSON(dialClient, url, input, + testutil.StatusNotOK(re), testutil.StringContain(re, "invalid tso value")) + re.NoError(err) +} + func mustGetKeyspaceGroupMembers(re *require.Assertions, server *tso.Server) map[uint32]*apis.KeyspaceGroupMember { httpReq, err := http.NewRequest(http.MethodGet, server.GetAddr()+tsoKeyspaceGroupsPrefix+"/members", nil) re.NoError(err) diff --git a/tests/integrations/mcs/tso/server_test.go b/tests/integrations/mcs/tso/server_test.go index d87f1542179..58006b87eeb 100644 --- a/tests/integrations/mcs/tso/server_test.go +++ b/tests/integrations/mcs/tso/server_test.go @@ -106,23 +106,19 @@ func (suite *tsoServerTestSuite) TestTSOServerStartAndStopNormally() { cc, err := grpc.DialContext(suite.ctx, s.GetAddr(), grpc.WithInsecure()) re.NoError(err) cc.Close() - url := s.GetAddr() + tsoapi.APIPathPrefix - { - resetJSON := `{"tso":"121312", "force-use-larger":true}` - re.NoError(err) - resp, err := http.Post(url+"/admin/reset-ts", "application/json", strings.NewReader(resetJSON)) - re.NoError(err) - defer resp.Body.Close() - re.Equal(http.StatusOK, resp.StatusCode) - } - { - resetJSON := `{}` - re.NoError(err) - resp, err := http.Post(url+"/admin/reset-ts", "application/json", strings.NewReader(resetJSON)) - re.NoError(err) - defer resp.Body.Close() - re.Equal(http.StatusBadRequest, resp.StatusCode) - } + + url := s.GetAddr() + tsoapi.APIPathPrefix + "/admin/reset-ts" + // Test reset ts + input := []byte(`{"tso":"121312", "force-use-larger":true}`) + err = testutil.CheckPostJSON(dialClient, url, input, + testutil.StatusOK(re), testutil.StringContain(re, "Reset ts successfully")) + suite.NoError(err) + + // Test reset ts with invalid tso + input = []byte(`{}`) + err = testutil.CheckPostJSON(dialClient, suite.backendEndpoints+"/pd/api/v1/admin/reset-ts", input, + testutil.StatusNotOK(re), testutil.StringContain(re, "invalid tso value")) + re.NoError(err) } func (suite *tsoServerTestSuite) TestParticipantStartWithAdvertiseListenAddr() { diff --git a/tests/pdctl/operator/operator_test.go b/tests/pdctl/operator/operator_test.go index c53c5a42a0f..99974bcd2fb 100644 --- a/tests/pdctl/operator/operator_test.go +++ b/tests/pdctl/operator/operator_test.go @@ -16,6 +16,7 @@ package operator_test import ( "context" + "encoding/json" "strconv" "strings" "testing" @@ -242,3 +243,32 @@ func TestOperator(t *testing.T) { return strings.Contains(string(output1), "Success!") || strings.Contains(string(output2), "Success!") }) } + +func TestForwardOperatorRequest(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + cluster, err := tests.NewTestAPICluster(ctx, 1) + re.NoError(err) + re.NoError(cluster.RunInitialServers()) + re.NotEmpty(cluster.WaitLeader()) + server := cluster.GetServer(cluster.GetLeader()) + re.NoError(server.BootstrapCluster()) + backendEndpoints := server.GetAddr() + tc, err := tests.NewTestSchedulingCluster(ctx, 2, backendEndpoints) + re.NoError(err) + defer tc.Destroy() + tc.WaitForPrimaryServing(re) + + cmd := pdctlCmd.GetRootCmd() + args := []string{"-u", backendEndpoints, "operator", "show"} + var slice []string + output, err := pdctl.ExecuteCommand(cmd, args...) + re.NoError(err) + re.NoError(json.Unmarshal(output, &slice)) + re.Len(slice, 0) + args = []string{"-u", backendEndpoints, "operator", "check", "2"} + output, err = pdctl.ExecuteCommand(cmd, args...) + re.NoError(err) + re.Contains(string(output), "null") +} diff --git a/tests/pdctl/scheduler/scheduler_test.go b/tests/pdctl/scheduler/scheduler_test.go index a8dc7b35f11..ee1a506369f 100644 --- a/tests/pdctl/scheduler/scheduler_test.go +++ b/tests/pdctl/scheduler/scheduler_test.go @@ -21,7 +21,9 @@ import ( "time" "github.com/pingcap/kvproto/pkg/metapb" + "github.com/spf13/cobra" "github.com/stretchr/testify/require" + "github.com/tikv/pd/pkg/core" sc "github.com/tikv/pd/pkg/schedule/config" "github.com/tikv/pd/pkg/utils/testutil" "github.com/tikv/pd/pkg/versioninfo" @@ -485,3 +487,116 @@ func TestScheduler(t *testing.T) { re.NoError(err) checkSchedulerWithStatusCommand(nil, "disabled", nil) } + +func TestSchedulerDiagnostic(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + cluster, err := tests.NewTestCluster(ctx, 1) + re.NoError(err) + defer cluster.Destroy() + err = cluster.RunInitialServers() + re.NoError(err) + cluster.WaitLeader() + pdAddr := cluster.GetConfig().GetClientURL() + cmd := pdctlCmd.GetRootCmd() + + checkSchedulerDescribeCommand := func(schedulerName, expectedStatus, expectedSummary string) { + result := make(map[string]interface{}) + testutil.Eventually(re, func() bool { + mightExec(re, cmd, []string{"-u", pdAddr, "scheduler", "describe", schedulerName}, &result) + return len(result) != 0 + }, testutil.WithTickInterval(50*time.Millisecond)) + re.Equal(expectedStatus, result["status"]) + re.Equal(expectedSummary, result["summary"]) + } + + stores := []*metapb.Store{ + { + Id: 1, + State: metapb.StoreState_Up, + LastHeartbeat: time.Now().UnixNano(), + }, + { + Id: 2, + State: metapb.StoreState_Up, + LastHeartbeat: time.Now().UnixNano(), + }, + { + Id: 3, + State: metapb.StoreState_Up, + LastHeartbeat: time.Now().UnixNano(), + }, + { + Id: 4, + State: metapb.StoreState_Up, + LastHeartbeat: time.Now().UnixNano(), + }, + } + leaderServer := cluster.GetServer(cluster.GetLeader()) + re.NoError(leaderServer.BootstrapCluster()) + for _, store := range stores { + pdctl.MustPutStore(re, leaderServer.GetServer(), store) + } + + // note: because pdqsort is a unstable sort algorithm, set ApproximateSize for this region. + pdctl.MustPutRegion(re, cluster, 1, 1, []byte("a"), []byte("b"), core.SetApproximateSize(10)) + time.Sleep(3 * time.Second) + + echo := mustExec(re, cmd, []string{"-u", pdAddr, "config", "set", "enable-diagnostic", "true"}, nil) + re.Contains(echo, "Success!") + checkSchedulerDescribeCommand("balance-region-scheduler", "pending", "1 store(s) RegionNotMatchRule; ") + + // scheduler delete command + mustExec(re, cmd, []string{"-u", pdAddr, "scheduler", "remove", "balance-region-scheduler"}, nil) + + checkSchedulerDescribeCommand("balance-region-scheduler", "disabled", "") + + mustExec(re, cmd, []string{"-u", pdAddr, "scheduler", "pause", "balance-leader-scheduler", "60"}, nil) + mustExec(re, cmd, []string{"-u", pdAddr, "scheduler", "resume", "balance-leader-scheduler"}, nil) + checkSchedulerDescribeCommand("balance-leader-scheduler", "normal", "") +} + +func mustExec(re *require.Assertions, cmd *cobra.Command, args []string, v interface{}) string { + output, err := pdctl.ExecuteCommand(cmd, args...) + re.NoError(err) + if v == nil { + return string(output) + } + re.NoError(json.Unmarshal(output, v)) + return "" +} + +func mightExec(re *require.Assertions, cmd *cobra.Command, args []string, v interface{}) { + output, err := pdctl.ExecuteCommand(cmd, args...) + re.NoError(err) + if v == nil { + return + } + json.Unmarshal(output, v) +} + +func TestForwardSchedulerRequest(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + cluster, err := tests.NewTestAPICluster(ctx, 1) + re.NoError(err) + re.NoError(cluster.RunInitialServers()) + re.NotEmpty(cluster.WaitLeader()) + server := cluster.GetServer(cluster.GetLeader()) + re.NoError(server.BootstrapCluster()) + backendEndpoints := server.GetAddr() + tc, err := tests.NewTestSchedulingCluster(ctx, 2, backendEndpoints) + re.NoError(err) + defer tc.Destroy() + tc.WaitForPrimaryServing(re) + + cmd := pdctlCmd.GetRootCmd() + args := []string{"-u", backendEndpoints, "scheduler", "show"} + var slice []string + output, err := pdctl.ExecuteCommand(cmd, args...) + re.NoError(err) + re.NoError(json.Unmarshal(output, &slice)) + re.Contains(slice, "balance-leader-scheduler") +}