From 7b4409c1194665bc1eabc739f1ca97e12c92599c Mon Sep 17 00:00:00 2001 From: you06 Date: Thu, 9 Jan 2025 00:19:10 +0900 Subject: [PATCH] Try to validate read ts for all RPC requests (#1513) (#1546) Signed-off-by: MyonKeminta Signed-off-by: you06 Co-authored-by: MyonKeminta <9948422+MyonKeminta@users.noreply.github.com> --- internal/locate/region_cache_test.go | 5 +- internal/locate/region_request.go | 54 ++++++- internal/locate/region_request3_test.go | 8 +- internal/locate/region_request_state_test.go | 3 +- internal/locate/region_request_test.go | 15 +- oracle/oracle.go | 38 ++++- oracle/oracles/export_test.go | 2 +- oracle/oracles/local.go | 15 +- oracle/oracles/mock.go | 19 ++- oracle/oracles/pd.go | 54 ++++--- oracle/oracles/pd_test.go | 160 +++++++++++++++---- rawkv/rawkv.go | 9 +- tikv/kv.go | 2 +- tikv/region.go | 5 +- tikv/split_region.go | 2 +- txnkv/transaction/commit.go | 2 +- txnkv/transaction/pessimistic.go | 2 +- txnkv/transaction/prewrite.go | 2 +- txnkv/txnsnapshot/client_helper.go | 5 +- txnkv/txnsnapshot/scan.go | 2 +- 20 files changed, 314 insertions(+), 90 deletions(-) diff --git a/internal/locate/region_cache_test.go b/internal/locate/region_cache_test.go index 0a6b55da0..2410b7a56 100644 --- a/internal/locate/region_cache_test.go +++ b/internal/locate/region_cache_test.go @@ -54,6 +54,7 @@ import ( "github.com/tikv/client-go/v2/internal/mockstore/mocktikv" "github.com/tikv/client-go/v2/internal/retry" "github.com/tikv/client-go/v2/kv" + "github.com/tikv/client-go/v2/oracle" pd "github.com/tikv/pd/client" ) @@ -1004,7 +1005,7 @@ func (s *testRegionCacheSuite) TestRegionEpochOnTiFlash() { s.Equal(ctxTiFlash.Peer.Id, s.peer1) ctxTiFlash.Peer.Role = metapb.PeerRole_Learner r := ctxTiFlash.Meta - reqSend := NewRegionRequestSender(s.cache, nil) + reqSend := NewRegionRequestSender(s.cache, nil, oracle.NoopReadTSValidator{}) regionErr := &errorpb.Error{EpochNotMatch: &errorpb.EpochNotMatch{CurrentRegions: []*metapb.Region{r}}} reqSend.onRegionError(s.bo, ctxTiFlash, nil, regionErr) @@ -1640,7 +1641,7 @@ func (s *testRegionCacheSuite) TestShouldNotRetryFlashback() { ctx, err := s.cache.GetTiKVRPCContext(retry.NewBackofferWithVars(context.Background(), 100, nil), loc.Region, kv.ReplicaReadLeader, 0) s.NotNil(ctx) s.NoError(err) - reqSend := NewRegionRequestSender(s.cache, nil) + reqSend := NewRegionRequestSender(s.cache, nil, oracle.NoopReadTSValidator{}) shouldRetry, err := reqSend.onRegionError(s.bo, ctx, nil, &errorpb.Error{FlashbackInProgress: &errorpb.FlashbackInProgress{}}) s.Error(err) s.False(shouldRetry) diff --git a/internal/locate/region_request.go b/internal/locate/region_request.go index b7c5d8fbd..44ad2e09b 100644 --- a/internal/locate/region_request.go +++ b/internal/locate/region_request.go @@ -46,6 +46,7 @@ import ( "sync/atomic" "time" + "github.com/tikv/client-go/v2/oracle" "go.uber.org/zap" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" @@ -105,6 +106,7 @@ type RegionRequestSender struct { regionCache *RegionCache apiVersion kvrpcpb.APIVersion client client.Client + readTSValidator oracle.ReadTSValidator storeAddr string rpcError error replicaSelector *replicaSelector @@ -193,11 +195,12 @@ func RecordRegionRequestRuntimeStats(stats map[tikvrpc.CmdType]*RPCRuntimeStats, } // NewRegionRequestSender creates a new sender. -func NewRegionRequestSender(regionCache *RegionCache, client client.Client) *RegionRequestSender { +func NewRegionRequestSender(regionCache *RegionCache, client client.Client, readTSValidator oracle.ReadTSValidator) *RegionRequestSender { return &RegionRequestSender{ - regionCache: regionCache, - apiVersion: regionCache.codec.GetAPIVersion(), - client: client, + regionCache: regionCache, + apiVersion: regionCache.codec.GetAPIVersion(), + client: client, + readTSValidator: readTSValidator, } } @@ -1261,6 +1264,11 @@ func (s *RegionRequestSender) SendReqCtx( } } + if err = s.validateReadTS(bo.GetCtx(), req); err != nil { + logutil.Logger(bo.GetCtx()).Error("validate read ts failed for request", zap.Stringer("reqType", req.Type), zap.Stringer("req", req.Req.(fmt.Stringer)), zap.Stringer("context", &req.Context), zap.Stack("stack"), zap.Error(err)) + return nil, nil, 0, err + } + // If the MaxExecutionDurationMs is not set yet, we set it to be the RPC timeout duration // so TiKV can give up the requests whose response TiDB cannot receive due to timeout. if req.Context.MaxExecutionDurationMs == 0 { @@ -2179,6 +2187,44 @@ func (s *RegionRequestSender) onRegionError( return false, nil } +func (s *RegionRequestSender) validateReadTS(ctx context.Context, req *tikvrpc.Request) error { + if req.StoreTp == tikvrpc.TiDB { + // Skip the checking if the store type is TiDB. + return nil + } + + var readTS uint64 + switch req.Type { + case tikvrpc.CmdGet, tikvrpc.CmdScan, tikvrpc.CmdBatchGet, tikvrpc.CmdCop, tikvrpc.CmdCopStream, tikvrpc.CmdBatchCop, tikvrpc.CmdScanLock: + readTS = req.GetStartTS() + + // TODO: Check transactional write requests that has implicit read. + // case tikvrpc.CmdPessimisticLock: + // readTS = req.PessimisticLock().GetForUpdateTs() + // case tikvrpc.CmdPrewrite: + // inner := req.Prewrite() + // readTS = inner.GetForUpdateTs() + // if readTS == 0 { + // readTS = inner.GetStartVersion() + // } + // case tikvrpc.CmdCheckTxnStatus: + // inner := req.CheckTxnStatus() + // // TiKV uses the greater one of these three fields to update the max_ts. + // readTS = inner.GetLockTs() + // if inner.GetCurrentTs() != math.MaxUint64 && inner.GetCurrentTs() > readTS { + // readTS = inner.GetCurrentTs() + // } + // if inner.GetCallerStartTs() != math.MaxUint64 && inner.GetCallerStartTs() > readTS { + // readTS = inner.GetCallerStartTs() + // } + // case tikvrpc.CmdCheckSecondaryLocks, tikvrpc.CmdCleanup, tikvrpc.CmdBatchRollback: + // readTS = req.GetStartTS() + default: + return nil + } + return s.readTSValidator.ValidateReadTS(ctx, readTS, req.StaleRead, &oracle.Option{TxnScope: req.TxnScope}) +} + type staleReadMetricsCollector struct { } diff --git a/internal/locate/region_request3_test.go b/internal/locate/region_request3_test.go index cd9a8e3cb..f9ed22f8f 100644 --- a/internal/locate/region_request3_test.go +++ b/internal/locate/region_request3_test.go @@ -43,6 +43,7 @@ import ( "time" "unsafe" + "github.com/pingcap/failpoint" "github.com/pingcap/kvproto/pkg/errorpb" "github.com/pingcap/kvproto/pkg/kvrpcpb" "github.com/pingcap/kvproto/pkg/metapb" @@ -82,7 +83,9 @@ func (s *testRegionRequestToThreeStoresSuite) SetupTest() { s.cache = NewRegionCache(pdCli) s.bo = retry.NewNoopBackoff(context.Background()) client := mocktikv.NewRPCClient(s.cluster, s.mvccStore, nil) - s.regionRequestSender = NewRegionRequestSender(s.cache, client) + s.regionRequestSender = NewRegionRequestSender(s.cache, client, oracle.NoopReadTSValidator{}) + + s.NoError(failpoint.Enable("tikvclient/doNotRecoverStoreHealthCheckPanic", "return")) } func (s *testRegionRequestToThreeStoresSuite) TearDownTest() { @@ -147,7 +150,8 @@ func (s *testRegionRequestToThreeStoresSuite) loadAndGetLeaderStore() (*Store, s } func (s *testRegionRequestToThreeStoresSuite) TestForwarding() { - s.regionRequestSender.regionCache.enableForwarding = true + sender := NewRegionRequestSender(s.cache, s.regionRequestSender.client, oracle.NoopReadTSValidator{}) + sender.regionCache.enableForwarding = true // First get the leader's addr from region cache leaderStore, leaderAddr := s.loadAndGetLeaderStore() diff --git a/internal/locate/region_request_state_test.go b/internal/locate/region_request_state_test.go index d1b5c6201..f2c940881 100644 --- a/internal/locate/region_request_state_test.go +++ b/internal/locate/region_request_state_test.go @@ -34,6 +34,7 @@ import ( "github.com/tikv/client-go/v2/internal/retry" "github.com/tikv/client-go/v2/kv" "github.com/tikv/client-go/v2/metrics" + "github.com/tikv/client-go/v2/oracle" "github.com/tikv/client-go/v2/tikvrpc" ) @@ -76,7 +77,7 @@ func (s *testRegionCacheStaleReadSuite) SetupTest() { s.cache = NewRegionCache(pdCli) s.bo = retry.NewNoopBackoff(context.Background()) client := mocktikv.NewRPCClient(s.cluster, s.mvccStore, nil) - s.regionRequestSender = NewRegionRequestSender(s.cache, client) + s.regionRequestSender = NewRegionRequestSender(s.cache, client, oracle.NoopReadTSValidator{}) s.setClient() s.injection = testRegionCacheFSMSuiteInjection{ unavailableStoreIDs: make(map[uint64]struct{}), diff --git a/internal/locate/region_request_test.go b/internal/locate/region_request_test.go index cabec4741..30c035bd1 100644 --- a/internal/locate/region_request_test.go +++ b/internal/locate/region_request_test.go @@ -60,7 +60,9 @@ import ( "github.com/tikv/client-go/v2/internal/client/mock_server" "github.com/tikv/client-go/v2/internal/mockstore/mocktikv" "github.com/tikv/client-go/v2/internal/retry" + "github.com/tikv/client-go/v2/oracle" "github.com/tikv/client-go/v2/tikvrpc" + pd "github.com/tikv/pd/client" pderr "github.com/tikv/pd/client/errs" "google.golang.org/grpc" ) @@ -75,6 +77,7 @@ type testRegionRequestToSingleStoreSuite struct { store uint64 peer uint64 region uint64 + pdCli pd.Client cache *RegionCache bo *retry.Backoffer regionRequestSender *RegionRequestSender @@ -85,11 +88,11 @@ func (s *testRegionRequestToSingleStoreSuite) SetupTest() { s.mvccStore = mocktikv.MustNewMVCCStore() s.cluster = mocktikv.NewCluster(s.mvccStore) s.store, s.peer, s.region = mocktikv.BootstrapWithSingleStore(s.cluster) - pdCli := &CodecPDClient{mocktikv.NewPDClient(s.cluster), apicodec.NewCodecV1(apicodec.ModeTxn)} - s.cache = NewRegionCache(pdCli) + s.pdCli = &CodecPDClient{mocktikv.NewPDClient(s.cluster), apicodec.NewCodecV1(apicodec.ModeTxn)} + s.cache = NewRegionCache(s.pdCli) s.bo = retry.NewNoopBackoff(context.Background()) client := mocktikv.NewRPCClient(s.cluster, s.mvccStore, nil) - s.regionRequestSender = NewRegionRequestSender(s.cache, client) + s.regionRequestSender = NewRegionRequestSender(s.cache, client, oracle.NoopReadTSValidator{}) } func (s *testRegionRequestToSingleStoreSuite) TearDownTest() { @@ -567,7 +570,7 @@ func (s *testRegionRequestToSingleStoreSuite) TestNoReloadRegionForGrpcWhenCtxCa }() cli := client.NewRPCClient() - sender := NewRegionRequestSender(s.cache, cli) + sender := NewRegionRequestSender(s.cache, cli, oracle.NoopReadTSValidator{}) req := tikvrpc.NewRequest(tikvrpc.CmdRawPut, &kvrpcpb.RawPutRequest{ Key: []byte("key"), Value: []byte("value"), @@ -586,7 +589,7 @@ func (s *testRegionRequestToSingleStoreSuite) TestNoReloadRegionForGrpcWhenCtxCa Client: client.NewRPCClient(), redirectAddr: addr, } - sender = NewRegionRequestSender(s.cache, client1) + sender = NewRegionRequestSender(s.cache, client1, oracle.NoopReadTSValidator{}) sender.SendReq(s.bo, req, region.Region, 3*time.Second) // cleanup @@ -772,7 +775,7 @@ func (s *testRegionRequestToSingleStoreSuite) TestBatchClientSendLoopPanic() { cancel() }() req := tikvrpc.NewRequest(tikvrpc.CmdCop, &coprocessor.Request{Data: []byte("a"), StartTs: 1}) - regionRequestSender := NewRegionRequestSender(s.cache, fnClient) + regionRequestSender := NewRegionRequestSender(s.cache, fnClient, oracle.NoopReadTSValidator{}) regionRequestSender.regionCache.testingKnobs.mockRequestLiveness.Store((*livenessFunc)(&tf)) regionRequestSender.SendReq(bo, req, region.Region, client.ReadTimeoutShort) } diff --git a/oracle/oracle.go b/oracle/oracle.go index 7ace335ec..88de9d3ae 100644 --- a/oracle/oracle.go +++ b/oracle/oracle.go @@ -36,6 +36,7 @@ package oracle import ( "context" + "fmt" "time" ) @@ -64,12 +65,17 @@ type Oracle interface { GetExternalTimestamp(ctx context.Context) (uint64, error) SetExternalTimestamp(ctx context.Context, ts uint64) error - // ValidateSnapshotReadTS verifies whether it can be guaranteed that the given readTS doesn't exceed the maximum ts - // that has been allocated by the oracle, so that it's safe to use this ts to perform snapshot read, stale read, - // etc. + ReadTSValidator +} + +// ReadTSValidator is the interface for providing the ability for verifying whether a timestamp is safe to be used +// for readings, as part of the `Oracle` interface. +type ReadTSValidator interface { + // ValidateReadTS verifies whether it can be guaranteed that the given readTS doesn't exceed the maximum ts + // that has been allocated by the oracle, so that it's safe to use this ts to perform read operations. // Note that this method only checks the ts from the oracle's perspective. It doesn't check whether the snapshot // has been GCed. - ValidateSnapshotReadTS(ctx context.Context, readTS uint64, opt *Option) error + ValidateReadTS(ctx context.Context, readTS uint64, isStaleRead bool, opt *Option) error } // Future is a future which promises to return a timestamp. @@ -121,3 +127,27 @@ func GoTimeToTS(t time.Time) uint64 { func GoTimeToLowerLimitStartTS(now time.Time, maxTxnTimeUse int64) uint64 { return GoTimeToTS(now.Add(-time.Duration(maxTxnTimeUse) * time.Millisecond)) } + +// NoopReadTSValidator is a dummy implementation of ReadTSValidator that always let the validation pass. +// Only use this when using RPCs that are not related to ts (e.g. rawkv), or in tests where `Oracle` is not available +// and the validation is not necessary. +type NoopReadTSValidator struct{} + +func (NoopReadTSValidator) ValidateReadTS(ctx context.Context, readTS uint64, isStaleRead bool, opt *Option) error { + return nil +} + +type ErrFutureTSRead struct { + ReadTS uint64 + CurrentTS uint64 +} + +func (e ErrFutureTSRead) Error() string { + return fmt.Sprintf("cannot set read timestamp to a future time, readTS: %d, currentTS: %d", e.ReadTS, e.CurrentTS) +} + +type ErrLatestStaleRead struct{} + +func (ErrLatestStaleRead) Error() string { + return "cannot set read ts to max uint64 for stale read" +} diff --git a/oracle/oracles/export_test.go b/oracle/oracles/export_test.go index 08df25783..78e7c0a8b 100644 --- a/oracle/oracles/export_test.go +++ b/oracle/oracles/export_test.go @@ -65,6 +65,6 @@ func SetEmptyPDOracleLastTs(oc oracle.Oracle, ts uint64) { case *pdOracle: lastTSInterface, _ := o.lastTSMap.LoadOrStore(oracle.GlobalTxnScope, &atomic.Pointer[lastTSO]{}) lastTSPointer := lastTSInterface.(*atomic.Pointer[lastTSO]) - lastTSPointer.Store(&lastTSO{tso: ts, arrival: ts}) + lastTSPointer.Store(&lastTSO{tso: ts, arrival: oracle.GetTimeFromTS(ts)}) } } diff --git a/oracle/oracles/local.go b/oracle/oracles/local.go index 1e6b747c9..e916286ac 100644 --- a/oracle/oracles/local.go +++ b/oracle/oracles/local.go @@ -36,6 +36,7 @@ package oracles import ( "context" + "math" "sync" "time" @@ -136,13 +137,23 @@ func (l *localOracle) GetExternalTimestamp(ctx context.Context) (uint64, error) return l.getExternalTimestamp(ctx) } -func (l *localOracle) ValidateSnapshotReadTS(ctx context.Context, readTS uint64, opt *oracle.Option) error { +func (l *localOracle) ValidateReadTS(ctx context.Context, readTS uint64, isStaleRead bool, opt *oracle.Option) error { + if readTS == math.MaxUint64 { + if isStaleRead { + return oracle.ErrLatestStaleRead{} + } + return nil + } + currentTS, err := l.GetTimestamp(ctx, opt) if err != nil { return errors.Errorf("fail to validate read timestamp: %v", err) } if currentTS < readTS { - return errors.Errorf("cannot set read timestamp to a future time") + return oracle.ErrFutureTSRead{ + ReadTS: readTS, + CurrentTS: currentTS, + } } return nil } diff --git a/oracle/oracles/mock.go b/oracle/oracles/mock.go index 183b4c2d6..da8874d5c 100644 --- a/oracle/oracles/mock.go +++ b/oracle/oracles/mock.go @@ -36,6 +36,7 @@ package oracles import ( "context" + "math" "sync" "time" @@ -122,13 +123,27 @@ func (o *MockOracle) GetLowResolutionTimestampAsync(ctx context.Context, opt *or return o.GetTimestampAsync(ctx, opt) } -func (o *MockOracle) ValidateSnapshotReadTS(ctx context.Context, readTS uint64, opt *oracle.Option) error { +func (o *MockOracle) SetLowResolutionTimestampUpdateInterval(time.Duration) error { + return nil +} + +func (o *MockOracle) ValidateReadTS(ctx context.Context, readTS uint64, isStaleRead bool, opt *oracle.Option) error { + if readTS == math.MaxUint64 { + if isStaleRead { + return oracle.ErrLatestStaleRead{} + } + return nil + } + currentTS, err := o.GetTimestamp(ctx, opt) if err != nil { return errors.Errorf("fail to validate read timestamp: %v", err) } if currentTS < readTS { - return errors.Errorf("cannot set read timestamp to a future time") + return oracle.ErrFutureTSRead{ + ReadTS: readTS, + CurrentTS: currentTS, + } } return nil } diff --git a/oracle/oracles/pd.go b/oracle/oracles/pd.go index 6e7fb9b6f..805d1b5c7 100644 --- a/oracle/oracles/pd.go +++ b/oracle/oracles/pd.go @@ -37,6 +37,7 @@ package oracles import ( "context" "fmt" + "math" "strings" "sync" "sync/atomic" @@ -149,7 +150,7 @@ type pdOracle struct { // When the low resolution ts is not new enough and there are many concurrent stane read / snapshot read // operations that needs to validate the read ts, we can use this to avoid too many concurrent GetTS calls by - // reusing a result for different `ValidateSnapshotReadTS` calls. This can be done because that + // reusing a result for different `ValidateReadTS` calls. This can be done because that // we don't require the ts for validation to be strictly the latest one. // Note that the result can't be reused for different txnScopes. The txnScope is used as the key. tsForValidation singleflight.Group @@ -158,7 +159,7 @@ type pdOracle struct { // lastTSO stores the last timestamp oracle gets from PD server and the local time when the TSO is fetched. type lastTSO struct { tso uint64 - arrival uint64 + arrival time.Time } type PDOracleOptions struct { @@ -272,17 +273,13 @@ func (o *pdOracle) getTimestamp(ctx context.Context, txnScope string) (uint64, e return oracle.ComposeTS(physical, logical), nil } -func (o *pdOracle) getArrivalTimestamp() uint64 { - return oracle.GoTimeToTS(time.Now()) -} - func (o *pdOracle) setLastTS(ts uint64, txnScope string) { if txnScope == "" { txnScope = oracle.GlobalTxnScope } current := &lastTSO{ tso: ts, - arrival: o.getArrivalTimestamp(), + arrival: time.Now(), } lastTSInterface, ok := o.lastTSMap.Load(txnScope) if !ok { @@ -294,9 +291,12 @@ func (o *pdOracle) setLastTS(ts uint64, txnScope string) { lastTSPointer := lastTSInterface.(*atomic.Pointer[lastTSO]) for { last := lastTSPointer.Load() - if current.tso <= last.tso || current.arrival <= last.arrival { + if current.tso <= last.tso { return } + if last.arrival.After(current.arrival) { + current.arrival = last.arrival + } if lastTSPointer.CompareAndSwap(last, current) { return } @@ -561,8 +561,11 @@ func (o *pdOracle) getStaleTimestamp(txnScope string, prevSecond uint64) (uint64 if !ok { return 0, errors.Errorf("get stale timestamp fail, txnScope: %s", txnScope) } - ts, arrivalTS := last.tso, last.arrival - arrivalTime := oracle.GetTimeFromTS(arrivalTS) + return o.getStaleTimestampWithLastTS(last, prevSecond) +} + +func (o *pdOracle) getStaleTimestampWithLastTS(last *lastTSO, prevSecond uint64) (uint64, error) { + ts, arrivalTime := last.tso, last.arrival physicalTime := oracle.GetTimeFromTS(ts) if uint64(physicalTime.Unix()) <= prevSecond { return 0, errors.Errorf("invalid prevSecond %v", prevSecond) @@ -617,22 +620,34 @@ func (o *pdOracle) getCurrentTSForValidation(ctx context.Context, opt *oracle.Op } } -func (o *pdOracle) ValidateSnapshotReadTS(ctx context.Context, readTS uint64, opt *oracle.Option) error { - latestTS, err := o.GetLowResolutionTimestamp(ctx, opt) - // If we fail to get latestTS or the readTS exceeds it, get a timestamp from PD to double-check. +func (o *pdOracle) ValidateReadTS(ctx context.Context, readTS uint64, isStaleRead bool, opt *oracle.Option) (errRet error) { + if readTS == math.MaxUint64 { + if isStaleRead { + return oracle.ErrLatestStaleRead{} + } + return nil + } + + latestTSInfo, exists := o.getLastTSWithArrivalTS(opt.TxnScope) + // If we fail to get latestTSInfo or the readTS exceeds it, get a timestamp from PD to double-check. // But we don't need to strictly fetch the latest TS. So if there are already concurrent calls to this function // loading the latest TS, we can just reuse the same result to avoid too many concurrent GetTS calls. - if err != nil || readTS > latestTS { + if !exists || readTS > latestTSInfo.tso { currentTS, err := o.getCurrentTSForValidation(ctx, opt) if err != nil { return errors.Errorf("fail to validate read timestamp: %v", err) } - o.adjustUpdateLowResolutionTSIntervalWithRequestedStaleness(readTS, currentTS, time.Now()) + if isStaleRead { + o.adjustUpdateLowResolutionTSIntervalWithRequestedStaleness(readTS, currentTS, time.Now()) + } if readTS > currentTS { - return errors.Errorf("cannot set read timestamp to a future time") + return oracle.ErrFutureTSRead{ + ReadTS: readTS, + CurrentTS: currentTS, + } } - } else { - estimatedCurrentTS, err := o.getStaleTimestamp(opt.TxnScope, 0) + } else if isStaleRead { + estimatedCurrentTS, err := o.getStaleTimestampWithLastTS(latestTSInfo, 0) if err != nil { logutil.Logger(ctx).Warn("failed to estimate current ts by getSlateTimestamp for auto-adjusting update low resolution ts interval", zap.Error(err), zap.Uint64("readTS", readTS), zap.String("txnScope", opt.TxnScope)) @@ -643,6 +658,9 @@ func (o *pdOracle) ValidateSnapshotReadTS(ctx context.Context, readTS uint64, op return nil } +// adjustUpdateLowResolutionTSIntervalWithRequestedStaleness triggers adjustments the update interval of low resolution +// ts, if necessary, to suite the usage of stale read. +// This method is not supposed to be called when performing non-stale-read operations. func (o *pdOracle) adjustUpdateLowResolutionTSIntervalWithRequestedStaleness(readTS uint64, currentTS uint64, now time.Time) { requiredStaleness := oracle.GetTimeFromTS(currentTS).Sub(oracle.GetTimeFromTS(readTS)) diff --git a/oracle/oracles/pd_test.go b/oracle/oracles/pd_test.go index 48739fd5e..25345f3b8 100644 --- a/oracle/oracles/pd_test.go +++ b/oracle/oracles/pd_test.go @@ -237,40 +237,54 @@ func TestAdaptiveUpdateTSInterval(t *testing.T) { assert.Equal(t, adaptiveUpdateTSIntervalStateNormal, o.adaptiveUpdateIntervalState.state) } -func TestValidateSnapshotReadTS(t *testing.T) { - pdClient := MockPdClient{} - o, err := NewPdOracle(&pdClient, &PDOracleOptions{ - UpdateInterval: time.Second * 2, - }) - assert.NoError(t, err) - defer o.Close() - - ctx := context.Background() - opt := &oracle.Option{TxnScope: oracle.GlobalTxnScope} - ts, err := o.GetTimestamp(ctx, opt) - assert.NoError(t, err) - assert.GreaterOrEqual(t, ts, uint64(1)) +func TestValidateReadTS(t *testing.T) { + testImpl := func(staleRead bool) { + pdClient := MockPdClient{} + o, err := NewPdOracle(&pdClient, &PDOracleOptions{ + UpdateInterval: time.Second * 2, + }) + assert.NoError(t, err) + defer o.Close() + + ctx := context.Background() + opt := &oracle.Option{TxnScope: oracle.GlobalTxnScope} + + // Always returns error for MaxUint64 + err = o.ValidateReadTS(ctx, math.MaxUint64, staleRead, opt) + if staleRead { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } - err = o.ValidateSnapshotReadTS(ctx, 1, opt) - assert.NoError(t, err) - ts, err = o.GetTimestamp(ctx, opt) - assert.NoError(t, err) - // The readTS exceeds the latest ts, so it first fails the check with the low resolution ts. Then it fallbacks to - // the fetching-from-PD path, and it can get the previous ts + 1, which can allow this validation to pass. - err = o.ValidateSnapshotReadTS(ctx, ts+1, opt) - assert.NoError(t, err) - // It can't pass if the readTS is newer than previous ts + 2. - ts, err = o.GetTimestamp(ctx, opt) - assert.NoError(t, err) - err = o.ValidateSnapshotReadTS(ctx, ts+2, opt) - assert.Error(t, err) + ts, err := o.GetTimestamp(ctx, opt) + assert.NoError(t, err) + assert.GreaterOrEqual(t, ts, uint64(1)) + + err = o.ValidateReadTS(ctx, 1, staleRead, opt) + assert.NoError(t, err) + ts, err = o.GetTimestamp(ctx, opt) + assert.NoError(t, err) + // The readTS exceeds the latest ts, so it first fails the check with the low resolution ts. Then it fallbacks to + // the fetching-from-PD path, and it can get the previous ts + 1, which can allow this validation to pass. + err = o.ValidateReadTS(ctx, ts+1, staleRead, opt) + assert.NoError(t, err) + // It can't pass if the readTS is newer than previous ts + 2. + ts, err = o.GetTimestamp(ctx, opt) + assert.NoError(t, err) + err = o.ValidateReadTS(ctx, ts+2, staleRead, opt) + assert.Error(t, err) + + // Simulate other PD clients requests a timestamp. + ts, err = o.GetTimestamp(ctx, opt) + assert.NoError(t, err) + pdClient.logicalTimestamp.Add(2) + err = o.ValidateReadTS(ctx, ts+3, staleRead, opt) + assert.NoError(t, err) + } - // Simulate other PD clients requests a timestamp. - ts, err = o.GetTimestamp(ctx, opt) - assert.NoError(t, err) - pdClient.logicalTimestamp.Add(2) - err = o.ValidateSnapshotReadTS(ctx, ts+3, opt) - assert.NoError(t, err) + testImpl(true) + testImpl(false) } type MockPDClientWithPause struct { @@ -292,7 +306,7 @@ func (c *MockPDClientWithPause) Resume() { c.mu.Unlock() } -func TestValidateSnapshotReadTSReusingGetTSResult(t *testing.T) { +func TestValidateReadTSForStaleReadReusingGetTSResult(t *testing.T) { pdClient := &MockPDClientWithPause{} o, err := NewPdOracle(pdClient, &PDOracleOptions{ UpdateInterval: time.Second * 2, @@ -304,7 +318,7 @@ func TestValidateSnapshotReadTSReusingGetTSResult(t *testing.T) { asyncValidate := func(ctx context.Context, readTS uint64) chan error { ch := make(chan error, 1) go func() { - err := o.ValidateSnapshotReadTS(ctx, readTS, &oracle.Option{TxnScope: oracle.GlobalTxnScope}) + err := o.ValidateReadTS(ctx, readTS, true, &oracle.Option{TxnScope: oracle.GlobalTxnScope}) ch <- err }() return ch @@ -313,7 +327,7 @@ func TestValidateSnapshotReadTSReusingGetTSResult(t *testing.T) { noResult := func(ch chan error) { select { case <-ch: - assert.FailNow(t, "a ValidateSnapshotReadTS operation is not blocked while it's expected to be blocked") + assert.FailNow(t, "a ValidateReadTS operation is not blocked while it's expected to be blocked") default: } } @@ -391,3 +405,79 @@ func TestValidateSnapshotReadTSReusingGetTSResult(t *testing.T) { } } } + +func TestValidateReadTSForNormalReadDoNotAffectUpdateInterval(t *testing.T) { + oracleInterface, err := NewPdOracle(&MockPdClient{}, &PDOracleOptions{ + UpdateInterval: time.Second * 2, + NoUpdateTS: true, + }) + assert.NoError(t, err) + o := oracleInterface.(*pdOracle) + defer o.Close() + + ctx := context.Background() + opt := &oracle.Option{TxnScope: oracle.GlobalTxnScope} + + // Validating read ts for non-stale-read requests must not trigger updating the adaptive update interval of + // low resolution ts. + mustNoNotify := func() { + select { + case <-o.adaptiveUpdateIntervalState.shrinkIntervalCh: + assert.Fail(t, "expects not notifying shrinking update interval immediately, but message was received") + default: + } + } + + ts, err := o.GetTimestamp(ctx, opt) + assert.NoError(t, err) + assert.GreaterOrEqual(t, ts, uint64(1)) + + err = o.ValidateReadTS(ctx, ts, false, opt) + assert.NoError(t, err) + mustNoNotify() + + // It loads `ts + 1` from the mock PD, and the check cannot pass. + err = o.ValidateReadTS(ctx, ts+2, false, opt) + assert.Error(t, err) + mustNoNotify() + + // Do the check again. It loads `ts + 2` from the mock PD, and the check passes. + err = o.ValidateReadTS(ctx, ts+2, false, opt) + assert.NoError(t, err) + mustNoNotify() +} + +func TestSetLastTSAlwaysPushTS(t *testing.T) { + oracleInterface, err := NewPdOracle(&MockPdClient{}, &PDOracleOptions{ + UpdateInterval: time.Second * 2, + NoUpdateTS: true, + }) + assert.NoError(t, err) + o := oracleInterface.(*pdOracle) + defer o.Close() + + var wg sync.WaitGroup + cancel := make(chan struct{}) + for i := 0; i < 100; i++ { + wg.Add(1) + go func() { + defer wg.Done() + ctx := context.Background() + for { + select { + case <-cancel: + return + default: + } + ts, err := o.GetTimestamp(ctx, &oracle.Option{TxnScope: oracle.GlobalTxnScope}) + assert.NoError(t, err) + lastTS, found := o.getLastTS(oracle.GlobalTxnScope) + assert.True(t, found) + assert.GreaterOrEqual(t, lastTS, ts) + } + }() + } + time.Sleep(time.Second) + close(cancel) + wg.Wait() +} diff --git a/rawkv/rawkv.go b/rawkv/rawkv.go index cebd534d4..2a841fb09 100644 --- a/rawkv/rawkv.go +++ b/rawkv/rawkv.go @@ -48,6 +48,7 @@ import ( "github.com/tikv/client-go/v2/internal/locate" "github.com/tikv/client-go/v2/internal/retry" "github.com/tikv/client-go/v2/metrics" + "github.com/tikv/client-go/v2/oracle" "github.com/tikv/client-go/v2/tikv" "github.com/tikv/client-go/v2/tikvrpc" pd "github.com/tikv/pd/client" @@ -686,7 +687,7 @@ func (c *Client) CompareAndSwap(ctx context.Context, key, previousValue, newValu func (c *Client) sendReq(ctx context.Context, key []byte, req *tikvrpc.Request, reverse bool) (*tikvrpc.Response, *locate.KeyLocation, error) { bo := retry.NewBackofferWithVars(ctx, rawkvMaxBackoff, nil) - sender := locate.NewRegionRequestSender(c.regionCache, c.rpcClient) + sender := locate.NewRegionRequestSender(c.regionCache, c.rpcClient, oracle.NoopReadTSValidator{}) for { var loc *locate.KeyLocation var err error @@ -783,7 +784,7 @@ func (c *Client) doBatchReq(bo *retry.Backoffer, batch kvrpc.Batch, options *raw }) } - sender := locate.NewRegionRequestSender(c.regionCache, c.rpcClient) + sender := locate.NewRegionRequestSender(c.regionCache, c.rpcClient, oracle.NoopReadTSValidator{}) req.MaxExecutionDurationMs = uint64(client.MaxWriteExecutionTime.Milliseconds()) resp, _, err := sender.SendReq(bo, req, batch.RegionID, client.ReadTimeoutShort) @@ -833,7 +834,7 @@ func (c *Client) doBatchReq(bo *retry.Backoffer, batch kvrpc.Batch, options *raw // TODO: Is there any better way to avoid duplicating code with func `sendReq` ? func (c *Client) sendDeleteRangeReq(ctx context.Context, startKey []byte, endKey []byte, opts *rawOptions) (*tikvrpc.Response, []byte, error) { bo := retry.NewBackofferWithVars(ctx, rawkvMaxBackoff, nil) - sender := locate.NewRegionRequestSender(c.regionCache, c.rpcClient) + sender := locate.NewRegionRequestSender(c.regionCache, c.rpcClient, oracle.NoopReadTSValidator{}) for { loc, err := c.regionCache.LocateKey(bo, startKey) if err != nil { @@ -935,7 +936,7 @@ func (c *Client) doBatchPut(bo *retry.Backoffer, batch kvrpc.Batch, opts *rawOpt Ttl: ttl, }) - sender := locate.NewRegionRequestSender(c.regionCache, c.rpcClient) + sender := locate.NewRegionRequestSender(c.regionCache, c.rpcClient, oracle.NoopReadTSValidator{}) req.MaxExecutionDurationMs = uint64(client.MaxWriteExecutionTime.Milliseconds()) req.ApiVersion = c.apiVersion resp, _, err := sender.SendReq(bo, req, batch.RegionID, client.ReadTimeoutShort) diff --git a/tikv/kv.go b/tikv/kv.go index 2c179aea2..acf43f274 100644 --- a/tikv/kv.go +++ b/tikv/kv.go @@ -455,7 +455,7 @@ func (s *KVStore) SupportDeleteRange() (supported bool) { func (s *KVStore) SendReq( bo *Backoffer, req *tikvrpc.Request, regionID locate.RegionVerID, timeout time.Duration, ) (*tikvrpc.Response, error) { - sender := locate.NewRegionRequestSender(s.regionCache, s.GetTiKVClient()) + sender := locate.NewRegionRequestSender(s.regionCache, s.GetTiKVClient(), s.oracle) resp, _, err := sender.SendReq(bo, req, regionID, timeout) return resp, err } diff --git a/tikv/region.go b/tikv/region.go index 6b5e4874d..f32628a3a 100644 --- a/tikv/region.go +++ b/tikv/region.go @@ -41,6 +41,7 @@ import ( "github.com/tikv/client-go/v2/internal/apicodec" "github.com/tikv/client-go/v2/internal/client" "github.com/tikv/client-go/v2/internal/locate" + "github.com/tikv/client-go/v2/oracle" "github.com/tikv/client-go/v2/tikvrpc" pd "github.com/tikv/pd/client" ) @@ -165,8 +166,8 @@ func GetStoreTypeByMeta(store *metapb.Store) tikvrpc.EndpointType { } // NewRegionRequestSender creates a new sender. -func NewRegionRequestSender(regionCache *RegionCache, client client.Client) *RegionRequestSender { - return locate.NewRegionRequestSender(regionCache, client) +func NewRegionRequestSender(regionCache *RegionCache, client client.Client, readTSValidator oracle.ReadTSValidator) *RegionRequestSender { + return locate.NewRegionRequestSender(regionCache, client, readTSValidator) } // LoadShuttingDown atomically loads ShuttingDown. diff --git a/tikv/split_region.go b/tikv/split_region.go index 2844b3889..6f9d1f9dd 100644 --- a/tikv/split_region.go +++ b/tikv/split_region.go @@ -148,7 +148,7 @@ func (s *KVStore) batchSendSingleRegion(bo *Backoffer, batch kvrpc.Batch, scatte RequestSource: util.RequestSourceFromCtx(bo.GetCtx()), }) - sender := locate.NewRegionRequestSender(s.regionCache, s.GetTiKVClient()) + sender := locate.NewRegionRequestSender(s.regionCache, s.GetTiKVClient(), s.oracle) resp, _, err := sender.SendReq(bo, req, batch.RegionID, client.ReadTimeoutShort) batchResp := kvrpc.BatchResult{Response: resp} diff --git a/txnkv/transaction/commit.go b/txnkv/transaction/commit.go index 9e3eac4fe..3818e3978 100644 --- a/txnkv/transaction/commit.go +++ b/txnkv/transaction/commit.go @@ -95,7 +95,7 @@ func (action actionCommit) handleSingleBatch(c *twoPhaseCommitter, bo *retry.Bac tBegin := time.Now() attempts := 0 - sender := locate.NewRegionRequestSender(c.store.GetRegionCache(), c.store.GetTiKVClient()) + sender := locate.NewRegionRequestSender(c.store.GetRegionCache(), c.store.GetTiKVClient(), c.store.GetOracle()) for { attempts++ reqBegin := time.Now() diff --git a/txnkv/transaction/pessimistic.go b/txnkv/transaction/pessimistic.go index 28835baeb..e485462df 100644 --- a/txnkv/transaction/pessimistic.go +++ b/txnkv/transaction/pessimistic.go @@ -184,7 +184,7 @@ func (action actionPessimisticLock) handleSingleBatch( time.Sleep(300 * time.Millisecond) return errors.WithStack(&tikverr.ErrWriteConflict{WriteConflict: nil}) } - sender := locate.NewRegionRequestSender(c.store.GetRegionCache(), c.store.GetTiKVClient()) + sender := locate.NewRegionRequestSender(c.store.GetRegionCache(), c.store.GetTiKVClient(), c.store.GetOracle()) startTime := time.Now() resp, _, err := sender.SendReq(bo, req, batch.region, client.ReadTimeoutShort) diagCtx.reqDuration = time.Since(startTime) diff --git a/txnkv/transaction/prewrite.go b/txnkv/transaction/prewrite.go index e83fee3f8..d74b5fe6d 100644 --- a/txnkv/transaction/prewrite.go +++ b/txnkv/transaction/prewrite.go @@ -268,7 +268,7 @@ func (action actionPrewrite) handleSingleBatch( attempts := 0 req := c.buildPrewriteRequest(batch, txnSize) - sender := locate.NewRegionRequestSender(c.store.GetRegionCache(), c.store.GetTiKVClient()) + sender := locate.NewRegionRequestSender(c.store.GetRegionCache(), c.store.GetTiKVClient(), c.store.GetOracle()) var resolvingRecordToken *int defer func() { if err != nil { diff --git a/txnkv/txnsnapshot/client_helper.go b/txnkv/txnsnapshot/client_helper.go index 259f21926..acbc22d69 100644 --- a/txnkv/txnsnapshot/client_helper.go +++ b/txnkv/txnsnapshot/client_helper.go @@ -40,6 +40,7 @@ import ( "github.com/tikv/client-go/v2/internal/client" "github.com/tikv/client-go/v2/internal/locate" "github.com/tikv/client-go/v2/internal/retry" + "github.com/tikv/client-go/v2/oracle" "github.com/tikv/client-go/v2/tikvrpc" "github.com/tikv/client-go/v2/txnkv/txnlock" "github.com/tikv/client-go/v2/util" @@ -63,6 +64,7 @@ type ClientHelper struct { client client.Client resolveLite bool locate.RegionRequestRuntimeStats + oracle oracle.Oracle } // NewClientHelper creates a helper instance. @@ -74,6 +76,7 @@ func NewClientHelper(store kvstore, resolvedLocks *util.TSSet, committedLocks *u committedLocks: committedLocks, client: store.GetTiKVClient(), resolveLite: resolveLite, + oracle: store.GetOracle(), } } @@ -136,7 +139,7 @@ func (ch *ClientHelper) ResolveLocksDone(callerStartTS uint64, token int) { // SendReqCtx wraps the SendReqCtx function and use the resolved lock result in the kvrpcpb.Context. func (ch *ClientHelper) SendReqCtx(bo *retry.Backoffer, req *tikvrpc.Request, regionID locate.RegionVerID, timeout time.Duration, et tikvrpc.EndpointType, directStoreAddr string, opts ...locate.StoreSelectorOption) (*tikvrpc.Response, *locate.RPCContext, string, error) { - sender := locate.NewRegionRequestSender(ch.regionCache, ch.client) + sender := locate.NewRegionRequestSender(ch.regionCache, ch.client, ch.oracle) if len(directStoreAddr) > 0 { sender.SetStoreAddr(directStoreAddr) } diff --git a/txnkv/txnsnapshot/scan.go b/txnkv/txnsnapshot/scan.go index 7b07a920a..f5d7b32f5 100644 --- a/txnkv/txnsnapshot/scan.go +++ b/txnkv/txnsnapshot/scan.go @@ -197,7 +197,7 @@ func (s *Scanner) getData(bo *retry.Backoffer) error { zap.String("nextEndKey", kv.StrKey(s.nextEndKey)), zap.Bool("reverse", s.reverse), zap.Uint64("txnStartTS", s.startTS())) - sender := locate.NewRegionRequestSender(s.snapshot.store.GetRegionCache(), s.snapshot.store.GetTiKVClient()) + sender := locate.NewRegionRequestSender(s.snapshot.store.GetRegionCache(), s.snapshot.store.GetTiKVClient(), s.snapshot.store.GetOracle()) var reqEndKey, reqStartKey []byte var loc *locate.KeyLocation var resolvingRecordToken *int