Skip to content

Commit

Permalink
client: add the missing forward ctx (tikv#6895)
Browse files Browse the repository at this point in the history
close tikv#6894

Signed-off-by: lhy1024 <[email protected]>
  • Loading branch information
lhy1024 authored Aug 7, 2023
1 parent 365e384 commit 1faac0c
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 12 deletions.
15 changes: 15 additions & 0 deletions client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -1457,6 +1457,9 @@ func trimHTTPPrefix(str string) string {
}

func (c *client) LoadGlobalConfig(ctx context.Context, names []string, configPath string) ([]GlobalConfigItem, int64, error) {
ctx, cancel := context.WithTimeout(ctx, c.option.timeout)
defer cancel()
ctx = grpcutil.BuildForwardContext(ctx, c.GetLeaderAddr())
protoClient := c.getClient()
if protoClient == nil {
return nil, 0, errs.ErrClientGetProtoClient
Expand Down Expand Up @@ -1486,6 +1489,9 @@ func (c *client) StoreGlobalConfig(ctx context.Context, configPath string, items
for i, it := range items {
resArr[i] = &pdpb.GlobalConfigItem{Name: it.Name, Value: it.Value, Kind: it.EventType, Payload: it.PayLoad}
}
ctx, cancel := context.WithTimeout(ctx, c.option.timeout)
defer cancel()
ctx = grpcutil.BuildForwardContext(ctx, c.GetLeaderAddr())
protoClient := c.getClient()
if protoClient == nil {
return errs.ErrClientGetProtoClient
Expand All @@ -1501,6 +1507,9 @@ func (c *client) WatchGlobalConfig(ctx context.Context, configPath string, revis
// TODO: Add retry mechanism
// register watch components there
globalConfigWatcherCh := make(chan []GlobalConfigItem, 16)
ctx, cancel := context.WithTimeout(ctx, c.option.timeout)
defer cancel()
ctx = grpcutil.BuildForwardContext(ctx, c.GetLeaderAddr())
protoClient := c.getClient()
if protoClient == nil {
return nil, errs.ErrClientGetProtoClient
Expand Down Expand Up @@ -1547,6 +1556,9 @@ func (c *client) WatchGlobalConfig(ctx context.Context, configPath string, revis
}

func (c *client) GetExternalTimestamp(ctx context.Context) (uint64, error) {
ctx, cancel := context.WithTimeout(ctx, c.option.timeout)
defer cancel()
ctx = grpcutil.BuildForwardContext(ctx, c.GetLeaderAddr())
protoClient := c.getClient()
if protoClient == nil {
return 0, errs.ErrClientGetProtoClient
Expand All @@ -1565,6 +1577,9 @@ func (c *client) GetExternalTimestamp(ctx context.Context) (uint64, error) {
}

func (c *client) SetExternalTimestamp(ctx context.Context, timestamp uint64) error {
ctx, cancel := context.WithTimeout(ctx, c.option.timeout)
defer cancel()
ctx = grpcutil.BuildForwardContext(ctx, c.GetLeaderAddr())
protoClient := c.getClient()
if protoClient == nil {
return errs.ErrClientGetProtoClient
Expand Down
3 changes: 3 additions & 0 deletions client/gc_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,9 @@ func (c *client) WatchGCSafePointV2(ctx context.Context, revision int64) (chan [
Revision: revision,
}

ctx, cancel := context.WithTimeout(ctx, c.option.timeout)
defer cancel()
ctx = grpcutil.BuildForwardContext(ctx, c.GetLeaderAddr())
protoClient := c.getClient()
if protoClient == nil {
return nil, errs.ErrClientGetProtoClient
Expand Down
12 changes: 8 additions & 4 deletions pkg/replication/replication_mode.go
Original file line number Diff line number Diff line change
Expand Up @@ -340,25 +340,29 @@ func (m *ModeManager) Run(ctx context.Context) {

go func() {
defer wg.Done()
ticker := time.NewTicker(tickInterval)
defer ticker.Stop()
for {
select {
case <-time.After(tickInterval):
case <-ticker.C:
m.tickUpdateState()
case <-ctx.Done():
return
}
m.tickUpdateState()
}
}()

go func() {
defer wg.Done()
ticker := time.NewTicker(replicateStateInterval)
defer ticker.Stop()
for {
select {
case <-time.After(replicateStateInterval):
case <-ticker.C:
m.tickReplicateStatus()
case <-ctx.Done():
return
}
m.tickReplicateStatus()
}
}()

Expand Down
9 changes: 6 additions & 3 deletions tests/integrations/client/gc_client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -153,12 +153,14 @@ func (suite *gcClientTestSuite) testClientWatchWithRevision(fromNewRevision bool
watchChan, err := suite.client.WatchGCSafePointV2(suite.server.Context(), startRevision)
suite.NoError(err)

timeout := time.After(time.Second)

timer := time.NewTimer(time.Second)
defer timer.Stop()
isFirstUpdate := true
runTest := false
for {
select {
case <-timeout:
case <-timer.C:
suite.True(runTest)
return
case res := <-watchChan:
for _, r := range res {
Expand All @@ -174,6 +176,7 @@ func (suite *gcClientTestSuite) testClientWatchWithRevision(fromNewRevision bool
continue
}
}
runTest = true
}
}
}
Expand Down
16 changes: 11 additions & 5 deletions tests/integrations/client/global_config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -278,19 +278,20 @@ func (suite *globalConfigTestSuite) TestClientStore() {
}

func (suite *globalConfigTestSuite) TestClientWatchWithRevision() {
ctx := suite.server.Context()
defer func() {
_, err := suite.server.GetClient().Delete(suite.server.Context(), suite.GetEtcdPath("test"))
_, err := suite.server.GetClient().Delete(ctx, suite.GetEtcdPath("test"))
suite.NoError(err)

for i := 3; i < 9; i++ {
_, err := suite.server.GetClient().Delete(suite.server.Context(), suite.GetEtcdPath(strconv.Itoa(i)))
_, err := suite.server.GetClient().Delete(ctx, suite.GetEtcdPath(strconv.Itoa(i)))
suite.NoError(err)
}
}()
// Mock get revision by loading
r, err := suite.server.GetClient().Put(suite.server.Context(), suite.GetEtcdPath("test"), "test")
r, err := suite.server.GetClient().Put(ctx, suite.GetEtcdPath("test"), "test")
suite.NoError(err)
res, revision, err := suite.client.LoadGlobalConfig(suite.server.Context(), nil, globalConfigPath)
res, revision, err := suite.client.LoadGlobalConfig(ctx, nil, globalConfigPath)
suite.NoError(err)
suite.Len(res, 1)
suite.LessOrEqual(r.Header.GetRevision(), revision)
Expand All @@ -313,14 +314,19 @@ func (suite *globalConfigTestSuite) TestClientWatchWithRevision() {
_, err = suite.server.GetClient().Put(suite.server.Context(), suite.GetEtcdPath(strconv.Itoa(i)), strconv.Itoa(i))
suite.NoError(err)
}
timer := time.NewTimer(time.Second)
defer timer.Stop()
runTest := false
for {
select {
case <-time.After(time.Second):
case <-timer.C:
suite.True(runTest)
return
case res := <-configChan:
for _, r := range res {
suite.Equal(suite.GetEtcdPath(r.Value), r.Name)
}
runTest = true
}
}
}
Expand Down

0 comments on commit 1faac0c

Please sign in to comment.