Skip to content

Commit

Permalink
mcs: forward current http request to mcs (tikv#7078)
Browse files Browse the repository at this point in the history
ref tikv#5839

Signed-off-by: lhy1024 <[email protected]>

Co-authored-by: ti-chi-bot[bot] <108142056+ti-chi-bot[bot]@users.noreply.github.com>
  • Loading branch information
lhy1024 and ti-chi-bot[bot] committed Sep 18, 2023
1 parent d7d4756 commit 0888ef6
Show file tree
Hide file tree
Showing 14 changed files with 278 additions and 69 deletions.
2 changes: 1 addition & 1 deletion pkg/mcs/resourcemanager/server/apis/v1/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ func NewService(srv *rmserver.Service) *Service {
manager := srv.GetManager()
apiHandlerEngine.Use(func(c *gin.Context) {
// manager implements the interface of basicserver.Service.
c.Set("service", manager.GetBasicServer())
c.Set(multiservicesapi.ServiceContextKey, manager.GetBasicServer())
c.Next()
})
apiHandlerEngine.Use(multiservicesapi.ServiceRedirector())
Expand Down
2 changes: 1 addition & 1 deletion pkg/mcs/scheduling/server/apis/v1/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ import (
)

// APIPathPrefix is the prefix of the API path.
const APIPathPrefix = "/scheduling/api/v1/"
const APIPathPrefix = "/scheduling/api/v1"

var (
once sync.Once
Expand Down
17 changes: 17 additions & 0 deletions pkg/utils/apiutil/apiutil.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ const (
XForwardedPortHeader = "X-Forwarded-Port"
// XRealIPHeader is used to mark the real client IP.
XRealIPHeader = "X-Real-Ip"
// ForwardToMicroServiceHeader is used to mark the request is forwarded to micro service.
ForwardToMicroServiceHeader = "Forward-To-Micro-Service"

// ErrRedirectFailed is the error message for redirect failed.
ErrRedirectFailed = "redirect failed"
Expand Down Expand Up @@ -435,8 +437,17 @@ func (p *customReverseProxies) ServeHTTP(w http.ResponseWriter, r *http.Request)
reader = resp.Body
}

// We need to copy the response headers before we write the header.
// Otherwise, we cannot set the header after w.WriteHeader() is called.
// And we need to write the header before we copy the response body.
// Otherwise, we cannot set the status code after w.Write() is called.
// In other words, we must perform the following steps strictly in order:
// 1. Set the response headers.
// 2. Write the response header.
// 3. Write the response body.
copyHeader(w.Header(), resp.Header)
w.WriteHeader(resp.StatusCode)

for {
if _, err = io.CopyN(w, reader, chunkSize); err != nil {
if err == io.EOF {
Expand All @@ -455,8 +466,14 @@ func (p *customReverseProxies) ServeHTTP(w http.ResponseWriter, r *http.Request)
http.Error(w, ErrRedirectFailed, http.StatusInternalServerError)
}

// copyHeader duplicates the HTTP headers from the source `src` to the destination `dst`.
// It skips the "Content-Encoding" and "Content-Length" headers because they should be set by `http.ResponseWriter`.
// These headers may be modified after a redirect when gzip compression is enabled.
func copyHeader(dst, src http.Header) {
for k, vv := range src {
if k == "Content-Encoding" || k == "Content-Length" {
continue
}
values := dst[k]
for _, v := range vv {
if !slice.Contains(values, v) {
Expand Down
35 changes: 28 additions & 7 deletions pkg/utils/apiutil/serverapi/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,12 @@ package serverapi
import (
"net/http"
"net/url"
"strings"

"github.com/pingcap/failpoint"
"github.com/pingcap/log"
"github.com/tikv/pd/pkg/errs"
"github.com/tikv/pd/pkg/slice"
"github.com/tikv/pd/pkg/utils/apiutil"
"github.com/tikv/pd/server"
"github.com/urfave/negroni"
Expand Down Expand Up @@ -75,6 +78,7 @@ type microserviceRedirectRule struct {
matchPath string
targetPath string
targetServiceName string
matchMethods []string
}

// NewRedirector redirects request to the leader if needs to be handled in the leader.
Expand All @@ -90,12 +94,13 @@ func NewRedirector(s *server.Server, opts ...RedirectorOption) negroni.Handler {
type RedirectorOption func(*redirector)

// MicroserviceRedirectRule new a microservice redirect rule option
func MicroserviceRedirectRule(matchPath, targetPath, targetServiceName string) RedirectorOption {
func MicroserviceRedirectRule(matchPath, targetPath, targetServiceName string, methods []string) RedirectorOption {
return func(s *redirector) {
s.microserviceRedirectRules = append(s.microserviceRedirectRules, &microserviceRedirectRule{
matchPath,
targetPath,
targetServiceName,
methods,
})
}
}
Expand All @@ -108,24 +113,35 @@ func (h *redirector) matchMicroServiceRedirectRules(r *http.Request) (bool, stri
return false, ""
}
for _, rule := range h.microserviceRedirectRules {
if rule.matchPath == r.URL.Path {
if strings.HasPrefix(r.URL.Path, rule.matchPath) && slice.Contains(rule.matchMethods, r.Method) {
addr, ok := h.s.GetServicePrimaryAddr(r.Context(), rule.targetServiceName)
if !ok || addr == "" {
log.Warn("failed to get the service primary addr when try match redirect rules",
log.Warn("failed to get the service primary addr when trying to match redirect rules",
zap.String("path", r.URL.Path))
}
r.URL.Path = rule.targetPath
// Extract parameters from the URL path
// e.g. r.URL.Path = /pd/api/v1/operators/1 (before redirect)
// matchPath = /pd/api/v1/operators
// targetPath = /scheduling/api/v1/operators
// r.URL.Path = /scheduling/api/v1/operator/1 (after redirect)
pathParams := strings.TrimPrefix(r.URL.Path, rule.matchPath)
pathParams = strings.Trim(pathParams, "/") // Remove leading and trailing '/'
if len(pathParams) > 0 {
r.URL.Path = rule.targetPath + "/" + pathParams
} else {
r.URL.Path = rule.targetPath
}
return true, addr
}
}
return false, ""
}

func (h *redirector) ServeHTTP(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
matchedFlag, targetAddr := h.matchMicroServiceRedirectRules(r)
redirectToMicroService, targetAddr := h.matchMicroServiceRedirectRules(r)
allowFollowerHandle := len(r.Header.Get(apiutil.PDAllowFollowerHandleHeader)) > 0
isLeader := h.s.GetMember().IsLeader()
if !h.s.IsClosed() && (allowFollowerHandle || isLeader) && !matchedFlag {
if !h.s.IsClosed() && (allowFollowerHandle || isLeader) && !redirectToMicroService {
next(w, r)
return
}
Expand All @@ -150,12 +166,17 @@ func (h *redirector) ServeHTTP(w http.ResponseWriter, r *http.Request, next http
}

var clientUrls []string
if matchedFlag {
if redirectToMicroService {
if len(targetAddr) == 0 {
http.Error(w, apiutil.ErrRedirectFailed, http.StatusInternalServerError)
return
}
clientUrls = append(clientUrls, targetAddr)
failpoint.Inject("checkHeader", func() {
// add a header to the response, this is not a failure injection
// it is used for testing, to check whether the request is forwarded to the micro service
w.Header().Set(apiutil.ForwardToMicroServiceHeader, "true")
})
} else {
leader := h.s.GetMember().GetLeader()
if leader == nil {
Expand Down
73 changes: 44 additions & 29 deletions pkg/utils/testutil/api_check.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,56 +23,71 @@ import (
"github.com/tikv/pd/pkg/utils/apiutil"
)

// Status is used to check whether http response code is equal given code
func Status(re *require.Assertions, code int) func([]byte, int) {
return func(resp []byte, i int) {
// Status is used to check whether http response code is equal given code.
func Status(re *require.Assertions, code int) func([]byte, int, http.Header) {
return func(resp []byte, i int, _ http.Header) {
re.Equal(code, i, "resp: "+string(resp))
}
}

// StatusOK is used to check whether http response code is equal http.StatusOK
func StatusOK(re *require.Assertions) func([]byte, int) {
// StatusOK is used to check whether http response code is equal http.StatusOK.
func StatusOK(re *require.Assertions) func([]byte, int, http.Header) {
return Status(re, http.StatusOK)
}

// StatusNotOK is used to check whether http response code is not equal http.StatusOK
func StatusNotOK(re *require.Assertions) func([]byte, int) {
return func(_ []byte, i int) {
// StatusNotOK is used to check whether http response code is not equal http.StatusOK.
func StatusNotOK(re *require.Assertions) func([]byte, int, http.Header) {
return func(_ []byte, i int, _ http.Header) {
re.NotEqual(http.StatusOK, i)
}
}

// ExtractJSON is used to check whether given data can be extracted successfully
func ExtractJSON(re *require.Assertions, data interface{}) func([]byte, int) {
return func(res []byte, _ int) {
// ExtractJSON is used to check whether given data can be extracted successfully.
func ExtractJSON(re *require.Assertions, data interface{}) func([]byte, int, http.Header) {
return func(res []byte, _ int, _ http.Header) {
re.NoError(json.Unmarshal(res, data))
}
}

// StringContain is used to check whether response context contains given string
func StringContain(re *require.Assertions, sub string) func([]byte, int) {
return func(res []byte, _ int) {
// StringContain is used to check whether response context contains given string.
func StringContain(re *require.Assertions, sub string) func([]byte, int, http.Header) {
return func(res []byte, _ int, _ http.Header) {
re.Contains(string(res), sub)
}
}

// StringEqual is used to check whether response context equal given string
func StringEqual(re *require.Assertions, str string) func([]byte, int) {
return func(res []byte, _ int) {
// StringEqual is used to check whether response context equal given string.
func StringEqual(re *require.Assertions, str string) func([]byte, int, http.Header) {
return func(res []byte, _ int, _ http.Header) {
re.Contains(string(res), str)
}
}

// ReadGetJSON is used to do get request and check whether given data can be extracted successfully
func ReadGetJSON(re *require.Assertions, client *http.Client, url string, data interface{}) error {
// WithHeader is used to check whether response header contains given key and value.
func WithHeader(re *require.Assertions, key, value string) func([]byte, int, http.Header) {
return func(_ []byte, _ int, header http.Header) {
re.Equal(value, header.Get(key))
}
}

// WithoutHeader is used to check whether response header does not contain given key.
func WithoutHeader(re *require.Assertions, key string) func([]byte, int, http.Header) {
return func(_ []byte, _ int, header http.Header) {
re.Empty(header.Get(key))
}
}

// ReadGetJSON is used to do get request and check whether given data can be extracted successfully.
func ReadGetJSON(re *require.Assertions, client *http.Client, url string, data interface{}, checkOpts ...func([]byte, int, http.Header)) error {
resp, err := apiutil.GetJSON(client, url, nil)
if err != nil {
return err
}
return checkResp(resp, StatusOK(re), ExtractJSON(re, data))
checkOpts = append(checkOpts, StatusOK(re), ExtractJSON(re, data))
return checkResp(resp, checkOpts...)
}

// ReadGetJSONWithBody is used to do get request with input and check whether given data can be extracted successfully
// ReadGetJSONWithBody is used to do get request with input and check whether given data can be extracted successfully.
func ReadGetJSONWithBody(re *require.Assertions, client *http.Client, url string, input []byte, data interface{}) error {
resp, err := apiutil.GetJSON(client, url, input)
if err != nil {
Expand All @@ -81,41 +96,41 @@ func ReadGetJSONWithBody(re *require.Assertions, client *http.Client, url string
return checkResp(resp, StatusOK(re), ExtractJSON(re, data))
}

// CheckPostJSON is used to do post request and do check options
func CheckPostJSON(client *http.Client, url string, data []byte, checkOpts ...func([]byte, int)) error {
// CheckPostJSON is used to do post request and do check options.
func CheckPostJSON(client *http.Client, url string, data []byte, checkOpts ...func([]byte, int, http.Header)) error {
resp, err := apiutil.PostJSON(client, url, data)
if err != nil {
return err
}
return checkResp(resp, checkOpts...)
}

// CheckGetJSON is used to do get request and do check options
func CheckGetJSON(client *http.Client, url string, data []byte, checkOpts ...func([]byte, int)) error {
// CheckGetJSON is used to do get request and do check options.
func CheckGetJSON(client *http.Client, url string, data []byte, checkOpts ...func([]byte, int, http.Header)) error {
resp, err := apiutil.GetJSON(client, url, data)
if err != nil {
return err
}
return checkResp(resp, checkOpts...)
}

// CheckPatchJSON is used to do patch request and do check options
func CheckPatchJSON(client *http.Client, url string, data []byte, checkOpts ...func([]byte, int)) error {
// CheckPatchJSON is used to do patch request and do check options.
func CheckPatchJSON(client *http.Client, url string, data []byte, checkOpts ...func([]byte, int, http.Header)) error {
resp, err := apiutil.PatchJSON(client, url, data)
if err != nil {
return err
}
return checkResp(resp, checkOpts...)
}

func checkResp(resp *http.Response, checkOpts ...func([]byte, int)) error {
func checkResp(resp *http.Response, checkOpts ...func([]byte, int, http.Header)) error {
res, err := io.ReadAll(resp.Body)
resp.Body.Close()
if err != nil {
return err
}
for _, opt := range checkOpts {
opt(res, resp.StatusCode)
opt(res, resp.StatusCode, resp.Header)
}
return nil
}
5 changes: 3 additions & 2 deletions server/api/hot_status_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ package api
import (
"encoding/json"
"fmt"
"net/http"
"testing"
"time"

Expand Down Expand Up @@ -92,7 +93,7 @@ func (suite *hotStatusTestSuite) TestGetHistoryHotRegionsTimeRange() {
StartTime: now.UnixNano() / int64(time.Millisecond),
EndTime: now.Add(10*time.Second).UnixNano() / int64(time.Millisecond),
}
check := func(res []byte, statusCode int) {
check := func(res []byte, statusCode int, _ http.Header) {
suite.Equal(200, statusCode)
historyHotRegions := &storage.HistoryHotRegions{}
json.Unmarshal(res, historyHotRegions)
Expand Down Expand Up @@ -177,7 +178,7 @@ func (suite *hotStatusTestSuite) TestGetHistoryHotRegionsIDAndTypes() {
IsLearners: []bool{false},
EndTime: now.Add(10*time.Minute).UnixNano() / int64(time.Millisecond),
}
check := func(res []byte, statusCode int) {
check := func(res []byte, statusCode int, _ http.Header) {
suite.Equal(200, statusCode)
historyHotRegions := &storage.HistoryHotRegions{}
json.Unmarshal(res, historyHotRegions)
Expand Down
2 changes: 1 addition & 1 deletion server/api/region_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -408,7 +408,7 @@ func (suite *regionTestSuite) TestSplitRegions() {
hex.EncodeToString([]byte("bbb")),
hex.EncodeToString([]byte("ccc")),
hex.EncodeToString([]byte("ddd")))
checkOpt := func(res []byte, code int) {
checkOpt := func(res []byte, code int, _ http.Header) {
s := &struct {
ProcessedPercentage int `json:"processed-percentage"`
NewRegionsID []uint64 `json:"regions-id"`
Expand Down
2 changes: 1 addition & 1 deletion server/api/rule_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -829,7 +829,7 @@ func (suite *ruleTestSuite) TestBundleBadRequest() {
}
for _, testCase := range testCases {
err := tu.CheckPostJSON(testDialClient, suite.urlPrefix+testCase.uri, []byte(testCase.data),
func(_ []byte, code int) {
func(_ []byte, code int, _ http.Header) {
suite.Equal(testCase.ok, code == http.StatusOK)
})
suite.NoError(err)
Expand Down
34 changes: 29 additions & 5 deletions server/api/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"net/http"

"github.com/gorilla/mux"
scheapi "github.com/tikv/pd/pkg/mcs/scheduling/server/apis/v1"
tsoapi "github.com/tikv/pd/pkg/mcs/tso/server/apis/v1"
mcs "github.com/tikv/pd/pkg/mcs/utils"
"github.com/tikv/pd/pkg/utils/apiutil"
Expand All @@ -35,14 +36,37 @@ func NewHandler(_ context.Context, svr *server.Server) (http.Handler, apiutil.AP
Name: "core",
IsCore: true,
}
router := mux.NewRouter()
prefix := apiPrefix + "/api/v1"
r := createRouter(apiPrefix, svr)
router := mux.NewRouter()
router.PathPrefix(apiPrefix).Handler(negroni.New(
serverapi.NewRuntimeServiceValidator(svr, group),
serverapi.NewRedirector(svr, serverapi.MicroserviceRedirectRule(
apiPrefix+"/api/v1"+"/admin/reset-ts",
tsoapi.APIPathPrefix+"/admin/reset-ts",
mcs.TSOServiceName)),
serverapi.NewRedirector(svr,
serverapi.MicroserviceRedirectRule(
prefix+"/admin/reset-ts",
tsoapi.APIPathPrefix+"/admin/reset-ts",
mcs.TSOServiceName,
[]string{http.MethodPost}),
serverapi.MicroserviceRedirectRule(
prefix+"/operators",
scheapi.APIPathPrefix+"/operators",
mcs.SchedulingServiceName,
[]string{http.MethodPost, http.MethodGet, http.MethodDelete}),
// because the writing of all the meta information of the scheduling service is in the API server,
// we only forward read-only requests about checkers and schedulers to the scheduling service.
serverapi.MicroserviceRedirectRule(
prefix+"/checker", // Note: this is a typo in the original code
scheapi.APIPathPrefix+"/checkers",
mcs.SchedulingServiceName,
[]string{http.MethodGet}),
serverapi.MicroserviceRedirectRule(
prefix+"/schedulers",
scheapi.APIPathPrefix+"/schedulers",
mcs.SchedulingServiceName,
[]string{http.MethodGet}),
// TODO: we need to consider the case that v1 api not support restful api.
// we might change the previous path parameters to query parameters.
),
negroni.Wrap(r)),
)

Expand Down
Loading

0 comments on commit 0888ef6

Please sign in to comment.