Skip to content

Commit

Permalink
fix: hide os.ErrDeadlineExceeded with context.DeadlineExceeded only w…
Browse files Browse the repository at this point in the history
…hen ConnWriteTimeout is shorter (#672)

Signed-off-by: Rueian <[email protected]>
Co-authored-by: castaneai <[email protected]>
  • Loading branch information
rueian and castaneai authored Nov 15, 2024
1 parent e84fc9e commit 1c5e799
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 19 deletions.
8 changes: 5 additions & 3 deletions pipe.go
Original file line number Diff line number Diff line change
Expand Up @@ -590,7 +590,7 @@ func (p *pipe) backgroundPing() {
go func() { ch <- p.Do(context.Background(), cmds.PingCmd).NonRedisError() }()
select {
case <-tm.C:
err = context.DeadlineExceeded
err = os.ErrDeadlineExceeded
case err = <-ch:
tm.Stop()
}
Expand Down Expand Up @@ -1150,6 +1150,7 @@ func (p *pipe) syncDo(dl time.Time, dlOk bool, cmd Completed) (resp RedisResult)
defaultDeadline := time.Now().Add(p.timeout)
if dl.After(defaultDeadline) {
dl = defaultDeadline
dlOk = false
}
}
p.conn.SetDeadline(dl)
Expand All @@ -1165,7 +1166,7 @@ func (p *pipe) syncDo(dl time.Time, dlOk bool, cmd Completed) (resp RedisResult)
msg, err = syncRead(p.r)
}
if err != nil {
if errors.Is(err, os.ErrDeadlineExceeded) {
if dlOk && errors.Is(err, os.ErrDeadlineExceeded) {
err = context.DeadlineExceeded
}
p.error.CompareAndSwap(nil, &errs{error: err})
Expand All @@ -1187,6 +1188,7 @@ func (p *pipe) syncDoMulti(dl time.Time, dlOk bool, resp []RedisResult, multi []
defaultDeadline := time.Now().Add(p.timeout)
if dl.After(defaultDeadline) {
dl = defaultDeadline
dlOk = false
}
}
p.conn.SetDeadline(dl)
Expand Down Expand Up @@ -1218,7 +1220,7 @@ process:
}
return
abort:
if errors.Is(err, os.ErrDeadlineExceeded) {
if dlOk && errors.Is(err, os.ErrDeadlineExceeded) {
err = context.DeadlineExceeded
}
p.error.CompareAndSwap(nil, &errs{error: err})
Expand Down
77 changes: 63 additions & 14 deletions pipe_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"fmt"
"io"
"net"
"os"
"runtime"
"strconv"
"strings"
Expand Down Expand Up @@ -3402,7 +3403,7 @@ func TestExitOnRingFullAndPingTimout(t *testing.T) {
// fill the ring
for i := 0; i < len(p.queue.(*ring).store); i++ {
go func() {
if err := p.Do(context.Background(), cmds.NewCompleted([]string{"GET", "a"})).Error(); err != context.DeadlineExceeded {
if err := p.Do(context.Background(), cmds.NewCompleted([]string{"GET", "a"})).Error(); !errors.Is(err, os.ErrDeadlineExceeded) {
t.Errorf("unexpected result, expected context.DeadlineExceeded, got %v", err)
}
}()
Expand All @@ -3412,7 +3413,7 @@ func TestExitOnRingFullAndPingTimout(t *testing.T) {
mock.Expect("GET", "a")
}

if err := p.Do(context.Background(), cmds.NewCompleted([]string{"GET", "a"})).Error(); err != context.DeadlineExceeded {
if err := p.Do(context.Background(), cmds.NewCompleted([]string{"GET", "a"})).Error(); !errors.Is(err, os.ErrDeadlineExceeded) {
t.Errorf("unexpected result, expected context.DeadlineExceeded, got %v", err)
}
}
Expand Down Expand Up @@ -3944,7 +3945,7 @@ func TestSyncModeSwitchingWithDeadlineExceed_Do(t *testing.T) {
for i := 0; i < 10; i++ {
wg.Add(1)
go func() {
if err := p.Do(ctx, cmds.NewCompleted([]string{"GET", "a"})).NonRedisError(); !errors.Is(err, context.DeadlineExceeded) {
if err := p.Do(ctx, cmds.NewCompleted([]string{"GET", "a"})).NonRedisError(); !errors.Is(err, context.DeadlineExceeded) && !errors.Is(err, os.ErrDeadlineExceeded) {
t.Errorf("unexpected err %v", err)
}
wg.Done()
Expand All @@ -3970,7 +3971,7 @@ func TestSyncModeSwitchingWithDeadlineExceed_DoMulti(t *testing.T) {
for i := 0; i < 10; i++ {
wg.Add(1)
go func() {
if err := p.DoMulti(ctx, cmds.NewCompleted([]string{"GET", "a"})).s[0].NonRedisError(); !errors.Is(err, context.DeadlineExceeded) {
if err := p.DoMulti(ctx, cmds.NewCompleted([]string{"GET", "a"})).s[0].NonRedisError(); !errors.Is(err, context.DeadlineExceeded) && !errors.Is(err, os.ErrDeadlineExceeded) {
t.Errorf("unexpected err %v", err)
}
wg.Done()
Expand All @@ -3984,9 +3985,9 @@ func TestSyncModeSwitchingWithDeadlineExceed_DoMulti(t *testing.T) {
p.Close()
}

func TestOngoingDeadlineContextInSyncMode_Do(t *testing.T) {
func TestOngoingDeadlineShortContextInSyncMode_Do(t *testing.T) {
defer ShouldNotLeaked(SetupLeakDetection())
p, _, _, closeConn := setup(t, ClientOption{})
p, _, _, closeConn := setup(t, ClientOption{ConnWriteTimeout: 1 * time.Second})
defer closeConn()

ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(1*time.Second/2))
Expand All @@ -3998,12 +3999,26 @@ func TestOngoingDeadlineContextInSyncMode_Do(t *testing.T) {
p.Close()
}

func TestOngoingDeadlineLongContextInSyncMode_Do(t *testing.T) {
defer ShouldNotLeaked(SetupLeakDetection())
p, _, _, closeConn := setup(t, ClientOption{ConnWriteTimeout: 1 * time.Second / 4})
defer closeConn()

ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(1*time.Second/2))
defer cancel()

if err := p.Do(ctx, cmds.NewCompleted([]string{"GET", "a"})).NonRedisError(); !errors.Is(err, os.ErrDeadlineExceeded) {
t.Fatalf("unexpected err %v", err)
}
p.Close()
}

func TestWriteDeadlineInSyncMode_Do(t *testing.T) {
defer ShouldNotLeaked(SetupLeakDetection())
p, _, _, closeConn := setup(t, ClientOption{ConnWriteTimeout: 1 * time.Second / 2, Dialer: net.Dialer{KeepAlive: time.Second / 3}})
defer closeConn()

if err := p.Do(context.Background(), cmds.NewCompleted([]string{"GET", "a"})).NonRedisError(); !errors.Is(err, context.DeadlineExceeded) {
if err := p.Do(context.Background(), cmds.NewCompleted([]string{"GET", "a"})).NonRedisError(); !errors.Is(err, os.ErrDeadlineExceeded) {
t.Fatalf("unexpected err %v", err)
}
p.Close()
Expand All @@ -4018,7 +4033,7 @@ func TestWriteDeadlineIsShorterThanContextDeadlineInSyncMode_Do(t *testing.T) {
defer cancel()

startTime := time.Now()
if err := p.Do(ctx, cmds.NewCompleted([]string{"GET", "a"})).NonRedisError(); !errors.Is(err, context.DeadlineExceeded) {
if err := p.Do(ctx, cmds.NewCompleted([]string{"GET", "a"})).NonRedisError(); !errors.Is(err, os.ErrDeadlineExceeded) {
t.Fatalf("unexpected err %v", err)
}

Expand All @@ -4031,7 +4046,7 @@ func TestWriteDeadlineIsShorterThanContextDeadlineInSyncMode_Do(t *testing.T) {

func TestWriteDeadlineIsNoShorterThanContextDeadlineInSyncMode_DoBlocked(t *testing.T) {
defer ShouldNotLeaked(SetupLeakDetection())
p, _, _, closeConn := setup(t, ClientOption{ConnWriteTimeout: 5 * time.Millisecond, Dialer: net.Dialer{KeepAlive: time.Second}})
p, _, _, closeConn := setup(t, ClientOption{ConnWriteTimeout: 5 * time.Second, Dialer: net.Dialer{KeepAlive: time.Second}})
defer closeConn()

ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
Expand All @@ -4049,9 +4064,9 @@ func TestWriteDeadlineIsNoShorterThanContextDeadlineInSyncMode_DoBlocked(t *test
p.Close()
}

func TestOngoingDeadlineContextInSyncMode_DoMulti(t *testing.T) {
func TestOngoingDeadlineShortContextInSyncMode_DoMulti(t *testing.T) {
defer ShouldNotLeaked(SetupLeakDetection())
p, _, _, closeConn := setup(t, ClientOption{})
p, _, _, closeConn := setup(t, ClientOption{ConnWriteTimeout: time.Second})
defer closeConn()

ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(1*time.Second/2))
Expand All @@ -4063,12 +4078,26 @@ func TestOngoingDeadlineContextInSyncMode_DoMulti(t *testing.T) {
p.Close()
}

func TestOngoingDeadlineLongContextInSyncMode_DoMulti(t *testing.T) {
defer ShouldNotLeaked(SetupLeakDetection())
p, _, _, closeConn := setup(t, ClientOption{ConnWriteTimeout: time.Second / 4})
defer closeConn()

ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(1*time.Second/2))
defer cancel()

if err := p.DoMulti(ctx, cmds.NewCompleted([]string{"GET", "a"})).s[0].NonRedisError(); !errors.Is(err, os.ErrDeadlineExceeded) {
t.Fatalf("unexpected err %v", err)
}
p.Close()
}

func TestWriteDeadlineInSyncMode_DoMulti(t *testing.T) {
defer ShouldNotLeaked(SetupLeakDetection())
p, _, _, closeConn := setup(t, ClientOption{ConnWriteTimeout: time.Second / 2, Dialer: net.Dialer{KeepAlive: time.Second / 3}})
defer closeConn()

if err := p.DoMulti(context.Background(), cmds.NewCompleted([]string{"GET", "a"})).s[0].NonRedisError(); !errors.Is(err, context.DeadlineExceeded) {
if err := p.DoMulti(context.Background(), cmds.NewCompleted([]string{"GET", "a"})).s[0].NonRedisError(); !errors.Is(err, os.ErrDeadlineExceeded) {
t.Fatalf("unexpected err %v", err)
}
p.Close()
Expand All @@ -4082,6 +4111,26 @@ func TestWriteDeadlineIsShorterThanContextDeadlineInSyncMode_DoMulti(t *testing.
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()

startTime := time.Now()
if err := p.DoMulti(ctx, cmds.NewCompleted([]string{"GET", "a"})).s[0].NonRedisError(); !errors.Is(err, os.ErrDeadlineExceeded) {
t.Fatalf("unexpected err %v", err)
}

if time.Since(startTime) >= time.Second {
t.Fatalf("unexpected time %v", time.Since(startTime))
}

p.Close()
}

func TestWriteDeadlineIsNoShorterThanContextDeadlineInSyncMode_DoMulti(t *testing.T) {
defer ShouldNotLeaked(SetupLeakDetection())
p, _, _, closeConn := setup(t, ClientOption{ConnWriteTimeout: time.Second, Dialer: net.Dialer{KeepAlive: time.Second}})
defer closeConn()

ctx, cancel := context.WithTimeout(context.Background(), time.Second/2)
defer cancel()

startTime := time.Now()
if err := p.DoMulti(ctx, cmds.NewCompleted([]string{"GET", "a"})).s[0].NonRedisError(); !errors.Is(err, context.DeadlineExceeded) {
t.Fatalf("unexpected err %v", err)
Expand Down Expand Up @@ -4170,7 +4219,7 @@ func TestOngoingWriteTimeoutInPipelineMode_Do(t *testing.T) {
for i := 0; i < 5; i++ {
go func() {
_, err := p.Do(ctx, cmds.NewCompleted([]string{"GET", "a"})).ToString()
if errors.Is(err, context.DeadlineExceeded) {
if errors.Is(err, os.ErrDeadlineExceeded) {
atomic.AddInt32(&timeout, 1)
} else {
t.Errorf("unexpected err %v", err)
Expand Down Expand Up @@ -4244,7 +4293,7 @@ func TestOngoingWriteTimeoutInPipelineMode_DoMulti(t *testing.T) {
for i := 0; i < 5; i++ {
go func() {
_, err := p.DoMulti(ctx, cmds.NewCompleted([]string{"GET", "a"})).s[0].ToString()
if errors.Is(err, context.DeadlineExceeded) {
if errors.Is(err, os.ErrDeadlineExceeded) {
atomic.AddInt32(&timeout, 1)
} else {
t.Errorf("unexpected err %v", err)
Expand Down
5 changes: 3 additions & 2 deletions redis_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package rueidis
import (
"bytes"
"context"
"errors"
"math/rand"
"net"
"os"
Expand Down Expand Up @@ -167,7 +168,7 @@ func testSETGET(t *testing.T, client Client, csc bool) {
ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond)
defer cancel()
val, err := client.Do(ctx, client.B().Get().Key(key).Build()).ToString()
if err != context.DeadlineExceeded {
if !errors.Is(err, context.DeadlineExceeded) && !errors.Is(err, os.ErrDeadlineExceeded) {
if v, ok := kvs[key]; !((ok && val == v) || (!ok && IsRedisNil(err))) {
t.Errorf("unexpected get response %v %v %v", val, err, ok)
}
Expand Down Expand Up @@ -318,7 +319,7 @@ func testMultiSETGET(t *testing.T, client Client, csc bool) {
defer cancel()
for j, resp := range client.DoMulti(ctx, commands...) {
val, err := resp.ToString()
if err != context.DeadlineExceeded {
if !errors.Is(err, context.DeadlineExceeded) && !errors.Is(err, os.ErrDeadlineExceeded) {
if v, ok := kvs[cmdkeys[j]]; !((ok && val == v) || (!ok && IsRedisNil(err))) {
t.Fatalf("unexpected get response %v %v %v", val, err, ok)
}
Expand Down

0 comments on commit 1c5e799

Please sign in to comment.