diff --git a/pkg/utils/apiutil/serverapi/middleware.go b/pkg/utils/apiutil/serverapi/middleware.go old mode 100644 new mode 100755 index eb0f8a5f8eb..c7979dcc038 --- a/pkg/utils/apiutil/serverapi/middleware.go +++ b/pkg/utils/apiutil/serverapi/middleware.go @@ -182,14 +182,6 @@ func (h *redirector) ServeHTTP(w http.ResponseWriter, r *http.Request, next http return } - // Prevent more than one redirection. - if name := r.Header.Get(apiutil.PDRedirectorHeader); len(name) != 0 { - log.Error("redirect but server is not leader", zap.String("from", name), zap.String("server", h.s.Name()), errs.ZapError(errs.ErrRedirect)) - http.Error(w, errs.ErrRedirectToNotLeader.FastGenByArgs().Error(), http.StatusInternalServerError) - return - } - - r.Header.Set(apiutil.PDRedirectorHeader, h.s.Name()) forwardedIP, forwardedPort := apiutil.GetIPPortFromHTTPRequest(r) if len(forwardedIP) > 0 { r.Header.Add(apiutil.XForwardedForHeader, forwardedIP) @@ -208,9 +200,9 @@ func (h *redirector) ServeHTTP(w http.ResponseWriter, r *http.Request, next http return } clientUrls = append(clientUrls, targetAddr) + // Add a header to the response, this is not a failure injection + // it is used for testing, to check whether the request is forwarded to the micro service failpoint.Inject("checkHeader", func() { - // add a header to the response, this is not a failure injection - // it is used for testing, to check whether the request is forwarded to the micro service w.Header().Set(apiutil.ForwardToMicroServiceHeader, "true") }) } else { @@ -220,7 +212,15 @@ func (h *redirector) ServeHTTP(w http.ResponseWriter, r *http.Request, next http return } clientUrls = leader.GetClientUrls() + // Prevent more than one redirection among PD/API servers. + if name := r.Header.Get(apiutil.PDRedirectorHeader); len(name) != 0 { + log.Error("redirect but server is not leader", zap.String("from", name), zap.String("server", h.s.Name()), errs.ZapError(errs.ErrRedirect)) + http.Error(w, errs.ErrRedirectToNotLeader.FastGenByArgs().Error(), http.StatusInternalServerError) + return + } + r.Header.Set(apiutil.PDRedirectorHeader, h.s.Name()) } + urls := make([]url.URL, 0, len(clientUrls)) for _, item := range clientUrls { u, err := url.Parse(item) diff --git a/tests/integrations/mcs/scheduling/api_test.go b/tests/integrations/mcs/scheduling/api_test.go index 8f5d37ee1bb..4c71f8f14a3 100644 --- a/tests/integrations/mcs/scheduling/api_test.go +++ b/tests/integrations/mcs/scheduling/api_test.go @@ -1,6 +1,7 @@ package scheduling_test import ( + "context" "encoding/hex" "encoding/json" "fmt" @@ -40,10 +41,12 @@ func TestAPI(t *testing.T) { } func (suite *apiTestSuite) SetupSuite() { + suite.NoError(failpoint.Enable("github.com/tikv/pd/pkg/utils/apiutil/serverapi/checkHeader", "return(true)")) suite.env = tests.NewSchedulingTestEnvironment(suite.T()) } func (suite *apiTestSuite) TearDownSuite() { + suite.NoError(failpoint.Disable("github.com/tikv/pd/pkg/utils/apiutil/serverapi/checkHeader")) suite.env.Cleanup() } @@ -99,10 +102,6 @@ func (suite *apiTestSuite) TestAPIForward() { func (suite *apiTestSuite) checkAPIForward(cluster *tests.TestCluster) { re := suite.Require() - re.NoError(failpoint.Enable("github.com/tikv/pd/pkg/utils/apiutil/serverapi/checkHeader", "return(true)")) - defer func() { - re.NoError(failpoint.Disable("github.com/tikv/pd/pkg/utils/apiutil/serverapi/checkHeader")) - }() leader := cluster.GetLeaderServer().GetServer() urlPrefix := fmt.Sprintf("%s/pd/api/v1", leader.GetAddr()) @@ -300,7 +299,7 @@ func (suite *apiTestSuite) checkAPIForward(cluster *tests.TestCluster) { rulesArgs, err := json.Marshal(rules) suite.NoError(err) - err = testutil.ReadGetJSON(re, testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "/config/rules"), &rules, + err = testutil.ReadGetJSON(re, testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "config/rules"), &rules, testutil.WithHeader(re, apiutil.ForwardToMicroServiceHeader, "true")) re.NoError(err) err = testutil.CheckPostJSON(testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "config/rules"), rulesArgs, @@ -499,3 +498,49 @@ func (suite *apiTestSuite) checkAdminRegionCacheForward(cluster *tests.TestClust re.Equal(0, schedulingServer.GetCluster().GetRegionCount([]byte{}, []byte{})) re.Equal(0, apiServer.GetRaftCluster().GetRegionCount([]byte{}, []byte{}).Count) } + +func (suite *apiTestSuite) TestFollowerForward() { + suite.env.RunTestInTwoModes(suite.checkFollowerForward) +} + +func (suite *apiTestSuite) checkFollowerForward(cluster *tests.TestCluster) { + re := suite.Require() + leaderAddr := cluster.GetLeaderServer().GetAddr() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + follower, err := cluster.JoinAPIServer(ctx) + re.NoError(err) + re.NoError(follower.Run()) + re.NotEmpty(cluster.WaitLeader()) + + followerAddr := follower.GetAddr() + if cluster.GetLeaderServer().GetAddr() != leaderAddr { + followerAddr = leaderAddr + } + + urlPrefix := fmt.Sprintf("%s/pd/api/v1", followerAddr) + rules := []*placement.Rule{} + if sche := cluster.GetSchedulingPrimaryServer(); sche != nil { + // follower will forward to scheduling server directly + re.NotEqual(cluster.GetLeaderServer().GetAddr(), followerAddr) + err = testutil.ReadGetJSON(re, testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "config/rules"), &rules, + testutil.WithHeader(re, apiutil.ForwardToMicroServiceHeader, "true"), + ) + re.NoError(err) + } else { + // follower will forward to leader server + re.NotEqual(cluster.GetLeaderServer().GetAddr(), followerAddr) + err = testutil.ReadGetJSON(re, testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "config/rules"), &rules, + testutil.WithoutHeader(re, apiutil.ForwardToMicroServiceHeader), + ) + re.NoError(err) + } + + // follower will forward to leader server + re.NotEqual(cluster.GetLeaderServer().GetAddr(), followerAddr) + results := make(map[string]interface{}) + err = testutil.ReadGetJSON(re, testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "config"), &results, + testutil.WithoutHeader(re, apiutil.ForwardToMicroServiceHeader), + ) + re.NoError(err) +} diff --git a/tests/pdctl/scheduler/scheduler_test.go b/tests/pdctl/scheduler/scheduler_test.go index d8d54a79d13..fb7c239b431 100644 --- a/tests/pdctl/scheduler/scheduler_test.go +++ b/tests/pdctl/scheduler/scheduler_test.go @@ -15,7 +15,6 @@ package scheduler_test import ( - "context" "encoding/json" "fmt" "reflect" @@ -691,47 +690,3 @@ func mightExec(re *require.Assertions, cmd *cobra.Command, args []string, v inte } json.Unmarshal(output, v) } - -func TestForwardSchedulerRequest(t *testing.T) { - re := require.New(t) - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - cluster, err := tests.NewTestAPICluster(ctx, 1) - re.NoError(err) - re.NoError(cluster.RunInitialServers()) - re.NotEmpty(cluster.WaitLeader()) - server := cluster.GetLeaderServer() - re.NoError(server.BootstrapCluster()) - backendEndpoints := server.GetAddr() - tc, err := tests.NewTestSchedulingCluster(ctx, 1, backendEndpoints) - re.NoError(err) - defer tc.Destroy() - tc.WaitForPrimaryServing(re) - - cmd := pdctlCmd.GetRootCmd() - args := []string{"-u", backendEndpoints, "scheduler", "show"} - var sches []string - testutil.Eventually(re, func() bool { - output, err := pdctl.ExecuteCommand(cmd, args...) - re.NoError(err) - re.NoError(json.Unmarshal(output, &sches)) - return slice.Contains(sches, "balance-leader-scheduler") - }) - - mustUsage := func(args []string) { - output, err := pdctl.ExecuteCommand(cmd, args...) - re.NoError(err) - re.Contains(string(output), "Usage") - } - mustUsage([]string{"-u", backendEndpoints, "scheduler", "pause", "balance-leader-scheduler"}) - echo := mustExec(re, cmd, []string{"-u", backendEndpoints, "scheduler", "pause", "balance-leader-scheduler", "60"}, nil) - re.Contains(echo, "Success!") - checkSchedulerWithStatusCommand := func(status string, expected []string) { - var schedulers []string - mustExec(re, cmd, []string{"-u", backendEndpoints, "scheduler", "show", "--status", status}, &schedulers) - re.Equal(expected, schedulers) - } - checkSchedulerWithStatusCommand("paused", []string{ - "balance-leader-scheduler", - }) -}