From 4e6f3b3197252e07439ffa3ae5653b7a65cb3729 Mon Sep 17 00:00:00 2001 From: lhy1024 Date: Mon, 11 Sep 2023 16:05:31 +0800 Subject: [PATCH 01/11] mcs: watch scheduling service primary address Signed-off-by: lhy1024 --- pkg/member/participant.go | 33 ++++++++---------- pkg/storage/endpoint/key_path.go | 6 ++++ server/api/server.go | 5 ++- server/server.go | 34 +++++++++++++------ tests/cluster.go | 5 +++ .../mcs/scheduling/server_test.go | 15 ++++++-- 6 files changed, 67 insertions(+), 31 deletions(-) diff --git a/pkg/member/participant.go b/pkg/member/participant.go index 27dced57791..0ce55b136ef 100644 --- a/pkg/member/participant.go +++ b/pkg/member/participant.go @@ -166,15 +166,7 @@ func (m *Participant) setLeader(member participant) { // unsetLeader unsets the member's leader. func (m *Participant) unsetLeader() { - var leader participant - switch m.serviceName { - case utils.TSOServiceName: - leader = &tsopb.Participant{} - case utils.SchedulingServiceName: - leader = &schedulingpb.Participant{} - case utils.ResourceManagerServiceName: - leader = &resource_manager.Participant{} - } + leader := NewParticipantByService(m.serviceName) m.leader.Store(leader) m.lastLeaderUpdatedTime.Store(time.Now()) } @@ -225,15 +217,7 @@ func (m *Participant) PreCheckLeader() error { // getPersistentLeader gets the corresponding leader from etcd by given leaderPath (as the key). func (m *Participant) getPersistentLeader() (participant, int64, error) { - var leader participant - switch m.serviceName { - case utils.TSOServiceName: - leader = &tsopb.Participant{} - case utils.SchedulingServiceName: - leader = &schedulingpb.Participant{} - case utils.ResourceManagerServiceName: - leader = &resource_manager.Participant{} - } + leader := NewParticipantByService(m.serviceName) ok, rev, err := etcdutil.GetProtoMsgWithModRev(m.client, m.GetLeaderPath(), leader) if err != nil { return nil, 0, err @@ -399,3 +383,16 @@ func (m *Participant) campaignCheck() bool { func (m *Participant) SetCampaignChecker(checker leadershipCheckFunc) { m.campaignChecker.Store(checker) } + +// NewParticipantByService creates a new participant by service name. +func NewParticipantByService(serviceName string) (leader participant) { + switch serviceName { + case utils.TSOServiceName: + leader = &tsopb.Participant{} + case utils.SchedulingServiceName: + leader = &schedulingpb.Participant{} + case utils.ResourceManagerServiceName: + leader = &resource_manager.Participant{} + } + return leader +} diff --git a/pkg/storage/endpoint/key_path.go b/pkg/storage/endpoint/key_path.go index 85af79203a4..4b67441a5ac 100644 --- a/pkg/storage/endpoint/key_path.go +++ b/pkg/storage/endpoint/key_path.go @@ -325,6 +325,12 @@ func KeyspaceGroupPrimaryPath(rootPath string, keyspaceGroupID uint32) string { return path.Join(electionPath, utils.PrimaryKey) } +// SchedulingPrimaryPath returns the path of scheduling primary. +// Path: /ms/{cluster_id}/scheduling/primary +func SchedulingPrimaryPath(clusterID uint64) string { + return path.Join(SchedulingSvcRootPath(clusterID), utils.PrimaryKey) +} + // KeyspaceGroupsElectionPath returns the path of keyspace groups election. // default keyspace group: "/ms/{cluster_id}/tso/00000". // non-default keyspace group: "/ms/{cluster_id}/tso/keyspace_groups/election/{group}". diff --git a/server/api/server.go b/server/api/server.go index 272e76cc60b..1d881022c04 100644 --- a/server/api/server.go +++ b/server/api/server.go @@ -20,6 +20,7 @@ import ( "github.com/gorilla/mux" tsoapi "github.com/tikv/pd/pkg/mcs/tso/server/apis/v1" + mcs "github.com/tikv/pd/pkg/mcs/utils" "github.com/tikv/pd/pkg/utils/apiutil" "github.com/tikv/pd/pkg/utils/apiutil/serverapi" "github.com/tikv/pd/server" @@ -39,7 +40,9 @@ func NewHandler(_ context.Context, svr *server.Server) (http.Handler, apiutil.AP router.PathPrefix(apiPrefix).Handler(negroni.New( serverapi.NewRuntimeServiceValidator(svr, group), serverapi.NewRedirector(svr, serverapi.MicroserviceRedirectRule( - apiPrefix+"/api/v1"+"/admin/reset-ts", tsoapi.APIPathPrefix+"/admin/reset-ts", "tso")), + apiPrefix+"/api/v1"+"/admin/reset-ts", + tsoapi.APIPathPrefix+"/admin/reset-ts", + mcs.TSOServiceName)), negroni.Wrap(r)), ) diff --git a/server/server.go b/server/server.go index b74ca5b57b3..846b96ec29e 100644 --- a/server/server.go +++ b/server/server.go @@ -226,10 +226,11 @@ type Server struct { auditBackends []audit.Backend - registry *registry.ServiceRegistry - mode string - servicePrimaryMap sync.Map /* Store as map[string]string */ - tsoPrimaryWatcher *etcdutil.LoopWatcher + registry *registry.ServiceRegistry + mode string + servicePrimaryMap sync.Map /* Store as map[string]string */ + tsoPrimaryWatcher *etcdutil.LoopWatcher + schedulingPrimaryWatcher *etcdutil.LoopWatcher } // HandlerBuilder builds a server HTTP handler. @@ -618,6 +619,8 @@ func (s *Server) startServerLoop(ctx context.Context) { if s.IsAPIServiceMode() { s.initTSOPrimaryWatcher() s.tsoPrimaryWatcher.StartWatchLoop() + s.initSchedulingPrimaryWatcher() + s.schedulingPrimaryWatcher.StartWatchLoop() } } @@ -1962,8 +1965,18 @@ func (s *Server) initTSOPrimaryWatcher() { serviceName := mcs.TSOServiceName tsoRootPath := endpoint.TSOSvcRootPath(s.clusterID) tsoServicePrimaryKey := endpoint.KeyspaceGroupPrimaryPath(tsoRootPath, mcs.DefaultKeyspaceGroupID) + s.tsoPrimaryWatcher = s.initServicePrimaryWatcher(serviceName, tsoServicePrimaryKey) +} + +func (s *Server) initSchedulingPrimaryWatcher() { + serviceName := mcs.SchedulingServiceName + primaryKey := endpoint.SchedulingPrimaryPath(s.clusterID) + s.schedulingPrimaryWatcher = s.initServicePrimaryWatcher(serviceName, primaryKey) +} + +func (s *Server) initServicePrimaryWatcher(serviceName string, primaryKey string) *etcdutil.LoopWatcher { putFn := func(kv *mvccpb.KeyValue) error { - primary := &tsopb.Participant{} // TODO: use Generics + primary := member.NewParticipantByService(serviceName) if err := proto.Unmarshal(kv.Value, primary); err != nil { return err } @@ -1971,7 +1984,7 @@ func (s *Server) initTSOPrimaryWatcher() { if len(listenUrls) > 0 { // listenUrls[0] is the primary service endpoint of the keyspace group s.servicePrimaryMap.Store(serviceName, listenUrls[0]) - log.Info("update tso primary", zap.String("primary", listenUrls[0])) + log.Info("update service primary", zap.String("service-name", serviceName), zap.String("primary", listenUrls[0])) } return nil } @@ -1981,16 +1994,17 @@ func (s *Server) initTSOPrimaryWatcher() { if ok { oldPrimary = v.(string) } - log.Info("delete tso primary", zap.String("old-primary", oldPrimary)) + log.Info("delete service primary", zap.String("service-name", serviceName), zap.String("old-primary", oldPrimary)) s.servicePrimaryMap.Delete(serviceName) return nil } - s.tsoPrimaryWatcher = etcdutil.NewLoopWatcher( + name := fmt.Sprintf("%s-primary-watcher", serviceName) + return etcdutil.NewLoopWatcher( s.serverLoopCtx, &s.serverLoopWg, s.client, - "tso-primary-watcher", - tsoServicePrimaryKey, + name, + primaryKey, putFn, deleteFn, func() error { return nil }, diff --git a/tests/cluster.go b/tests/cluster.go index 607955cc6a9..ce8293531cd 100644 --- a/tests/cluster.go +++ b/tests/cluster.go @@ -433,6 +433,11 @@ func (s *TestServer) GetTSOAllocatorManager() *tso.AllocatorManager { return s.server.GetTSOAllocatorManager() } +// GetServicePrimaryAddr returns the primary address of the service. +func (s *TestServer) GetServicePrimaryAddr(ctx context.Context, serviceName string) (string, bool) { + return s.server.GetServicePrimaryAddr(ctx, serviceName) +} + // TestCluster is only for test. type TestCluster struct { config *clusterConfig diff --git a/tests/integrations/mcs/scheduling/server_test.go b/tests/integrations/mcs/scheduling/server_test.go index 9b5371deb62..54994bbc34b 100644 --- a/tests/integrations/mcs/scheduling/server_test.go +++ b/tests/integrations/mcs/scheduling/server_test.go @@ -21,6 +21,7 @@ import ( "github.com/pingcap/failpoint" "github.com/stretchr/testify/suite" + mcs "github.com/tikv/pd/pkg/mcs/utils" "github.com/tikv/pd/pkg/utils/testutil" "github.com/tikv/pd/tests" "go.uber.org/goleak" @@ -107,11 +108,21 @@ func (suite *serverTestSuite) TestPrimaryChange() { defer tc.Destroy() tc.WaitForPrimaryServing(re) primary := tc.GetPrimaryServer() - addr := primary.GetAddr() + oldPrimaryAddr := primary.GetAddr() re.Len(primary.GetCluster().GetCoordinator().GetSchedulersController().GetSchedulerNames(), 5) + testutil.Eventually(re, func() bool { + watchedAddr, ok := suite.pdLeader.GetServicePrimaryAddr(suite.ctx, mcs.SchedulingServiceName) + return ok && oldPrimaryAddr == watchedAddr + }) + // transfer leader primary.Close() tc.WaitForPrimaryServing(re) primary = tc.GetPrimaryServer() - re.NotEqual(addr, primary.GetAddr()) + newPrimaryAddr := primary.GetAddr() + re.NotEqual(oldPrimaryAddr, newPrimaryAddr) re.Len(primary.GetCluster().GetCoordinator().GetSchedulersController().GetSchedulerNames(), 5) + testutil.Eventually(re, func() bool { + watchedAddr, ok := suite.pdLeader.GetServicePrimaryAddr(suite.ctx, mcs.SchedulingServiceName) + return ok && newPrimaryAddr == watchedAddr + }) } From 780736b4dd5c14e6c49759fcbb7c9500880eb032 Mon Sep 17 00:00:00 2001 From: lhy1024 Date: Mon, 11 Sep 2023 18:15:27 +0800 Subject: [PATCH 02/11] address comments Signed-off-by: lhy1024 --- pkg/member/participant.go | 10 +++++----- server/server.go | 4 ++-- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/pkg/member/participant.go b/pkg/member/participant.go index 0ce55b136ef..b3034a86807 100644 --- a/pkg/member/participant.go +++ b/pkg/member/participant.go @@ -385,14 +385,14 @@ func (m *Participant) SetCampaignChecker(checker leadershipCheckFunc) { } // NewParticipantByService creates a new participant by service name. -func NewParticipantByService(serviceName string) (leader participant) { +func NewParticipantByService(serviceName string) (p participant) { switch serviceName { case utils.TSOServiceName: - leader = &tsopb.Participant{} + p = &tsopb.Participant{} case utils.SchedulingServiceName: - leader = &schedulingpb.Participant{} + p = &schedulingpb.Participant{} case utils.ResourceManagerServiceName: - leader = &resource_manager.Participant{} + p = &resource_manager.Participant{} } - return leader + return p } diff --git a/server/server.go b/server/server.go index 846b96ec29e..2a076923caf 100644 --- a/server/server.go +++ b/server/server.go @@ -618,9 +618,7 @@ func (s *Server) startServerLoop(ctx context.Context) { go s.encryptionKeyManagerLoop() if s.IsAPIServiceMode() { s.initTSOPrimaryWatcher() - s.tsoPrimaryWatcher.StartWatchLoop() s.initSchedulingPrimaryWatcher() - s.schedulingPrimaryWatcher.StartWatchLoop() } } @@ -1966,12 +1964,14 @@ func (s *Server) initTSOPrimaryWatcher() { tsoRootPath := endpoint.TSOSvcRootPath(s.clusterID) tsoServicePrimaryKey := endpoint.KeyspaceGroupPrimaryPath(tsoRootPath, mcs.DefaultKeyspaceGroupID) s.tsoPrimaryWatcher = s.initServicePrimaryWatcher(serviceName, tsoServicePrimaryKey) + s.tsoPrimaryWatcher.StartWatchLoop() } func (s *Server) initSchedulingPrimaryWatcher() { serviceName := mcs.SchedulingServiceName primaryKey := endpoint.SchedulingPrimaryPath(s.clusterID) s.schedulingPrimaryWatcher = s.initServicePrimaryWatcher(serviceName, primaryKey) + s.schedulingPrimaryWatcher.StartWatchLoop() } func (s *Server) initServicePrimaryWatcher(serviceName string, primaryKey string) *etcdutil.LoopWatcher { From 4012981826ffe5541d3c169c1a245def47eb826b Mon Sep 17 00:00:00 2001 From: lhy1024 Date: Tue, 12 Sep 2023 01:28:46 +0800 Subject: [PATCH 03/11] mcs: forward current http request to mcs Signed-off-by: lhy1024 --- pkg/mcs/resourcemanager/server/apis/v1/api.go | 2 +- pkg/mcs/scheduling/server/apis/v1/api.go | 2 +- pkg/utils/apiutil/apiutil.go | 2 ++ pkg/utils/apiutil/serverapi/middleware.go | 17 +++++++--- server/api/server.go | 26 ++++++++++++--- tests/integrations/mcs/scheduling/api_test.go | 31 ++++++++++++++++++ tests/integrations/mcs/tso/api_test.go | 32 ++++++++++++++++--- tests/integrations/mcs/tso/server_test.go | 30 ++++++++--------- tests/pdctl/operator/operator_test.go | 30 +++++++++++++++++ tests/pdctl/scheduler/scheduler_test.go | 25 +++++++++++++++ tools/pd-ctl/pdctl/command/global.go | 2 ++ 11 files changed, 165 insertions(+), 34 deletions(-) diff --git a/pkg/mcs/resourcemanager/server/apis/v1/api.go b/pkg/mcs/resourcemanager/server/apis/v1/api.go index 411933e55c3..7c5e3e010dc 100644 --- a/pkg/mcs/resourcemanager/server/apis/v1/api.go +++ b/pkg/mcs/resourcemanager/server/apis/v1/api.go @@ -73,7 +73,7 @@ func NewService(srv *rmserver.Service) *Service { manager := srv.GetManager() apiHandlerEngine.Use(func(c *gin.Context) { // manager implements the interface of basicserver.Service. - c.Set("service", manager.GetBasicServer()) + c.Set(multiservicesapi.ServiceContextKey, manager.GetBasicServer()) c.Next() }) apiHandlerEngine.Use(multiservicesapi.ServiceRedirector()) diff --git a/pkg/mcs/scheduling/server/apis/v1/api.go b/pkg/mcs/scheduling/server/apis/v1/api.go index e8c4faa5d55..3d1c3921470 100644 --- a/pkg/mcs/scheduling/server/apis/v1/api.go +++ b/pkg/mcs/scheduling/server/apis/v1/api.go @@ -34,7 +34,7 @@ import ( ) // APIPathPrefix is the prefix of the API path. -const APIPathPrefix = "/scheduling/api/v1/" +const APIPathPrefix = "/scheduling/api/v1" var ( once sync.Once diff --git a/pkg/utils/apiutil/apiutil.go b/pkg/utils/apiutil/apiutil.go index 6c32640218e..c5d7c247aca 100644 --- a/pkg/utils/apiutil/apiutil.go +++ b/pkg/utils/apiutil/apiutil.go @@ -196,12 +196,14 @@ func PostJSON(client *http.Client, url string, data []byte) (*http.Response, err return nil, err } req.Header.Set("Content-Type", "application/json") + req.Header.Add("Accept-Encoding", "identity") return client.Do(req) } // GetJSON is used to send GET request to specific url func GetJSON(client *http.Client, url string, data []byte) (*http.Response, error) { req, err := http.NewRequest(http.MethodGet, url, bytes.NewBuffer(data)) + req.Header.Add("Accept-Encoding", "identity") if err != nil { return nil, err } diff --git a/pkg/utils/apiutil/serverapi/middleware.go b/pkg/utils/apiutil/serverapi/middleware.go index 7d403ecef13..d7850029e2e 100644 --- a/pkg/utils/apiutil/serverapi/middleware.go +++ b/pkg/utils/apiutil/serverapi/middleware.go @@ -17,6 +17,7 @@ package serverapi import ( "net/http" "net/url" + "strings" "github.com/pingcap/log" "github.com/tikv/pd/pkg/errs" @@ -108,13 +109,19 @@ func (h *redirector) matchMicroServiceRedirectRules(r *http.Request) (bool, stri return false, "" } for _, rule := range h.microserviceRedirectRules { - if rule.matchPath == r.URL.Path { + if strings.HasPrefix(r.URL.Path, rule.matchPath) { addr, ok := h.s.GetServicePrimaryAddr(r.Context(), rule.targetServiceName) if !ok || addr == "" { log.Warn("failed to get the service primary addr when try match redirect rules", zap.String("path", r.URL.Path)) } - r.URL.Path = rule.targetPath + // Extract parameters from the URL path + pathParams := strings.TrimPrefix(r.URL.Path, rule.matchPath) + if len(pathParams) > 0 && pathParams[0] == '/' { + pathParams = pathParams[1:] // Remove leading '/' + } + r.URL.Path = rule.targetPath + "/" + pathParams + r.URL.Path = strings.TrimRight(r.URL.Path, "/") return true, addr } } @@ -122,10 +129,10 @@ func (h *redirector) matchMicroServiceRedirectRules(r *http.Request) (bool, stri } func (h *redirector) ServeHTTP(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) { - matchedFlag, targetAddr := h.matchMicroServiceRedirectRules(r) + needRedirectToMicroService, targetAddr := h.matchMicroServiceRedirectRules(r) allowFollowerHandle := len(r.Header.Get(apiutil.PDAllowFollowerHandleHeader)) > 0 isLeader := h.s.GetMember().IsLeader() - if !h.s.IsClosed() && (allowFollowerHandle || isLeader) && !matchedFlag { + if !h.s.IsClosed() && (allowFollowerHandle || isLeader) && !needRedirectToMicroService { next(w, r) return } @@ -150,7 +157,7 @@ func (h *redirector) ServeHTTP(w http.ResponseWriter, r *http.Request, next http } var clientUrls []string - if matchedFlag { + if needRedirectToMicroService { if len(targetAddr) == 0 { http.Error(w, apiutil.ErrRedirectFailed, http.StatusInternalServerError) return diff --git a/server/api/server.go b/server/api/server.go index 1d881022c04..5df5d22f2d0 100644 --- a/server/api/server.go +++ b/server/api/server.go @@ -19,6 +19,7 @@ import ( "net/http" "github.com/gorilla/mux" + scheapi "github.com/tikv/pd/pkg/mcs/scheduling/server/apis/v1" tsoapi "github.com/tikv/pd/pkg/mcs/tso/server/apis/v1" mcs "github.com/tikv/pd/pkg/mcs/utils" "github.com/tikv/pd/pkg/utils/apiutil" @@ -35,14 +36,29 @@ func NewHandler(_ context.Context, svr *server.Server) (http.Handler, apiutil.AP Name: "core", IsCore: true, } - router := mux.NewRouter() + prefix := apiPrefix + "/api/v1" r := createRouter(apiPrefix, svr) + router := mux.NewRouter() router.PathPrefix(apiPrefix).Handler(negroni.New( serverapi.NewRuntimeServiceValidator(svr, group), - serverapi.NewRedirector(svr, serverapi.MicroserviceRedirectRule( - apiPrefix+"/api/v1"+"/admin/reset-ts", - tsoapi.APIPathPrefix+"/admin/reset-ts", - mcs.TSOServiceName)), + serverapi.NewRedirector(svr, + serverapi.MicroserviceRedirectRule( + prefix+"/admin/reset-ts", + tsoapi.APIPathPrefix+"/admin/reset-ts", + mcs.TSOServiceName), + serverapi.MicroserviceRedirectRule( + prefix+"/operators", + scheapi.APIPathPrefix+"/operators", + mcs.SchedulingServiceName), + serverapi.MicroserviceRedirectRule( + prefix+"/checker", // Note: this is a typo in the original code + scheapi.APIPathPrefix+"/checkers", + mcs.SchedulingServiceName), + serverapi.MicroserviceRedirectRule( + prefix+"/schedulers", + scheapi.APIPathPrefix+"/schedulers", + mcs.SchedulingServiceName), + ), negroni.Wrap(r)), ) diff --git a/tests/integrations/mcs/scheduling/api_test.go b/tests/integrations/mcs/scheduling/api_test.go index 48bdf1ab95c..03db3433ec9 100644 --- a/tests/integrations/mcs/scheduling/api_test.go +++ b/tests/integrations/mcs/scheduling/api_test.go @@ -106,3 +106,34 @@ func (suite *apiTestSuite) TestGetCheckerByName() { suite.False(resp["paused"].(bool)) } } + +func (suite *apiTestSuite) TestAPIForward() { + re := suite.Require() + tc, err := tests.NewTestSchedulingCluster(suite.ctx, 2, suite.backendEndpoints) + re.NoError(err) + defer tc.Destroy() + tc.WaitForPrimaryServing(re) + + urlPrefix := fmt.Sprintf("%s/pd/api/v1", suite.backendEndpoints) + var slice []string + var resp map[string]interface{} + + // Test opeartor + err = testutil.ReadGetJSON(re, testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "operators"), &slice) + re.NoError(err) + re.Len(slice, 0) + + err = testutil.ReadGetJSON(re, testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "operators/2"), &resp) + re.NoError(err) + re.Nil(resp) + + // Test checker + err = testutil.ReadGetJSON(re, testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "checker/merge"), &resp) + re.NoError(err) + suite.False(resp["paused"].(bool)) + + // Test scheduler + err = testutil.ReadGetJSON(re, testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "schedulers"), &slice) + re.NoError(err) + re.Contains(slice, "balance-leader-scheduler") +} diff --git a/tests/integrations/mcs/tso/api_test.go b/tests/integrations/mcs/tso/api_test.go index fde6bcb8da0..9040ab41b36 100644 --- a/tests/integrations/mcs/tso/api_test.go +++ b/tests/integrations/mcs/tso/api_test.go @@ -30,6 +30,7 @@ import ( apis "github.com/tikv/pd/pkg/mcs/tso/server/apis/v1" mcsutils "github.com/tikv/pd/pkg/mcs/utils" "github.com/tikv/pd/pkg/storage/endpoint" + "github.com/tikv/pd/pkg/utils/testutil" "github.com/tikv/pd/server/config" "github.com/tikv/pd/tests" ) @@ -47,10 +48,11 @@ var dialClient = &http.Client{ type tsoAPITestSuite struct { suite.Suite - ctx context.Context - cancel context.CancelFunc - pdCluster *tests.TestCluster - tsoCluster *tests.TestTSOCluster + ctx context.Context + cancel context.CancelFunc + pdCluster *tests.TestCluster + tsoCluster *tests.TestTSOCluster + backendEndpoints string } func TestTSOAPI(t *testing.T) { @@ -69,7 +71,8 @@ func (suite *tsoAPITestSuite) SetupTest() { leaderName := suite.pdCluster.WaitLeader() pdLeaderServer := suite.pdCluster.GetServer(leaderName) re.NoError(pdLeaderServer.BootstrapCluster()) - suite.tsoCluster, err = tests.NewTestTSOCluster(suite.ctx, 1, pdLeaderServer.GetAddr()) + suite.backendEndpoints = pdLeaderServer.GetAddr() + suite.tsoCluster, err = tests.NewTestTSOCluster(suite.ctx, 1, suite.backendEndpoints) re.NoError(err) } @@ -95,6 +98,25 @@ func (suite *tsoAPITestSuite) TestGetKeyspaceGroupMembers() { re.Equal(primaryMember.GetLeaderID(), defaultGroupMember.PrimaryID) } +func (suite *tsoAPITestSuite) TestResetTS() { + re := suite.Require() + primary := suite.tsoCluster.WaitForDefaultPrimaryServing(re) + re.NotNil(primary) + url := suite.backendEndpoints + "/pd/api/v1/admin/reset-ts" + + // Test reset ts + input := []byte(`{"tso":"121312", "force-use-larger":true}`) + err := testutil.CheckPostJSON(dialClient, url, input, + testutil.StatusOK(re), testutil.StringContain(re, "Reset ts successfully")) + suite.NoError(err) + + // Test reset ts with invalid tso + input = []byte(`{}`) + err = testutil.CheckPostJSON(dialClient, url, input, + testutil.StatusNotOK(re), testutil.StringContain(re, "invalid tso value")) + re.NoError(err) +} + func mustGetKeyspaceGroupMembers(re *require.Assertions, server *tso.Server) map[uint32]*apis.KeyspaceGroupMember { httpReq, err := http.NewRequest(http.MethodGet, server.GetAddr()+tsoKeyspaceGroupsPrefix+"/members", nil) re.NoError(err) diff --git a/tests/integrations/mcs/tso/server_test.go b/tests/integrations/mcs/tso/server_test.go index d87f1542179..58006b87eeb 100644 --- a/tests/integrations/mcs/tso/server_test.go +++ b/tests/integrations/mcs/tso/server_test.go @@ -106,23 +106,19 @@ func (suite *tsoServerTestSuite) TestTSOServerStartAndStopNormally() { cc, err := grpc.DialContext(suite.ctx, s.GetAddr(), grpc.WithInsecure()) re.NoError(err) cc.Close() - url := s.GetAddr() + tsoapi.APIPathPrefix - { - resetJSON := `{"tso":"121312", "force-use-larger":true}` - re.NoError(err) - resp, err := http.Post(url+"/admin/reset-ts", "application/json", strings.NewReader(resetJSON)) - re.NoError(err) - defer resp.Body.Close() - re.Equal(http.StatusOK, resp.StatusCode) - } - { - resetJSON := `{}` - re.NoError(err) - resp, err := http.Post(url+"/admin/reset-ts", "application/json", strings.NewReader(resetJSON)) - re.NoError(err) - defer resp.Body.Close() - re.Equal(http.StatusBadRequest, resp.StatusCode) - } + + url := s.GetAddr() + tsoapi.APIPathPrefix + "/admin/reset-ts" + // Test reset ts + input := []byte(`{"tso":"121312", "force-use-larger":true}`) + err = testutil.CheckPostJSON(dialClient, url, input, + testutil.StatusOK(re), testutil.StringContain(re, "Reset ts successfully")) + suite.NoError(err) + + // Test reset ts with invalid tso + input = []byte(`{}`) + err = testutil.CheckPostJSON(dialClient, suite.backendEndpoints+"/pd/api/v1/admin/reset-ts", input, + testutil.StatusNotOK(re), testutil.StringContain(re, "invalid tso value")) + re.NoError(err) } func (suite *tsoServerTestSuite) TestParticipantStartWithAdvertiseListenAddr() { diff --git a/tests/pdctl/operator/operator_test.go b/tests/pdctl/operator/operator_test.go index ab5687cdc04..63d62c8a874 100644 --- a/tests/pdctl/operator/operator_test.go +++ b/tests/pdctl/operator/operator_test.go @@ -17,6 +17,7 @@ package operator_test import ( "context" "encoding/hex" + "encoding/json" "strconv" "strings" "testing" @@ -251,3 +252,32 @@ func TestOperator(t *testing.T) { return strings.Contains(string(output1), "Success!") || strings.Contains(string(output2), "Success!") }) } + +func TestMicroservice(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + cluster, err := tests.NewTestAPICluster(ctx, 1) + re.NoError(err) + re.NoError(cluster.RunInitialServers()) + re.NotEmpty(cluster.WaitLeader()) + server := cluster.GetServer(cluster.GetLeader()) + re.NoError(server.BootstrapCluster()) + backendEndpoints := server.GetAddr() + tc, err := tests.NewTestSchedulingCluster(ctx, 2, backendEndpoints) + re.NoError(err) + defer tc.Destroy() + tc.WaitForPrimaryServing(re) + + cmd := pdctlCmd.GetRootCmd() + args := []string{"-u", backendEndpoints, "operator", "show"} + var slice []string + output, err := pdctl.ExecuteCommand(cmd, args...) + re.NoError(err) + re.NoError(json.Unmarshal(output, &slice)) + re.Len(slice, 0) + args = []string{"-u", backendEndpoints, "operator", "check", "2"} + output, err = pdctl.ExecuteCommand(cmd, args...) + re.NoError(err) + re.Contains(string(output), "null") +} diff --git a/tests/pdctl/scheduler/scheduler_test.go b/tests/pdctl/scheduler/scheduler_test.go index a0447642cb6..b13790991c4 100644 --- a/tests/pdctl/scheduler/scheduler_test.go +++ b/tests/pdctl/scheduler/scheduler_test.go @@ -530,3 +530,28 @@ func mightExec(re *require.Assertions, cmd *cobra.Command, args []string, v inte } json.Unmarshal(output, v) } + +func TestMicroservice(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + cluster, err := tests.NewTestAPICluster(ctx, 1) + re.NoError(err) + re.NoError(cluster.RunInitialServers()) + re.NotEmpty(cluster.WaitLeader()) + server := cluster.GetServer(cluster.GetLeader()) + re.NoError(server.BootstrapCluster()) + backendEndpoints := server.GetAddr() + tc, err := tests.NewTestSchedulingCluster(ctx, 2, backendEndpoints) + re.NoError(err) + defer tc.Destroy() + tc.WaitForPrimaryServing(re) + + cmd := pdctlCmd.GetRootCmd() + args := []string{"-u", backendEndpoints, "scheduler", "show"} + var slice []string + output, err := pdctl.ExecuteCommand(cmd, args...) + re.NoError(err) + re.NoError(json.Unmarshal(output, &slice)) + re.Contains(slice, "balance-leader-scheduler") +} diff --git a/tools/pd-ctl/pdctl/command/global.go b/tools/pd-ctl/pdctl/command/global.go index 5d8552da51a..47e8a18ecef 100644 --- a/tools/pd-ctl/pdctl/command/global.go +++ b/tools/pd-ctl/pdctl/command/global.go @@ -185,6 +185,7 @@ func requestJSON(cmd *cobra.Command, method, prefix string, input map[string]int return err } req.Header.Set("Content-Type", "application/json") + req.Header.Add("Accept-Encoding", "identity") resp, err = dialClient.Do(req) default: err := errors.Errorf("method %s not supported", method) @@ -228,6 +229,7 @@ func do(endpoint, prefix, method string, resp *string, customHeader http.Header, var req *http.Request req, err = http.NewRequest(method, url, b.body) + req.Header.Add("Accept-Encoding", "identity") if err != nil { return err } From e8a88b9649e44b8bf416e1b904c0af41d7fc6dfa Mon Sep 17 00:00:00 2001 From: lhy1024 Date: Tue, 12 Sep 2023 03:22:50 +0800 Subject: [PATCH 04/11] avoid unexpected EOF error with gzip Signed-off-by: lhy1024 --- pkg/utils/apiutil/apiutil.go | 14 ++++++++++---- tests/integrations/mcs/tso/api_test.go | 2 +- tools/pd-ctl/pdctl/command/global.go | 2 -- 3 files changed, 11 insertions(+), 7 deletions(-) diff --git a/pkg/utils/apiutil/apiutil.go b/pkg/utils/apiutil/apiutil.go index c5d7c247aca..6af07f11885 100644 --- a/pkg/utils/apiutil/apiutil.go +++ b/pkg/utils/apiutil/apiutil.go @@ -196,14 +196,12 @@ func PostJSON(client *http.Client, url string, data []byte) (*http.Response, err return nil, err } req.Header.Set("Content-Type", "application/json") - req.Header.Add("Accept-Encoding", "identity") return client.Do(req) } // GetJSON is used to send GET request to specific url func GetJSON(client *http.Client, url string, data []byte) (*http.Response, error) { req, err := http.NewRequest(http.MethodGet, url, bytes.NewBuffer(data)) - req.Header.Add("Accept-Encoding", "identity") if err != nil { return nil, err } @@ -437,13 +435,17 @@ func (p *customReverseProxies) ServeHTTP(w http.ResponseWriter, r *http.Request) reader = resp.Body } - copyHeader(w.Header(), resp.Header) + // If WriteHeader has not yet been called, Write calls WriteHeader(http.StatusOK) before writing the data. + // So we need to call WriteHeader first. w.WriteHeader(resp.StatusCode) + + var contentLength, written int64 for { - if _, err = io.CopyN(w, reader, chunkSize); err != nil { + if written, err = io.CopyN(w, reader, chunkSize); err != nil { if err == io.EOF { err = nil } + contentLength += written break } } @@ -452,6 +454,10 @@ func (p *customReverseProxies) ServeHTTP(w http.ResponseWriter, r *http.Request) // try next url. continue } + + // We need to set the Content-Length header manually to avoid meeting unexpected EOF error. + w.Header().Set("Content-Length", strconv.FormatInt(contentLength, 10)) + copyHeader(w.Header(), resp.Header) return } http.Error(w, ErrRedirectFailed, http.StatusInternalServerError) diff --git a/tests/integrations/mcs/tso/api_test.go b/tests/integrations/mcs/tso/api_test.go index 9040ab41b36..7e870fbc198 100644 --- a/tests/integrations/mcs/tso/api_test.go +++ b/tests/integrations/mcs/tso/api_test.go @@ -98,7 +98,7 @@ func (suite *tsoAPITestSuite) TestGetKeyspaceGroupMembers() { re.Equal(primaryMember.GetLeaderID(), defaultGroupMember.PrimaryID) } -func (suite *tsoAPITestSuite) TestResetTS() { +func (suite *tsoAPITestSuite) TestForwardResetTS() { re := suite.Require() primary := suite.tsoCluster.WaitForDefaultPrimaryServing(re) re.NotNil(primary) diff --git a/tools/pd-ctl/pdctl/command/global.go b/tools/pd-ctl/pdctl/command/global.go index 47e8a18ecef..5d8552da51a 100644 --- a/tools/pd-ctl/pdctl/command/global.go +++ b/tools/pd-ctl/pdctl/command/global.go @@ -185,7 +185,6 @@ func requestJSON(cmd *cobra.Command, method, prefix string, input map[string]int return err } req.Header.Set("Content-Type", "application/json") - req.Header.Add("Accept-Encoding", "identity") resp, err = dialClient.Do(req) default: err := errors.Errorf("method %s not supported", method) @@ -229,7 +228,6 @@ func do(endpoint, prefix, method string, resp *string, customHeader http.Header, var req *http.Request req, err = http.NewRequest(method, url, b.body) - req.Header.Add("Accept-Encoding", "identity") if err != nil { return err } From 0f733dc439f96cc6e970ac904007707ce9febc09 Mon Sep 17 00:00:00 2001 From: lhy1024 Date: Tue, 12 Sep 2023 04:43:58 +0800 Subject: [PATCH 05/11] fix test Signed-off-by: lhy1024 --- pkg/utils/apiutil/apiutil.go | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/pkg/utils/apiutil/apiutil.go b/pkg/utils/apiutil/apiutil.go index 6af07f11885..8702b750df5 100644 --- a/pkg/utils/apiutil/apiutil.go +++ b/pkg/utils/apiutil/apiutil.go @@ -435,17 +435,18 @@ func (p *customReverseProxies) ServeHTTP(w http.ResponseWriter, r *http.Request) reader = resp.Body } - // If WriteHeader has not yet been called, Write calls WriteHeader(http.StatusOK) before writing the data. - // So we need to call WriteHeader first. + // We need to copy the response headers before we write the header. + // Otherwise, we cannot set the header. + // And we need to write the header before we copy the response body. + // Otherwise, we cannot set the status code. + copyHeader(w.Header(), resp.Header) w.WriteHeader(resp.StatusCode) - var contentLength, written int64 for { - if written, err = io.CopyN(w, reader, chunkSize); err != nil { + if _, err = io.CopyN(w, reader, chunkSize); err != nil { if err == io.EOF { err = nil } - contentLength += written break } } @@ -454,10 +455,6 @@ func (p *customReverseProxies) ServeHTTP(w http.ResponseWriter, r *http.Request) // try next url. continue } - - // We need to set the Content-Length header manually to avoid meeting unexpected EOF error. - w.Header().Set("Content-Length", strconv.FormatInt(contentLength, 10)) - copyHeader(w.Header(), resp.Header) return } http.Error(w, ErrRedirectFailed, http.StatusInternalServerError) @@ -465,6 +462,11 @@ func (p *customReverseProxies) ServeHTTP(w http.ResponseWriter, r *http.Request) func copyHeader(dst, src http.Header) { for k, vv := range src { + // skip Content-Encoding and Content-Length header + // because they need to be set by http.ResponseWriter when gzip is enabled + if k == "Content-Encoding" || k == "Content-Length" { + continue + } values := dst[k] for _, v := range vv { if !slice.Contains(values, v) { From 32a81e3a1f683dc3aff9d7c7062d2363598edd03 Mon Sep 17 00:00:00 2001 From: lhy1024 Date: Tue, 12 Sep 2023 19:09:42 +0800 Subject: [PATCH 06/11] address comments Signed-off-by: lhy1024 --- pkg/utils/apiutil/apiutil.go | 13 +++++++++---- pkg/utils/apiutil/serverapi/middleware.go | 21 +++++++++++++-------- 2 files changed, 22 insertions(+), 12 deletions(-) diff --git a/pkg/utils/apiutil/apiutil.go b/pkg/utils/apiutil/apiutil.go index 8702b750df5..bcab7c8e9e7 100644 --- a/pkg/utils/apiutil/apiutil.go +++ b/pkg/utils/apiutil/apiutil.go @@ -436,9 +436,13 @@ func (p *customReverseProxies) ServeHTTP(w http.ResponseWriter, r *http.Request) } // We need to copy the response headers before we write the header. - // Otherwise, we cannot set the header. + // Otherwise, we cannot set the header after w.WriteHeader() is called. // And we need to write the header before we copy the response body. - // Otherwise, we cannot set the status code. + // Otherwise, we cannot set the status code after w.Write() is called. + // In other words, we must perform the following steps strictly in order: + // 1. Set the response headers. + // 2. Write the response header. + // 3. Write the response body. copyHeader(w.Header(), resp.Header) w.WriteHeader(resp.StatusCode) @@ -460,10 +464,11 @@ func (p *customReverseProxies) ServeHTTP(w http.ResponseWriter, r *http.Request) http.Error(w, ErrRedirectFailed, http.StatusInternalServerError) } +// copyHeader duplicates the HTTP headers from the source `src` to the destination `dst`. +// It skips the "Content-Encoding" and "Content-Length" headers because they should be set by `http.ResponseWriter`. +// These headers may be modified after a redirect when gzip compression is enabled. func copyHeader(dst, src http.Header) { for k, vv := range src { - // skip Content-Encoding and Content-Length header - // because they need to be set by http.ResponseWriter when gzip is enabled if k == "Content-Encoding" || k == "Content-Length" { continue } diff --git a/pkg/utils/apiutil/serverapi/middleware.go b/pkg/utils/apiutil/serverapi/middleware.go index d7850029e2e..5e9cbd8b245 100644 --- a/pkg/utils/apiutil/serverapi/middleware.go +++ b/pkg/utils/apiutil/serverapi/middleware.go @@ -112,16 +112,21 @@ func (h *redirector) matchMicroServiceRedirectRules(r *http.Request) (bool, stri if strings.HasPrefix(r.URL.Path, rule.matchPath) { addr, ok := h.s.GetServicePrimaryAddr(r.Context(), rule.targetServiceName) if !ok || addr == "" { - log.Warn("failed to get the service primary addr when try match redirect rules", + log.Warn("failed to get the service primary addr when trying to match redirect rules", zap.String("path", r.URL.Path)) } // Extract parameters from the URL path + // e.g. r.URL.Path = /pd/api/v1/operators/1 (before redirect) + // matchPath = /pd/api/v1/operators + // targetPath = /scheduling/api/v1/operators + // r.URL.Path = /scheduling/api/v1/operator/1 (after redirect) pathParams := strings.TrimPrefix(r.URL.Path, rule.matchPath) - if len(pathParams) > 0 && pathParams[0] == '/' { - pathParams = pathParams[1:] // Remove leading '/' + pathParams = strings.Trim(pathParams, "/") // Remove leading and trailing '/' + if len(pathParams) > 0 { + r.URL.Path = rule.targetPath + "/" + pathParams + } else { + r.URL.Path = rule.targetPath } - r.URL.Path = rule.targetPath + "/" + pathParams - r.URL.Path = strings.TrimRight(r.URL.Path, "/") return true, addr } } @@ -129,10 +134,10 @@ func (h *redirector) matchMicroServiceRedirectRules(r *http.Request) (bool, stri } func (h *redirector) ServeHTTP(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) { - needRedirectToMicroService, targetAddr := h.matchMicroServiceRedirectRules(r) + redirectToMicroService, targetAddr := h.matchMicroServiceRedirectRules(r) allowFollowerHandle := len(r.Header.Get(apiutil.PDAllowFollowerHandleHeader)) > 0 isLeader := h.s.GetMember().IsLeader() - if !h.s.IsClosed() && (allowFollowerHandle || isLeader) && !needRedirectToMicroService { + if !h.s.IsClosed() && (allowFollowerHandle || isLeader) && !redirectToMicroService { next(w, r) return } @@ -157,7 +162,7 @@ func (h *redirector) ServeHTTP(w http.ResponseWriter, r *http.Request, next http } var clientUrls []string - if needRedirectToMicroService { + if redirectToMicroService { if len(targetAddr) == 0 { http.Error(w, apiutil.ErrRedirectFailed, http.StatusInternalServerError) return From 40f129f816a3067831cf8e5b8766fb8f00e6e851 Mon Sep 17 00:00:00 2001 From: lhy1024 Date: Wed, 13 Sep 2023 14:18:06 +0800 Subject: [PATCH 07/11] add methods check Signed-off-by: lhy1024 --- pkg/utils/apiutil/serverapi/middleware.go | 7 +++++-- server/api/server.go | 14 ++++++++++---- 2 files changed, 15 insertions(+), 6 deletions(-) diff --git a/pkg/utils/apiutil/serverapi/middleware.go b/pkg/utils/apiutil/serverapi/middleware.go index 5e9cbd8b245..6682f4aabbc 100644 --- a/pkg/utils/apiutil/serverapi/middleware.go +++ b/pkg/utils/apiutil/serverapi/middleware.go @@ -21,6 +21,7 @@ import ( "github.com/pingcap/log" "github.com/tikv/pd/pkg/errs" + "github.com/tikv/pd/pkg/slice" "github.com/tikv/pd/pkg/utils/apiutil" "github.com/tikv/pd/server" "github.com/urfave/negroni" @@ -76,6 +77,7 @@ type microserviceRedirectRule struct { matchPath string targetPath string targetServiceName string + matchMethods []string } // NewRedirector redirects request to the leader if needs to be handled in the leader. @@ -91,12 +93,13 @@ func NewRedirector(s *server.Server, opts ...RedirectorOption) negroni.Handler { type RedirectorOption func(*redirector) // MicroserviceRedirectRule new a microservice redirect rule option -func MicroserviceRedirectRule(matchPath, targetPath, targetServiceName string) RedirectorOption { +func MicroserviceRedirectRule(matchPath, targetPath, targetServiceName string, methods []string) RedirectorOption { return func(s *redirector) { s.microserviceRedirectRules = append(s.microserviceRedirectRules, µserviceRedirectRule{ matchPath, targetPath, targetServiceName, + methods, }) } } @@ -109,7 +112,7 @@ func (h *redirector) matchMicroServiceRedirectRules(r *http.Request) (bool, stri return false, "" } for _, rule := range h.microserviceRedirectRules { - if strings.HasPrefix(r.URL.Path, rule.matchPath) { + if strings.HasPrefix(r.URL.Path, rule.matchPath) && slice.Contains(rule.matchMethods, r.Method) { addr, ok := h.s.GetServicePrimaryAddr(r.Context(), rule.targetServiceName) if !ok || addr == "" { log.Warn("failed to get the service primary addr when trying to match redirect rules", diff --git a/server/api/server.go b/server/api/server.go index 5df5d22f2d0..fa7d174cac8 100644 --- a/server/api/server.go +++ b/server/api/server.go @@ -45,19 +45,25 @@ func NewHandler(_ context.Context, svr *server.Server) (http.Handler, apiutil.AP serverapi.MicroserviceRedirectRule( prefix+"/admin/reset-ts", tsoapi.APIPathPrefix+"/admin/reset-ts", - mcs.TSOServiceName), + mcs.TSOServiceName, + []string{http.MethodPost}), serverapi.MicroserviceRedirectRule( prefix+"/operators", scheapi.APIPathPrefix+"/operators", - mcs.SchedulingServiceName), + mcs.SchedulingServiceName, + []string{http.MethodPost, http.MethodGet, http.MethodDelete}), + // because the writing of all the meta information of the scheduling service is in the API server, + // we only forward read-only requests about checkers and schedulers to the scheduling service. serverapi.MicroserviceRedirectRule( prefix+"/checker", // Note: this is a typo in the original code scheapi.APIPathPrefix+"/checkers", - mcs.SchedulingServiceName), + mcs.SchedulingServiceName, + []string{http.MethodGet}), serverapi.MicroserviceRedirectRule( prefix+"/schedulers", scheapi.APIPathPrefix+"/schedulers", - mcs.SchedulingServiceName), + mcs.SchedulingServiceName, + []string{http.MethodGet}), ), negroni.Wrap(r)), ) From 98a573b1866ab8db2f4810d7b7ee4a709f278d31 Mon Sep 17 00:00:00 2001 From: lhy1024 Date: Thu, 14 Sep 2023 11:39:07 +0800 Subject: [PATCH 08/11] address comments Signed-off-by: lhy1024 --- server/api/server.go | 2 ++ tests/pdctl/operator/operator_test.go | 2 +- tests/pdctl/scheduler/scheduler_test.go | 2 +- 3 files changed, 4 insertions(+), 2 deletions(-) diff --git a/server/api/server.go b/server/api/server.go index fa7d174cac8..0094d8eb5dd 100644 --- a/server/api/server.go +++ b/server/api/server.go @@ -64,6 +64,8 @@ func NewHandler(_ context.Context, svr *server.Server) (http.Handler, apiutil.AP scheapi.APIPathPrefix+"/schedulers", mcs.SchedulingServiceName, []string{http.MethodGet}), + // TODO: we need to consider the case that v1 api not support restful api. + // we might change the previous path parameters to query parameters. ), negroni.Wrap(r)), ) diff --git a/tests/pdctl/operator/operator_test.go b/tests/pdctl/operator/operator_test.go index 63d62c8a874..b1acfcfefe6 100644 --- a/tests/pdctl/operator/operator_test.go +++ b/tests/pdctl/operator/operator_test.go @@ -253,7 +253,7 @@ func TestOperator(t *testing.T) { }) } -func TestMicroservice(t *testing.T) { +func TestOperatorInMicroservice(t *testing.T) { re := require.New(t) ctx, cancel := context.WithCancel(context.Background()) defer cancel() diff --git a/tests/pdctl/scheduler/scheduler_test.go b/tests/pdctl/scheduler/scheduler_test.go index b13790991c4..c5815c35dd2 100644 --- a/tests/pdctl/scheduler/scheduler_test.go +++ b/tests/pdctl/scheduler/scheduler_test.go @@ -531,7 +531,7 @@ func mightExec(re *require.Assertions, cmd *cobra.Command, args []string, v inte json.Unmarshal(output, v) } -func TestMicroservice(t *testing.T) { +func TestSchedulerInMicroservice(t *testing.T) { re := require.New(t) ctx, cancel := context.WithCancel(context.Background()) defer cancel() From 590d971f6c19e28c474961b00b7e52b545ca92aa Mon Sep 17 00:00:00 2001 From: lhy1024 Date: Fri, 15 Sep 2023 11:51:35 +0800 Subject: [PATCH 09/11] rename Signed-off-by: lhy1024 --- tests/pdctl/operator/operator_test.go | 2 +- tests/pdctl/scheduler/scheduler_test.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/pdctl/operator/operator_test.go b/tests/pdctl/operator/operator_test.go index b1acfcfefe6..148cbc9e081 100644 --- a/tests/pdctl/operator/operator_test.go +++ b/tests/pdctl/operator/operator_test.go @@ -253,7 +253,7 @@ func TestOperator(t *testing.T) { }) } -func TestOperatorInMicroservice(t *testing.T) { +func TestForwardOperatorRequest(t *testing.T) { re := require.New(t) ctx, cancel := context.WithCancel(context.Background()) defer cancel() diff --git a/tests/pdctl/scheduler/scheduler_test.go b/tests/pdctl/scheduler/scheduler_test.go index c5815c35dd2..31e6270aa3b 100644 --- a/tests/pdctl/scheduler/scheduler_test.go +++ b/tests/pdctl/scheduler/scheduler_test.go @@ -531,7 +531,7 @@ func mightExec(re *require.Assertions, cmd *cobra.Command, args []string, v inte json.Unmarshal(output, v) } -func TestSchedulerInMicroservice(t *testing.T) { +func TestForwardSchedulerRequest(t *testing.T) { re := require.New(t) ctx, cancel := context.WithCancel(context.Background()) defer cancel() From 2193aff21727842949a328d9400bdb1ed1412844 Mon Sep 17 00:00:00 2001 From: lhy1024 Date: Fri, 15 Sep 2023 16:23:37 +0800 Subject: [PATCH 10/11] add test Signed-off-by: lhy1024 --- pkg/utils/apiutil/apiutil.go | 2 + pkg/utils/apiutil/serverapi/middleware.go | 1 + pkg/utils/testutil/api_check.go | 73 +++++++++++-------- server/api/hot_status_test.go | 5 +- server/api/region_test.go | 2 +- server/api/rule_test.go | 2 +- tests/integrations/mcs/scheduling/api_test.go | 33 +++++++-- 7 files changed, 79 insertions(+), 39 deletions(-) diff --git a/pkg/utils/apiutil/apiutil.go b/pkg/utils/apiutil/apiutil.go index bcab7c8e9e7..0b72b9af10f 100644 --- a/pkg/utils/apiutil/apiutil.go +++ b/pkg/utils/apiutil/apiutil.go @@ -57,6 +57,8 @@ const ( XForwardedPortHeader = "X-Forwarded-Port" // XRealIPHeader is used to mark the real client IP. XRealIPHeader = "X-Real-Ip" + // ForwardToMicroServiceHeader is used to mark the request is forwarded to micro service. + ForwardToMicroServiceHeader = "Forward-To-Micro-Service" // ErrRedirectFailed is the error message for redirect failed. ErrRedirectFailed = "redirect failed" diff --git a/pkg/utils/apiutil/serverapi/middleware.go b/pkg/utils/apiutil/serverapi/middleware.go index 6682f4aabbc..1b97ce4d6aa 100644 --- a/pkg/utils/apiutil/serverapi/middleware.go +++ b/pkg/utils/apiutil/serverapi/middleware.go @@ -171,6 +171,7 @@ func (h *redirector) ServeHTTP(w http.ResponseWriter, r *http.Request, next http return } clientUrls = append(clientUrls, targetAddr) + w.Header().Set(apiutil.ForwardToMicroServiceHeader, "true") } else { leader := h.s.GetMember().GetLeader() if leader == nil { diff --git a/pkg/utils/testutil/api_check.go b/pkg/utils/testutil/api_check.go index fcc445b7e7a..d11d575967d 100644 --- a/pkg/utils/testutil/api_check.go +++ b/pkg/utils/testutil/api_check.go @@ -23,56 +23,71 @@ import ( "github.com/tikv/pd/pkg/utils/apiutil" ) -// Status is used to check whether http response code is equal given code -func Status(re *require.Assertions, code int) func([]byte, int) { - return func(resp []byte, i int) { +// Status is used to check whether http response code is equal given code. +func Status(re *require.Assertions, code int) func([]byte, int, http.Header) { + return func(resp []byte, i int, _ http.Header) { re.Equal(code, i, "resp: "+string(resp)) } } -// StatusOK is used to check whether http response code is equal http.StatusOK -func StatusOK(re *require.Assertions) func([]byte, int) { +// StatusOK is used to check whether http response code is equal http.StatusOK. +func StatusOK(re *require.Assertions) func([]byte, int, http.Header) { return Status(re, http.StatusOK) } -// StatusNotOK is used to check whether http response code is not equal http.StatusOK -func StatusNotOK(re *require.Assertions) func([]byte, int) { - return func(_ []byte, i int) { +// StatusNotOK is used to check whether http response code is not equal http.StatusOK. +func StatusNotOK(re *require.Assertions) func([]byte, int, http.Header) { + return func(_ []byte, i int, _ http.Header) { re.NotEqual(http.StatusOK, i) } } -// ExtractJSON is used to check whether given data can be extracted successfully -func ExtractJSON(re *require.Assertions, data interface{}) func([]byte, int) { - return func(res []byte, _ int) { +// ExtractJSON is used to check whether given data can be extracted successfully. +func ExtractJSON(re *require.Assertions, data interface{}) func([]byte, int, http.Header) { + return func(res []byte, _ int, _ http.Header) { re.NoError(json.Unmarshal(res, data)) } } -// StringContain is used to check whether response context contains given string -func StringContain(re *require.Assertions, sub string) func([]byte, int) { - return func(res []byte, _ int) { +// StringContain is used to check whether response context contains given string. +func StringContain(re *require.Assertions, sub string) func([]byte, int, http.Header) { + return func(res []byte, _ int, _ http.Header) { re.Contains(string(res), sub) } } -// StringEqual is used to check whether response context equal given string -func StringEqual(re *require.Assertions, str string) func([]byte, int) { - return func(res []byte, _ int) { +// StringEqual is used to check whether response context equal given string. +func StringEqual(re *require.Assertions, str string) func([]byte, int, http.Header) { + return func(res []byte, _ int, _ http.Header) { re.Contains(string(res), str) } } -// ReadGetJSON is used to do get request and check whether given data can be extracted successfully -func ReadGetJSON(re *require.Assertions, client *http.Client, url string, data interface{}) error { +// WithHeader is used to check whether response header contains given key and value. +func WithHeader(re *require.Assertions, key, value string) func([]byte, int, http.Header) { + return func(_ []byte, _ int, header http.Header) { + re.Equal(value, header.Get(key)) + } +} + +// WithoutHeader is used to check whether response header does not contain given key. +func WithoutHeader(re *require.Assertions, key string) func([]byte, int, http.Header) { + return func(_ []byte, _ int, header http.Header) { + re.Empty(header.Get(key)) + } +} + +// ReadGetJSON is used to do get request and check whether given data can be extracted successfully. +func ReadGetJSON(re *require.Assertions, client *http.Client, url string, data interface{}, checkOpts ...func([]byte, int, http.Header)) error { resp, err := apiutil.GetJSON(client, url, nil) if err != nil { return err } - return checkResp(resp, StatusOK(re), ExtractJSON(re, data)) + checkOpts = append(checkOpts, StatusOK(re), ExtractJSON(re, data)) + return checkResp(resp, checkOpts...) } -// ReadGetJSONWithBody is used to do get request with input and check whether given data can be extracted successfully +// ReadGetJSONWithBody is used to do get request with input and check whether given data can be extracted successfully. func ReadGetJSONWithBody(re *require.Assertions, client *http.Client, url string, input []byte, data interface{}) error { resp, err := apiutil.GetJSON(client, url, input) if err != nil { @@ -81,8 +96,8 @@ func ReadGetJSONWithBody(re *require.Assertions, client *http.Client, url string return checkResp(resp, StatusOK(re), ExtractJSON(re, data)) } -// CheckPostJSON is used to do post request and do check options -func CheckPostJSON(client *http.Client, url string, data []byte, checkOpts ...func([]byte, int)) error { +// CheckPostJSON is used to do post request and do check options. +func CheckPostJSON(client *http.Client, url string, data []byte, checkOpts ...func([]byte, int, http.Header)) error { resp, err := apiutil.PostJSON(client, url, data) if err != nil { return err @@ -90,8 +105,8 @@ func CheckPostJSON(client *http.Client, url string, data []byte, checkOpts ...fu return checkResp(resp, checkOpts...) } -// CheckGetJSON is used to do get request and do check options -func CheckGetJSON(client *http.Client, url string, data []byte, checkOpts ...func([]byte, int)) error { +// CheckGetJSON is used to do get request and do check options. +func CheckGetJSON(client *http.Client, url string, data []byte, checkOpts ...func([]byte, int, http.Header)) error { resp, err := apiutil.GetJSON(client, url, data) if err != nil { return err @@ -99,8 +114,8 @@ func CheckGetJSON(client *http.Client, url string, data []byte, checkOpts ...fun return checkResp(resp, checkOpts...) } -// CheckPatchJSON is used to do patch request and do check options -func CheckPatchJSON(client *http.Client, url string, data []byte, checkOpts ...func([]byte, int)) error { +// CheckPatchJSON is used to do patch request and do check options. +func CheckPatchJSON(client *http.Client, url string, data []byte, checkOpts ...func([]byte, int, http.Header)) error { resp, err := apiutil.PatchJSON(client, url, data) if err != nil { return err @@ -108,14 +123,14 @@ func CheckPatchJSON(client *http.Client, url string, data []byte, checkOpts ...f return checkResp(resp, checkOpts...) } -func checkResp(resp *http.Response, checkOpts ...func([]byte, int)) error { +func checkResp(resp *http.Response, checkOpts ...func([]byte, int, http.Header)) error { res, err := io.ReadAll(resp.Body) resp.Body.Close() if err != nil { return err } for _, opt := range checkOpts { - opt(res, resp.StatusCode) + opt(res, resp.StatusCode, resp.Header) } return nil } diff --git a/server/api/hot_status_test.go b/server/api/hot_status_test.go index a1d1bbc2617..d3d495f86fa 100644 --- a/server/api/hot_status_test.go +++ b/server/api/hot_status_test.go @@ -17,6 +17,7 @@ package api import ( "encoding/json" "fmt" + "net/http" "testing" "time" @@ -92,7 +93,7 @@ func (suite *hotStatusTestSuite) TestGetHistoryHotRegionsTimeRange() { StartTime: now.UnixNano() / int64(time.Millisecond), EndTime: now.Add(10*time.Second).UnixNano() / int64(time.Millisecond), } - check := func(res []byte, statusCode int) { + check := func(res []byte, statusCode int, _ http.Header) { suite.Equal(200, statusCode) historyHotRegions := &storage.HistoryHotRegions{} json.Unmarshal(res, historyHotRegions) @@ -177,7 +178,7 @@ func (suite *hotStatusTestSuite) TestGetHistoryHotRegionsIDAndTypes() { IsLearners: []bool{false}, EndTime: now.Add(10*time.Minute).UnixNano() / int64(time.Millisecond), } - check := func(res []byte, statusCode int) { + check := func(res []byte, statusCode int, _ http.Header) { suite.Equal(200, statusCode) historyHotRegions := &storage.HistoryHotRegions{} json.Unmarshal(res, historyHotRegions) diff --git a/server/api/region_test.go b/server/api/region_test.go index 63da19ab082..acd305884d4 100644 --- a/server/api/region_test.go +++ b/server/api/region_test.go @@ -408,7 +408,7 @@ func (suite *regionTestSuite) TestSplitRegions() { hex.EncodeToString([]byte("bbb")), hex.EncodeToString([]byte("ccc")), hex.EncodeToString([]byte("ddd"))) - checkOpt := func(res []byte, code int) { + checkOpt := func(res []byte, code int, _ http.Header) { s := &struct { ProcessedPercentage int `json:"processed-percentage"` NewRegionsID []uint64 `json:"regions-id"` diff --git a/server/api/rule_test.go b/server/api/rule_test.go index d2000eb9562..4cea1523401 100644 --- a/server/api/rule_test.go +++ b/server/api/rule_test.go @@ -829,7 +829,7 @@ func (suite *ruleTestSuite) TestBundleBadRequest() { } for _, testCase := range testCases { err := tu.CheckPostJSON(testDialClient, suite.urlPrefix+testCase.uri, []byte(testCase.data), - func(_ []byte, code int) { + func(_ []byte, code int, _ http.Header) { suite.Equal(testCase.ok, code == http.StatusOK) }) suite.NoError(err) diff --git a/tests/integrations/mcs/scheduling/api_test.go b/tests/integrations/mcs/scheduling/api_test.go index 03db3433ec9..ea9cd2df9c5 100644 --- a/tests/integrations/mcs/scheduling/api_test.go +++ b/tests/integrations/mcs/scheduling/api_test.go @@ -2,6 +2,7 @@ package scheduling_test import ( "context" + "encoding/json" "fmt" "net/http" "testing" @@ -9,6 +10,7 @@ import ( "github.com/stretchr/testify/suite" _ "github.com/tikv/pd/pkg/mcs/scheduling/server/apis/v1" + "github.com/tikv/pd/pkg/utils/apiutil" "github.com/tikv/pd/pkg/utils/tempurl" "github.com/tikv/pd/pkg/utils/testutil" "github.com/tikv/pd/tests" @@ -119,21 +121,40 @@ func (suite *apiTestSuite) TestAPIForward() { var resp map[string]interface{} // Test opeartor - err = testutil.ReadGetJSON(re, testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "operators"), &slice) + err = testutil.ReadGetJSON(re, testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "operators"), &slice, + testutil.WithHeader(re, apiutil.ForwardToMicroServiceHeader, "true")) re.NoError(err) re.Len(slice, 0) - err = testutil.ReadGetJSON(re, testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "operators/2"), &resp) + err = testutil.ReadGetJSON(re, testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "operators/2"), &resp, + testutil.WithHeader(re, apiutil.ForwardToMicroServiceHeader, "true")) re.NoError(err) re.Nil(resp) - // Test checker - err = testutil.ReadGetJSON(re, testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "checker/merge"), &resp) + // Test checker: only read-only requests are forwarded + err = testutil.ReadGetJSON(re, testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "checker/merge"), &resp, + testutil.WithHeader(re, apiutil.ForwardToMicroServiceHeader, "true")) re.NoError(err) suite.False(resp["paused"].(bool)) - // Test scheduler - err = testutil.ReadGetJSON(re, testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "schedulers"), &slice) + input := make(map[string]interface{}) + input["delay"] = 10 + pauseArgs, err := json.Marshal(input) + suite.NoError(err) + err = testutil.CheckPostJSON(testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "checker/merge"), pauseArgs, + testutil.StatusOK(re), testutil.WithoutHeader(re, apiutil.PDRedirectorHeader)) + suite.NoError(err) + + // Test scheduler: only read-only requests are forwarded + err = testutil.ReadGetJSON(re, testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "schedulers"), &slice, + testutil.WithHeader(re, apiutil.ForwardToMicroServiceHeader, "true")) re.NoError(err) re.Contains(slice, "balance-leader-scheduler") + + input["delay"] = 30 + pauseArgs, err = json.Marshal(input) + suite.NoError(err) + err = testutil.CheckPostJSON(testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "schedulers/all"), pauseArgs, + testutil.StatusOK(re), testutil.WithoutHeader(re, apiutil.ForwardToMicroServiceHeader)) + suite.NoError(err) } From f070884b6057829e34583062609e5464225e2301 Mon Sep 17 00:00:00 2001 From: lhy1024 Date: Mon, 18 Sep 2023 15:38:37 +0800 Subject: [PATCH 11/11] add test with failpoint Signed-off-by: lhy1024 --- pkg/utils/apiutil/serverapi/middleware.go | 7 ++++++- tests/integrations/mcs/scheduling/api_test.go | 6 ++++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/pkg/utils/apiutil/serverapi/middleware.go b/pkg/utils/apiutil/serverapi/middleware.go index 1b97ce4d6aa..063ad042dbb 100644 --- a/pkg/utils/apiutil/serverapi/middleware.go +++ b/pkg/utils/apiutil/serverapi/middleware.go @@ -19,6 +19,7 @@ import ( "net/url" "strings" + "github.com/pingcap/failpoint" "github.com/pingcap/log" "github.com/tikv/pd/pkg/errs" "github.com/tikv/pd/pkg/slice" @@ -171,7 +172,11 @@ func (h *redirector) ServeHTTP(w http.ResponseWriter, r *http.Request, next http return } clientUrls = append(clientUrls, targetAddr) - w.Header().Set(apiutil.ForwardToMicroServiceHeader, "true") + failpoint.Inject("checkHeader", func() { + // add a header to the response, this is not a failure injection + // it is used for testing, to check whether the request is forwarded to the micro service + w.Header().Set(apiutil.ForwardToMicroServiceHeader, "true") + }) } else { leader := h.s.GetMember().GetLeader() if leader == nil { diff --git a/tests/integrations/mcs/scheduling/api_test.go b/tests/integrations/mcs/scheduling/api_test.go index ea9cd2df9c5..04671d84798 100644 --- a/tests/integrations/mcs/scheduling/api_test.go +++ b/tests/integrations/mcs/scheduling/api_test.go @@ -8,6 +8,7 @@ import ( "testing" "time" + "github.com/pingcap/failpoint" "github.com/stretchr/testify/suite" _ "github.com/tikv/pd/pkg/mcs/scheduling/server/apis/v1" "github.com/tikv/pd/pkg/utils/apiutil" @@ -116,6 +117,11 @@ func (suite *apiTestSuite) TestAPIForward() { defer tc.Destroy() tc.WaitForPrimaryServing(re) + failpoint.Enable("github.com/tikv/pd/pkg/utils/apiutil/serverapi/checkHeader", "return(true)") + defer func() { + failpoint.Disable("github.com/tikv/pd/pkg/utils/apiutil/serverapi/checkHeader") + }() + urlPrefix := fmt.Sprintf("%s/pd/api/v1", suite.backendEndpoints) var slice []string var resp map[string]interface{}