diff --git a/pipe.go b/pipe.go index a36776ac..b7082726 100644 --- a/pipe.go +++ b/pipe.go @@ -590,7 +590,7 @@ func (p *pipe) backgroundPing() { go func() { ch <- p.Do(context.Background(), cmds.PingCmd).NonRedisError() }() select { case <-tm.C: - err = context.DeadlineExceeded + err = os.ErrDeadlineExceeded case err = <-ch: tm.Stop() } @@ -1150,6 +1150,7 @@ func (p *pipe) syncDo(dl time.Time, dlOk bool, cmd Completed) (resp RedisResult) defaultDeadline := time.Now().Add(p.timeout) if dl.After(defaultDeadline) { dl = defaultDeadline + dlOk = false } } p.conn.SetDeadline(dl) @@ -1165,7 +1166,7 @@ func (p *pipe) syncDo(dl time.Time, dlOk bool, cmd Completed) (resp RedisResult) msg, err = syncRead(p.r) } if err != nil { - if errors.Is(err, os.ErrDeadlineExceeded) { + if dlOk && errors.Is(err, os.ErrDeadlineExceeded) { err = context.DeadlineExceeded } p.error.CompareAndSwap(nil, &errs{error: err}) @@ -1187,6 +1188,7 @@ func (p *pipe) syncDoMulti(dl time.Time, dlOk bool, resp []RedisResult, multi [] defaultDeadline := time.Now().Add(p.timeout) if dl.After(defaultDeadline) { dl = defaultDeadline + dlOk = false } } p.conn.SetDeadline(dl) @@ -1218,7 +1220,7 @@ process: } return abort: - if errors.Is(err, os.ErrDeadlineExceeded) { + if dlOk && errors.Is(err, os.ErrDeadlineExceeded) { err = context.DeadlineExceeded } p.error.CompareAndSwap(nil, &errs{error: err}) diff --git a/pipe_test.go b/pipe_test.go index 46b093dc..3d6c14f6 100644 --- a/pipe_test.go +++ b/pipe_test.go @@ -8,6 +8,7 @@ import ( "fmt" "io" "net" + "os" "runtime" "strconv" "strings" @@ -3402,7 +3403,7 @@ func TestExitOnRingFullAndPingTimout(t *testing.T) { // fill the ring for i := 0; i < len(p.queue.(*ring).store); i++ { go func() { - if err := p.Do(context.Background(), cmds.NewCompleted([]string{"GET", "a"})).Error(); err != context.DeadlineExceeded { + if err := p.Do(context.Background(), cmds.NewCompleted([]string{"GET", "a"})).Error(); !errors.Is(err, os.ErrDeadlineExceeded) { t.Errorf("unexpected result, expected context.DeadlineExceeded, got %v", err) } }() @@ -3412,7 +3413,7 @@ func TestExitOnRingFullAndPingTimout(t *testing.T) { mock.Expect("GET", "a") } - if err := p.Do(context.Background(), cmds.NewCompleted([]string{"GET", "a"})).Error(); err != context.DeadlineExceeded { + if err := p.Do(context.Background(), cmds.NewCompleted([]string{"GET", "a"})).Error(); !errors.Is(err, os.ErrDeadlineExceeded) { t.Errorf("unexpected result, expected context.DeadlineExceeded, got %v", err) } } @@ -3944,7 +3945,7 @@ func TestSyncModeSwitchingWithDeadlineExceed_Do(t *testing.T) { for i := 0; i < 10; i++ { wg.Add(1) go func() { - if err := p.Do(ctx, cmds.NewCompleted([]string{"GET", "a"})).NonRedisError(); !errors.Is(err, context.DeadlineExceeded) { + if err := p.Do(ctx, cmds.NewCompleted([]string{"GET", "a"})).NonRedisError(); !errors.Is(err, context.DeadlineExceeded) && !errors.Is(err, os.ErrDeadlineExceeded) { t.Errorf("unexpected err %v", err) } wg.Done() @@ -3970,7 +3971,7 @@ func TestSyncModeSwitchingWithDeadlineExceed_DoMulti(t *testing.T) { for i := 0; i < 10; i++ { wg.Add(1) go func() { - if err := p.DoMulti(ctx, cmds.NewCompleted([]string{"GET", "a"})).s[0].NonRedisError(); !errors.Is(err, context.DeadlineExceeded) { + if err := p.DoMulti(ctx, cmds.NewCompleted([]string{"GET", "a"})).s[0].NonRedisError(); !errors.Is(err, context.DeadlineExceeded) && !errors.Is(err, os.ErrDeadlineExceeded) { t.Errorf("unexpected err %v", err) } wg.Done() @@ -3984,9 +3985,9 @@ func TestSyncModeSwitchingWithDeadlineExceed_DoMulti(t *testing.T) { p.Close() } -func TestOngoingDeadlineContextInSyncMode_Do(t *testing.T) { +func TestOngoingDeadlineShortContextInSyncMode_Do(t *testing.T) { defer ShouldNotLeaked(SetupLeakDetection()) - p, _, _, closeConn := setup(t, ClientOption{}) + p, _, _, closeConn := setup(t, ClientOption{ConnWriteTimeout: 1 * time.Second}) defer closeConn() ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(1*time.Second/2)) @@ -3998,12 +3999,26 @@ func TestOngoingDeadlineContextInSyncMode_Do(t *testing.T) { p.Close() } +func TestOngoingDeadlineLongContextInSyncMode_Do(t *testing.T) { + defer ShouldNotLeaked(SetupLeakDetection()) + p, _, _, closeConn := setup(t, ClientOption{ConnWriteTimeout: 1 * time.Second / 4}) + defer closeConn() + + ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(1*time.Second/2)) + defer cancel() + + if err := p.Do(ctx, cmds.NewCompleted([]string{"GET", "a"})).NonRedisError(); !errors.Is(err, os.ErrDeadlineExceeded) { + t.Fatalf("unexpected err %v", err) + } + p.Close() +} + func TestWriteDeadlineInSyncMode_Do(t *testing.T) { defer ShouldNotLeaked(SetupLeakDetection()) p, _, _, closeConn := setup(t, ClientOption{ConnWriteTimeout: 1 * time.Second / 2, Dialer: net.Dialer{KeepAlive: time.Second / 3}}) defer closeConn() - if err := p.Do(context.Background(), cmds.NewCompleted([]string{"GET", "a"})).NonRedisError(); !errors.Is(err, context.DeadlineExceeded) { + if err := p.Do(context.Background(), cmds.NewCompleted([]string{"GET", "a"})).NonRedisError(); !errors.Is(err, os.ErrDeadlineExceeded) { t.Fatalf("unexpected err %v", err) } p.Close() @@ -4018,7 +4033,7 @@ func TestWriteDeadlineIsShorterThanContextDeadlineInSyncMode_Do(t *testing.T) { defer cancel() startTime := time.Now() - if err := p.Do(ctx, cmds.NewCompleted([]string{"GET", "a"})).NonRedisError(); !errors.Is(err, context.DeadlineExceeded) { + if err := p.Do(ctx, cmds.NewCompleted([]string{"GET", "a"})).NonRedisError(); !errors.Is(err, os.ErrDeadlineExceeded) { t.Fatalf("unexpected err %v", err) } @@ -4031,7 +4046,7 @@ func TestWriteDeadlineIsShorterThanContextDeadlineInSyncMode_Do(t *testing.T) { func TestWriteDeadlineIsNoShorterThanContextDeadlineInSyncMode_DoBlocked(t *testing.T) { defer ShouldNotLeaked(SetupLeakDetection()) - p, _, _, closeConn := setup(t, ClientOption{ConnWriteTimeout: 5 * time.Millisecond, Dialer: net.Dialer{KeepAlive: time.Second}}) + p, _, _, closeConn := setup(t, ClientOption{ConnWriteTimeout: 5 * time.Second, Dialer: net.Dialer{KeepAlive: time.Second}}) defer closeConn() ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) @@ -4049,9 +4064,9 @@ func TestWriteDeadlineIsNoShorterThanContextDeadlineInSyncMode_DoBlocked(t *test p.Close() } -func TestOngoingDeadlineContextInSyncMode_DoMulti(t *testing.T) { +func TestOngoingDeadlineShortContextInSyncMode_DoMulti(t *testing.T) { defer ShouldNotLeaked(SetupLeakDetection()) - p, _, _, closeConn := setup(t, ClientOption{}) + p, _, _, closeConn := setup(t, ClientOption{ConnWriteTimeout: time.Second}) defer closeConn() ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(1*time.Second/2)) @@ -4063,12 +4078,26 @@ func TestOngoingDeadlineContextInSyncMode_DoMulti(t *testing.T) { p.Close() } +func TestOngoingDeadlineLongContextInSyncMode_DoMulti(t *testing.T) { + defer ShouldNotLeaked(SetupLeakDetection()) + p, _, _, closeConn := setup(t, ClientOption{ConnWriteTimeout: time.Second / 4}) + defer closeConn() + + ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(1*time.Second/2)) + defer cancel() + + if err := p.DoMulti(ctx, cmds.NewCompleted([]string{"GET", "a"})).s[0].NonRedisError(); !errors.Is(err, os.ErrDeadlineExceeded) { + t.Fatalf("unexpected err %v", err) + } + p.Close() +} + func TestWriteDeadlineInSyncMode_DoMulti(t *testing.T) { defer ShouldNotLeaked(SetupLeakDetection()) p, _, _, closeConn := setup(t, ClientOption{ConnWriteTimeout: time.Second / 2, Dialer: net.Dialer{KeepAlive: time.Second / 3}}) defer closeConn() - if err := p.DoMulti(context.Background(), cmds.NewCompleted([]string{"GET", "a"})).s[0].NonRedisError(); !errors.Is(err, context.DeadlineExceeded) { + if err := p.DoMulti(context.Background(), cmds.NewCompleted([]string{"GET", "a"})).s[0].NonRedisError(); !errors.Is(err, os.ErrDeadlineExceeded) { t.Fatalf("unexpected err %v", err) } p.Close() @@ -4082,6 +4111,26 @@ func TestWriteDeadlineIsShorterThanContextDeadlineInSyncMode_DoMulti(t *testing. ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() + startTime := time.Now() + if err := p.DoMulti(ctx, cmds.NewCompleted([]string{"GET", "a"})).s[0].NonRedisError(); !errors.Is(err, os.ErrDeadlineExceeded) { + t.Fatalf("unexpected err %v", err) + } + + if time.Since(startTime) >= time.Second { + t.Fatalf("unexpected time %v", time.Since(startTime)) + } + + p.Close() +} + +func TestWriteDeadlineIsNoShorterThanContextDeadlineInSyncMode_DoMulti(t *testing.T) { + defer ShouldNotLeaked(SetupLeakDetection()) + p, _, _, closeConn := setup(t, ClientOption{ConnWriteTimeout: time.Second, Dialer: net.Dialer{KeepAlive: time.Second}}) + defer closeConn() + + ctx, cancel := context.WithTimeout(context.Background(), time.Second/2) + defer cancel() + startTime := time.Now() if err := p.DoMulti(ctx, cmds.NewCompleted([]string{"GET", "a"})).s[0].NonRedisError(); !errors.Is(err, context.DeadlineExceeded) { t.Fatalf("unexpected err %v", err) @@ -4170,7 +4219,7 @@ func TestOngoingWriteTimeoutInPipelineMode_Do(t *testing.T) { for i := 0; i < 5; i++ { go func() { _, err := p.Do(ctx, cmds.NewCompleted([]string{"GET", "a"})).ToString() - if errors.Is(err, context.DeadlineExceeded) { + if errors.Is(err, os.ErrDeadlineExceeded) { atomic.AddInt32(&timeout, 1) } else { t.Errorf("unexpected err %v", err) @@ -4244,7 +4293,7 @@ func TestOngoingWriteTimeoutInPipelineMode_DoMulti(t *testing.T) { for i := 0; i < 5; i++ { go func() { _, err := p.DoMulti(ctx, cmds.NewCompleted([]string{"GET", "a"})).s[0].ToString() - if errors.Is(err, context.DeadlineExceeded) { + if errors.Is(err, os.ErrDeadlineExceeded) { atomic.AddInt32(&timeout, 1) } else { t.Errorf("unexpected err %v", err) diff --git a/redis_test.go b/redis_test.go index 68647e52..4595ac6d 100644 --- a/redis_test.go +++ b/redis_test.go @@ -3,6 +3,7 @@ package rueidis import ( "bytes" "context" + "errors" "math/rand" "net" "os" @@ -167,7 +168,7 @@ func testSETGET(t *testing.T, client Client, csc bool) { ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond) defer cancel() val, err := client.Do(ctx, client.B().Get().Key(key).Build()).ToString() - if err != context.DeadlineExceeded { + if !errors.Is(err, context.DeadlineExceeded) && !errors.Is(err, os.ErrDeadlineExceeded) { if v, ok := kvs[key]; !((ok && val == v) || (!ok && IsRedisNil(err))) { t.Errorf("unexpected get response %v %v %v", val, err, ok) } @@ -318,7 +319,7 @@ func testMultiSETGET(t *testing.T, client Client, csc bool) { defer cancel() for j, resp := range client.DoMulti(ctx, commands...) { val, err := resp.ToString() - if err != context.DeadlineExceeded { + if !errors.Is(err, context.DeadlineExceeded) && !errors.Is(err, os.ErrDeadlineExceeded) { if v, ok := kvs[cmdkeys[j]]; !((ok && val == v) || (!ok && IsRedisNil(err))) { t.Fatalf("unexpected get response %v %v %v", val, err, ok) }