From dbf8a35a8e70eb12bc4345680ca59ae3a7d1f8f5 Mon Sep 17 00:00:00 2001 From: Cabinfever_B Date: Fri, 10 Nov 2023 09:31:02 +0800 Subject: [PATCH] improve api bench Signed-off-by: Cabinfever_B --- client/client.go | 3 + tools/pd-api-bench/cases/cases.go | 52 ++++++++++++- tools/pd-api-bench/go.mod | 3 +- tools/pd-api-bench/main.go | 122 +++++++++++++++++++++++++++--- 4 files changed, 166 insertions(+), 14 deletions(-) diff --git a/client/client.go b/client/client.go index 067872d2d39..65d0be8b864 100644 --- a/client/client.go +++ b/client/client.go @@ -146,6 +146,9 @@ type Client interface { // SetExternalTimestamp sets external timestamp SetExternalTimestamp(ctx context.Context, timestamp uint64) error + // GetServiceDiscovery returns ServiceDiscovery + GetServiceDiscovery() ServiceDiscovery + // TSOClient is the TSO client. TSOClient // MetaStorageClient is the meta storage client. diff --git a/tools/pd-api-bench/cases/cases.go b/tools/pd-api-bench/cases/cases.go index d431b6f325c..04151016bef 100644 --- a/tools/pd-api-bench/cases/cases.go +++ b/tools/pd-api-bench/cases/cases.go @@ -22,6 +22,8 @@ import ( "net/http" "net/url" + "github.com/pingcap/kvproto/pkg/pdpb" + "github.com/pkg/errors" pd "github.com/tikv/pd/client" "github.com/tikv/pd/pkg/statistics" "github.com/tikv/pd/pkg/utils/apiutil" @@ -33,6 +35,8 @@ var ( PDAddress string // Debug is the flag to print the output of api response for debug. Debug bool + // ClusterID is the ID of cluster. + ClusterID uint64 ) var ( @@ -43,6 +47,7 @@ var ( // InitCluster initializes the cluster. func InitCluster(ctx context.Context, cli pd.Client, httpClit *http.Client) error { + ClusterID = cli.GetClusterID(ctx) req, _ := http.NewRequestWithContext(ctx, http.MethodGet, PDAddress+"/pd/api/v1/stats/region?start_key=&end_key=&count", nil) resp, err := httpClit.Do(req) @@ -113,10 +118,11 @@ type GRPCCase interface { // GRPCCaseMap is the map for all gRPC cases. var GRPCCaseMap = map[string]GRPCCase{ - "GetRegion": newGetRegion(), - "GetStore": newGetStore(), - "GetStores": newGetStores(), - "ScanRegions": newScanRegions(), + "StoreHeartbeat": newStoreHeartbeat(), + "GetRegion": newGetRegion(), + "GetStore": newGetStore(), + "GetStores": newGetStores(), + "ScanRegions": newScanRegions(), } // HTTPCase is the interface for all HTTP cases. @@ -254,6 +260,44 @@ func (c *getRegion) Unary(ctx context.Context, cli pd.Client) error { return nil } +type storeHeartbeat struct { + *baseCase +} + +func newStoreHeartbeat() *storeHeartbeat { + return &storeHeartbeat{ + baseCase: &baseCase{ + name: "StoreHeartbeat", + qps: 10000, + burst: 1, + }, + } +} + +func (c *storeHeartbeat) Unary(ctx context.Context, cli pd.Client) error { + sd := cli.GetServiceDiscovery() + conn := sd.GetServingEndpointClientConn() + client := pdpb.NewPDClient(conn) + req := &pdpb.StoreHeartbeatRequest{ + Header: &pdpb.RequestHeader{ + ClusterId: ClusterID, + }, + Stats: &pdpb.StoreStats{ + StoreId: 1, + }, + } + resp, err := client.StoreHeartbeat(ctx, req) + if err != nil { + return err + } + if resp != nil { + if resp.Header.Error != nil { + return errors.Errorf(resp.Header.Error.Message) + } + } + return nil +} + type scanRegions struct { *baseCase regionSample int diff --git a/tools/pd-api-bench/go.mod b/tools/pd-api-bench/go.mod index e6e896a0797..b22fd44a292 100644 --- a/tools/pd-api-bench/go.mod +++ b/tools/pd-api-bench/go.mod @@ -3,6 +3,8 @@ module github.com/tools/pd-api-bench go 1.21 require ( + github.com/pingcap/kvproto v0.0.0-20231018065736-c0689aded40c + github.com/pkg/errors v0.9.1 github.com/tikv/pd v0.0.0-00010101000000-000000000000 github.com/tikv/pd/client v0.0.0-00010101000000-000000000000 go.uber.org/zap v1.24.0 @@ -69,7 +71,6 @@ require ( github.com/pingcap/errcode v0.3.0 // indirect github.com/pingcap/errors v0.11.5-0.20211224045212-9687c2b0f87c // indirect github.com/pingcap/failpoint v0.0.0-20210918120811-547c13e3eb00 // indirect - github.com/pingcap/kvproto v0.0.0-20231018065736-c0689aded40c // indirect github.com/pingcap/log v1.1.1-0.20221110025148-ca232912c9f3 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/prometheus/client_golang v1.11.1 // indirect diff --git a/tools/pd-api-bench/main.go b/tools/pd-api-bench/main.go index 722c7eb08fa..25147588578 100644 --- a/tools/pd-api-bench/main.go +++ b/tools/pd-api-bench/main.go @@ -25,6 +25,7 @@ import ( "os/signal" "strconv" "strings" + "sync" "syscall" "time" @@ -48,6 +49,8 @@ var ( qps = flag.Int64("qps", 1000, "qps") burst = flag.Int64("burst", 1, "burst") + wait = flag.Bool("wait", true, "wait for a round") + // http params httpParams = flag.String("params", "", "http params") @@ -141,11 +144,52 @@ func main() { } gcaseStr := strings.Split(*gRPCCases, ",") for _, str := range gcaseStr { - if len(str) == 0 { + caseQPS := int64(0) + caseBurst := int64(0) + cStr := "" + + strs := strings.Split(str, "-") + fmt.Println(strs) + // to get case name + strsa := strings.Split(strs[0], "+") + cStr = strsa[0] + // to get case Burst + if len(strsa) > 1 { + caseBurst, err = strconv.ParseInt(strsa[1], 10, 64) + if err != nil { + log.Printf("parse burst failed for case %s", str) + } + } + // to get case qps + if len(strs) > 1 { + strsb := strings.Split(strs[1], "+") + caseQPS, err = strconv.ParseInt(strsb[0], 10, 64) + if err != nil { + log.Printf("parse qps failed for case %s", str) + } + // to get case Burst + if len(strsb) > 1 { + caseBurst, err = strconv.ParseInt(strsb[1], 10, 64) + if err != nil { + log.Printf("parse burst failed for case %s", str) + } + } + } + if len(cStr) == 0 { continue } - if cas, ok := cases.GRPCCaseMap[str]; ok { + if cas, ok := cases.GRPCCaseMap[cStr]; ok { gcases = append(gcases, cas) + if caseBurst > 0 { + cas.SetBurst(caseBurst) + } else if *burst > 0 { + cas.SetBurst(*burst) + } + if caseQPS > 0 { + cas.SetQPS(caseQPS) + } else if *qps > 0 { + cas.SetQPS(*qps) + } } else { log.Println("no this case", str) } @@ -188,6 +232,9 @@ func main() { } func handleGRPCCase(ctx context.Context, gcase cases.GRPCCase, clients []pd.Client) { + startCnt := 0 + endCnt := 0 + var cntMu sync.Mutex qps := gcase.GetQPS() burst := gcase.GetBurst() tt := time.Duration(base/qps*burst*int64(*client)) * time.Microsecond @@ -200,9 +247,36 @@ func handleGRPCCase(ctx context.Context, gcase cases.GRPCCase, clients []pd.Clie select { case <-ticker.C: for i := int64(0); i < burst; i++ { - err := gcase.Unary(ctx, cli) - if err != nil { - log.Println(err) + cntMu.Lock() + startCnt++ + if startCnt%1000 == 0 { + log.Printf("case grpc %s has sent query %d", gcase.Name(), startCnt) + } + cntMu.Unlock() + if *wait { + err := gcase.Unary(ctx, cli) + if err != nil { + log.Println(err) + } + cntMu.Lock() + endCnt++ + if endCnt%1000 == 0 { + log.Printf("case grpc %s has finished query %d", gcase.Name(), endCnt) + } + cntMu.Unlock() + } else { + go func() { + err := gcase.Unary(ctx, cli) + if err != nil { + log.Println(err) + } + cntMu.Lock() + endCnt++ + if endCnt%1000 == 0 { + log.Printf("case grpc %s has finished query %d", gcase.Name(), endCnt) + } + cntMu.Unlock() + }() } } case <-ctx.Done(): @@ -215,6 +289,9 @@ func handleGRPCCase(ctx context.Context, gcase cases.GRPCCase, clients []pd.Clie } func handleHTTPCase(ctx context.Context, hcase cases.HTTPCase, httpClis []*http.Client) { + startCnt := 0 + endCnt := 0 + var cntMu sync.Mutex qps := hcase.GetQPS() burst := hcase.GetBurst() tt := time.Duration(base/qps*burst*int64(*client)) * time.Microsecond @@ -230,9 +307,36 @@ func handleHTTPCase(ctx context.Context, hcase cases.HTTPCase, httpClis []*http. select { case <-ticker.C: for i := int64(0); i < burst; i++ { - err := hcase.Do(ctx, hCli) - if err != nil { - log.Println(err) + cntMu.Lock() + startCnt++ + if startCnt%1000 == 0 { + log.Printf("case http %s has done query %d", hcase.Name(), startCnt) + } + cntMu.Unlock() + if *wait { + err := hcase.Do(ctx, hCli) + if err != nil { + log.Println(err) + } + cntMu.Lock() + endCnt++ + if endCnt%1000 == 0 { + log.Printf("case http %s has finished query %d", hcase.Name(), endCnt) + } + cntMu.Unlock() + } else { + go func() { + err := hcase.Do(ctx, hCli) + if err != nil { + log.Println(err) + } + cntMu.Lock() + endCnt++ + if endCnt%1000 == 0 { + log.Printf("case http %s has finished query %d", hcase.Name(), endCnt) + } + cntMu.Unlock() + }() } } case <-ctx.Done(): @@ -252,7 +356,7 @@ func exit(code int) { func newHTTPClient() *http.Client { // defaultTimeout for non-context requests. const defaultTimeout = 30 * time.Second - cli := &http.Client{Timeout: defaultTimeout} + cli := &http.Client{Timeout: defaultTimeout, Transport: http.DefaultTransport.(*http.Transport).Clone()} tlsConf := loadTLSConfig() if tlsConf != nil { transport := http.DefaultTransport.(*http.Transport).Clone()