From 073b353c682724a1b65b225ca4ef850bed1ca7bb Mon Sep 17 00:00:00 2001 From: Rueian Date: Wed, 27 Nov 2024 11:18:42 -0800 Subject: [PATCH] feat: allow setting negative values to disable keepalive and write timeout --- cluster.go | 10 ++++++++-- redis_test.go | 37 +++++++++++++++++++++++++++++++++++++ rueidis.go | 2 +- 3 files changed, 46 insertions(+), 3 deletions(-) diff --git a/cluster.go b/cluster.go index 455dada2..324507b5 100644 --- a/cluster.go +++ b/cluster.go @@ -144,8 +144,14 @@ func (s clusterslots) parse(tls bool) map[string]group { } func getClusterSlots(c conn, timeout time.Duration) clusterslots { - ctx, cancel := context.WithTimeout(context.Background(), timeout) - defer cancel() + var ctx context.Context + var cancel context.CancelFunc + if timeout > 0 { + ctx, cancel = context.WithTimeout(context.Background(), timeout) + defer cancel() + } else { + ctx = context.Background() + } v := c.Version() if v < 8 { return clusterslots{reply: c.Do(ctx, cmds.SlotCmd), addr: c.Addr(), ver: v} diff --git a/redis_test.go b/redis_test.go index 4595ac6d..68d3f8bc 100644 --- a/redis_test.go +++ b/redis_test.go @@ -1015,3 +1015,40 @@ func TestKvrocksSingleClientIntegration(t *testing.T) { client.Close() } + +func TestNegativeConnWriteTimeout(t *testing.T) { + defer ShouldNotLeaked(SetupLeakDetection()) + client, err := NewClient(ClientOption{ + InitAddress: []string{"127.0.0.1:6379"}, + ConnWriteTimeout: -1, + }) + if err != nil { + t.Fatal(err) + } + client.Close() +} + +func TestNegativeKeepalive(t *testing.T) { + defer ShouldNotLeaked(SetupLeakDetection()) + client, err := NewClient(ClientOption{ + InitAddress: []string{"127.0.0.1:6379"}, + Dialer: net.Dialer{KeepAlive: -1}, + }) + if err != nil { + t.Fatal(err) + } + client.Close() +} + +func TestNegativeConnWriteTimeoutKeepalive(t *testing.T) { + defer ShouldNotLeaked(SetupLeakDetection()) + client, err := NewClient(ClientOption{ + InitAddress: []string{"127.0.0.1:6379"}, + Dialer: net.Dialer{KeepAlive: -1}, + ConnWriteTimeout: -1, + }) + if err != nil { + t.Fatal(err) + } + client.Close() +} diff --git a/rueidis.go b/rueidis.go index 1da3a178..2c6c9ace 100644 --- a/rueidis.go +++ b/rueidis.go @@ -371,7 +371,7 @@ func NewClient(option ClientOption) (client Client, err error) { option.Dialer.KeepAlive = DefaultTCPKeepAlive } if option.ConnWriteTimeout == 0 { - option.ConnWriteTimeout = option.Dialer.KeepAlive * 10 + option.ConnWriteTimeout = max(DefaultTCPKeepAlive, option.Dialer.KeepAlive) * 10 } if option.BlockingPipeline == 0 { option.BlockingPipeline = DefaultBlockingPipeline