Skip to content

Commit

Permalink
fix: websocket disconnect and goroutine leak (#220)
Browse files Browse the repository at this point in the history
  • Loading branch information
hunjixin authored Jul 15, 2024
1 parent 23bfe47 commit cca7510
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 63 deletions.
110 changes: 54 additions & 56 deletions pkg/http/websocket_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package http

import (
"context"
"sync"
"time"

"github.com/gorilla/websocket"
Expand All @@ -11,69 +12,66 @@ import (
// ConnectWebSocket establishes a new WebSocket connection
func ConnectWebSocket(
url string,
messageChan chan []byte,
ctx context.Context,
) *websocket.Conn {
closed := false

var conn *websocket.Conn

// if we ever get a cancellation from the context, try to close the connection
go func() {
<-ctx.Done()
closed = true
if conn != nil {
conn.Close()
) chan []byte {
connectFactory := func() *websocket.Conn {
for {
log.Debug().Msgf("WebSocket connection connecting: %s", url)
conn, _, err := websocket.DefaultDialer.Dial(url, nil)
if err != nil {
log.Error().Msgf("WebSocket connection failed: %s\nReconnecting in 2 seconds...", err)
time.Sleep(2 * time.Second)
continue
}
conn.SetPongHandler(nil)
return conn
}
}()
}

pingInterval := time.NewTicker(time.Second * 5)
connLk := &sync.Mutex{}
responseCh := make(chan []byte)
errCh := make(chan error)

// retry connecting until we get a connection
for {
var err error
log.Debug().Msgf("WebSocket connection connecting: %s", url)
conn, _, err = websocket.DefaultDialer.Dial(url, nil)
if err != nil {
log.Error().Msgf("WebSocket connection failed: %s\nReconnecting in 2 seconds...", err)
if closed {
break
readMessage := func(conn *websocket.Conn) {
for {
messageType, p, err := conn.ReadMessage()
if err != nil {
errCh <- err
return
}
if messageType == websocket.TextMessage {
log.Debug().
Str("action", "ws READ").
Str("payload", string(p)).
Msgf("")
responseCh <- p
}
time.Sleep(2 * time.Second)
continue
}
break
}

// now that we have a connection, if we haven't been closed yet, forever
// read from the connection and send messages down the channel, unless we
// fail a read in which case we try to reconnect
if !closed {
go func() {
for {
messageType, p, err := conn.ReadMessage()
if err != nil {
if closed {
return
}
log.Error().Msgf("Read error: %s\nReconnecting in 2 seconds...", err)
time.Sleep(2 * time.Second)
conn = ConnectWebSocket(url, messageChan, ctx)
// exit this goroutine now, another one will be spawned if
// the recursive call to ConnectWebSocket succeeds. Not
// exiting this goroutine here will cause goroutines to pile
// up forever concurrently calling conn.ReadMessage(), which
// is not thread-safe.
return
}
if messageType == websocket.TextMessage {
log.Debug().
Str("action", "ws READ").
Str("payload", string(p)).
Msgf("")
messageChan <- p
conn := connectFactory()
go readMessage(conn)
go func() {
for {
select {
case <-pingInterval.C:
connLk.Lock()
log.Trace().Msg("send ping message")
if err := conn.WriteMessage(websocket.PingMessage, []byte{}); err != nil {
log.Err(err).Msg("sending ping message")
connLk.Unlock()
continue
}
connLk.Unlock()
case err := <-errCh:
log.Err(err).Msg("websocket error")
connLk.Lock()
conn = connectFactory()
connLk.Unlock()
go readMessage(conn)
}
}()
}

return conn
}
}()
return responseCh
}
2 changes: 2 additions & 0 deletions pkg/http/websocket_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,8 @@ func StartWebSocketServer(
log.Error().Msgf("Error upgrading websocket: %s", err.Error())
return
}

conn.SetPingHandler(nil)
params := r.URL.Query()
connParams := WSConnectionParams{
ID: params.Get("ID"),
Expand Down
11 changes: 4 additions & 7 deletions pkg/solver/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@ func NewSolverClient(

// connect the websocket to the solver server
func (client *SolverClient) Start(ctx context.Context, cm *system.CleanupManager) error {
websocketEventChannel := make(chan []byte)

websocketURL := fmt.Sprintf("%s%s%s%s%s", http.WEBSOCKET_SUB_PATH, "?&Type=", client.options.Type, "&ID=", client.options.PublicAddress)
websocketEventChannel := http.ConnectWebSocket(http.WebsocketURL(client.options, websocketURL), ctx)
go func() {
for {
select {
Expand All @@ -49,12 +51,7 @@ func (client *SolverClient) Start(ctx context.Context, cm *system.CleanupManager
}
}
}()
websocketURL := fmt.Sprintf("%s%s%s%s%s", http.WEBSOCKET_SUB_PATH, "?&Type=", client.options.Type, "&ID=", client.options.PublicAddress)
http.ConnectWebSocket(
http.WebsocketURL(client.options, websocketURL),
websocketEventChannel,
ctx,
)

return nil
}

Expand Down

0 comments on commit cca7510

Please sign in to comment.