From 091b60a8fe803bb8f5930cf9d7e3de6a7aaf5a76 Mon Sep 17 00:00:00 2001 From: Kautilya Tripathi Date: Thu, 14 Nov 2024 10:13:52 +0530 Subject: [PATCH] backend: Fix panic of websockets Now websocket has clear type that is needs to sends. This also fixes panic of websocket in various edge cases. Signed-off-by: Kautilya Tripathi --- backend/cmd/multiplexer.go | 278 ++++++++++++++++++++------------ backend/cmd/multiplexer_test.go | 9 +- 2 files changed, 181 insertions(+), 106 deletions(-) diff --git a/backend/cmd/multiplexer.go b/backend/cmd/multiplexer.go index 54c6fe51e9..2383b2f61d 100644 --- a/backend/cmd/multiplexer.go +++ b/backend/cmd/multiplexer.go @@ -68,6 +68,10 @@ type Connection struct { Done chan struct{} // mu is a mutex to synchronize access to the connection. mu sync.RWMutex + // writeMu is a mutex to synchronize access to the write operations. + writeMu sync.Mutex + // closed is a flag to indicate if the connection is closed. + closed bool } // Message represents a WebSocket message structure. @@ -81,7 +85,9 @@ type Message struct { // UserID is the ID of the user. UserID string `json:"userId"` // Data contains the message payload. - Data []byte `json:"data,omitempty"` + Data string `json:"data,omitempty"` + // Binary is a flag to indicate if the message is binary. + Binary bool `json:"binary,omitempty"` // Type is the type of the message. Type string `json:"type"` } @@ -116,41 +122,58 @@ func (c *Connection) updateStatus(state ConnectionState, err error) { c.mu.Lock() defer c.mu.Unlock() + if c.closed { + return + } + c.Status.State = state c.Status.LastMsg = time.Now() + c.Status.Error = "" if err != nil { c.Status.Error = err.Error() - } else { - c.Status.Error = "" } - if c.Client != nil { - statusData := struct { - State string `json:"state"` - Error string `json:"error"` - }{ - State: string(state), - Error: c.Status.Error, - } + if c.Client == nil { + return + } - jsonData, jsonErr := json.Marshal(statusData) - if jsonErr != nil { - logger.Log(logger.LevelError, map[string]string{"clusterID": c.ClusterID}, jsonErr, "marshaling status message") + c.writeMu.Lock() + defer c.writeMu.Unlock() - return - } + // Check if connection is closed before writing + if c.closed { + return + } - statusMsg := Message{ - ClusterID: c.ClusterID, - Path: c.Path, - Data: jsonData, - } + statusData := struct { + State string `json:"state"` + Error string `json:"error"` + }{ + State: string(state), + Error: c.Status.Error, + } - err := c.Client.WriteJSON(statusMsg) - if err != nil { + jsonData, jsonErr := json.Marshal(statusData) + if jsonErr != nil { + logger.Log(logger.LevelError, map[string]string{"clusterID": c.ClusterID}, jsonErr, "marshaling status message") + + return + } + + statusMsg := Message{ + ClusterID: c.ClusterID, + Path: c.Path, + Data: string(jsonData), + Type: "STATUS", + } + + if err := c.Client.WriteJSON(statusMsg); err != nil { + if !websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) { logger.Log(logger.LevelError, map[string]string{"clusterID": c.ClusterID}, err, "writing status message to client") } + + c.closed = true } } @@ -190,7 +213,8 @@ func (m *Multiplexer) establishClusterConnection( connection.updateStatus(StateConnected, nil) m.mutex.Lock() - m.connections[clusterID+path] = connection + connKey := fmt.Sprintf("%s:%s:%s", clusterID, path, userID) + m.connections[connKey] = connection m.mutex.Unlock() go m.monitorConnection(connection) @@ -334,7 +358,7 @@ func (m *Multiplexer) HandleClientWebSocket(w http.ResponseWriter, r *http.Reque } // Check if it's a close message - if msg.Data != nil && len(msg.Data) > 0 && string(msg.Data) == "close" { + if msg.Type == "CLOSE" { err := m.CloseConnection(msg.ClusterID, msg.Path, msg.UserID) if err != nil { logger.Log( @@ -355,8 +379,8 @@ func (m *Multiplexer) HandleClientWebSocket(w http.ResponseWriter, r *http.Reque continue } - if len(msg.Data) > 0 && conn.Status.State == StateConnected { - err = m.writeMessageToCluster(conn, msg.Data) + if msg.Type == "REQUEST" && conn.Status.State == StateConnected { + err = m.writeMessageToCluster(conn, []byte(msg.Data)) if err != nil { continue } @@ -458,100 +482,149 @@ func (m *Multiplexer) writeMessageToCluster(conn *Connection, data []byte) error // handleClusterMessages handles messages from a cluster connection. func (m *Multiplexer) handleClusterMessages(conn *Connection, clientConn *websocket.Conn) { - defer func() { - conn.updateStatus(StateClosed, nil) - conn.WSConn.Close() - }() + defer m.cleanupConnection(conn) + + var lastResourceVersion string for { select { case <-conn.Done: return default: - if err := m.processClusterMessage(conn, clientConn); err != nil { + if err := m.processClusterMessage(conn, clientConn, &lastResourceVersion); err != nil { return } } } } -// processClusterMessage processes a message from a cluster connection. -func (m *Multiplexer) processClusterMessage(conn *Connection, clientConn *websocket.Conn) error { +// processClusterMessage processes a single message from the cluster. +func (m *Multiplexer) processClusterMessage( + conn *Connection, + clientConn *websocket.Conn, + lastResourceVersion *string, +) error { messageType, message, err := conn.WSConn.ReadMessage() if err != nil { - m.handleReadError(conn, err) + if websocket.IsUnexpectedCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) { + logger.Log(logger.LevelError, + map[string]string{ + "clusterID": conn.ClusterID, + "userID": conn.UserID, + }, + err, + "reading cluster message", + ) + } return err } - wrapperMsg := m.createWrapperMessage(conn, messageType, message) + if err := m.checkResourceVersion(message, conn, clientConn, lastResourceVersion); err != nil { + return err + } - if err := clientConn.WriteJSON(wrapperMsg); err != nil { - m.handleWriteError(conn, err) + return m.sendDataMessage(conn, clientConn, messageType, message) +} - return err +// checkResourceVersion checks and handles resource version changes. +func (m *Multiplexer) checkResourceVersion( + message []byte, + conn *Connection, + clientConn *websocket.Conn, + lastResourceVersion *string, +) error { + var obj map[string]interface{} + if err := json.Unmarshal(message, &obj); err != nil { + return nil // Ignore unmarshalling errors for resource version check } - conn.mu.Lock() - conn.Status.LastMsg = time.Now() - conn.mu.Unlock() + if metadata, ok := obj["metadata"].(map[string]interface{}); ok { + if rv, ok := metadata["resourceVersion"].(string); ok { + if *lastResourceVersion != "" && rv != *lastResourceVersion { + return m.sendCompleteMessage(conn, clientConn) + } + + *lastResourceVersion = rv + } + } return nil } -// createWrapperMessage creates a wrapper message for a cluster connection. -func (m *Multiplexer) createWrapperMessage(conn *Connection, messageType int, message []byte) struct { - ClusterID string `json:"clusterId"` - Path string `json:"path"` - Query string `json:"query"` - UserID string `json:"userId"` - Data string `json:"data"` - Binary bool `json:"binary"` -} { - wrapperMsg := struct { - ClusterID string `json:"clusterId"` - Path string `json:"path"` - Query string `json:"query"` - UserID string `json:"userId"` - Data string `json:"data"` - Binary bool `json:"binary"` - }{ +// sendCompleteMessage sends a COMPLETE message to the client. +func (m *Multiplexer) sendCompleteMessage(conn *Connection, clientConn *websocket.Conn) error { + completeMsg := Message{ ClusterID: conn.ClusterID, Path: conn.Path, Query: conn.Query, UserID: conn.UserID, - Binary: messageType == websocket.BinaryMessage, + Type: "COMPLETE", } - if messageType == websocket.BinaryMessage { - wrapperMsg.Data = base64.StdEncoding.EncodeToString(message) - } else { - wrapperMsg.Data = string(message) + conn.writeMu.Lock() + defer conn.writeMu.Unlock() + + return clientConn.WriteJSON(completeMsg) +} + +// sendDataMessage sends the actual data message to the client. +func (m *Multiplexer) sendDataMessage( + conn *Connection, + clientConn *websocket.Conn, + messageType int, + message []byte, +) error { + dataMsg := m.createWrapperMessage(conn, messageType, message) + + conn.writeMu.Lock() + defer conn.writeMu.Unlock() + + if err := clientConn.WriteJSON(dataMsg); err != nil { + return err } - return wrapperMsg + conn.mu.Lock() + conn.Status.LastMsg = time.Now() + conn.mu.Unlock() + + return nil } -// handleReadError handles errors that occur when reading a message from a cluster connection. -func (m *Multiplexer) handleReadError(conn *Connection, err error) { - conn.updateStatus(StateError, err) - logger.Log( - logger.LevelError, - map[string]string{"clusterID": conn.ClusterID, "UserID": conn.UserID}, - err, - "reading message from cluster", - ) +// cleanupConnection performs cleanup for a connection. +func (m *Multiplexer) cleanupConnection(conn *Connection) { + conn.mu.Lock() + conn.closed = true + conn.mu.Unlock() + + if conn.WSConn != nil { + conn.WSConn.Close() + } + + m.mutex.Lock() + connKey := fmt.Sprintf("%s:%s:%s", conn.ClusterID, conn.Path, conn.UserID) + delete(m.connections, connKey) + m.mutex.Unlock() } -// handleWriteError handles errors that occur when writing a message to a client connection. -func (m *Multiplexer) handleWriteError(conn *Connection, err error) { - conn.updateStatus(StateError, err) - logger.Log( - logger.LevelError, - map[string]string{"clusterID": conn.ClusterID, "UserID": conn.UserID}, - err, - "writing message to client", - ) +// createWrapperMessage creates a wrapper message for a cluster connection. +func (m *Multiplexer) createWrapperMessage(conn *Connection, messageType int, message []byte) Message { + var data string + if messageType == websocket.BinaryMessage { + data = base64.StdEncoding.EncodeToString(message) + } else { + data = string(message) + } + + return Message{ + ClusterID: conn.ClusterID, + Path: conn.Path, + Query: conn.Query, + UserID: conn.UserID, + Data: data, + Binary: messageType == websocket.BinaryMessage, + Type: "DATA", + } } // cleanupConnections closes and removes all connections. @@ -587,37 +660,42 @@ func (m *Multiplexer) getClusterConfig(clusterID string) (*rest.Config, error) { } // CloseConnection closes a specific connection based on its identifier. +// +//nolint:unparam func (m *Multiplexer) CloseConnection(clusterID, path, userID string) error { connKey := fmt.Sprintf("%s:%s:%s", clusterID, path, userID) m.mutex.Lock() - defer m.mutex.Unlock() conn, exists := m.connections[connKey] if !exists { - return fmt.Errorf("connection not found for key: %s", connKey) + m.mutex.Unlock() + // Don't log error for non-existent connections during cleanup + return nil } - // Signal the connection to close - close(conn.Done) + // Mark as closed before releasing the lock + conn.mu.Lock() + if conn.closed { + conn.mu.Unlock() + m.mutex.Unlock() + logger.Log(logger.LevelError, map[string]string{"clusterID": conn.ClusterID}, nil, "closing connection") - // Close the WebSocket connection - if conn.WSConn != nil { - if err := conn.WSConn.Close(); err != nil { - logger.Log( - logger.LevelError, - map[string]string{"clusterID": clusterID, "userID": userID}, - err, - "closing WebSocket connection", - ) - } + return nil } - // Update the connection status - conn.updateStatus(StateClosed, nil) + conn.closed = true + conn.mu.Unlock() - // Remove the connection from the map delete(m.connections, connKey) + m.mutex.Unlock() + + // Close the Done channel and connections after removing from map + close(conn.Done) + + if conn.WSConn != nil { + conn.WSConn.Close() + } return nil } diff --git a/backend/cmd/multiplexer_test.go b/backend/cmd/multiplexer_test.go index 058e01377b..1772091e8a 100644 --- a/backend/cmd/multiplexer_test.go +++ b/backend/cmd/multiplexer_test.go @@ -225,7 +225,7 @@ func TestHandleClusterMessages(t *testing.T) { t.Fatal("Test timed out") } - assert.Equal(t, StateClosed, conn.Status.State) + assert.Equal(t, StateConnecting, conn.Status.State) } func TestCleanupConnections(t *testing.T) { @@ -295,11 +295,8 @@ func TestCloseConnection(t *testing.T) { err := m.CloseConnection("test-cluster", "/api/v1/pods", "test-user") assert.NoError(t, err) assert.Empty(t, m.connections) - assert.Equal(t, StateClosed, conn.Status.State) - - // Test closing a non-existent connection - err = m.CloseConnection("non-existent", "/api/v1/pods", "test-user") - assert.Error(t, err) + // It will reconnect to the cluster + assert.Equal(t, StateConnecting, conn.Status.State) } func TestCreateWrapperMessage(t *testing.T) {