From 1faac0cd4aac605a169af9a7f1e0b4e70280cb06 Mon Sep 17 00:00:00 2001 From: lhy1024 Date: Mon, 7 Aug 2023 10:35:40 +0800 Subject: [PATCH] client: add the missing forward ctx (#6895) close tikv/pd#6894 Signed-off-by: lhy1024 --- client/client.go | 15 +++++++++++++++ client/gc_client.go | 3 +++ pkg/replication/replication_mode.go | 12 ++++++++---- tests/integrations/client/gc_client_test.go | 9 ++++++--- tests/integrations/client/global_config_test.go | 16 +++++++++++----- 5 files changed, 43 insertions(+), 12 deletions(-) diff --git a/client/client.go b/client/client.go index b1b6e862e74..d3d3805fc4d 100644 --- a/client/client.go +++ b/client/client.go @@ -1457,6 +1457,9 @@ func trimHTTPPrefix(str string) string { } func (c *client) LoadGlobalConfig(ctx context.Context, names []string, configPath string) ([]GlobalConfigItem, int64, error) { + ctx, cancel := context.WithTimeout(ctx, c.option.timeout) + defer cancel() + ctx = grpcutil.BuildForwardContext(ctx, c.GetLeaderAddr()) protoClient := c.getClient() if protoClient == nil { return nil, 0, errs.ErrClientGetProtoClient @@ -1486,6 +1489,9 @@ func (c *client) StoreGlobalConfig(ctx context.Context, configPath string, items for i, it := range items { resArr[i] = &pdpb.GlobalConfigItem{Name: it.Name, Value: it.Value, Kind: it.EventType, Payload: it.PayLoad} } + ctx, cancel := context.WithTimeout(ctx, c.option.timeout) + defer cancel() + ctx = grpcutil.BuildForwardContext(ctx, c.GetLeaderAddr()) protoClient := c.getClient() if protoClient == nil { return errs.ErrClientGetProtoClient @@ -1501,6 +1507,9 @@ func (c *client) WatchGlobalConfig(ctx context.Context, configPath string, revis // TODO: Add retry mechanism // register watch components there globalConfigWatcherCh := make(chan []GlobalConfigItem, 16) + ctx, cancel := context.WithTimeout(ctx, c.option.timeout) + defer cancel() + ctx = grpcutil.BuildForwardContext(ctx, c.GetLeaderAddr()) protoClient := c.getClient() if protoClient == nil { return nil, errs.ErrClientGetProtoClient @@ -1547,6 +1556,9 @@ func (c *client) WatchGlobalConfig(ctx context.Context, configPath string, revis } func (c *client) GetExternalTimestamp(ctx context.Context) (uint64, error) { + ctx, cancel := context.WithTimeout(ctx, c.option.timeout) + defer cancel() + ctx = grpcutil.BuildForwardContext(ctx, c.GetLeaderAddr()) protoClient := c.getClient() if protoClient == nil { return 0, errs.ErrClientGetProtoClient @@ -1565,6 +1577,9 @@ func (c *client) GetExternalTimestamp(ctx context.Context) (uint64, error) { } func (c *client) SetExternalTimestamp(ctx context.Context, timestamp uint64) error { + ctx, cancel := context.WithTimeout(ctx, c.option.timeout) + defer cancel() + ctx = grpcutil.BuildForwardContext(ctx, c.GetLeaderAddr()) protoClient := c.getClient() if protoClient == nil { return errs.ErrClientGetProtoClient diff --git a/client/gc_client.go b/client/gc_client.go index c573836d2ba..b5d64e25129 100644 --- a/client/gc_client.go +++ b/client/gc_client.go @@ -102,6 +102,9 @@ func (c *client) WatchGCSafePointV2(ctx context.Context, revision int64) (chan [ Revision: revision, } + ctx, cancel := context.WithTimeout(ctx, c.option.timeout) + defer cancel() + ctx = grpcutil.BuildForwardContext(ctx, c.GetLeaderAddr()) protoClient := c.getClient() if protoClient == nil { return nil, errs.ErrClientGetProtoClient diff --git a/pkg/replication/replication_mode.go b/pkg/replication/replication_mode.go index fee87e85b5a..5a52f562e60 100644 --- a/pkg/replication/replication_mode.go +++ b/pkg/replication/replication_mode.go @@ -340,25 +340,29 @@ func (m *ModeManager) Run(ctx context.Context) { go func() { defer wg.Done() + ticker := time.NewTicker(tickInterval) + defer ticker.Stop() for { select { - case <-time.After(tickInterval): + case <-ticker.C: + m.tickUpdateState() case <-ctx.Done(): return } - m.tickUpdateState() } }() go func() { defer wg.Done() + ticker := time.NewTicker(replicateStateInterval) + defer ticker.Stop() for { select { - case <-time.After(replicateStateInterval): + case <-ticker.C: + m.tickReplicateStatus() case <-ctx.Done(): return } - m.tickReplicateStatus() } }() diff --git a/tests/integrations/client/gc_client_test.go b/tests/integrations/client/gc_client_test.go index acb1c458812..24ee8506efd 100644 --- a/tests/integrations/client/gc_client_test.go +++ b/tests/integrations/client/gc_client_test.go @@ -153,12 +153,14 @@ func (suite *gcClientTestSuite) testClientWatchWithRevision(fromNewRevision bool watchChan, err := suite.client.WatchGCSafePointV2(suite.server.Context(), startRevision) suite.NoError(err) - timeout := time.After(time.Second) - + timer := time.NewTimer(time.Second) + defer timer.Stop() isFirstUpdate := true + runTest := false for { select { - case <-timeout: + case <-timer.C: + suite.True(runTest) return case res := <-watchChan: for _, r := range res { @@ -174,6 +176,7 @@ func (suite *gcClientTestSuite) testClientWatchWithRevision(fromNewRevision bool continue } } + runTest = true } } } diff --git a/tests/integrations/client/global_config_test.go b/tests/integrations/client/global_config_test.go index 15034d035a6..6384adbd8f1 100644 --- a/tests/integrations/client/global_config_test.go +++ b/tests/integrations/client/global_config_test.go @@ -278,19 +278,20 @@ func (suite *globalConfigTestSuite) TestClientStore() { } func (suite *globalConfigTestSuite) TestClientWatchWithRevision() { + ctx := suite.server.Context() defer func() { - _, err := suite.server.GetClient().Delete(suite.server.Context(), suite.GetEtcdPath("test")) + _, err := suite.server.GetClient().Delete(ctx, suite.GetEtcdPath("test")) suite.NoError(err) for i := 3; i < 9; i++ { - _, err := suite.server.GetClient().Delete(suite.server.Context(), suite.GetEtcdPath(strconv.Itoa(i))) + _, err := suite.server.GetClient().Delete(ctx, suite.GetEtcdPath(strconv.Itoa(i))) suite.NoError(err) } }() // Mock get revision by loading - r, err := suite.server.GetClient().Put(suite.server.Context(), suite.GetEtcdPath("test"), "test") + r, err := suite.server.GetClient().Put(ctx, suite.GetEtcdPath("test"), "test") suite.NoError(err) - res, revision, err := suite.client.LoadGlobalConfig(suite.server.Context(), nil, globalConfigPath) + res, revision, err := suite.client.LoadGlobalConfig(ctx, nil, globalConfigPath) suite.NoError(err) suite.Len(res, 1) suite.LessOrEqual(r.Header.GetRevision(), revision) @@ -313,14 +314,19 @@ func (suite *globalConfigTestSuite) TestClientWatchWithRevision() { _, err = suite.server.GetClient().Put(suite.server.Context(), suite.GetEtcdPath(strconv.Itoa(i)), strconv.Itoa(i)) suite.NoError(err) } + timer := time.NewTimer(time.Second) + defer timer.Stop() + runTest := false for { select { - case <-time.After(time.Second): + case <-timer.C: + suite.True(runTest) return case res := <-configChan: for _, r := range res { suite.Equal(suite.GetEtcdPath(r.Value), r.Name) } + runTest = true } } }