Skip to content

Commit

Permalink
fix: deadlocks in Redis pub/sub
Browse files Browse the repository at this point in the history
  • Loading branch information
palkan committed Aug 28, 2024
1 parent 6249e65 commit 57ac522
Show file tree
Hide file tree
Showing 4 changed files with 127 additions and 90 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

## master

- Fix potential deadlocks in Redis pub/sub on reconnect. ([@palkan][])

## 1.5.2 (2024-06-04)

- Add `?raw=1` option for EventSource connections to receive only data messages (no protocol messages). ([@palkan][])
Expand Down
8 changes: 6 additions & 2 deletions broadcast/redis.go
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,12 @@ func (s *RedisBroadcaster) runReader(done chan (error)) {
return
}

if s.reconnectAttempt > 0 {
s.log.Info("reconnected to Redis")
}

s.reconnectAttempt = 0

// First, create a consumer group for the stream
err = s.client.Do(context.Background(),
s.client.B().XgroupCreate().Key(s.config.Channel).Group(s.config.Group).Id("$").Mkstream().Build(),
Expand All @@ -153,8 +159,6 @@ func (s *RedisBroadcaster) runReader(done chan (error)) {
}
}

s.reconnectAttempt = 0

readBlockMilliseconds := s.config.StreamReadBlockMilliseconds
var lastClaimedAt int64

Expand Down
188 changes: 113 additions & 75 deletions pubsub/redis.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,23 @@ import (
rconfig "github.com/anycable/anycable-go/redis"
"github.com/anycable/anycable-go/utils"
"github.com/redis/rueidis"
"golang.org/x/exp/maps"
)

type subscriptionState = int
type subscriptionCmd = int

const (
subscriptionPending subscriptionState = iota
subscriptionCreated
subscriptionPendingUnsubscribe
subscribeCmd subscriptionCmd = iota
unsubscribeCmd
)

type clientCommand struct {
cmd subscriptionCmd
id string
}

type subscriptionEntry struct {
id string
state subscriptionState
id string
}

type RedisSubscriber struct {
Expand All @@ -40,10 +44,16 @@ type RedisSubscriber struct {
subscriptions map[string]*subscriptionEntry
subMu sync.RWMutex

streamsCh chan (*subscriptionEntry)
commandsCh chan (*clientCommand)
shutdownCh chan struct{}

log *slog.Logger

// test-only
// TODO: refactor tests to not depend on internals
events map[string]subscriptionCmd
eventsMu sync.Mutex
trackingEvents bool
}

var _ Subscriber = (*RedisSubscriber)(nil)
Expand All @@ -57,13 +67,15 @@ func NewRedisSubscriber(node Handler, config *rconfig.RedisConfig, l *slog.Logge
}

return &RedisSubscriber{
node: node,
config: config,
clientOptions: options,
subscriptions: make(map[string]*subscriptionEntry),
log: l.With("context", "pubsub"),
streamsCh: make(chan *subscriptionEntry, 1024),
shutdownCh: make(chan struct{}),
node: node,
config: config,
clientOptions: options,
subscriptions: make(map[string]*subscriptionEntry),
log: l.With("context", "pubsub"),
commandsCh: make(chan *clientCommand, 2),
shutdownCh: make(chan struct{}),
trackingEvents: false,
events: make(map[string]subscriptionCmd),
}, nil
}

Expand All @@ -76,9 +88,12 @@ func (s *RedisSubscriber) Start(done chan (error)) error {
s.log.Info(fmt.Sprintf("Starting Redis pub/sub: %s", s.config.Hostname()))
}

go s.runPubSub(done)
// Add internal channel to subscriptions
s.subMu.Lock()
s.subscriptions[s.config.InternalChannel] = &subscriptionEntry{id: s.config.InternalChannel}
s.subMu.Unlock()

s.Subscribe(s.config.InternalChannel)
go s.runPubSub(done)

return nil
}
Expand Down Expand Up @@ -106,11 +121,11 @@ func (s *RedisSubscriber) IsMultiNode() bool {

func (s *RedisSubscriber) Subscribe(stream string) {
s.subMu.Lock()
s.subscriptions[stream] = &subscriptionEntry{state: subscriptionPending, id: stream}
s.subscriptions[stream] = &subscriptionEntry{id: stream}
entry := s.subscriptions[stream]
s.subMu.Unlock()

s.streamsCh <- entry
s.commandsCh <- &clientCommand{cmd: subscribeCmd, id: entry.id}
}

func (s *RedisSubscriber) Unsubscribe(stream string) {
Expand All @@ -120,11 +135,10 @@ func (s *RedisSubscriber) Unsubscribe(stream string) {
return
}

entry := s.subscriptions[stream]
entry.state = subscriptionPendingUnsubscribe

s.streamsCh <- entry
delete(s.subscriptions, stream)
s.subMu.Unlock()

s.commandsCh <- &clientCommand{cmd: unsubscribeCmd, id: stream}
}

func (s *RedisSubscriber) Broadcast(msg *common.StreamMessage) {
Expand Down Expand Up @@ -184,29 +198,21 @@ func (s *RedisSubscriber) runPubSub(done chan (error)) {
client, cancel := s.client.Dedicate()
defer cancel()

s.log.Debug("initialized pub/sub client")

wait := client.SetPubSubHooks(rueidis.PubSubHooks{
OnSubscription: func(m rueidis.PubSubSubscription) {
s.subMu.Lock()
defer s.subMu.Unlock()

if m.Kind == "subscribe" && m.Channel == s.config.InternalChannel {
if s.reconnectAttempt > 0 {
s.log.Info("reconnected to Redis")
s.log.Info("reconnected")
} else {
s.log.Info("connected")
}
s.reconnectAttempt = 0
}

if entry, ok := s.subscriptions[m.Channel]; ok {
if entry.state == subscriptionPending && m.Kind == "subscribe" {
s.log.With("channel", m.Channel).Debug("subscribed")
entry.state = subscriptionCreated
}

if entry.state == subscriptionPendingUnsubscribe && m.Kind == "unsubscribe" {
s.log.With("channel", m.Channel).Debug("unsubscribed")
delete(s.subscriptions, entry.id)
}
}
s.log.With("channel", m.Channel).Debug(m.Kind)
s.trackEvent(m.Kind, m.Channel)
},
OnMessage: func(m rueidis.PubSubMessage) {
msg, err := common.PubSubMessageFromJSON([]byte(m.Message))
Expand All @@ -227,6 +233,8 @@ func (s *RedisSubscriber) runPubSub(done chan (error)) {
},
})

s.resubscribe(client)

for {
select {
case err := <-wait:
Expand All @@ -240,32 +248,21 @@ func (s *RedisSubscriber) runPubSub(done chan (error)) {
case <-s.shutdownCh:
s.log.Debug("close pub/sub channel")
return
case entry := <-s.streamsCh:
case entry := <-s.commandsCh:
ctx := context.Background()

switch entry.state {
case subscriptionPending:
switch entry.cmd {
case subscribeCmd:
s.log.With("channel", entry.id).Debug("subscribing")
client.Do(ctx, client.B().Subscribe().Channel(entry.id).Build())
case subscriptionPendingUnsubscribe:
case unsubscribeCmd:
s.log.With("channel", entry.id).Debug("unsubscribing")
client.Do(ctx, client.B().Unsubscribe().Channel(entry.id).Build())
}
}
}
}

func (s *RedisSubscriber) subscriptionEntry(stream string) *subscriptionEntry {
s.subMu.RLock()
defer s.subMu.RUnlock()

if entry, ok := s.subscriptions[stream]; ok {
return entry
}

return nil
}

func (s *RedisSubscriber) maybeReconnect(done chan (error)) {
if s.reconnectAttempt >= s.config.MaxReconnectAttempts {
done <- errors.New("failed to reconnect to Redis: attempts exceeded") //nolint:stylecheck
Expand All @@ -280,24 +277,6 @@ func (s *RedisSubscriber) maybeReconnect(done chan (error)) {
}
s.clientMu.RUnlock()

s.subMu.Lock()
toRemove := []string{}

for key, sub := range s.subscriptions {
if sub.state == subscriptionCreated {
sub.state = subscriptionPending
}

if sub.state == subscriptionPendingUnsubscribe {
toRemove = append(toRemove, key)
}
}

for _, key := range toRemove {
delete(s.subscriptions, key)
}
s.subMu.Unlock()

s.reconnectAttempt++

delay := utils.NextRetry(s.reconnectAttempt - 1)
Expand All @@ -308,14 +287,73 @@ func (s *RedisSubscriber) maybeReconnect(done chan (error)) {
s.log.Info("reconnecting to Redis...")

go s.runPubSub(done)
}

const batchSubscribeSize = 256

func (s *RedisSubscriber) resubscribe(client rueidis.DedicatedClient) {
s.subMu.RLock()
defer s.subMu.RUnlock()
channels := maps.Keys(s.subscriptions)
s.subMu.RUnlock()

batch := make([]string, 0, batchSubscribeSize)

for i, id := range channels {
if i > 0 && i%batchSubscribeSize == 0 {
err := batchSubscribe(client, batch)
if err != nil {
s.log.Error("failed to resubscribe", "error", err)
return
}
batch = batch[:0]
}

batch = append(batch, id)
}

for _, sub := range s.subscriptions {
if sub.state == subscriptionPending {
s.log.Debug("resubscribing to stream", "stream", sub.id)
s.streamsCh <- sub
if len(batch) > 0 {
err := batchSubscribe(client, batch)
if err != nil {
s.log.Error("failed to resubscribe", "error", err)
return
}
}
}

func batchSubscribe(client rueidis.DedicatedClient, channels []string) error {
if len(channels) == 0 {
return nil
}

return client.Do(context.Background(), client.B().Subscribe().Channel(channels...).Build()).Error()
}

// test-only
func (s *RedisSubscriber) trackEvent(event string, channel string) {
if !s.trackingEvents {
return
}

s.eventsMu.Lock()
defer s.eventsMu.Unlock()

if event == "subscribe" {
s.events[channel] = subscribeCmd
} else if event == "unsubscribe" {
s.events[channel] = unsubscribeCmd
}
}

// test-only
func (s *RedisSubscriber) getEvent(channel string) subscriptionCmd {
s.eventsMu.Lock()
defer s.eventsMu.Unlock()

cmd, ok := s.events[channel]

if !ok {
return unsubscribeCmd
}

return cmd
}
19 changes: 6 additions & 13 deletions pubsub/redis_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ func TestRedisCommon(t *testing.T) {

SharedSubscriberTests(t, func(handler *TestHandler) Subscriber {
sub, err := NewRedisSubscriber(handler, &config, slog.Default())
sub.trackingEvents = true

if err != nil {
panic(err)
Expand All @@ -97,6 +98,8 @@ func TestRedisReconnect(t *testing.T) {
subscriber, err := NewRedisSubscriber(handler, &config, slog.Default())
require.NoError(t, err)

subscriber.trackingEvents = true

done := make(chan error)

err = subscriber.Start(done)
Expand Down Expand Up @@ -153,24 +156,14 @@ func waitRedisSubscription(subscriber Subscriber, stream string) error {
}
}

s.subMu.RLock()
entry := s.subscriptionEntry(stream)
state := subscriptionPending
if entry != nil {
state = entry.state
}
s.subMu.RUnlock()
event := s.getEvent(stream)

if unsubscribing {
if entry == nil {
if event == unsubscribeCmd {
return nil
}
} else {
if entry == nil {
return fmt.Errorf("No pending subscription: %s", stream)
}

if state == subscriptionCreated {
if event == subscribeCmd {
return nil
}
}
Expand Down

0 comments on commit 57ac522

Please sign in to comment.