diff --git a/client/metrics.go b/client/metrics.go index a11362669b3..a83b4a36407 100644 --- a/client/metrics.go +++ b/client/metrics.go @@ -39,13 +39,14 @@ func initAndRegisterMetrics(constLabels prometheus.Labels) { } var ( - cmdDuration *prometheus.HistogramVec - cmdFailedDuration *prometheus.HistogramVec - requestDuration *prometheus.HistogramVec - tsoBestBatchSize prometheus.Histogram - tsoBatchSize prometheus.Histogram - tsoBatchSendLatency prometheus.Histogram - requestForwarded *prometheus.GaugeVec + cmdDuration *prometheus.HistogramVec + cmdFailedDuration *prometheus.HistogramVec + requestDuration *prometheus.HistogramVec + tsoBestBatchSize prometheus.Histogram + tsoBatchSize prometheus.Histogram + tsoBatchSendLatency prometheus.Histogram + requestForwarded *prometheus.GaugeVec + ongoingRequestCountGauge *prometheus.GaugeVec ) func initMetrics(constLabels prometheus.Labels) { @@ -117,6 +118,15 @@ func initMetrics(constLabels prometheus.Labels) { Help: "The status to indicate if the request is forwarded", ConstLabels: constLabels, }, []string{"host", "delegate"}) + + ongoingRequestCountGauge = prometheus.NewGaugeVec( + prometheus.GaugeOpts{ + Namespace: "pd_client", + Subsystem: "request", + Name: "ongoing_requests_count", + Help: "Current count of ongoing batch tso requests", + ConstLabels: constLabels, + }, []string{"stream"}) } var ( diff --git a/client/tso_batch_controller.go b/client/tso_batch_controller.go index a713b7a187d..32191889160 100644 --- a/client/tso_batch_controller.go +++ b/client/tso_batch_controller.go @@ -19,10 +19,7 @@ import ( "runtime/trace" "time" - "github.com/pingcap/errors" - "github.com/pingcap/log" "github.com/tikv/pd/client/tsoutil" - "go.uber.org/zap" ) type tsoBatchController struct { @@ -30,42 +27,77 @@ type tsoBatchController struct { // bestBatchSize is a dynamic size that changed based on the current batch effect. bestBatchSize int - tsoRequestCh chan *tsoRequest collectedRequests []*tsoRequest collectedRequestCount int - batchStartTime time.Time + // The time after getting the first request and the token, and before performing extra batching. + extraBatchingStartTime time.Time } -func newTSOBatchController(tsoRequestCh chan *tsoRequest, maxBatchSize int) *tsoBatchController { +func newTSOBatchController(maxBatchSize int) *tsoBatchController { return &tsoBatchController{ maxBatchSize: maxBatchSize, bestBatchSize: 8, /* Starting from a low value is necessary because we need to make sure it will be converged to (current_batch_size - 4) */ - tsoRequestCh: tsoRequestCh, collectedRequests: make([]*tsoRequest, maxBatchSize+1), collectedRequestCount: 0, } } // fetchPendingRequests will start a new round of the batch collecting from the channel. -// It returns true if everything goes well, otherwise false which means we should stop the service. -func (tbc *tsoBatchController) fetchPendingRequests(ctx context.Context, maxBatchWaitInterval time.Duration) error { - var firstRequest *tsoRequest - select { - case <-ctx.Done(): - return ctx.Err() - case firstRequest = <-tbc.tsoRequestCh: - } - // Start to batch when the first TSO request arrives. - tbc.batchStartTime = time.Now() +// It returns nil error if everything goes well, otherwise a non-nil error which means we should stop the service. +// It's guaranteed that if this function failed after collecting some requests, then these requests will be cancelled +// when the function returns, so the caller don't need to clear them manually. +func (tbc *tsoBatchController) fetchPendingRequests(ctx context.Context, tsoRequestCh <-chan *tsoRequest, tokenCh chan struct{}, maxBatchWaitInterval time.Duration) (errRet error) { + var tokenAcquired bool + defer func() { + if errRet != nil { + // Something went wrong when collecting a batch of requests. Release the token and cancel collected requests + // if any. + if tokenAcquired { + tokenCh <- struct{}{} + } + tbc.finishCollectedRequests(0, 0, 0, invalidStreamID, errRet) + } + }() + + // Wait until BOTH the first request and the token have arrived. + // TODO: `tbc.collectedRequestCount` should never be non-empty here. Consider do assertion here. tbc.collectedRequestCount = 0 - tbc.pushRequest(firstRequest) + for { + select { + case <-ctx.Done(): + return ctx.Err() + case req := <-tsoRequestCh: + // Start to batch when the first TSO request arrives. + tbc.pushRequest(req) + // A request arrives but the token is not ready yet. Continue waiting, and also allowing collecting the next + // request if it arrives. + continue + case <-tokenCh: + tokenAcquired = true + } + + // The token is ready. If the first request didn't arrive, wait for it. + if tbc.collectedRequestCount == 0 { + select { + case <-ctx.Done(): + return ctx.Err() + case firstRequest := <-tsoRequestCh: + tbc.pushRequest(firstRequest) + } + } + + // Both token and the first request have arrived. + break + } + + tbc.extraBatchingStartTime = time.Now() // This loop is for trying best to collect more requests, so we use `tbc.maxBatchSize` here. fetchPendingRequestsLoop: for tbc.collectedRequestCount < tbc.maxBatchSize { select { - case tsoReq := <-tbc.tsoRequestCh: + case tsoReq := <-tsoRequestCh: tbc.pushRequest(tsoReq) case <-ctx.Done(): return ctx.Err() @@ -88,7 +120,7 @@ fetchPendingRequestsLoop: defer after.Stop() for tbc.collectedRequestCount < tbc.bestBatchSize { select { - case tsoReq := <-tbc.tsoRequestCh: + case tsoReq := <-tsoRequestCh: tbc.pushRequest(tsoReq) case <-ctx.Done(): return ctx.Err() @@ -103,7 +135,7 @@ fetchPendingRequestsLoop: // we can adjust the `tbc.bestBatchSize` dynamically later. for tbc.collectedRequestCount < tbc.maxBatchSize { select { - case tsoReq := <-tbc.tsoRequestCh: + case tsoReq := <-tsoRequestCh: tbc.pushRequest(tsoReq) case <-ctx.Done(): return ctx.Err() @@ -136,31 +168,16 @@ func (tbc *tsoBatchController) adjustBestBatchSize() { } } -func (tbc *tsoBatchController) finishCollectedRequests(physical, firstLogical int64, suffixBits uint32, err error) { +func (tbc *tsoBatchController) finishCollectedRequests(physical, firstLogical int64, suffixBits uint32, streamID string, err error) { for i := 0; i < tbc.collectedRequestCount; i++ { tsoReq := tbc.collectedRequests[i] // Retrieve the request context before the request is done to trace without race. requestCtx := tsoReq.requestCtx tsoReq.physical, tsoReq.logical = physical, tsoutil.AddLogical(firstLogical, int64(i), suffixBits) + tsoReq.streamID = streamID tsoReq.tryDone(err) trace.StartRegion(requestCtx, "pdclient.tsoReqDequeue").End() } // Prevent the finished requests from being processed again. tbc.collectedRequestCount = 0 } - -func (tbc *tsoBatchController) revokePendingRequests(err error) { - for i := 0; i < len(tbc.tsoRequestCh); i++ { - req := <-tbc.tsoRequestCh - req.tryDone(err) - } -} - -func (tbc *tsoBatchController) clear() { - log.Info("[pd] clear the tso batch controller", - zap.Int("max-batch-size", tbc.maxBatchSize), zap.Int("best-batch-size", tbc.bestBatchSize), - zap.Int("collected-request-count", tbc.collectedRequestCount), zap.Int("pending-request-count", len(tbc.tsoRequestCh))) - tsoErr := errors.WithStack(errClosing) - tbc.finishCollectedRequests(0, 0, 0, tsoErr) - tbc.revokePendingRequests(tsoErr) -} diff --git a/client/tso_client.go b/client/tso_client.go index 2f3b949f017..f1538a7f164 100644 --- a/client/tso_client.go +++ b/client/tso_client.go @@ -203,6 +203,7 @@ func (c *tsoClient) getTSORequest(ctx context.Context, dcLocation string) *tsoRe req.physical = 0 req.logical = 0 req.dcLocation = dcLocation + req.streamID = "" return req } diff --git a/client/tso_dispatcher.go b/client/tso_dispatcher.go index a7c99057275..a1e0b03a1fa 100644 --- a/client/tso_dispatcher.go +++ b/client/tso_dispatcher.go @@ -76,10 +76,18 @@ type tsoDispatcher struct { provider tsoServiceProvider // URL -> *connectionContext - connectionCtxs *sync.Map - batchController *tsoBatchController - tsDeadlineCh chan *deadline - lastTSOInfo *tsoInfo + connectionCtxs *sync.Map + tsoRequestCh chan *tsoRequest + tsDeadlineCh chan *deadline + lastTSOInfo *tsoInfo + // For reusing tsoBatchController objects + batchBufferPool *sync.Pool + + // For controlling amount of concurrently processing RPC requests. + // A token must be acquired here before sending an RPC request, and the token must be put back after finishing the + // RPC. This is used like a semaphore, but we don't use semaphore directly here as it cannot be selected with + // other channels. + tokenCh chan struct{} updateConnectionCtxsCh chan struct{} } @@ -91,24 +99,29 @@ func newTSODispatcher( provider tsoServiceProvider, ) *tsoDispatcher { dispatcherCtx, dispatcherCancel := context.WithCancel(ctx) - tsoBatchController := newTSOBatchController( - make(chan *tsoRequest, maxBatchSize*2), - maxBatchSize, - ) + tsoRequestCh := make(chan *tsoRequest, maxBatchSize*2) failpoint.Inject("shortDispatcherChannel", func() { - tsoBatchController = newTSOBatchController( - make(chan *tsoRequest, 1), - maxBatchSize, - ) + tsoRequestCh = make(chan *tsoRequest, 1) }) + + // A large-enough capacity to hold maximum concurrent RPC requests. In our design, the concurrency is at most 16. + const tokenChCapacity = 64 + tokenCh := make(chan struct{}, tokenChCapacity) + td := &tsoDispatcher{ - ctx: dispatcherCtx, - cancel: dispatcherCancel, - dc: dc, - provider: provider, - connectionCtxs: &sync.Map{}, - batchController: tsoBatchController, - tsDeadlineCh: make(chan *deadline, 1), + ctx: dispatcherCtx, + cancel: dispatcherCancel, + dc: dc, + provider: provider, + connectionCtxs: &sync.Map{}, + tsoRequestCh: tsoRequestCh, + tsDeadlineCh: make(chan *deadline, 1), + batchBufferPool: &sync.Pool{ + New: func() any { + return newTSOBatchController(maxBatchSize * 2) + }, + }, + tokenCh: tokenCh, updateConnectionCtxsCh: make(chan struct{}, 1), } go td.watchTSDeadline() @@ -146,13 +159,21 @@ func (td *tsoDispatcher) scheduleUpdateConnectionCtxs() { } } +func (td *tsoDispatcher) revokePendingRequests(err error) { + for i := 0; i < len(td.tsoRequestCh); i++ { + req := <-td.tsoRequestCh + req.tryDone(err) + } +} + func (td *tsoDispatcher) close() { td.cancel() - td.batchController.clear() + tsoErr := errors.WithStack(errClosing) + td.revokePendingRequests(tsoErr) } func (td *tsoDispatcher) push(request *tsoRequest) { - td.batchController.tsoRequestCh <- request + td.tsoRequestCh <- request } func (td *tsoDispatcher) handleDispatcher(wg *sync.WaitGroup) { @@ -163,8 +184,12 @@ func (td *tsoDispatcher) handleDispatcher(wg *sync.WaitGroup) { svcDiscovery = provider.getServiceDiscovery() option = provider.getOption() connectionCtxs = td.connectionCtxs - batchController = td.batchController + batchController *tsoBatchController ) + + // Currently only 1 concurrency is supported. Put one token in. + td.tokenCh <- struct{}{} + log.Info("[tso] tso dispatcher created", zap.String("dc-location", dc)) // Clean up the connectionCtxs when the dispatcher exits. defer func() { @@ -174,8 +199,11 @@ func (td *tsoDispatcher) handleDispatcher(wg *sync.WaitGroup) { cc.(*tsoConnectionContext).cancel() return true }) - // Clear the tso batch controller. - batchController.clear() + if batchController != nil && batchController.collectedRequestCount != 0 { + log.Fatal("batched tso requests not cleared when exiting the tso dispatcher loop", zap.Any("panic", recover())) + } + tsoErr := errors.WithStack(errClosing) + td.revokePendingRequests(tsoErr) wg.Done() }() // Daemon goroutine to update the connectionCtxs periodically and handle the `connectionCtxs` update event. @@ -199,13 +227,17 @@ tsoBatchLoop: return default: } + + // In case error happens, the loop may continue without resetting `batchController` for retrying. + if batchController == nil { + batchController = td.batchBufferPool.Get().(*tsoBatchController) + } + // Start to collect the TSO requests. maxBatchWaitInterval := option.getMaxTSOBatchWaitInterval() // Once the TSO requests are collected, must make sure they could be finished or revoked eventually, // otherwise the upper caller may get blocked on waiting for the results. - if err = batchController.fetchPendingRequests(ctx, maxBatchWaitInterval); err != nil { - // Finish the collected requests if the fetch failed. - batchController.finishCollectedRequests(0, 0, 0, errors.WithStack(err)) + if err = batchController.fetchPendingRequests(ctx, td.tsoRequestCh, td.tokenCh, maxBatchWaitInterval); err != nil { if err == context.Canceled { log.Info("[tso] stop fetching the pending tso requests due to context canceled", zap.String("dc-location", dc)) @@ -246,7 +278,7 @@ tsoBatchLoop: select { case <-ctx.Done(): // Finish the collected requests if the context is canceled. - batchController.finishCollectedRequests(0, 0, 0, errors.WithStack(ctx.Err())) + td.cancelCollectedRequests(batchController, invalidStreamID, errors.WithStack(ctx.Err())) timer.Stop() return case <-streamLoopTimer.C: @@ -254,7 +286,7 @@ tsoBatchLoop: log.Error("[tso] create tso stream error", zap.String("dc-location", dc), errs.ZapError(err)) svcDiscovery.ScheduleCheckMemberChanged() // Finish the collected requests if the stream is failed to be created. - batchController.finishCollectedRequests(0, 0, 0, errors.WithStack(err)) + td.cancelCollectedRequests(batchController, invalidStreamID, errors.WithStack(err)) timer.Stop() continue tsoBatchLoop case <-timer.C: @@ -271,55 +303,90 @@ tsoBatchLoop: stream = nil continue default: - break streamChoosingLoop } + + // Check if any error has occurred on this stream when receiving asynchronously. + if err = stream.GetRecvError(); err != nil { + exit := !td.handleProcessRequestError(ctx, bo, streamURL, cancel, err) + stream = nil + if exit { + td.cancelCollectedRequests(batchController, invalidStreamID, errors.WithStack(ctx.Err())) + return + } + continue + } + + break streamChoosingLoop } done := make(chan struct{}) dl := newTSDeadline(option.timeout, done, cancel) select { case <-ctx.Done(): // Finish the collected requests if the context is canceled. - batchController.finishCollectedRequests(0, 0, 0, errors.WithStack(ctx.Err())) + td.cancelCollectedRequests(batchController, invalidStreamID, errors.WithStack(ctx.Err())) return case td.tsDeadlineCh <- dl: } // processRequests guarantees that the collected requests could be finished properly. - err = td.processRequests(stream, dc, td.batchController) - close(done) + err = td.processRequests(stream, dc, batchController, done) // If error happens during tso stream handling, reset stream and run the next trial. - if err != nil { + if err == nil { + // A nil error returned by `processRequests` indicates that the request batch is started successfully. + // In this case, the `batchController` will be put back to the pool when the request is finished + // asynchronously (either successful or not). This infers that the current `batchController` object will + // be asynchronously accessed after the `processRequests` call. As a result, we need to use another + // `batchController` for collecting the next batch. Do to this, we set the `batchController` to nil so that + // another one will be fetched from the pool at the beginning of the batching loop. + // Otherwise, the `batchController` won't be processed in other goroutines concurrently, and it can be + // reused in the next loop safely. + batchController = nil + } else { + exit := !td.handleProcessRequestError(ctx, bo, streamURL, cancel, err) + stream = nil + if exit { + return + } + } + } +} + +// handleProcessRequestError handles errors occurs when trying to process a TSO RPC request for the dispatcher loop. +// Returns true if the dispatcher loop is ok to continue. Otherwise, the dispatcher loop should be exited. +func (td *tsoDispatcher) handleProcessRequestError(ctx context.Context, bo *retry.Backoffer, streamURL string, streamCancelFunc context.CancelFunc, err error) bool { + select { + case <-ctx.Done(): + return false + default: + } + + svcDiscovery := td.provider.getServiceDiscovery() + + svcDiscovery.ScheduleCheckMemberChanged() + log.Error("[tso] getTS error after processing requests", + zap.String("dc-location", td.dc), + zap.String("stream-url", streamURL), + zap.Error(errs.ErrClientGetTSO.FastGenByArgs(err.Error()))) + // Set `stream` to nil and remove this stream from the `connectionCtxs` due to error. + td.connectionCtxs.Delete(streamURL) + streamCancelFunc() + // Because ScheduleCheckMemberChanged is asynchronous, if the leader changes, we better call `updateMember` ASAP. + if errs.IsLeaderChange(err) { + if err := bo.Exec(ctx, svcDiscovery.CheckMemberChanged); err != nil { select { case <-ctx.Done(): - return + return false default: } - svcDiscovery.ScheduleCheckMemberChanged() - log.Error("[tso] getTS error after processing requests", - zap.String("dc-location", dc), - zap.String("stream-url", streamURL), - zap.Error(errs.ErrClientGetTSO.FastGenByArgs(err.Error()))) - // Set `stream` to nil and remove this stream from the `connectionCtxs` due to error. - connectionCtxs.Delete(streamURL) - cancel() - stream = nil - // Because ScheduleCheckMemberChanged is asynchronous, if the leader changes, we better call `updateMember` ASAP. - if errs.IsLeaderChange(err) { - if err := bo.Exec(ctx, svcDiscovery.CheckMemberChanged); err != nil { - select { - case <-ctx.Done(): - return - default: - } - } - // Because the TSO Follower Proxy could be configured online, - // If we change it from on -> off, background updateConnectionCtxs - // will cancel the current stream, then the EOF error caused by cancel() - // should not trigger the updateConnectionCtxs here. - // So we should only call it when the leader changes. - provider.updateConnectionCtxs(ctx, dc, connectionCtxs) - } } + // Because the TSO Follower Proxy could be configured online, + // If we change it from on -> off, background updateConnectionCtxs + // will cancel the current stream, then the EOF error caused by cancel() + // should not trigger the updateConnectionCtxs here. + // So we should only call it when the leader changes. + td.provider.updateConnectionCtxs(ctx, td.dc, td.connectionCtxs) } + + return true } // updateConnectionCtxs updates the `connectionCtxs` for the specified DC location regularly. @@ -392,9 +459,14 @@ func chooseStream(connectionCtxs *sync.Map) (connectionCtx *tsoConnectionContext return connectionCtx } +// processRequests sends the RPC request for the batch. It's guaranteed that after calling this function, requests +// in the batch must be eventually finished (done or canceled), either synchronously or asynchronously. +// `close(done)` will be called at the same time when finishing the requests. +// If this function returns a non-nil error, the requests will always be canceled synchronously. func (td *tsoDispatcher) processRequests( - stream *tsoStream, dcLocation string, tbc *tsoBatchController, + stream *tsoStream, dcLocation string, tbc *tsoBatchController, done chan struct{}, ) error { + // `done` must be guaranteed to be eventually called. var ( requests = tbc.getCollectedRequests() traceRegions = make([]*trace.Region, 0, len(requests)) @@ -422,28 +494,54 @@ func (td *tsoDispatcher) processRequests( keyspaceID = svcDiscovery.GetKeyspaceID() reqKeyspaceGroupID = svcDiscovery.GetKeyspaceGroupID() ) - respKeyspaceGroupID, physical, logical, suffixBits, err := stream.processRequests( + + cb := func(result tsoRequestResult, reqKeyspaceGroupID uint32, err error) { + // As golang doesn't allow double-closing a channel, here is implicitly a check that the callback + // is never called twice or called while it's also being cancelled elsewhere. + close(done) + + defer td.batchBufferPool.Put(tbc) + if err != nil { + td.cancelCollectedRequests(tbc, stream.streamID, err) + return + } + + curTSOInfo := &tsoInfo{ + tsoServer: stream.getServerURL(), + reqKeyspaceGroupID: reqKeyspaceGroupID, + respKeyspaceGroupID: result.respKeyspaceGroupID, + respReceivedAt: time.Now(), + physical: result.physical, + logical: result.logical, + } + // `logical` is the largest ts's logical part here, we need to do the subtracting before we finish each TSO request. + firstLogical := tsoutil.AddLogical(result.logical, -int64(result.count)+1, result.suffixBits) + td.compareAndSwapTS(curTSOInfo, firstLogical) + td.doneCollectedRequests(tbc, result.physical, firstLogical, result.suffixBits, stream.streamID) + } + + err := stream.processRequests( clusterID, keyspaceID, reqKeyspaceGroupID, - dcLocation, count, tbc.batchStartTime) + dcLocation, count, tbc.extraBatchingStartTime, cb) if err != nil { - tbc.finishCollectedRequests(0, 0, 0, err) + close(done) + + td.cancelCollectedRequests(tbc, stream.streamID, err) return err } - curTSOInfo := &tsoInfo{ - tsoServer: stream.getServerURL(), - reqKeyspaceGroupID: reqKeyspaceGroupID, - respKeyspaceGroupID: respKeyspaceGroupID, - respReceivedAt: time.Now(), - physical: physical, - logical: logical, - } - // `logical` is the largest ts's logical part here, we need to do the subtracting before we finish each TSO request. - firstLogical := tsoutil.AddLogical(logical, -count+1, suffixBits) - td.compareAndSwapTS(curTSOInfo, firstLogical) - tbc.finishCollectedRequests(physical, firstLogical, suffixBits, nil) return nil } +func (td *tsoDispatcher) cancelCollectedRequests(tbc *tsoBatchController, streamID string, err error) { + td.tokenCh <- struct{}{} + tbc.finishCollectedRequests(0, 0, 0, streamID, err) +} + +func (td *tsoDispatcher) doneCollectedRequests(tbc *tsoBatchController, physical int64, firstLogical int64, suffixBits uint32, streamID string) { + td.tokenCh <- struct{}{} + tbc.finishCollectedRequests(physical, firstLogical, suffixBits, streamID, nil) +} + func (td *tsoDispatcher) compareAndSwapTS( curTSOInfo *tsoInfo, firstLogical int64, ) { diff --git a/client/tso_dispatcher_test.go b/client/tso_dispatcher_test.go new file mode 100644 index 00000000000..b8f0fcef208 --- /dev/null +++ b/client/tso_dispatcher_test.go @@ -0,0 +1,104 @@ +// Copyright 2024 TiKV Project Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package pd + +import ( + "context" + "fmt" + "sync" + "testing" + "time" + + "github.com/pingcap/log" + "go.uber.org/zap/zapcore" +) + +type mockTSOServiceProvider struct { + option *option +} + +func newMockTSOServiceProvider(option *option) *mockTSOServiceProvider { + return &mockTSOServiceProvider{ + option: option, + } +} + +func (m *mockTSOServiceProvider) getOption() *option { + return m.option +} + +func (*mockTSOServiceProvider) getServiceDiscovery() ServiceDiscovery { + return NewMockPDServiceDiscovery([]string{mockStreamURL}, nil) +} + +func (*mockTSOServiceProvider) updateConnectionCtxs(ctx context.Context, _dc string, connectionCtxs *sync.Map) bool { + _, ok := connectionCtxs.Load(mockStreamURL) + if ok { + return true + } + ctx, cancel := context.WithCancel(ctx) + stream := newTSOStream(ctx, mockStreamURL, newMockTSOStreamImpl(ctx, true)) + connectionCtxs.LoadOrStore(mockStreamURL, &tsoConnectionContext{ctx, cancel, mockStreamURL, stream}) + return true +} + +func BenchmarkTSODispatcherHandleRequests(b *testing.B) { + log.SetLevel(zapcore.FatalLevel) + + ctx := context.Background() + + reqPool := &sync.Pool{ + New: func() any { + return &tsoRequest{ + done: make(chan error, 1), + physical: 0, + logical: 0, + dcLocation: globalDCLocation, + } + }, + } + getReq := func() *tsoRequest { + req := reqPool.Get().(*tsoRequest) + req.clientCtx = ctx + req.requestCtx = ctx + req.physical = 0 + req.logical = 0 + req.start = time.Now() + req.pool = reqPool + return req + } + + dispatcher := newTSODispatcher(ctx, globalDCLocation, defaultMaxTSOBatchSize, newMockTSOServiceProvider(newOption())) + var wg sync.WaitGroup + wg.Add(1) + + go dispatcher.handleDispatcher(&wg) + defer func() { + dispatcher.close() + wg.Wait() + }() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + req := getReq() + dispatcher.push(req) + _, _, err := req.Wait() + if err != nil { + panic(fmt.Sprintf("unexpected error from tsoReq: %+v", err)) + } + } + // Don't count the time cost in `defer` + b.StopTimer() +} diff --git a/client/tso_request.go b/client/tso_request.go index b912fa35497..fb2ae2bb92e 100644 --- a/client/tso_request.go +++ b/client/tso_request.go @@ -42,6 +42,9 @@ type tsoRequest struct { logical int64 dcLocation string + // The identifier of the RPC stream in which the request is processed. + streamID string + // Runtime fields. start time.Time pool *sync.Pool diff --git a/client/tso_stream.go b/client/tso_stream.go index 76b6ae3c51c..479beff2c6a 100644 --- a/client/tso_stream.go +++ b/client/tso_stream.go @@ -16,13 +16,19 @@ package pd import ( "context" + "fmt" "io" + "sync" + "sync/atomic" "time" "github.com/pingcap/errors" "github.com/pingcap/kvproto/pkg/pdpb" "github.com/pingcap/kvproto/pkg/tsopb" + "github.com/pingcap/log" + "github.com/prometheus/client_golang/prometheus" "github.com/tikv/pd/client/errs" + "go.uber.org/zap" "google.golang.org/grpc" ) @@ -62,7 +68,7 @@ func (b *pdTSOStreamBuilder) build(ctx context.Context, cancel context.CancelFun stream, err := b.client.Tso(ctx) done <- struct{}{} if err == nil { - return &tsoStream{stream: pdTSOStreamAdapter{stream}, serverURL: b.serverURL}, nil + return newTSOStream(ctx, b.serverURL, pdTSOStreamAdapter{stream}), nil } return nil, err } @@ -81,7 +87,7 @@ func (b *tsoTSOStreamBuilder) build( stream, err := b.client.Tso(ctx) done <- struct{}{} if err == nil { - return &tsoStream{stream: tsoTSOStreamAdapter{stream}, serverURL: b.serverURL}, nil + return newTSOStream(ctx, b.serverURL, tsoTSOStreamAdapter{stream}), nil } return nil, err } @@ -176,51 +182,255 @@ func (s tsoTSOStreamAdapter) Recv() (tsoRequestResult, error) { }, nil } +type onFinishedCallback func(result tsoRequestResult, reqKeyspaceGroupID uint32, err error) + +type batchedRequests struct { + startTime time.Time + count int64 + reqKeyspaceGroupID uint32 + callback onFinishedCallback +} + +// tsoStream represents an abstracted stream for requesting TSO. +// This type designed decoupled with users of this type, so tsoDispatcher won't be directly accessed here. +// Also in order to avoid potential memory allocations that might happen when passing closures as the callback, +// we instead use the `batchedRequestsNotifier` as the abstraction, and accepts generic type instead of dynamic interface +// type. type tsoStream struct { serverURL string // The internal gRPC stream. // - `pdpb.PD_TsoClient` for a leader/follower in the PD cluster. // - `tsopb.TSO_TsoClient` for a primary/secondary in the TSO cluster. stream grpcTSOStreamAdapter + // An identifier of the tsoStream object for metrics reporting and diagnosing. + streamID string + + pendingRequests chan batchedRequests + + cancel context.CancelFunc + wg sync.WaitGroup + + // For syncing between sender and receiver to guarantee all requests are finished when closing. + state atomic.Int32 + stoppedWithErr atomic.Pointer[error] + + ongoingRequestCountGauge prometheus.Gauge + ongoingRequests atomic.Int32 +} + +const ( + streamStateIdle int32 = iota + streamStateSending + streamStateClosing +) + +var streamIDAlloc atomic.Int32 + +const invalidStreamID = "" + +func newTSOStream(ctx context.Context, serverURL string, stream grpcTSOStreamAdapter) *tsoStream { + streamID := fmt.Sprintf("%s-%d", serverURL, streamIDAlloc.Add(1)) + // To make error handling in `tsoDispatcher` work, the internal `cancel` and external `cancel` is better to be + // distinguished. + ctx, cancel := context.WithCancel(ctx) + s := &tsoStream{ + serverURL: serverURL, + stream: stream, + streamID: streamID, + + pendingRequests: make(chan batchedRequests, 64), + + cancel: cancel, + + ongoingRequestCountGauge: ongoingRequestCountGauge.WithLabelValues(streamID), + } + s.wg.Add(1) + go s.recvLoop(ctx) + return s } func (s *tsoStream) getServerURL() string { return s.serverURL } +// processRequests starts an RPC to get a batch of timestamps without waiting for the result. When the result is ready, +// it will be passed th `notifier.finish`. +// +// This function is NOT thread-safe. Don't call this function concurrently in multiple goroutines. +// +// It's guaranteed that the `callback` will be called, but when the request is failed to be scheduled, the callback +// will be ignored. func (s *tsoStream) processRequests( - clusterID uint64, keyspaceID, keyspaceGroupID uint32, dcLocation string, count int64, batchStartTime time.Time, -) (respKeyspaceGroupID uint32, physical, logical int64, suffixBits uint32, err error) { + clusterID uint64, keyspaceID, keyspaceGroupID uint32, dcLocation string, count int64, batchStartTime time.Time, callback onFinishedCallback, +) error { start := time.Now() - if err = s.stream.Send(clusterID, keyspaceID, keyspaceGroupID, dcLocation, count); err != nil { - if err == io.EOF { - err = errs.ErrClientTSOStreamClosed - } else { - err = errors.WithStack(err) + + // Check if the stream is closing or closed, in which case no more requests should be put in. + // Note that the prevState should be restored very soon, as the receiver may check + prevState := s.state.Swap(streamStateSending) + switch prevState { + case streamStateIdle: + // Expected case + case streamStateClosing: + s.state.Store(prevState) + err := s.GetRecvError() + log.Info("sending to closed tsoStream", zap.String("stream", s.streamID), zap.Error(err)) + if err == nil { + err = errors.WithStack(errs.ErrClientTSOStreamClosed) } - return + return err + case streamStateSending: + log.Fatal("unexpected concurrent sending on tsoStream", zap.String("stream", s.streamID)) + default: + log.Fatal("unknown tsoStream state", zap.String("stream", s.streamID), zap.Int32("state", prevState)) + } + + select { + case s.pendingRequests <- batchedRequests{ + startTime: start, + count: count, + reqKeyspaceGroupID: keyspaceGroupID, + callback: callback, + }: + default: + s.state.Store(prevState) + return errors.New("unexpected channel full") + } + s.state.Store(prevState) + + if err := s.stream.Send(clusterID, keyspaceID, keyspaceGroupID, dcLocation, count); err != nil { + // As the request is already put into `pendingRequests`, the request should finally be canceled by the recvLoop. + // So skip returning error here to avoid + // if err == io.EOF { + // return errors.WithStack(errs.ErrClientTSOStreamClosed) + // } + // return errors.WithStack(err) + log.Warn("failed to send RPC request through tsoStream", zap.String("stream", s.streamID), zap.Error(err)) + return nil } tsoBatchSendLatency.Observe(time.Since(batchStartTime).Seconds()) - res, err := s.stream.Recv() - duration := time.Since(start).Seconds() - if err != nil { - requestFailedDurationTSO.Observe(duration) - if err == io.EOF { - err = errs.ErrClientTSOStreamClosed - } else { - err = errors.WithStack(err) + s.ongoingRequestCountGauge.Set(float64(s.ongoingRequests.Add(1))) + return nil +} + +func (s *tsoStream) recvLoop(ctx context.Context) { + var finishWithErr error + var currentReq batchedRequests + var hasReq bool + + defer func() { + if r := recover(); r != nil { + log.Fatal("tsoStream.recvLoop internal panic", zap.Stack("stacktrace"), zap.Any("panicMessage", r)) } - return + + if finishWithErr == nil { + // The loop must exit with a non-nil error (including io.EOF and context.Canceled). This should be + // unreachable code. + log.Fatal("tsoStream.recvLoop exited without error info", zap.String("stream", s.streamID)) + } + + if hasReq { + // There's an unfinished request, cancel it, otherwise it will be blocked forever. + currentReq.callback(tsoRequestResult{}, currentReq.reqKeyspaceGroupID, finishWithErr) + } + + s.stoppedWithErr.Store(&finishWithErr) + s.cancel() + for !s.state.CompareAndSwap(streamStateIdle, streamStateClosing) { + switch state := s.state.Load(); state { + case streamStateIdle, streamStateSending: + // streamStateSending should switch to streamStateIdle very quickly. Spin until successfully setting to + // streamStateClosing. + continue + case streamStateClosing: + log.Warn("unexpected double closing of tsoStream", zap.String("stream", s.streamID)) + default: + log.Fatal("unknown tsoStream state", zap.String("stream", s.streamID), zap.Int32("state", state)) + } + } + + log.Info("tsoStream.recvLoop ended", zap.String("stream", s.streamID), zap.Error(finishWithErr)) + + close(s.pendingRequests) + + // Cancel remaining pending requests. + for req := range s.pendingRequests { + req.callback(tsoRequestResult{}, req.reqKeyspaceGroupID, errors.WithStack(finishWithErr)) + } + + s.wg.Done() + s.ongoingRequests.Store(0) + s.ongoingRequestCountGauge.Set(0) + }() + +recvLoop: + for { + select { + case <-ctx.Done(): + finishWithErr = context.Canceled + break recvLoop + default: + } + + res, err := s.stream.Recv() + + // Try to load the corresponding `batchedRequests`. If `Recv` is successful, there must be a request pending + // in the queue. + select { + case currentReq = <-s.pendingRequests: + hasReq = true + default: + hasReq = false + } + + durationSeconds := time.Since(currentReq.startTime).Seconds() + + if err != nil { + // If a request is pending and error occurs, observe the duration it has cost. + // Note that it's also possible that the stream is broken due to network without being requested. In this + // case, `Recv` may return an error while no request is pending. + if hasReq { + requestFailedDurationTSO.Observe(durationSeconds) + } + if err == io.EOF { + finishWithErr = errors.WithStack(errs.ErrClientTSOStreamClosed) + } else { + finishWithErr = errors.WithStack(err) + } + break recvLoop + } else if !hasReq { + finishWithErr = errors.New("tsoStream timing order broken") + break recvLoop + } + + latencySeconds := durationSeconds + requestDurationTSO.Observe(latencySeconds) + tsoBatchSize.Observe(float64(res.count)) + + if res.count != uint32(currentReq.count) { + finishWithErr = errors.WithStack(errTSOLength) + break recvLoop + } + + currentReq.callback(res, currentReq.reqKeyspaceGroupID, nil) + // After finishing the requests, unset these variables which will be checked in the defer block. + currentReq = batchedRequests{} + hasReq = false + + s.ongoingRequestCountGauge.Set(float64(s.ongoingRequests.Add(-1))) } - requestDurationTSO.Observe(duration) - tsoBatchSize.Observe(float64(count)) +} - if res.count != uint32(count) { - err = errors.WithStack(errTSOLength) - return +// GetRecvError returns the error (if any) that has been encountered when receiving response asynchronously. +func (s *tsoStream) GetRecvError() error { + perr := s.stoppedWithErr.Load() + if perr == nil { + return nil } + return *perr +} - respKeyspaceGroupID = res.respKeyspaceGroupID - physical, logical, suffixBits = res.physical, res.logical, res.suffixBits - return +// WaitForClosed blocks until the stream is closed and the inner loop exits. +func (s *tsoStream) WaitForClosed() { + s.wg.Wait() } diff --git a/client/tso_stream_test.go b/client/tso_stream_test.go new file mode 100644 index 00000000000..b09c54baf3a --- /dev/null +++ b/client/tso_stream_test.go @@ -0,0 +1,488 @@ +// Copyright 2024 TiKV Project Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package pd + +import ( + "context" + "io" + "testing" + "time" + + "github.com/pingcap/errors" + "github.com/pingcap/log" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" + "github.com/tikv/pd/client/errs" + "go.uber.org/zap/zapcore" +) + +const mockStreamURL = "mock:///" + +type requestMsg struct { + clusterID uint64 + keyspaceGroupID uint32 + count int64 +} + +type resultMsg struct { + r tsoRequestResult + err error + breakStream bool +} + +type mockTSOStreamImpl struct { + ctx context.Context + requestCh chan requestMsg + resultCh chan resultMsg + keyspaceID uint32 + errorState error + + autoGenerateResult bool + // Current progress of generating TSO results + resGenPhysical, resGenLogical int64 +} + +func newMockTSOStreamImpl(ctx context.Context, autoGenerateResult bool) *mockTSOStreamImpl { + return &mockTSOStreamImpl{ + ctx: ctx, + requestCh: make(chan requestMsg, 64), + resultCh: make(chan resultMsg, 64), + keyspaceID: 0, + + autoGenerateResult: autoGenerateResult, + resGenPhysical: 10000, + resGenLogical: 0, + } +} + +func (s *mockTSOStreamImpl) Send(clusterID uint64, _keyspaceID, keyspaceGroupID uint32, _dcLocation string, count int64) error { + select { + case <-s.ctx.Done(): + return s.ctx.Err() + default: + } + s.requestCh <- requestMsg{ + clusterID: clusterID, + keyspaceGroupID: keyspaceGroupID, + count: count, + } + return nil +} + +func (s *mockTSOStreamImpl) Recv() (tsoRequestResult, error) { + // This stream have ever receive an error, it returns the error forever. + if s.errorState != nil { + return tsoRequestResult{}, s.errorState + } + + select { + case <-s.ctx.Done(): + s.errorState = s.ctx.Err() + return tsoRequestResult{}, s.errorState + default: + } + + var ( + res resultMsg + hasRes bool + req requestMsg + hasReq bool + ) + + // Try to match a pair of request and result from each channel and allowing breaking the stream at any time. + select { + case <-s.ctx.Done(): + s.errorState = s.ctx.Err() + return tsoRequestResult{}, s.errorState + case req = <-s.requestCh: + hasReq = true + select { + case res = <-s.resultCh: + hasRes = true + default: + } + case res = <-s.resultCh: + hasRes = true + select { + case req = <-s.requestCh: + hasReq = true + default: + } + } + // Either req or res should be ready at this time. + + if hasRes { + if res.breakStream { + if res.err == nil { + panic("breaking mockTSOStreamImpl without error") + } + s.errorState = res.err + return tsoRequestResult{}, s.errorState + } else if s.autoGenerateResult { + // Do not allow manually assigning result. + panic("trying manually specifying result for mockTSOStreamImpl when it's auto-generating mode") + } + } else if s.autoGenerateResult { + res = s.autoGenResult(req.count) + hasRes = true + } + + if !hasReq { + // If req is not ready, the res must be ready. So it's certain that it don't need to be canceled by breakStream. + select { + case <-s.ctx.Done(): + s.errorState = s.ctx.Err() + return tsoRequestResult{}, s.errorState + case req = <-s.requestCh: + // Skip the assignment to make linter happy. + // hasReq = true + } + } else if !hasRes { + select { + case <-s.ctx.Done(): + s.errorState = s.ctx.Err() + return tsoRequestResult{}, s.errorState + case res = <-s.resultCh: + // Skip the assignment to make linter happy. + // hasRes = true + } + } + + // Both res and req should be ready here. + if res.err != nil { + s.errorState = res.err + } + return res.r, res.err +} + +func (s *mockTSOStreamImpl) autoGenResult(count int64) resultMsg { + physical := s.resGenPhysical + logical := s.resGenLogical + count + if logical >= (1 << 18) { + physical += logical >> 18 + logical &= (1 << 18) - 1 + } + + s.resGenPhysical = physical + s.resGenLogical = logical + + return resultMsg{ + r: tsoRequestResult{ + physical: s.resGenPhysical, + logical: s.resGenLogical, + count: uint32(count), + suffixBits: 0, + respKeyspaceGroupID: 0, + }, + } +} + +func (s *mockTSOStreamImpl) returnResult(physical int64, logical int64, count uint32) { + s.resultCh <- resultMsg{ + r: tsoRequestResult{ + physical: physical, + logical: logical, + count: count, + suffixBits: 0, + respKeyspaceGroupID: s.keyspaceID, + }, + } +} + +func (s *mockTSOStreamImpl) returnError(err error) { + s.resultCh <- resultMsg{ + err: err, + } +} + +func (s *mockTSOStreamImpl) breakStream(err error) { + s.resultCh <- resultMsg{ + err: err, + breakStream: true, + } +} + +func (s *mockTSOStreamImpl) stop() { + s.breakStream(io.EOF) +} + +type callbackInvocation struct { + result tsoRequestResult + err error +} + +type testTSOStreamSuite struct { + suite.Suite + re *require.Assertions + + inner *mockTSOStreamImpl + stream *tsoStream +} + +func (s *testTSOStreamSuite) SetupTest() { + s.re = require.New(s.T()) + s.inner = newMockTSOStreamImpl(context.Background(), false) + s.stream = newTSOStream(context.Background(), mockStreamURL, s.inner) +} + +func (s *testTSOStreamSuite) TearDownTest() { + s.inner.stop() + s.stream.WaitForClosed() + s.inner = nil + s.stream = nil +} + +func TestTSOStreamTestSuite(t *testing.T) { + suite.Run(t, new(testTSOStreamSuite)) +} + +func (s *testTSOStreamSuite) noResult(ch <-chan callbackInvocation) { + select { + case res := <-ch: + s.re.FailNowf("result received unexpectedly", "received result: %+v", res) + case <-time.After(time.Millisecond * 20): + } +} + +func (s *testTSOStreamSuite) getResult(ch <-chan callbackInvocation) callbackInvocation { + select { + case res := <-ch: + return res + case <-time.After(time.Second * 10000): + s.re.FailNow("result not ready in time") + panic("result not ready in time") + } +} + +func (s *testTSOStreamSuite) processRequestWithResultCh(count int64) (<-chan callbackInvocation, error) { + ch := make(chan callbackInvocation, 1) + err := s.stream.processRequests(1, 2, 3, globalDCLocation, count, time.Now(), func(result tsoRequestResult, reqKeyspaceGroupID uint32, err error) { + if err == nil { + s.re.Equal(uint32(3), reqKeyspaceGroupID) + s.re.Equal(uint32(0), result.suffixBits) + } + ch <- callbackInvocation{ + result: result, + err: err, + } + }) + if err != nil { + return nil, err + } + return ch, nil +} + +func (s *testTSOStreamSuite) mustProcessRequestWithResultCh(count int64) <-chan callbackInvocation { + ch, err := s.processRequestWithResultCh(count) + s.re.NoError(err) + return ch +} + +func (s *testTSOStreamSuite) TestTSOStreamBasic() { + ch := s.mustProcessRequestWithResultCh(1) + s.noResult(ch) + s.inner.returnResult(10, 1, 1) + res := s.getResult(ch) + + s.re.NoError(res.err) + s.re.Equal(int64(10), res.result.physical) + s.re.Equal(int64(1), res.result.logical) + s.re.Equal(uint32(1), res.result.count) + + ch = s.mustProcessRequestWithResultCh(2) + s.noResult(ch) + s.inner.returnResult(20, 3, 2) + res = s.getResult(ch) + + s.re.NoError(res.err) + s.re.Equal(int64(20), res.result.physical) + s.re.Equal(int64(3), res.result.logical) + s.re.Equal(uint32(2), res.result.count) + + ch = s.mustProcessRequestWithResultCh(3) + s.noResult(ch) + s.inner.returnError(errors.New("mock rpc error")) + res = s.getResult(ch) + s.re.Error(res.err) + s.re.Equal("mock rpc error", res.err.Error()) + + // After an error from the (simulated) RPC stream, the tsoStream should be in a broken status and can't accept + // new request anymore. + err := s.stream.processRequests(1, 2, 3, globalDCLocation, 1, time.Now(), func(_result tsoRequestResult, _reqKeyspaceGroupID uint32, _err error) { + panic("unreachable") + }) + s.re.Error(err) +} + +func (s *testTSOStreamSuite) testTSOStreamBrokenImpl(err error, pendingRequests int) { + var resultCh []<-chan callbackInvocation + + for i := 0; i < pendingRequests; i++ { + ch := s.mustProcessRequestWithResultCh(1) + resultCh = append(resultCh, ch) + s.noResult(ch) + } + + s.inner.breakStream(err) + closedCh := make(chan struct{}) + go func() { + s.stream.WaitForClosed() + closedCh <- struct{}{} + }() + select { + case <-closedCh: + case <-time.After(time.Second): + s.re.FailNow("stream receiver loop didn't exit") + } + + for _, ch := range resultCh { + res := s.getResult(ch) + s.re.Error(res.err) + if err == io.EOF { + s.re.ErrorIs(res.err, errs.ErrClientTSOStreamClosed) + } else { + s.re.ErrorIs(res.err, err) + } + } +} + +func (s *testTSOStreamSuite) TestTSOStreamBrokenWithEOFNoPendingReq() { + s.testTSOStreamBrokenImpl(io.EOF, 0) +} + +func (s *testTSOStreamSuite) TestTSOStreamCanceledNoPendingReq() { + s.testTSOStreamBrokenImpl(context.Canceled, 0) +} + +func (s *testTSOStreamSuite) TestTSOStreamBrokenWithEOFWithPendingReq() { + s.testTSOStreamBrokenImpl(io.EOF, 5) +} + +func (s *testTSOStreamSuite) TestTSOStreamCanceledWithPendingReq() { + s.testTSOStreamBrokenImpl(context.Canceled, 5) +} + +func (s *testTSOStreamSuite) TestTSOStreamFIFO() { + var resultChs []<-chan callbackInvocation + const count = 5 + for i := 0; i < count; i++ { + ch := s.mustProcessRequestWithResultCh(int64(i + 1)) + resultChs = append(resultChs, ch) + } + + for _, ch := range resultChs { + s.noResult(ch) + } + + for i := 0; i < count; i++ { + s.inner.returnResult(int64((i+1)*10), int64(i), uint32(i+1)) + } + + for i, ch := range resultChs { + res := s.getResult(ch) + s.re.NoError(res.err) + s.re.Equal(int64((i+1)*10), res.result.physical) + s.re.Equal(int64(i), res.result.logical) + s.re.Equal(uint32(i+1), res.result.count) + } +} + +func (s *testTSOStreamSuite) TestTSOStreamConcurrentRunning() { + resultChCh := make(chan (<-chan callbackInvocation), 10000) + const totalCount = 10000 + + // Continuously start requests + go func() { + for i := 1; i <= totalCount; i++ { + // Retry loop + for { + ch, err := s.processRequestWithResultCh(int64(i)) + if err != nil { + // If the capacity of the request queue is exhausted, it returns this error. As a test, we simply + // spin and retry it until it has enough space, as a coverage of the almost-full case. But note that + // this should not happen in production, in which case the caller of tsoStream should have its own + // limit of concurrent RPC requests. + s.Contains(err.Error(), "unexpected channel full") + continue + } + + resultChCh <- ch + break + } + } + }() + + // Continuously send results + go func() { + for i := int64(1); i <= totalCount; i++ { + s.inner.returnResult(i*10, i%(1<<18), uint32(i)) + } + s.inner.breakStream(io.EOF) + }() + + // Check results + for i := int64(1); i <= totalCount; i++ { + ch := <-resultChCh + res := s.getResult(ch) + s.re.NoError(res.err) + s.re.Equal(i*10, res.result.physical) + s.re.Equal(i%(1<<18), res.result.logical) + s.re.Equal(uint32(i), res.result.count) + } + + // After handling all these requests, the stream is ended by an EOF error. The next request won't succeed. + // So, either the `processRequests` function returns an error or the callback is called with an error. + ch, err := s.processRequestWithResultCh(1) + if err != nil { + s.re.ErrorIs(err, errs.ErrClientTSOStreamClosed) + } else { + res := s.getResult(ch) + s.re.Error(res.err) + s.re.ErrorIs(res.err, errs.ErrClientTSOStreamClosed) + } +} + +func BenchmarkTSOStreamSendRecv(b *testing.B) { + log.SetLevel(zapcore.FatalLevel) + + streamInner := newMockTSOStreamImpl(context.Background(), true) + stream := newTSOStream(context.Background(), mockStreamURL, streamInner) + defer func() { + streamInner.stop() + stream.WaitForClosed() + }() + + now := time.Now() + resCh := make(chan tsoRequestResult, 1) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + err := stream.processRequests(1, 1, 1, globalDCLocation, 1, now, func(result tsoRequestResult, _ uint32, err error) { + if err != nil { + panic(err) + } + select { + case resCh <- result: + default: + panic("channel not cleared in the last iteration") + } + }) + if err != nil { + panic(err) + } + <-resCh + } + b.StopTimer() +} diff --git a/tests/integrations/mcs/tso/keyspace_group_manager_test.go b/tests/integrations/mcs/tso/keyspace_group_manager_test.go index 09d8011c6c8..60ec4843130 100644 --- a/tests/integrations/mcs/tso/keyspace_group_manager_test.go +++ b/tests/integrations/mcs/tso/keyspace_group_manager_test.go @@ -16,6 +16,8 @@ package tso import ( "context" + "errors" + "fmt" "math/rand" "net/http" "strings" @@ -473,10 +475,11 @@ func (suite *tsoKeyspaceGroupManagerTestSuite) dispatchClient( strings.Contains(errMsg, clierrs.NotLeaderErr) || strings.Contains(errMsg, clierrs.NotServedErr) || strings.Contains(errMsg, "ErrKeyspaceNotAssigned") || - strings.Contains(errMsg, "ErrKeyspaceGroupIsMerging") { + strings.Contains(errMsg, "ErrKeyspaceGroupIsMerging") || + errors.Is(err, clierrs.ErrClientTSOStreamClosed) { continue } - re.FailNow(errMsg) + re.FailNow(fmt.Sprintf("%+v", err)) } if physical == lastPhysical { re.Greater(logical, lastLogical)