Skip to content

Commit

Permalink
Merge pull request #361 from Exca-DK/cluster-shards
Browse files Browse the repository at this point in the history
feat: use CLUSTER SHARDS for redis >= 7
  • Loading branch information
rueian authored Sep 10, 2023
2 parents 73da7f7 + b387cae commit 256b4a7
Show file tree
Hide file tree
Showing 12 changed files with 861 additions and 44 deletions.
8 changes: 8 additions & 0 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ type mockConn struct {
DoMultiCacheFn func(multi ...CacheableTTL) *redisresults
ReceiveFn func(ctx context.Context, subscribe Completed, fn func(message PubSubMessage)) error
InfoFn func() map[string]RedisMessage
VersionFn func() int
ErrorFn func() error
CloseFn func()
DialFn func() error
Expand Down Expand Up @@ -138,6 +139,13 @@ func (m *mockConn) Info() map[string]RedisMessage {
return nil
}

func (m *mockConn) Version() int {
if m.VersionFn != nil {
return m.VersionFn()
}
return 0
}

func (m *mockConn) Error() error {
if m.ErrorFn != nil {
return m.ErrorFn()
Expand Down
108 changes: 101 additions & 7 deletions cluster.go
Original file line number Diff line number Diff line change
Expand Up @@ -143,26 +143,44 @@ type clusterslots struct {
}

func (c *clusterClient) _refresh() (err error) {
var reply RedisMessage
var addr string

type versionedResult struct {
version int
reply RedisResult
addr string
}
var (
reply RedisMessage
addr string
version int
)
c.mu.RLock()
results := make(chan clusterslots, len(c.conns))
results := make(chan versionedResult, len(c.conns))
pending := make([]conn, 0, len(c.conns))
for _, cc := range c.conns {
pending = append(pending, cc)
}
c.mu.RUnlock()

for i := 0; i < cap(results); i++ {
if i&3 == 0 { // batch CLUSTER SLOTS for every 4 connections
if i&3 == 0 { // batch CLUSTER SLOTS/CLUSTER SHARDS for every 4 connections
for j := i; j < i+4 && j < len(pending); j++ {
go func(c conn) {
results <- clusterslots{reply: c.Do(context.Background(), cmds.SlotCmd), addr: c.Addr()}
var reply RedisResult
if c.Version() < 7 {
reply = c.Do(context.Background(), cmds.SlotCmd)
} else {
reply = c.Do(context.Background(), cmds.ShardsCmd)
}
results <- versionedResult{
version: c.Version(),
reply: reply,
addr: c.Addr(),
}
}(pending[j])
}
}
r := <-results
version = r.version
addr = r.addr
reply, err = r.reply.ToMessage()
if len(reply.values) != 0 {
Expand All @@ -175,7 +193,12 @@ func (c *clusterClient) _refresh() (err error) {
return err
}

groups := parseSlots(reply, addr)
var groups map[string]group
if version < 7 {
groups = parseSlots(reply, addr)
} else {
groups = parseShards(reply, addr, c.opt.TLSConfig != nil)
}

conns := make(map[string]conn, len(groups))
for _, g := range groups {
Expand Down Expand Up @@ -289,6 +312,77 @@ func parseSlots(slots RedisMessage, defaultAddr string) map[string]group {
return groups
}

// parseShards - map redis shards for each redis nodes/addresses
// defaultAddr is needed in case the node does not know its own IP
func parseShards(shards RedisMessage, defaultAddr string, tls bool) map[string]group {
parseNodeEndpoint := func(msg map[string]RedisMessage) string {
endpoint := msg["endpoint"].string
switch endpoint {
case "":
return defaultAddr
case "?":
return ""
}

port := msg["port"].integer
if tls && msg["tls-port"].integer > 0 {
port = msg["tls-port"].integer
}

return net.JoinHostPort(endpoint, strconv.FormatInt(port, 10))
}

groups := make(map[string]group, len(shards.values))
for _, v := range shards.values {
slotsAndNodes, _ := v.ToMap()
var (
master string
masterPos int
)
nodes := slotsAndNodes["nodes"].values
for i := 0; i < len(nodes); i++ {
dict, _ := nodes[i].ToMap()
if dict["role"].string != "master" {
continue
}
master = parseNodeEndpoint(dict)
masterPos = i
break
}

if master == "" {
continue
}

g, ok := groups[master]
if !ok {
g.slots = make([][2]int64, 0)
g.nodes = make([]string, 0, len(nodes))
g.nodes = append(g.nodes, master)
for i := 0; i < len(nodes); i++ {
if i == masterPos {
continue
}
dict, _ := nodes[i].ToMap()
dst := parseNodeEndpoint(dict)
if dst == "" {
continue
}
g.nodes = append(g.nodes, dst)
}
}
slots := slotsAndNodes["slots"]
arr, _ := slots.ToArray()
for i := 0; i+1 < len(arr); i += 2 {
start, _ := arr[i].AsInt64()
end, _ := arr[i+1].AsInt64()
g.slots = append(g.slots, [2]int64{start, end})
}
groups[master] = g
}
return groups
}

func (c *clusterClient) _pick(slot uint16) (p conn) {
c.mu.RLock()
if slot == cmds.InitSlot {
Expand Down
Loading

0 comments on commit 256b4a7

Please sign in to comment.