From c586e5c0a21bbf7c82ee53ad45caeb8e7d4a9776 Mon Sep 17 00:00:00 2001 From: Bin Shi <39923490+binshi-bing@users.noreply.github.com> Date: Thu, 29 Jun 2023 20:44:11 -0700 Subject: [PATCH 1/2] Fix data race between read APIs and finshiSplit/finishMerge in keyspace group manager (#6723) close tikv/pd#6721 checkTSOMerge and checkTSOSplit will read from kgm.getKeyspaceGroupMeta finishMergeKeyspaceGroup and finishSplitKeyspaceGroup will update kgm so just return a copy to avoid data race Signed-off-by: Bin Shi --- pkg/tso/keyspace_group_manager.go | 98 +++++++++++++++++++------------ 1 file changed, 60 insertions(+), 38 deletions(-) diff --git a/pkg/tso/keyspace_group_manager.go b/pkg/tso/keyspace_group_manager.go index a82376430fa..0291bc5863d 100644 --- a/pkg/tso/keyspace_group_manager.go +++ b/pkg/tso/keyspace_group_manager.go @@ -112,6 +112,39 @@ func (s *state) getKeyspaceGroupMeta( return s.ams[groupID], s.kgs[groupID] } +func (s *state) checkTSOSplit( + targetGroupID uint32, +) (splitTargetAM, splitSourceAM *AllocatorManager, err error) { + s.RLock() + defer s.RUnlock() + splitTargetAM, splitTargetGroup := s.ams[targetGroupID], s.kgs[targetGroupID] + // Only the split target keyspace group needs to check the TSO split. + if !splitTargetGroup.IsSplitTarget() { + return nil, nil, nil // it isn't in the split state + } + sourceGroupID := splitTargetGroup.SplitSource() + splitSourceAM, splitSourceGroup := s.ams[sourceGroupID], s.kgs[sourceGroupID] + if splitSourceAM == nil || splitSourceGroup == nil { + log.Error("the split source keyspace group is not initialized", + zap.Uint32("source", sourceGroupID)) + return nil, nil, errs.ErrKeyspaceGroupNotInitialized.FastGenByArgs(sourceGroupID) + } + return splitTargetAM, splitSourceAM, nil +} + +// Reject any request if the keyspace group is in merging state, +// we need to wait for the merging checker to finish the TSO merging. +func (s *state) checkTSOMerge( + groupID uint32, +) error { + s.RLock() + defer s.RUnlock() + if s.kgs[groupID] == nil || !s.kgs[groupID].IsMerging() { + return nil + } + return errs.ErrKeyspaceGroupIsMerging.FastGenByArgs(groupID) +} + // getKeyspaceGroupMetaWithCheck returns the keyspace group meta of the given keyspace. // It also checks if the keyspace is served by the given keyspace group. If not, it returns the meta // of the keyspace group to which the keyspace currently belongs and returns NotServed (by the given @@ -957,7 +990,7 @@ func (kgm *KeyspaceGroupManager) HandleTSORequest( if err != nil { return pdpb.Timestamp{}, curKeyspaceGroupID, err } - err = kgm.checkTSOMerge(curKeyspaceGroupID) + err = kgm.state.checkTSOMerge(curKeyspaceGroupID) if err != nil { return pdpb.Timestamp{}, curKeyspaceGroupID, err } @@ -1032,19 +1065,11 @@ func (kgm *KeyspaceGroupManager) checkTSOSplit( keyspaceGroupID uint32, dcLocation string, ) error { - splitAM, splitGroup := kgm.getKeyspaceGroupMeta(keyspaceGroupID) - // Only the split target keyspace group needs to check the TSO split. - if !splitGroup.IsSplitTarget() { - return nil - } - splitSource := splitGroup.SplitSource() - splitSourceAM, splitSourceGroup := kgm.getKeyspaceGroupMeta(splitSource) - if splitSourceAM == nil || splitSourceGroup == nil { - log.Error("the split source keyspace group is not initialized", - zap.Uint32("source", splitSource)) - return errs.ErrKeyspaceGroupNotInitialized.FastGenByArgs(splitSource) + splitTargetAM, splitSourceAM, err := kgm.state.checkTSOSplit(keyspaceGroupID) + if err != nil || splitTargetAM == nil { + return err } - splitAllocator, err := splitAM.GetAllocator(dcLocation) + splitTargetAllocator, err := splitTargetAM.GetAllocator(dcLocation) if err != nil { return err } @@ -1052,7 +1077,7 @@ func (kgm *KeyspaceGroupManager) checkTSOSplit( if err != nil { return err } - splitTSO, err := splitAllocator.GenerateTSO(1) + splitTargetTSO, err := splitTargetAllocator.GenerateTSO(1) if err != nil { return err } @@ -1061,19 +1086,19 @@ func (kgm *KeyspaceGroupManager) checkTSOSplit( return err } // If the split source TSO is not greater than the newly split TSO, we don't need to do anything. - if tsoutil.CompareTimestamp(&splitSourceTSO, &splitTSO) <= 0 { + if tsoutil.CompareTimestamp(&splitSourceTSO, &splitTargetTSO) <= 0 { log.Info("the split source tso is less than the newly split tso", zap.Int64("split-source-tso-physical", splitSourceTSO.Physical), zap.Int64("split-source-tso-logical", splitSourceTSO.Logical), - zap.Int64("split-tso-physical", splitTSO.Physical), - zap.Int64("split-tso-logical", splitTSO.Logical)) + zap.Int64("split-tso-physical", splitTargetTSO.Physical), + zap.Int64("split-tso-logical", splitTargetTSO.Logical)) // Finish the split state directly. return kgm.finishSplitKeyspaceGroup(keyspaceGroupID) } // If the split source TSO is greater than the newly split TSO, we need to update the split // TSO to make sure the following TSO will be greater than the split keyspaces ever had // in the past. - err = splitAllocator.SetTSO(tsoutil.GenerateTS(&pdpb.Timestamp{ + err = splitTargetAllocator.SetTSO(tsoutil.GenerateTS(&pdpb.Timestamp{ Physical: splitSourceTSO.Physical + 1, Logical: splitSourceTSO.Logical, }), true, true) @@ -1083,8 +1108,8 @@ func (kgm *KeyspaceGroupManager) checkTSOSplit( log.Info("the split source tso is greater than the newly split tso", zap.Int64("split-source-tso-physical", splitSourceTSO.Physical), zap.Int64("split-source-tso-logical", splitSourceTSO.Logical), - zap.Int64("split-tso-physical", splitTSO.Physical), - zap.Int64("split-tso-logical", splitTSO.Logical)) + zap.Int64("split-tso-physical", splitTargetTSO.Physical), + zap.Int64("split-tso-logical", splitTargetTSO.Logical)) // Finish the split state. return kgm.finishSplitKeyspaceGroup(keyspaceGroupID) } @@ -1116,9 +1141,13 @@ func (kgm *KeyspaceGroupManager) finishSplitKeyspaceGroup(id uint32) error { zap.Int("status-code", statusCode)) return errs.ErrSendRequest.FastGenByArgs() } - // Pre-update the split keyspace group split state in memory. - splitGroup.SplitState = nil - kgm.kgs[id] = splitGroup + // Pre-update the split keyspace group's split state in memory. + // Note: to avoid data race with state read APIs, we always replace the group in memory as a whole. + // For now, we only have scenarios to update split state/merge state, and the other fields are always + // loaded from etcd without any modification, so we can simply copy the group and replace the state. + newSplitGroup := *splitGroup + newSplitGroup.SplitState = nil + kgm.kgs[id] = &newSplitGroup return nil } @@ -1146,9 +1175,14 @@ func (kgm *KeyspaceGroupManager) finishMergeKeyspaceGroup(id uint32) error { zap.Int("status-code", statusCode)) return errs.ErrSendRequest.FastGenByArgs() } - // Pre-update the split keyspace group split state in memory. - mergeTarget.MergeState = nil - kgm.kgs[id] = mergeTarget + + // Pre-update the merge target keyspace group's merge state in memory. + // Note: to avoid data race with state read APIs, we always replace the group in memory as a whole. + // For now, we only have scenarios to update split state/merge state, and the other fields are always + // loaded from etcd without any modification, so we can simply copy the group and replace the state. + newTargetGroup := *mergeTarget + newTargetGroup.MergeState = nil + kgm.kgs[id] = &newTargetGroup return nil } @@ -1286,15 +1320,3 @@ func (kgm *KeyspaceGroupManager) mergingChecker(ctx context.Context, mergeTarget return } } - -// Reject any request if the keyspace group is in merging state, -// we need to wait for the merging checker to finish the TSO merging. -func (kgm *KeyspaceGroupManager) checkTSOMerge( - keyspaceGroupID uint32, -) error { - _, group := kgm.getKeyspaceGroupMeta(keyspaceGroupID) - if !group.IsMerging() { - return nil - } - return errs.ErrKeyspaceGroupIsMerging.FastGenByArgs(keyspaceGroupID) -} From 0fe5eb40ba38c1139b9cc0551a2b3a22982fe356 Mon Sep 17 00:00:00 2001 From: lhy1024 Date: Fri, 30 Jun 2023 16:32:12 +0800 Subject: [PATCH 2/2] tso: fix memory leak introduced by timer.After (#6730) close tikv/pd#6719, ref tikv/pd#6720 Signed-off-by: lhy1024 Co-authored-by: ti-chi-bot[bot] <108142056+ti-chi-bot[bot]@users.noreply.github.com> --- pkg/timerpool/pool.go | 43 ++++++++++++++++++ pkg/timerpool/pool_test.go | 70 +++++++++++++++++++++++++++++ pkg/utils/tsoutil/tso_dispatcher.go | 17 ++++--- server/grpc_service.go | 12 +++-- 4 files changed, 133 insertions(+), 9 deletions(-) create mode 100644 pkg/timerpool/pool.go create mode 100644 pkg/timerpool/pool_test.go diff --git a/pkg/timerpool/pool.go b/pkg/timerpool/pool.go new file mode 100644 index 00000000000..28ffacfc629 --- /dev/null +++ b/pkg/timerpool/pool.go @@ -0,0 +1,43 @@ +// Copyright 2020 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Note: This file is copied from https://go-review.googlesource.com/c/go/+/276133 + +package timerpool + +import ( + "sync" + "time" +) + +// GlobalTimerPool is a global pool for reusing *time.Timer. +var GlobalTimerPool TimerPool + +// TimerPool is a wrapper of sync.Pool which caches *time.Timer for reuse. +type TimerPool struct { + pool sync.Pool +} + +// Get returns a timer with a given duration. +func (tp *TimerPool) Get(d time.Duration) *time.Timer { + if v := tp.pool.Get(); v != nil { + timer := v.(*time.Timer) + timer.Reset(d) + return timer + } + return time.NewTimer(d) +} + +// Put tries to call timer.Stop() before putting it back into pool, +// if the timer.Stop() returns false (it has either already expired or been stopped), +// have a shot at draining the channel with residual time if there is one. +func (tp *TimerPool) Put(timer *time.Timer) { + if !timer.Stop() { + select { + case <-timer.C: + default: + } + } + tp.pool.Put(timer) +} diff --git a/pkg/timerpool/pool_test.go b/pkg/timerpool/pool_test.go new file mode 100644 index 00000000000..d6dffc723a9 --- /dev/null +++ b/pkg/timerpool/pool_test.go @@ -0,0 +1,70 @@ +// Copyright 2020 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Note: This file is copied from https://go-review.googlesource.com/c/go/+/276133 + +package timerpool + +import ( + "testing" + "time" +) + +func TestTimerPool(t *testing.T) { + var tp TimerPool + + for i := 0; i < 100; i++ { + timer := tp.Get(20 * time.Millisecond) + + select { + case <-timer.C: + t.Errorf("timer expired too early") + continue + default: + } + + select { + case <-time.After(100 * time.Millisecond): + t.Errorf("timer didn't expire on time") + case <-timer.C: + } + + tp.Put(timer) + } +} + +const timeout = 10 * time.Millisecond + +func BenchmarkTimerUtilization(b *testing.B) { + b.Run("TimerWithPool", func(b *testing.B) { + for i := 0; i < b.N; i++ { + t := GlobalTimerPool.Get(timeout) + GlobalTimerPool.Put(t) + } + }) + b.Run("TimerWithoutPool", func(b *testing.B) { + for i := 0; i < b.N; i++ { + t := time.NewTimer(timeout) + t.Stop() + } + }) +} + +func BenchmarkTimerPoolParallel(b *testing.B) { + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + t := GlobalTimerPool.Get(timeout) + GlobalTimerPool.Put(t) + } + }) +} + +func BenchmarkTimerNativeParallel(b *testing.B) { + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + t := time.NewTimer(timeout) + t.Stop() + } + }) +} diff --git a/pkg/utils/tsoutil/tso_dispatcher.go b/pkg/utils/tsoutil/tso_dispatcher.go index 69baf4b1e41..f9585ba5cdd 100644 --- a/pkg/utils/tsoutil/tso_dispatcher.go +++ b/pkg/utils/tsoutil/tso_dispatcher.go @@ -24,6 +24,7 @@ import ( "github.com/pingcap/log" "github.com/prometheus/client_golang/prometheus" "github.com/tikv/pd/pkg/errs" + "github.com/tikv/pd/pkg/timerpool" "github.com/tikv/pd/pkg/utils/etcdutil" "github.com/tikv/pd/pkg/utils/logutil" "go.uber.org/zap" @@ -197,7 +198,7 @@ func (s *TSODispatcher) finishRequest(requests []Request, physical, firstLogical // TSDeadline is used to watch the deadline of each tso request. type TSDeadline struct { - timer <-chan time.Time + timer *time.Timer done chan struct{} cancel context.CancelFunc } @@ -208,8 +209,9 @@ func NewTSDeadline( done chan struct{}, cancel context.CancelFunc, ) *TSDeadline { + timer := timerpool.GlobalTimerPool.Get(timeout) return &TSDeadline{ - timer: time.After(timeout), + timer: timer, done: done, cancel: cancel, } @@ -224,13 +226,15 @@ func WatchTSDeadline(ctx context.Context, tsDeadlineCh <-chan *TSDeadline) { select { case d := <-tsDeadlineCh: select { - case <-d.timer: + case <-d.timer.C: log.Error("tso proxy request processing is canceled due to timeout", errs.ZapError(errs.ErrProxyTSOTimeout)) d.cancel() + timerpool.GlobalTimerPool.Put(d.timer) case <-d.done: - continue + timerpool.GlobalTimerPool.Put(d.timer) case <-ctx.Done(): + timerpool.GlobalTimerPool.Put(d.timer) return } case <-ctx.Done(): @@ -241,11 +245,12 @@ func WatchTSDeadline(ctx context.Context, tsDeadlineCh <-chan *TSDeadline) { func checkStream(streamCtx context.Context, cancel context.CancelFunc, done chan struct{}) { defer logutil.LogPanic() - + timer := time.NewTimer(3 * time.Second) + defer timer.Stop() select { case <-done: return - case <-time.After(3 * time.Second): + case <-timer.C: cancel() case <-streamCtx.Done(): } diff --git a/server/grpc_service.go b/server/grpc_service.go index 1badabb19d8..f66bd37ed11 100644 --- a/server/grpc_service.go +++ b/server/grpc_service.go @@ -606,13 +606,15 @@ func (s *tsoServer) Send(m *pdpb.TsoResponse) error { }) done <- s.stream.Send(m) }() + timer := time.NewTimer(tsoutil.DefaultTSOProxyTimeout) + defer timer.Stop() select { case err := <-done: if err != nil { atomic.StoreInt32(&s.closed, 1) } return errors.WithStack(err) - case <-time.After(tsoutil.DefaultTSOProxyTimeout): + case <-timer.C: atomic.StoreInt32(&s.closed, 1) return ErrForwardTSOTimeout } @@ -633,6 +635,8 @@ func (s *tsoServer) Recv(timeout time.Duration) (*pdpb.TsoRequest, error) { request, err := s.stream.Recv() requestCh <- &pdpbTSORequest{request: request, err: err} }() + timer := time.NewTimer(timeout) + defer timer.Stop() select { case req := <-requestCh: if req.err != nil { @@ -640,7 +644,7 @@ func (s *tsoServer) Recv(timeout time.Duration) (*pdpb.TsoRequest, error) { return nil, errors.WithStack(req.err) } return req.request, nil - case <-time.After(timeout): + case <-timer.C: atomic.StoreInt32(&s.closed, 1) return nil, ErrTSOProxyRecvFromClientTimeout } @@ -2173,10 +2177,12 @@ func forwardReportBucketClientToServer(forwardStream pdpb.PD_ReportBucketsClient // TODO: If goroutine here timeout when tso stream created successfully, we need to handle it correctly. func checkStream(streamCtx context.Context, cancel context.CancelFunc, done chan struct{}) { defer logutil.LogPanic() + timer := time.NewTimer(3 * time.Second) + defer timer.Stop() select { case <-done: return - case <-time.After(3 * time.Second): + case <-timer.C: cancel() case <-streamCtx.Done(): }