diff --git a/benchmark/worker/benchmark_client.go b/benchmark/worker/benchmark_client.go index c28312dd6aab..af81445ec5bf 100644 --- a/benchmark/worker/benchmark_client.go +++ b/benchmark/worker/benchmark_client.go @@ -73,7 +73,6 @@ func (h *lockingHistogram) mergeInto(merged *stats.Histogram) { type benchmarkClient struct { closeConns func() - stop chan bool lastResetTime time.Time histogramOptions stats.HistogramOptions lockingHistograms []lockingHistogram @@ -168,7 +167,7 @@ func createConns(config *testpb.ClientConfig) ([]*grpc.ClientConn, func(), error }, nil } -func performRPCs(config *testpb.ClientConfig, conns []*grpc.ClientConn, bc *benchmarkClient) error { +func performRPCs(ctx context.Context, config *testpb.ClientConfig, conns []*grpc.ClientConn, bc *benchmarkClient) error { // Read payload size and type from config. var ( payloadReqSize, payloadRespSize int @@ -212,9 +211,9 @@ func performRPCs(config *testpb.ClientConfig, conns []*grpc.ClientConn, bc *benc switch config.RpcType { case testpb.RpcType_UNARY: - bc.unaryLoop(conns, rpcCountPerConn, payloadReqSize, payloadRespSize, poissonLambda) + bc.unaryLoop(ctx, conns, rpcCountPerConn, payloadReqSize, payloadRespSize, poissonLambda) case testpb.RpcType_STREAMING: - bc.streamingLoop(conns, rpcCountPerConn, payloadReqSize, payloadRespSize, payloadType, poissonLambda) + bc.streamingLoop(ctx, conns, rpcCountPerConn, payloadReqSize, payloadRespSize, payloadType, poissonLambda) default: return status.Errorf(codes.InvalidArgument, "unknown rpc type: %v", config.RpcType) } @@ -222,7 +221,7 @@ func performRPCs(config *testpb.ClientConfig, conns []*grpc.ClientConn, bc *benc return nil } -func startBenchmarkClient(config *testpb.ClientConfig) (*benchmarkClient, error) { +func startBenchmarkClient(ctx context.Context, config *testpb.ClientConfig) (*benchmarkClient, error) { printClientConfig(config) // Set running environment like how many cores to use. @@ -243,13 +242,12 @@ func startBenchmarkClient(config *testpb.ClientConfig) (*benchmarkClient, error) }, lockingHistograms: make([]lockingHistogram, rpcCountPerConn*len(conns)), - stop: make(chan bool), lastResetTime: time.Now(), closeConns: closeConns, rusageLastReset: syscall.GetRusage(), } - if err = performRPCs(config, conns, bc); err != nil { + if err = performRPCs(ctx, config, conns, bc); err != nil { // Close all connections if performRPCs failed. closeConns() return nil, err @@ -258,7 +256,7 @@ func startBenchmarkClient(config *testpb.ClientConfig) (*benchmarkClient, error) return bc, nil } -func (bc *benchmarkClient) unaryLoop(conns []*grpc.ClientConn, rpcCountPerConn int, reqSize int, respSize int, poissonLambda *float64) { +func (bc *benchmarkClient) unaryLoop(ctx context.Context, conns []*grpc.ClientConn, rpcCountPerConn int, reqSize int, respSize int, poissonLambda *float64) { for ic, conn := range conns { client := testgrpc.NewBenchmarkServiceClient(conn) // For each connection, create rpcCountPerConn goroutines to do rpc. @@ -274,10 +272,8 @@ func (bc *benchmarkClient) unaryLoop(conns []*grpc.ClientConn, rpcCountPerConn i // before starting benchmark. if poissonLambda == nil { // Closed loop. for { - select { - case <-bc.stop: - return - default: + if ctx.Err() != nil { + break } start := time.Now() if err := benchmark.DoUnaryCall(client, reqSize, respSize); err != nil { @@ -292,13 +288,12 @@ func (bc *benchmarkClient) unaryLoop(conns []*grpc.ClientConn, rpcCountPerConn i bc.poissonUnary(client, idx, reqSize, respSize, *poissonLambda) }) } - }(idx) } } } -func (bc *benchmarkClient) streamingLoop(conns []*grpc.ClientConn, rpcCountPerConn int, reqSize int, respSize int, payloadType string, poissonLambda *float64) { +func (bc *benchmarkClient) streamingLoop(ctx context.Context, conns []*grpc.ClientConn, rpcCountPerConn int, reqSize int, respSize int, payloadType string, poissonLambda *float64) { var doRPC func(testgrpc.BenchmarkService_StreamingCallClient, int, int) error if payloadType == "bytebuf" { doRPC = benchmark.DoByteBufStreamingRoundTrip @@ -329,10 +324,8 @@ func (bc *benchmarkClient) streamingLoop(conns []*grpc.ClientConn, rpcCountPerCo } elapse := time.Since(start) bc.lockingHistograms[idx].add(int64(elapse)) - select { - case <-bc.stop: + if ctx.Err() != nil { return - default: } } }(idx) @@ -364,6 +357,7 @@ func (bc *benchmarkClient) poissonUnary(client testgrpc.BenchmarkServiceClient, func (bc *benchmarkClient) poissonStreaming(stream testgrpc.BenchmarkService_StreamingCallClient, idx int, reqSize int, respSize int, lambda float64, doRPC func(testgrpc.BenchmarkService_StreamingCallClient, int, int) error) { go func() { start := time.Now() + if err := doRPC(stream, reqSize, respSize); err != nil { return } @@ -430,6 +424,5 @@ func (bc *benchmarkClient) getStats(reset bool) *testpb.ClientStats { } func (bc *benchmarkClient) shutdown() { - close(bc.stop) bc.closeConns() } diff --git a/benchmark/worker/main.go b/benchmark/worker/main.go index 45893d7b15a2..0a6dd965e177 100644 --- a/benchmark/worker/main.go +++ b/benchmark/worker/main.go @@ -139,7 +139,9 @@ func (s *workerServer) RunServer(stream testgrpc.WorkerService_RunServerServer) func (s *workerServer) RunClient(stream testgrpc.WorkerService_RunClientServer) error { var bc *benchmarkClient + ctx, cancel := context.WithCancel(stream.Context()) defer func() { + cancel() // Shut down benchmark client when stream ends. logger.Infof("shutting down benchmark client") if bc != nil { @@ -163,7 +165,7 @@ func (s *workerServer) RunClient(stream testgrpc.WorkerService_RunClientServer) logger.Infof("client setup received when client already exists, shutting down the existing client") bc.shutdown() } - bc, err = startBenchmarkClient(t.Setup) + bc, err = startBenchmarkClient(ctx, t.Setup) if err != nil { return err }