From 140a7c2d50720d79c41921a1795519f069da3f87 Mon Sep 17 00:00:00 2001 From: MyonKeminta Date: Mon, 5 Aug 2024 17:32:11 +0800 Subject: [PATCH 01/21] Add implementation of making tsoStream asynchronous Signed-off-by: MyonKeminta --- client/metrics.go | 24 ++-- client/tso_batch_controller.go | 83 +++++++++---- client/tso_dispatcher.go | 145 +++++++++++++++------- client/tso_stream.go | 218 +++++++++++++++++++++++++++++---- 4 files changed, 362 insertions(+), 108 deletions(-) diff --git a/client/metrics.go b/client/metrics.go index a11362669b3..a83b4a36407 100644 --- a/client/metrics.go +++ b/client/metrics.go @@ -39,13 +39,14 @@ func initAndRegisterMetrics(constLabels prometheus.Labels) { } var ( - cmdDuration *prometheus.HistogramVec - cmdFailedDuration *prometheus.HistogramVec - requestDuration *prometheus.HistogramVec - tsoBestBatchSize prometheus.Histogram - tsoBatchSize prometheus.Histogram - tsoBatchSendLatency prometheus.Histogram - requestForwarded *prometheus.GaugeVec + cmdDuration *prometheus.HistogramVec + cmdFailedDuration *prometheus.HistogramVec + requestDuration *prometheus.HistogramVec + tsoBestBatchSize prometheus.Histogram + tsoBatchSize prometheus.Histogram + tsoBatchSendLatency prometheus.Histogram + requestForwarded *prometheus.GaugeVec + ongoingRequestCountGauge *prometheus.GaugeVec ) func initMetrics(constLabels prometheus.Labels) { @@ -117,6 +118,15 @@ func initMetrics(constLabels prometheus.Labels) { Help: "The status to indicate if the request is forwarded", ConstLabels: constLabels, }, []string{"host", "delegate"}) + + ongoingRequestCountGauge = prometheus.NewGaugeVec( + prometheus.GaugeOpts{ + Namespace: "pd_client", + Subsystem: "request", + Name: "ongoing_requests_count", + Help: "Current count of ongoing batch tso requests", + ConstLabels: constLabels, + }, []string{"stream"}) } var ( diff --git a/client/tso_batch_controller.go b/client/tso_batch_controller.go index a713b7a187d..bf072b24535 100644 --- a/client/tso_batch_controller.go +++ b/client/tso_batch_controller.go @@ -30,42 +30,79 @@ type tsoBatchController struct { // bestBatchSize is a dynamic size that changed based on the current batch effect. bestBatchSize int - tsoRequestCh chan *tsoRequest collectedRequests []*tsoRequest collectedRequestCount int - batchStartTime time.Time + // The time after getting the first request and the token, and before performing extra batching. + extraBatchingStartTime time.Time } -func newTSOBatchController(tsoRequestCh chan *tsoRequest, maxBatchSize int) *tsoBatchController { +func newTSOBatchController(maxBatchSize int) *tsoBatchController { return &tsoBatchController{ maxBatchSize: maxBatchSize, bestBatchSize: 8, /* Starting from a low value is necessary because we need to make sure it will be converged to (current_batch_size - 4) */ - tsoRequestCh: tsoRequestCh, collectedRequests: make([]*tsoRequest, maxBatchSize+1), collectedRequestCount: 0, } } // fetchPendingRequests will start a new round of the batch collecting from the channel. -// It returns true if everything goes well, otherwise false which means we should stop the service. -func (tbc *tsoBatchController) fetchPendingRequests(ctx context.Context, maxBatchWaitInterval time.Duration) error { - var firstRequest *tsoRequest - select { - case <-ctx.Done(): - return ctx.Err() - case firstRequest = <-tbc.tsoRequestCh: - } - // Start to batch when the first TSO request arrives. - tbc.batchStartTime = time.Now() +// It returns nil error if everything goes well, otherwise a non-nil error which means we should stop the service. +// It's guaranteed that if this function failed after collecting some requests, then these requests will be cancelled +// when the function returns, so the caller don't need to clear them manually. +func (tbc *tsoBatchController) fetchPendingRequests(ctx context.Context, tsoRequestCh <-chan *tsoRequest, tokenCh chan struct{}, maxBatchWaitInterval time.Duration) (errRet error) { + var tokenAcquired bool + defer func() { + if errRet != nil { + // Something went wrong when collecting a batch of requests. Release the token and cancel collected requests + // if any. + if tokenAcquired { + tokenCh <- struct{}{} + } + if tbc.collectedRequestCount > 0 { + tbc.finishCollectedRequests(0, 0, 0, errRet) + } + } + }() + + // Wait until BOTH the first request and the token have arrived. + // TODO: `tbc.collectedRequestCount` should never be non-empty here. Consider do assertion here. tbc.collectedRequestCount = 0 - tbc.pushRequest(firstRequest) + for { + select { + case <-ctx.Done(): + return ctx.Err() + case req := <-tsoRequestCh: + // Start to batch when the first TSO request arrives. + tbc.pushRequest(req) + // A request arrives but the token is not ready yet. Continue waiting, and also allowing collecting the next + // request if it arrives. + continue + case <-tokenCh: + tokenAcquired = true + } + + // The token is ready. If the first request didn't arrive, wait for it. + if tbc.collectedRequestCount == 0 { + select { + case <-ctx.Done(): + return ctx.Err() + case firstRequest := <-tsoRequestCh: + tbc.pushRequest(firstRequest) + } + } + + // Both token and the first request have arrived. + break + } + + tbc.extraBatchingStartTime = time.Now() // This loop is for trying best to collect more requests, so we use `tbc.maxBatchSize` here. fetchPendingRequestsLoop: for tbc.collectedRequestCount < tbc.maxBatchSize { select { - case tsoReq := <-tbc.tsoRequestCh: + case tsoReq := <-tsoRequestCh: tbc.pushRequest(tsoReq) case <-ctx.Done(): return ctx.Err() @@ -88,7 +125,7 @@ fetchPendingRequestsLoop: defer after.Stop() for tbc.collectedRequestCount < tbc.bestBatchSize { select { - case tsoReq := <-tbc.tsoRequestCh: + case tsoReq := <-tsoRequestCh: tbc.pushRequest(tsoReq) case <-ctx.Done(): return ctx.Err() @@ -103,7 +140,7 @@ fetchPendingRequestsLoop: // we can adjust the `tbc.bestBatchSize` dynamically later. for tbc.collectedRequestCount < tbc.maxBatchSize { select { - case tsoReq := <-tbc.tsoRequestCh: + case tsoReq := <-tsoRequestCh: tbc.pushRequest(tsoReq) case <-ctx.Done(): return ctx.Err() @@ -149,18 +186,10 @@ func (tbc *tsoBatchController) finishCollectedRequests(physical, firstLogical in tbc.collectedRequestCount = 0 } -func (tbc *tsoBatchController) revokePendingRequests(err error) { - for i := 0; i < len(tbc.tsoRequestCh); i++ { - req := <-tbc.tsoRequestCh - req.tryDone(err) - } -} - func (tbc *tsoBatchController) clear() { log.Info("[pd] clear the tso batch controller", zap.Int("max-batch-size", tbc.maxBatchSize), zap.Int("best-batch-size", tbc.bestBatchSize), - zap.Int("collected-request-count", tbc.collectedRequestCount), zap.Int("pending-request-count", len(tbc.tsoRequestCh))) + zap.Int("collected-request-count", tbc.collectedRequestCount)) tsoErr := errors.WithStack(errClosing) tbc.finishCollectedRequests(0, 0, 0, tsoErr) - tbc.revokePendingRequests(tsoErr) } diff --git a/client/tso_dispatcher.go b/client/tso_dispatcher.go index a7c99057275..2e670c656c9 100644 --- a/client/tso_dispatcher.go +++ b/client/tso_dispatcher.go @@ -76,10 +76,18 @@ type tsoDispatcher struct { provider tsoServiceProvider // URL -> *connectionContext - connectionCtxs *sync.Map - batchController *tsoBatchController - tsDeadlineCh chan *deadline - lastTSOInfo *tsoInfo + connectionCtxs *sync.Map + tsoRequestCh chan *tsoRequest + tsDeadlineCh chan *deadline + lastTSOInfo *tsoInfo + // For reusing tsoBatchController objects + batchBufferPool *sync.Pool + + // For controlling amount of concurrently processing RPC requests. + // A token must be acquired here before sending an RPC request, and the token must be put back after finishing the + // RPC. This is used like a semaphore, but we don't use semaphore directly here as it cannot be selected with + // other channels. + tokenCh chan struct{} updateConnectionCtxsCh chan struct{} } @@ -91,24 +99,27 @@ func newTSODispatcher( provider tsoServiceProvider, ) *tsoDispatcher { dispatcherCtx, dispatcherCancel := context.WithCancel(ctx) - tsoBatchController := newTSOBatchController( - make(chan *tsoRequest, maxBatchSize*2), - maxBatchSize, - ) + tsoRequestCh := make(chan *tsoRequest, maxBatchSize*2) failpoint.Inject("shortDispatcherChannel", func() { - tsoBatchController = newTSOBatchController( - make(chan *tsoRequest, 1), - maxBatchSize, - ) + tsoRequestCh = make(chan *tsoRequest, 1) }) + + tokenCh := make(chan struct{}, 64) + td := &tsoDispatcher{ - ctx: dispatcherCtx, - cancel: dispatcherCancel, - dc: dc, - provider: provider, - connectionCtxs: &sync.Map{}, - batchController: tsoBatchController, - tsDeadlineCh: make(chan *deadline, 1), + ctx: dispatcherCtx, + cancel: dispatcherCancel, + dc: dc, + provider: provider, + connectionCtxs: &sync.Map{}, + tsoRequestCh: tsoRequestCh, + tsDeadlineCh: make(chan *deadline, 1), + batchBufferPool: &sync.Pool{ + New: func() any { + return newTSOBatchController(maxBatchSize * 2) + }, + }, + tokenCh: tokenCh, updateConnectionCtxsCh: make(chan struct{}, 1), } go td.watchTSDeadline() @@ -146,13 +157,21 @@ func (td *tsoDispatcher) scheduleUpdateConnectionCtxs() { } } +func (td *tsoDispatcher) revokePendingRequests(err error) { + for i := 0; i < len(td.tsoRequestCh); i++ { + req := <-td.tsoRequestCh + req.tryDone(err) + } +} + func (td *tsoDispatcher) close() { td.cancel() - td.batchController.clear() + tsoErr := errors.WithStack(errClosing) + td.revokePendingRequests(tsoErr) } func (td *tsoDispatcher) push(request *tsoRequest) { - td.batchController.tsoRequestCh <- request + td.tsoRequestCh <- request } func (td *tsoDispatcher) handleDispatcher(wg *sync.WaitGroup) { @@ -163,8 +182,12 @@ func (td *tsoDispatcher) handleDispatcher(wg *sync.WaitGroup) { svcDiscovery = provider.getServiceDiscovery() option = provider.getOption() connectionCtxs = td.connectionCtxs - batchController = td.batchController + batchController *tsoBatchController ) + + // Currently only 1 concurrency is supported. Put one token in. + td.tokenCh <- struct{}{} + log.Info("[tso] tso dispatcher created", zap.String("dc-location", dc)) // Clean up the connectionCtxs when the dispatcher exits. defer func() { @@ -174,8 +197,11 @@ func (td *tsoDispatcher) handleDispatcher(wg *sync.WaitGroup) { cc.(*tsoConnectionContext).cancel() return true }) - // Clear the tso batch controller. - batchController.clear() + if batchController.collectedRequestCount != 0 { + log.Fatal("batched tso requests not cleared when exiting the tso dispatcher loop") + } + tsoErr := errors.WithStack(errClosing) + td.revokePendingRequests(tsoErr) wg.Done() }() // Daemon goroutine to update the connectionCtxs periodically and handle the `connectionCtxs` update event. @@ -203,9 +229,7 @@ tsoBatchLoop: maxBatchWaitInterval := option.getMaxTSOBatchWaitInterval() // Once the TSO requests are collected, must make sure they could be finished or revoked eventually, // otherwise the upper caller may get blocked on waiting for the results. - if err = batchController.fetchPendingRequests(ctx, maxBatchWaitInterval); err != nil { - // Finish the collected requests if the fetch failed. - batchController.finishCollectedRequests(0, 0, 0, errors.WithStack(err)) + if err = batchController.fetchPendingRequests(ctx, td.tsoRequestCh, td.tokenCh, maxBatchWaitInterval); err != nil { if err == context.Canceled { log.Info("[tso] stop fetching the pending tso requests due to context canceled", zap.String("dc-location", dc)) @@ -246,7 +270,7 @@ tsoBatchLoop: select { case <-ctx.Done(): // Finish the collected requests if the context is canceled. - batchController.finishCollectedRequests(0, 0, 0, errors.WithStack(ctx.Err())) + td.cancelCollectedRequests(batchController, errors.WithStack(ctx.Err())) timer.Stop() return case <-streamLoopTimer.C: @@ -254,7 +278,7 @@ tsoBatchLoop: log.Error("[tso] create tso stream error", zap.String("dc-location", dc), errs.ZapError(err)) svcDiscovery.ScheduleCheckMemberChanged() // Finish the collected requests if the stream is failed to be created. - batchController.finishCollectedRequests(0, 0, 0, errors.WithStack(err)) + td.cancelCollectedRequests(batchController, errors.WithStack(err)) timer.Stop() continue tsoBatchLoop case <-timer.C: @@ -279,15 +303,22 @@ tsoBatchLoop: select { case <-ctx.Done(): // Finish the collected requests if the context is canceled. - batchController.finishCollectedRequests(0, 0, 0, errors.WithStack(ctx.Err())) + td.cancelCollectedRequests(batchController, errors.WithStack(ctx.Err())) return case td.tsDeadlineCh <- dl: } // processRequests guarantees that the collected requests could be finished properly. - err = td.processRequests(stream, dc, td.batchController) + err = td.processRequests(stream, dc, batchController) close(done) // If error happens during tso stream handling, reset stream and run the next trial. - if err != nil { + if err == nil { + // If the request is started successfully, the `batchController` will be put back to the pool when the + // request is finished (either successful or not). In this case, set the `batchController` to nil so that + // another one will be fetched from the pool. + // Otherwise, the `batchController` won't be processed in other goroutines concurrently, and it can be + // reused in the next loop safely. + batchController = nil + } else { select { case <-ctx.Done(): return @@ -422,28 +453,48 @@ func (td *tsoDispatcher) processRequests( keyspaceID = svcDiscovery.GetKeyspaceID() reqKeyspaceGroupID = svcDiscovery.GetKeyspaceGroupID() ) - respKeyspaceGroupID, physical, logical, suffixBits, err := stream.processRequests( + + cb := func(result tsoRequestResult, reqKeyspaceGroupID uint32, streamURL string, err error) { + defer td.batchBufferPool.Put(tbc) + if err != nil { + td.cancelCollectedRequests(tbc, err) + return + } + + curTSOInfo := &tsoInfo{ + tsoServer: stream.getServerURL(), + reqKeyspaceGroupID: reqKeyspaceGroupID, + respKeyspaceGroupID: result.respKeyspaceGroupID, + respReceivedAt: time.Now(), + physical: result.physical, + logical: result.logical, + } + // `logical` is the largest ts's logical part here, we need to do the subtracting before we finish each TSO request. + firstLogical := tsoutil.AddLogical(result.logical, -int64(result.count)+1, result.suffixBits) + td.compareAndSwapTS(curTSOInfo, firstLogical) + td.doneCollectedRequests(tbc, result.physical, firstLogical, result.suffixBits) + } + + err := stream.processRequests( clusterID, keyspaceID, reqKeyspaceGroupID, - dcLocation, count, tbc.batchStartTime) + dcLocation, count, tbc.extraBatchingStartTime, cb) if err != nil { - tbc.finishCollectedRequests(0, 0, 0, err) + td.cancelCollectedRequests(tbc, err) return err } - curTSOInfo := &tsoInfo{ - tsoServer: stream.getServerURL(), - reqKeyspaceGroupID: reqKeyspaceGroupID, - respKeyspaceGroupID: respKeyspaceGroupID, - respReceivedAt: time.Now(), - physical: physical, - logical: logical, - } - // `logical` is the largest ts's logical part here, we need to do the subtracting before we finish each TSO request. - firstLogical := tsoutil.AddLogical(logical, -count+1, suffixBits) - td.compareAndSwapTS(curTSOInfo, firstLogical) - tbc.finishCollectedRequests(physical, firstLogical, suffixBits, nil) return nil } +func (td *tsoDispatcher) cancelCollectedRequests(tbc *tsoBatchController, err error) { + td.tokenCh <- struct{}{} + tbc.finishCollectedRequests(0, 0, 0, err) +} + +func (td *tsoDispatcher) doneCollectedRequests(tbc *tsoBatchController, physical int64, firstLogical int64, suffixBits uint32) { + td.tokenCh <- struct{}{} + tbc.finishCollectedRequests(physical, firstLogical, suffixBits, nil) +} + func (td *tsoDispatcher) compareAndSwapTS( curTSOInfo *tsoInfo, firstLogical int64, ) { diff --git a/client/tso_stream.go b/client/tso_stream.go index da9cab95ba0..74fbc620cef 100644 --- a/client/tso_stream.go +++ b/client/tso_stream.go @@ -16,13 +16,19 @@ package pd import ( "context" + "fmt" "io" + "sync" + "sync/atomic" "time" "github.com/pingcap/errors" "github.com/pingcap/kvproto/pkg/pdpb" "github.com/pingcap/kvproto/pkg/tsopb" + "github.com/pingcap/log" + "github.com/prometheus/client_golang/prometheus" "github.com/tikv/pd/client/errs" + "go.uber.org/zap" "google.golang.org/grpc" ) @@ -62,7 +68,7 @@ func (b *pdTSOStreamBuilder) build(ctx context.Context, cancel context.CancelFun stream, err := b.client.Tso(ctx) done <- struct{}{} if err == nil { - return &tsoStream{stream: pdTSOStreamAdapter{stream}, serverURL: b.serverURL}, nil + return newTSOStream(b.serverURL, pdTSOStreamAdapter{stream}), nil } return nil, err } @@ -81,7 +87,7 @@ func (b *tsoTSOStreamBuilder) build( stream, err := b.client.Tso(ctx) done <- struct{}{} if err == nil { - return &tsoStream{stream: tsoTSOStreamAdapter{stream}, serverURL: b.serverURL}, nil + return newTSOStream(b.serverURL, tsoTSOStreamAdapter{stream}), nil } return nil, err } @@ -172,51 +178,209 @@ func (s tsoTSOStreamAdapter) Recv() (tsoRequestResult, error) { }, nil } +type onFinishedCallback func(result tsoRequestResult, reqKeyspaceGroupID uint32, streamURL string, err error) + +type batchedRequests struct { + startTime time.Time + count int64 + reqKeyspaceGroupID uint32 + callback onFinishedCallback +} + +// tsoStream represents an abstracted stream for requesting TSO. +// This type designed decoupled with users of this type, so tsoDispatcher won't be directly accessed here. +// Also in order to avoid potential memory allocations that might happen when passing closures as the callback, +// we instead use the `batchedRequestsNotifier` as the abstraction, and accepts generic type instead of dynamic interface +// type. type tsoStream struct { serverURL string // The internal gRPC stream. // - `pdpb.PD_TsoClient` for a leader/follower in the PD cluster. // - `tsopb.TSO_TsoClient` for a primary/secondary in the TSO cluster. stream grpcTSOStreamAdapter + // An identifier of the tsoStream object for metrics reporting and diagnosing. + streamID string + + pendingRequests chan batchedRequests + + estimateLatencyMicros atomic.Uint64 + + cancel context.CancelFunc + wg sync.WaitGroup + + // For syncing between sender and receiver to guarantee all requests are finished when closing. + state atomic.Int32 + + ongoingRequestCountGauge prometheus.Gauge + ongoingRequests atomic.Int32 +} + +const ( + streamStateIdle int32 = iota + streamStateSending + streamStateClosing +) + +var streamIDAlloc atomic.Int32 + +// TODO: Pass a context? +func newTSOStream(serverURL string, stream grpcTSOStreamAdapter) *tsoStream { + streamID := fmt.Sprintf("%s-%d", serverURL, streamIDAlloc.Add(1)) + ctx, cancel := context.WithCancel(context.Background()) + s := &tsoStream{ + serverURL: serverURL, + stream: stream, + streamID: streamID, + + pendingRequests: make(chan batchedRequests, 64), + + cancel: cancel, + + ongoingRequestCountGauge: ongoingRequestCountGauge.WithLabelValues(streamID), + } + s.wg.Add(1) + go s.recvLoop(ctx) + return s } func (s *tsoStream) getServerURL() string { return s.serverURL } +// processRequests starts an RPC to get a batch of timestamps without waiting for the result. When the result is ready, +// it will be passed th `notifier.finish`. +// +// This function is NOT thread-safe. Don't call this function concurrently in multiple goroutines. +// +// It's guaranteed that the `callback` will be called, but when the request is failed to be scheduled, the callback +// will be ignored. func (s *tsoStream) processRequests( - clusterID uint64, keyspaceID, keyspaceGroupID uint32, dcLocation string, count int64, batchStartTime time.Time, -) (respKeyspaceGroupID uint32, physical, logical int64, suffixBits uint32, err error) { + clusterID uint64, keyspaceID, keyspaceGroupID uint32, dcLocation string, count int64, batchStartTime time.Time, callback onFinishedCallback, +) error { start := time.Now() - if err = s.stream.Send(clusterID, keyspaceID, keyspaceGroupID, dcLocation, count); err != nil { + + // Check if the stream is closing or closed, in which case no more requests should be put in. + // Note that the prevState should be restored very soon, as the receiver may check + prevState := s.state.Swap(streamStateSending) + switch prevState { + case streamStateIdle: + // Expected case + break + case streamStateClosing: + s.state.Store(prevState) + log.Info("tsoStream closed") + return errs.ErrClientTSOStreamClosed + case streamStateSending: + log.Fatal("unexpected concurrent sending on tsoStream", zap.String("stream", s.streamID)) + default: + log.Fatal("unknown tsoStream state", zap.String("stream", s.streamID), zap.Int32("state", prevState)) + } + + select { + case s.pendingRequests <- batchedRequests{ + startTime: start, + count: count, + reqKeyspaceGroupID: keyspaceGroupID, + callback: callback, + }: + default: + s.state.Store(prevState) + return errors.New("unexpected channel full") + } + s.state.Store(prevState) + + if err := s.stream.Send(clusterID, keyspaceID, keyspaceGroupID, dcLocation, count); err != nil { if err == io.EOF { - err = errs.ErrClientTSOStreamClosed - } else { - err = errors.WithStack(err) + return errs.ErrClientTSOStreamClosed } - return + return errors.WithStack(err) } tsoBatchSendLatency.Observe(time.Since(batchStartTime).Seconds()) - res, err := s.stream.Recv() - duration := time.Since(start).Seconds() - if err != nil { - requestFailedDurationTSO.Observe(duration) - if err == io.EOF { - err = errs.ErrClientTSOStreamClosed - } else { - err = errors.WithStack(err) + s.ongoingRequestCountGauge.Set(float64(s.ongoingRequests.Add(1))) + return nil +} + +func (s *tsoStream) recvLoop(ctx context.Context) { + var finishWithErr error + + defer func() { + s.cancel() + for !s.state.CompareAndSwap(streamStateIdle, streamStateClosing) { + switch state := s.state.Load(); state { + case streamStateIdle, streamStateSending: + // streamStateSending should switch to streamStateIdle very quickly. Spin until successfully setting to + // streamStateClosing. + continue + case streamStateClosing: + log.Warn("unexpected double closing of tsoStream", zap.String("stream", s.streamID)) + default: + log.Fatal("unknown tsoStream state", zap.String("stream", s.streamID), zap.Int32("state", state)) + } } - return - } - requestDurationTSO.Observe(duration) - tsoBatchSize.Observe(float64(count)) - if res.count != uint32(count) { - err = errors.WithStack(errTSOLength) - return + // The loop must end with an error (including context.Canceled). + if finishWithErr == nil { + log.Fatal("tsoStream recvLoop ended without error", zap.String("stream", s.streamID)) + } + log.Info("tsoStream.recvLoop ended", zap.String("stream", s.streamID), zap.Error(finishWithErr)) + + close(s.pendingRequests) + for req := range s.pendingRequests { + req.callback(tsoRequestResult{}, req.reqKeyspaceGroupID, s.serverURL, finishWithErr) + } + + s.wg.Done() + s.ongoingRequests.Store(0) + s.ongoingRequestCountGauge.Set(0) + }() + +recvLoop: + for { + select { + case <-ctx.Done(): + finishWithErr = context.Canceled + break recvLoop + default: + } + + res, err := s.stream.Recv() + + // Load the corresponding batchedRequests + var req batchedRequests + select { + case req = <-s.pendingRequests: + default: + finishWithErr = errors.New("tsoStream timing order broken") + break + } + + durationSeconds := time.Since(req.startTime).Seconds() + + if err != nil { + requestFailedDurationTSO.Observe(durationSeconds) + if err == io.EOF { + finishWithErr = errs.ErrClientTSOStreamClosed + } else { + finishWithErr = errors.WithStack(err) + } + break + } + + latencySeconds := durationSeconds + requestDurationTSO.Observe(latencySeconds) + tsoBatchSize.Observe(float64(res.count)) + + if res.count != uint32(req.count) { + finishWithErr = errors.WithStack(errTSOLength) + break + } + + req.callback(res, req.reqKeyspaceGroupID, s.serverURL, nil) + s.ongoingRequestCountGauge.Set(float64(s.ongoingRequests.Add(-1))) } +} - respKeyspaceGroupID = res.respKeyspaceGroupID - physical, logical, suffixBits = res.physical, res.logical, res.suffixBits - return +func (s *tsoStream) Close() { + s.cancel() + s.wg.Wait() } From 88d9f5b0fe8d1777a69ef4fb0961b6d000c868b0 Mon Sep 17 00:00:00 2001 From: MyonKeminta Date: Wed, 7 Aug 2024 17:34:29 +0800 Subject: [PATCH 02/21] Add basic tests for rewritten tsoStream Signed-off-by: MyonKeminta --- client/tso_stream.go | 61 ++++++--- client/tso_stream_test.go | 263 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 307 insertions(+), 17 deletions(-) create mode 100644 client/tso_stream_test.go diff --git a/client/tso_stream.go b/client/tso_stream.go index 74fbc620cef..4e18b5e31d7 100644 --- a/client/tso_stream.go +++ b/client/tso_stream.go @@ -302,8 +302,25 @@ func (s *tsoStream) processRequests( func (s *tsoStream) recvLoop(ctx context.Context) { var finishWithErr error + var currentReq batchedRequests + var hasReq bool defer func() { + if r := recover(); r != nil { + log.Fatal("tsoStream.recvLoop internal panic", zap.Stack("stacktrace"), zap.Any("panicMessage", r)) + } + + if finishWithErr == nil { + // The loop must exit with a non-nil error (including io.EOF and context.Canceled). This should be + // unreachable code. + log.Fatal("tsoStream.recvLoop exited without error info") + } + + if hasReq { + // There's an unfinished request, cancel it, otherwise it will be blocked forever. + currentReq.callback(tsoRequestResult{}, currentReq.reqKeyspaceGroupID, s.serverURL, finishWithErr) + } + s.cancel() for !s.state.CompareAndSwap(streamStateIdle, streamStateClosing) { switch state := s.state.Load(); state { @@ -318,13 +335,11 @@ func (s *tsoStream) recvLoop(ctx context.Context) { } } - // The loop must end with an error (including context.Canceled). - if finishWithErr == nil { - log.Fatal("tsoStream recvLoop ended without error", zap.String("stream", s.streamID)) - } log.Info("tsoStream.recvLoop ended", zap.String("stream", s.streamID), zap.Error(finishWithErr)) close(s.pendingRequests) + + // Cancel remaining pending requests. for req := range s.pendingRequests { req.callback(tsoRequestResult{}, req.reqKeyspaceGroupID, s.serverURL, finishWithErr) } @@ -345,42 +360,54 @@ recvLoop: res, err := s.stream.Recv() - // Load the corresponding batchedRequests - var req batchedRequests + // Try to load the corresponding `batchedRequests`. If `Recv` is successful, there must be a request pending + // in the queue. select { - case req = <-s.pendingRequests: + case currentReq = <-s.pendingRequests: + hasReq = true default: - finishWithErr = errors.New("tsoStream timing order broken") - break + hasReq = false } - durationSeconds := time.Since(req.startTime).Seconds() + durationSeconds := time.Since(currentReq.startTime).Seconds() if err != nil { - requestFailedDurationTSO.Observe(durationSeconds) + // If a request is pending and error occurs, observe the duration it has cost. + // Note that it's also possible that the stream is broken due to network without being requested. In this + // case, `Recv` may return an error while no request is pending. + if hasReq { + requestFailedDurationTSO.Observe(durationSeconds) + } if err == io.EOF { finishWithErr = errs.ErrClientTSOStreamClosed } else { finishWithErr = errors.WithStack(err) } - break + break recvLoop + } else if !hasReq { + finishWithErr = errors.New("tsoStream timing order broken") + break recvLoop } latencySeconds := durationSeconds requestDurationTSO.Observe(latencySeconds) tsoBatchSize.Observe(float64(res.count)) - if res.count != uint32(req.count) { + if res.count != uint32(currentReq.count) { finishWithErr = errors.WithStack(errTSOLength) - break + break recvLoop } - req.callback(res, req.reqKeyspaceGroupID, s.serverURL, nil) + currentReq.callback(res, currentReq.reqKeyspaceGroupID, s.serverURL, nil) + // After finishing the requests, unset these variables which will be checked in the defer block. + currentReq = batchedRequests{} + hasReq = false + s.ongoingRequestCountGauge.Set(float64(s.ongoingRequests.Add(-1))) } } -func (s *tsoStream) Close() { - s.cancel() +// WaitForClosed blocks until the stream is closed and the inner loop exits. +func (s *tsoStream) WaitForClosed() { s.wg.Wait() } diff --git a/client/tso_stream_test.go b/client/tso_stream_test.go new file mode 100644 index 00000000000..47af677ca1f --- /dev/null +++ b/client/tso_stream_test.go @@ -0,0 +1,263 @@ +// Copyright 2024 TiKV Project Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package pd + +import ( + "context" + "io" + "testing" + "time" + + "github.com/pingcap/errors" + "github.com/stretchr/testify/suite" + "github.com/tikv/pd/client/errs" +) + +type resultMsg struct { + r tsoRequestResult + err error + breakStream bool +} + +type mockTSOStreamImpl struct { + requestCh chan struct{} + resultCh chan resultMsg + keyspaceID uint32 +} + +func newMockTSOStreamImpl() *mockTSOStreamImpl { + return &mockTSOStreamImpl{ + requestCh: make(chan struct{}, 64), + resultCh: make(chan resultMsg, 64), + keyspaceID: 0, + } +} + +func (s *mockTSOStreamImpl) Send(clusterID uint64, keyspaceID, keyspaceGroupID uint32, dcLocation string, count int64) error { + s.requestCh <- struct{}{} + return nil +} + +func (s *mockTSOStreamImpl) Recv() (tsoRequestResult, error) { + res := <-s.resultCh + if !res.breakStream { + <-s.requestCh + } + return res.r, res.err +} + +func (s *mockTSOStreamImpl) returnResult(physical int64, logical int64, count uint32) { + s.resultCh <- resultMsg{ + r: tsoRequestResult{ + physical: physical, + logical: logical, + count: count, + suffixBits: 0, + respKeyspaceGroupID: s.keyspaceID, + }, + } +} + +func (s *mockTSOStreamImpl) returnError(err error) { + s.resultCh <- resultMsg{ + err: err, + } +} + +func (s *mockTSOStreamImpl) breakStream(err error) { + s.resultCh <- resultMsg{ + err: err, + breakStream: true, + } +} + +func (s *mockTSOStreamImpl) stop() { + s.breakStream(io.EOF) +} + +type callbackInvocation struct { + result tsoRequestResult + streamURL string + err error +} + +type testTSOStreamSuite struct { + suite.Suite + + inner *mockTSOStreamImpl + stream *tsoStream +} + +func (s *testTSOStreamSuite) SetupTest() { + s.inner = newMockTSOStreamImpl() + s.stream = newTSOStream("mock:///", s.inner) +} + +func (s *testTSOStreamSuite) TearDownTest() { + s.inner.stop() + s.stream.WaitForClosed() + s.inner = nil + s.stream = nil +} + +func TestTSOStreamTestSuite(t *testing.T) { + suite.Run(t, new(testTSOStreamSuite)) +} + +func (s *testTSOStreamSuite) noResult(ch <-chan callbackInvocation) { + select { + case res := <-ch: + s.FailNowf("result received unexpectedly", "received result: %+v", res) + case <-time.After(time.Millisecond * 20): + } +} + +func (s *testTSOStreamSuite) getResult(ch <-chan callbackInvocation) callbackInvocation { + select { + case res := <-ch: + return res + case <-time.After(time.Second * 10000): + s.FailNow("result not ready in time") + panic("result not ready in time") + } +} + +func (s *testTSOStreamSuite) processRequestWithResultCh(count int64) <-chan callbackInvocation { + ch := make(chan callbackInvocation, 1) + err := s.stream.processRequests(1, 2, 3, globalDCLocation, count, time.Now(), func(result tsoRequestResult, reqKeyspaceGroupID uint32, streamURL string, err error) { + if err == nil { + s.Equal(uint32(3), reqKeyspaceGroupID) + s.Equal(uint32(0), result.suffixBits) + } + ch <- callbackInvocation{ + result: result, + streamURL: streamURL, + err: err, + } + }) + s.NoError(err) + return ch +} + +func (s *testTSOStreamSuite) TestTSOStreamBasic() { + ch := s.processRequestWithResultCh(1) + s.noResult(ch) + s.inner.returnResult(10, 1, 1) + res := s.getResult(ch) + + s.NoError(res.err) + s.Equal("mock:///", res.streamURL) + s.Equal(int64(10), res.result.physical) + s.Equal(int64(1), res.result.logical) + s.Equal(uint32(1), res.result.count) + + ch = s.processRequestWithResultCh(2) + s.noResult(ch) + s.inner.returnResult(20, 3, 2) + res = s.getResult(ch) + + s.NoError(res.err) + s.Equal("mock:///", res.streamURL) + s.Equal(int64(20), res.result.physical) + s.Equal(int64(3), res.result.logical) + s.Equal(uint32(2), res.result.count) + + ch = s.processRequestWithResultCh(3) + s.noResult(ch) + s.inner.returnError(errors.New("mock rpc error")) + res = s.getResult(ch) + s.Error(res.err) + s.Equal("mock rpc error", res.err.Error()) + + // After an error from the (simulated) RPC stream, the tsoStream should be in a broken status and can't accept + // new request anymore. + err := s.stream.processRequests(1, 2, 3, globalDCLocation, 1, time.Now(), func(result tsoRequestResult, reqKeyspaceGroupID uint32, streamURL string, err error) { + panic("unreachable") + }) + s.Error(err) +} + +func (s *testTSOStreamSuite) testTSOStreamBrokenImpl(err error, pendingRequests int) { + var resultCh []<-chan callbackInvocation + + for i := 0; i < pendingRequests; i++ { + ch := s.processRequestWithResultCh(1) + resultCh = append(resultCh, ch) + s.noResult(ch) + } + + s.inner.breakStream(err) + closedCh := make(chan struct{}) + go func() { + s.stream.WaitForClosed() + closedCh <- struct{}{} + }() + select { + case <-closedCh: + case <-time.After(time.Second): + s.FailNow("stream receiver loop didn't exit") + } + + for _, ch := range resultCh { + res := s.getResult(ch) + s.Error(res.err) + if err == io.EOF { + s.ErrorIs(res.err, errs.ErrClientTSOStreamClosed) + } else { + s.ErrorIs(res.err, err) + } + } +} + +func (s *testTSOStreamSuite) TestTSOStreamBrokenWithEOFNoPendingReq() { + s.testTSOStreamBrokenImpl(io.EOF, 0) +} + +func (s *testTSOStreamSuite) TestTSOStreamCanceledNoPendingReq() { + s.testTSOStreamBrokenImpl(context.Canceled, 0) +} + +func (s *testTSOStreamSuite) TestTSOStreamBrokenWithEOFWithPendingReq() { + s.testTSOStreamBrokenImpl(io.EOF, 5) +} + +func (s *testTSOStreamSuite) TestTSOStreamCanceledWithPendingReq() { + s.testTSOStreamBrokenImpl(context.Canceled, 5) +} + +func (s *testTSOStreamSuite) TestTSOStreamFIFO() { + var resultChs []<-chan callbackInvocation + const COUNT = 5 + for i := 0; i < COUNT; i++ { + ch := s.processRequestWithResultCh(int64(i + 1)) + resultChs = append(resultChs, ch) + } + + for _, ch := range resultChs { + s.noResult(ch) + } + + for i := 0; i < COUNT; i++ { + s.inner.returnResult(int64((i+1)*10), int64(i), uint32(i+1)) + } + + for i, ch := range resultChs { + res := s.getResult(ch) + s.NoError(res.err) + s.Equal(int64((i+1)*10), res.result.physical) + s.Equal(int64(i), res.result.logical) + s.Equal(uint32(i+1), res.result.count) + } +} From 623d896f6b8e99dd07cf014e1267e0bf090d3d4a Mon Sep 17 00:00:00 2001 From: MyonKeminta Date: Thu, 8 Aug 2024 16:35:16 +0800 Subject: [PATCH 03/21] Add a concurrency test Signed-off-by: MyonKeminta --- client/tso_stream_test.go | 89 +++++++++++++++++++++++++++++++++++---- 1 file changed, 80 insertions(+), 9 deletions(-) diff --git a/client/tso_stream_test.go b/client/tso_stream_test.go index 47af677ca1f..46b308d1f2c 100644 --- a/client/tso_stream_test.go +++ b/client/tso_stream_test.go @@ -35,6 +35,7 @@ type mockTSOStreamImpl struct { requestCh chan struct{} resultCh chan resultMsg keyspaceID uint32 + errorState error } func newMockTSOStreamImpl() *mockTSOStreamImpl { @@ -51,10 +52,17 @@ func (s *mockTSOStreamImpl) Send(clusterID uint64, keyspaceID, keyspaceGroupID u } func (s *mockTSOStreamImpl) Recv() (tsoRequestResult, error) { + // This stream have ever receive an error, it returns the error forever. + if s.errorState != nil { + return tsoRequestResult{}, s.errorState + } res := <-s.resultCh if !res.breakStream { <-s.requestCh } + if res.err != nil { + s.errorState = res.err + } return res.r, res.err } @@ -134,7 +142,7 @@ func (s *testTSOStreamSuite) getResult(ch <-chan callbackInvocation) callbackInv } } -func (s *testTSOStreamSuite) processRequestWithResultCh(count int64) <-chan callbackInvocation { +func (s *testTSOStreamSuite) processRequestWithResultCh(count int64) (<-chan callbackInvocation, error) { ch := make(chan callbackInvocation, 1) err := s.stream.processRequests(1, 2, 3, globalDCLocation, count, time.Now(), func(result tsoRequestResult, reqKeyspaceGroupID uint32, streamURL string, err error) { if err == nil { @@ -147,12 +155,20 @@ func (s *testTSOStreamSuite) processRequestWithResultCh(count int64) <-chan call err: err, } }) + if err != nil { + return nil, err + } + return ch, nil +} + +func (s *testTSOStreamSuite) mustProcessRequestWithResultCh(count int64) <-chan callbackInvocation { + ch, err := s.processRequestWithResultCh(count) s.NoError(err) return ch } func (s *testTSOStreamSuite) TestTSOStreamBasic() { - ch := s.processRequestWithResultCh(1) + ch := s.mustProcessRequestWithResultCh(1) s.noResult(ch) s.inner.returnResult(10, 1, 1) res := s.getResult(ch) @@ -163,7 +179,7 @@ func (s *testTSOStreamSuite) TestTSOStreamBasic() { s.Equal(int64(1), res.result.logical) s.Equal(uint32(1), res.result.count) - ch = s.processRequestWithResultCh(2) + ch = s.mustProcessRequestWithResultCh(2) s.noResult(ch) s.inner.returnResult(20, 3, 2) res = s.getResult(ch) @@ -174,7 +190,7 @@ func (s *testTSOStreamSuite) TestTSOStreamBasic() { s.Equal(int64(3), res.result.logical) s.Equal(uint32(2), res.result.count) - ch = s.processRequestWithResultCh(3) + ch = s.mustProcessRequestWithResultCh(3) s.noResult(ch) s.inner.returnError(errors.New("mock rpc error")) res = s.getResult(ch) @@ -193,7 +209,7 @@ func (s *testTSOStreamSuite) testTSOStreamBrokenImpl(err error, pendingRequests var resultCh []<-chan callbackInvocation for i := 0; i < pendingRequests; i++ { - ch := s.processRequestWithResultCh(1) + ch := s.mustProcessRequestWithResultCh(1) resultCh = append(resultCh, ch) s.noResult(ch) } @@ -239,9 +255,9 @@ func (s *testTSOStreamSuite) TestTSOStreamCanceledWithPendingReq() { func (s *testTSOStreamSuite) TestTSOStreamFIFO() { var resultChs []<-chan callbackInvocation - const COUNT = 5 - for i := 0; i < COUNT; i++ { - ch := s.processRequestWithResultCh(int64(i + 1)) + const count = 5 + for i := 0; i < count; i++ { + ch := s.mustProcessRequestWithResultCh(int64(i + 1)) resultChs = append(resultChs, ch) } @@ -249,7 +265,7 @@ func (s *testTSOStreamSuite) TestTSOStreamFIFO() { s.noResult(ch) } - for i := 0; i < COUNT; i++ { + for i := 0; i < count; i++ { s.inner.returnResult(int64((i+1)*10), int64(i), uint32(i+1)) } @@ -261,3 +277,58 @@ func (s *testTSOStreamSuite) TestTSOStreamFIFO() { s.Equal(uint32(i+1), res.result.count) } } + +func (s *testTSOStreamSuite) TestTSOStreamConcurrentRunning() { + resultChCh := make(chan (<-chan callbackInvocation), 10000) + const totalCount = 10000 + + // Continuously start requests + go func() { + for i := 1; i <= totalCount; i++ { + // Retry loop + for { + ch, err := s.processRequestWithResultCh(int64(i)) + if err != nil { + // If the capacity of the request queue is exhausted, it returns this error. As a test, we simply + // spin and retry it until it has enough space, as a coverage of the almost-full case. But note that + // this should not happen in production, in which case the caller of tsoStream should have its own + // limit of concurrent RPC requests. + s.Contains(err.Error(), "unexpected channel full") + continue + } + + resultChCh <- ch + break + } + } + }() + + // Continuously send results + go func() { + for i := int64(1); i <= totalCount; i++ { + s.inner.returnResult(i*10, i%(1<<18), uint32(i)) + } + s.inner.breakStream(io.EOF) + }() + + // Check results + for i := int64(1); i <= totalCount; i++ { + ch := <-resultChCh + res := s.getResult(ch) + s.NoError(res.err) + s.Equal(i*10, res.result.physical) + s.Equal(i%(1<<18), res.result.logical) + s.Equal(uint32(i), res.result.count) + } + + // After handling all these requests, the stream is ended by an EOF error. The next request won't succeed. + // So, either the `processRequests` function returns an error or the callback is called with an error. + ch, err := s.processRequestWithResultCh(1) + if err != nil { + s.ErrorIs(err, errs.ErrClientTSOStreamClosed) + } else { + res := s.getResult(ch) + s.Error(res.err) + s.ErrorIs(res.err, errs.ErrClientTSOStreamClosed) + } +} From 86d0f42d2fe76542f8c2fefe3efd582f87fd70dd Mon Sep 17 00:00:00 2001 From: MyonKeminta Date: Thu, 8 Aug 2024 17:22:48 +0800 Subject: [PATCH 04/21] Fix some of CI failures Signed-off-by: MyonKeminta --- client/tso_batch_controller.go | 18 ++++---- client/tso_dispatcher.go | 5 +- client/tso_stream.go | 11 ++--- client/tso_stream_test.go | 83 +++++++++++++++++----------------- 4 files changed, 56 insertions(+), 61 deletions(-) diff --git a/client/tso_batch_controller.go b/client/tso_batch_controller.go index bf072b24535..457ca7c1078 100644 --- a/client/tso_batch_controller.go +++ b/client/tso_batch_controller.go @@ -19,10 +19,7 @@ import ( "runtime/trace" "time" - "github.com/pingcap/errors" - "github.com/pingcap/log" "github.com/tikv/pd/client/tsoutil" - "go.uber.org/zap" ) type tsoBatchController struct { @@ -186,10 +183,11 @@ func (tbc *tsoBatchController) finishCollectedRequests(physical, firstLogical in tbc.collectedRequestCount = 0 } -func (tbc *tsoBatchController) clear() { - log.Info("[pd] clear the tso batch controller", - zap.Int("max-batch-size", tbc.maxBatchSize), zap.Int("best-batch-size", tbc.bestBatchSize), - zap.Int("collected-request-count", tbc.collectedRequestCount)) - tsoErr := errors.WithStack(errClosing) - tbc.finishCollectedRequests(0, 0, 0, tsoErr) -} +// +//func (tbc *tsoBatchController) clear() { +// log.Info("[pd] clear the tso batch controller", +// zap.Int("max-batch-size", tbc.maxBatchSize), zap.Int("best-batch-size", tbc.bestBatchSize), +// zap.Int("collected-request-count", tbc.collectedRequestCount)) +// tsoErr := errors.WithStack(errClosing) +// tbc.finishCollectedRequests(0, 0, 0, tsoErr) +//} diff --git a/client/tso_dispatcher.go b/client/tso_dispatcher.go index 2e670c656c9..c382bb3c699 100644 --- a/client/tso_dispatcher.go +++ b/client/tso_dispatcher.go @@ -197,7 +197,7 @@ func (td *tsoDispatcher) handleDispatcher(wg *sync.WaitGroup) { cc.(*tsoConnectionContext).cancel() return true }) - if batchController.collectedRequestCount != 0 { + if batchController != nil && batchController.collectedRequestCount != 0 { log.Fatal("batched tso requests not cleared when exiting the tso dispatcher loop") } tsoErr := errors.WithStack(errClosing) @@ -225,6 +225,7 @@ tsoBatchLoop: return default: } + batchController = td.batchBufferPool.Get().(*tsoBatchController) // Start to collect the TSO requests. maxBatchWaitInterval := option.getMaxTSOBatchWaitInterval() // Once the TSO requests are collected, must make sure they could be finished or revoked eventually, @@ -454,7 +455,7 @@ func (td *tsoDispatcher) processRequests( reqKeyspaceGroupID = svcDiscovery.GetKeyspaceGroupID() ) - cb := func(result tsoRequestResult, reqKeyspaceGroupID uint32, streamURL string, err error) { + cb := func(result tsoRequestResult, reqKeyspaceGroupID uint32, err error) { defer td.batchBufferPool.Put(tbc) if err != nil { td.cancelCollectedRequests(tbc, err) diff --git a/client/tso_stream.go b/client/tso_stream.go index 4e18b5e31d7..c58cf3eb841 100644 --- a/client/tso_stream.go +++ b/client/tso_stream.go @@ -178,7 +178,7 @@ func (s tsoTSOStreamAdapter) Recv() (tsoRequestResult, error) { }, nil } -type onFinishedCallback func(result tsoRequestResult, reqKeyspaceGroupID uint32, streamURL string, err error) +type onFinishedCallback func(result tsoRequestResult, reqKeyspaceGroupID uint32, err error) type batchedRequests struct { startTime time.Time @@ -203,8 +203,6 @@ type tsoStream struct { pendingRequests chan batchedRequests - estimateLatencyMicros atomic.Uint64 - cancel context.CancelFunc wg sync.WaitGroup @@ -265,7 +263,6 @@ func (s *tsoStream) processRequests( switch prevState { case streamStateIdle: // Expected case - break case streamStateClosing: s.state.Store(prevState) log.Info("tsoStream closed") @@ -318,7 +315,7 @@ func (s *tsoStream) recvLoop(ctx context.Context) { if hasReq { // There's an unfinished request, cancel it, otherwise it will be blocked forever. - currentReq.callback(tsoRequestResult{}, currentReq.reqKeyspaceGroupID, s.serverURL, finishWithErr) + currentReq.callback(tsoRequestResult{}, currentReq.reqKeyspaceGroupID, finishWithErr) } s.cancel() @@ -341,7 +338,7 @@ func (s *tsoStream) recvLoop(ctx context.Context) { // Cancel remaining pending requests. for req := range s.pendingRequests { - req.callback(tsoRequestResult{}, req.reqKeyspaceGroupID, s.serverURL, finishWithErr) + req.callback(tsoRequestResult{}, req.reqKeyspaceGroupID, finishWithErr) } s.wg.Done() @@ -398,7 +395,7 @@ recvLoop: break recvLoop } - currentReq.callback(res, currentReq.reqKeyspaceGroupID, s.serverURL, nil) + currentReq.callback(res, currentReq.reqKeyspaceGroupID, nil) // After finishing the requests, unset these variables which will be checked in the defer block. currentReq = batchedRequests{} hasReq = false diff --git a/client/tso_stream_test.go b/client/tso_stream_test.go index 46b308d1f2c..8bcb0c29a4b 100644 --- a/client/tso_stream_test.go +++ b/client/tso_stream_test.go @@ -21,6 +21,7 @@ import ( "time" "github.com/pingcap/errors" + "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" "github.com/tikv/pd/client/errs" ) @@ -46,7 +47,7 @@ func newMockTSOStreamImpl() *mockTSOStreamImpl { } } -func (s *mockTSOStreamImpl) Send(clusterID uint64, keyspaceID, keyspaceGroupID uint32, dcLocation string, count int64) error { +func (s *mockTSOStreamImpl) Send(_clusterID uint64, _keyspaceID, _keyspaceGroupID uint32, _dcLocation string, _count int64) error { s.requestCh <- struct{}{} return nil } @@ -96,19 +97,20 @@ func (s *mockTSOStreamImpl) stop() { } type callbackInvocation struct { - result tsoRequestResult - streamURL string - err error + result tsoRequestResult + err error } type testTSOStreamSuite struct { suite.Suite + re *require.Assertions inner *mockTSOStreamImpl stream *tsoStream } func (s *testTSOStreamSuite) SetupTest() { + s.re = require.New(s.T()) s.inner = newMockTSOStreamImpl() s.stream = newTSOStream("mock:///", s.inner) } @@ -127,7 +129,7 @@ func TestTSOStreamTestSuite(t *testing.T) { func (s *testTSOStreamSuite) noResult(ch <-chan callbackInvocation) { select { case res := <-ch: - s.FailNowf("result received unexpectedly", "received result: %+v", res) + s.re.FailNowf("result received unexpectedly", "received result: %+v", res) case <-time.After(time.Millisecond * 20): } } @@ -137,22 +139,21 @@ func (s *testTSOStreamSuite) getResult(ch <-chan callbackInvocation) callbackInv case res := <-ch: return res case <-time.After(time.Second * 10000): - s.FailNow("result not ready in time") + s.re.FailNow("result not ready in time") panic("result not ready in time") } } func (s *testTSOStreamSuite) processRequestWithResultCh(count int64) (<-chan callbackInvocation, error) { ch := make(chan callbackInvocation, 1) - err := s.stream.processRequests(1, 2, 3, globalDCLocation, count, time.Now(), func(result tsoRequestResult, reqKeyspaceGroupID uint32, streamURL string, err error) { + err := s.stream.processRequests(1, 2, 3, globalDCLocation, count, time.Now(), func(result tsoRequestResult, reqKeyspaceGroupID uint32, err error) { if err == nil { - s.Equal(uint32(3), reqKeyspaceGroupID) - s.Equal(uint32(0), result.suffixBits) + s.re.Equal(uint32(3), reqKeyspaceGroupID) + s.re.Equal(uint32(0), result.suffixBits) } ch <- callbackInvocation{ - result: result, - streamURL: streamURL, - err: err, + result: result, + err: err, } }) if err != nil { @@ -163,7 +164,7 @@ func (s *testTSOStreamSuite) processRequestWithResultCh(count int64) (<-chan cal func (s *testTSOStreamSuite) mustProcessRequestWithResultCh(count int64) <-chan callbackInvocation { ch, err := s.processRequestWithResultCh(count) - s.NoError(err) + s.re.NoError(err) return ch } @@ -173,36 +174,34 @@ func (s *testTSOStreamSuite) TestTSOStreamBasic() { s.inner.returnResult(10, 1, 1) res := s.getResult(ch) - s.NoError(res.err) - s.Equal("mock:///", res.streamURL) - s.Equal(int64(10), res.result.physical) - s.Equal(int64(1), res.result.logical) - s.Equal(uint32(1), res.result.count) + s.re.NoError(res.err) + s.re.Equal(int64(10), res.result.physical) + s.re.Equal(int64(1), res.result.logical) + s.re.Equal(uint32(1), res.result.count) ch = s.mustProcessRequestWithResultCh(2) s.noResult(ch) s.inner.returnResult(20, 3, 2) res = s.getResult(ch) - s.NoError(res.err) - s.Equal("mock:///", res.streamURL) - s.Equal(int64(20), res.result.physical) - s.Equal(int64(3), res.result.logical) - s.Equal(uint32(2), res.result.count) + s.re.NoError(res.err) + s.re.Equal(int64(20), res.result.physical) + s.re.Equal(int64(3), res.result.logical) + s.re.Equal(uint32(2), res.result.count) ch = s.mustProcessRequestWithResultCh(3) s.noResult(ch) s.inner.returnError(errors.New("mock rpc error")) res = s.getResult(ch) - s.Error(res.err) - s.Equal("mock rpc error", res.err.Error()) + s.re.Error(res.err) + s.re.Equal("mock rpc error", res.err.Error()) // After an error from the (simulated) RPC stream, the tsoStream should be in a broken status and can't accept // new request anymore. - err := s.stream.processRequests(1, 2, 3, globalDCLocation, 1, time.Now(), func(result tsoRequestResult, reqKeyspaceGroupID uint32, streamURL string, err error) { + err := s.stream.processRequests(1, 2, 3, globalDCLocation, 1, time.Now(), func(_result tsoRequestResult, _reqKeyspaceGroupID uint32, _err error) { panic("unreachable") }) - s.Error(err) + s.re.Error(err) } func (s *testTSOStreamSuite) testTSOStreamBrokenImpl(err error, pendingRequests int) { @@ -223,16 +222,16 @@ func (s *testTSOStreamSuite) testTSOStreamBrokenImpl(err error, pendingRequests select { case <-closedCh: case <-time.After(time.Second): - s.FailNow("stream receiver loop didn't exit") + s.re.FailNow("stream receiver loop didn't exit") } for _, ch := range resultCh { res := s.getResult(ch) - s.Error(res.err) + s.re.Error(res.err) if err == io.EOF { - s.ErrorIs(res.err, errs.ErrClientTSOStreamClosed) + s.re.ErrorIs(res.err, errs.ErrClientTSOStreamClosed) } else { - s.ErrorIs(res.err, err) + s.re.ErrorIs(res.err, err) } } } @@ -271,10 +270,10 @@ func (s *testTSOStreamSuite) TestTSOStreamFIFO() { for i, ch := range resultChs { res := s.getResult(ch) - s.NoError(res.err) - s.Equal(int64((i+1)*10), res.result.physical) - s.Equal(int64(i), res.result.logical) - s.Equal(uint32(i+1), res.result.count) + s.re.NoError(res.err) + s.re.Equal(int64((i+1)*10), res.result.physical) + s.re.Equal(int64(i), res.result.logical) + s.re.Equal(uint32(i+1), res.result.count) } } @@ -315,20 +314,20 @@ func (s *testTSOStreamSuite) TestTSOStreamConcurrentRunning() { for i := int64(1); i <= totalCount; i++ { ch := <-resultChCh res := s.getResult(ch) - s.NoError(res.err) - s.Equal(i*10, res.result.physical) - s.Equal(i%(1<<18), res.result.logical) - s.Equal(uint32(i), res.result.count) + s.re.NoError(res.err) + s.re.Equal(i*10, res.result.physical) + s.re.Equal(i%(1<<18), res.result.logical) + s.re.Equal(uint32(i), res.result.count) } // After handling all these requests, the stream is ended by an EOF error. The next request won't succeed. // So, either the `processRequests` function returns an error or the callback is called with an error. ch, err := s.processRequestWithResultCh(1) if err != nil { - s.ErrorIs(err, errs.ErrClientTSOStreamClosed) + s.re.ErrorIs(err, errs.ErrClientTSOStreamClosed) } else { res := s.getResult(ch) - s.Error(res.err) - s.ErrorIs(res.err, errs.ErrClientTSOStreamClosed) + s.re.Error(res.err) + s.re.ErrorIs(res.err, errs.ErrClientTSOStreamClosed) } } From ff0f11a6734fa654d77303e47874b3a893a5aca4 Mon Sep 17 00:00:00 2001 From: MyonKeminta Date: Fri, 30 Aug 2024 17:21:56 +0800 Subject: [PATCH 05/21] Fix integration tests Signed-off-by: MyonKeminta --- client/tso_stream.go | 18 ++++++++++++------ .../mcs/tso/keyspace_group_manager_test.go | 7 +++++-- 2 files changed, 17 insertions(+), 8 deletions(-) diff --git a/client/tso_stream.go b/client/tso_stream.go index c58cf3eb841..5d790e7b843 100644 --- a/client/tso_stream.go +++ b/client/tso_stream.go @@ -207,7 +207,8 @@ type tsoStream struct { wg sync.WaitGroup // For syncing between sender and receiver to guarantee all requests are finished when closing. - state atomic.Int32 + state atomic.Int32 + stoppedWithErr error ongoingRequestCountGauge prometheus.Gauge ongoingRequests atomic.Int32 @@ -265,8 +266,12 @@ func (s *tsoStream) processRequests( // Expected case case streamStateClosing: s.state.Store(prevState) - log.Info("tsoStream closed") - return errs.ErrClientTSOStreamClosed + err := s.stoppedWithErr + log.Info("sending to closed tsoStream", zap.Error(err)) + if err == nil { + err = errors.WithStack(errs.ErrClientTSOStreamClosed) + } + return err case streamStateSending: log.Fatal("unexpected concurrent sending on tsoStream", zap.String("stream", s.streamID)) default: @@ -288,7 +293,7 @@ func (s *tsoStream) processRequests( if err := s.stream.Send(clusterID, keyspaceID, keyspaceGroupID, dcLocation, count); err != nil { if err == io.EOF { - return errs.ErrClientTSOStreamClosed + return errors.WithStack(errs.ErrClientTSOStreamClosed) } return errors.WithStack(err) } @@ -318,6 +323,7 @@ func (s *tsoStream) recvLoop(ctx context.Context) { currentReq.callback(tsoRequestResult{}, currentReq.reqKeyspaceGroupID, finishWithErr) } + s.stoppedWithErr = errors.WithStack(finishWithErr) s.cancel() for !s.state.CompareAndSwap(streamStateIdle, streamStateClosing) { switch state := s.state.Load(); state { @@ -338,7 +344,7 @@ func (s *tsoStream) recvLoop(ctx context.Context) { // Cancel remaining pending requests. for req := range s.pendingRequests { - req.callback(tsoRequestResult{}, req.reqKeyspaceGroupID, finishWithErr) + req.callback(tsoRequestResult{}, req.reqKeyspaceGroupID, errors.WithStack(finishWithErr)) } s.wg.Done() @@ -376,7 +382,7 @@ recvLoop: requestFailedDurationTSO.Observe(durationSeconds) } if err == io.EOF { - finishWithErr = errs.ErrClientTSOStreamClosed + finishWithErr = errors.WithStack(errs.ErrClientTSOStreamClosed) } else { finishWithErr = errors.WithStack(err) } diff --git a/tests/integrations/mcs/tso/keyspace_group_manager_test.go b/tests/integrations/mcs/tso/keyspace_group_manager_test.go index 2acbbcc3b42..eae586b5b20 100644 --- a/tests/integrations/mcs/tso/keyspace_group_manager_test.go +++ b/tests/integrations/mcs/tso/keyspace_group_manager_test.go @@ -16,6 +16,8 @@ package tso import ( "context" + "errors" + "fmt" "math/rand" "net/http" "strings" @@ -472,10 +474,11 @@ func (suite *tsoKeyspaceGroupManagerTestSuite) dispatchClient( strings.Contains(errMsg, clierrs.NotLeaderErr) || strings.Contains(errMsg, clierrs.NotServedErr) || strings.Contains(errMsg, "ErrKeyspaceNotAssigned") || - strings.Contains(errMsg, "ErrKeyspaceGroupIsMerging") { + strings.Contains(errMsg, "ErrKeyspaceGroupIsMerging") || + errors.Is(err, clierrs.ErrClientTSOStreamClosed) { continue } - re.FailNow(errMsg) + re.FailNow(fmt.Sprintf("%+v", err)) } if physical == lastPhysical { re.Greater(logical, lastLogical) From ef47b9d9e50c1a7dbdc7a6c2055833d925bc46f8 Mon Sep 17 00:00:00 2001 From: MyonKeminta Date: Mon, 2 Sep 2024 16:38:29 +0800 Subject: [PATCH 06/21] Fix tso deadline Signed-off-by: MyonKeminta --- client/tso_dispatcher.go | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/client/tso_dispatcher.go b/client/tso_dispatcher.go index c382bb3c699..5fbf47f927b 100644 --- a/client/tso_dispatcher.go +++ b/client/tso_dispatcher.go @@ -225,7 +225,12 @@ tsoBatchLoop: return default: } - batchController = td.batchBufferPool.Get().(*tsoBatchController) + + // In case error happens, the loop may continue without resetting `batchController` for retrying. + if batchController == nil { + batchController = td.batchBufferPool.Get().(*tsoBatchController) + } + // Start to collect the TSO requests. maxBatchWaitInterval := option.getMaxTSOBatchWaitInterval() // Once the TSO requests are collected, must make sure they could be finished or revoked eventually, @@ -309,8 +314,7 @@ tsoBatchLoop: case td.tsDeadlineCh <- dl: } // processRequests guarantees that the collected requests could be finished properly. - err = td.processRequests(stream, dc, batchController) - close(done) + err = td.processRequests(stream, dc, batchController, done) // If error happens during tso stream handling, reset stream and run the next trial. if err == nil { // If the request is started successfully, the `batchController` will be put back to the pool when the @@ -425,8 +429,9 @@ func chooseStream(connectionCtxs *sync.Map) (connectionCtx *tsoConnectionContext } func (td *tsoDispatcher) processRequests( - stream *tsoStream, dcLocation string, tbc *tsoBatchController, + stream *tsoStream, dcLocation string, tbc *tsoBatchController, done chan struct{}, ) error { + // `done` must be guaranteed to be eventually called. var ( requests = tbc.getCollectedRequests() traceRegions = make([]*trace.Region, 0, len(requests)) @@ -456,6 +461,8 @@ func (td *tsoDispatcher) processRequests( ) cb := func(result tsoRequestResult, reqKeyspaceGroupID uint32, err error) { + close(done) + defer td.batchBufferPool.Put(tbc) if err != nil { td.cancelCollectedRequests(tbc, err) @@ -480,6 +487,8 @@ func (td *tsoDispatcher) processRequests( clusterID, keyspaceID, reqKeyspaceGroupID, dcLocation, count, tbc.extraBatchingStartTime, cb) if err != nil { + close(done) + td.cancelCollectedRequests(tbc, err) return err } From ff1d50d0377dcc6bc34d08d1bffa0dc3a3897072 Mon Sep 17 00:00:00 2001 From: MyonKeminta Date: Mon, 2 Sep 2024 18:01:45 +0800 Subject: [PATCH 07/21] Fix stream broken handling Signed-off-by: MyonKeminta --- client/tso_dispatcher.go | 82 ++++++++++++++++++++++++++-------------- client/tso_stream.go | 14 +++++-- 2 files changed, 65 insertions(+), 31 deletions(-) diff --git a/client/tso_dispatcher.go b/client/tso_dispatcher.go index 5fbf47f927b..9b8e66b553c 100644 --- a/client/tso_dispatcher.go +++ b/client/tso_dispatcher.go @@ -198,7 +198,7 @@ func (td *tsoDispatcher) handleDispatcher(wg *sync.WaitGroup) { return true }) if batchController != nil && batchController.collectedRequestCount != 0 { - log.Fatal("batched tso requests not cleared when exiting the tso dispatcher loop") + log.Fatal("batched tso requests not cleared when exiting the tso dispatcher loop", zap.Any("panic", recover())) } tsoErr := errors.WithStack(errClosing) td.revokePendingRequests(tsoErr) @@ -301,8 +301,20 @@ tsoBatchLoop: stream = nil continue default: - break streamChoosingLoop } + + // Check if any error has occurred on this stream when receiving asynchronously. + if err = stream.GetRecvError(); err != nil { + exit := !td.handleProcessRequestError(ctx, bo, streamURL, cancel, err) + stream = nil + if exit { + td.cancelCollectedRequests(batchController, errors.WithStack(ctx.Err())) + return + } + continue + } + + break streamChoosingLoop } done := make(chan struct{}) dl := newTSDeadline(option.timeout, done, cancel) @@ -324,38 +336,52 @@ tsoBatchLoop: // reused in the next loop safely. batchController = nil } else { + exit := !td.handleProcessRequestError(ctx, bo, streamURL, cancel, err) + stream = nil + if exit { + return + } + } + } +} + +// handleProcessRequestError handles errors occurs when trying to process a TSO RPC request for the dispatcher loop. +// Returns true if the dispatcher loop is ok to continue. Otherwise, the dispatcher loop should be exited. +func (td *tsoDispatcher) handleProcessRequestError(ctx context.Context, bo *retry.Backoffer, streamURL string, streamCancelFunc context.CancelFunc, err error) bool { + select { + case <-ctx.Done(): + return false + default: + } + + svcDiscovery := td.provider.getServiceDiscovery() + + svcDiscovery.ScheduleCheckMemberChanged() + log.Error("[tso] getTS error after processing requests", + zap.String("dc-location", td.dc), + zap.String("stream-url", streamURL), + zap.Error(errs.ErrClientGetTSO.FastGenByArgs(err.Error()))) + // Set `stream` to nil and remove this stream from the `connectionCtxs` due to error. + td.connectionCtxs.Delete(streamURL) + streamCancelFunc() + // Because ScheduleCheckMemberChanged is asynchronous, if the leader changes, we better call `updateMember` ASAP. + if errs.IsLeaderChange(err) { + if err := bo.Exec(ctx, svcDiscovery.CheckMemberChanged); err != nil { select { case <-ctx.Done(): - return + return false default: } - svcDiscovery.ScheduleCheckMemberChanged() - log.Error("[tso] getTS error after processing requests", - zap.String("dc-location", dc), - zap.String("stream-url", streamURL), - zap.Error(errs.ErrClientGetTSO.FastGenByArgs(err.Error()))) - // Set `stream` to nil and remove this stream from the `connectionCtxs` due to error. - connectionCtxs.Delete(streamURL) - cancel() - stream = nil - // Because ScheduleCheckMemberChanged is asynchronous, if the leader changes, we better call `updateMember` ASAP. - if errs.IsLeaderChange(err) { - if err := bo.Exec(ctx, svcDiscovery.CheckMemberChanged); err != nil { - select { - case <-ctx.Done(): - return - default: - } - } - // Because the TSO Follower Proxy could be configured online, - // If we change it from on -> off, background updateConnectionCtxs - // will cancel the current stream, then the EOF error caused by cancel() - // should not trigger the updateConnectionCtxs here. - // So we should only call it when the leader changes. - provider.updateConnectionCtxs(ctx, dc, connectionCtxs) - } } + // Because the TSO Follower Proxy could be configured online, + // If we change it from on -> off, background updateConnectionCtxs + // will cancel the current stream, then the EOF error caused by cancel() + // should not trigger the updateConnectionCtxs here. + // So we should only call it when the leader changes. + td.provider.updateConnectionCtxs(ctx, td.dc, td.connectionCtxs) } + + return true } // updateConnectionCtxs updates the `connectionCtxs` for the specified DC location regularly. diff --git a/client/tso_stream.go b/client/tso_stream.go index 5d790e7b843..7394d8fe4ab 100644 --- a/client/tso_stream.go +++ b/client/tso_stream.go @@ -208,7 +208,7 @@ type tsoStream struct { // For syncing between sender and receiver to guarantee all requests are finished when closing. state atomic.Int32 - stoppedWithErr error + stoppedWithErr atomic.Pointer[error] ongoingRequestCountGauge prometheus.Gauge ongoingRequests atomic.Int32 @@ -266,7 +266,7 @@ func (s *tsoStream) processRequests( // Expected case case streamStateClosing: s.state.Store(prevState) - err := s.stoppedWithErr + err := s.GetRecvError() log.Info("sending to closed tsoStream", zap.Error(err)) if err == nil { err = errors.WithStack(errs.ErrClientTSOStreamClosed) @@ -323,7 +323,7 @@ func (s *tsoStream) recvLoop(ctx context.Context) { currentReq.callback(tsoRequestResult{}, currentReq.reqKeyspaceGroupID, finishWithErr) } - s.stoppedWithErr = errors.WithStack(finishWithErr) + s.stoppedWithErr.Store(&finishWithErr) s.cancel() for !s.state.CompareAndSwap(streamStateIdle, streamStateClosing) { switch state := s.state.Load(); state { @@ -410,6 +410,14 @@ recvLoop: } } +func (s *tsoStream) GetRecvError() error { + perr := s.stoppedWithErr.Load() + if perr == nil { + return nil + } + return *perr +} + // WaitForClosed blocks until the stream is closed and the inner loop exits. func (s *tsoStream) WaitForClosed() { s.wg.Wait() From 9076c9dcf78880eab2e9c3351b0f03f5ab2756bc Mon Sep 17 00:00:00 2001 From: MyonKeminta Date: Mon, 2 Sep 2024 18:10:40 +0800 Subject: [PATCH 08/21] Fix comments Signed-off-by: MyonKeminta --- client/tso_batch_controller.go | 9 --------- client/tso_dispatcher.go | 4 ++++ client/tso_stream.go | 1 + 3 files changed, 5 insertions(+), 9 deletions(-) diff --git a/client/tso_batch_controller.go b/client/tso_batch_controller.go index 457ca7c1078..6dc0b57ddb8 100644 --- a/client/tso_batch_controller.go +++ b/client/tso_batch_controller.go @@ -182,12 +182,3 @@ func (tbc *tsoBatchController) finishCollectedRequests(physical, firstLogical in // Prevent the finished requests from being processed again. tbc.collectedRequestCount = 0 } - -// -//func (tbc *tsoBatchController) clear() { -// log.Info("[pd] clear the tso batch controller", -// zap.Int("max-batch-size", tbc.maxBatchSize), zap.Int("best-batch-size", tbc.bestBatchSize), -// zap.Int("collected-request-count", tbc.collectedRequestCount)) -// tsoErr := errors.WithStack(errClosing) -// tbc.finishCollectedRequests(0, 0, 0, tsoErr) -//} diff --git a/client/tso_dispatcher.go b/client/tso_dispatcher.go index 9b8e66b553c..e6a5588aa84 100644 --- a/client/tso_dispatcher.go +++ b/client/tso_dispatcher.go @@ -454,6 +454,10 @@ func chooseStream(connectionCtxs *sync.Map) (connectionCtx *tsoConnectionContext return connectionCtx } +// processRequests sends the RPC request for the batch. It's guaranteed that after calling this function, requests +// in the batch must be eventually finished (done or canceled), either synchronously or asynchronously. +// `close(done)` will be called at the same time when finishing the requests. +// If this function returns a non-nil error, the requests will always be canceled synchronously. func (td *tsoDispatcher) processRequests( stream *tsoStream, dcLocation string, tbc *tsoBatchController, done chan struct{}, ) error { diff --git a/client/tso_stream.go b/client/tso_stream.go index 7394d8fe4ab..1f9038512ff 100644 --- a/client/tso_stream.go +++ b/client/tso_stream.go @@ -410,6 +410,7 @@ recvLoop: } } +// GetRecvError returns the error (if any) that has been encountered when receiving response asynchronously. func (s *tsoStream) GetRecvError() error { perr := s.stoppedWithErr.Load() if perr == nil { From 39e4e0b05ee773857e71c40202513c444c4c5d5a Mon Sep 17 00:00:00 2001 From: MyonKeminta Date: Wed, 4 Sep 2024 17:04:54 +0800 Subject: [PATCH 09/21] try to avoid double invocation to the callback Signed-off-by: MyonKeminta --- client/tso_stream.go | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/client/tso_stream.go b/client/tso_stream.go index 1f9038512ff..bbd1ea4ec87 100644 --- a/client/tso_stream.go +++ b/client/tso_stream.go @@ -292,10 +292,14 @@ func (s *tsoStream) processRequests( s.state.Store(prevState) if err := s.stream.Send(clusterID, keyspaceID, keyspaceGroupID, dcLocation, count); err != nil { - if err == io.EOF { - return errors.WithStack(errs.ErrClientTSOStreamClosed) - } - return errors.WithStack(err) + // As the request is already put into `pendingRequests`, the request should finally be canceled by the recvLoop. + // So skip returning error here to avoid + //if err == io.EOF { + // return errors.WithStack(errs.ErrClientTSOStreamClosed) + //} + //return errors.WithStack(err) + log.Warn("failed to send RPC request through tsoStream", zap.String("stream", s.streamID), zap.Error(err)) + return nil } tsoBatchSendLatency.Observe(time.Since(batchStartTime).Seconds()) s.ongoingRequestCountGauge.Set(float64(s.ongoingRequests.Add(1))) From 632781b2047f73265d26fdec890ce47cee8fac41 Mon Sep 17 00:00:00 2001 From: MyonKeminta Date: Wed, 4 Sep 2024 17:17:21 +0800 Subject: [PATCH 10/21] Fix lint Signed-off-by: MyonKeminta --- client/tso_stream.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/client/tso_stream.go b/client/tso_stream.go index bbd1ea4ec87..279f11db47c 100644 --- a/client/tso_stream.go +++ b/client/tso_stream.go @@ -294,10 +294,10 @@ func (s *tsoStream) processRequests( if err := s.stream.Send(clusterID, keyspaceID, keyspaceGroupID, dcLocation, count); err != nil { // As the request is already put into `pendingRequests`, the request should finally be canceled by the recvLoop. // So skip returning error here to avoid - //if err == io.EOF { - // return errors.WithStack(errs.ErrClientTSOStreamClosed) - //} - //return errors.WithStack(err) + // if err == io.EOF { + // return errors.WithStack(errs.ErrClientTSOStreamClosed) + // } + // return errors.WithStack(err) log.Warn("failed to send RPC request through tsoStream", zap.String("stream", s.streamID), zap.Error(err)) return nil } From 8809c45a00a40904de478af5d34260d3576909fd Mon Sep 17 00:00:00 2001 From: MyonKeminta Date: Fri, 6 Sep 2024 18:16:44 +0800 Subject: [PATCH 11/21] Address comments Signed-off-by: MyonKeminta --- client/tso_batch_controller.go | 7 +++---- client/tso_dispatcher.go | 26 ++++++++++++++------------ client/tso_request.go | 3 +++ client/tso_stream.go | 2 ++ 4 files changed, 22 insertions(+), 16 deletions(-) diff --git a/client/tso_batch_controller.go b/client/tso_batch_controller.go index 6dc0b57ddb8..32191889160 100644 --- a/client/tso_batch_controller.go +++ b/client/tso_batch_controller.go @@ -56,9 +56,7 @@ func (tbc *tsoBatchController) fetchPendingRequests(ctx context.Context, tsoRequ if tokenAcquired { tokenCh <- struct{}{} } - if tbc.collectedRequestCount > 0 { - tbc.finishCollectedRequests(0, 0, 0, errRet) - } + tbc.finishCollectedRequests(0, 0, 0, invalidStreamID, errRet) } }() @@ -170,12 +168,13 @@ func (tbc *tsoBatchController) adjustBestBatchSize() { } } -func (tbc *tsoBatchController) finishCollectedRequests(physical, firstLogical int64, suffixBits uint32, err error) { +func (tbc *tsoBatchController) finishCollectedRequests(physical, firstLogical int64, suffixBits uint32, streamID string, err error) { for i := 0; i < tbc.collectedRequestCount; i++ { tsoReq := tbc.collectedRequests[i] // Retrieve the request context before the request is done to trace without race. requestCtx := tsoReq.requestCtx tsoReq.physical, tsoReq.logical = physical, tsoutil.AddLogical(firstLogical, int64(i), suffixBits) + tsoReq.streamID = streamID tsoReq.tryDone(err) trace.StartRegion(requestCtx, "pdclient.tsoReqDequeue").End() } diff --git a/client/tso_dispatcher.go b/client/tso_dispatcher.go index e6a5588aa84..efbe246a82a 100644 --- a/client/tso_dispatcher.go +++ b/client/tso_dispatcher.go @@ -104,7 +104,9 @@ func newTSODispatcher( tsoRequestCh = make(chan *tsoRequest, 1) }) - tokenCh := make(chan struct{}, 64) + // A large-enough capacity to hold maximum concurrent RPC requests. In our design, the concurrency is at most 16. + const tokenChCapacity = 64 + tokenCh := make(chan struct{}, tokenChCapacity) td := &tsoDispatcher{ ctx: dispatcherCtx, @@ -276,7 +278,7 @@ tsoBatchLoop: select { case <-ctx.Done(): // Finish the collected requests if the context is canceled. - td.cancelCollectedRequests(batchController, errors.WithStack(ctx.Err())) + td.cancelCollectedRequests(batchController, invalidStreamID, errors.WithStack(ctx.Err())) timer.Stop() return case <-streamLoopTimer.C: @@ -284,7 +286,7 @@ tsoBatchLoop: log.Error("[tso] create tso stream error", zap.String("dc-location", dc), errs.ZapError(err)) svcDiscovery.ScheduleCheckMemberChanged() // Finish the collected requests if the stream is failed to be created. - td.cancelCollectedRequests(batchController, errors.WithStack(err)) + td.cancelCollectedRequests(batchController, invalidStreamID, errors.WithStack(err)) timer.Stop() continue tsoBatchLoop case <-timer.C: @@ -308,7 +310,7 @@ tsoBatchLoop: exit := !td.handleProcessRequestError(ctx, bo, streamURL, cancel, err) stream = nil if exit { - td.cancelCollectedRequests(batchController, errors.WithStack(ctx.Err())) + td.cancelCollectedRequests(batchController, invalidStreamID, errors.WithStack(ctx.Err())) return } continue @@ -321,7 +323,7 @@ tsoBatchLoop: select { case <-ctx.Done(): // Finish the collected requests if the context is canceled. - td.cancelCollectedRequests(batchController, errors.WithStack(ctx.Err())) + td.cancelCollectedRequests(batchController, invalidStreamID, errors.WithStack(ctx.Err())) return case td.tsDeadlineCh <- dl: } @@ -495,7 +497,7 @@ func (td *tsoDispatcher) processRequests( defer td.batchBufferPool.Put(tbc) if err != nil { - td.cancelCollectedRequests(tbc, err) + td.cancelCollectedRequests(tbc, stream.streamID, err) return } @@ -510,7 +512,7 @@ func (td *tsoDispatcher) processRequests( // `logical` is the largest ts's logical part here, we need to do the subtracting before we finish each TSO request. firstLogical := tsoutil.AddLogical(result.logical, -int64(result.count)+1, result.suffixBits) td.compareAndSwapTS(curTSOInfo, firstLogical) - td.doneCollectedRequests(tbc, result.physical, firstLogical, result.suffixBits) + td.doneCollectedRequests(tbc, result.physical, firstLogical, result.suffixBits, stream.streamID) } err := stream.processRequests( @@ -519,20 +521,20 @@ func (td *tsoDispatcher) processRequests( if err != nil { close(done) - td.cancelCollectedRequests(tbc, err) + td.cancelCollectedRequests(tbc, stream.streamID, err) return err } return nil } -func (td *tsoDispatcher) cancelCollectedRequests(tbc *tsoBatchController, err error) { +func (td *tsoDispatcher) cancelCollectedRequests(tbc *tsoBatchController, streamID string, err error) { td.tokenCh <- struct{}{} - tbc.finishCollectedRequests(0, 0, 0, err) + tbc.finishCollectedRequests(0, 0, 0, streamID, err) } -func (td *tsoDispatcher) doneCollectedRequests(tbc *tsoBatchController, physical int64, firstLogical int64, suffixBits uint32) { +func (td *tsoDispatcher) doneCollectedRequests(tbc *tsoBatchController, physical int64, firstLogical int64, suffixBits uint32, streamID string) { td.tokenCh <- struct{}{} - tbc.finishCollectedRequests(physical, firstLogical, suffixBits, nil) + tbc.finishCollectedRequests(physical, firstLogical, suffixBits, streamID, nil) } func (td *tsoDispatcher) compareAndSwapTS( diff --git a/client/tso_request.go b/client/tso_request.go index b912fa35497..fb2ae2bb92e 100644 --- a/client/tso_request.go +++ b/client/tso_request.go @@ -42,6 +42,9 @@ type tsoRequest struct { logical int64 dcLocation string + // The identifier of the RPC stream in which the request is processed. + streamID string + // Runtime fields. start time.Time pool *sync.Pool diff --git a/client/tso_stream.go b/client/tso_stream.go index 279f11db47c..bff311eb285 100644 --- a/client/tso_stream.go +++ b/client/tso_stream.go @@ -222,6 +222,8 @@ const ( var streamIDAlloc atomic.Int32 +const invalidStreamID = "" + // TODO: Pass a context? func newTSOStream(serverURL string, stream grpcTSOStreamAdapter) *tsoStream { streamID := fmt.Sprintf("%s-%d", serverURL, streamIDAlloc.Add(1)) From ca64ed0d192254c0f2f8e5124a49c9f5669556a1 Mon Sep 17 00:00:00 2001 From: MyonKeminta Date: Mon, 9 Sep 2024 18:22:59 +0800 Subject: [PATCH 12/21] Pass context outside the tsoStream (push for running CI) Signed-off-by: MyonKeminta --- client/tso_stream.go | 11 ++++++----- client/tso_stream_test.go | 2 +- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/client/tso_stream.go b/client/tso_stream.go index bff311eb285..726bf502485 100644 --- a/client/tso_stream.go +++ b/client/tso_stream.go @@ -68,7 +68,7 @@ func (b *pdTSOStreamBuilder) build(ctx context.Context, cancel context.CancelFun stream, err := b.client.Tso(ctx) done <- struct{}{} if err == nil { - return newTSOStream(b.serverURL, pdTSOStreamAdapter{stream}), nil + return newTSOStream(ctx, b.serverURL, pdTSOStreamAdapter{stream}), nil } return nil, err } @@ -87,7 +87,7 @@ func (b *tsoTSOStreamBuilder) build( stream, err := b.client.Tso(ctx) done <- struct{}{} if err == nil { - return newTSOStream(b.serverURL, tsoTSOStreamAdapter{stream}), nil + return newTSOStream(ctx, b.serverURL, tsoTSOStreamAdapter{stream}), nil } return nil, err } @@ -224,10 +224,11 @@ var streamIDAlloc atomic.Int32 const invalidStreamID = "" -// TODO: Pass a context? -func newTSOStream(serverURL string, stream grpcTSOStreamAdapter) *tsoStream { +func newTSOStream(ctx context.Context, serverURL string, stream grpcTSOStreamAdapter) *tsoStream { streamID := fmt.Sprintf("%s-%d", serverURL, streamIDAlloc.Add(1)) - ctx, cancel := context.WithCancel(context.Background()) + // To make error handling in `tsoDispatcher` work, the internal `cancel` and external `cancel` is better to be + // distinguished. + ctx, cancel := context.WithCancel(ctx) s := &tsoStream{ serverURL: serverURL, stream: stream, diff --git a/client/tso_stream_test.go b/client/tso_stream_test.go index 8bcb0c29a4b..e0b3beaa010 100644 --- a/client/tso_stream_test.go +++ b/client/tso_stream_test.go @@ -112,7 +112,7 @@ type testTSOStreamSuite struct { func (s *testTSOStreamSuite) SetupTest() { s.re = require.New(s.T()) s.inner = newMockTSOStreamImpl() - s.stream = newTSOStream("mock:///", s.inner) + s.stream = newTSOStream(context.Background(), "mock:///", s.inner) } func (s *testTSOStreamSuite) TearDownTest() { From 67967e4f5eee7ae5363c86dd41cf7bbb34d9dfcf Mon Sep 17 00:00:00 2001 From: MyonKeminta Date: Wed, 11 Sep 2024 19:58:15 +0800 Subject: [PATCH 13/21] client: Add benchmark for tsoStream and dispatcher Signed-off-by: MyonKeminta --- client/tso_dispatcher_test.go | 107 +++++++++++++++++++ client/tso_stream_test.go | 188 ++++++++++++++++++++++++++++++++++ 2 files changed, 295 insertions(+) create mode 100644 client/tso_dispatcher_test.go create mode 100644 client/tso_stream_test.go diff --git a/client/tso_dispatcher_test.go b/client/tso_dispatcher_test.go new file mode 100644 index 00000000000..b1a87d0a2cd --- /dev/null +++ b/client/tso_dispatcher_test.go @@ -0,0 +1,107 @@ +// Copyright 2024 TiKV Project Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package pd + +import ( + "context" + "fmt" + "sync" + "testing" + "time" + + "github.com/pingcap/log" + "go.uber.org/zap/zapcore" +) + +type mockTSOServiceProvider struct { + option *option +} + +func newMockTSOServiceProvider(option *option) *mockTSOServiceProvider { + return &mockTSOServiceProvider{ + option: option, + } +} + +func (m *mockTSOServiceProvider) getOption() *option { + return m.option +} + +func (m *mockTSOServiceProvider) getServiceDiscovery() ServiceDiscovery { + return NewMockPDServiceDiscovery([]string{mockStreamURL}, nil) +} + +func (m *mockTSOServiceProvider) updateConnectionCtxs(ctx context.Context, _dc string, connectionCtxs *sync.Map) bool { + _, ok := connectionCtxs.Load(mockStreamURL) + if ok { + return true + } + ctx, cancel := context.WithCancel(ctx) + stream := &tsoStream{ + serverURL: mockStreamURL, + stream: newMockTSOStreamImpl(ctx, true), + } + connectionCtxs.LoadOrStore(mockStreamURL, &tsoConnectionContext{ctx, cancel, mockStreamURL, stream}) + return true +} + +func BenchmarkTSODispatcherHandleRequests(b *testing.B) { + log.SetLevel(zapcore.FatalLevel) + + ctx := context.Background() + + reqPool := &sync.Pool{ + New: func() any { + return &tsoRequest{ + done: make(chan error, 1), + physical: 0, + logical: 0, + dcLocation: globalDCLocation, + } + }, + } + getReq := func() *tsoRequest { + req := reqPool.Get().(*tsoRequest) + req.clientCtx = ctx + req.requestCtx = ctx + req.physical = 0 + req.logical = 0 + req.start = time.Now() + req.pool = reqPool + return req + } + + dispatcher := newTSODispatcher(ctx, globalDCLocation, defaultMaxTSOBatchSize, newMockTSOServiceProvider(newOption())) + var wg sync.WaitGroup + wg.Add(1) + + go dispatcher.handleDispatcher(&wg) + defer func() { + dispatcher.close() + wg.Wait() + }() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + req := getReq() + dispatcher.push(req) + _, _, err := req.Wait() + if err != nil { + panic(fmt.Sprintf("unexpected error from tsoReq: %+v", err)) + } + } + // Don't count the time cost in `defer` + b.StopTimer() +} diff --git a/client/tso_stream_test.go b/client/tso_stream_test.go new file mode 100644 index 00000000000..9c68a35d55e --- /dev/null +++ b/client/tso_stream_test.go @@ -0,0 +1,188 @@ +// Copyright 2024 TiKV Project Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package pd + +import ( + "context" + "io" + "testing" + "time" +) + +const mockStreamURL = "mock:///" + +type requestMsg struct { + clusterID uint64 + keyspaceGroupID uint32 + count int64 +} + +type resultMsg struct { + r tsoRequestResult + err error + breakStream bool +} + +type mockTSOStreamImpl struct { + ctx context.Context + requestCh chan requestMsg + resultCh chan resultMsg + keyspaceID uint32 + errorState error + + autoGenerateResult bool + // Current progress of generating TSO results + resGenPhysical, resGenLogical int64 +} + +func newMockTSOStreamImpl(ctx context.Context, autoGenerateResult bool) *mockTSOStreamImpl { + return &mockTSOStreamImpl{ + ctx: ctx, + requestCh: make(chan requestMsg, 64), + resultCh: make(chan resultMsg, 64), + keyspaceID: 0, + + autoGenerateResult: autoGenerateResult, + resGenPhysical: 10000, + resGenLogical: 0, + } +} + +func (s *mockTSOStreamImpl) Send(clusterID uint64, _keyspaceID, keyspaceGroupID uint32, _dcLocation string, count int64) error { + select { + case <-s.ctx.Done(): + return s.ctx.Err() + default: + } + s.requestCh <- requestMsg{ + clusterID: clusterID, + keyspaceGroupID: keyspaceGroupID, + count: count, + } + return nil +} + +func (s *mockTSOStreamImpl) Recv() (ret tsoRequestResult, retErr error) { + // This stream have ever receive an error, it returns the error forever. + if s.errorState != nil { + return tsoRequestResult{}, s.errorState + } + + select { + case <-s.ctx.Done(): + s.errorState = s.ctx.Err() + return tsoRequestResult{}, s.errorState + default: + } + + var res resultMsg + needGenRes := false + if s.autoGenerateResult { + select { + case res = <-s.resultCh: + default: + needGenRes = true + } + } else { + select { + case res = <-s.resultCh: + case <-s.ctx.Done(): + s.errorState = s.ctx.Err() + return tsoRequestResult{}, s.errorState + } + } + + if !res.breakStream { + var req requestMsg + select { + case req = <-s.requestCh: + case <-s.ctx.Done(): + s.errorState = s.ctx.Err() + return tsoRequestResult{}, s.errorState + } + if needGenRes { + + physical := s.resGenPhysical + logical := s.resGenLogical + req.count + if logical >= (1 << 18) { + physical += logical >> 18 + logical &= (1 << 18) - 1 + } + + s.resGenPhysical = physical + s.resGenLogical = logical + + res = resultMsg{ + r: tsoRequestResult{ + physical: s.resGenPhysical, + logical: s.resGenLogical, + count: uint32(req.count), + suffixBits: 0, + respKeyspaceGroupID: 0, + }, + } + } + } + if res.err != nil { + s.errorState = res.err + } + return res.r, res.err +} + +func (s *mockTSOStreamImpl) returnResult(physical int64, logical int64, count uint32) { + s.resultCh <- resultMsg{ + r: tsoRequestResult{ + physical: physical, + logical: logical, + count: count, + suffixBits: 0, + respKeyspaceGroupID: s.keyspaceID, + }, + } +} + +func (s *mockTSOStreamImpl) returnError(err error) { + s.resultCh <- resultMsg{ + err: err, + } +} + +func (s *mockTSOStreamImpl) breakStream(err error) { + s.resultCh <- resultMsg{ + err: err, + breakStream: true, + } +} + +func (s *mockTSOStreamImpl) stop() { + s.breakStream(io.EOF) +} + +func BenchmarkTSOStreamSendRecv(b *testing.B) { + streamInner := newMockTSOStreamImpl(context.Background(), true) + stream := tsoStream{ + serverURL: mockStreamURL, + stream: streamInner, + } + defer streamInner.stop() + + now := time.Now() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _, _, _, _ = stream.processRequests(1, 1, 1, globalDCLocation, 1, now) + } + b.StopTimer() +} From 4699b4112cd73210b24b3f80ce3e2350bbab9767 Mon Sep 17 00:00:00 2001 From: MyonKeminta Date: Thu, 12 Sep 2024 13:34:20 +0800 Subject: [PATCH 14/21] Fix build Signed-off-by: MyonKeminta --- client/tso_stream_test.go | 23 +++++++++++++++++------ 1 file changed, 17 insertions(+), 6 deletions(-) diff --git a/client/tso_stream_test.go b/client/tso_stream_test.go index 39a631543bc..19a90fd0674 100644 --- a/client/tso_stream_test.go +++ b/client/tso_stream_test.go @@ -413,17 +413,28 @@ func (s *testTSOStreamSuite) TestTSOStreamConcurrentRunning() { func BenchmarkTSOStreamSendRecv(b *testing.B) { streamInner := newMockTSOStreamImpl(context.Background(), true) - stream := tsoStream{ - serverURL: mockStreamURL, - stream: streamInner, - } - defer streamInner.stop() + stream := newTSOStream(context.Background(), mockStreamURL, streamInner) + defer func() { + streamInner.stop() + stream.WaitForClosed() + }() now := time.Now() + resCh := make(chan tsoRequestResult, 1) b.ResetTimer() for i := 0; i < b.N; i++ { - _, _, _, _, _ = stream.processRequests(1, 1, 1, globalDCLocation, 1, now) + err := stream.processRequests(1, 1, 1, globalDCLocation, 1, now, func(result tsoRequestResult, reqKeyspaceGroupID uint32, err error) { + select { + case resCh <- result: + default: + panic("channel not cleared in the last iteration") + } + }) + if err != nil { + panic(err) + } + <-resCh } b.StopTimer() } From 7dbcd770af83d272e769e71ae4f81dd0ed85204a Mon Sep 17 00:00:00 2001 From: MyonKeminta Date: Thu, 12 Sep 2024 15:02:02 +0800 Subject: [PATCH 15/21] Reimplement mock receiving Signed-off-by: MyonKeminta --- client/tso_stream_test.go | 102 ++++++++++++++++++++++++++------------ 1 file changed, 70 insertions(+), 32 deletions(-) diff --git a/client/tso_stream_test.go b/client/tso_stream_test.go index 9c68a35d55e..c543d74c827 100644 --- a/client/tso_stream_test.go +++ b/client/tso_stream_test.go @@ -74,7 +74,7 @@ func (s *mockTSOStreamImpl) Send(clusterID uint64, _keyspaceID, keyspaceGroupID return nil } -func (s *mockTSOStreamImpl) Recv() (ret tsoRequestResult, retErr error) { +func (s *mockTSOStreamImpl) Recv() (tsoRequestResult, error) { // This stream have ever receive an error, it returns the error forever. if s.errorState != nil { return tsoRequestResult{}, s.errorState @@ -87,60 +87,98 @@ func (s *mockTSOStreamImpl) Recv() (ret tsoRequestResult, retErr error) { default: } - var res resultMsg - needGenRes := false - if s.autoGenerateResult { + var ( + res resultMsg + hasRes bool + req requestMsg + hasReq bool + ) + + // Try to match a pair of request and result from each channel and allowing breaking the stream at any time. + select { + case <-s.ctx.Done(): + s.errorState = s.ctx.Err() + return tsoRequestResult{}, s.errorState + case req = <-s.requestCh: + hasReq = true select { case res = <-s.resultCh: + hasRes = true default: - needGenRes = true } - } else { + case res = <-s.resultCh: + hasRes = true select { - case res = <-s.resultCh: - case <-s.ctx.Done(): + case req = <-s.requestCh: + hasReq = true + default: + } + } + // Either req or res should be ready at this time. + + if hasRes { + if res.breakStream { s.errorState = s.ctx.Err() return tsoRequestResult{}, s.errorState + } else { + // Do not allow manually assigning result. + if s.autoGenerateResult { + panic("trying manually specifying result for mockTSOStreamImpl when it's auto-generating mode") + } } + } else if s.autoGenerateResult { + res = s.autoGenResult(req.count) + hasRes = true } - if !res.breakStream { - var req requestMsg + if !hasReq { + // If req is not ready, the res must be ready. So it's certain that it don't need to be canceled by breakStream. select { - case req = <-s.requestCh: case <-s.ctx.Done(): s.errorState = s.ctx.Err() return tsoRequestResult{}, s.errorState + case req = <-s.requestCh: + hasReq = true } - if needGenRes { - - physical := s.resGenPhysical - logical := s.resGenLogical + req.count - if logical >= (1 << 18) { - physical += logical >> 18 - logical &= (1 << 18) - 1 - } - - s.resGenPhysical = physical - s.resGenLogical = logical - - res = resultMsg{ - r: tsoRequestResult{ - physical: s.resGenPhysical, - logical: s.resGenLogical, - count: uint32(req.count), - suffixBits: 0, - respKeyspaceGroupID: 0, - }, - } + } else if !hasRes { + select { + case <-s.ctx.Done(): + s.errorState = s.ctx.Err() + return tsoRequestResult{}, s.errorState + case res = <-s.resultCh: + hasRes = true } } + + // Both res and req should be ready here. if res.err != nil { s.errorState = res.err } return res.r, res.err } +func (s *mockTSOStreamImpl) autoGenResult(count int64) resultMsg { + physical := s.resGenPhysical + logical := s.resGenLogical + count + if logical >= (1 << 18) { + physical += logical >> 18 + logical &= (1 << 18) - 1 + } + + s.resGenPhysical = physical + s.resGenLogical = logical + + return resultMsg{ + r: tsoRequestResult{ + physical: s.resGenPhysical, + logical: s.resGenLogical, + count: uint32(count), + suffixBits: 0, + respKeyspaceGroupID: 0, + }, + } +} + func (s *mockTSOStreamImpl) returnResult(physical int64, logical int64, count uint32) { s.resultCh <- resultMsg{ r: tsoRequestResult{ From 1f2dcd862f183daf44bd4be68592d51833ab4160 Mon Sep 17 00:00:00 2001 From: MyonKeminta Date: Thu, 12 Sep 2024 15:25:45 +0800 Subject: [PATCH 16/21] Fix disptacher benchmark Signed-off-by: MyonKeminta --- client/tso_dispatcher_test.go | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/client/tso_dispatcher_test.go b/client/tso_dispatcher_test.go index b1a87d0a2cd..20904d41f0a 100644 --- a/client/tso_dispatcher_test.go +++ b/client/tso_dispatcher_test.go @@ -49,10 +49,7 @@ func (m *mockTSOServiceProvider) updateConnectionCtxs(ctx context.Context, _dc s return true } ctx, cancel := context.WithCancel(ctx) - stream := &tsoStream{ - serverURL: mockStreamURL, - stream: newMockTSOStreamImpl(ctx, true), - } + stream := newTSOStream(ctx, mockStreamURL, newMockTSOStreamImpl(ctx, true)) connectionCtxs.LoadOrStore(mockStreamURL, &tsoConnectionContext{ctx, cancel, mockStreamURL, stream}) return true } From 123a869fb5349867b80f458517b4c53c7a6d089f Mon Sep 17 00:00:00 2001 From: MyonKeminta Date: Thu, 12 Sep 2024 15:27:09 +0800 Subject: [PATCH 17/21] Fix incorrect error generating Signed-off-by: MyonKeminta --- client/tso_stream_test.go | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/client/tso_stream_test.go b/client/tso_stream_test.go index c543d74c827..e47fe75eca9 100644 --- a/client/tso_stream_test.go +++ b/client/tso_stream_test.go @@ -118,7 +118,10 @@ func (s *mockTSOStreamImpl) Recv() (tsoRequestResult, error) { if hasRes { if res.breakStream { - s.errorState = s.ctx.Err() + if res.err == nil { + panic("breaking mockTSOStreamImpl without error") + } + s.errorState = res.err return tsoRequestResult{}, s.errorState } else { // Do not allow manually assigning result. From bff9eaedef3c2d0b2629bc6caeb734c18d7b49fb Mon Sep 17 00:00:00 2001 From: MyonKeminta Date: Thu, 12 Sep 2024 15:31:09 +0800 Subject: [PATCH 18/21] Disable logs in BenchmarkTSOStreamSendRecv Signed-off-by: MyonKeminta --- client/tso_stream_test.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/client/tso_stream_test.go b/client/tso_stream_test.go index 25277dcc2aa..90039bd4b55 100644 --- a/client/tso_stream_test.go +++ b/client/tso_stream_test.go @@ -21,9 +21,11 @@ import ( "time" "github.com/pingcap/errors" + "github.com/pingcap/log" "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" "github.com/tikv/pd/client/errs" + "go.uber.org/zap/zapcore" ) const mockStreamURL = "mock:///" @@ -453,6 +455,8 @@ func (s *testTSOStreamSuite) TestTSOStreamConcurrentRunning() { } func BenchmarkTSOStreamSendRecv(b *testing.B) { + log.SetLevel(zapcore.FatalLevel) + streamInner := newMockTSOStreamImpl(context.Background(), true) stream := newTSOStream(context.Background(), mockStreamURL, streamInner) defer func() { From ef6e365c7a9a3061db8012b79bfdab9d1ce806a9 Mon Sep 17 00:00:00 2001 From: MyonKeminta Date: Thu, 12 Sep 2024 16:14:33 +0800 Subject: [PATCH 19/21] Fix lint Signed-off-by: MyonKeminta --- client/tso_dispatcher_test.go | 4 ++-- client/tso_stream_test.go | 15 +++++++++------ 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/client/tso_dispatcher_test.go b/client/tso_dispatcher_test.go index b1a87d0a2cd..7b7ac8f736b 100644 --- a/client/tso_dispatcher_test.go +++ b/client/tso_dispatcher_test.go @@ -39,11 +39,11 @@ func (m *mockTSOServiceProvider) getOption() *option { return m.option } -func (m *mockTSOServiceProvider) getServiceDiscovery() ServiceDiscovery { +func (*mockTSOServiceProvider) getServiceDiscovery() ServiceDiscovery { return NewMockPDServiceDiscovery([]string{mockStreamURL}, nil) } -func (m *mockTSOServiceProvider) updateConnectionCtxs(ctx context.Context, _dc string, connectionCtxs *sync.Map) bool { +func (*mockTSOServiceProvider) updateConnectionCtxs(ctx context.Context, _dc string, connectionCtxs *sync.Map) bool { _, ok := connectionCtxs.Load(mockStreamURL) if ok { return true diff --git a/client/tso_stream_test.go b/client/tso_stream_test.go index e47fe75eca9..4709fe04f4c 100644 --- a/client/tso_stream_test.go +++ b/client/tso_stream_test.go @@ -123,11 +123,9 @@ func (s *mockTSOStreamImpl) Recv() (tsoRequestResult, error) { } s.errorState = res.err return tsoRequestResult{}, s.errorState - } else { + } else if s.autoGenerateResult { // Do not allow manually assigning result. - if s.autoGenerateResult { - panic("trying manually specifying result for mockTSOStreamImpl when it's auto-generating mode") - } + panic("trying manually specifying result for mockTSOStreamImpl when it's auto-generating mode") } } else if s.autoGenerateResult { res = s.autoGenResult(req.count) @@ -141,7 +139,8 @@ func (s *mockTSOStreamImpl) Recv() (tsoRequestResult, error) { s.errorState = s.ctx.Err() return tsoRequestResult{}, s.errorState case req = <-s.requestCh: - hasReq = true + // Skip the assignment to make linter happy. + // hasReq = true } } else if !hasRes { select { @@ -149,7 +148,8 @@ func (s *mockTSOStreamImpl) Recv() (tsoRequestResult, error) { s.errorState = s.ctx.Err() return tsoRequestResult{}, s.errorState case res = <-s.resultCh: - hasRes = true + // Skip the assignment to make linter happy. + // hasRes = true } } @@ -182,6 +182,7 @@ func (s *mockTSOStreamImpl) autoGenResult(count int64) resultMsg { } } +// nolint:unused func (s *mockTSOStreamImpl) returnResult(physical int64, logical int64, count uint32) { s.resultCh <- resultMsg{ r: tsoRequestResult{ @@ -194,12 +195,14 @@ func (s *mockTSOStreamImpl) returnResult(physical int64, logical int64, count ui } } +// nolint:unused func (s *mockTSOStreamImpl) returnError(err error) { s.resultCh <- resultMsg{ err: err, } } +// nolint:unused func (s *mockTSOStreamImpl) breakStream(err error) { s.resultCh <- resultMsg{ err: err, From 388d705601b33493a26bd3a7191823abf08d84a9 Mon Sep 17 00:00:00 2001 From: MyonKeminta Date: Thu, 12 Sep 2024 17:03:24 +0800 Subject: [PATCH 20/21] Make the comments more clear; unset req.streamID when getting new tsoRequest from the pool Signed-off-by: MyonKeminta --- client/tso_client.go | 1 + client/tso_dispatcher.go | 11 ++++++++--- client/tso_stream.go | 4 ++-- client/tso_stream_test.go | 3 --- 4 files changed, 11 insertions(+), 8 deletions(-) diff --git a/client/tso_client.go b/client/tso_client.go index 2f3b949f017..f1538a7f164 100644 --- a/client/tso_client.go +++ b/client/tso_client.go @@ -203,6 +203,7 @@ func (c *tsoClient) getTSORequest(ctx context.Context, dcLocation string) *tsoRe req.physical = 0 req.logical = 0 req.dcLocation = dcLocation + req.streamID = "" return req } diff --git a/client/tso_dispatcher.go b/client/tso_dispatcher.go index efbe246a82a..a1e0b03a1fa 100644 --- a/client/tso_dispatcher.go +++ b/client/tso_dispatcher.go @@ -331,9 +331,12 @@ tsoBatchLoop: err = td.processRequests(stream, dc, batchController, done) // If error happens during tso stream handling, reset stream and run the next trial. if err == nil { - // If the request is started successfully, the `batchController` will be put back to the pool when the - // request is finished (either successful or not). In this case, set the `batchController` to nil so that - // another one will be fetched from the pool. + // A nil error returned by `processRequests` indicates that the request batch is started successfully. + // In this case, the `batchController` will be put back to the pool when the request is finished + // asynchronously (either successful or not). This infers that the current `batchController` object will + // be asynchronously accessed after the `processRequests` call. As a result, we need to use another + // `batchController` for collecting the next batch. Do to this, we set the `batchController` to nil so that + // another one will be fetched from the pool at the beginning of the batching loop. // Otherwise, the `batchController` won't be processed in other goroutines concurrently, and it can be // reused in the next loop safely. batchController = nil @@ -493,6 +496,8 @@ func (td *tsoDispatcher) processRequests( ) cb := func(result tsoRequestResult, reqKeyspaceGroupID uint32, err error) { + // As golang doesn't allow double-closing a channel, here is implicitly a check that the callback + // is never called twice or called while it's also being cancelled elsewhere. close(done) defer td.batchBufferPool.Put(tbc) diff --git a/client/tso_stream.go b/client/tso_stream.go index 726bf502485..a6e23bbf1d3 100644 --- a/client/tso_stream.go +++ b/client/tso_stream.go @@ -270,7 +270,7 @@ func (s *tsoStream) processRequests( case streamStateClosing: s.state.Store(prevState) err := s.GetRecvError() - log.Info("sending to closed tsoStream", zap.Error(err)) + log.Info("sending to closed tsoStream", zap.String("stream", s.streamID), zap.Error(err)) if err == nil { err = errors.WithStack(errs.ErrClientTSOStreamClosed) } @@ -322,7 +322,7 @@ func (s *tsoStream) recvLoop(ctx context.Context) { if finishWithErr == nil { // The loop must exit with a non-nil error (including io.EOF and context.Canceled). This should be // unreachable code. - log.Fatal("tsoStream.recvLoop exited without error info") + log.Fatal("tsoStream.recvLoop exited without error info", zap.String("stream", s.streamID)) } if hasReq { diff --git a/client/tso_stream_test.go b/client/tso_stream_test.go index 13583ee3839..564de37d6db 100644 --- a/client/tso_stream_test.go +++ b/client/tso_stream_test.go @@ -189,7 +189,6 @@ func (s *mockTSOStreamImpl) autoGenResult(count int64) resultMsg { } } -// nolint:unused func (s *mockTSOStreamImpl) returnResult(physical int64, logical int64, count uint32) { s.resultCh <- resultMsg{ r: tsoRequestResult{ @@ -202,14 +201,12 @@ func (s *mockTSOStreamImpl) returnResult(physical int64, logical int64, count ui } } -// nolint:unused func (s *mockTSOStreamImpl) returnError(err error) { s.resultCh <- resultMsg{ err: err, } } -// nolint:unused func (s *mockTSOStreamImpl) breakStream(err error) { s.resultCh <- resultMsg{ err: err, From a183aca0a2ecb886c885f1a3d584266ee80f96fb Mon Sep 17 00:00:00 2001 From: MyonKeminta Date: Thu, 12 Sep 2024 17:42:54 +0800 Subject: [PATCH 21/21] Fix lint Signed-off-by: MyonKeminta --- client/tso_stream_test.go | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/client/tso_stream_test.go b/client/tso_stream_test.go index 564de37d6db..b09c54baf3a 100644 --- a/client/tso_stream_test.go +++ b/client/tso_stream_test.go @@ -469,7 +469,10 @@ func BenchmarkTSOStreamSendRecv(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { - err := stream.processRequests(1, 1, 1, globalDCLocation, 1, now, func(result tsoRequestResult, reqKeyspaceGroupID uint32, err error) { + err := stream.processRequests(1, 1, 1, globalDCLocation, 1, now, func(result tsoRequestResult, _ uint32, err error) { + if err != nil { + panic(err) + } select { case resCh <- result: default: