diff --git a/mux.go b/mux.go index 08e36069..b0e0203f 100644 --- a/mux.go +++ b/mux.go @@ -204,8 +204,10 @@ func (m *mux) DoMultiStream(ctx context.Context, multi ...Completed) MultiRedisR } func (m *mux) Do(ctx context.Context, cmd Completed) (resp RedisResult) { - if (m.usePool && !cmd.NoReply()) || cmd.IsBlock() { - resp = m.blocking(ctx, cmd) + if m.usePool && !cmd.NoReply() { + resp = m.blocking(m.spool, ctx, cmd) + } else if cmd.IsBlock() { + resp = m.blocking(m.dpool, ctx, cmd) } else { resp = m.pipeline(ctx, cmd) } @@ -213,37 +215,38 @@ func (m *mux) Do(ctx context.Context, cmd Completed) (resp RedisResult) { } func (m *mux) DoMulti(ctx context.Context, multi ...Completed) (resp *redisresults) { - if m.usePool || (len(multi) >= m.maxm && m.maxm > 0) { - goto block // use a dedicated connection if the pipeline is too large - } for _, cmd := range multi { + if cmd.NoReply() { + return m.pipelineMulti(ctx, multi) + } if cmd.IsBlock() { cmds.ToBlock(&multi[0]) // mark the first cmd as block if one of them is block to shortcut later check. goto block } } + if m.usePool || (len(multi) >= m.maxm && m.maxm > 0) { + goto block // use a dedicated connection if the pipeline is too large + } return m.pipelineMulti(ctx, multi) block: - for _, cmd := range multi { - if cmd.NoReply() { - return m.pipelineMulti(ctx, multi) - } + if m.usePool { + return m.blockingMulti(m.spool, ctx, multi) } - return m.blockingMulti(ctx, multi) + return m.blockingMulti(m.dpool, ctx, multi) } -func (m *mux) blocking(ctx context.Context, cmd Completed) (resp RedisResult) { - wire := m.spool.Acquire() +func (m *mux) blocking(pool *pool, ctx context.Context, cmd Completed) (resp RedisResult) { + wire := pool.Acquire() resp = wire.Do(ctx, cmd) if resp.NonRedisError() != nil { // abort the wire if blocking command return early (ex. context.DeadlineExceeded) wire.Close() } - m.spool.Store(wire) + pool.Store(wire) return resp } -func (m *mux) blockingMulti(ctx context.Context, cmd []Completed) (resp *redisresults) { - wire := m.spool.Acquire() +func (m *mux) blockingMulti(pool *pool, ctx context.Context, cmd []Completed) (resp *redisresults) { + wire := pool.Acquire() resp = wire.DoMulti(ctx, cmd...) for _, res := range resp.s { if res.NonRedisError() != nil { // abort the wire if blocking command return early (ex. context.DeadlineExceeded) @@ -251,7 +254,7 @@ func (m *mux) blockingMulti(ctx context.Context, cmd []Completed) (resp *redisre break } } - m.spool.Store(wire) + pool.Store(wire) return resp } diff --git a/mux_test.go b/mux_test.go index 0eb2ad72..8179f068 100644 --- a/mux_test.go +++ b/mux_test.go @@ -177,7 +177,7 @@ func TestMuxReuseWire(t *testing.T) { m.Close() }) - t.Run("reuse blocking pool", func(t *testing.T) { + t.Run("reuse blocking (dpool) pool", func(t *testing.T) { blocking := make(chan struct{}) response := make(chan RedisResult) m, checkClean := setupMux([]*mockWire{ @@ -202,6 +202,57 @@ func TestMuxReuseWire(t *testing.T) { t.Fatalf("unexpected dial error %v", err) } + wire1 := m.dpool.Acquire() + + go func() { + // this should use the second wire + if val, err := m.Do(context.Background(), cmds.NewBlockingCompleted([]string{"PING"})).ToString(); err != nil { + t.Errorf("unexpected error %v", err) + } else if val != "BLOCK_RESPONSE" { + t.Errorf("unexpected response %v", val) + } + close(blocking) + }() + <-blocking + + m.dpool.Store(wire1) + // this should use the first wire + if val, err := m.Do(context.Background(), cmds.NewBlockingCompleted([]string{"PING"})).ToString(); err != nil { + t.Fatalf("unexpected error %v", err) + } else if val != "ACQUIRED" { + t.Fatalf("unexpected response %v", val) + } + + response <- newResult(RedisMessage{typ: '+', string: "BLOCK_RESPONSE"}, nil) + <-blocking + }) + + t.Run("reuse blocking (spool) pool", func(t *testing.T) { + blocking := make(chan struct{}) + response := make(chan RedisResult) + m, checkClean := setupMux([]*mockWire{ + { + // leave first wire for pipeline calls + }, + { + DoFn: func(cmd Completed) RedisResult { + return newResult(RedisMessage{typ: '+', string: "ACQUIRED"}, nil) + }, + }, + { + DoFn: func(cmd Completed) RedisResult { + blocking <- struct{}{} + return <-response + }, + }, + }) + m.usePool = true // switch to spool + defer checkClean(t) + defer m.Close() + if err := m.Dial(); err != nil { + t.Fatalf("unexpected dial error %v", err) + } + wire1 := m.spool.Acquire() go func() { @@ -227,6 +278,107 @@ func TestMuxReuseWire(t *testing.T) { <-blocking }) + t.Run("reuse blocking (dpool) pool DoMulti", func(t *testing.T) { + blocking := make(chan struct{}) + response := make(chan RedisResult) + m, checkClean := setupMux([]*mockWire{ + { + // leave first wire for pipeline calls + }, + { + DoMultiFn: func(cmd ...Completed) *redisresults { + return &redisresults{s: []RedisResult{newResult(RedisMessage{typ: '+', string: "ACQUIRED"}, nil)}} + }, + }, + { + DoMultiFn: func(cmd ...Completed) *redisresults { + blocking <- struct{}{} + return &redisresults{s: []RedisResult{<-response}} + }, + }, + }) + m.usePool = true // switch to spool + defer checkClean(t) + defer m.Close() + if err := m.Dial(); err != nil { + t.Fatalf("unexpected dial error %v", err) + } + + wire1 := m.spool.Acquire() + + go func() { + // this should use the second wire + if val, err := m.DoMulti(context.Background(), cmds.NewBlockingCompleted([]string{"PING"})).s[0].ToString(); err != nil { + t.Errorf("unexpected error %v", err) + } else if val != "BLOCK_RESPONSE" { + t.Errorf("unexpected response %v", val) + } + close(blocking) + }() + <-blocking + + m.spool.Store(wire1) + // this should use the first wire + if val, err := m.DoMulti(context.Background(), cmds.NewBlockingCompleted([]string{"PING"})).s[0].ToString(); err != nil { + t.Fatalf("unexpected error %v", err) + } else if val != "ACQUIRED" { + t.Fatalf("unexpected response %v", val) + } + + response <- newResult(RedisMessage{typ: '+', string: "BLOCK_RESPONSE"}, nil) + <-blocking + }) + + t.Run("reuse blocking (spool) pool DoMulti", func(t *testing.T) { + blocking := make(chan struct{}) + response := make(chan RedisResult) + m, checkClean := setupMux([]*mockWire{ + { + // leave first wire for pipeline calls + }, + { + DoMultiFn: func(cmd ...Completed) *redisresults { + return &redisresults{s: []RedisResult{newResult(RedisMessage{typ: '+', string: "ACQUIRED"}, nil)}} + }, + }, + { + DoMultiFn: func(cmd ...Completed) *redisresults { + blocking <- struct{}{} + return &redisresults{s: []RedisResult{<-response}} + }, + }, + }) + defer checkClean(t) + defer m.Close() + if err := m.Dial(); err != nil { + t.Fatalf("unexpected dial error %v", err) + } + + wire1 := m.dpool.Acquire() + + go func() { + // this should use the second wire + if val, err := m.DoMulti(context.Background(), cmds.NewBlockingCompleted([]string{"PING"})).s[0].ToString(); err != nil { + t.Errorf("unexpected error %v", err) + } else if val != "BLOCK_RESPONSE" { + t.Errorf("unexpected response %v", val) + } + close(blocking) + }() + <-blocking + + m.dpool.Store(wire1) + // this should use the first wire + if val, err := m.DoMulti(context.Background(), cmds.NewBlockingCompleted([]string{"PING"})).s[0].ToString(); err != nil { + t.Fatalf("unexpected error %v", err) + } else if val != "ACQUIRED" { + t.Fatalf("unexpected response %v", val) + } + + response <- newResult(RedisMessage{typ: '+', string: "BLOCK_RESPONSE"}, nil) + <-blocking + }) + t.Run("unsubscribe blocking pool", func(t *testing.T) { cleaned := false m, checkClean := setupMux([]*mockWire{