Skip to content

Commit

Permalink
add test
Browse files Browse the repository at this point in the history
Signed-off-by: lhy1024 <[email protected]>
  • Loading branch information
lhy1024 committed Sep 15, 2023
1 parent 590d971 commit 2193aff
Show file tree
Hide file tree
Showing 7 changed files with 79 additions and 39 deletions.
2 changes: 2 additions & 0 deletions pkg/utils/apiutil/apiutil.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ const (
XForwardedPortHeader = "X-Forwarded-Port"
// XRealIPHeader is used to mark the real client IP.
XRealIPHeader = "X-Real-Ip"
// ForwardToMicroServiceHeader is used to mark the request is forwarded to micro service.
ForwardToMicroServiceHeader = "Forward-To-Micro-Service"

// ErrRedirectFailed is the error message for redirect failed.
ErrRedirectFailed = "redirect failed"
Expand Down
1 change: 1 addition & 0 deletions pkg/utils/apiutil/serverapi/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ func (h *redirector) ServeHTTP(w http.ResponseWriter, r *http.Request, next http
return
}
clientUrls = append(clientUrls, targetAddr)
w.Header().Set(apiutil.ForwardToMicroServiceHeader, "true")
} else {
leader := h.s.GetMember().GetLeader()
if leader == nil {
Expand Down
73 changes: 44 additions & 29 deletions pkg/utils/testutil/api_check.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,56 +23,71 @@ import (
"github.com/tikv/pd/pkg/utils/apiutil"
)

// Status is used to check whether http response code is equal given code
func Status(re *require.Assertions, code int) func([]byte, int) {
return func(resp []byte, i int) {
// Status is used to check whether http response code is equal given code.
func Status(re *require.Assertions, code int) func([]byte, int, http.Header) {
return func(resp []byte, i int, _ http.Header) {
re.Equal(code, i, "resp: "+string(resp))
}
}

// StatusOK is used to check whether http response code is equal http.StatusOK
func StatusOK(re *require.Assertions) func([]byte, int) {
// StatusOK is used to check whether http response code is equal http.StatusOK.
func StatusOK(re *require.Assertions) func([]byte, int, http.Header) {
return Status(re, http.StatusOK)
}

// StatusNotOK is used to check whether http response code is not equal http.StatusOK
func StatusNotOK(re *require.Assertions) func([]byte, int) {
return func(_ []byte, i int) {
// StatusNotOK is used to check whether http response code is not equal http.StatusOK.
func StatusNotOK(re *require.Assertions) func([]byte, int, http.Header) {
return func(_ []byte, i int, _ http.Header) {
re.NotEqual(http.StatusOK, i)
}
}

// ExtractJSON is used to check whether given data can be extracted successfully
func ExtractJSON(re *require.Assertions, data interface{}) func([]byte, int) {
return func(res []byte, _ int) {
// ExtractJSON is used to check whether given data can be extracted successfully.
func ExtractJSON(re *require.Assertions, data interface{}) func([]byte, int, http.Header) {
return func(res []byte, _ int, _ http.Header) {
re.NoError(json.Unmarshal(res, data))
}
}

// StringContain is used to check whether response context contains given string
func StringContain(re *require.Assertions, sub string) func([]byte, int) {
return func(res []byte, _ int) {
// StringContain is used to check whether response context contains given string.
func StringContain(re *require.Assertions, sub string) func([]byte, int, http.Header) {
return func(res []byte, _ int, _ http.Header) {
re.Contains(string(res), sub)
}
}

// StringEqual is used to check whether response context equal given string
func StringEqual(re *require.Assertions, str string) func([]byte, int) {
return func(res []byte, _ int) {
// StringEqual is used to check whether response context equal given string.
func StringEqual(re *require.Assertions, str string) func([]byte, int, http.Header) {
return func(res []byte, _ int, _ http.Header) {
re.Contains(string(res), str)
}
}

// ReadGetJSON is used to do get request and check whether given data can be extracted successfully
func ReadGetJSON(re *require.Assertions, client *http.Client, url string, data interface{}) error {
// WithHeader is used to check whether response header contains given key and value.
func WithHeader(re *require.Assertions, key, value string) func([]byte, int, http.Header) {
return func(_ []byte, _ int, header http.Header) {
re.Equal(value, header.Get(key))
}
}

// WithoutHeader is used to check whether response header does not contain given key.
func WithoutHeader(re *require.Assertions, key string) func([]byte, int, http.Header) {
return func(_ []byte, _ int, header http.Header) {
re.Empty(header.Get(key))
}
}

// ReadGetJSON is used to do get request and check whether given data can be extracted successfully.
func ReadGetJSON(re *require.Assertions, client *http.Client, url string, data interface{}, checkOpts ...func([]byte, int, http.Header)) error {
resp, err := apiutil.GetJSON(client, url, nil)
if err != nil {
return err
}
return checkResp(resp, StatusOK(re), ExtractJSON(re, data))
checkOpts = append(checkOpts, StatusOK(re), ExtractJSON(re, data))
return checkResp(resp, checkOpts...)
}

// ReadGetJSONWithBody is used to do get request with input and check whether given data can be extracted successfully
// ReadGetJSONWithBody is used to do get request with input and check whether given data can be extracted successfully.
func ReadGetJSONWithBody(re *require.Assertions, client *http.Client, url string, input []byte, data interface{}) error {
resp, err := apiutil.GetJSON(client, url, input)
if err != nil {
Expand All @@ -81,41 +96,41 @@ func ReadGetJSONWithBody(re *require.Assertions, client *http.Client, url string
return checkResp(resp, StatusOK(re), ExtractJSON(re, data))
}

// CheckPostJSON is used to do post request and do check options
func CheckPostJSON(client *http.Client, url string, data []byte, checkOpts ...func([]byte, int)) error {
// CheckPostJSON is used to do post request and do check options.
func CheckPostJSON(client *http.Client, url string, data []byte, checkOpts ...func([]byte, int, http.Header)) error {
resp, err := apiutil.PostJSON(client, url, data)
if err != nil {
return err
}
return checkResp(resp, checkOpts...)
}

// CheckGetJSON is used to do get request and do check options
func CheckGetJSON(client *http.Client, url string, data []byte, checkOpts ...func([]byte, int)) error {
// CheckGetJSON is used to do get request and do check options.
func CheckGetJSON(client *http.Client, url string, data []byte, checkOpts ...func([]byte, int, http.Header)) error {
resp, err := apiutil.GetJSON(client, url, data)
if err != nil {
return err
}
return checkResp(resp, checkOpts...)
}

// CheckPatchJSON is used to do patch request and do check options
func CheckPatchJSON(client *http.Client, url string, data []byte, checkOpts ...func([]byte, int)) error {
// CheckPatchJSON is used to do patch request and do check options.
func CheckPatchJSON(client *http.Client, url string, data []byte, checkOpts ...func([]byte, int, http.Header)) error {
resp, err := apiutil.PatchJSON(client, url, data)
if err != nil {
return err
}
return checkResp(resp, checkOpts...)
}

func checkResp(resp *http.Response, checkOpts ...func([]byte, int)) error {
func checkResp(resp *http.Response, checkOpts ...func([]byte, int, http.Header)) error {
res, err := io.ReadAll(resp.Body)
resp.Body.Close()
if err != nil {
return err
}
for _, opt := range checkOpts {
opt(res, resp.StatusCode)
opt(res, resp.StatusCode, resp.Header)
}
return nil
}
5 changes: 3 additions & 2 deletions server/api/hot_status_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ package api
import (
"encoding/json"
"fmt"
"net/http"
"testing"
"time"

Expand Down Expand Up @@ -92,7 +93,7 @@ func (suite *hotStatusTestSuite) TestGetHistoryHotRegionsTimeRange() {
StartTime: now.UnixNano() / int64(time.Millisecond),
EndTime: now.Add(10*time.Second).UnixNano() / int64(time.Millisecond),
}
check := func(res []byte, statusCode int) {
check := func(res []byte, statusCode int, _ http.Header) {
suite.Equal(200, statusCode)
historyHotRegions := &storage.HistoryHotRegions{}
json.Unmarshal(res, historyHotRegions)
Expand Down Expand Up @@ -177,7 +178,7 @@ func (suite *hotStatusTestSuite) TestGetHistoryHotRegionsIDAndTypes() {
IsLearners: []bool{false},
EndTime: now.Add(10*time.Minute).UnixNano() / int64(time.Millisecond),
}
check := func(res []byte, statusCode int) {
check := func(res []byte, statusCode int, _ http.Header) {
suite.Equal(200, statusCode)
historyHotRegions := &storage.HistoryHotRegions{}
json.Unmarshal(res, historyHotRegions)
Expand Down
2 changes: 1 addition & 1 deletion server/api/region_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -408,7 +408,7 @@ func (suite *regionTestSuite) TestSplitRegions() {
hex.EncodeToString([]byte("bbb")),
hex.EncodeToString([]byte("ccc")),
hex.EncodeToString([]byte("ddd")))
checkOpt := func(res []byte, code int) {
checkOpt := func(res []byte, code int, _ http.Header) {
s := &struct {
ProcessedPercentage int `json:"processed-percentage"`
NewRegionsID []uint64 `json:"regions-id"`
Expand Down
2 changes: 1 addition & 1 deletion server/api/rule_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -829,7 +829,7 @@ func (suite *ruleTestSuite) TestBundleBadRequest() {
}
for _, testCase := range testCases {
err := tu.CheckPostJSON(testDialClient, suite.urlPrefix+testCase.uri, []byte(testCase.data),
func(_ []byte, code int) {
func(_ []byte, code int, _ http.Header) {
suite.Equal(testCase.ok, code == http.StatusOK)
})
suite.NoError(err)
Expand Down
33 changes: 27 additions & 6 deletions tests/integrations/mcs/scheduling/api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@ package scheduling_test

import (
"context"
"encoding/json"
"fmt"
"net/http"
"testing"
"time"

"github.com/stretchr/testify/suite"
_ "github.com/tikv/pd/pkg/mcs/scheduling/server/apis/v1"
"github.com/tikv/pd/pkg/utils/apiutil"
"github.com/tikv/pd/pkg/utils/tempurl"
"github.com/tikv/pd/pkg/utils/testutil"
"github.com/tikv/pd/tests"
Expand Down Expand Up @@ -119,21 +121,40 @@ func (suite *apiTestSuite) TestAPIForward() {
var resp map[string]interface{}

// Test opeartor
err = testutil.ReadGetJSON(re, testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "operators"), &slice)
err = testutil.ReadGetJSON(re, testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "operators"), &slice,
testutil.WithHeader(re, apiutil.ForwardToMicroServiceHeader, "true"))
re.NoError(err)
re.Len(slice, 0)

err = testutil.ReadGetJSON(re, testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "operators/2"), &resp)
err = testutil.ReadGetJSON(re, testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "operators/2"), &resp,
testutil.WithHeader(re, apiutil.ForwardToMicroServiceHeader, "true"))
re.NoError(err)
re.Nil(resp)

// Test checker
err = testutil.ReadGetJSON(re, testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "checker/merge"), &resp)
// Test checker: only read-only requests are forwarded
err = testutil.ReadGetJSON(re, testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "checker/merge"), &resp,
testutil.WithHeader(re, apiutil.ForwardToMicroServiceHeader, "true"))
re.NoError(err)
suite.False(resp["paused"].(bool))

// Test scheduler
err = testutil.ReadGetJSON(re, testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "schedulers"), &slice)
input := make(map[string]interface{})
input["delay"] = 10
pauseArgs, err := json.Marshal(input)
suite.NoError(err)
err = testutil.CheckPostJSON(testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "checker/merge"), pauseArgs,
testutil.StatusOK(re), testutil.WithoutHeader(re, apiutil.PDRedirectorHeader))
suite.NoError(err)

// Test scheduler: only read-only requests are forwarded
err = testutil.ReadGetJSON(re, testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "schedulers"), &slice,
testutil.WithHeader(re, apiutil.ForwardToMicroServiceHeader, "true"))
re.NoError(err)
re.Contains(slice, "balance-leader-scheduler")

input["delay"] = 30
pauseArgs, err = json.Marshal(input)
suite.NoError(err)
err = testutil.CheckPostJSON(testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "schedulers/all"), pauseArgs,
testutil.StatusOK(re), testutil.WithoutHeader(re, apiutil.ForwardToMicroServiceHeader))
suite.NoError(err)
}

0 comments on commit 2193aff

Please sign in to comment.