Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

mcs: forward current http request to mcs #7078

Merged
merged 13 commits into from
Sep 18, 2023
2 changes: 1 addition & 1 deletion pkg/mcs/resourcemanager/server/apis/v1/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ func NewService(srv *rmserver.Service) *Service {
manager := srv.GetManager()
apiHandlerEngine.Use(func(c *gin.Context) {
// manager implements the interface of basicserver.Service.
c.Set("service", manager.GetBasicServer())
c.Set(multiservicesapi.ServiceContextKey, manager.GetBasicServer())
c.Next()
})
apiHandlerEngine.Use(multiservicesapi.ServiceRedirector())
Expand Down
2 changes: 1 addition & 1 deletion pkg/mcs/scheduling/server/apis/v1/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ import (
)

// APIPathPrefix is the prefix of the API path.
const APIPathPrefix = "/scheduling/api/v1/"
const APIPathPrefix = "/scheduling/api/v1"

var (
once sync.Once
Expand Down
17 changes: 17 additions & 0 deletions pkg/utils/apiutil/apiutil.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ const (
XForwardedPortHeader = "X-Forwarded-Port"
// XRealIPHeader is used to mark the real client IP.
XRealIPHeader = "X-Real-Ip"
// ForwardToMicroServiceHeader is used to mark the request is forwarded to micro service.
ForwardToMicroServiceHeader = "Forward-To-Micro-Service"

// ErrRedirectFailed is the error message for redirect failed.
ErrRedirectFailed = "redirect failed"
Expand Down Expand Up @@ -435,8 +437,17 @@ func (p *customReverseProxies) ServeHTTP(w http.ResponseWriter, r *http.Request)
reader = resp.Body
}

// We need to copy the response headers before we write the header.
// Otherwise, we cannot set the header after w.WriteHeader() is called.
// And we need to write the header before we copy the response body.
// Otherwise, we cannot set the status code after w.Write() is called.
// In other words, we must perform the following steps strictly in order:
// 1. Set the response headers.
// 2. Write the response header.
// 3. Write the response body.
copyHeader(w.Header(), resp.Header)
w.WriteHeader(resp.StatusCode)

for {
if _, err = io.CopyN(w, reader, chunkSize); err != nil {
if err == io.EOF {
Expand All @@ -455,8 +466,14 @@ func (p *customReverseProxies) ServeHTTP(w http.ResponseWriter, r *http.Request)
http.Error(w, ErrRedirectFailed, http.StatusInternalServerError)
}

// copyHeader duplicates the HTTP headers from the source `src` to the destination `dst`.
// It skips the "Content-Encoding" and "Content-Length" headers because they should be set by `http.ResponseWriter`.
// These headers may be modified after a redirect when gzip compression is enabled.
func copyHeader(dst, src http.Header) {
for k, vv := range src {
if k == "Content-Encoding" || k == "Content-Length" {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a concern about it and am not sure if only two keys will affect it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's possible, but for now gzip only affects these two, and we'll do more testing as we add interfaces.

continue
}
values := dst[k]
for _, v := range vv {
if !slice.Contains(values, v) {
Expand Down
35 changes: 28 additions & 7 deletions pkg/utils/apiutil/serverapi/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,12 @@ package serverapi
import (
"net/http"
"net/url"
"strings"

"github.com/pingcap/failpoint"
"github.com/pingcap/log"
"github.com/tikv/pd/pkg/errs"
"github.com/tikv/pd/pkg/slice"
"github.com/tikv/pd/pkg/utils/apiutil"
"github.com/tikv/pd/server"
"github.com/urfave/negroni"
Expand Down Expand Up @@ -75,6 +78,7 @@ type microserviceRedirectRule struct {
matchPath string
targetPath string
targetServiceName string
matchMethods []string
}

// NewRedirector redirects request to the leader if needs to be handled in the leader.
Expand All @@ -90,12 +94,13 @@ func NewRedirector(s *server.Server, opts ...RedirectorOption) negroni.Handler {
type RedirectorOption func(*redirector)

// MicroserviceRedirectRule new a microservice redirect rule option
func MicroserviceRedirectRule(matchPath, targetPath, targetServiceName string) RedirectorOption {
func MicroserviceRedirectRule(matchPath, targetPath, targetServiceName string, methods []string) RedirectorOption {
return func(s *redirector) {
s.microserviceRedirectRules = append(s.microserviceRedirectRules, &microserviceRedirectRule{
matchPath,
targetPath,
targetServiceName,
methods,
})
}
}
Expand All @@ -108,24 +113,35 @@ func (h *redirector) matchMicroServiceRedirectRules(r *http.Request) (bool, stri
return false, ""
}
for _, rule := range h.microserviceRedirectRules {
if rule.matchPath == r.URL.Path {
if strings.HasPrefix(r.URL.Path, rule.matchPath) && slice.Contains(rule.matchMethods, r.Method) {
addr, ok := h.s.GetServicePrimaryAddr(r.Context(), rule.targetServiceName)
if !ok || addr == "" {
log.Warn("failed to get the service primary addr when try match redirect rules",
log.Warn("failed to get the service primary addr when trying to match redirect rules",
zap.String("path", r.URL.Path))
}
r.URL.Path = rule.targetPath
// Extract parameters from the URL path
lhy1024 marked this conversation as resolved.
Show resolved Hide resolved
// e.g. r.URL.Path = /pd/api/v1/operators/1 (before redirect)
// matchPath = /pd/api/v1/operators
// targetPath = /scheduling/api/v1/operators
// r.URL.Path = /scheduling/api/v1/operator/1 (after redirect)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about using the custom way to do the transfer? Because we might change the previous path parameters to query parameters.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will add TODO when meet other interfaces, which not support restful

pathParams := strings.TrimPrefix(r.URL.Path, rule.matchPath)
pathParams = strings.Trim(pathParams, "/") // Remove leading and trailing '/'
if len(pathParams) > 0 {
r.URL.Path = rule.targetPath + "/" + pathParams
} else {
r.URL.Path = rule.targetPath
}
return true, addr
}
}
return false, ""
}

func (h *redirector) ServeHTTP(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
matchedFlag, targetAddr := h.matchMicroServiceRedirectRules(r)
redirectToMicroService, targetAddr := h.matchMicroServiceRedirectRules(r)
allowFollowerHandle := len(r.Header.Get(apiutil.PDAllowFollowerHandleHeader)) > 0
isLeader := h.s.GetMember().IsLeader()
if !h.s.IsClosed() && (allowFollowerHandle || isLeader) && !matchedFlag {
if !h.s.IsClosed() && (allowFollowerHandle || isLeader) && !redirectToMicroService {
next(w, r)
return
}
Expand All @@ -150,12 +166,17 @@ func (h *redirector) ServeHTTP(w http.ResponseWriter, r *http.Request, next http
}

var clientUrls []string
if matchedFlag {
if redirectToMicroService {
if len(targetAddr) == 0 {
http.Error(w, apiutil.ErrRedirectFailed, http.StatusInternalServerError)
return
}
clientUrls = append(clientUrls, targetAddr)
failpoint.Inject("checkHeader", func() {
// add a header to the response, this is not a failure injection
// it is used for testing, to check whether the request is forwarded to the micro service
w.Header().Set(apiutil.ForwardToMicroServiceHeader, "true")
})
} else {
leader := h.s.GetMember().GetLeader()
if leader == nil {
Expand Down
73 changes: 44 additions & 29 deletions pkg/utils/testutil/api_check.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,56 +23,71 @@ import (
"github.com/tikv/pd/pkg/utils/apiutil"
)

// Status is used to check whether http response code is equal given code
func Status(re *require.Assertions, code int) func([]byte, int) {
return func(resp []byte, i int) {
// Status is used to check whether http response code is equal given code.
func Status(re *require.Assertions, code int) func([]byte, int, http.Header) {
return func(resp []byte, i int, _ http.Header) {
re.Equal(code, i, "resp: "+string(resp))
}
}

// StatusOK is used to check whether http response code is equal http.StatusOK
func StatusOK(re *require.Assertions) func([]byte, int) {
// StatusOK is used to check whether http response code is equal http.StatusOK.
func StatusOK(re *require.Assertions) func([]byte, int, http.Header) {
return Status(re, http.StatusOK)
}

// StatusNotOK is used to check whether http response code is not equal http.StatusOK
func StatusNotOK(re *require.Assertions) func([]byte, int) {
return func(_ []byte, i int) {
// StatusNotOK is used to check whether http response code is not equal http.StatusOK.
func StatusNotOK(re *require.Assertions) func([]byte, int, http.Header) {
return func(_ []byte, i int, _ http.Header) {
re.NotEqual(http.StatusOK, i)
}
}

// ExtractJSON is used to check whether given data can be extracted successfully
func ExtractJSON(re *require.Assertions, data interface{}) func([]byte, int) {
return func(res []byte, _ int) {
// ExtractJSON is used to check whether given data can be extracted successfully.
func ExtractJSON(re *require.Assertions, data interface{}) func([]byte, int, http.Header) {
return func(res []byte, _ int, _ http.Header) {
re.NoError(json.Unmarshal(res, data))
}
}

// StringContain is used to check whether response context contains given string
func StringContain(re *require.Assertions, sub string) func([]byte, int) {
return func(res []byte, _ int) {
// StringContain is used to check whether response context contains given string.
func StringContain(re *require.Assertions, sub string) func([]byte, int, http.Header) {
return func(res []byte, _ int, _ http.Header) {
re.Contains(string(res), sub)
}
}

// StringEqual is used to check whether response context equal given string
func StringEqual(re *require.Assertions, str string) func([]byte, int) {
return func(res []byte, _ int) {
// StringEqual is used to check whether response context equal given string.
func StringEqual(re *require.Assertions, str string) func([]byte, int, http.Header) {
return func(res []byte, _ int, _ http.Header) {
re.Contains(string(res), str)
}
}

// ReadGetJSON is used to do get request and check whether given data can be extracted successfully
func ReadGetJSON(re *require.Assertions, client *http.Client, url string, data interface{}) error {
// WithHeader is used to check whether response header contains given key and value.
func WithHeader(re *require.Assertions, key, value string) func([]byte, int, http.Header) {
return func(_ []byte, _ int, header http.Header) {
re.Equal(value, header.Get(key))
}
}

// WithoutHeader is used to check whether response header does not contain given key.
func WithoutHeader(re *require.Assertions, key string) func([]byte, int, http.Header) {
return func(_ []byte, _ int, header http.Header) {
re.Empty(header.Get(key))
}
}

// ReadGetJSON is used to do get request and check whether given data can be extracted successfully.
func ReadGetJSON(re *require.Assertions, client *http.Client, url string, data interface{}, checkOpts ...func([]byte, int, http.Header)) error {
resp, err := apiutil.GetJSON(client, url, nil)
if err != nil {
return err
}
return checkResp(resp, StatusOK(re), ExtractJSON(re, data))
checkOpts = append(checkOpts, StatusOK(re), ExtractJSON(re, data))
return checkResp(resp, checkOpts...)
}

// ReadGetJSONWithBody is used to do get request with input and check whether given data can be extracted successfully
// ReadGetJSONWithBody is used to do get request with input and check whether given data can be extracted successfully.
func ReadGetJSONWithBody(re *require.Assertions, client *http.Client, url string, input []byte, data interface{}) error {
resp, err := apiutil.GetJSON(client, url, input)
if err != nil {
Expand All @@ -81,41 +96,41 @@ func ReadGetJSONWithBody(re *require.Assertions, client *http.Client, url string
return checkResp(resp, StatusOK(re), ExtractJSON(re, data))
}

// CheckPostJSON is used to do post request and do check options
func CheckPostJSON(client *http.Client, url string, data []byte, checkOpts ...func([]byte, int)) error {
// CheckPostJSON is used to do post request and do check options.
func CheckPostJSON(client *http.Client, url string, data []byte, checkOpts ...func([]byte, int, http.Header)) error {
resp, err := apiutil.PostJSON(client, url, data)
if err != nil {
return err
}
return checkResp(resp, checkOpts...)
}

// CheckGetJSON is used to do get request and do check options
func CheckGetJSON(client *http.Client, url string, data []byte, checkOpts ...func([]byte, int)) error {
// CheckGetJSON is used to do get request and do check options.
func CheckGetJSON(client *http.Client, url string, data []byte, checkOpts ...func([]byte, int, http.Header)) error {
resp, err := apiutil.GetJSON(client, url, data)
if err != nil {
return err
}
return checkResp(resp, checkOpts...)
}

// CheckPatchJSON is used to do patch request and do check options
func CheckPatchJSON(client *http.Client, url string, data []byte, checkOpts ...func([]byte, int)) error {
// CheckPatchJSON is used to do patch request and do check options.
func CheckPatchJSON(client *http.Client, url string, data []byte, checkOpts ...func([]byte, int, http.Header)) error {
resp, err := apiutil.PatchJSON(client, url, data)
if err != nil {
return err
}
return checkResp(resp, checkOpts...)
}

func checkResp(resp *http.Response, checkOpts ...func([]byte, int)) error {
func checkResp(resp *http.Response, checkOpts ...func([]byte, int, http.Header)) error {
res, err := io.ReadAll(resp.Body)
resp.Body.Close()
if err != nil {
return err
}
for _, opt := range checkOpts {
opt(res, resp.StatusCode)
opt(res, resp.StatusCode, resp.Header)
}
return nil
}
5 changes: 3 additions & 2 deletions server/api/hot_status_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ package api
import (
"encoding/json"
"fmt"
"net/http"
"testing"
"time"

Expand Down Expand Up @@ -92,7 +93,7 @@ func (suite *hotStatusTestSuite) TestGetHistoryHotRegionsTimeRange() {
StartTime: now.UnixNano() / int64(time.Millisecond),
EndTime: now.Add(10*time.Second).UnixNano() / int64(time.Millisecond),
}
check := func(res []byte, statusCode int) {
check := func(res []byte, statusCode int, _ http.Header) {
suite.Equal(200, statusCode)
historyHotRegions := &storage.HistoryHotRegions{}
json.Unmarshal(res, historyHotRegions)
Expand Down Expand Up @@ -177,7 +178,7 @@ func (suite *hotStatusTestSuite) TestGetHistoryHotRegionsIDAndTypes() {
IsLearners: []bool{false},
EndTime: now.Add(10*time.Minute).UnixNano() / int64(time.Millisecond),
}
check := func(res []byte, statusCode int) {
check := func(res []byte, statusCode int, _ http.Header) {
suite.Equal(200, statusCode)
historyHotRegions := &storage.HistoryHotRegions{}
json.Unmarshal(res, historyHotRegions)
Expand Down
2 changes: 1 addition & 1 deletion server/api/region_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -408,7 +408,7 @@ func (suite *regionTestSuite) TestSplitRegions() {
hex.EncodeToString([]byte("bbb")),
hex.EncodeToString([]byte("ccc")),
hex.EncodeToString([]byte("ddd")))
checkOpt := func(res []byte, code int) {
checkOpt := func(res []byte, code int, _ http.Header) {
s := &struct {
ProcessedPercentage int `json:"processed-percentage"`
NewRegionsID []uint64 `json:"regions-id"`
Expand Down
2 changes: 1 addition & 1 deletion server/api/rule_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -829,7 +829,7 @@ func (suite *ruleTestSuite) TestBundleBadRequest() {
}
for _, testCase := range testCases {
err := tu.CheckPostJSON(testDialClient, suite.urlPrefix+testCase.uri, []byte(testCase.data),
func(_ []byte, code int) {
func(_ []byte, code int, _ http.Header) {
suite.Equal(testCase.ok, code == http.StatusOK)
})
suite.NoError(err)
Expand Down
34 changes: 29 additions & 5 deletions server/api/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"net/http"

"github.com/gorilla/mux"
scheapi "github.com/tikv/pd/pkg/mcs/scheduling/server/apis/v1"
tsoapi "github.com/tikv/pd/pkg/mcs/tso/server/apis/v1"
mcs "github.com/tikv/pd/pkg/mcs/utils"
"github.com/tikv/pd/pkg/utils/apiutil"
Expand All @@ -35,14 +36,37 @@ func NewHandler(_ context.Context, svr *server.Server) (http.Handler, apiutil.AP
Name: "core",
IsCore: true,
}
router := mux.NewRouter()
prefix := apiPrefix + "/api/v1"
r := createRouter(apiPrefix, svr)
router := mux.NewRouter()
router.PathPrefix(apiPrefix).Handler(negroni.New(
serverapi.NewRuntimeServiceValidator(svr, group),
serverapi.NewRedirector(svr, serverapi.MicroserviceRedirectRule(
apiPrefix+"/api/v1"+"/admin/reset-ts",
tsoapi.APIPathPrefix+"/admin/reset-ts",
mcs.TSOServiceName)),
serverapi.NewRedirector(svr,
serverapi.MicroserviceRedirectRule(
prefix+"/admin/reset-ts",
tsoapi.APIPathPrefix+"/admin/reset-ts",
mcs.TSOServiceName,
[]string{http.MethodPost}),
serverapi.MicroserviceRedirectRule(
prefix+"/operators",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about the config or other paths?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will resolve them in other PRs, this PR only forwards the current HTTP method in scheduling server.

scheapi.APIPathPrefix+"/operators",
mcs.SchedulingServiceName,
[]string{http.MethodPost, http.MethodGet, http.MethodDelete}),
// because the writing of all the meta information of the scheduling service is in the API server,
// we only forward read-only requests about checkers and schedulers to the scheduling service.
serverapi.MicroserviceRedirectRule(
prefix+"/checker", // Note: this is a typo in the original code
scheapi.APIPathPrefix+"/checkers",
mcs.SchedulingServiceName,
[]string{http.MethodGet}),
serverapi.MicroserviceRedirectRule(
prefix+"/schedulers",
scheapi.APIPathPrefix+"/schedulers",
mcs.SchedulingServiceName,
[]string{http.MethodGet}),
// TODO: we need to consider the case that v1 api not support restful api.
// we might change the previous path parameters to query parameters.
),
negroni.Wrap(r)),
)

Expand Down
Loading
Loading