From 140a7c2d50720d79c41921a1795519f069da3f87 Mon Sep 17 00:00:00 2001 From: MyonKeminta Date: Mon, 5 Aug 2024 17:32:11 +0800 Subject: [PATCH 01/34] Add implementation of making tsoStream asynchronous Signed-off-by: MyonKeminta --- client/metrics.go | 24 ++-- client/tso_batch_controller.go | 83 +++++++++---- client/tso_dispatcher.go | 145 +++++++++++++++------- client/tso_stream.go | 218 +++++++++++++++++++++++++++++---- 4 files changed, 362 insertions(+), 108 deletions(-) diff --git a/client/metrics.go b/client/metrics.go index a11362669b3..a83b4a36407 100644 --- a/client/metrics.go +++ b/client/metrics.go @@ -39,13 +39,14 @@ func initAndRegisterMetrics(constLabels prometheus.Labels) { } var ( - cmdDuration *prometheus.HistogramVec - cmdFailedDuration *prometheus.HistogramVec - requestDuration *prometheus.HistogramVec - tsoBestBatchSize prometheus.Histogram - tsoBatchSize prometheus.Histogram - tsoBatchSendLatency prometheus.Histogram - requestForwarded *prometheus.GaugeVec + cmdDuration *prometheus.HistogramVec + cmdFailedDuration *prometheus.HistogramVec + requestDuration *prometheus.HistogramVec + tsoBestBatchSize prometheus.Histogram + tsoBatchSize prometheus.Histogram + tsoBatchSendLatency prometheus.Histogram + requestForwarded *prometheus.GaugeVec + ongoingRequestCountGauge *prometheus.GaugeVec ) func initMetrics(constLabels prometheus.Labels) { @@ -117,6 +118,15 @@ func initMetrics(constLabels prometheus.Labels) { Help: "The status to indicate if the request is forwarded", ConstLabels: constLabels, }, []string{"host", "delegate"}) + + ongoingRequestCountGauge = prometheus.NewGaugeVec( + prometheus.GaugeOpts{ + Namespace: "pd_client", + Subsystem: "request", + Name: "ongoing_requests_count", + Help: "Current count of ongoing batch tso requests", + ConstLabels: constLabels, + }, []string{"stream"}) } var ( diff --git a/client/tso_batch_controller.go b/client/tso_batch_controller.go index a713b7a187d..bf072b24535 100644 --- a/client/tso_batch_controller.go +++ b/client/tso_batch_controller.go @@ -30,42 +30,79 @@ type tsoBatchController struct { // bestBatchSize is a dynamic size that changed based on the current batch effect. bestBatchSize int - tsoRequestCh chan *tsoRequest collectedRequests []*tsoRequest collectedRequestCount int - batchStartTime time.Time + // The time after getting the first request and the token, and before performing extra batching. + extraBatchingStartTime time.Time } -func newTSOBatchController(tsoRequestCh chan *tsoRequest, maxBatchSize int) *tsoBatchController { +func newTSOBatchController(maxBatchSize int) *tsoBatchController { return &tsoBatchController{ maxBatchSize: maxBatchSize, bestBatchSize: 8, /* Starting from a low value is necessary because we need to make sure it will be converged to (current_batch_size - 4) */ - tsoRequestCh: tsoRequestCh, collectedRequests: make([]*tsoRequest, maxBatchSize+1), collectedRequestCount: 0, } } // fetchPendingRequests will start a new round of the batch collecting from the channel. -// It returns true if everything goes well, otherwise false which means we should stop the service. -func (tbc *tsoBatchController) fetchPendingRequests(ctx context.Context, maxBatchWaitInterval time.Duration) error { - var firstRequest *tsoRequest - select { - case <-ctx.Done(): - return ctx.Err() - case firstRequest = <-tbc.tsoRequestCh: - } - // Start to batch when the first TSO request arrives. - tbc.batchStartTime = time.Now() +// It returns nil error if everything goes well, otherwise a non-nil error which means we should stop the service. +// It's guaranteed that if this function failed after collecting some requests, then these requests will be cancelled +// when the function returns, so the caller don't need to clear them manually. +func (tbc *tsoBatchController) fetchPendingRequests(ctx context.Context, tsoRequestCh <-chan *tsoRequest, tokenCh chan struct{}, maxBatchWaitInterval time.Duration) (errRet error) { + var tokenAcquired bool + defer func() { + if errRet != nil { + // Something went wrong when collecting a batch of requests. Release the token and cancel collected requests + // if any. + if tokenAcquired { + tokenCh <- struct{}{} + } + if tbc.collectedRequestCount > 0 { + tbc.finishCollectedRequests(0, 0, 0, errRet) + } + } + }() + + // Wait until BOTH the first request and the token have arrived. + // TODO: `tbc.collectedRequestCount` should never be non-empty here. Consider do assertion here. tbc.collectedRequestCount = 0 - tbc.pushRequest(firstRequest) + for { + select { + case <-ctx.Done(): + return ctx.Err() + case req := <-tsoRequestCh: + // Start to batch when the first TSO request arrives. + tbc.pushRequest(req) + // A request arrives but the token is not ready yet. Continue waiting, and also allowing collecting the next + // request if it arrives. + continue + case <-tokenCh: + tokenAcquired = true + } + + // The token is ready. If the first request didn't arrive, wait for it. + if tbc.collectedRequestCount == 0 { + select { + case <-ctx.Done(): + return ctx.Err() + case firstRequest := <-tsoRequestCh: + tbc.pushRequest(firstRequest) + } + } + + // Both token and the first request have arrived. + break + } + + tbc.extraBatchingStartTime = time.Now() // This loop is for trying best to collect more requests, so we use `tbc.maxBatchSize` here. fetchPendingRequestsLoop: for tbc.collectedRequestCount < tbc.maxBatchSize { select { - case tsoReq := <-tbc.tsoRequestCh: + case tsoReq := <-tsoRequestCh: tbc.pushRequest(tsoReq) case <-ctx.Done(): return ctx.Err() @@ -88,7 +125,7 @@ fetchPendingRequestsLoop: defer after.Stop() for tbc.collectedRequestCount < tbc.bestBatchSize { select { - case tsoReq := <-tbc.tsoRequestCh: + case tsoReq := <-tsoRequestCh: tbc.pushRequest(tsoReq) case <-ctx.Done(): return ctx.Err() @@ -103,7 +140,7 @@ fetchPendingRequestsLoop: // we can adjust the `tbc.bestBatchSize` dynamically later. for tbc.collectedRequestCount < tbc.maxBatchSize { select { - case tsoReq := <-tbc.tsoRequestCh: + case tsoReq := <-tsoRequestCh: tbc.pushRequest(tsoReq) case <-ctx.Done(): return ctx.Err() @@ -149,18 +186,10 @@ func (tbc *tsoBatchController) finishCollectedRequests(physical, firstLogical in tbc.collectedRequestCount = 0 } -func (tbc *tsoBatchController) revokePendingRequests(err error) { - for i := 0; i < len(tbc.tsoRequestCh); i++ { - req := <-tbc.tsoRequestCh - req.tryDone(err) - } -} - func (tbc *tsoBatchController) clear() { log.Info("[pd] clear the tso batch controller", zap.Int("max-batch-size", tbc.maxBatchSize), zap.Int("best-batch-size", tbc.bestBatchSize), - zap.Int("collected-request-count", tbc.collectedRequestCount), zap.Int("pending-request-count", len(tbc.tsoRequestCh))) + zap.Int("collected-request-count", tbc.collectedRequestCount)) tsoErr := errors.WithStack(errClosing) tbc.finishCollectedRequests(0, 0, 0, tsoErr) - tbc.revokePendingRequests(tsoErr) } diff --git a/client/tso_dispatcher.go b/client/tso_dispatcher.go index a7c99057275..2e670c656c9 100644 --- a/client/tso_dispatcher.go +++ b/client/tso_dispatcher.go @@ -76,10 +76,18 @@ type tsoDispatcher struct { provider tsoServiceProvider // URL -> *connectionContext - connectionCtxs *sync.Map - batchController *tsoBatchController - tsDeadlineCh chan *deadline - lastTSOInfo *tsoInfo + connectionCtxs *sync.Map + tsoRequestCh chan *tsoRequest + tsDeadlineCh chan *deadline + lastTSOInfo *tsoInfo + // For reusing tsoBatchController objects + batchBufferPool *sync.Pool + + // For controlling amount of concurrently processing RPC requests. + // A token must be acquired here before sending an RPC request, and the token must be put back after finishing the + // RPC. This is used like a semaphore, but we don't use semaphore directly here as it cannot be selected with + // other channels. + tokenCh chan struct{} updateConnectionCtxsCh chan struct{} } @@ -91,24 +99,27 @@ func newTSODispatcher( provider tsoServiceProvider, ) *tsoDispatcher { dispatcherCtx, dispatcherCancel := context.WithCancel(ctx) - tsoBatchController := newTSOBatchController( - make(chan *tsoRequest, maxBatchSize*2), - maxBatchSize, - ) + tsoRequestCh := make(chan *tsoRequest, maxBatchSize*2) failpoint.Inject("shortDispatcherChannel", func() { - tsoBatchController = newTSOBatchController( - make(chan *tsoRequest, 1), - maxBatchSize, - ) + tsoRequestCh = make(chan *tsoRequest, 1) }) + + tokenCh := make(chan struct{}, 64) + td := &tsoDispatcher{ - ctx: dispatcherCtx, - cancel: dispatcherCancel, - dc: dc, - provider: provider, - connectionCtxs: &sync.Map{}, - batchController: tsoBatchController, - tsDeadlineCh: make(chan *deadline, 1), + ctx: dispatcherCtx, + cancel: dispatcherCancel, + dc: dc, + provider: provider, + connectionCtxs: &sync.Map{}, + tsoRequestCh: tsoRequestCh, + tsDeadlineCh: make(chan *deadline, 1), + batchBufferPool: &sync.Pool{ + New: func() any { + return newTSOBatchController(maxBatchSize * 2) + }, + }, + tokenCh: tokenCh, updateConnectionCtxsCh: make(chan struct{}, 1), } go td.watchTSDeadline() @@ -146,13 +157,21 @@ func (td *tsoDispatcher) scheduleUpdateConnectionCtxs() { } } +func (td *tsoDispatcher) revokePendingRequests(err error) { + for i := 0; i < len(td.tsoRequestCh); i++ { + req := <-td.tsoRequestCh + req.tryDone(err) + } +} + func (td *tsoDispatcher) close() { td.cancel() - td.batchController.clear() + tsoErr := errors.WithStack(errClosing) + td.revokePendingRequests(tsoErr) } func (td *tsoDispatcher) push(request *tsoRequest) { - td.batchController.tsoRequestCh <- request + td.tsoRequestCh <- request } func (td *tsoDispatcher) handleDispatcher(wg *sync.WaitGroup) { @@ -163,8 +182,12 @@ func (td *tsoDispatcher) handleDispatcher(wg *sync.WaitGroup) { svcDiscovery = provider.getServiceDiscovery() option = provider.getOption() connectionCtxs = td.connectionCtxs - batchController = td.batchController + batchController *tsoBatchController ) + + // Currently only 1 concurrency is supported. Put one token in. + td.tokenCh <- struct{}{} + log.Info("[tso] tso dispatcher created", zap.String("dc-location", dc)) // Clean up the connectionCtxs when the dispatcher exits. defer func() { @@ -174,8 +197,11 @@ func (td *tsoDispatcher) handleDispatcher(wg *sync.WaitGroup) { cc.(*tsoConnectionContext).cancel() return true }) - // Clear the tso batch controller. - batchController.clear() + if batchController.collectedRequestCount != 0 { + log.Fatal("batched tso requests not cleared when exiting the tso dispatcher loop") + } + tsoErr := errors.WithStack(errClosing) + td.revokePendingRequests(tsoErr) wg.Done() }() // Daemon goroutine to update the connectionCtxs periodically and handle the `connectionCtxs` update event. @@ -203,9 +229,7 @@ tsoBatchLoop: maxBatchWaitInterval := option.getMaxTSOBatchWaitInterval() // Once the TSO requests are collected, must make sure they could be finished or revoked eventually, // otherwise the upper caller may get blocked on waiting for the results. - if err = batchController.fetchPendingRequests(ctx, maxBatchWaitInterval); err != nil { - // Finish the collected requests if the fetch failed. - batchController.finishCollectedRequests(0, 0, 0, errors.WithStack(err)) + if err = batchController.fetchPendingRequests(ctx, td.tsoRequestCh, td.tokenCh, maxBatchWaitInterval); err != nil { if err == context.Canceled { log.Info("[tso] stop fetching the pending tso requests due to context canceled", zap.String("dc-location", dc)) @@ -246,7 +270,7 @@ tsoBatchLoop: select { case <-ctx.Done(): // Finish the collected requests if the context is canceled. - batchController.finishCollectedRequests(0, 0, 0, errors.WithStack(ctx.Err())) + td.cancelCollectedRequests(batchController, errors.WithStack(ctx.Err())) timer.Stop() return case <-streamLoopTimer.C: @@ -254,7 +278,7 @@ tsoBatchLoop: log.Error("[tso] create tso stream error", zap.String("dc-location", dc), errs.ZapError(err)) svcDiscovery.ScheduleCheckMemberChanged() // Finish the collected requests if the stream is failed to be created. - batchController.finishCollectedRequests(0, 0, 0, errors.WithStack(err)) + td.cancelCollectedRequests(batchController, errors.WithStack(err)) timer.Stop() continue tsoBatchLoop case <-timer.C: @@ -279,15 +303,22 @@ tsoBatchLoop: select { case <-ctx.Done(): // Finish the collected requests if the context is canceled. - batchController.finishCollectedRequests(0, 0, 0, errors.WithStack(ctx.Err())) + td.cancelCollectedRequests(batchController, errors.WithStack(ctx.Err())) return case td.tsDeadlineCh <- dl: } // processRequests guarantees that the collected requests could be finished properly. - err = td.processRequests(stream, dc, td.batchController) + err = td.processRequests(stream, dc, batchController) close(done) // If error happens during tso stream handling, reset stream and run the next trial. - if err != nil { + if err == nil { + // If the request is started successfully, the `batchController` will be put back to the pool when the + // request is finished (either successful or not). In this case, set the `batchController` to nil so that + // another one will be fetched from the pool. + // Otherwise, the `batchController` won't be processed in other goroutines concurrently, and it can be + // reused in the next loop safely. + batchController = nil + } else { select { case <-ctx.Done(): return @@ -422,28 +453,48 @@ func (td *tsoDispatcher) processRequests( keyspaceID = svcDiscovery.GetKeyspaceID() reqKeyspaceGroupID = svcDiscovery.GetKeyspaceGroupID() ) - respKeyspaceGroupID, physical, logical, suffixBits, err := stream.processRequests( + + cb := func(result tsoRequestResult, reqKeyspaceGroupID uint32, streamURL string, err error) { + defer td.batchBufferPool.Put(tbc) + if err != nil { + td.cancelCollectedRequests(tbc, err) + return + } + + curTSOInfo := &tsoInfo{ + tsoServer: stream.getServerURL(), + reqKeyspaceGroupID: reqKeyspaceGroupID, + respKeyspaceGroupID: result.respKeyspaceGroupID, + respReceivedAt: time.Now(), + physical: result.physical, + logical: result.logical, + } + // `logical` is the largest ts's logical part here, we need to do the subtracting before we finish each TSO request. + firstLogical := tsoutil.AddLogical(result.logical, -int64(result.count)+1, result.suffixBits) + td.compareAndSwapTS(curTSOInfo, firstLogical) + td.doneCollectedRequests(tbc, result.physical, firstLogical, result.suffixBits) + } + + err := stream.processRequests( clusterID, keyspaceID, reqKeyspaceGroupID, - dcLocation, count, tbc.batchStartTime) + dcLocation, count, tbc.extraBatchingStartTime, cb) if err != nil { - tbc.finishCollectedRequests(0, 0, 0, err) + td.cancelCollectedRequests(tbc, err) return err } - curTSOInfo := &tsoInfo{ - tsoServer: stream.getServerURL(), - reqKeyspaceGroupID: reqKeyspaceGroupID, - respKeyspaceGroupID: respKeyspaceGroupID, - respReceivedAt: time.Now(), - physical: physical, - logical: logical, - } - // `logical` is the largest ts's logical part here, we need to do the subtracting before we finish each TSO request. - firstLogical := tsoutil.AddLogical(logical, -count+1, suffixBits) - td.compareAndSwapTS(curTSOInfo, firstLogical) - tbc.finishCollectedRequests(physical, firstLogical, suffixBits, nil) return nil } +func (td *tsoDispatcher) cancelCollectedRequests(tbc *tsoBatchController, err error) { + td.tokenCh <- struct{}{} + tbc.finishCollectedRequests(0, 0, 0, err) +} + +func (td *tsoDispatcher) doneCollectedRequests(tbc *tsoBatchController, physical int64, firstLogical int64, suffixBits uint32) { + td.tokenCh <- struct{}{} + tbc.finishCollectedRequests(physical, firstLogical, suffixBits, nil) +} + func (td *tsoDispatcher) compareAndSwapTS( curTSOInfo *tsoInfo, firstLogical int64, ) { diff --git a/client/tso_stream.go b/client/tso_stream.go index da9cab95ba0..74fbc620cef 100644 --- a/client/tso_stream.go +++ b/client/tso_stream.go @@ -16,13 +16,19 @@ package pd import ( "context" + "fmt" "io" + "sync" + "sync/atomic" "time" "github.com/pingcap/errors" "github.com/pingcap/kvproto/pkg/pdpb" "github.com/pingcap/kvproto/pkg/tsopb" + "github.com/pingcap/log" + "github.com/prometheus/client_golang/prometheus" "github.com/tikv/pd/client/errs" + "go.uber.org/zap" "google.golang.org/grpc" ) @@ -62,7 +68,7 @@ func (b *pdTSOStreamBuilder) build(ctx context.Context, cancel context.CancelFun stream, err := b.client.Tso(ctx) done <- struct{}{} if err == nil { - return &tsoStream{stream: pdTSOStreamAdapter{stream}, serverURL: b.serverURL}, nil + return newTSOStream(b.serverURL, pdTSOStreamAdapter{stream}), nil } return nil, err } @@ -81,7 +87,7 @@ func (b *tsoTSOStreamBuilder) build( stream, err := b.client.Tso(ctx) done <- struct{}{} if err == nil { - return &tsoStream{stream: tsoTSOStreamAdapter{stream}, serverURL: b.serverURL}, nil + return newTSOStream(b.serverURL, tsoTSOStreamAdapter{stream}), nil } return nil, err } @@ -172,51 +178,209 @@ func (s tsoTSOStreamAdapter) Recv() (tsoRequestResult, error) { }, nil } +type onFinishedCallback func(result tsoRequestResult, reqKeyspaceGroupID uint32, streamURL string, err error) + +type batchedRequests struct { + startTime time.Time + count int64 + reqKeyspaceGroupID uint32 + callback onFinishedCallback +} + +// tsoStream represents an abstracted stream for requesting TSO. +// This type designed decoupled with users of this type, so tsoDispatcher won't be directly accessed here. +// Also in order to avoid potential memory allocations that might happen when passing closures as the callback, +// we instead use the `batchedRequestsNotifier` as the abstraction, and accepts generic type instead of dynamic interface +// type. type tsoStream struct { serverURL string // The internal gRPC stream. // - `pdpb.PD_TsoClient` for a leader/follower in the PD cluster. // - `tsopb.TSO_TsoClient` for a primary/secondary in the TSO cluster. stream grpcTSOStreamAdapter + // An identifier of the tsoStream object for metrics reporting and diagnosing. + streamID string + + pendingRequests chan batchedRequests + + estimateLatencyMicros atomic.Uint64 + + cancel context.CancelFunc + wg sync.WaitGroup + + // For syncing between sender and receiver to guarantee all requests are finished when closing. + state atomic.Int32 + + ongoingRequestCountGauge prometheus.Gauge + ongoingRequests atomic.Int32 +} + +const ( + streamStateIdle int32 = iota + streamStateSending + streamStateClosing +) + +var streamIDAlloc atomic.Int32 + +// TODO: Pass a context? +func newTSOStream(serverURL string, stream grpcTSOStreamAdapter) *tsoStream { + streamID := fmt.Sprintf("%s-%d", serverURL, streamIDAlloc.Add(1)) + ctx, cancel := context.WithCancel(context.Background()) + s := &tsoStream{ + serverURL: serverURL, + stream: stream, + streamID: streamID, + + pendingRequests: make(chan batchedRequests, 64), + + cancel: cancel, + + ongoingRequestCountGauge: ongoingRequestCountGauge.WithLabelValues(streamID), + } + s.wg.Add(1) + go s.recvLoop(ctx) + return s } func (s *tsoStream) getServerURL() string { return s.serverURL } +// processRequests starts an RPC to get a batch of timestamps without waiting for the result. When the result is ready, +// it will be passed th `notifier.finish`. +// +// This function is NOT thread-safe. Don't call this function concurrently in multiple goroutines. +// +// It's guaranteed that the `callback` will be called, but when the request is failed to be scheduled, the callback +// will be ignored. func (s *tsoStream) processRequests( - clusterID uint64, keyspaceID, keyspaceGroupID uint32, dcLocation string, count int64, batchStartTime time.Time, -) (respKeyspaceGroupID uint32, physical, logical int64, suffixBits uint32, err error) { + clusterID uint64, keyspaceID, keyspaceGroupID uint32, dcLocation string, count int64, batchStartTime time.Time, callback onFinishedCallback, +) error { start := time.Now() - if err = s.stream.Send(clusterID, keyspaceID, keyspaceGroupID, dcLocation, count); err != nil { + + // Check if the stream is closing or closed, in which case no more requests should be put in. + // Note that the prevState should be restored very soon, as the receiver may check + prevState := s.state.Swap(streamStateSending) + switch prevState { + case streamStateIdle: + // Expected case + break + case streamStateClosing: + s.state.Store(prevState) + log.Info("tsoStream closed") + return errs.ErrClientTSOStreamClosed + case streamStateSending: + log.Fatal("unexpected concurrent sending on tsoStream", zap.String("stream", s.streamID)) + default: + log.Fatal("unknown tsoStream state", zap.String("stream", s.streamID), zap.Int32("state", prevState)) + } + + select { + case s.pendingRequests <- batchedRequests{ + startTime: start, + count: count, + reqKeyspaceGroupID: keyspaceGroupID, + callback: callback, + }: + default: + s.state.Store(prevState) + return errors.New("unexpected channel full") + } + s.state.Store(prevState) + + if err := s.stream.Send(clusterID, keyspaceID, keyspaceGroupID, dcLocation, count); err != nil { if err == io.EOF { - err = errs.ErrClientTSOStreamClosed - } else { - err = errors.WithStack(err) + return errs.ErrClientTSOStreamClosed } - return + return errors.WithStack(err) } tsoBatchSendLatency.Observe(time.Since(batchStartTime).Seconds()) - res, err := s.stream.Recv() - duration := time.Since(start).Seconds() - if err != nil { - requestFailedDurationTSO.Observe(duration) - if err == io.EOF { - err = errs.ErrClientTSOStreamClosed - } else { - err = errors.WithStack(err) + s.ongoingRequestCountGauge.Set(float64(s.ongoingRequests.Add(1))) + return nil +} + +func (s *tsoStream) recvLoop(ctx context.Context) { + var finishWithErr error + + defer func() { + s.cancel() + for !s.state.CompareAndSwap(streamStateIdle, streamStateClosing) { + switch state := s.state.Load(); state { + case streamStateIdle, streamStateSending: + // streamStateSending should switch to streamStateIdle very quickly. Spin until successfully setting to + // streamStateClosing. + continue + case streamStateClosing: + log.Warn("unexpected double closing of tsoStream", zap.String("stream", s.streamID)) + default: + log.Fatal("unknown tsoStream state", zap.String("stream", s.streamID), zap.Int32("state", state)) + } } - return - } - requestDurationTSO.Observe(duration) - tsoBatchSize.Observe(float64(count)) - if res.count != uint32(count) { - err = errors.WithStack(errTSOLength) - return + // The loop must end with an error (including context.Canceled). + if finishWithErr == nil { + log.Fatal("tsoStream recvLoop ended without error", zap.String("stream", s.streamID)) + } + log.Info("tsoStream.recvLoop ended", zap.String("stream", s.streamID), zap.Error(finishWithErr)) + + close(s.pendingRequests) + for req := range s.pendingRequests { + req.callback(tsoRequestResult{}, req.reqKeyspaceGroupID, s.serverURL, finishWithErr) + } + + s.wg.Done() + s.ongoingRequests.Store(0) + s.ongoingRequestCountGauge.Set(0) + }() + +recvLoop: + for { + select { + case <-ctx.Done(): + finishWithErr = context.Canceled + break recvLoop + default: + } + + res, err := s.stream.Recv() + + // Load the corresponding batchedRequests + var req batchedRequests + select { + case req = <-s.pendingRequests: + default: + finishWithErr = errors.New("tsoStream timing order broken") + break + } + + durationSeconds := time.Since(req.startTime).Seconds() + + if err != nil { + requestFailedDurationTSO.Observe(durationSeconds) + if err == io.EOF { + finishWithErr = errs.ErrClientTSOStreamClosed + } else { + finishWithErr = errors.WithStack(err) + } + break + } + + latencySeconds := durationSeconds + requestDurationTSO.Observe(latencySeconds) + tsoBatchSize.Observe(float64(res.count)) + + if res.count != uint32(req.count) { + finishWithErr = errors.WithStack(errTSOLength) + break + } + + req.callback(res, req.reqKeyspaceGroupID, s.serverURL, nil) + s.ongoingRequestCountGauge.Set(float64(s.ongoingRequests.Add(-1))) } +} - respKeyspaceGroupID = res.respKeyspaceGroupID - physical, logical, suffixBits = res.physical, res.logical, res.suffixBits - return +func (s *tsoStream) Close() { + s.cancel() + s.wg.Wait() } From 88d9f5b0fe8d1777a69ef4fb0961b6d000c868b0 Mon Sep 17 00:00:00 2001 From: MyonKeminta Date: Wed, 7 Aug 2024 17:34:29 +0800 Subject: [PATCH 02/34] Add basic tests for rewritten tsoStream Signed-off-by: MyonKeminta --- client/tso_stream.go | 61 ++++++--- client/tso_stream_test.go | 263 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 307 insertions(+), 17 deletions(-) create mode 100644 client/tso_stream_test.go diff --git a/client/tso_stream.go b/client/tso_stream.go index 74fbc620cef..4e18b5e31d7 100644 --- a/client/tso_stream.go +++ b/client/tso_stream.go @@ -302,8 +302,25 @@ func (s *tsoStream) processRequests( func (s *tsoStream) recvLoop(ctx context.Context) { var finishWithErr error + var currentReq batchedRequests + var hasReq bool defer func() { + if r := recover(); r != nil { + log.Fatal("tsoStream.recvLoop internal panic", zap.Stack("stacktrace"), zap.Any("panicMessage", r)) + } + + if finishWithErr == nil { + // The loop must exit with a non-nil error (including io.EOF and context.Canceled). This should be + // unreachable code. + log.Fatal("tsoStream.recvLoop exited without error info") + } + + if hasReq { + // There's an unfinished request, cancel it, otherwise it will be blocked forever. + currentReq.callback(tsoRequestResult{}, currentReq.reqKeyspaceGroupID, s.serverURL, finishWithErr) + } + s.cancel() for !s.state.CompareAndSwap(streamStateIdle, streamStateClosing) { switch state := s.state.Load(); state { @@ -318,13 +335,11 @@ func (s *tsoStream) recvLoop(ctx context.Context) { } } - // The loop must end with an error (including context.Canceled). - if finishWithErr == nil { - log.Fatal("tsoStream recvLoop ended without error", zap.String("stream", s.streamID)) - } log.Info("tsoStream.recvLoop ended", zap.String("stream", s.streamID), zap.Error(finishWithErr)) close(s.pendingRequests) + + // Cancel remaining pending requests. for req := range s.pendingRequests { req.callback(tsoRequestResult{}, req.reqKeyspaceGroupID, s.serverURL, finishWithErr) } @@ -345,42 +360,54 @@ recvLoop: res, err := s.stream.Recv() - // Load the corresponding batchedRequests - var req batchedRequests + // Try to load the corresponding `batchedRequests`. If `Recv` is successful, there must be a request pending + // in the queue. select { - case req = <-s.pendingRequests: + case currentReq = <-s.pendingRequests: + hasReq = true default: - finishWithErr = errors.New("tsoStream timing order broken") - break + hasReq = false } - durationSeconds := time.Since(req.startTime).Seconds() + durationSeconds := time.Since(currentReq.startTime).Seconds() if err != nil { - requestFailedDurationTSO.Observe(durationSeconds) + // If a request is pending and error occurs, observe the duration it has cost. + // Note that it's also possible that the stream is broken due to network without being requested. In this + // case, `Recv` may return an error while no request is pending. + if hasReq { + requestFailedDurationTSO.Observe(durationSeconds) + } if err == io.EOF { finishWithErr = errs.ErrClientTSOStreamClosed } else { finishWithErr = errors.WithStack(err) } - break + break recvLoop + } else if !hasReq { + finishWithErr = errors.New("tsoStream timing order broken") + break recvLoop } latencySeconds := durationSeconds requestDurationTSO.Observe(latencySeconds) tsoBatchSize.Observe(float64(res.count)) - if res.count != uint32(req.count) { + if res.count != uint32(currentReq.count) { finishWithErr = errors.WithStack(errTSOLength) - break + break recvLoop } - req.callback(res, req.reqKeyspaceGroupID, s.serverURL, nil) + currentReq.callback(res, currentReq.reqKeyspaceGroupID, s.serverURL, nil) + // After finishing the requests, unset these variables which will be checked in the defer block. + currentReq = batchedRequests{} + hasReq = false + s.ongoingRequestCountGauge.Set(float64(s.ongoingRequests.Add(-1))) } } -func (s *tsoStream) Close() { - s.cancel() +// WaitForClosed blocks until the stream is closed and the inner loop exits. +func (s *tsoStream) WaitForClosed() { s.wg.Wait() } diff --git a/client/tso_stream_test.go b/client/tso_stream_test.go new file mode 100644 index 00000000000..47af677ca1f --- /dev/null +++ b/client/tso_stream_test.go @@ -0,0 +1,263 @@ +// Copyright 2024 TiKV Project Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package pd + +import ( + "context" + "io" + "testing" + "time" + + "github.com/pingcap/errors" + "github.com/stretchr/testify/suite" + "github.com/tikv/pd/client/errs" +) + +type resultMsg struct { + r tsoRequestResult + err error + breakStream bool +} + +type mockTSOStreamImpl struct { + requestCh chan struct{} + resultCh chan resultMsg + keyspaceID uint32 +} + +func newMockTSOStreamImpl() *mockTSOStreamImpl { + return &mockTSOStreamImpl{ + requestCh: make(chan struct{}, 64), + resultCh: make(chan resultMsg, 64), + keyspaceID: 0, + } +} + +func (s *mockTSOStreamImpl) Send(clusterID uint64, keyspaceID, keyspaceGroupID uint32, dcLocation string, count int64) error { + s.requestCh <- struct{}{} + return nil +} + +func (s *mockTSOStreamImpl) Recv() (tsoRequestResult, error) { + res := <-s.resultCh + if !res.breakStream { + <-s.requestCh + } + return res.r, res.err +} + +func (s *mockTSOStreamImpl) returnResult(physical int64, logical int64, count uint32) { + s.resultCh <- resultMsg{ + r: tsoRequestResult{ + physical: physical, + logical: logical, + count: count, + suffixBits: 0, + respKeyspaceGroupID: s.keyspaceID, + }, + } +} + +func (s *mockTSOStreamImpl) returnError(err error) { + s.resultCh <- resultMsg{ + err: err, + } +} + +func (s *mockTSOStreamImpl) breakStream(err error) { + s.resultCh <- resultMsg{ + err: err, + breakStream: true, + } +} + +func (s *mockTSOStreamImpl) stop() { + s.breakStream(io.EOF) +} + +type callbackInvocation struct { + result tsoRequestResult + streamURL string + err error +} + +type testTSOStreamSuite struct { + suite.Suite + + inner *mockTSOStreamImpl + stream *tsoStream +} + +func (s *testTSOStreamSuite) SetupTest() { + s.inner = newMockTSOStreamImpl() + s.stream = newTSOStream("mock:///", s.inner) +} + +func (s *testTSOStreamSuite) TearDownTest() { + s.inner.stop() + s.stream.WaitForClosed() + s.inner = nil + s.stream = nil +} + +func TestTSOStreamTestSuite(t *testing.T) { + suite.Run(t, new(testTSOStreamSuite)) +} + +func (s *testTSOStreamSuite) noResult(ch <-chan callbackInvocation) { + select { + case res := <-ch: + s.FailNowf("result received unexpectedly", "received result: %+v", res) + case <-time.After(time.Millisecond * 20): + } +} + +func (s *testTSOStreamSuite) getResult(ch <-chan callbackInvocation) callbackInvocation { + select { + case res := <-ch: + return res + case <-time.After(time.Second * 10000): + s.FailNow("result not ready in time") + panic("result not ready in time") + } +} + +func (s *testTSOStreamSuite) processRequestWithResultCh(count int64) <-chan callbackInvocation { + ch := make(chan callbackInvocation, 1) + err := s.stream.processRequests(1, 2, 3, globalDCLocation, count, time.Now(), func(result tsoRequestResult, reqKeyspaceGroupID uint32, streamURL string, err error) { + if err == nil { + s.Equal(uint32(3), reqKeyspaceGroupID) + s.Equal(uint32(0), result.suffixBits) + } + ch <- callbackInvocation{ + result: result, + streamURL: streamURL, + err: err, + } + }) + s.NoError(err) + return ch +} + +func (s *testTSOStreamSuite) TestTSOStreamBasic() { + ch := s.processRequestWithResultCh(1) + s.noResult(ch) + s.inner.returnResult(10, 1, 1) + res := s.getResult(ch) + + s.NoError(res.err) + s.Equal("mock:///", res.streamURL) + s.Equal(int64(10), res.result.physical) + s.Equal(int64(1), res.result.logical) + s.Equal(uint32(1), res.result.count) + + ch = s.processRequestWithResultCh(2) + s.noResult(ch) + s.inner.returnResult(20, 3, 2) + res = s.getResult(ch) + + s.NoError(res.err) + s.Equal("mock:///", res.streamURL) + s.Equal(int64(20), res.result.physical) + s.Equal(int64(3), res.result.logical) + s.Equal(uint32(2), res.result.count) + + ch = s.processRequestWithResultCh(3) + s.noResult(ch) + s.inner.returnError(errors.New("mock rpc error")) + res = s.getResult(ch) + s.Error(res.err) + s.Equal("mock rpc error", res.err.Error()) + + // After an error from the (simulated) RPC stream, the tsoStream should be in a broken status and can't accept + // new request anymore. + err := s.stream.processRequests(1, 2, 3, globalDCLocation, 1, time.Now(), func(result tsoRequestResult, reqKeyspaceGroupID uint32, streamURL string, err error) { + panic("unreachable") + }) + s.Error(err) +} + +func (s *testTSOStreamSuite) testTSOStreamBrokenImpl(err error, pendingRequests int) { + var resultCh []<-chan callbackInvocation + + for i := 0; i < pendingRequests; i++ { + ch := s.processRequestWithResultCh(1) + resultCh = append(resultCh, ch) + s.noResult(ch) + } + + s.inner.breakStream(err) + closedCh := make(chan struct{}) + go func() { + s.stream.WaitForClosed() + closedCh <- struct{}{} + }() + select { + case <-closedCh: + case <-time.After(time.Second): + s.FailNow("stream receiver loop didn't exit") + } + + for _, ch := range resultCh { + res := s.getResult(ch) + s.Error(res.err) + if err == io.EOF { + s.ErrorIs(res.err, errs.ErrClientTSOStreamClosed) + } else { + s.ErrorIs(res.err, err) + } + } +} + +func (s *testTSOStreamSuite) TestTSOStreamBrokenWithEOFNoPendingReq() { + s.testTSOStreamBrokenImpl(io.EOF, 0) +} + +func (s *testTSOStreamSuite) TestTSOStreamCanceledNoPendingReq() { + s.testTSOStreamBrokenImpl(context.Canceled, 0) +} + +func (s *testTSOStreamSuite) TestTSOStreamBrokenWithEOFWithPendingReq() { + s.testTSOStreamBrokenImpl(io.EOF, 5) +} + +func (s *testTSOStreamSuite) TestTSOStreamCanceledWithPendingReq() { + s.testTSOStreamBrokenImpl(context.Canceled, 5) +} + +func (s *testTSOStreamSuite) TestTSOStreamFIFO() { + var resultChs []<-chan callbackInvocation + const COUNT = 5 + for i := 0; i < COUNT; i++ { + ch := s.processRequestWithResultCh(int64(i + 1)) + resultChs = append(resultChs, ch) + } + + for _, ch := range resultChs { + s.noResult(ch) + } + + for i := 0; i < COUNT; i++ { + s.inner.returnResult(int64((i+1)*10), int64(i), uint32(i+1)) + } + + for i, ch := range resultChs { + res := s.getResult(ch) + s.NoError(res.err) + s.Equal(int64((i+1)*10), res.result.physical) + s.Equal(int64(i), res.result.logical) + s.Equal(uint32(i+1), res.result.count) + } +} From 623d896f6b8e99dd07cf014e1267e0bf090d3d4a Mon Sep 17 00:00:00 2001 From: MyonKeminta Date: Thu, 8 Aug 2024 16:35:16 +0800 Subject: [PATCH 03/34] Add a concurrency test Signed-off-by: MyonKeminta --- client/tso_stream_test.go | 89 +++++++++++++++++++++++++++++++++++---- 1 file changed, 80 insertions(+), 9 deletions(-) diff --git a/client/tso_stream_test.go b/client/tso_stream_test.go index 47af677ca1f..46b308d1f2c 100644 --- a/client/tso_stream_test.go +++ b/client/tso_stream_test.go @@ -35,6 +35,7 @@ type mockTSOStreamImpl struct { requestCh chan struct{} resultCh chan resultMsg keyspaceID uint32 + errorState error } func newMockTSOStreamImpl() *mockTSOStreamImpl { @@ -51,10 +52,17 @@ func (s *mockTSOStreamImpl) Send(clusterID uint64, keyspaceID, keyspaceGroupID u } func (s *mockTSOStreamImpl) Recv() (tsoRequestResult, error) { + // This stream have ever receive an error, it returns the error forever. + if s.errorState != nil { + return tsoRequestResult{}, s.errorState + } res := <-s.resultCh if !res.breakStream { <-s.requestCh } + if res.err != nil { + s.errorState = res.err + } return res.r, res.err } @@ -134,7 +142,7 @@ func (s *testTSOStreamSuite) getResult(ch <-chan callbackInvocation) callbackInv } } -func (s *testTSOStreamSuite) processRequestWithResultCh(count int64) <-chan callbackInvocation { +func (s *testTSOStreamSuite) processRequestWithResultCh(count int64) (<-chan callbackInvocation, error) { ch := make(chan callbackInvocation, 1) err := s.stream.processRequests(1, 2, 3, globalDCLocation, count, time.Now(), func(result tsoRequestResult, reqKeyspaceGroupID uint32, streamURL string, err error) { if err == nil { @@ -147,12 +155,20 @@ func (s *testTSOStreamSuite) processRequestWithResultCh(count int64) <-chan call err: err, } }) + if err != nil { + return nil, err + } + return ch, nil +} + +func (s *testTSOStreamSuite) mustProcessRequestWithResultCh(count int64) <-chan callbackInvocation { + ch, err := s.processRequestWithResultCh(count) s.NoError(err) return ch } func (s *testTSOStreamSuite) TestTSOStreamBasic() { - ch := s.processRequestWithResultCh(1) + ch := s.mustProcessRequestWithResultCh(1) s.noResult(ch) s.inner.returnResult(10, 1, 1) res := s.getResult(ch) @@ -163,7 +179,7 @@ func (s *testTSOStreamSuite) TestTSOStreamBasic() { s.Equal(int64(1), res.result.logical) s.Equal(uint32(1), res.result.count) - ch = s.processRequestWithResultCh(2) + ch = s.mustProcessRequestWithResultCh(2) s.noResult(ch) s.inner.returnResult(20, 3, 2) res = s.getResult(ch) @@ -174,7 +190,7 @@ func (s *testTSOStreamSuite) TestTSOStreamBasic() { s.Equal(int64(3), res.result.logical) s.Equal(uint32(2), res.result.count) - ch = s.processRequestWithResultCh(3) + ch = s.mustProcessRequestWithResultCh(3) s.noResult(ch) s.inner.returnError(errors.New("mock rpc error")) res = s.getResult(ch) @@ -193,7 +209,7 @@ func (s *testTSOStreamSuite) testTSOStreamBrokenImpl(err error, pendingRequests var resultCh []<-chan callbackInvocation for i := 0; i < pendingRequests; i++ { - ch := s.processRequestWithResultCh(1) + ch := s.mustProcessRequestWithResultCh(1) resultCh = append(resultCh, ch) s.noResult(ch) } @@ -239,9 +255,9 @@ func (s *testTSOStreamSuite) TestTSOStreamCanceledWithPendingReq() { func (s *testTSOStreamSuite) TestTSOStreamFIFO() { var resultChs []<-chan callbackInvocation - const COUNT = 5 - for i := 0; i < COUNT; i++ { - ch := s.processRequestWithResultCh(int64(i + 1)) + const count = 5 + for i := 0; i < count; i++ { + ch := s.mustProcessRequestWithResultCh(int64(i + 1)) resultChs = append(resultChs, ch) } @@ -249,7 +265,7 @@ func (s *testTSOStreamSuite) TestTSOStreamFIFO() { s.noResult(ch) } - for i := 0; i < COUNT; i++ { + for i := 0; i < count; i++ { s.inner.returnResult(int64((i+1)*10), int64(i), uint32(i+1)) } @@ -261,3 +277,58 @@ func (s *testTSOStreamSuite) TestTSOStreamFIFO() { s.Equal(uint32(i+1), res.result.count) } } + +func (s *testTSOStreamSuite) TestTSOStreamConcurrentRunning() { + resultChCh := make(chan (<-chan callbackInvocation), 10000) + const totalCount = 10000 + + // Continuously start requests + go func() { + for i := 1; i <= totalCount; i++ { + // Retry loop + for { + ch, err := s.processRequestWithResultCh(int64(i)) + if err != nil { + // If the capacity of the request queue is exhausted, it returns this error. As a test, we simply + // spin and retry it until it has enough space, as a coverage of the almost-full case. But note that + // this should not happen in production, in which case the caller of tsoStream should have its own + // limit of concurrent RPC requests. + s.Contains(err.Error(), "unexpected channel full") + continue + } + + resultChCh <- ch + break + } + } + }() + + // Continuously send results + go func() { + for i := int64(1); i <= totalCount; i++ { + s.inner.returnResult(i*10, i%(1<<18), uint32(i)) + } + s.inner.breakStream(io.EOF) + }() + + // Check results + for i := int64(1); i <= totalCount; i++ { + ch := <-resultChCh + res := s.getResult(ch) + s.NoError(res.err) + s.Equal(i*10, res.result.physical) + s.Equal(i%(1<<18), res.result.logical) + s.Equal(uint32(i), res.result.count) + } + + // After handling all these requests, the stream is ended by an EOF error. The next request won't succeed. + // So, either the `processRequests` function returns an error or the callback is called with an error. + ch, err := s.processRequestWithResultCh(1) + if err != nil { + s.ErrorIs(err, errs.ErrClientTSOStreamClosed) + } else { + res := s.getResult(ch) + s.Error(res.err) + s.ErrorIs(res.err, errs.ErrClientTSOStreamClosed) + } +} From 86d0f42d2fe76542f8c2fefe3efd582f87fd70dd Mon Sep 17 00:00:00 2001 From: MyonKeminta Date: Thu, 8 Aug 2024 17:22:48 +0800 Subject: [PATCH 04/34] Fix some of CI failures Signed-off-by: MyonKeminta --- client/tso_batch_controller.go | 18 ++++---- client/tso_dispatcher.go | 5 +- client/tso_stream.go | 11 ++--- client/tso_stream_test.go | 83 +++++++++++++++++----------------- 4 files changed, 56 insertions(+), 61 deletions(-) diff --git a/client/tso_batch_controller.go b/client/tso_batch_controller.go index bf072b24535..457ca7c1078 100644 --- a/client/tso_batch_controller.go +++ b/client/tso_batch_controller.go @@ -19,10 +19,7 @@ import ( "runtime/trace" "time" - "github.com/pingcap/errors" - "github.com/pingcap/log" "github.com/tikv/pd/client/tsoutil" - "go.uber.org/zap" ) type tsoBatchController struct { @@ -186,10 +183,11 @@ func (tbc *tsoBatchController) finishCollectedRequests(physical, firstLogical in tbc.collectedRequestCount = 0 } -func (tbc *tsoBatchController) clear() { - log.Info("[pd] clear the tso batch controller", - zap.Int("max-batch-size", tbc.maxBatchSize), zap.Int("best-batch-size", tbc.bestBatchSize), - zap.Int("collected-request-count", tbc.collectedRequestCount)) - tsoErr := errors.WithStack(errClosing) - tbc.finishCollectedRequests(0, 0, 0, tsoErr) -} +// +//func (tbc *tsoBatchController) clear() { +// log.Info("[pd] clear the tso batch controller", +// zap.Int("max-batch-size", tbc.maxBatchSize), zap.Int("best-batch-size", tbc.bestBatchSize), +// zap.Int("collected-request-count", tbc.collectedRequestCount)) +// tsoErr := errors.WithStack(errClosing) +// tbc.finishCollectedRequests(0, 0, 0, tsoErr) +//} diff --git a/client/tso_dispatcher.go b/client/tso_dispatcher.go index 2e670c656c9..c382bb3c699 100644 --- a/client/tso_dispatcher.go +++ b/client/tso_dispatcher.go @@ -197,7 +197,7 @@ func (td *tsoDispatcher) handleDispatcher(wg *sync.WaitGroup) { cc.(*tsoConnectionContext).cancel() return true }) - if batchController.collectedRequestCount != 0 { + if batchController != nil && batchController.collectedRequestCount != 0 { log.Fatal("batched tso requests not cleared when exiting the tso dispatcher loop") } tsoErr := errors.WithStack(errClosing) @@ -225,6 +225,7 @@ tsoBatchLoop: return default: } + batchController = td.batchBufferPool.Get().(*tsoBatchController) // Start to collect the TSO requests. maxBatchWaitInterval := option.getMaxTSOBatchWaitInterval() // Once the TSO requests are collected, must make sure they could be finished or revoked eventually, @@ -454,7 +455,7 @@ func (td *tsoDispatcher) processRequests( reqKeyspaceGroupID = svcDiscovery.GetKeyspaceGroupID() ) - cb := func(result tsoRequestResult, reqKeyspaceGroupID uint32, streamURL string, err error) { + cb := func(result tsoRequestResult, reqKeyspaceGroupID uint32, err error) { defer td.batchBufferPool.Put(tbc) if err != nil { td.cancelCollectedRequests(tbc, err) diff --git a/client/tso_stream.go b/client/tso_stream.go index 4e18b5e31d7..c58cf3eb841 100644 --- a/client/tso_stream.go +++ b/client/tso_stream.go @@ -178,7 +178,7 @@ func (s tsoTSOStreamAdapter) Recv() (tsoRequestResult, error) { }, nil } -type onFinishedCallback func(result tsoRequestResult, reqKeyspaceGroupID uint32, streamURL string, err error) +type onFinishedCallback func(result tsoRequestResult, reqKeyspaceGroupID uint32, err error) type batchedRequests struct { startTime time.Time @@ -203,8 +203,6 @@ type tsoStream struct { pendingRequests chan batchedRequests - estimateLatencyMicros atomic.Uint64 - cancel context.CancelFunc wg sync.WaitGroup @@ -265,7 +263,6 @@ func (s *tsoStream) processRequests( switch prevState { case streamStateIdle: // Expected case - break case streamStateClosing: s.state.Store(prevState) log.Info("tsoStream closed") @@ -318,7 +315,7 @@ func (s *tsoStream) recvLoop(ctx context.Context) { if hasReq { // There's an unfinished request, cancel it, otherwise it will be blocked forever. - currentReq.callback(tsoRequestResult{}, currentReq.reqKeyspaceGroupID, s.serverURL, finishWithErr) + currentReq.callback(tsoRequestResult{}, currentReq.reqKeyspaceGroupID, finishWithErr) } s.cancel() @@ -341,7 +338,7 @@ func (s *tsoStream) recvLoop(ctx context.Context) { // Cancel remaining pending requests. for req := range s.pendingRequests { - req.callback(tsoRequestResult{}, req.reqKeyspaceGroupID, s.serverURL, finishWithErr) + req.callback(tsoRequestResult{}, req.reqKeyspaceGroupID, finishWithErr) } s.wg.Done() @@ -398,7 +395,7 @@ recvLoop: break recvLoop } - currentReq.callback(res, currentReq.reqKeyspaceGroupID, s.serverURL, nil) + currentReq.callback(res, currentReq.reqKeyspaceGroupID, nil) // After finishing the requests, unset these variables which will be checked in the defer block. currentReq = batchedRequests{} hasReq = false diff --git a/client/tso_stream_test.go b/client/tso_stream_test.go index 46b308d1f2c..8bcb0c29a4b 100644 --- a/client/tso_stream_test.go +++ b/client/tso_stream_test.go @@ -21,6 +21,7 @@ import ( "time" "github.com/pingcap/errors" + "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" "github.com/tikv/pd/client/errs" ) @@ -46,7 +47,7 @@ func newMockTSOStreamImpl() *mockTSOStreamImpl { } } -func (s *mockTSOStreamImpl) Send(clusterID uint64, keyspaceID, keyspaceGroupID uint32, dcLocation string, count int64) error { +func (s *mockTSOStreamImpl) Send(_clusterID uint64, _keyspaceID, _keyspaceGroupID uint32, _dcLocation string, _count int64) error { s.requestCh <- struct{}{} return nil } @@ -96,19 +97,20 @@ func (s *mockTSOStreamImpl) stop() { } type callbackInvocation struct { - result tsoRequestResult - streamURL string - err error + result tsoRequestResult + err error } type testTSOStreamSuite struct { suite.Suite + re *require.Assertions inner *mockTSOStreamImpl stream *tsoStream } func (s *testTSOStreamSuite) SetupTest() { + s.re = require.New(s.T()) s.inner = newMockTSOStreamImpl() s.stream = newTSOStream("mock:///", s.inner) } @@ -127,7 +129,7 @@ func TestTSOStreamTestSuite(t *testing.T) { func (s *testTSOStreamSuite) noResult(ch <-chan callbackInvocation) { select { case res := <-ch: - s.FailNowf("result received unexpectedly", "received result: %+v", res) + s.re.FailNowf("result received unexpectedly", "received result: %+v", res) case <-time.After(time.Millisecond * 20): } } @@ -137,22 +139,21 @@ func (s *testTSOStreamSuite) getResult(ch <-chan callbackInvocation) callbackInv case res := <-ch: return res case <-time.After(time.Second * 10000): - s.FailNow("result not ready in time") + s.re.FailNow("result not ready in time") panic("result not ready in time") } } func (s *testTSOStreamSuite) processRequestWithResultCh(count int64) (<-chan callbackInvocation, error) { ch := make(chan callbackInvocation, 1) - err := s.stream.processRequests(1, 2, 3, globalDCLocation, count, time.Now(), func(result tsoRequestResult, reqKeyspaceGroupID uint32, streamURL string, err error) { + err := s.stream.processRequests(1, 2, 3, globalDCLocation, count, time.Now(), func(result tsoRequestResult, reqKeyspaceGroupID uint32, err error) { if err == nil { - s.Equal(uint32(3), reqKeyspaceGroupID) - s.Equal(uint32(0), result.suffixBits) + s.re.Equal(uint32(3), reqKeyspaceGroupID) + s.re.Equal(uint32(0), result.suffixBits) } ch <- callbackInvocation{ - result: result, - streamURL: streamURL, - err: err, + result: result, + err: err, } }) if err != nil { @@ -163,7 +164,7 @@ func (s *testTSOStreamSuite) processRequestWithResultCh(count int64) (<-chan cal func (s *testTSOStreamSuite) mustProcessRequestWithResultCh(count int64) <-chan callbackInvocation { ch, err := s.processRequestWithResultCh(count) - s.NoError(err) + s.re.NoError(err) return ch } @@ -173,36 +174,34 @@ func (s *testTSOStreamSuite) TestTSOStreamBasic() { s.inner.returnResult(10, 1, 1) res := s.getResult(ch) - s.NoError(res.err) - s.Equal("mock:///", res.streamURL) - s.Equal(int64(10), res.result.physical) - s.Equal(int64(1), res.result.logical) - s.Equal(uint32(1), res.result.count) + s.re.NoError(res.err) + s.re.Equal(int64(10), res.result.physical) + s.re.Equal(int64(1), res.result.logical) + s.re.Equal(uint32(1), res.result.count) ch = s.mustProcessRequestWithResultCh(2) s.noResult(ch) s.inner.returnResult(20, 3, 2) res = s.getResult(ch) - s.NoError(res.err) - s.Equal("mock:///", res.streamURL) - s.Equal(int64(20), res.result.physical) - s.Equal(int64(3), res.result.logical) - s.Equal(uint32(2), res.result.count) + s.re.NoError(res.err) + s.re.Equal(int64(20), res.result.physical) + s.re.Equal(int64(3), res.result.logical) + s.re.Equal(uint32(2), res.result.count) ch = s.mustProcessRequestWithResultCh(3) s.noResult(ch) s.inner.returnError(errors.New("mock rpc error")) res = s.getResult(ch) - s.Error(res.err) - s.Equal("mock rpc error", res.err.Error()) + s.re.Error(res.err) + s.re.Equal("mock rpc error", res.err.Error()) // After an error from the (simulated) RPC stream, the tsoStream should be in a broken status and can't accept // new request anymore. - err := s.stream.processRequests(1, 2, 3, globalDCLocation, 1, time.Now(), func(result tsoRequestResult, reqKeyspaceGroupID uint32, streamURL string, err error) { + err := s.stream.processRequests(1, 2, 3, globalDCLocation, 1, time.Now(), func(_result tsoRequestResult, _reqKeyspaceGroupID uint32, _err error) { panic("unreachable") }) - s.Error(err) + s.re.Error(err) } func (s *testTSOStreamSuite) testTSOStreamBrokenImpl(err error, pendingRequests int) { @@ -223,16 +222,16 @@ func (s *testTSOStreamSuite) testTSOStreamBrokenImpl(err error, pendingRequests select { case <-closedCh: case <-time.After(time.Second): - s.FailNow("stream receiver loop didn't exit") + s.re.FailNow("stream receiver loop didn't exit") } for _, ch := range resultCh { res := s.getResult(ch) - s.Error(res.err) + s.re.Error(res.err) if err == io.EOF { - s.ErrorIs(res.err, errs.ErrClientTSOStreamClosed) + s.re.ErrorIs(res.err, errs.ErrClientTSOStreamClosed) } else { - s.ErrorIs(res.err, err) + s.re.ErrorIs(res.err, err) } } } @@ -271,10 +270,10 @@ func (s *testTSOStreamSuite) TestTSOStreamFIFO() { for i, ch := range resultChs { res := s.getResult(ch) - s.NoError(res.err) - s.Equal(int64((i+1)*10), res.result.physical) - s.Equal(int64(i), res.result.logical) - s.Equal(uint32(i+1), res.result.count) + s.re.NoError(res.err) + s.re.Equal(int64((i+1)*10), res.result.physical) + s.re.Equal(int64(i), res.result.logical) + s.re.Equal(uint32(i+1), res.result.count) } } @@ -315,20 +314,20 @@ func (s *testTSOStreamSuite) TestTSOStreamConcurrentRunning() { for i := int64(1); i <= totalCount; i++ { ch := <-resultChCh res := s.getResult(ch) - s.NoError(res.err) - s.Equal(i*10, res.result.physical) - s.Equal(i%(1<<18), res.result.logical) - s.Equal(uint32(i), res.result.count) + s.re.NoError(res.err) + s.re.Equal(i*10, res.result.physical) + s.re.Equal(i%(1<<18), res.result.logical) + s.re.Equal(uint32(i), res.result.count) } // After handling all these requests, the stream is ended by an EOF error. The next request won't succeed. // So, either the `processRequests` function returns an error or the callback is called with an error. ch, err := s.processRequestWithResultCh(1) if err != nil { - s.ErrorIs(err, errs.ErrClientTSOStreamClosed) + s.re.ErrorIs(err, errs.ErrClientTSOStreamClosed) } else { res := s.getResult(ch) - s.Error(res.err) - s.ErrorIs(res.err, errs.ErrClientTSOStreamClosed) + s.re.Error(res.err) + s.re.ErrorIs(res.err, errs.ErrClientTSOStreamClosed) } } From ff0f11a6734fa654d77303e47874b3a893a5aca4 Mon Sep 17 00:00:00 2001 From: MyonKeminta Date: Fri, 30 Aug 2024 17:21:56 +0800 Subject: [PATCH 05/34] Fix integration tests Signed-off-by: MyonKeminta --- client/tso_stream.go | 18 ++++++++++++------ .../mcs/tso/keyspace_group_manager_test.go | 7 +++++-- 2 files changed, 17 insertions(+), 8 deletions(-) diff --git a/client/tso_stream.go b/client/tso_stream.go index c58cf3eb841..5d790e7b843 100644 --- a/client/tso_stream.go +++ b/client/tso_stream.go @@ -207,7 +207,8 @@ type tsoStream struct { wg sync.WaitGroup // For syncing between sender and receiver to guarantee all requests are finished when closing. - state atomic.Int32 + state atomic.Int32 + stoppedWithErr error ongoingRequestCountGauge prometheus.Gauge ongoingRequests atomic.Int32 @@ -265,8 +266,12 @@ func (s *tsoStream) processRequests( // Expected case case streamStateClosing: s.state.Store(prevState) - log.Info("tsoStream closed") - return errs.ErrClientTSOStreamClosed + err := s.stoppedWithErr + log.Info("sending to closed tsoStream", zap.Error(err)) + if err == nil { + err = errors.WithStack(errs.ErrClientTSOStreamClosed) + } + return err case streamStateSending: log.Fatal("unexpected concurrent sending on tsoStream", zap.String("stream", s.streamID)) default: @@ -288,7 +293,7 @@ func (s *tsoStream) processRequests( if err := s.stream.Send(clusterID, keyspaceID, keyspaceGroupID, dcLocation, count); err != nil { if err == io.EOF { - return errs.ErrClientTSOStreamClosed + return errors.WithStack(errs.ErrClientTSOStreamClosed) } return errors.WithStack(err) } @@ -318,6 +323,7 @@ func (s *tsoStream) recvLoop(ctx context.Context) { currentReq.callback(tsoRequestResult{}, currentReq.reqKeyspaceGroupID, finishWithErr) } + s.stoppedWithErr = errors.WithStack(finishWithErr) s.cancel() for !s.state.CompareAndSwap(streamStateIdle, streamStateClosing) { switch state := s.state.Load(); state { @@ -338,7 +344,7 @@ func (s *tsoStream) recvLoop(ctx context.Context) { // Cancel remaining pending requests. for req := range s.pendingRequests { - req.callback(tsoRequestResult{}, req.reqKeyspaceGroupID, finishWithErr) + req.callback(tsoRequestResult{}, req.reqKeyspaceGroupID, errors.WithStack(finishWithErr)) } s.wg.Done() @@ -376,7 +382,7 @@ recvLoop: requestFailedDurationTSO.Observe(durationSeconds) } if err == io.EOF { - finishWithErr = errs.ErrClientTSOStreamClosed + finishWithErr = errors.WithStack(errs.ErrClientTSOStreamClosed) } else { finishWithErr = errors.WithStack(err) } diff --git a/tests/integrations/mcs/tso/keyspace_group_manager_test.go b/tests/integrations/mcs/tso/keyspace_group_manager_test.go index 2acbbcc3b42..eae586b5b20 100644 --- a/tests/integrations/mcs/tso/keyspace_group_manager_test.go +++ b/tests/integrations/mcs/tso/keyspace_group_manager_test.go @@ -16,6 +16,8 @@ package tso import ( "context" + "errors" + "fmt" "math/rand" "net/http" "strings" @@ -472,10 +474,11 @@ func (suite *tsoKeyspaceGroupManagerTestSuite) dispatchClient( strings.Contains(errMsg, clierrs.NotLeaderErr) || strings.Contains(errMsg, clierrs.NotServedErr) || strings.Contains(errMsg, "ErrKeyspaceNotAssigned") || - strings.Contains(errMsg, "ErrKeyspaceGroupIsMerging") { + strings.Contains(errMsg, "ErrKeyspaceGroupIsMerging") || + errors.Is(err, clierrs.ErrClientTSOStreamClosed) { continue } - re.FailNow(errMsg) + re.FailNow(fmt.Sprintf("%+v", err)) } if physical == lastPhysical { re.Greater(logical, lastLogical) From ef47b9d9e50c1a7dbdc7a6c2055833d925bc46f8 Mon Sep 17 00:00:00 2001 From: MyonKeminta Date: Mon, 2 Sep 2024 16:38:29 +0800 Subject: [PATCH 06/34] Fix tso deadline Signed-off-by: MyonKeminta --- client/tso_dispatcher.go | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/client/tso_dispatcher.go b/client/tso_dispatcher.go index c382bb3c699..5fbf47f927b 100644 --- a/client/tso_dispatcher.go +++ b/client/tso_dispatcher.go @@ -225,7 +225,12 @@ tsoBatchLoop: return default: } - batchController = td.batchBufferPool.Get().(*tsoBatchController) + + // In case error happens, the loop may continue without resetting `batchController` for retrying. + if batchController == nil { + batchController = td.batchBufferPool.Get().(*tsoBatchController) + } + // Start to collect the TSO requests. maxBatchWaitInterval := option.getMaxTSOBatchWaitInterval() // Once the TSO requests are collected, must make sure they could be finished or revoked eventually, @@ -309,8 +314,7 @@ tsoBatchLoop: case td.tsDeadlineCh <- dl: } // processRequests guarantees that the collected requests could be finished properly. - err = td.processRequests(stream, dc, batchController) - close(done) + err = td.processRequests(stream, dc, batchController, done) // If error happens during tso stream handling, reset stream and run the next trial. if err == nil { // If the request is started successfully, the `batchController` will be put back to the pool when the @@ -425,8 +429,9 @@ func chooseStream(connectionCtxs *sync.Map) (connectionCtx *tsoConnectionContext } func (td *tsoDispatcher) processRequests( - stream *tsoStream, dcLocation string, tbc *tsoBatchController, + stream *tsoStream, dcLocation string, tbc *tsoBatchController, done chan struct{}, ) error { + // `done` must be guaranteed to be eventually called. var ( requests = tbc.getCollectedRequests() traceRegions = make([]*trace.Region, 0, len(requests)) @@ -456,6 +461,8 @@ func (td *tsoDispatcher) processRequests( ) cb := func(result tsoRequestResult, reqKeyspaceGroupID uint32, err error) { + close(done) + defer td.batchBufferPool.Put(tbc) if err != nil { td.cancelCollectedRequests(tbc, err) @@ -480,6 +487,8 @@ func (td *tsoDispatcher) processRequests( clusterID, keyspaceID, reqKeyspaceGroupID, dcLocation, count, tbc.extraBatchingStartTime, cb) if err != nil { + close(done) + td.cancelCollectedRequests(tbc, err) return err } From ff1d50d0377dcc6bc34d08d1bffa0dc3a3897072 Mon Sep 17 00:00:00 2001 From: MyonKeminta Date: Mon, 2 Sep 2024 18:01:45 +0800 Subject: [PATCH 07/34] Fix stream broken handling Signed-off-by: MyonKeminta --- client/tso_dispatcher.go | 82 ++++++++++++++++++++++++++-------------- client/tso_stream.go | 14 +++++-- 2 files changed, 65 insertions(+), 31 deletions(-) diff --git a/client/tso_dispatcher.go b/client/tso_dispatcher.go index 5fbf47f927b..9b8e66b553c 100644 --- a/client/tso_dispatcher.go +++ b/client/tso_dispatcher.go @@ -198,7 +198,7 @@ func (td *tsoDispatcher) handleDispatcher(wg *sync.WaitGroup) { return true }) if batchController != nil && batchController.collectedRequestCount != 0 { - log.Fatal("batched tso requests not cleared when exiting the tso dispatcher loop") + log.Fatal("batched tso requests not cleared when exiting the tso dispatcher loop", zap.Any("panic", recover())) } tsoErr := errors.WithStack(errClosing) td.revokePendingRequests(tsoErr) @@ -301,8 +301,20 @@ tsoBatchLoop: stream = nil continue default: - break streamChoosingLoop } + + // Check if any error has occurred on this stream when receiving asynchronously. + if err = stream.GetRecvError(); err != nil { + exit := !td.handleProcessRequestError(ctx, bo, streamURL, cancel, err) + stream = nil + if exit { + td.cancelCollectedRequests(batchController, errors.WithStack(ctx.Err())) + return + } + continue + } + + break streamChoosingLoop } done := make(chan struct{}) dl := newTSDeadline(option.timeout, done, cancel) @@ -324,38 +336,52 @@ tsoBatchLoop: // reused in the next loop safely. batchController = nil } else { + exit := !td.handleProcessRequestError(ctx, bo, streamURL, cancel, err) + stream = nil + if exit { + return + } + } + } +} + +// handleProcessRequestError handles errors occurs when trying to process a TSO RPC request for the dispatcher loop. +// Returns true if the dispatcher loop is ok to continue. Otherwise, the dispatcher loop should be exited. +func (td *tsoDispatcher) handleProcessRequestError(ctx context.Context, bo *retry.Backoffer, streamURL string, streamCancelFunc context.CancelFunc, err error) bool { + select { + case <-ctx.Done(): + return false + default: + } + + svcDiscovery := td.provider.getServiceDiscovery() + + svcDiscovery.ScheduleCheckMemberChanged() + log.Error("[tso] getTS error after processing requests", + zap.String("dc-location", td.dc), + zap.String("stream-url", streamURL), + zap.Error(errs.ErrClientGetTSO.FastGenByArgs(err.Error()))) + // Set `stream` to nil and remove this stream from the `connectionCtxs` due to error. + td.connectionCtxs.Delete(streamURL) + streamCancelFunc() + // Because ScheduleCheckMemberChanged is asynchronous, if the leader changes, we better call `updateMember` ASAP. + if errs.IsLeaderChange(err) { + if err := bo.Exec(ctx, svcDiscovery.CheckMemberChanged); err != nil { select { case <-ctx.Done(): - return + return false default: } - svcDiscovery.ScheduleCheckMemberChanged() - log.Error("[tso] getTS error after processing requests", - zap.String("dc-location", dc), - zap.String("stream-url", streamURL), - zap.Error(errs.ErrClientGetTSO.FastGenByArgs(err.Error()))) - // Set `stream` to nil and remove this stream from the `connectionCtxs` due to error. - connectionCtxs.Delete(streamURL) - cancel() - stream = nil - // Because ScheduleCheckMemberChanged is asynchronous, if the leader changes, we better call `updateMember` ASAP. - if errs.IsLeaderChange(err) { - if err := bo.Exec(ctx, svcDiscovery.CheckMemberChanged); err != nil { - select { - case <-ctx.Done(): - return - default: - } - } - // Because the TSO Follower Proxy could be configured online, - // If we change it from on -> off, background updateConnectionCtxs - // will cancel the current stream, then the EOF error caused by cancel() - // should not trigger the updateConnectionCtxs here. - // So we should only call it when the leader changes. - provider.updateConnectionCtxs(ctx, dc, connectionCtxs) - } } + // Because the TSO Follower Proxy could be configured online, + // If we change it from on -> off, background updateConnectionCtxs + // will cancel the current stream, then the EOF error caused by cancel() + // should not trigger the updateConnectionCtxs here. + // So we should only call it when the leader changes. + td.provider.updateConnectionCtxs(ctx, td.dc, td.connectionCtxs) } + + return true } // updateConnectionCtxs updates the `connectionCtxs` for the specified DC location regularly. diff --git a/client/tso_stream.go b/client/tso_stream.go index 5d790e7b843..7394d8fe4ab 100644 --- a/client/tso_stream.go +++ b/client/tso_stream.go @@ -208,7 +208,7 @@ type tsoStream struct { // For syncing between sender and receiver to guarantee all requests are finished when closing. state atomic.Int32 - stoppedWithErr error + stoppedWithErr atomic.Pointer[error] ongoingRequestCountGauge prometheus.Gauge ongoingRequests atomic.Int32 @@ -266,7 +266,7 @@ func (s *tsoStream) processRequests( // Expected case case streamStateClosing: s.state.Store(prevState) - err := s.stoppedWithErr + err := s.GetRecvError() log.Info("sending to closed tsoStream", zap.Error(err)) if err == nil { err = errors.WithStack(errs.ErrClientTSOStreamClosed) @@ -323,7 +323,7 @@ func (s *tsoStream) recvLoop(ctx context.Context) { currentReq.callback(tsoRequestResult{}, currentReq.reqKeyspaceGroupID, finishWithErr) } - s.stoppedWithErr = errors.WithStack(finishWithErr) + s.stoppedWithErr.Store(&finishWithErr) s.cancel() for !s.state.CompareAndSwap(streamStateIdle, streamStateClosing) { switch state := s.state.Load(); state { @@ -410,6 +410,14 @@ recvLoop: } } +func (s *tsoStream) GetRecvError() error { + perr := s.stoppedWithErr.Load() + if perr == nil { + return nil + } + return *perr +} + // WaitForClosed blocks until the stream is closed and the inner loop exits. func (s *tsoStream) WaitForClosed() { s.wg.Wait() From 9076c9dcf78880eab2e9c3351b0f03f5ab2756bc Mon Sep 17 00:00:00 2001 From: MyonKeminta Date: Mon, 2 Sep 2024 18:10:40 +0800 Subject: [PATCH 08/34] Fix comments Signed-off-by: MyonKeminta --- client/tso_batch_controller.go | 9 --------- client/tso_dispatcher.go | 4 ++++ client/tso_stream.go | 1 + 3 files changed, 5 insertions(+), 9 deletions(-) diff --git a/client/tso_batch_controller.go b/client/tso_batch_controller.go index 457ca7c1078..6dc0b57ddb8 100644 --- a/client/tso_batch_controller.go +++ b/client/tso_batch_controller.go @@ -182,12 +182,3 @@ func (tbc *tsoBatchController) finishCollectedRequests(physical, firstLogical in // Prevent the finished requests from being processed again. tbc.collectedRequestCount = 0 } - -// -//func (tbc *tsoBatchController) clear() { -// log.Info("[pd] clear the tso batch controller", -// zap.Int("max-batch-size", tbc.maxBatchSize), zap.Int("best-batch-size", tbc.bestBatchSize), -// zap.Int("collected-request-count", tbc.collectedRequestCount)) -// tsoErr := errors.WithStack(errClosing) -// tbc.finishCollectedRequests(0, 0, 0, tsoErr) -//} diff --git a/client/tso_dispatcher.go b/client/tso_dispatcher.go index 9b8e66b553c..e6a5588aa84 100644 --- a/client/tso_dispatcher.go +++ b/client/tso_dispatcher.go @@ -454,6 +454,10 @@ func chooseStream(connectionCtxs *sync.Map) (connectionCtx *tsoConnectionContext return connectionCtx } +// processRequests sends the RPC request for the batch. It's guaranteed that after calling this function, requests +// in the batch must be eventually finished (done or canceled), either synchronously or asynchronously. +// `close(done)` will be called at the same time when finishing the requests. +// If this function returns a non-nil error, the requests will always be canceled synchronously. func (td *tsoDispatcher) processRequests( stream *tsoStream, dcLocation string, tbc *tsoBatchController, done chan struct{}, ) error { diff --git a/client/tso_stream.go b/client/tso_stream.go index 7394d8fe4ab..1f9038512ff 100644 --- a/client/tso_stream.go +++ b/client/tso_stream.go @@ -410,6 +410,7 @@ recvLoop: } } +// GetRecvError returns the error (if any) that has been encountered when receiving response asynchronously. func (s *tsoStream) GetRecvError() error { perr := s.stoppedWithErr.Load() if perr == nil { From 39e4e0b05ee773857e71c40202513c444c4c5d5a Mon Sep 17 00:00:00 2001 From: MyonKeminta Date: Wed, 4 Sep 2024 17:04:54 +0800 Subject: [PATCH 09/34] try to avoid double invocation to the callback Signed-off-by: MyonKeminta --- client/tso_stream.go | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/client/tso_stream.go b/client/tso_stream.go index 1f9038512ff..bbd1ea4ec87 100644 --- a/client/tso_stream.go +++ b/client/tso_stream.go @@ -292,10 +292,14 @@ func (s *tsoStream) processRequests( s.state.Store(prevState) if err := s.stream.Send(clusterID, keyspaceID, keyspaceGroupID, dcLocation, count); err != nil { - if err == io.EOF { - return errors.WithStack(errs.ErrClientTSOStreamClosed) - } - return errors.WithStack(err) + // As the request is already put into `pendingRequests`, the request should finally be canceled by the recvLoop. + // So skip returning error here to avoid + //if err == io.EOF { + // return errors.WithStack(errs.ErrClientTSOStreamClosed) + //} + //return errors.WithStack(err) + log.Warn("failed to send RPC request through tsoStream", zap.String("stream", s.streamID), zap.Error(err)) + return nil } tsoBatchSendLatency.Observe(time.Since(batchStartTime).Seconds()) s.ongoingRequestCountGauge.Set(float64(s.ongoingRequests.Add(1))) From 632781b2047f73265d26fdec890ce47cee8fac41 Mon Sep 17 00:00:00 2001 From: MyonKeminta Date: Wed, 4 Sep 2024 17:17:21 +0800 Subject: [PATCH 10/34] Fix lint Signed-off-by: MyonKeminta --- client/tso_stream.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/client/tso_stream.go b/client/tso_stream.go index bbd1ea4ec87..279f11db47c 100644 --- a/client/tso_stream.go +++ b/client/tso_stream.go @@ -294,10 +294,10 @@ func (s *tsoStream) processRequests( if err := s.stream.Send(clusterID, keyspaceID, keyspaceGroupID, dcLocation, count); err != nil { // As the request is already put into `pendingRequests`, the request should finally be canceled by the recvLoop. // So skip returning error here to avoid - //if err == io.EOF { - // return errors.WithStack(errs.ErrClientTSOStreamClosed) - //} - //return errors.WithStack(err) + // if err == io.EOF { + // return errors.WithStack(errs.ErrClientTSOStreamClosed) + // } + // return errors.WithStack(err) log.Warn("failed to send RPC request through tsoStream", zap.String("stream", s.streamID), zap.Error(err)) return nil } From 8809c45a00a40904de478af5d34260d3576909fd Mon Sep 17 00:00:00 2001 From: MyonKeminta Date: Fri, 6 Sep 2024 18:16:44 +0800 Subject: [PATCH 11/34] Address comments Signed-off-by: MyonKeminta --- client/tso_batch_controller.go | 7 +++---- client/tso_dispatcher.go | 26 ++++++++++++++------------ client/tso_request.go | 3 +++ client/tso_stream.go | 2 ++ 4 files changed, 22 insertions(+), 16 deletions(-) diff --git a/client/tso_batch_controller.go b/client/tso_batch_controller.go index 6dc0b57ddb8..32191889160 100644 --- a/client/tso_batch_controller.go +++ b/client/tso_batch_controller.go @@ -56,9 +56,7 @@ func (tbc *tsoBatchController) fetchPendingRequests(ctx context.Context, tsoRequ if tokenAcquired { tokenCh <- struct{}{} } - if tbc.collectedRequestCount > 0 { - tbc.finishCollectedRequests(0, 0, 0, errRet) - } + tbc.finishCollectedRequests(0, 0, 0, invalidStreamID, errRet) } }() @@ -170,12 +168,13 @@ func (tbc *tsoBatchController) adjustBestBatchSize() { } } -func (tbc *tsoBatchController) finishCollectedRequests(physical, firstLogical int64, suffixBits uint32, err error) { +func (tbc *tsoBatchController) finishCollectedRequests(physical, firstLogical int64, suffixBits uint32, streamID string, err error) { for i := 0; i < tbc.collectedRequestCount; i++ { tsoReq := tbc.collectedRequests[i] // Retrieve the request context before the request is done to trace without race. requestCtx := tsoReq.requestCtx tsoReq.physical, tsoReq.logical = physical, tsoutil.AddLogical(firstLogical, int64(i), suffixBits) + tsoReq.streamID = streamID tsoReq.tryDone(err) trace.StartRegion(requestCtx, "pdclient.tsoReqDequeue").End() } diff --git a/client/tso_dispatcher.go b/client/tso_dispatcher.go index e6a5588aa84..efbe246a82a 100644 --- a/client/tso_dispatcher.go +++ b/client/tso_dispatcher.go @@ -104,7 +104,9 @@ func newTSODispatcher( tsoRequestCh = make(chan *tsoRequest, 1) }) - tokenCh := make(chan struct{}, 64) + // A large-enough capacity to hold maximum concurrent RPC requests. In our design, the concurrency is at most 16. + const tokenChCapacity = 64 + tokenCh := make(chan struct{}, tokenChCapacity) td := &tsoDispatcher{ ctx: dispatcherCtx, @@ -276,7 +278,7 @@ tsoBatchLoop: select { case <-ctx.Done(): // Finish the collected requests if the context is canceled. - td.cancelCollectedRequests(batchController, errors.WithStack(ctx.Err())) + td.cancelCollectedRequests(batchController, invalidStreamID, errors.WithStack(ctx.Err())) timer.Stop() return case <-streamLoopTimer.C: @@ -284,7 +286,7 @@ tsoBatchLoop: log.Error("[tso] create tso stream error", zap.String("dc-location", dc), errs.ZapError(err)) svcDiscovery.ScheduleCheckMemberChanged() // Finish the collected requests if the stream is failed to be created. - td.cancelCollectedRequests(batchController, errors.WithStack(err)) + td.cancelCollectedRequests(batchController, invalidStreamID, errors.WithStack(err)) timer.Stop() continue tsoBatchLoop case <-timer.C: @@ -308,7 +310,7 @@ tsoBatchLoop: exit := !td.handleProcessRequestError(ctx, bo, streamURL, cancel, err) stream = nil if exit { - td.cancelCollectedRequests(batchController, errors.WithStack(ctx.Err())) + td.cancelCollectedRequests(batchController, invalidStreamID, errors.WithStack(ctx.Err())) return } continue @@ -321,7 +323,7 @@ tsoBatchLoop: select { case <-ctx.Done(): // Finish the collected requests if the context is canceled. - td.cancelCollectedRequests(batchController, errors.WithStack(ctx.Err())) + td.cancelCollectedRequests(batchController, invalidStreamID, errors.WithStack(ctx.Err())) return case td.tsDeadlineCh <- dl: } @@ -495,7 +497,7 @@ func (td *tsoDispatcher) processRequests( defer td.batchBufferPool.Put(tbc) if err != nil { - td.cancelCollectedRequests(tbc, err) + td.cancelCollectedRequests(tbc, stream.streamID, err) return } @@ -510,7 +512,7 @@ func (td *tsoDispatcher) processRequests( // `logical` is the largest ts's logical part here, we need to do the subtracting before we finish each TSO request. firstLogical := tsoutil.AddLogical(result.logical, -int64(result.count)+1, result.suffixBits) td.compareAndSwapTS(curTSOInfo, firstLogical) - td.doneCollectedRequests(tbc, result.physical, firstLogical, result.suffixBits) + td.doneCollectedRequests(tbc, result.physical, firstLogical, result.suffixBits, stream.streamID) } err := stream.processRequests( @@ -519,20 +521,20 @@ func (td *tsoDispatcher) processRequests( if err != nil { close(done) - td.cancelCollectedRequests(tbc, err) + td.cancelCollectedRequests(tbc, stream.streamID, err) return err } return nil } -func (td *tsoDispatcher) cancelCollectedRequests(tbc *tsoBatchController, err error) { +func (td *tsoDispatcher) cancelCollectedRequests(tbc *tsoBatchController, streamID string, err error) { td.tokenCh <- struct{}{} - tbc.finishCollectedRequests(0, 0, 0, err) + tbc.finishCollectedRequests(0, 0, 0, streamID, err) } -func (td *tsoDispatcher) doneCollectedRequests(tbc *tsoBatchController, physical int64, firstLogical int64, suffixBits uint32) { +func (td *tsoDispatcher) doneCollectedRequests(tbc *tsoBatchController, physical int64, firstLogical int64, suffixBits uint32, streamID string) { td.tokenCh <- struct{}{} - tbc.finishCollectedRequests(physical, firstLogical, suffixBits, nil) + tbc.finishCollectedRequests(physical, firstLogical, suffixBits, streamID, nil) } func (td *tsoDispatcher) compareAndSwapTS( diff --git a/client/tso_request.go b/client/tso_request.go index b912fa35497..fb2ae2bb92e 100644 --- a/client/tso_request.go +++ b/client/tso_request.go @@ -42,6 +42,9 @@ type tsoRequest struct { logical int64 dcLocation string + // The identifier of the RPC stream in which the request is processed. + streamID string + // Runtime fields. start time.Time pool *sync.Pool diff --git a/client/tso_stream.go b/client/tso_stream.go index 279f11db47c..bff311eb285 100644 --- a/client/tso_stream.go +++ b/client/tso_stream.go @@ -222,6 +222,8 @@ const ( var streamIDAlloc atomic.Int32 +const invalidStreamID = "" + // TODO: Pass a context? func newTSOStream(serverURL string, stream grpcTSOStreamAdapter) *tsoStream { streamID := fmt.Sprintf("%s-%d", serverURL, streamIDAlloc.Add(1)) From ca64ed0d192254c0f2f8e5124a49c9f5669556a1 Mon Sep 17 00:00:00 2001 From: MyonKeminta Date: Mon, 9 Sep 2024 18:22:59 +0800 Subject: [PATCH 12/34] Pass context outside the tsoStream (push for running CI) Signed-off-by: MyonKeminta --- client/tso_stream.go | 11 ++++++----- client/tso_stream_test.go | 2 +- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/client/tso_stream.go b/client/tso_stream.go index bff311eb285..726bf502485 100644 --- a/client/tso_stream.go +++ b/client/tso_stream.go @@ -68,7 +68,7 @@ func (b *pdTSOStreamBuilder) build(ctx context.Context, cancel context.CancelFun stream, err := b.client.Tso(ctx) done <- struct{}{} if err == nil { - return newTSOStream(b.serverURL, pdTSOStreamAdapter{stream}), nil + return newTSOStream(ctx, b.serverURL, pdTSOStreamAdapter{stream}), nil } return nil, err } @@ -87,7 +87,7 @@ func (b *tsoTSOStreamBuilder) build( stream, err := b.client.Tso(ctx) done <- struct{}{} if err == nil { - return newTSOStream(b.serverURL, tsoTSOStreamAdapter{stream}), nil + return newTSOStream(ctx, b.serverURL, tsoTSOStreamAdapter{stream}), nil } return nil, err } @@ -224,10 +224,11 @@ var streamIDAlloc atomic.Int32 const invalidStreamID = "" -// TODO: Pass a context? -func newTSOStream(serverURL string, stream grpcTSOStreamAdapter) *tsoStream { +func newTSOStream(ctx context.Context, serverURL string, stream grpcTSOStreamAdapter) *tsoStream { streamID := fmt.Sprintf("%s-%d", serverURL, streamIDAlloc.Add(1)) - ctx, cancel := context.WithCancel(context.Background()) + // To make error handling in `tsoDispatcher` work, the internal `cancel` and external `cancel` is better to be + // distinguished. + ctx, cancel := context.WithCancel(ctx) s := &tsoStream{ serverURL: serverURL, stream: stream, diff --git a/client/tso_stream_test.go b/client/tso_stream_test.go index 8bcb0c29a4b..e0b3beaa010 100644 --- a/client/tso_stream_test.go +++ b/client/tso_stream_test.go @@ -112,7 +112,7 @@ type testTSOStreamSuite struct { func (s *testTSOStreamSuite) SetupTest() { s.re = require.New(s.T()) s.inner = newMockTSOStreamImpl() - s.stream = newTSOStream("mock:///", s.inner) + s.stream = newTSOStream(context.Background(), "mock:///", s.inner) } func (s *testTSOStreamSuite) TearDownTest() { From 67967e4f5eee7ae5363c86dd41cf7bbb34d9dfcf Mon Sep 17 00:00:00 2001 From: MyonKeminta Date: Wed, 11 Sep 2024 19:58:15 +0800 Subject: [PATCH 13/34] client: Add benchmark for tsoStream and dispatcher Signed-off-by: MyonKeminta --- client/tso_dispatcher_test.go | 107 +++++++++++++++++++ client/tso_stream_test.go | 188 ++++++++++++++++++++++++++++++++++ 2 files changed, 295 insertions(+) create mode 100644 client/tso_dispatcher_test.go create mode 100644 client/tso_stream_test.go diff --git a/client/tso_dispatcher_test.go b/client/tso_dispatcher_test.go new file mode 100644 index 00000000000..b1a87d0a2cd --- /dev/null +++ b/client/tso_dispatcher_test.go @@ -0,0 +1,107 @@ +// Copyright 2024 TiKV Project Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package pd + +import ( + "context" + "fmt" + "sync" + "testing" + "time" + + "github.com/pingcap/log" + "go.uber.org/zap/zapcore" +) + +type mockTSOServiceProvider struct { + option *option +} + +func newMockTSOServiceProvider(option *option) *mockTSOServiceProvider { + return &mockTSOServiceProvider{ + option: option, + } +} + +func (m *mockTSOServiceProvider) getOption() *option { + return m.option +} + +func (m *mockTSOServiceProvider) getServiceDiscovery() ServiceDiscovery { + return NewMockPDServiceDiscovery([]string{mockStreamURL}, nil) +} + +func (m *mockTSOServiceProvider) updateConnectionCtxs(ctx context.Context, _dc string, connectionCtxs *sync.Map) bool { + _, ok := connectionCtxs.Load(mockStreamURL) + if ok { + return true + } + ctx, cancel := context.WithCancel(ctx) + stream := &tsoStream{ + serverURL: mockStreamURL, + stream: newMockTSOStreamImpl(ctx, true), + } + connectionCtxs.LoadOrStore(mockStreamURL, &tsoConnectionContext{ctx, cancel, mockStreamURL, stream}) + return true +} + +func BenchmarkTSODispatcherHandleRequests(b *testing.B) { + log.SetLevel(zapcore.FatalLevel) + + ctx := context.Background() + + reqPool := &sync.Pool{ + New: func() any { + return &tsoRequest{ + done: make(chan error, 1), + physical: 0, + logical: 0, + dcLocation: globalDCLocation, + } + }, + } + getReq := func() *tsoRequest { + req := reqPool.Get().(*tsoRequest) + req.clientCtx = ctx + req.requestCtx = ctx + req.physical = 0 + req.logical = 0 + req.start = time.Now() + req.pool = reqPool + return req + } + + dispatcher := newTSODispatcher(ctx, globalDCLocation, defaultMaxTSOBatchSize, newMockTSOServiceProvider(newOption())) + var wg sync.WaitGroup + wg.Add(1) + + go dispatcher.handleDispatcher(&wg) + defer func() { + dispatcher.close() + wg.Wait() + }() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + req := getReq() + dispatcher.push(req) + _, _, err := req.Wait() + if err != nil { + panic(fmt.Sprintf("unexpected error from tsoReq: %+v", err)) + } + } + // Don't count the time cost in `defer` + b.StopTimer() +} diff --git a/client/tso_stream_test.go b/client/tso_stream_test.go new file mode 100644 index 00000000000..9c68a35d55e --- /dev/null +++ b/client/tso_stream_test.go @@ -0,0 +1,188 @@ +// Copyright 2024 TiKV Project Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package pd + +import ( + "context" + "io" + "testing" + "time" +) + +const mockStreamURL = "mock:///" + +type requestMsg struct { + clusterID uint64 + keyspaceGroupID uint32 + count int64 +} + +type resultMsg struct { + r tsoRequestResult + err error + breakStream bool +} + +type mockTSOStreamImpl struct { + ctx context.Context + requestCh chan requestMsg + resultCh chan resultMsg + keyspaceID uint32 + errorState error + + autoGenerateResult bool + // Current progress of generating TSO results + resGenPhysical, resGenLogical int64 +} + +func newMockTSOStreamImpl(ctx context.Context, autoGenerateResult bool) *mockTSOStreamImpl { + return &mockTSOStreamImpl{ + ctx: ctx, + requestCh: make(chan requestMsg, 64), + resultCh: make(chan resultMsg, 64), + keyspaceID: 0, + + autoGenerateResult: autoGenerateResult, + resGenPhysical: 10000, + resGenLogical: 0, + } +} + +func (s *mockTSOStreamImpl) Send(clusterID uint64, _keyspaceID, keyspaceGroupID uint32, _dcLocation string, count int64) error { + select { + case <-s.ctx.Done(): + return s.ctx.Err() + default: + } + s.requestCh <- requestMsg{ + clusterID: clusterID, + keyspaceGroupID: keyspaceGroupID, + count: count, + } + return nil +} + +func (s *mockTSOStreamImpl) Recv() (ret tsoRequestResult, retErr error) { + // This stream have ever receive an error, it returns the error forever. + if s.errorState != nil { + return tsoRequestResult{}, s.errorState + } + + select { + case <-s.ctx.Done(): + s.errorState = s.ctx.Err() + return tsoRequestResult{}, s.errorState + default: + } + + var res resultMsg + needGenRes := false + if s.autoGenerateResult { + select { + case res = <-s.resultCh: + default: + needGenRes = true + } + } else { + select { + case res = <-s.resultCh: + case <-s.ctx.Done(): + s.errorState = s.ctx.Err() + return tsoRequestResult{}, s.errorState + } + } + + if !res.breakStream { + var req requestMsg + select { + case req = <-s.requestCh: + case <-s.ctx.Done(): + s.errorState = s.ctx.Err() + return tsoRequestResult{}, s.errorState + } + if needGenRes { + + physical := s.resGenPhysical + logical := s.resGenLogical + req.count + if logical >= (1 << 18) { + physical += logical >> 18 + logical &= (1 << 18) - 1 + } + + s.resGenPhysical = physical + s.resGenLogical = logical + + res = resultMsg{ + r: tsoRequestResult{ + physical: s.resGenPhysical, + logical: s.resGenLogical, + count: uint32(req.count), + suffixBits: 0, + respKeyspaceGroupID: 0, + }, + } + } + } + if res.err != nil { + s.errorState = res.err + } + return res.r, res.err +} + +func (s *mockTSOStreamImpl) returnResult(physical int64, logical int64, count uint32) { + s.resultCh <- resultMsg{ + r: tsoRequestResult{ + physical: physical, + logical: logical, + count: count, + suffixBits: 0, + respKeyspaceGroupID: s.keyspaceID, + }, + } +} + +func (s *mockTSOStreamImpl) returnError(err error) { + s.resultCh <- resultMsg{ + err: err, + } +} + +func (s *mockTSOStreamImpl) breakStream(err error) { + s.resultCh <- resultMsg{ + err: err, + breakStream: true, + } +} + +func (s *mockTSOStreamImpl) stop() { + s.breakStream(io.EOF) +} + +func BenchmarkTSOStreamSendRecv(b *testing.B) { + streamInner := newMockTSOStreamImpl(context.Background(), true) + stream := tsoStream{ + serverURL: mockStreamURL, + stream: streamInner, + } + defer streamInner.stop() + + now := time.Now() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _, _, _, _ = stream.processRequests(1, 1, 1, globalDCLocation, 1, now) + } + b.StopTimer() +} From 4699b4112cd73210b24b3f80ce3e2350bbab9767 Mon Sep 17 00:00:00 2001 From: MyonKeminta Date: Thu, 12 Sep 2024 13:34:20 +0800 Subject: [PATCH 14/34] Fix build Signed-off-by: MyonKeminta --- client/tso_stream_test.go | 23 +++++++++++++++++------ 1 file changed, 17 insertions(+), 6 deletions(-) diff --git a/client/tso_stream_test.go b/client/tso_stream_test.go index 39a631543bc..19a90fd0674 100644 --- a/client/tso_stream_test.go +++ b/client/tso_stream_test.go @@ -413,17 +413,28 @@ func (s *testTSOStreamSuite) TestTSOStreamConcurrentRunning() { func BenchmarkTSOStreamSendRecv(b *testing.B) { streamInner := newMockTSOStreamImpl(context.Background(), true) - stream := tsoStream{ - serverURL: mockStreamURL, - stream: streamInner, - } - defer streamInner.stop() + stream := newTSOStream(context.Background(), mockStreamURL, streamInner) + defer func() { + streamInner.stop() + stream.WaitForClosed() + }() now := time.Now() + resCh := make(chan tsoRequestResult, 1) b.ResetTimer() for i := 0; i < b.N; i++ { - _, _, _, _, _ = stream.processRequests(1, 1, 1, globalDCLocation, 1, now) + err := stream.processRequests(1, 1, 1, globalDCLocation, 1, now, func(result tsoRequestResult, reqKeyspaceGroupID uint32, err error) { + select { + case resCh <- result: + default: + panic("channel not cleared in the last iteration") + } + }) + if err != nil { + panic(err) + } + <-resCh } b.StopTimer() } From 7dbcd770af83d272e769e71ae4f81dd0ed85204a Mon Sep 17 00:00:00 2001 From: MyonKeminta Date: Thu, 12 Sep 2024 15:02:02 +0800 Subject: [PATCH 15/34] Reimplement mock receiving Signed-off-by: MyonKeminta --- client/tso_stream_test.go | 102 ++++++++++++++++++++++++++------------ 1 file changed, 70 insertions(+), 32 deletions(-) diff --git a/client/tso_stream_test.go b/client/tso_stream_test.go index 9c68a35d55e..c543d74c827 100644 --- a/client/tso_stream_test.go +++ b/client/tso_stream_test.go @@ -74,7 +74,7 @@ func (s *mockTSOStreamImpl) Send(clusterID uint64, _keyspaceID, keyspaceGroupID return nil } -func (s *mockTSOStreamImpl) Recv() (ret tsoRequestResult, retErr error) { +func (s *mockTSOStreamImpl) Recv() (tsoRequestResult, error) { // This stream have ever receive an error, it returns the error forever. if s.errorState != nil { return tsoRequestResult{}, s.errorState @@ -87,60 +87,98 @@ func (s *mockTSOStreamImpl) Recv() (ret tsoRequestResult, retErr error) { default: } - var res resultMsg - needGenRes := false - if s.autoGenerateResult { + var ( + res resultMsg + hasRes bool + req requestMsg + hasReq bool + ) + + // Try to match a pair of request and result from each channel and allowing breaking the stream at any time. + select { + case <-s.ctx.Done(): + s.errorState = s.ctx.Err() + return tsoRequestResult{}, s.errorState + case req = <-s.requestCh: + hasReq = true select { case res = <-s.resultCh: + hasRes = true default: - needGenRes = true } - } else { + case res = <-s.resultCh: + hasRes = true select { - case res = <-s.resultCh: - case <-s.ctx.Done(): + case req = <-s.requestCh: + hasReq = true + default: + } + } + // Either req or res should be ready at this time. + + if hasRes { + if res.breakStream { s.errorState = s.ctx.Err() return tsoRequestResult{}, s.errorState + } else { + // Do not allow manually assigning result. + if s.autoGenerateResult { + panic("trying manually specifying result for mockTSOStreamImpl when it's auto-generating mode") + } } + } else if s.autoGenerateResult { + res = s.autoGenResult(req.count) + hasRes = true } - if !res.breakStream { - var req requestMsg + if !hasReq { + // If req is not ready, the res must be ready. So it's certain that it don't need to be canceled by breakStream. select { - case req = <-s.requestCh: case <-s.ctx.Done(): s.errorState = s.ctx.Err() return tsoRequestResult{}, s.errorState + case req = <-s.requestCh: + hasReq = true } - if needGenRes { - - physical := s.resGenPhysical - logical := s.resGenLogical + req.count - if logical >= (1 << 18) { - physical += logical >> 18 - logical &= (1 << 18) - 1 - } - - s.resGenPhysical = physical - s.resGenLogical = logical - - res = resultMsg{ - r: tsoRequestResult{ - physical: s.resGenPhysical, - logical: s.resGenLogical, - count: uint32(req.count), - suffixBits: 0, - respKeyspaceGroupID: 0, - }, - } + } else if !hasRes { + select { + case <-s.ctx.Done(): + s.errorState = s.ctx.Err() + return tsoRequestResult{}, s.errorState + case res = <-s.resultCh: + hasRes = true } } + + // Both res and req should be ready here. if res.err != nil { s.errorState = res.err } return res.r, res.err } +func (s *mockTSOStreamImpl) autoGenResult(count int64) resultMsg { + physical := s.resGenPhysical + logical := s.resGenLogical + count + if logical >= (1 << 18) { + physical += logical >> 18 + logical &= (1 << 18) - 1 + } + + s.resGenPhysical = physical + s.resGenLogical = logical + + return resultMsg{ + r: tsoRequestResult{ + physical: s.resGenPhysical, + logical: s.resGenLogical, + count: uint32(count), + suffixBits: 0, + respKeyspaceGroupID: 0, + }, + } +} + func (s *mockTSOStreamImpl) returnResult(physical int64, logical int64, count uint32) { s.resultCh <- resultMsg{ r: tsoRequestResult{ From 1f2dcd862f183daf44bd4be68592d51833ab4160 Mon Sep 17 00:00:00 2001 From: MyonKeminta Date: Thu, 12 Sep 2024 15:25:45 +0800 Subject: [PATCH 16/34] Fix disptacher benchmark Signed-off-by: MyonKeminta --- client/tso_dispatcher_test.go | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/client/tso_dispatcher_test.go b/client/tso_dispatcher_test.go index b1a87d0a2cd..20904d41f0a 100644 --- a/client/tso_dispatcher_test.go +++ b/client/tso_dispatcher_test.go @@ -49,10 +49,7 @@ func (m *mockTSOServiceProvider) updateConnectionCtxs(ctx context.Context, _dc s return true } ctx, cancel := context.WithCancel(ctx) - stream := &tsoStream{ - serverURL: mockStreamURL, - stream: newMockTSOStreamImpl(ctx, true), - } + stream := newTSOStream(ctx, mockStreamURL, newMockTSOStreamImpl(ctx, true)) connectionCtxs.LoadOrStore(mockStreamURL, &tsoConnectionContext{ctx, cancel, mockStreamURL, stream}) return true } From 123a869fb5349867b80f458517b4c53c7a6d089f Mon Sep 17 00:00:00 2001 From: MyonKeminta Date: Thu, 12 Sep 2024 15:27:09 +0800 Subject: [PATCH 17/34] Fix incorrect error generating Signed-off-by: MyonKeminta --- client/tso_stream_test.go | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/client/tso_stream_test.go b/client/tso_stream_test.go index c543d74c827..e47fe75eca9 100644 --- a/client/tso_stream_test.go +++ b/client/tso_stream_test.go @@ -118,7 +118,10 @@ func (s *mockTSOStreamImpl) Recv() (tsoRequestResult, error) { if hasRes { if res.breakStream { - s.errorState = s.ctx.Err() + if res.err == nil { + panic("breaking mockTSOStreamImpl without error") + } + s.errorState = res.err return tsoRequestResult{}, s.errorState } else { // Do not allow manually assigning result. From bff9eaedef3c2d0b2629bc6caeb734c18d7b49fb Mon Sep 17 00:00:00 2001 From: MyonKeminta Date: Thu, 12 Sep 2024 15:31:09 +0800 Subject: [PATCH 18/34] Disable logs in BenchmarkTSOStreamSendRecv Signed-off-by: MyonKeminta --- client/tso_stream_test.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/client/tso_stream_test.go b/client/tso_stream_test.go index 25277dcc2aa..90039bd4b55 100644 --- a/client/tso_stream_test.go +++ b/client/tso_stream_test.go @@ -21,9 +21,11 @@ import ( "time" "github.com/pingcap/errors" + "github.com/pingcap/log" "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" "github.com/tikv/pd/client/errs" + "go.uber.org/zap/zapcore" ) const mockStreamURL = "mock:///" @@ -453,6 +455,8 @@ func (s *testTSOStreamSuite) TestTSOStreamConcurrentRunning() { } func BenchmarkTSOStreamSendRecv(b *testing.B) { + log.SetLevel(zapcore.FatalLevel) + streamInner := newMockTSOStreamImpl(context.Background(), true) stream := newTSOStream(context.Background(), mockStreamURL, streamInner) defer func() { From ef6e365c7a9a3061db8012b79bfdab9d1ce806a9 Mon Sep 17 00:00:00 2001 From: MyonKeminta Date: Thu, 12 Sep 2024 16:14:33 +0800 Subject: [PATCH 19/34] Fix lint Signed-off-by: MyonKeminta --- client/tso_dispatcher_test.go | 4 ++-- client/tso_stream_test.go | 15 +++++++++------ 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/client/tso_dispatcher_test.go b/client/tso_dispatcher_test.go index b1a87d0a2cd..7b7ac8f736b 100644 --- a/client/tso_dispatcher_test.go +++ b/client/tso_dispatcher_test.go @@ -39,11 +39,11 @@ func (m *mockTSOServiceProvider) getOption() *option { return m.option } -func (m *mockTSOServiceProvider) getServiceDiscovery() ServiceDiscovery { +func (*mockTSOServiceProvider) getServiceDiscovery() ServiceDiscovery { return NewMockPDServiceDiscovery([]string{mockStreamURL}, nil) } -func (m *mockTSOServiceProvider) updateConnectionCtxs(ctx context.Context, _dc string, connectionCtxs *sync.Map) bool { +func (*mockTSOServiceProvider) updateConnectionCtxs(ctx context.Context, _dc string, connectionCtxs *sync.Map) bool { _, ok := connectionCtxs.Load(mockStreamURL) if ok { return true diff --git a/client/tso_stream_test.go b/client/tso_stream_test.go index e47fe75eca9..4709fe04f4c 100644 --- a/client/tso_stream_test.go +++ b/client/tso_stream_test.go @@ -123,11 +123,9 @@ func (s *mockTSOStreamImpl) Recv() (tsoRequestResult, error) { } s.errorState = res.err return tsoRequestResult{}, s.errorState - } else { + } else if s.autoGenerateResult { // Do not allow manually assigning result. - if s.autoGenerateResult { - panic("trying manually specifying result for mockTSOStreamImpl when it's auto-generating mode") - } + panic("trying manually specifying result for mockTSOStreamImpl when it's auto-generating mode") } } else if s.autoGenerateResult { res = s.autoGenResult(req.count) @@ -141,7 +139,8 @@ func (s *mockTSOStreamImpl) Recv() (tsoRequestResult, error) { s.errorState = s.ctx.Err() return tsoRequestResult{}, s.errorState case req = <-s.requestCh: - hasReq = true + // Skip the assignment to make linter happy. + // hasReq = true } } else if !hasRes { select { @@ -149,7 +148,8 @@ func (s *mockTSOStreamImpl) Recv() (tsoRequestResult, error) { s.errorState = s.ctx.Err() return tsoRequestResult{}, s.errorState case res = <-s.resultCh: - hasRes = true + // Skip the assignment to make linter happy. + // hasRes = true } } @@ -182,6 +182,7 @@ func (s *mockTSOStreamImpl) autoGenResult(count int64) resultMsg { } } +// nolint:unused func (s *mockTSOStreamImpl) returnResult(physical int64, logical int64, count uint32) { s.resultCh <- resultMsg{ r: tsoRequestResult{ @@ -194,12 +195,14 @@ func (s *mockTSOStreamImpl) returnResult(physical int64, logical int64, count ui } } +// nolint:unused func (s *mockTSOStreamImpl) returnError(err error) { s.resultCh <- resultMsg{ err: err, } } +// nolint:unused func (s *mockTSOStreamImpl) breakStream(err error) { s.resultCh <- resultMsg{ err: err, From 388d705601b33493a26bd3a7191823abf08d84a9 Mon Sep 17 00:00:00 2001 From: MyonKeminta Date: Thu, 12 Sep 2024 17:03:24 +0800 Subject: [PATCH 20/34] Make the comments more clear; unset req.streamID when getting new tsoRequest from the pool Signed-off-by: MyonKeminta --- client/tso_client.go | 1 + client/tso_dispatcher.go | 11 ++++++++--- client/tso_stream.go | 4 ++-- client/tso_stream_test.go | 3 --- 4 files changed, 11 insertions(+), 8 deletions(-) diff --git a/client/tso_client.go b/client/tso_client.go index 2f3b949f017..f1538a7f164 100644 --- a/client/tso_client.go +++ b/client/tso_client.go @@ -203,6 +203,7 @@ func (c *tsoClient) getTSORequest(ctx context.Context, dcLocation string) *tsoRe req.physical = 0 req.logical = 0 req.dcLocation = dcLocation + req.streamID = "" return req } diff --git a/client/tso_dispatcher.go b/client/tso_dispatcher.go index efbe246a82a..a1e0b03a1fa 100644 --- a/client/tso_dispatcher.go +++ b/client/tso_dispatcher.go @@ -331,9 +331,12 @@ tsoBatchLoop: err = td.processRequests(stream, dc, batchController, done) // If error happens during tso stream handling, reset stream and run the next trial. if err == nil { - // If the request is started successfully, the `batchController` will be put back to the pool when the - // request is finished (either successful or not). In this case, set the `batchController` to nil so that - // another one will be fetched from the pool. + // A nil error returned by `processRequests` indicates that the request batch is started successfully. + // In this case, the `batchController` will be put back to the pool when the request is finished + // asynchronously (either successful or not). This infers that the current `batchController` object will + // be asynchronously accessed after the `processRequests` call. As a result, we need to use another + // `batchController` for collecting the next batch. Do to this, we set the `batchController` to nil so that + // another one will be fetched from the pool at the beginning of the batching loop. // Otherwise, the `batchController` won't be processed in other goroutines concurrently, and it can be // reused in the next loop safely. batchController = nil @@ -493,6 +496,8 @@ func (td *tsoDispatcher) processRequests( ) cb := func(result tsoRequestResult, reqKeyspaceGroupID uint32, err error) { + // As golang doesn't allow double-closing a channel, here is implicitly a check that the callback + // is never called twice or called while it's also being cancelled elsewhere. close(done) defer td.batchBufferPool.Put(tbc) diff --git a/client/tso_stream.go b/client/tso_stream.go index 726bf502485..a6e23bbf1d3 100644 --- a/client/tso_stream.go +++ b/client/tso_stream.go @@ -270,7 +270,7 @@ func (s *tsoStream) processRequests( case streamStateClosing: s.state.Store(prevState) err := s.GetRecvError() - log.Info("sending to closed tsoStream", zap.Error(err)) + log.Info("sending to closed tsoStream", zap.String("stream", s.streamID), zap.Error(err)) if err == nil { err = errors.WithStack(errs.ErrClientTSOStreamClosed) } @@ -322,7 +322,7 @@ func (s *tsoStream) recvLoop(ctx context.Context) { if finishWithErr == nil { // The loop must exit with a non-nil error (including io.EOF and context.Canceled). This should be // unreachable code. - log.Fatal("tsoStream.recvLoop exited without error info") + log.Fatal("tsoStream.recvLoop exited without error info", zap.String("stream", s.streamID)) } if hasReq { diff --git a/client/tso_stream_test.go b/client/tso_stream_test.go index 13583ee3839..564de37d6db 100644 --- a/client/tso_stream_test.go +++ b/client/tso_stream_test.go @@ -189,7 +189,6 @@ func (s *mockTSOStreamImpl) autoGenResult(count int64) resultMsg { } } -// nolint:unused func (s *mockTSOStreamImpl) returnResult(physical int64, logical int64, count uint32) { s.resultCh <- resultMsg{ r: tsoRequestResult{ @@ -202,14 +201,12 @@ func (s *mockTSOStreamImpl) returnResult(physical int64, logical int64, count ui } } -// nolint:unused func (s *mockTSOStreamImpl) returnError(err error) { s.resultCh <- resultMsg{ err: err, } } -// nolint:unused func (s *mockTSOStreamImpl) breakStream(err error) { s.resultCh <- resultMsg{ err: err, From a183aca0a2ecb886c885f1a3d584266ee80f96fb Mon Sep 17 00:00:00 2001 From: MyonKeminta Date: Thu, 12 Sep 2024 17:42:54 +0800 Subject: [PATCH 21/34] Fix lint Signed-off-by: MyonKeminta --- client/tso_stream_test.go | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/client/tso_stream_test.go b/client/tso_stream_test.go index 564de37d6db..b09c54baf3a 100644 --- a/client/tso_stream_test.go +++ b/client/tso_stream_test.go @@ -469,7 +469,10 @@ func BenchmarkTSOStreamSendRecv(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { - err := stream.processRequests(1, 1, 1, globalDCLocation, 1, now, func(result tsoRequestResult, reqKeyspaceGroupID uint32, err error) { + err := stream.processRequests(1, 1, 1, globalDCLocation, 1, now, func(result tsoRequestResult, _ uint32, err error) { + if err != nil { + panic(err) + } select { case resCh <- result: default: From 4296c3f8a035026d1ca1ec2af5fac69f41afc599 Mon Sep 17 00:00:00 2001 From: MyonKeminta Date: Thu, 12 Sep 2024 19:49:34 +0800 Subject: [PATCH 22/34] Add latency estimation Signed-off-by: MyonKeminta --- client/tso_stream.go | 74 ++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 71 insertions(+), 3 deletions(-) diff --git a/client/tso_stream.go b/client/tso_stream.go index a6e23bbf1d3..0cddd2c9747 100644 --- a/client/tso_stream.go +++ b/client/tso_stream.go @@ -18,6 +18,7 @@ import ( "context" "fmt" "io" + "math" "sync" "sync/atomic" "time" @@ -210,6 +211,8 @@ type tsoStream struct { state atomic.Int32 stoppedWithErr atomic.Pointer[error] + estimatedLatencyMicros atomic.Uint64 + ongoingRequestCountGauge prometheus.Gauge ongoingRequests atomic.Int32 } @@ -359,6 +362,24 @@ func (s *tsoStream) recvLoop(ctx context.Context) { s.ongoingRequestCountGauge.Set(0) }() + // For calculating the estimated RPC latency. + const ( + filterCutoffFreq float64 = 1.0 + filterNewSampleWeightUpperbound = 0.2 + ) + // The filter applies on logarithm of the latency of each TSO RPC in microseconds. + filter := newRCFilter(filterCutoffFreq, filterNewSampleWeightUpperbound) + + updateEstimatedLatency := func(sampleTime time.Time, latency time.Duration) { + if latency < 0 { + // Unreachable + return + } + currentSample := math.Log(float64(latency.Microseconds())) + filteredValue := filter.update(sampleTime, currentSample) + s.estimatedLatencyMicros.Store(uint64(math.Exp(filteredValue))) + } + recvLoop: for { select { @@ -379,14 +400,15 @@ recvLoop: hasReq = false } - durationSeconds := time.Since(currentReq.startTime).Seconds() + latency := time.Since(currentReq.startTime) + latencySeconds := latency.Seconds() if err != nil { // If a request is pending and error occurs, observe the duration it has cost. // Note that it's also possible that the stream is broken due to network without being requested. In this // case, `Recv` may return an error while no request is pending. if hasReq { - requestFailedDurationTSO.Observe(durationSeconds) + requestFailedDurationTSO.Observe(latencySeconds) } if err == io.EOF { finishWithErr = errors.WithStack(errs.ErrClientTSOStreamClosed) @@ -399,9 +421,9 @@ recvLoop: break recvLoop } - latencySeconds := durationSeconds requestDurationTSO.Observe(latencySeconds) tsoBatchSize.Observe(float64(res.count)) + updateEstimatedLatency(currentReq.startTime, latency) if res.count != uint32(currentReq.count) { finishWithErr = errors.WithStack(errTSOLength) @@ -417,6 +439,17 @@ recvLoop: } } +// EstimatedRPCLatency returns an estimation of the duration of each TSO RPC. If the stream has never handled any RPC, +// this function returns 0. +func (s *tsoStream) EstimatedRPCLatency() time.Duration { + latencyUs := s.estimatedLatencyMicros.Load() + // Limit it at least 100us + if latencyUs < 100 { + latencyUs = 100 + } + return time.Microsecond * time.Duration(latencyUs) +} + // GetRecvError returns the error (if any) that has been encountered when receiving response asynchronously. func (s *tsoStream) GetRecvError() error { perr := s.stoppedWithErr.Load() @@ -430,3 +463,38 @@ func (s *tsoStream) GetRecvError() error { func (s *tsoStream) WaitForClosed() { s.wg.Wait() } + +type rcFilter struct { + rc float64 + newSampleWeightUpperBound float64 + value float64 + lastSampleTime time.Time + firstSampleArrived bool +} + +func newRCFilter(cutoff float64, newSampleWeightUpperBound float64) rcFilter { + rc := 1.0 / (2.0 * math.Pi * cutoff) + return rcFilter{ + rc: rc, + newSampleWeightUpperBound: newSampleWeightUpperBound, + } +} + +func (f *rcFilter) update(sampleTime time.Time, newSample float64) float64 { + // Handle the first sample + if !f.firstSampleArrived { + f.firstSampleArrived = true + f.lastSampleTime = sampleTime + f.value = newSample + return newSample + } + + // Delta time + dt := sampleTime.Sub(f.lastSampleTime).Seconds() + // Current sample represented and calculated in log(microseconds) + alpha := math.Min(dt/(f.rc+dt), f.newSampleWeightUpperBound) + f.value = (1-alpha)*f.value + alpha*newSample + + f.lastSampleTime = sampleTime + return f.value +} From 9b5b054d84de5ff8897c33ff0dce90e539335e3b Mon Sep 17 00:00:00 2001 From: MyonKeminta Date: Sat, 14 Sep 2024 18:53:24 +0800 Subject: [PATCH 23/34] Support concurrent RPC Signed-off-by: MyonKeminta --- client/client.go | 6 ++ client/metrics.go | 10 +++ client/option.go | 15 ++++ client/tso_batch_controller.go | 34 ++++++++- client/tso_dispatcher.go | 129 ++++++++++++++++++++++++++++++++- 5 files changed, 191 insertions(+), 3 deletions(-) diff --git a/client/client.go b/client/client.go index aafe4aba77f..bc7cdf9d576 100644 --- a/client/client.go +++ b/client/client.go @@ -796,6 +796,12 @@ func (c *client) UpdateOption(option DynamicOption, value any) error { return errors.New("[pd] invalid value type for EnableFollowerHandle option, it should be bool") } c.option.setEnableFollowerHandle(enable) + case TSOClientRPCConcurrency: + value, ok := value.(int) + if !ok { + return errors.New("[pd] invalid value type for TSOClientRPCConcurrency option, it should be int") + } + c.option.setTSOClientRPCConcurrency(value) default: return errors.New("[pd] unsupported client option") } diff --git a/client/metrics.go b/client/metrics.go index a83b4a36407..d1b375aea8a 100644 --- a/client/metrics.go +++ b/client/metrics.go @@ -47,6 +47,7 @@ var ( tsoBatchSendLatency prometheus.Histogram requestForwarded *prometheus.GaugeVec ongoingRequestCountGauge *prometheus.GaugeVec + estimateTSOLatencyGauge *prometheus.GaugeVec ) func initMetrics(constLabels prometheus.Labels) { @@ -127,6 +128,14 @@ func initMetrics(constLabels prometheus.Labels) { Help: "Current count of ongoing batch tso requests", ConstLabels: constLabels, }, []string{"stream"}) + estimateTSOLatencyGauge = prometheus.NewGaugeVec( + prometheus.GaugeOpts{ + Namespace: "pd_client", + Subsystem: "request", + Name: "estimate_tso_latency", + Help: "Estimated latency of an RTT of getting TSO", + ConstLabels: constLabels, + }, []string{"stream"}) } var ( @@ -236,4 +245,5 @@ func registerMetrics() { prometheus.MustRegister(tsoBatchSize) prometheus.MustRegister(tsoBatchSendLatency) prometheus.MustRegister(requestForwarded) + prometheus.MustRegister(estimateTSOLatencyGauge) } diff --git a/client/option.go b/client/option.go index 0109bfc4ed0..3f2b7119b52 100644 --- a/client/option.go +++ b/client/option.go @@ -29,6 +29,7 @@ const ( defaultMaxTSOBatchWaitInterval time.Duration = 0 defaultEnableTSOFollowerProxy = false defaultEnableFollowerHandle = false + defaultTSOClientRPCConcurrency = 1 ) // DynamicOption is used to distinguish the dynamic option type. @@ -43,6 +44,8 @@ const ( EnableTSOFollowerProxy // EnableFollowerHandle is the follower handle option. EnableFollowerHandle + // TSOClientRPCConcurrency controls the amount of ongoing TSO RPC requests at the same time in a single TSO client. + TSOClientRPCConcurrency dynamicOptionCount ) @@ -77,6 +80,7 @@ func newOption() *option { co.dynamicOptions[MaxTSOBatchWaitInterval].Store(defaultMaxTSOBatchWaitInterval) co.dynamicOptions[EnableTSOFollowerProxy].Store(defaultEnableTSOFollowerProxy) co.dynamicOptions[EnableFollowerHandle].Store(defaultEnableFollowerHandle) + co.dynamicOptions[TSOClientRPCConcurrency].Store(defaultTSOClientRPCConcurrency) return co } @@ -127,3 +131,14 @@ func (o *option) setEnableTSOFollowerProxy(enable bool) { func (o *option) getEnableTSOFollowerProxy() bool { return o.dynamicOptions[EnableTSOFollowerProxy].Load().(bool) } + +func (o *option) setTSOClientRPCConcurrency(value int) { + old := o.getTSOClientRPCConcurrency() + if value != old { + o.dynamicOptions[TSOClientRPCConcurrency].Store(value) + } +} + +func (o *option) getTSOClientRPCConcurrency() int { + return o.dynamicOptions[TSOClientRPCConcurrency].Load().(int) +} diff --git a/client/tso_batch_controller.go b/client/tso_batch_controller.go index 32191889160..51bba01a2cd 100644 --- a/client/tso_batch_controller.go +++ b/client/tso_batch_controller.go @@ -47,6 +47,7 @@ func newTSOBatchController(maxBatchSize int) *tsoBatchController { // It returns nil error if everything goes well, otherwise a non-nil error which means we should stop the service. // It's guaranteed that if this function failed after collecting some requests, then these requests will be cancelled // when the function returns, so the caller don't need to clear them manually. +// If maxBatchWaitInterval is not enabled, func (tbc *tsoBatchController) fetchPendingRequests(ctx context.Context, tsoRequestCh <-chan *tsoRequest, tokenCh chan struct{}, maxBatchWaitInterval time.Duration) (errRet error) { var tokenAcquired bool defer func() { @@ -63,7 +64,7 @@ func (tbc *tsoBatchController) fetchPendingRequests(ctx context.Context, tsoRequ // Wait until BOTH the first request and the token have arrived. // TODO: `tbc.collectedRequestCount` should never be non-empty here. Consider do assertion here. tbc.collectedRequestCount = 0 - for { + for tbc.collectedRequestCount < tbc.maxBatchSize { select { case <-ctx.Done(): return ctx.Err() @@ -146,6 +147,37 @@ fetchPendingRequestsLoop: return nil } +// fetchRequestsWithTimer tries to fetch requests until the given timer ticks. The caller must set the timer properly +// before calling this function. +func (tbc *tsoBatchController) fetchRequestsWithTimer(ctx context.Context, tsoRequestCh <-chan *tsoRequest, timer *time.Timer) error { +batchingLoop: + for tbc.collectedRequestCount < tbc.maxBatchSize { + select { + case <-ctx.Done(): + return ctx.Err() + case req := <-tsoRequestCh: + tbc.pushRequest(req) + case <-timer.C: + break batchingLoop + } + } + + // Try to collect more requests in non-blocking way. +nonWaitingBatchLoop: + for tbc.collectedRequestCount < tbc.maxBatchSize { + select { + case <-ctx.Done(): + return ctx.Err() + case req := <-tsoRequestCh: + tbc.pushRequest(req) + default: + break nonWaitingBatchLoop + } + } + + return nil +} + func (tbc *tsoBatchController) pushRequest(tsoReq *tsoRequest) { tbc.collectedRequests[tbc.collectedRequestCount] = tsoReq tbc.collectedRequestCount++ diff --git a/client/tso_dispatcher.go b/client/tso_dispatcher.go index a1e0b03a1fa..a8402608bee 100644 --- a/client/tso_dispatcher.go +++ b/client/tso_dispatcher.go @@ -69,6 +69,8 @@ type tsoServiceProvider interface { updateConnectionCtxs(ctx context.Context, dc string, connectionCtxs *sync.Map) bool } +const dispatcherCheckRPCConcurrencyInterval = time.Second * 5 + type tsoDispatcher struct { ctx context.Context cancel context.CancelFunc @@ -87,7 +89,10 @@ type tsoDispatcher struct { // A token must be acquired here before sending an RPC request, and the token must be put back after finishing the // RPC. This is used like a semaphore, but we don't use semaphore directly here as it cannot be selected with // other channels. - tokenCh chan struct{} + tokenCh chan struct{} + lastCheckConcurrencyTime time.Time + tokenCount int + rpcConcurrency int updateConnectionCtxsCh chan struct{} } @@ -219,6 +224,12 @@ func (td *tsoDispatcher) handleDispatcher(wg *sync.WaitGroup) { // Loop through each batch of TSO requests and send them for processing. streamLoopTimer := time.NewTimer(option.timeout) defer streamLoopTimer.Stop() + + // Create a not-started-timer to be used for collecting batches for concurrent RPC. + batchingTimer := time.NewTimer(0) + <-batchingTimer.C + defer batchingTimer.Stop() + bo := retry.InitialBackoffer(updateMemberBackOffBaseTime, updateMemberTimeout, updateMemberBackOffBaseTime) tsoBatchLoop: for { @@ -233,8 +244,18 @@ tsoBatchLoop: batchController = td.batchBufferPool.Get().(*tsoBatchController) } - // Start to collect the TSO requests. maxBatchWaitInterval := option.getMaxTSOBatchWaitInterval() + + currentBatchStartTime := time.Now() + // Update concurrency settings if needed. + if err = td.checkTSORPCConcurrency(ctx, maxBatchWaitInterval, currentBatchStartTime); err != nil { + // checkTSORPCConcurrency can only fail due to `ctx` being invalidated. + log.Info("[tso] stop checking tso rpc concurrency configurations due to context canceled", + zap.String("dc-location", dc), zap.Error(err)) + return + } + + // Start to collect the TSO requests. // Once the TSO requests are collected, must make sure they could be finished or revoked eventually, // otherwise the upper caller may get blocked on waiting for the results. if err = batchController.fetchPendingRequests(ctx, td.tsoRequestCh, td.tokenCh, maxBatchWaitInterval); err != nil { @@ -318,6 +339,39 @@ tsoBatchLoop: break streamChoosingLoop } + + // If concurrent RPC is enabled, the time for collecting each request batch is expected to be + // estimatedRPCDuration / concurrency. Note the time mentioned here is counted from starting trying to collect + // the batch, instead of the time when the first request arrives. + // Here, if the elapsed time since starting collecting this batch didn't reach the expected batch time, then + // continue collecting. + if td.isConcurrentRPCEnabled() { + estimatedLatency := stream.EstimatedRPCLatency() + estimateTSOLatencyGauge.WithLabelValues(streamURL).Set(estimatedLatency.Seconds()) + + totalBatchTime := estimatedLatency / time.Duration(td.rpcConcurrency) + waitTimerStart := time.Now() + remainingBatchTime := totalBatchTime - waitTimerStart.Sub(currentBatchStartTime) + if remainingBatchTime > 0 { + if !batchingTimer.Stop() { + select { + case <-batchingTimer.C: + default: + } + } + batchingTimer.Reset(remainingBatchTime) + + err = batchController.fetchRequestsWithTimer(ctx, td.tsoRequestCh, batchingTimer) + if err != nil { + // There should not be + log.Info("[tso] stop fetching the pending tso requests due to context canceled", + zap.String("dc-location", dc), zap.Error(err)) + td.cancelCollectedRequests(batchController, invalidStreamID, errors.WithStack(ctx.Err())) + return + } + } + } + done := make(chan struct{}) dl := newTSDeadline(option.timeout, done, cancel) select { @@ -580,3 +634,74 @@ func (td *tsoDispatcher) compareAndSwapTS( } td.lastTSOInfo = curTSOInfo } + +// checkTSORPCConcurrency checks configurations about TSO RPC concurrency, and adjust the token count if needed. +// Some other options (EnableTSOFollowerProxy and MaxTSOBatchWaitInterval) may affect the availability of concurrent +// RPC requests. As the dispatcher loop loads MaxTSOBatchWaitInterval in each single circle, pass it directly to this +// function. Other configurations will be loaded within this function when needed. +// +// Behavior of the function: +// - As concurrent TSO RPC requests is an optimization aiming on the opposite purpose to that of EnableTSOFollowerProxy +// and MaxTSOBatchWaitInterval, so once either EnableTSOFollowerProxy and MaxTSOBatchWaitInterval is enabled, the +// concurrency will always be set to 1 no matter how the user configured it. +// - Normally, this function takes effect in a limited frequency controlled by dispatcherCheckRPCConcurrencyInterval. +// However, if the RPC concurrency is set to more than 1, and MaxTSOBatchWaitInterval is changed from disabled into +// enabled (0 -> positive), this function takes effect immediately to disable concurrent RPC requests. +// - After this function takes effect, the final decision of concurrency and token count will be set to +// td.rpcConcurrency and td.tokenCount; and tokens available in td.tokenCh will also be adjusted. +func (td *tsoDispatcher) checkTSORPCConcurrency(ctx context.Context, maxBatchWaitInterval time.Duration, now time.Time) error { + // If we currently enabled concurrent TSO RPC requests, but `maxBatchWaitInterval` is a positive value, it must + // because that MaxTSOBatchWaitInterval is just enabled. In this case, disable concurrent TSO RPC requests + // immediately, because MaxTSOBatchWaitInterval and concurrent RPC requests has opposite purpose. + immediatelyUpdate := td.rpcConcurrency > 1 && maxBatchWaitInterval > 0 + + if !immediatelyUpdate && now.Sub(td.lastCheckConcurrencyTime) < dispatcherCheckRPCConcurrencyInterval { + return nil + } + td.lastCheckConcurrencyTime = now + + newConcurrency := td.provider.getOption().getTSOClientRPCConcurrency() + if maxBatchWaitInterval > 0 || td.provider.getOption().getEnableTSOFollowerProxy() { + newConcurrency = 1 + } + + if newConcurrency == td.rpcConcurrency { + return nil + } + + log.Info("[tso] switching tso rpc concurrency", zap.Int("old", td.rpcConcurrency), zap.Int("new", newConcurrency)) + td.rpcConcurrency = newConcurrency + + // Find a proper token count. + // When the concurrency is set to 1, there's only 1 token, which means only 1 RPC request can run at the same + // time. + // When the concurrency is set to more than 1, the time interval between sending two batches of requests is + // controlled by an estimation of an average RPC duration. But as the duration of an RPC may jitter in the network, + // and an RPC request may finish earlier or later. So we allow there to be the actual number of concurrent ongoing + // request to be fluctuating. So in this case, the token count will be set to 2 times the expected concurrency. + newTokenCount := newConcurrency + if newConcurrency > 1 { + newTokenCount = newConcurrency * 2 + } + + if newTokenCount > td.tokenCount { + for td.tokenCount < newTokenCount { + td.tokenCh <- struct{}{} + td.tokenCount++ + } + } else if newTokenCount < td.tokenCount { + for td.tokenCount > newTokenCount { + select { + case <-ctx.Done(): + return ctx.Err() + case <-td.tokenCh: + } + td.tokenCount-- + } + } + return nil +} + +func (td *tsoDispatcher) isConcurrentRPCEnabled() bool { + return td.rpcConcurrency > 1 +} From 4bb33a263725cca7e6c12d16f94d53b6fec6e527 Mon Sep 17 00:00:00 2001 From: MyonKeminta Date: Sat, 14 Sep 2024 20:26:15 +0800 Subject: [PATCH 24/34] Fix token not acquired when maxBatchSize is reached before token is ready Signed-off-by: MyonKeminta --- client/tso_batch_controller.go | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/client/tso_batch_controller.go b/client/tso_batch_controller.go index 51bba01a2cd..08f0bccc403 100644 --- a/client/tso_batch_controller.go +++ b/client/tso_batch_controller.go @@ -64,7 +64,18 @@ func (tbc *tsoBatchController) fetchPendingRequests(ctx context.Context, tsoRequ // Wait until BOTH the first request and the token have arrived. // TODO: `tbc.collectedRequestCount` should never be non-empty here. Consider do assertion here. tbc.collectedRequestCount = 0 - for tbc.collectedRequestCount < tbc.maxBatchSize { + for { + // If the batch size reaches the maxBatchSize limit but the token haven't arrived yet, don't receive more + // requests, and return when token is ready. + if tbc.collectedRequestCount >= tbc.maxBatchSize && !tokenAcquired { + select { + case <-ctx.Done(): + return ctx.Err() + case <-tokenCh: + return nil + } + } + select { case <-ctx.Done(): return ctx.Err() From e7590ec5200e48b336972734a33cb44128847823 Mon Sep 17 00:00:00 2001 From: MyonKeminta Date: Mon, 23 Sep 2024 13:12:08 +0800 Subject: [PATCH 25/34] fix lint Signed-off-by: MyonKeminta --- client/tso_stream.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/client/tso_stream.go b/client/tso_stream.go index cde961e891e..94d27662cf3 100644 --- a/client/tso_stream.go +++ b/client/tso_stream.go @@ -369,7 +369,7 @@ func (s *tsoStream) recvLoop(ctx context.Context) { // For calculating the estimated RPC latency. const ( filterCutoffFreq float64 = 1.0 - filterNewSampleWeightUpperbound = 0.2 + filterNewSampleWeightUpperbound float64 = 0.2 ) // The filter applies on logarithm of the latency of each TSO RPC in microseconds. filter := newRCFilter(filterCutoffFreq, filterNewSampleWeightUpperbound) From e5f9355878d04db2587bbb1d712a4a9af95d4bdc Mon Sep 17 00:00:00 2001 From: MyonKeminta Date: Mon, 23 Sep 2024 16:10:10 +0800 Subject: [PATCH 26/34] Add test to the rcFilter Signed-off-by: MyonKeminta --- client/tso_dispatcher.go | 1 + client/tso_stream.go | 14 +++++++-- client/tso_stream_test.go | 62 +++++++++++++++++++++++++++++++++++++++ 3 files changed, 75 insertions(+), 2 deletions(-) diff --git a/client/tso_dispatcher.go b/client/tso_dispatcher.go index a8402608bee..2fada15809e 100644 --- a/client/tso_dispatcher.go +++ b/client/tso_dispatcher.go @@ -570,6 +570,7 @@ func (td *tsoDispatcher) processRequests( } // `logical` is the largest ts's logical part here, we need to do the subtracting before we finish each TSO request. firstLogical := tsoutil.AddLogical(result.logical, -int64(result.count)+1, result.suffixBits) + // Do the check before releasing the token. td.compareAndSwapTS(curTSOInfo, firstLogical) td.doneCollectedRequests(tbc, result.physical, firstLogical, result.suffixBits, stream.streamID) } diff --git a/client/tso_stream.go b/client/tso_stream.go index 94d27662cf3..0779f7fb1b9 100644 --- a/client/tso_stream.go +++ b/client/tso_stream.go @@ -468,6 +468,14 @@ func (s *tsoStream) WaitForClosed() { s.wg.Wait() } +// rcFilter is a simple implementation of a discrete-time low-pass filter. +// Ref: https://en.wikipedia.org/wiki/Low-pass_filter#Simple_infinite_impulse_response_filter +// There are some differences between this implementation and the wikipedia one: +// - Time-interval between each two samples is not necessarily a constant. We allow non-even sample interval by simply +// calculating the alpha (which is calculated by `dt / (rc + dt)`) dynamically for each sample, at the expense of +// losing some mathematical strictness. +// - Support specifying the upperbound of the new sample when updating. This can be an approach to avoid the output +// jumps drastically when the samples come in a low frequency. type rcFilter struct { rc float64 newSampleWeightUpperBound float64 @@ -476,6 +484,8 @@ type rcFilter struct { firstSampleArrived bool } +// newRCFilter initializes an rcFilter. `cutoff` is the cutoff frequency in Hertz. `newSampleWeightUpperbound` controls +// the upper limit of the weight of each incoming sample (pass 1 for unlimited). func newRCFilter(cutoff float64, newSampleWeightUpperBound float64) rcFilter { rc := 1.0 / (2.0 * math.Pi * cutoff) return rcFilter{ @@ -493,9 +503,9 @@ func (f *rcFilter) update(sampleTime time.Time, newSample float64) float64 { return newSample } - // Delta time + // Delta time. dt := sampleTime.Sub(f.lastSampleTime).Seconds() - // Current sample represented and calculated in log(microseconds) + // `alpha` is the weight of the new sample, limited with `newSampleWeightUpperBound`. alpha := math.Min(dt/(f.rc+dt), f.newSampleWeightUpperBound) f.value = (1-alpha)*f.value + alpha*newSample diff --git a/client/tso_stream_test.go b/client/tso_stream_test.go index b09c54baf3a..5016c3cef1f 100644 --- a/client/tso_stream_test.go +++ b/client/tso_stream_test.go @@ -17,6 +17,7 @@ package pd import ( "context" "io" + "math" "testing" "time" @@ -454,6 +455,67 @@ func (s *testTSOStreamSuite) TestTSOStreamConcurrentRunning() { } } +//func (s *testTSOStreamSuite) TestEstimatedLatency() { +// s.inner.returnResult(100, 0, 1) +// res := s.getResult(s.mustProcessRequestWithResultCh(1)) +// s.NoError(res.err) +// s.Equal(int64(100), res.result.physical) +// s.Equal(int64(0), res.result.logical) +// s.InDelta() +//} + +func TestRCFilter(t *testing.T) { + re := require.New(t) + // Test basic calculation with frequency 1 + f := newRCFilter(1, 1) + now := time.Now() + // The first sample initializes the value. + re.Equal(10.0, f.update(now, 10)) + now = now.Add(time.Second) + expectedValue := 10 / (2*math.Pi + 1) + re.InEpsilon(expectedValue, f.update(now, 0), 1e-8) + expectedValue = expectedValue*(1/(2*math.Pi))/(1/(2*math.Pi)+2) + 100*2/(1/(2*math.Pi)+2) + now = now.Add(time.Second * 2) + re.InEpsilon(expectedValue, f.update(now, 100), 1e-8) + + // Test newSampleWeightUpperBound + f = newRCFilter(10, 0.5) + now = time.Now() + re.Equal(0.0, f.update(now, 0)) + now = now.Add(time.Second) + re.InEpsilon(1.0, f.update(now, 2), 1e-8) + now = now.Add(time.Second * 2) + re.InEpsilon(3.0, f.update(now, 5), 1e-8) + + // Test another cutoff frequency and weight upperbound. + f = newRCFilter(1/(2*math.Pi), 0.9) + now = time.Now() + re.Equal(1.0, f.update(now, 1)) + now = now.Add(time.Second) + re.InEpsilon(2.0, f.update(now, 3), 1e-8) + now = now.Add(time.Second * 2) + re.InEpsilon(6.0, f.update(now, 8), 1e-8) + now = now.Add(time.Minute) + re.InEpsilon(15.0, f.update(now, 16), 1e-8) + + // Test with dense samples + f = newRCFilter(1/(2*math.Pi), 0.9) + now = time.Now() + re.Equal(0.0, f.update(now, 0)) + lastOutput := 0.0 + // 10000 even samples in 1 second. + for i := 0; i < 10000; i++ { + now = now.Add(time.Microsecond * 100) + output := f.update(now, 1.0) + re.Greater(output, lastOutput) + re.Less(output, 1.0) + lastOutput = output + } + // Regarding the above samples as being close enough to a continuous function, the output after 1 second + // should be 1 - exp(-RC*t) = 1 - exp(-t). Here RC = 1/(2*pi*cutoff) = 1. + re.InDelta(0.63, lastOutput, 0.02) +} + func BenchmarkTSOStreamSendRecv(b *testing.B) { log.SetLevel(zapcore.FatalLevel) From 1b022ced471ecef96d1306baa801cb4b46185c24 Mon Sep 17 00:00:00 2001 From: MyonKeminta Date: Mon, 23 Sep 2024 18:20:45 +0800 Subject: [PATCH 27/34] Adapt the monotonicity check for parallel RPC Signed-off-by: MyonKeminta --- client/tso_dispatcher.go | 52 ++++++++++++++++++++++++++-------------- 1 file changed, 34 insertions(+), 18 deletions(-) diff --git a/client/tso_dispatcher.go b/client/tso_dispatcher.go index 2fada15809e..43049476d29 100644 --- a/client/tso_dispatcher.go +++ b/client/tso_dispatcher.go @@ -20,6 +20,7 @@ import ( "math/rand" "runtime/trace" "sync" + "sync/atomic" "time" "github.com/opentracing/opentracing-go" @@ -81,7 +82,7 @@ type tsoDispatcher struct { connectionCtxs *sync.Map tsoRequestCh chan *tsoRequest tsDeadlineCh chan *deadline - lastTSOInfo *tsoInfo + latestTSOInfo atomic.Pointer[tsoInfo] // For reusing tsoBatchController objects batchBufferPool *sync.Pool @@ -549,6 +550,9 @@ func (td *tsoDispatcher) processRequests( reqKeyspaceGroupID = svcDiscovery.GetKeyspaceGroupID() ) + // Load latest allocated ts for monotonicity assertion. + tsoInfoBeforeReq := td.latestTSOInfo.Load() + cb := func(result tsoRequestResult, reqKeyspaceGroupID uint32, err error) { // As golang doesn't allow double-closing a channel, here is implicitly a check that the callback // is never called twice or called while it's also being cancelled elsewhere. @@ -571,7 +575,7 @@ func (td *tsoDispatcher) processRequests( // `logical` is the largest ts's logical part here, we need to do the subtracting before we finish each TSO request. firstLogical := tsoutil.AddLogical(result.logical, -int64(result.count)+1, result.suffixBits) // Do the check before releasing the token. - td.compareAndSwapTS(curTSOInfo, firstLogical) + td.checkMonotonicity(tsoInfoBeforeReq, curTSOInfo, firstLogical) td.doneCollectedRequests(tbc, result.physical, firstLogical, result.suffixBits, stream.streamID) } @@ -597,32 +601,27 @@ func (td *tsoDispatcher) doneCollectedRequests(tbc *tsoBatchController, physical tbc.finishCollectedRequests(physical, firstLogical, suffixBits, streamID, nil) } -func (td *tsoDispatcher) compareAndSwapTS( - curTSOInfo *tsoInfo, firstLogical int64, +func (td *tsoDispatcher) checkMonotonicity( + lastTSOInfo *tsoInfo, curTSOInfo *tsoInfo, firstLogical int64, ) { - if td.lastTSOInfo != nil { - var ( - lastTSOInfo = td.lastTSOInfo - dc = td.dc - physical = curTSOInfo.physical - keyspaceID = td.provider.getServiceDiscovery().GetKeyspaceID() - ) - if td.lastTSOInfo.respKeyspaceGroupID != curTSOInfo.respKeyspaceGroupID { + keyspaceID := td.provider.getServiceDiscovery().GetKeyspaceID() + if lastTSOInfo != nil { + if lastTSOInfo.respKeyspaceGroupID != curTSOInfo.respKeyspaceGroupID { log.Info("[tso] keyspace group changed", - zap.String("dc-location", dc), + zap.String("dc-location", td.dc), zap.Uint32("old-group-id", lastTSOInfo.respKeyspaceGroupID), zap.Uint32("new-group-id", curTSOInfo.respKeyspaceGroupID)) } // The TSO we get is a range like [largestLogical-count+1, largestLogical], so we save the last TSO's largest logical // to compare with the new TSO's first logical. For example, if we have a TSO resp with logical 10, count 5, then - // all TSOs we get will be [6, 7, 8, 9, 10]. lastTSOInfo.logical stores the logical part of the largest ts returned + // all TSOs we get will be [6, 7, 8, 9, 10]. latestTSOInfo.logical stores the logical part of the largest ts returned // last time. - if tsoutil.TSLessEqual(physical, firstLogical, lastTSOInfo.physical, lastTSOInfo.logical) { + if tsoutil.TSLessEqual(curTSOInfo.physical, firstLogical, lastTSOInfo.physical, lastTSOInfo.logical) { log.Panic("[tso] timestamp fallback", - zap.String("dc-location", dc), + zap.String("dc-location", td.dc), zap.Uint32("keyspace", keyspaceID), zap.String("last-ts", fmt.Sprintf("(%d, %d)", lastTSOInfo.physical, lastTSOInfo.logical)), - zap.String("cur-ts", fmt.Sprintf("(%d, %d)", physical, firstLogical)), + zap.String("cur-ts", fmt.Sprintf("(%d, %d)", curTSOInfo.physical, firstLogical)), zap.String("last-tso-server", lastTSOInfo.tsoServer), zap.String("cur-tso-server", curTSOInfo.tsoServer), zap.Uint32("last-keyspace-group-in-request", lastTSOInfo.reqKeyspaceGroupID), @@ -633,7 +632,24 @@ func (td *tsoDispatcher) compareAndSwapTS( zap.Time("cur-response-received-at", curTSOInfo.respReceivedAt)) } } - td.lastTSOInfo = curTSOInfo + + if td.latestTSOInfo.CompareAndSwap(nil, curTSOInfo) { + // If latestTSOInfo is missing, simply store it and exit. + return + } + + // Replace if we are holding a larger ts than that has been recorded. + for { + old := td.latestTSOInfo.Load() + if tsoutil.TSLessEqual(curTSOInfo.physical, curTSOInfo.logical, old.physical, old.logical) { + // The current one is large enough. Skip. + break + } + if td.latestTSOInfo.CompareAndSwap(old, curTSOInfo) { + // Successfully replaced. + break + } + } } // checkTSORPCConcurrency checks configurations about TSO RPC concurrency, and adjust the token count if needed. From c4c9b4f64686a63d80a5ff7db5ad5d32f6ffb092 Mon Sep 17 00:00:00 2001 From: MyonKeminta Date: Mon, 23 Sep 2024 19:35:38 +0800 Subject: [PATCH 28/34] Add test TestEstimatedLatency Signed-off-by: MyonKeminta --- client/tso_stream.go | 4 ++- client/tso_stream_test.go | 70 ++++++++++++++++++++++++++++++++++----- 2 files changed, 65 insertions(+), 9 deletions(-) diff --git a/client/tso_stream.go b/client/tso_stream.go index 0779f7fb1b9..5576d89c0cb 100644 --- a/client/tso_stream.go +++ b/client/tso_stream.go @@ -231,6 +231,8 @@ var streamIDAlloc atomic.Int32 const invalidStreamID = "" +const maxPendingRequestsInTSOStream = 64 + func newTSOStream(ctx context.Context, serverURL string, stream grpcTSOStreamAdapter) *tsoStream { streamID := fmt.Sprintf("%s-%d", serverURL, streamIDAlloc.Add(1)) // To make error handling in `tsoDispatcher` work, the internal `cancel` and external `cancel` is better to be @@ -241,7 +243,7 @@ func newTSOStream(ctx context.Context, serverURL string, stream grpcTSOStreamAda stream: stream, streamID: streamID, - pendingRequests: make(chan batchedRequests, 64), + pendingRequests: make(chan batchedRequests, maxPendingRequestsInTSOStream), cancel: cancel, diff --git a/client/tso_stream_test.go b/client/tso_stream_test.go index 5016c3cef1f..ed857eb714d 100644 --- a/client/tso_stream_test.go +++ b/client/tso_stream_test.go @@ -455,14 +455,68 @@ func (s *testTSOStreamSuite) TestTSOStreamConcurrentRunning() { } } -//func (s *testTSOStreamSuite) TestEstimatedLatency() { -// s.inner.returnResult(100, 0, 1) -// res := s.getResult(s.mustProcessRequestWithResultCh(1)) -// s.NoError(res.err) -// s.Equal(int64(100), res.result.physical) -// s.Equal(int64(0), res.result.logical) -// s.InDelta() -//} +func (s *testTSOStreamSuite) TestEstimatedLatency() { + s.inner.returnResult(100, 0, 1) + res := s.getResult(s.mustProcessRequestWithResultCh(1)) + s.re.NoError(res.err) + s.re.Equal(int64(100), res.result.physical) + s.re.Equal(int64(0), res.result.logical) + estimation := s.stream.EstimatedRPCLatency().Seconds() + s.re.Greater(estimation, 0.0) + s.re.InDelta(0.0, estimation, 0.01) + + // For each began request, record its startTime and send it to the result returning goroutine. + reqStartTimeCh := make(chan time.Time, maxPendingRequestsInTSOStream) + // Limit concurrent requests to be less than the capacity of tsoStream.pendingRequests. + tokenCh := make(chan struct{}, maxPendingRequestsInTSOStream-1) + for i := 0; i < 40; i++ { + tokenCh <- struct{}{} + } + // Return a result after 50ms delay for each requests + const delay = time.Millisecond * 50 + // The goroutine to delay and return the result. + go func() { + allocated := int64(1) + for reqStartTime := range reqStartTimeCh { + now := time.Now() + elapsed := now.Sub(reqStartTime) + if elapsed < delay { + time.Sleep(delay - elapsed) + } + s.inner.returnResult(100, allocated, 1) + allocated++ + } + }() + + // Limit the test time within 1s + startTime := time.Now() + resCh := make(chan (<-chan callbackInvocation), 100) + // The sending goroutine + go func() { + for time.Since(startTime) < time.Second { + <-tokenCh + reqStartTimeCh <- time.Now() + r := s.mustProcessRequestWithResultCh(1) + resCh <- r + } + close(reqStartTimeCh) + close(resCh) + }() + // Check the result + index := 0 + for r := range resCh { + // The first is 1 + index++ + res := s.getResult(r) + tokenCh <- struct{}{} + s.re.NoError(res.err) + s.re.Equal(int64(100), res.result.physical) + s.re.Equal(int64(index), res.result.logical) + } + + s.re.Greater(s.stream.EstimatedRPCLatency(), time.Duration(int64(0.9*float64(delay)))) + s.re.Less(s.stream.EstimatedRPCLatency(), time.Duration(math.Floor(1.1*float64(delay)))) +} func TestRCFilter(t *testing.T) { re := require.New(t) From f451595beade87c0bba45fa107d89e9cb7406da8 Mon Sep 17 00:00:00 2001 From: MyonKeminta Date: Wed, 25 Sep 2024 03:27:55 +0800 Subject: [PATCH 29/34] Add tests for concurrency limiting Signed-off-by: MyonKeminta --- client/tso_batch_controller.go | 1 - client/tso_dispatcher.go | 24 +++- client/tso_dispatcher_test.go | 229 ++++++++++++++++++++++++++++++++- client/tso_request.go | 16 ++- client/tso_stream.go | 12 ++ client/tso_stream_test.go | 60 +++++++-- 6 files changed, 314 insertions(+), 28 deletions(-) diff --git a/client/tso_batch_controller.go b/client/tso_batch_controller.go index 08f0bccc403..b810e108667 100644 --- a/client/tso_batch_controller.go +++ b/client/tso_batch_controller.go @@ -47,7 +47,6 @@ func newTSOBatchController(maxBatchSize int) *tsoBatchController { // It returns nil error if everything goes well, otherwise a non-nil error which means we should stop the service. // It's guaranteed that if this function failed after collecting some requests, then these requests will be cancelled // when the function returns, so the caller don't need to clear them manually. -// If maxBatchWaitInterval is not enabled, func (tbc *tsoBatchController) fetchPendingRequests(ctx context.Context, tsoRequestCh <-chan *tsoRequest, tokenCh chan struct{}, maxBatchWaitInterval time.Duration) (errRet error) { var tokenAcquired bool defer func() { diff --git a/client/tso_dispatcher.go b/client/tso_dispatcher.go index 43049476d29..563bcc9fc6e 100644 --- a/client/tso_dispatcher.go +++ b/client/tso_dispatcher.go @@ -121,7 +121,7 @@ func newTSODispatcher( provider: provider, connectionCtxs: &sync.Map{}, tsoRequestCh: tsoRequestCh, - tsDeadlineCh: make(chan *deadline, 1), + tsDeadlineCh: make(chan *deadline, tokenChCapacity), batchBufferPool: &sync.Pool{ New: func() any { return newTSOBatchController(maxBatchSize * 2) @@ -193,9 +193,6 @@ func (td *tsoDispatcher) handleDispatcher(wg *sync.WaitGroup) { batchController *tsoBatchController ) - // Currently only 1 concurrency is supported. Put one token in. - td.tokenCh <- struct{}{} - log.Info("[tso] tso dispatcher created", zap.String("dc-location", dc)) // Clean up the connectionCtxs when the dispatcher exits. defer func() { @@ -206,7 +203,10 @@ func (td *tsoDispatcher) handleDispatcher(wg *sync.WaitGroup) { return true }) if batchController != nil && batchController.collectedRequestCount != 0 { - log.Fatal("batched tso requests not cleared when exiting the tso dispatcher loop", zap.Any("panic", recover())) + if r := recover(); r != nil { + panic(r) + } + log.Fatal("batched tso requests not cleared when exiting the tso dispatcher loop") } tsoErr := errors.WithStack(errClosing) td.revokePendingRequests(tsoErr) @@ -341,12 +341,17 @@ tsoBatchLoop: break streamChoosingLoop } + noDelay := false + failpoint.Inject("tsoDispatcherConcurrentModeNoDelay", func() { + noDelay = true + }) + // If concurrent RPC is enabled, the time for collecting each request batch is expected to be // estimatedRPCDuration / concurrency. Note the time mentioned here is counted from starting trying to collect // the batch, instead of the time when the first request arrives. // Here, if the elapsed time since starting collecting this batch didn't reach the expected batch time, then // continue collecting. - if td.isConcurrentRPCEnabled() { + if td.isConcurrentRPCEnabled() && !noDelay { estimatedLatency := stream.EstimatedRPCLatency() estimateTSOLatencyGauge.WithLabelValues(streamURL).Set(estimatedLatency.Seconds()) @@ -364,7 +369,7 @@ tsoBatchLoop: err = batchController.fetchRequestsWithTimer(ctx, td.tsoRequestCh, batchingTimer) if err != nil { - // There should not be + // There should not be other kinds of errors. log.Info("[tso] stop fetching the pending tso requests due to context canceled", zap.String("dc-location", dc), zap.Error(err)) td.cancelCollectedRequests(batchController, invalidStreamID, errors.WithStack(ctx.Err())) @@ -672,6 +677,11 @@ func (td *tsoDispatcher) checkTSORPCConcurrency(ctx context.Context, maxBatchWai // immediately, because MaxTSOBatchWaitInterval and concurrent RPC requests has opposite purpose. immediatelyUpdate := td.rpcConcurrency > 1 && maxBatchWaitInterval > 0 + // Allow always updating for test purpose. + failpoint.Inject("tsoDispatcherAlwaysCheckConcurrency", func() { + immediatelyUpdate = true + }) + if !immediatelyUpdate && now.Sub(td.lastCheckConcurrencyTime) < dispatcherCheckRPCConcurrencyInterval { return nil } diff --git a/client/tso_dispatcher_test.go b/client/tso_dispatcher_test.go index b8f0fcef208..2aa082a4a64 100644 --- a/client/tso_dispatcher_test.go +++ b/client/tso_dispatcher_test.go @@ -18,20 +18,26 @@ import ( "context" "fmt" "sync" + "sync/atomic" "testing" "time" + "github.com/pingcap/failpoint" "github.com/pingcap/log" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" "go.uber.org/zap/zapcore" ) type mockTSOServiceProvider struct { - option *option + option *option + createStream func(ctx context.Context) *tsoStream } -func newMockTSOServiceProvider(option *option) *mockTSOServiceProvider { +func newMockTSOServiceProvider(option *option, createStream func(ctx context.Context) *tsoStream) *mockTSOServiceProvider { return &mockTSOServiceProvider{ - option: option, + option: option, + createStream: createStream, } } @@ -43,17 +49,228 @@ func (*mockTSOServiceProvider) getServiceDiscovery() ServiceDiscovery { return NewMockPDServiceDiscovery([]string{mockStreamURL}, nil) } -func (*mockTSOServiceProvider) updateConnectionCtxs(ctx context.Context, _dc string, connectionCtxs *sync.Map) bool { +func (m *mockTSOServiceProvider) updateConnectionCtxs(ctx context.Context, _dc string, connectionCtxs *sync.Map) bool { _, ok := connectionCtxs.Load(mockStreamURL) if ok { return true } ctx, cancel := context.WithCancel(ctx) - stream := newTSOStream(ctx, mockStreamURL, newMockTSOStreamImpl(ctx, true)) + var stream *tsoStream + if m.createStream == nil { + stream = newTSOStream(ctx, mockStreamURL, newMockTSOStreamImpl(ctx, resultModeGenerated)) + } else { + stream = m.createStream(ctx) + } connectionCtxs.LoadOrStore(mockStreamURL, &tsoConnectionContext{ctx, cancel, mockStreamURL, stream}) return true } +type testTSODispatcherSuite struct { + suite.Suite + re *require.Assertions + + streamInner *mockTSOStreamImpl + stream *tsoStream + dispatcher *tsoDispatcher + dispatcherWg sync.WaitGroup + option *option + + reqPool *sync.Pool +} + +func (s *testTSODispatcherSuite) SetupTest() { + s.re = require.New(s.T()) + s.option = newOption() + s.option.timeout = time.Hour + // As the internal logic of the tsoDispatcher allows it to create streams multiple times, but our tests needs + // single stable access to the inner stream, we do not allow it to create it more than once in these tests. + created := new(atomic.Bool) + createStream := func(ctx context.Context) *tsoStream { + if !created.CompareAndSwap(false, true) { + s.re.FailNow("testTSODispatcherSuite: trying to create stream more than once, which is unsupported in this tests") + } + s.streamInner = newMockTSOStreamImpl(ctx, resultModeGenerateOnSignal) + s.stream = newTSOStream(ctx, mockStreamURL, s.streamInner) + return s.stream + } + s.dispatcher = newTSODispatcher(context.Background(), globalDCLocation, defaultMaxTSOBatchSize, newMockTSOServiceProvider(s.option, createStream)) + s.reqPool = &sync.Pool{ + New: func() any { + return &tsoRequest{ + done: make(chan error, 1), + physical: 0, + logical: 0, + dcLocation: globalDCLocation, + } + }, + } + + s.dispatcherWg.Add(1) + go s.dispatcher.handleDispatcher(&s.dispatcherWg) + + // Perform a request to ensure the stream must be created. + + { + ctx := context.Background() + req := s.sendReq(ctx) + s.reqMustNotReady(req) + s.streamInner.generateNext() + s.reqMustReady(req) + } + s.re.NotNil(s.stream) +} + +func (s *testTSODispatcherSuite) TearDownTest() { + s.dispatcher.close() + s.streamInner.stop() + s.dispatcherWg.Wait() + s.stream.WaitForClosed() + s.streamInner = nil + s.stream = nil + s.dispatcher = nil + s.reqPool = nil +} + +func (s *testTSODispatcherSuite) getReq(ctx context.Context) *tsoRequest { + req := s.reqPool.Get().(*tsoRequest) + req.clientCtx = context.Background() + req.requestCtx = ctx + req.physical = 0 + req.logical = 0 + req.start = time.Now() + req.pool = s.reqPool + return req +} + +func (s *testTSODispatcherSuite) sendReq(ctx context.Context) *tsoRequest { + req := s.getReq(ctx) + s.dispatcher.push(req) + return req +} + +func (s *testTSODispatcherSuite) reqMustNotReady(req *tsoRequest) { + _, _, err := req.waitTimeout(time.Millisecond * 50) + s.re.Error(err) + s.re.ErrorIs(err, context.DeadlineExceeded) +} + +func (s *testTSODispatcherSuite) reqMustReady(req *tsoRequest) (physical int64, logical int64) { + physical, logical, err := req.waitTimeout(time.Second) + s.re.NoError(err) + return physical, logical +} + +func TestTSODispatcherTestSuite(t *testing.T) { + suite.Run(t, new(testTSODispatcherSuite)) +} + +func (s *testTSODispatcherSuite) TestBasic() { + ctx := context.Background() + req := s.sendReq(ctx) + s.reqMustNotReady(req) + s.streamInner.generateNext() + s.reqMustReady(req) +} + +func (s *testTSODispatcherSuite) checkIdleTokenCount(expectedTotal int) { + // When the tsoDispatcher is idle, the dispatcher loop will acquire a token and wait for requests. Therefore + // there should be N-1 free tokens remaining. + spinStart := time.Now() + for time.Since(spinStart) < time.Second { + if s.dispatcher.tokenCount != expectedTotal { + continue + } + if len(s.dispatcher.tokenCh) == expectedTotal-1 { + break + } + } + s.re.Equal(expectedTotal, s.dispatcher.tokenCount) + s.re.Equal(expectedTotal-1, len(s.dispatcher.tokenCh)) +} + +func (s *testTSODispatcherSuite) testStaticConcurrencyImpl(concurrency int) { + ctx := context.Background() + s.option.setTSOClientRPCConcurrency(concurrency) + + // Make sure the state of the mock stream is clear. Unexpected batching may make the requests sent to the stream + // less than expected, causing there are more `generateNext` signals or generated results. + s.re.Equal(0, len(s.streamInner.resultCh)) + + // The dispatcher may block on fetching requests, which is after checking concurrency option. Perform a request + // to make sure the concurrency setting takes effect. + req := s.sendReq(ctx) + s.reqMustNotReady(req) + s.streamInner.generateNext() + s.reqMustReady(req) + + // For concurrent mode, the actual token count is twice the concurrency. + // Note that the concurrency is a hint, and it's allowed to have more than `concurrency` requests running. + tokenCount := concurrency + if concurrency > 1 { + tokenCount = concurrency * 2 + } + s.checkIdleTokenCount(tokenCount) + + // As the failpoint `tsoDispatcherConcurrentModeNoDelay` is set, tsoDispatcher won't collect requests in blocking + // way. And as `reqMustNotReady` delays for a while, requests shouldn't be batched as long as there are free tokens. + // The first N requests (N=tokenCount) will each be a single batch, occupying a token. The last 3 are blocked, + // and will be batched together once there is a free token. + reqs := make([]*tsoRequest, 0, tokenCount+3) + + for i := 0; i < tokenCount+3; i++ { + req := s.sendReq(ctx) + s.reqMustNotReady(req) + reqs = append(reqs, req) + } + + //time.Sleep(time.Hour) + // The dispatcher won't process more request batches if tokens are used up. + // Note that `reqMustNotReady` contains a delay, which makes it nearly impossible that dispatcher is processing the + // second batch but not finished yet. + // Also note that in current implementation, the tsoStream tries to receive the next result before checking + // the `tsoStream.pendingRequests` queue. Changing this behavior may need to update this test. + for i := 0; i < tokenCount+3; i++ { + expectedPending := tokenCount + 1 - i + if expectedPending > tokenCount { + expectedPending = tokenCount + } + if expectedPending < 0 { + expectedPending = 0 + } + + // Spin for a while as the dispatcher loop may have not finished sending next batch to pendingRequests + spinStart := time.Now() + for time.Since(spinStart) < time.Second { + if expectedPending == len(s.stream.pendingRequests) { + break + } + } + s.re.Equal(expectedPending, len(s.stream.pendingRequests)) + + req := reqs[i] + // The last 3 requests should be in a single batch. Don't need to generate new results for the last 2. + if i <= tokenCount { + s.reqMustNotReady(req) + s.streamInner.generateNext() + } + s.reqMustReady(req) + } +} + +func (s *testTSODispatcherSuite) TestConcurrentRPC() { + s.re.NoError(failpoint.Enable("github.com/tikv/pd/client/tsoDispatcherConcurrentModeNoDelay", "return")) + s.re.NoError(failpoint.Enable("github.com/tikv/pd/client/tsoDispatcherAlwaysCheckConcurrency", "return")) + defer func() { + s.re.NoError(failpoint.Disable("github.com/tikv/pd/client/tsoDispatcherConcurrentModeNoDelay")) + s.re.NoError(failpoint.Disable("github.com/tikv/pd/client/tsoDispatcherAlwaysCheckConcurrency")) + }() + + s.testStaticConcurrencyImpl(1) + s.testStaticConcurrencyImpl(2) + s.testStaticConcurrencyImpl(4) + s.testStaticConcurrencyImpl(16) +} + func BenchmarkTSODispatcherHandleRequests(b *testing.B) { log.SetLevel(zapcore.FatalLevel) @@ -80,7 +297,7 @@ func BenchmarkTSODispatcherHandleRequests(b *testing.B) { return req } - dispatcher := newTSODispatcher(ctx, globalDCLocation, defaultMaxTSOBatchSize, newMockTSOServiceProvider(newOption())) + dispatcher := newTSODispatcher(ctx, globalDCLocation, defaultMaxTSOBatchSize, newMockTSOServiceProvider(newOption(), nil)) var wg sync.WaitGroup wg.Add(1) diff --git a/client/tso_request.go b/client/tso_request.go index fb2ae2bb92e..5c959673a8b 100644 --- a/client/tso_request.go +++ b/client/tso_request.go @@ -60,6 +60,11 @@ func (req *tsoRequest) tryDone(err error) { // Wait will block until the TSO result is ready. func (req *tsoRequest) Wait() (physical int64, logical int64, err error) { + return req.waitCtx(req.requestCtx) +} + +// waitCtx waits for the TSO result with specified ctx, while not using req.requestCtx. +func (req *tsoRequest) waitCtx(ctx context.Context) (physical int64, logical int64, err error) { // If tso command duration is observed very high, the reason could be it // takes too long for Wait() be called. start := time.Now() @@ -78,13 +83,20 @@ func (req *tsoRequest) Wait() (physical int64, logical int64, err error) { cmdDurationWait.Observe(now.Sub(start).Seconds()) cmdDurationTSO.Observe(now.Sub(req.start).Seconds()) return - case <-req.requestCtx.Done(): - return 0, 0, errors.WithStack(req.requestCtx.Err()) + case <-ctx.Done(): + return 0, 0, errors.WithStack(ctx.Err()) case <-req.clientCtx.Done(): return 0, 0, errors.WithStack(req.clientCtx.Err()) } } +// waitTimeout waits for the TSO result for limited time. Currently only for test purposes. +func (req *tsoRequest) waitTimeout(timeout time.Duration) (physical int64, logical int64, err error) { + ctx, cancel := context.WithTimeout(req.requestCtx, timeout) + defer cancel() + return req.waitCtx(ctx) +} + type tsoRequestFastFail struct { err error } diff --git a/client/tso_stream.go b/client/tso_stream.go index 5576d89c0cb..9235e7d6e20 100644 --- a/client/tso_stream.go +++ b/client/tso_stream.go @@ -24,6 +24,7 @@ import ( "time" "github.com/pingcap/errors" + "github.com/pingcap/failpoint" "github.com/pingcap/kvproto/pkg/pdpb" "github.com/pingcap/kvproto/pkg/tsopb" "github.com/pingcap/log" @@ -448,6 +449,17 @@ recvLoop: // EstimatedRPCLatency returns an estimation of the duration of each TSO RPC. If the stream has never handled any RPC, // this function returns 0. func (s *tsoStream) EstimatedRPCLatency() time.Duration { + failpoint.Inject("tsoStreamSimulateEstimatedRPCLatency", func(val failpoint.Value) { + if s, ok := val.(string); ok { + duration, err := time.ParseDuration(s) + if err != nil { + panic(err) + } + failpoint.Return(duration) + } else { + panic("invalid failpoint value for `tsoStreamSimulateEstimatedRPCLatency`: expected string") + } + }) latencyUs := s.estimatedLatencyMicros.Load() // Limit it at least 100us if latencyUs < 100 { diff --git a/client/tso_stream_test.go b/client/tso_stream_test.go index ed857eb714d..ab6f2786ff3 100644 --- a/client/tso_stream_test.go +++ b/client/tso_stream_test.go @@ -43,6 +43,14 @@ type resultMsg struct { breakStream bool } +type resultMode int + +const ( + resultModeManual resultMode = iota + resultModeGenerated + resultModeGenerateOnSignal +) + type mockTSOStreamImpl struct { ctx context.Context requestCh chan requestMsg @@ -50,21 +58,21 @@ type mockTSOStreamImpl struct { keyspaceID uint32 errorState error - autoGenerateResult bool + resultMode resultMode // Current progress of generating TSO results resGenPhysical, resGenLogical int64 } -func newMockTSOStreamImpl(ctx context.Context, autoGenerateResult bool) *mockTSOStreamImpl { +func newMockTSOStreamImpl(ctx context.Context, resultMode resultMode) *mockTSOStreamImpl { return &mockTSOStreamImpl{ ctx: ctx, requestCh: make(chan requestMsg, 64), resultCh: make(chan resultMsg, 64), keyspaceID: 0, - autoGenerateResult: autoGenerateResult, - resGenPhysical: 10000, - resGenLogical: 0, + resultMode: resultMode, + resGenPhysical: 10000, + resGenLogical: 0, } } @@ -83,6 +91,17 @@ func (s *mockTSOStreamImpl) Send(clusterID uint64, _keyspaceID, keyspaceGroupID } func (s *mockTSOStreamImpl) Recv() (tsoRequestResult, error) { + var needGenerateResult, needResultSignal bool + switch s.resultMode { + case resultModeManual: + needResultSignal = true + case resultModeGenerated: + needGenerateResult = true + case resultModeGenerateOnSignal: + needResultSignal = true + needGenerateResult = true + } + // This stream have ever receive an error, it returns the error forever. if s.errorState != nil { return tsoRequestResult{}, s.errorState @@ -131,12 +150,12 @@ func (s *mockTSOStreamImpl) Recv() (tsoRequestResult, error) { } s.errorState = res.err return tsoRequestResult{}, s.errorState - } else if s.autoGenerateResult { + } else if !needResultSignal { // Do not allow manually assigning result. panic("trying manually specifying result for mockTSOStreamImpl when it's auto-generating mode") } - } else if s.autoGenerateResult { - res = s.autoGenResult(req.count) + } else if !needResultSignal { + // Mark hasRes as true to skip receiving from resultCh. The actual value of the result will be generated later. hasRes = true } @@ -161,6 +180,10 @@ func (s *mockTSOStreamImpl) Recv() (tsoRequestResult, error) { } } + if needGenerateResult { + res = s.autoGenResult(req.count) + } + // Both res and req should be ready here. if res.err != nil { s.errorState = res.err @@ -169,11 +192,14 @@ func (s *mockTSOStreamImpl) Recv() (tsoRequestResult, error) { } func (s *mockTSOStreamImpl) autoGenResult(count int64) resultMsg { + if count >= (1 << 18) { + panic("requested count too large") + } physical := s.resGenPhysical logical := s.resGenLogical + count if logical >= (1 << 18) { - physical += logical >> 18 - logical &= (1 << 18) - 1 + physical += 1 + logical = count } s.resGenPhysical = physical @@ -191,6 +217,9 @@ func (s *mockTSOStreamImpl) autoGenResult(count int64) resultMsg { } func (s *mockTSOStreamImpl) returnResult(physical int64, logical int64, count uint32) { + if s.resultMode != resultModeManual { + panic("trying to manually specifying tso result on generating mode") + } s.resultCh <- resultMsg{ r: tsoRequestResult{ physical: physical, @@ -202,6 +231,13 @@ func (s *mockTSOStreamImpl) returnResult(physical int64, logical int64, count ui } } +func (s *mockTSOStreamImpl) generateNext() { + if s.resultMode != resultModeGenerateOnSignal { + panic("trying to signal generation when the stream is not generate-on-signal mode") + } + s.resultCh <- resultMsg{} +} + func (s *mockTSOStreamImpl) returnError(err error) { s.resultCh <- resultMsg{ err: err, @@ -234,7 +270,7 @@ type testTSOStreamSuite struct { func (s *testTSOStreamSuite) SetupTest() { s.re = require.New(s.T()) - s.inner = newMockTSOStreamImpl(context.Background(), false) + s.inner = newMockTSOStreamImpl(context.Background(), resultModeManual) s.stream = newTSOStream(context.Background(), mockStreamURL, s.inner) } @@ -573,7 +609,7 @@ func TestRCFilter(t *testing.T) { func BenchmarkTSOStreamSendRecv(b *testing.B) { log.SetLevel(zapcore.FatalLevel) - streamInner := newMockTSOStreamImpl(context.Background(), true) + streamInner := newMockTSOStreamImpl(context.Background(), resultModeGenerated) stream := newTSOStream(context.Background(), mockStreamURL, streamInner) defer func() { streamInner.stop() From 708bbe9fc4c97649eb0ddec4a1737dba058e04c4 Mon Sep 17 00:00:00 2001 From: MyonKeminta Date: Wed, 25 Sep 2024 04:10:34 +0800 Subject: [PATCH 30/34] Add tests for calculating batch delay time Signed-off-by: MyonKeminta --- client/tso_dispatcher.go | 29 +++++++++++++++------- client/tso_dispatcher_test.go | 45 +++++++++++++++++++++++++++++++++++ 2 files changed, 66 insertions(+), 8 deletions(-) diff --git a/client/tso_dispatcher.go b/client/tso_dispatcher.go index 563bcc9fc6e..956a7daace6 100644 --- a/client/tso_dispatcher.go +++ b/client/tso_dispatcher.go @@ -17,6 +17,7 @@ package pd import ( "context" "fmt" + "math" "math/rand" "runtime/trace" "sync" @@ -203,10 +204,8 @@ func (td *tsoDispatcher) handleDispatcher(wg *sync.WaitGroup) { return true }) if batchController != nil && batchController.collectedRequestCount != 0 { - if r := recover(); r != nil { - panic(r) - } - log.Fatal("batched tso requests not cleared when exiting the tso dispatcher loop") + // If you encounter this failure, please check the stack in the logs to see if it's a panic. + log.Fatal("batched tso requests not cleared when exiting the tso dispatcher loop", zap.Any("panic", recover())) } tsoErr := errors.WithStack(errClosing) td.revokePendingRequests(tsoErr) @@ -351,14 +350,28 @@ tsoBatchLoop: // the batch, instead of the time when the first request arrives. // Here, if the elapsed time since starting collecting this batch didn't reach the expected batch time, then // continue collecting. - if td.isConcurrentRPCEnabled() && !noDelay { + if td.isConcurrentRPCEnabled() { estimatedLatency := stream.EstimatedRPCLatency() estimateTSOLatencyGauge.WithLabelValues(streamURL).Set(estimatedLatency.Seconds()) + goalBatchTime := estimatedLatency / time.Duration(td.rpcConcurrency) + + failpoint.Inject("tsoDispatcherConcurrentModeAssertDelayDuration", func(val failpoint.Value) { + if s, ok := val.(string); ok { + expected, err := time.ParseDuration(s) + if err != nil { + panic(err) + } + if math.Abs(expected.Seconds()-goalBatchTime.Seconds()) > 1e-6 { + log.Fatal("tsoDispatcher: trying to delay for unexpected duration for the batch", zap.Duration("goalBatchTime", goalBatchTime), zap.Duration("expectedBatchTime", expected)) + } + } else { + panic("invalid value for failpoint tsoDispatcherConcurrentModeAssertDelayDuration: expected string") + } + }) - totalBatchTime := estimatedLatency / time.Duration(td.rpcConcurrency) waitTimerStart := time.Now() - remainingBatchTime := totalBatchTime - waitTimerStart.Sub(currentBatchStartTime) - if remainingBatchTime > 0 { + remainingBatchTime := goalBatchTime - waitTimerStart.Sub(currentBatchStartTime) + if remainingBatchTime > 0 && !noDelay { if !batchingTimer.Stop() { select { case <-batchingTimer.C: diff --git a/client/tso_dispatcher_test.go b/client/tso_dispatcher_test.go index 2aa082a4a64..2f7a1ef6e1e 100644 --- a/client/tso_dispatcher_test.go +++ b/client/tso_dispatcher_test.go @@ -32,6 +32,7 @@ import ( type mockTSOServiceProvider struct { option *option createStream func(ctx context.Context) *tsoStream + updateConnMu sync.Mutex } func newMockTSOServiceProvider(option *option, createStream func(ctx context.Context) *tsoStream) *mockTSOServiceProvider { @@ -50,6 +51,11 @@ func (*mockTSOServiceProvider) getServiceDiscovery() ServiceDiscovery { } func (m *mockTSOServiceProvider) updateConnectionCtxs(ctx context.Context, _dc string, connectionCtxs *sync.Map) bool { + // Avoid concurrent updating in the background updating goroutine and active updating in the dispatcher loop when + // stream is missing. + m.updateConnMu.Lock() + defer m.updateConnMu.Unlock() + _, ok := connectionCtxs.Load(mockStreamURL) if ok { return true @@ -271,6 +277,45 @@ func (s *testTSODispatcherSuite) TestConcurrentRPC() { s.testStaticConcurrencyImpl(16) } +func (s *testTSODispatcherSuite) TestBatchDelaying() { + ctx := context.Background() + s.option.setTSOClientRPCConcurrency(2) + + s.re.NoError(failpoint.Enable("github.com/tikv/pd/client/tsoDispatcherConcurrentModeNoDelay", "return")) + s.re.NoError(failpoint.Enable("github.com/tikv/pd/client/tsoStreamSimulateEstimatedRPCLatency", `return("12ms")`)) + defer func() { + s.re.NoError(failpoint.Disable("github.com/tikv/pd/client/tsoDispatcherConcurrentModeNoDelay")) + s.re.NoError(failpoint.Disable("github.com/tikv/pd/client/tsoStreamSimulateEstimatedRPCLatency")) + }() + + // Make sure concurrency option takes effect. + req := s.sendReq(ctx) + s.streamInner.generateNext() + s.reqMustReady(req) + + // Trigger the check. + s.re.NoError(failpoint.Enable("github.com/tikv/pd/client/tsoDispatcherConcurrentModeAssertDelayDuration", `return("6ms")`)) + defer func() { + s.re.NoError(failpoint.Disable("github.com/tikv/pd/client/tsoDispatcherConcurrentModeAssertDelayDuration")) + }() + req = s.sendReq(ctx) + s.streamInner.generateNext() + s.reqMustReady(req) + + // Try other concurrency. + s.option.setTSOClientRPCConcurrency(3) + s.re.NoError(failpoint.Enable("github.com/tikv/pd/client/tsoDispatcherConcurrentModeAssertDelayDuration", `return("4ms")`)) + req = s.sendReq(ctx) + s.streamInner.generateNext() + s.reqMustReady(req) + + s.option.setTSOClientRPCConcurrency(4) + s.re.NoError(failpoint.Enable("github.com/tikv/pd/client/tsoDispatcherConcurrentModeAssertDelayDuration", `return("3ms")`)) + req = s.sendReq(ctx) + s.streamInner.generateNext() + s.reqMustReady(req) +} + func BenchmarkTSODispatcherHandleRequests(b *testing.B) { log.SetLevel(zapcore.FatalLevel) From 4a795fed2abc90b783bfeb31b35b7c0451b6ff08 Mon Sep 17 00:00:00 2001 From: MyonKeminta Date: Wed, 25 Sep 2024 04:14:28 +0800 Subject: [PATCH 31/34] Fix lint Signed-off-by: MyonKeminta --- client/tso_dispatcher_test.go | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/client/tso_dispatcher_test.go b/client/tso_dispatcher_test.go index 2f7a1ef6e1e..f22823d9698 100644 --- a/client/tso_dispatcher_test.go +++ b/client/tso_dispatcher_test.go @@ -191,7 +191,7 @@ func (s *testTSODispatcherSuite) checkIdleTokenCount(expectedTotal int) { } } s.re.Equal(expectedTotal, s.dispatcher.tokenCount) - s.re.Equal(expectedTotal-1, len(s.dispatcher.tokenCh)) + s.re.Len(s.dispatcher.tokenCh, expectedTotal-1) } func (s *testTSODispatcherSuite) testStaticConcurrencyImpl(concurrency int) { @@ -200,7 +200,7 @@ func (s *testTSODispatcherSuite) testStaticConcurrencyImpl(concurrency int) { // Make sure the state of the mock stream is clear. Unexpected batching may make the requests sent to the stream // less than expected, causing there are more `generateNext` signals or generated results. - s.re.Equal(0, len(s.streamInner.resultCh)) + s.re.Empty(s.streamInner.resultCh) // The dispatcher may block on fetching requests, which is after checking concurrency option. Perform a request // to make sure the concurrency setting takes effect. @@ -229,7 +229,6 @@ func (s *testTSODispatcherSuite) testStaticConcurrencyImpl(concurrency int) { reqs = append(reqs, req) } - //time.Sleep(time.Hour) // The dispatcher won't process more request batches if tokens are used up. // Note that `reqMustNotReady` contains a delay, which makes it nearly impossible that dispatcher is processing the // second batch but not finished yet. @@ -251,7 +250,7 @@ func (s *testTSODispatcherSuite) testStaticConcurrencyImpl(concurrency int) { break } } - s.re.Equal(expectedPending, len(s.stream.pendingRequests)) + s.re.Len(s.stream.pendingRequests, expectedPending) req := reqs[i] // The last 3 requests should be in a single batch. Don't need to generate new results for the last 2. From 0bfe55f73deaa519d1b5beae1081fe0b93189185 Mon Sep 17 00:00:00 2001 From: MyonKeminta Date: Wed, 25 Sep 2024 12:34:25 +0800 Subject: [PATCH 32/34] Address comments; add more comments to explain checkMonotonicity Signed-off-by: MyonKeminta --- client/tso_dispatcher.go | 14 +++++++++++++- client/tso_stream.go | 7 ++++--- 2 files changed, 17 insertions(+), 4 deletions(-) diff --git a/client/tso_dispatcher.go b/client/tso_dispatcher.go index 956a7daace6..009b017aa4d 100644 --- a/client/tso_dispatcher.go +++ b/client/tso_dispatcher.go @@ -63,6 +63,7 @@ type tsoInfo struct { respReceivedAt time.Time physical int64 logical int64 + sourceStreamID string } type tsoServiceProvider interface { @@ -589,6 +590,7 @@ func (td *tsoDispatcher) processRequests( respReceivedAt: time.Now(), physical: result.physical, logical: result.logical, + sourceStreamID: stream.streamID, } // `logical` is the largest ts's logical part here, we need to do the subtracting before we finish each TSO request. firstLogical := tsoutil.AddLogical(result.logical, -int64(result.count)+1, result.suffixBits) @@ -619,6 +621,14 @@ func (td *tsoDispatcher) doneCollectedRequests(tbc *tsoBatchController, physical tbc.finishCollectedRequests(physical, firstLogical, suffixBits, streamID, nil) } +// checkMonotonicity checks whether the monotonicity of the TSO allocation is violated. +// It asserts (curTSOInfo, firstLogical) must be larger than lastTSOInfo, and updates td.latestTSOInfo if it grows. +// +// Note that when concurrent RPC is enabled, the lastTSOInfo may not be the latest value stored in td.latestTSOInfo +// field. Instead, it's the value that was loaded just before the current RPC request's beginning. The reason is, +// if two requests processing time has overlap, they don't have a strong order, and the later-finished one may be +// allocated later (with larger value) than another. We only need to guarantee request A returns larger ts than B +// if request A *starts* after request B *finishes*. func (td *tsoDispatcher) checkMonotonicity( lastTSOInfo *tsoInfo, curTSOInfo *tsoInfo, firstLogical int64, ) { @@ -647,7 +657,9 @@ func (td *tsoDispatcher) checkMonotonicity( zap.Uint32("last-keyspace-group-in-response", lastTSOInfo.respKeyspaceGroupID), zap.Uint32("cur-keyspace-group-in-response", curTSOInfo.respKeyspaceGroupID), zap.Time("last-response-received-at", lastTSOInfo.respReceivedAt), - zap.Time("cur-response-received-at", curTSOInfo.respReceivedAt)) + zap.Time("cur-response-received-at", curTSOInfo.respReceivedAt), + zap.String("last-stream-id", lastTSOInfo.sourceStreamID), + zap.String("cur-stream-id", curTSOInfo.sourceStreamID)) } } diff --git a/client/tso_stream.go b/client/tso_stream.go index 9235e7d6e20..b8118b04468 100644 --- a/client/tso_stream.go +++ b/client/tso_stream.go @@ -230,9 +230,10 @@ const ( var streamIDAlloc atomic.Int32 -const invalidStreamID = "" - -const maxPendingRequestsInTSOStream = 64 +const ( + invalidStreamID = "" + maxPendingRequestsInTSOStream = 64 +) func newTSOStream(ctx context.Context, serverURL string, stream grpcTSOStreamAdapter) *tsoStream { streamID := fmt.Sprintf("%s-%d", serverURL, streamIDAlloc.Add(1)) From 0bfe9a69293de6efe3576bc89e3f5dfe7b8c1ead Mon Sep 17 00:00:00 2001 From: MyonKeminta Date: Wed, 25 Sep 2024 12:51:11 +0800 Subject: [PATCH 33/34] Fix data race Signed-off-by: MyonKeminta --- client/tso_dispatcher_test.go | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/client/tso_dispatcher_test.go b/client/tso_dispatcher_test.go index f22823d9698..bf038e7b7f3 100644 --- a/client/tso_dispatcher_test.go +++ b/client/tso_dispatcher_test.go @@ -90,13 +90,16 @@ func (s *testTSODispatcherSuite) SetupTest() { s.option.timeout = time.Hour // As the internal logic of the tsoDispatcher allows it to create streams multiple times, but our tests needs // single stable access to the inner stream, we do not allow it to create it more than once in these tests. + creating := new(atomic.Bool) + // To avoid data race on reading `stream` and `streamInner` fields. created := new(atomic.Bool) createStream := func(ctx context.Context) *tsoStream { - if !created.CompareAndSwap(false, true) { + if !creating.CompareAndSwap(false, true) { s.re.FailNow("testTSODispatcherSuite: trying to create stream more than once, which is unsupported in this tests") } s.streamInner = newMockTSOStreamImpl(ctx, resultModeGenerateOnSignal) s.stream = newTSOStream(ctx, mockStreamURL, s.streamInner) + created.Store(true) return s.stream } s.dispatcher = newTSODispatcher(context.Background(), globalDCLocation, defaultMaxTSOBatchSize, newMockTSOServiceProvider(s.option, createStream)) @@ -120,9 +123,14 @@ func (s *testTSODispatcherSuite) SetupTest() { ctx := context.Background() req := s.sendReq(ctx) s.reqMustNotReady(req) + // Wait until created + for !created.Load() { + time.Sleep(time.Millisecond) + } s.streamInner.generateNext() s.reqMustReady(req) } + s.re.True(created.Load()) s.re.NotNil(s.stream) } From f67467ea5fbaf91bf1da768cda5a8c1765cafb73 Mon Sep 17 00:00:00 2001 From: MyonKeminta Date: Wed, 25 Sep 2024 17:29:20 +0800 Subject: [PATCH 34/34] Update estimated latency metric in tsoStream instead of tsoDispatcher Signed-off-by: MyonKeminta --- client/tso_dispatcher.go | 1 - client/tso_stream.go | 5 ++++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/client/tso_dispatcher.go b/client/tso_dispatcher.go index 009b017aa4d..7febf194f3c 100644 --- a/client/tso_dispatcher.go +++ b/client/tso_dispatcher.go @@ -353,7 +353,6 @@ tsoBatchLoop: // continue collecting. if td.isConcurrentRPCEnabled() { estimatedLatency := stream.EstimatedRPCLatency() - estimateTSOLatencyGauge.WithLabelValues(streamURL).Set(estimatedLatency.Seconds()) goalBatchTime := estimatedLatency / time.Duration(td.rpcConcurrency) failpoint.Inject("tsoDispatcherConcurrentModeAssertDelayDuration", func(val failpoint.Value) { diff --git a/client/tso_stream.go b/client/tso_stream.go index b8118b04468..142ad71c6b9 100644 --- a/client/tso_stream.go +++ b/client/tso_stream.go @@ -385,7 +385,10 @@ func (s *tsoStream) recvLoop(ctx context.Context) { } currentSample := math.Log(float64(latency.Microseconds())) filteredValue := filter.update(sampleTime, currentSample) - s.estimatedLatencyMicros.Store(uint64(math.Exp(filteredValue))) + micros := math.Exp(filteredValue) + s.estimatedLatencyMicros.Store(uint64(micros)) + // Update the metrics in seconds. + estimateTSOLatencyGauge.WithLabelValues(s.streamID).Set(micros * 1e-6) } recvLoop: