Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

mcs: forward current http request to mcs #7078

Merged
merged 13 commits into from
Sep 18, 2023
2 changes: 2 additions & 0 deletions pkg/utils/apiutil/apiutil.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ const (
XForwardedPortHeader = "X-Forwarded-Port"
// XRealIPHeader is used to mark the real client IP.
XRealIPHeader = "X-Real-Ip"
// ForwardToMicroServiceHeader is used to mark the request is forwarded to micro service.
ForwardToMicroServiceHeader = "Forward-To-Micro-Service"

// ErrRedirectFailed is the error message for redirect failed.
ErrRedirectFailed = "redirect failed"
Expand Down
1 change: 1 addition & 0 deletions pkg/utils/apiutil/serverapi/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@
if strings.HasPrefix(r.URL.Path, rule.matchPath) && slice.Contains(rule.matchMethods, r.Method) {
addr, ok := h.s.GetServicePrimaryAddr(r.Context(), rule.targetServiceName)
if !ok || addr == "" {
log.Warn("failed to get the service primary addr when trying to match redirect rules",

Check warning on line 118 in pkg/utils/apiutil/serverapi/middleware.go

View check run for this annotation

Codecov / codecov/patch

pkg/utils/apiutil/serverapi/middleware.go#L118

Added line #L118 was not covered by tests
zap.String("path", r.URL.Path))
}
// Extract parameters from the URL path
lhy1024 marked this conversation as resolved.
Show resolved Hide resolved
Expand Down Expand Up @@ -171,6 +171,7 @@
return
}
clientUrls = append(clientUrls, targetAddr)
w.Header().Set(apiutil.ForwardToMicroServiceHeader, "true")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it only used for testing?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it is now.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then, I think we shouldn't add it in the header, the testing code should not affect the normal situation.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it is not added to resp, is there any other way to tell whether it is processed by the PD server or the scheduling server?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When debugging on a real cluster, it may also be useful to identify the origin of the server handling the request.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we keep this header filed, what about setting its value to the microservice name rather than a true?

Copy link
Contributor Author

@lhy1024 lhy1024 Sep 18, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

e.g scheduling?tso?Is it necessary?

} else {
leader := h.s.GetMember().GetLeader()
if leader == nil {
Expand Down
73 changes: 44 additions & 29 deletions pkg/utils/testutil/api_check.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,56 +23,71 @@ import (
"github.com/tikv/pd/pkg/utils/apiutil"
)

// Status is used to check whether http response code is equal given code
func Status(re *require.Assertions, code int) func([]byte, int) {
return func(resp []byte, i int) {
// Status is used to check whether http response code is equal given code.
func Status(re *require.Assertions, code int) func([]byte, int, http.Header) {
return func(resp []byte, i int, _ http.Header) {
re.Equal(code, i, "resp: "+string(resp))
}
}

// StatusOK is used to check whether http response code is equal http.StatusOK
func StatusOK(re *require.Assertions) func([]byte, int) {
// StatusOK is used to check whether http response code is equal http.StatusOK.
func StatusOK(re *require.Assertions) func([]byte, int, http.Header) {
return Status(re, http.StatusOK)
}

// StatusNotOK is used to check whether http response code is not equal http.StatusOK
func StatusNotOK(re *require.Assertions) func([]byte, int) {
return func(_ []byte, i int) {
// StatusNotOK is used to check whether http response code is not equal http.StatusOK.
func StatusNotOK(re *require.Assertions) func([]byte, int, http.Header) {
return func(_ []byte, i int, _ http.Header) {
re.NotEqual(http.StatusOK, i)
}
}

// ExtractJSON is used to check whether given data can be extracted successfully
func ExtractJSON(re *require.Assertions, data interface{}) func([]byte, int) {
return func(res []byte, _ int) {
// ExtractJSON is used to check whether given data can be extracted successfully.
func ExtractJSON(re *require.Assertions, data interface{}) func([]byte, int, http.Header) {
return func(res []byte, _ int, _ http.Header) {
re.NoError(json.Unmarshal(res, data))
}
}

// StringContain is used to check whether response context contains given string
func StringContain(re *require.Assertions, sub string) func([]byte, int) {
return func(res []byte, _ int) {
// StringContain is used to check whether response context contains given string.
func StringContain(re *require.Assertions, sub string) func([]byte, int, http.Header) {
return func(res []byte, _ int, _ http.Header) {
re.Contains(string(res), sub)
}
}

// StringEqual is used to check whether response context equal given string
func StringEqual(re *require.Assertions, str string) func([]byte, int) {
return func(res []byte, _ int) {
// StringEqual is used to check whether response context equal given string.
func StringEqual(re *require.Assertions, str string) func([]byte, int, http.Header) {
return func(res []byte, _ int, _ http.Header) {
re.Contains(string(res), str)
}
}

// ReadGetJSON is used to do get request and check whether given data can be extracted successfully
func ReadGetJSON(re *require.Assertions, client *http.Client, url string, data interface{}) error {
// WithHeader is used to check whether response header contains given key and value.
func WithHeader(re *require.Assertions, key, value string) func([]byte, int, http.Header) {
return func(_ []byte, _ int, header http.Header) {
re.Equal(value, header.Get(key))
}
}

// WithoutHeader is used to check whether response header does not contain given key.
func WithoutHeader(re *require.Assertions, key string) func([]byte, int, http.Header) {
return func(_ []byte, _ int, header http.Header) {
re.Empty(header.Get(key))
}
}

// ReadGetJSON is used to do get request and check whether given data can be extracted successfully.
func ReadGetJSON(re *require.Assertions, client *http.Client, url string, data interface{}, checkOpts ...func([]byte, int, http.Header)) error {
resp, err := apiutil.GetJSON(client, url, nil)
if err != nil {
return err
}
return checkResp(resp, StatusOK(re), ExtractJSON(re, data))
checkOpts = append(checkOpts, StatusOK(re), ExtractJSON(re, data))
return checkResp(resp, checkOpts...)
}

// ReadGetJSONWithBody is used to do get request with input and check whether given data can be extracted successfully
// ReadGetJSONWithBody is used to do get request with input and check whether given data can be extracted successfully.
func ReadGetJSONWithBody(re *require.Assertions, client *http.Client, url string, input []byte, data interface{}) error {
resp, err := apiutil.GetJSON(client, url, input)
if err != nil {
Expand All @@ -81,41 +96,41 @@ func ReadGetJSONWithBody(re *require.Assertions, client *http.Client, url string
return checkResp(resp, StatusOK(re), ExtractJSON(re, data))
}

// CheckPostJSON is used to do post request and do check options
func CheckPostJSON(client *http.Client, url string, data []byte, checkOpts ...func([]byte, int)) error {
// CheckPostJSON is used to do post request and do check options.
func CheckPostJSON(client *http.Client, url string, data []byte, checkOpts ...func([]byte, int, http.Header)) error {
resp, err := apiutil.PostJSON(client, url, data)
if err != nil {
return err
}
return checkResp(resp, checkOpts...)
}

// CheckGetJSON is used to do get request and do check options
func CheckGetJSON(client *http.Client, url string, data []byte, checkOpts ...func([]byte, int)) error {
// CheckGetJSON is used to do get request and do check options.
func CheckGetJSON(client *http.Client, url string, data []byte, checkOpts ...func([]byte, int, http.Header)) error {
resp, err := apiutil.GetJSON(client, url, data)
if err != nil {
return err
}
return checkResp(resp, checkOpts...)
}

// CheckPatchJSON is used to do patch request and do check options
func CheckPatchJSON(client *http.Client, url string, data []byte, checkOpts ...func([]byte, int)) error {
// CheckPatchJSON is used to do patch request and do check options.
func CheckPatchJSON(client *http.Client, url string, data []byte, checkOpts ...func([]byte, int, http.Header)) error {
resp, err := apiutil.PatchJSON(client, url, data)
if err != nil {
return err
}
return checkResp(resp, checkOpts...)
}

func checkResp(resp *http.Response, checkOpts ...func([]byte, int)) error {
func checkResp(resp *http.Response, checkOpts ...func([]byte, int, http.Header)) error {
res, err := io.ReadAll(resp.Body)
resp.Body.Close()
if err != nil {
return err
}
for _, opt := range checkOpts {
opt(res, resp.StatusCode)
opt(res, resp.StatusCode, resp.Header)
}
return nil
}
5 changes: 3 additions & 2 deletions server/api/hot_status_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ package api
import (
"encoding/json"
"fmt"
"net/http"
"testing"
"time"

Expand Down Expand Up @@ -92,7 +93,7 @@ func (suite *hotStatusTestSuite) TestGetHistoryHotRegionsTimeRange() {
StartTime: now.UnixNano() / int64(time.Millisecond),
EndTime: now.Add(10*time.Second).UnixNano() / int64(time.Millisecond),
}
check := func(res []byte, statusCode int) {
check := func(res []byte, statusCode int, _ http.Header) {
suite.Equal(200, statusCode)
historyHotRegions := &storage.HistoryHotRegions{}
json.Unmarshal(res, historyHotRegions)
Expand Down Expand Up @@ -177,7 +178,7 @@ func (suite *hotStatusTestSuite) TestGetHistoryHotRegionsIDAndTypes() {
IsLearners: []bool{false},
EndTime: now.Add(10*time.Minute).UnixNano() / int64(time.Millisecond),
}
check := func(res []byte, statusCode int) {
check := func(res []byte, statusCode int, _ http.Header) {
suite.Equal(200, statusCode)
historyHotRegions := &storage.HistoryHotRegions{}
json.Unmarshal(res, historyHotRegions)
Expand Down
2 changes: 1 addition & 1 deletion server/api/region_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -408,7 +408,7 @@ func (suite *regionTestSuite) TestSplitRegions() {
hex.EncodeToString([]byte("bbb")),
hex.EncodeToString([]byte("ccc")),
hex.EncodeToString([]byte("ddd")))
checkOpt := func(res []byte, code int) {
checkOpt := func(res []byte, code int, _ http.Header) {
s := &struct {
ProcessedPercentage int `json:"processed-percentage"`
NewRegionsID []uint64 `json:"regions-id"`
Expand Down
2 changes: 1 addition & 1 deletion server/api/rule_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -829,7 +829,7 @@ func (suite *ruleTestSuite) TestBundleBadRequest() {
}
for _, testCase := range testCases {
err := tu.CheckPostJSON(testDialClient, suite.urlPrefix+testCase.uri, []byte(testCase.data),
func(_ []byte, code int) {
func(_ []byte, code int, _ http.Header) {
suite.Equal(testCase.ok, code == http.StatusOK)
})
suite.NoError(err)
Expand Down
33 changes: 27 additions & 6 deletions tests/integrations/mcs/scheduling/api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@ package scheduling_test

import (
"context"
"encoding/json"
"fmt"
"net/http"
"testing"
"time"

"github.com/stretchr/testify/suite"
_ "github.com/tikv/pd/pkg/mcs/scheduling/server/apis/v1"
"github.com/tikv/pd/pkg/utils/apiutil"
"github.com/tikv/pd/pkg/utils/tempurl"
"github.com/tikv/pd/pkg/utils/testutil"
"github.com/tikv/pd/tests"
Expand Down Expand Up @@ -119,21 +121,40 @@ func (suite *apiTestSuite) TestAPIForward() {
var resp map[string]interface{}

// Test opeartor
err = testutil.ReadGetJSON(re, testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "operators"), &slice)
err = testutil.ReadGetJSON(re, testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "operators"), &slice,
testutil.WithHeader(re, apiutil.ForwardToMicroServiceHeader, "true"))
re.NoError(err)
re.Len(slice, 0)

err = testutil.ReadGetJSON(re, testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "operators/2"), &resp)
err = testutil.ReadGetJSON(re, testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "operators/2"), &resp,
testutil.WithHeader(re, apiutil.ForwardToMicroServiceHeader, "true"))
re.NoError(err)
re.Nil(resp)

// Test checker
err = testutil.ReadGetJSON(re, testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "checker/merge"), &resp)
// Test checker: only read-only requests are forwarded
err = testutil.ReadGetJSON(re, testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "checker/merge"), &resp,
testutil.WithHeader(re, apiutil.ForwardToMicroServiceHeader, "true"))
re.NoError(err)
suite.False(resp["paused"].(bool))

// Test scheduler
err = testutil.ReadGetJSON(re, testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "schedulers"), &slice)
input := make(map[string]interface{})
input["delay"] = 10
pauseArgs, err := json.Marshal(input)
suite.NoError(err)
err = testutil.CheckPostJSON(testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "checker/merge"), pauseArgs,
testutil.StatusOK(re), testutil.WithoutHeader(re, apiutil.PDRedirectorHeader))
suite.NoError(err)

// Test scheduler: only read-only requests are forwarded
err = testutil.ReadGetJSON(re, testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "schedulers"), &slice,
testutil.WithHeader(re, apiutil.ForwardToMicroServiceHeader, "true"))
re.NoError(err)
re.Contains(slice, "balance-leader-scheduler")

input["delay"] = 30
pauseArgs, err = json.Marshal(input)
suite.NoError(err)
err = testutil.CheckPostJSON(testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "schedulers/all"), pauseArgs,
testutil.StatusOK(re), testutil.WithoutHeader(re, apiutil.ForwardToMicroServiceHeader))
suite.NoError(err)
}
Loading