Skip to content

Commit

Permalink
Fixes shutdown deadlock in websocket client (#1221)
Browse files Browse the repository at this point in the history
* Adds a failing test for deadlocks in shutting down the write pump

* Protects access to stopReadPump and isConnected with RWMutex to prevent races

* Makes the send channel buffered to allow re-queueing messages for transmission during writePump shutdown
  • Loading branch information
fhats-stripe authored Jul 26, 2024
1 parent b241c27 commit 7a252a1
Show file tree
Hide file tree
Showing 2 changed files with 106 additions and 39 deletions.
52 changes: 38 additions & 14 deletions pkg/websocket/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,16 +79,30 @@ type Client struct {
// Optional configuration parameters
cfg *Config

conn *ws.Conn
done chan struct{}
isConnected bool

NotifyExpired chan struct{}
notifyClose chan error
send chan *OutgoingMessage
stopReadPump chan struct{}
stopWritePump chan struct{}
wg *sync.WaitGroup
conn *ws.Conn
done chan struct{}
isConnected bool
isConnectedMutex sync.RWMutex

NotifyExpired chan struct{}
notifyClose chan error
send chan *OutgoingMessage
stopReadPumpMutex sync.RWMutex
stopReadPump chan struct{}
stopWritePump chan struct{}
wg *sync.WaitGroup
}

func (c *Client) setIsConnected(newValue bool) {
c.isConnectedMutex.Lock()
defer c.isConnectedMutex.Unlock()
c.isConnected = newValue
}

func (c *Client) getIsConnected() bool {
c.isConnectedMutex.RLock()
defer c.isConnectedMutex.RUnlock()
return c.isConnected
}

// Connected returns a channel that's closed when the client has finished
Expand All @@ -97,7 +111,7 @@ func (c *Client) Connected() <-chan struct{} {
d := make(chan struct{})

go func() {
for !c.isConnected {
for !c.getIsConnected() {
time.Sleep(100 * time.Millisecond)
}
close(d)
Expand All @@ -109,7 +123,7 @@ func (c *Client) Connected() <-chan struct{} {
// Run starts listening for incoming webhook requests from Stripe.
func (c *Client) Run(ctx context.Context) {
for {
c.isConnected = false
c.setIsConnected(false)
c.cfg.Log.WithFields(log.Fields{
"prefix": "websocket.client.Run",
}).Debug("Attempting to connect to Stripe")
Expand Down Expand Up @@ -171,6 +185,8 @@ func (c *Client) Run(ctx context.Context) {
// Close executes a proper closure handshake then closes the connection
// list of close codes: https://datatracker.ietf.org/doc/html/rfc6455#section-7.4
func (c *Client) Close(closeCode int, text string) {
c.stopReadPumpMutex.Lock()
defer c.stopReadPumpMutex.Unlock()
close(c.stopReadPump)
close(c.stopWritePump)
if c.conn != nil {
Expand Down Expand Up @@ -271,7 +287,7 @@ func (c *Client) connect(ctx context.Context) error {
defer resp.Body.Close()

c.changeConnection(conn)
c.isConnected = true
c.setIsConnected(true)

c.wg = &sync.WaitGroup{}
c.wg.Add(2)
Expand All @@ -289,6 +305,8 @@ func (c *Client) connect(ctx context.Context) error {

// changeConnection takes a new connection and recreates the channels.
func (c *Client) changeConnection(conn *ws.Conn) {
c.stopReadPumpMutex.Lock()
defer c.stopReadPumpMutex.Unlock()
c.conn = conn
c.notifyClose = make(chan error)
c.stopReadPump = make(chan struct{})
Expand Down Expand Up @@ -461,6 +479,12 @@ func (c *Client) writePump() {
}
}

func (c *Client) terminateReadPump() {
c.stopReadPumpMutex.Lock()
defer c.stopReadPumpMutex.Unlock()
c.stopReadPump <- struct{}{}
}

//
// Public functions
//
Expand Down Expand Up @@ -513,7 +537,7 @@ func NewClient(url string, webSocketID string, websocketAuthorizedFeature string
WebSocketAuthorizedFeature: websocketAuthorizedFeature,
cfg: cfg,
done: make(chan struct{}),
send: make(chan *OutgoingMessage),
send: make(chan *OutgoingMessage, 10),
NotifyExpired: make(chan struct{}),
}
}
Expand Down
93 changes: 68 additions & 25 deletions pkg/websocket/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package websocket
import (
"context"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"strings"
Expand All @@ -11,7 +12,8 @@ import (
"time"

ws "github.com/gorilla/websocket"
// log "github.com/sirupsen/logrus"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

Expand Down Expand Up @@ -176,55 +178,96 @@ func TestClientExpiredError(t *testing.T) {
}
}

/* func TestClientWebhookReconnect(t *testing.T) {
log.SetLevel(log.DebugLevel)
wg := &sync.WaitGroup{}
wg.Add(20)
// This test is a regression test for deadlocks that can be encountered
// when the write pump is interrupted by closed connections at inopportune
// times.
//
// The goal is to simulate a scenario where the read pump is shut down but the
// client still has messages to send. The read pump should be shut down because
// in the majority of cases it is how the client ends up stopped. However, there's
// no hard synchronization between the read and write pumps so we have to defend
// against race conditions where the read side is shut down, hence this test.
func TestWritePumpInterruptionRequeued(t *testing.T) {
serverReceivedMessages := make(chan string, 10)
wg := sync.WaitGroup{}

upgrader := ws.Upgrader{}
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
wg.Add(1)

require.NotEmpty(t, r.UserAgent())
require.NotEmpty(t, r.Header.Get("X-Stripe-Client-User-Agent"))
require.Equal(t, "websocket-random-id", r.Header.Get("Websocket-Id"))
c, err := upgrader.Upgrade(w, r, nil)
require.NoError(t, err)

defer c.Close()
require.Equal(t, "websocket_feature=webhook-payloads", r.URL.RawQuery)

swg := &sync.WaitGroup{}
swg.Add(1)
defer c.Close()

go func() {
for {
if _, _, err := c.ReadMessage(); err != nil {
swg.Done()
return
}
}
}()
msgType, msg, err := c.ReadMessage()
require.NoError(t, err)
require.Equal(t, msgType, ws.TextMessage)
serverReceivedMessages <- string(msg)

swg.Wait()
// To simulate a forced reconnection, the server closes the connection
// after receiving any messages
c.WriteControl(ws.CloseMessage, ws.FormatCloseMessage(ws.CloseNormalClosure, ""), time.Now().Add(5*time.Second))
c.Close()
wg.Done()
}))

defer ts.Close()

url := "ws" + strings.TrimPrefix(ts.URL, "http")

rcvMsgChan := make(chan WebhookEvent)
client := NewClient(
url,
"websocket-random-id",
"webhook-payloads",
&Config{
EventHandler: EventHandlerFunc(func(msg IncomingMessage) {
rcvMsgChan <- *msg.WebhookEvent
}),
Log: log.StandardLogger(),
ReconnectInterval: 10 * time.Second,
EventHandler: EventHandlerFunc(func(msg IncomingMessage) {}),
WriteWait: 10 * time.Second,
PongWait: 60 * time.Second,
PingPeriod: 60 * time.Hour,
},
)

go client.Run(context.Background())

defer client.Stop()

actualMessages := []string{}
connectedChan := client.Connected()
<-connectedChan
go func() { client.terminateReadPump() }()

for i := 0; i < 2; i++ {
client.SendMessage(NewEventAck(fmt.Sprintf("event_%d", i), fmt.Sprintf("event_%d", i)))
// Needed to deflake the test from racing against itself
// Something to do with the buffering
time.Sleep(100 * time.Millisecond)

msg := <-serverReceivedMessages
actualMessages = append(actualMessages, msg)
wg.Wait()
}

wg.Wait()
} */

for {
exhausted := false
select {
case msg := <-serverReceivedMessages:
actualMessages = append(actualMessages, msg)
default:
exhausted = true
}

if exhausted {
break
}
}

assert.Len(t, actualMessages, 2)
}

0 comments on commit 7a252a1

Please sign in to comment.