diff --git a/pkg/utils/apiutil/apiutil.go b/pkg/utils/apiutil/apiutil.go index bcab7c8e9e7..0b72b9af10f 100644 --- a/pkg/utils/apiutil/apiutil.go +++ b/pkg/utils/apiutil/apiutil.go @@ -57,6 +57,8 @@ const ( XForwardedPortHeader = "X-Forwarded-Port" // XRealIPHeader is used to mark the real client IP. XRealIPHeader = "X-Real-Ip" + // ForwardToMicroServiceHeader is used to mark the request is forwarded to micro service. + ForwardToMicroServiceHeader = "Forward-To-Micro-Service" // ErrRedirectFailed is the error message for redirect failed. ErrRedirectFailed = "redirect failed" diff --git a/pkg/utils/apiutil/serverapi/middleware.go b/pkg/utils/apiutil/serverapi/middleware.go index 6682f4aabbc..1b97ce4d6aa 100644 --- a/pkg/utils/apiutil/serverapi/middleware.go +++ b/pkg/utils/apiutil/serverapi/middleware.go @@ -171,6 +171,7 @@ func (h *redirector) ServeHTTP(w http.ResponseWriter, r *http.Request, next http return } clientUrls = append(clientUrls, targetAddr) + w.Header().Set(apiutil.ForwardToMicroServiceHeader, "true") } else { leader := h.s.GetMember().GetLeader() if leader == nil { diff --git a/pkg/utils/testutil/api_check.go b/pkg/utils/testutil/api_check.go index fcc445b7e7a..d11d575967d 100644 --- a/pkg/utils/testutil/api_check.go +++ b/pkg/utils/testutil/api_check.go @@ -23,56 +23,71 @@ import ( "github.com/tikv/pd/pkg/utils/apiutil" ) -// Status is used to check whether http response code is equal given code -func Status(re *require.Assertions, code int) func([]byte, int) { - return func(resp []byte, i int) { +// Status is used to check whether http response code is equal given code. +func Status(re *require.Assertions, code int) func([]byte, int, http.Header) { + return func(resp []byte, i int, _ http.Header) { re.Equal(code, i, "resp: "+string(resp)) } } -// StatusOK is used to check whether http response code is equal http.StatusOK -func StatusOK(re *require.Assertions) func([]byte, int) { +// StatusOK is used to check whether http response code is equal http.StatusOK. +func StatusOK(re *require.Assertions) func([]byte, int, http.Header) { return Status(re, http.StatusOK) } -// StatusNotOK is used to check whether http response code is not equal http.StatusOK -func StatusNotOK(re *require.Assertions) func([]byte, int) { - return func(_ []byte, i int) { +// StatusNotOK is used to check whether http response code is not equal http.StatusOK. +func StatusNotOK(re *require.Assertions) func([]byte, int, http.Header) { + return func(_ []byte, i int, _ http.Header) { re.NotEqual(http.StatusOK, i) } } -// ExtractJSON is used to check whether given data can be extracted successfully -func ExtractJSON(re *require.Assertions, data interface{}) func([]byte, int) { - return func(res []byte, _ int) { +// ExtractJSON is used to check whether given data can be extracted successfully. +func ExtractJSON(re *require.Assertions, data interface{}) func([]byte, int, http.Header) { + return func(res []byte, _ int, _ http.Header) { re.NoError(json.Unmarshal(res, data)) } } -// StringContain is used to check whether response context contains given string -func StringContain(re *require.Assertions, sub string) func([]byte, int) { - return func(res []byte, _ int) { +// StringContain is used to check whether response context contains given string. +func StringContain(re *require.Assertions, sub string) func([]byte, int, http.Header) { + return func(res []byte, _ int, _ http.Header) { re.Contains(string(res), sub) } } -// StringEqual is used to check whether response context equal given string -func StringEqual(re *require.Assertions, str string) func([]byte, int) { - return func(res []byte, _ int) { +// StringEqual is used to check whether response context equal given string. +func StringEqual(re *require.Assertions, str string) func([]byte, int, http.Header) { + return func(res []byte, _ int, _ http.Header) { re.Contains(string(res), str) } } -// ReadGetJSON is used to do get request and check whether given data can be extracted successfully -func ReadGetJSON(re *require.Assertions, client *http.Client, url string, data interface{}) error { +// WithHeader is used to check whether response header contains given key and value. +func WithHeader(re *require.Assertions, key, value string) func([]byte, int, http.Header) { + return func(_ []byte, _ int, header http.Header) { + re.Equal(value, header.Get(key)) + } +} + +// WithoutHeader is used to check whether response header does not contain given key. +func WithoutHeader(re *require.Assertions, key string) func([]byte, int, http.Header) { + return func(_ []byte, _ int, header http.Header) { + re.Empty(header.Get(key)) + } +} + +// ReadGetJSON is used to do get request and check whether given data can be extracted successfully. +func ReadGetJSON(re *require.Assertions, client *http.Client, url string, data interface{}, checkOpts ...func([]byte, int, http.Header)) error { resp, err := apiutil.GetJSON(client, url, nil) if err != nil { return err } - return checkResp(resp, StatusOK(re), ExtractJSON(re, data)) + checkOpts = append(checkOpts, StatusOK(re), ExtractJSON(re, data)) + return checkResp(resp, checkOpts...) } -// ReadGetJSONWithBody is used to do get request with input and check whether given data can be extracted successfully +// ReadGetJSONWithBody is used to do get request with input and check whether given data can be extracted successfully. func ReadGetJSONWithBody(re *require.Assertions, client *http.Client, url string, input []byte, data interface{}) error { resp, err := apiutil.GetJSON(client, url, input) if err != nil { @@ -81,8 +96,8 @@ func ReadGetJSONWithBody(re *require.Assertions, client *http.Client, url string return checkResp(resp, StatusOK(re), ExtractJSON(re, data)) } -// CheckPostJSON is used to do post request and do check options -func CheckPostJSON(client *http.Client, url string, data []byte, checkOpts ...func([]byte, int)) error { +// CheckPostJSON is used to do post request and do check options. +func CheckPostJSON(client *http.Client, url string, data []byte, checkOpts ...func([]byte, int, http.Header)) error { resp, err := apiutil.PostJSON(client, url, data) if err != nil { return err @@ -90,8 +105,8 @@ func CheckPostJSON(client *http.Client, url string, data []byte, checkOpts ...fu return checkResp(resp, checkOpts...) } -// CheckGetJSON is used to do get request and do check options -func CheckGetJSON(client *http.Client, url string, data []byte, checkOpts ...func([]byte, int)) error { +// CheckGetJSON is used to do get request and do check options. +func CheckGetJSON(client *http.Client, url string, data []byte, checkOpts ...func([]byte, int, http.Header)) error { resp, err := apiutil.GetJSON(client, url, data) if err != nil { return err @@ -99,8 +114,8 @@ func CheckGetJSON(client *http.Client, url string, data []byte, checkOpts ...fun return checkResp(resp, checkOpts...) } -// CheckPatchJSON is used to do patch request and do check options -func CheckPatchJSON(client *http.Client, url string, data []byte, checkOpts ...func([]byte, int)) error { +// CheckPatchJSON is used to do patch request and do check options. +func CheckPatchJSON(client *http.Client, url string, data []byte, checkOpts ...func([]byte, int, http.Header)) error { resp, err := apiutil.PatchJSON(client, url, data) if err != nil { return err @@ -108,14 +123,14 @@ func CheckPatchJSON(client *http.Client, url string, data []byte, checkOpts ...f return checkResp(resp, checkOpts...) } -func checkResp(resp *http.Response, checkOpts ...func([]byte, int)) error { +func checkResp(resp *http.Response, checkOpts ...func([]byte, int, http.Header)) error { res, err := io.ReadAll(resp.Body) resp.Body.Close() if err != nil { return err } for _, opt := range checkOpts { - opt(res, resp.StatusCode) + opt(res, resp.StatusCode, resp.Header) } return nil } diff --git a/server/api/hot_status_test.go b/server/api/hot_status_test.go index a1d1bbc2617..d3d495f86fa 100644 --- a/server/api/hot_status_test.go +++ b/server/api/hot_status_test.go @@ -17,6 +17,7 @@ package api import ( "encoding/json" "fmt" + "net/http" "testing" "time" @@ -92,7 +93,7 @@ func (suite *hotStatusTestSuite) TestGetHistoryHotRegionsTimeRange() { StartTime: now.UnixNano() / int64(time.Millisecond), EndTime: now.Add(10*time.Second).UnixNano() / int64(time.Millisecond), } - check := func(res []byte, statusCode int) { + check := func(res []byte, statusCode int, _ http.Header) { suite.Equal(200, statusCode) historyHotRegions := &storage.HistoryHotRegions{} json.Unmarshal(res, historyHotRegions) @@ -177,7 +178,7 @@ func (suite *hotStatusTestSuite) TestGetHistoryHotRegionsIDAndTypes() { IsLearners: []bool{false}, EndTime: now.Add(10*time.Minute).UnixNano() / int64(time.Millisecond), } - check := func(res []byte, statusCode int) { + check := func(res []byte, statusCode int, _ http.Header) { suite.Equal(200, statusCode) historyHotRegions := &storage.HistoryHotRegions{} json.Unmarshal(res, historyHotRegions) diff --git a/server/api/region_test.go b/server/api/region_test.go index 63da19ab082..acd305884d4 100644 --- a/server/api/region_test.go +++ b/server/api/region_test.go @@ -408,7 +408,7 @@ func (suite *regionTestSuite) TestSplitRegions() { hex.EncodeToString([]byte("bbb")), hex.EncodeToString([]byte("ccc")), hex.EncodeToString([]byte("ddd"))) - checkOpt := func(res []byte, code int) { + checkOpt := func(res []byte, code int, _ http.Header) { s := &struct { ProcessedPercentage int `json:"processed-percentage"` NewRegionsID []uint64 `json:"regions-id"` diff --git a/server/api/rule_test.go b/server/api/rule_test.go index d2000eb9562..4cea1523401 100644 --- a/server/api/rule_test.go +++ b/server/api/rule_test.go @@ -829,7 +829,7 @@ func (suite *ruleTestSuite) TestBundleBadRequest() { } for _, testCase := range testCases { err := tu.CheckPostJSON(testDialClient, suite.urlPrefix+testCase.uri, []byte(testCase.data), - func(_ []byte, code int) { + func(_ []byte, code int, _ http.Header) { suite.Equal(testCase.ok, code == http.StatusOK) }) suite.NoError(err) diff --git a/tests/integrations/mcs/scheduling/api_test.go b/tests/integrations/mcs/scheduling/api_test.go index 03db3433ec9..ea9cd2df9c5 100644 --- a/tests/integrations/mcs/scheduling/api_test.go +++ b/tests/integrations/mcs/scheduling/api_test.go @@ -2,6 +2,7 @@ package scheduling_test import ( "context" + "encoding/json" "fmt" "net/http" "testing" @@ -9,6 +10,7 @@ import ( "github.com/stretchr/testify/suite" _ "github.com/tikv/pd/pkg/mcs/scheduling/server/apis/v1" + "github.com/tikv/pd/pkg/utils/apiutil" "github.com/tikv/pd/pkg/utils/tempurl" "github.com/tikv/pd/pkg/utils/testutil" "github.com/tikv/pd/tests" @@ -119,21 +121,40 @@ func (suite *apiTestSuite) TestAPIForward() { var resp map[string]interface{} // Test opeartor - err = testutil.ReadGetJSON(re, testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "operators"), &slice) + err = testutil.ReadGetJSON(re, testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "operators"), &slice, + testutil.WithHeader(re, apiutil.ForwardToMicroServiceHeader, "true")) re.NoError(err) re.Len(slice, 0) - err = testutil.ReadGetJSON(re, testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "operators/2"), &resp) + err = testutil.ReadGetJSON(re, testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "operators/2"), &resp, + testutil.WithHeader(re, apiutil.ForwardToMicroServiceHeader, "true")) re.NoError(err) re.Nil(resp) - // Test checker - err = testutil.ReadGetJSON(re, testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "checker/merge"), &resp) + // Test checker: only read-only requests are forwarded + err = testutil.ReadGetJSON(re, testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "checker/merge"), &resp, + testutil.WithHeader(re, apiutil.ForwardToMicroServiceHeader, "true")) re.NoError(err) suite.False(resp["paused"].(bool)) - // Test scheduler - err = testutil.ReadGetJSON(re, testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "schedulers"), &slice) + input := make(map[string]interface{}) + input["delay"] = 10 + pauseArgs, err := json.Marshal(input) + suite.NoError(err) + err = testutil.CheckPostJSON(testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "checker/merge"), pauseArgs, + testutil.StatusOK(re), testutil.WithoutHeader(re, apiutil.PDRedirectorHeader)) + suite.NoError(err) + + // Test scheduler: only read-only requests are forwarded + err = testutil.ReadGetJSON(re, testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "schedulers"), &slice, + testutil.WithHeader(re, apiutil.ForwardToMicroServiceHeader, "true")) re.NoError(err) re.Contains(slice, "balance-leader-scheduler") + + input["delay"] = 30 + pauseArgs, err = json.Marshal(input) + suite.NoError(err) + err = testutil.CheckPostJSON(testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "schedulers/all"), pauseArgs, + testutil.StatusOK(re), testutil.WithoutHeader(re, apiutil.ForwardToMicroServiceHeader)) + suite.NoError(err) }