diff --git a/engine/access/rest/websockets/controller.go b/engine/access/rest/websockets/controller.go index 07f76c93fba..fcbff1b6299 100644 --- a/engine/access/rest/websockets/controller.go +++ b/engine/access/rest/websockets/controller.go @@ -21,7 +21,9 @@ type Controller struct { config Config conn WebsocketConnection - communicationChannel chan interface{} // Channel for sending messages to the client. + // data channel which data providers write messages to. + // writer routine reads from this channel and writes messages to connection + multiplexedStream chan interface{} dataProviders *concurrentmap.Map[uuid.UUID, dp.DataProvider] dataProviderFactory dp.DataProviderFactory @@ -34,12 +36,12 @@ func NewWebSocketController( dataProviderFactory dp.DataProviderFactory, ) *Controller { return &Controller{ - logger: logger.With().Str("component", "websocket-controller").Logger(), - config: config, - conn: conn, - communicationChannel: make(chan interface{}), //TODO: should it be buffered chan? - dataProviders: concurrentmap.New[uuid.UUID, dp.DataProvider](), - dataProviderFactory: dataProviderFactory, + logger: logger.With().Str("component", "websocket-controller").Logger(), + config: config, + conn: conn, + multiplexedStream: make(chan interface{}), + dataProviders: concurrentmap.New[uuid.UUID, dp.DataProvider](), + dataProviderFactory: dataProviderFactory, } } @@ -49,37 +51,28 @@ func NewWebSocketController( // Parameters: // - ctx: The context for controlling cancellation and timeouts. func (c *Controller) HandleConnection(ctx context.Context) { + defer c.shutdownConnection() - // configuring the connection with appropriate read/write deadlines and handlers. err := c.configureKeepalive() if err != nil { - // TODO: add error handling here c.logger.Error().Err(err).Msg("error configuring connection") - c.shutdownConnection() return } - //TODO: spin up a response limit tracker routine - - // for track all goroutines and error handling g, gCtx := errgroup.WithContext(ctx) g.Go(func() error { - return c.readMessagesFromClient(gCtx) + return c.readMessages(gCtx) }) - g.Go(func() error { return c.keepalive(gCtx) }) - g.Go(func() error { - return c.writeMessagesToClient(gCtx) + return c.writeMessages(gCtx) }) if err = g.Wait(); err != nil { - //TODO: add error handling here c.logger.Error().Err(err).Msg("error detected in one of the goroutines") - c.shutdownConnection() } } @@ -103,6 +96,7 @@ func (c *Controller) configureKeepalive() error { if err := c.conn.SetReadDeadline(time.Now().Add(PongWait)); err != nil { return fmt.Errorf("failed to set the initial read deadline: %w", err) } + // Establish a Pong handler which sets the handler for pong messages received from the peer. c.conn.SetPongHandler(func(string) error { return c.conn.SetReadDeadline(time.Now().Add(PongWait)) @@ -111,73 +105,63 @@ func (c *Controller) configureKeepalive() error { return nil } -// writeMessagesToClient reads a messages from communication channel and passes them on to a client WebSocket connection. +// writeMessages reads a messages from communication channel and passes them on to a client WebSocket connection. // The communication channel is filled by data providers. Besides, the response limit tracker is involved in // write message regulation // // Expected errors during normal operation: // - context.Canceled if the client disconnected -func (c *Controller) writeMessagesToClient(ctx context.Context) error { +func (c *Controller) writeMessages(ctx context.Context) error { for { select { case <-ctx.Done(): return ctx.Err() - case msg, ok := <-c.communicationChannel: + case msg, ok := <-c.multiplexedStream: if !ok { - err := fmt.Errorf("communication channel closed, no error occurred") - return err + return nil } - // TODO: handle 'response per second' limits - // Specifies a timeout for the write operation. If the write - // isn't completed within this duration, it fails with a timeout error. - // SetWriteDeadline ensures the write operation does not block indefinitely - // if the client is slow or unresponsive. This prevents resource exhaustion - // and allows the server to gracefully handle timeouts for delayed writes. if err := c.conn.SetWriteDeadline(time.Now().Add(WriteWait)); err != nil { - c.logger.Error().Err(err).Msg("failed to set the write deadline") - return err + return fmt.Errorf("failed to set the write deadline: %w", err) } + err := c.conn.WriteJSON(msg) if err != nil { - c.logger.Error().Err(err).Msg("error writing to connection") - return err + if IsCloseError(err) { + return nil + } + c.logger.Error().Err(err).Msg("failed to write msg to connection") } } } } -// readMessagesFromClient continuously reads messages from a client WebSocket connection, +// readMessages continuously reads messages from a client WebSocket connection, // processes each message, and handles actions based on the message type. // // Expected errors during normal operation: // - context.Canceled if the client disconnected -func (c *Controller) readMessagesFromClient(ctx context.Context) error { +func (c *Controller) readMessages(ctx context.Context) error { for { - select { - case <-ctx.Done(): - c.logger.Info().Msg("context canceled, stopping read message loop") - return ctx.Err() - default: - msg, err := c.readMessage() - if err != nil { - if websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseAbnormalClosure) { - return nil - } - c.logger.Warn().Err(err).Msg("error reading message from client") - return fmt.Errorf("failed to read message from client: %w", err) + msg, err := c.readMessage() + if err != nil { + if IsCloseError(err) { + return nil } - baseMsg, validatedMsg, err := c.parseAndValidateMessage(msg) - if err != nil { - c.logger.Debug().Err(err).Msg("error parsing and validating client message") - return fmt.Errorf("failed to parse and validate client message: %w", err) - } + c.logger.Error().Err(err).Msg("error reading message") + continue + } - if err := c.handleAction(ctx, validatedMsg); err != nil { - c.logger.Warn().Err(err).Str("action", baseMsg.Action).Msg("error handling action") - return fmt.Errorf("failed to handle message action: %w", err) - } + validatedMsg, err := c.parseAndValidateMessage(msg) + if err != nil { + c.logger.Error().Err(err).Msg("failed to parse message") + continue + } + + if err := c.handleAction(ctx, validatedMsg); err != nil { + c.logger.Error().Err(err).Msg("failed to handle action") + continue } } } @@ -190,10 +174,10 @@ func (c *Controller) readMessage() (json.RawMessage, error) { return message, nil } -func (c *Controller) parseAndValidateMessage(message json.RawMessage) (models.BaseMessageRequest, interface{}, error) { +func (c *Controller) parseAndValidateMessage(message json.RawMessage) (interface{}, error) { var baseMsg models.BaseMessageRequest if err := json.Unmarshal(message, &baseMsg); err != nil { - return models.BaseMessageRequest{}, nil, fmt.Errorf("error unmarshalling base message: %w", err) + return nil, fmt.Errorf("error unmarshalling base message: %w", err) } var validatedMsg interface{} @@ -201,7 +185,7 @@ func (c *Controller) parseAndValidateMessage(message json.RawMessage) (models.Ba case "subscribe": var subscribeMsg models.SubscribeMessageRequest if err := json.Unmarshal(message, &subscribeMsg); err != nil { - return baseMsg, nil, fmt.Errorf("error unmarshalling subscribe message: %w", err) + return nil, fmt.Errorf("error unmarshalling subscribe message: %w", err) } //TODO: add validation logic for `topic` field validatedMsg = subscribeMsg @@ -209,23 +193,23 @@ func (c *Controller) parseAndValidateMessage(message json.RawMessage) (models.Ba case "unsubscribe": var unsubscribeMsg models.UnsubscribeMessageRequest if err := json.Unmarshal(message, &unsubscribeMsg); err != nil { - return baseMsg, nil, fmt.Errorf("error unmarshalling unsubscribe message: %w", err) + return nil, fmt.Errorf("error unmarshalling unsubscribe message: %w", err) } validatedMsg = unsubscribeMsg case "list_subscriptions": var listMsg models.ListSubscriptionsMessageRequest if err := json.Unmarshal(message, &listMsg); err != nil { - return baseMsg, nil, fmt.Errorf("error unmarshalling list subscriptions message: %w", err) + return nil, fmt.Errorf("error unmarshalling list subscriptions message: %w", err) } validatedMsg = listMsg default: c.logger.Debug().Str("action", baseMsg.Action).Msg("unknown action type") - return baseMsg, nil, fmt.Errorf("unknown action type: %s", baseMsg.Action) + return nil, fmt.Errorf("unknown action type: %s", baseMsg.Action) } - return baseMsg, validatedMsg, nil + return validatedMsg, nil } func (c *Controller) handleAction(ctx context.Context, message interface{}) error { @@ -243,7 +227,7 @@ func (c *Controller) handleAction(ctx context.Context, message interface{}) erro } func (c *Controller) handleSubscribe(ctx context.Context, msg models.SubscribeMessageRequest) { - dp, err := c.dataProviderFactory.NewDataProvider(ctx, msg.Topic, msg.Arguments, c.communicationChannel) + dp, err := c.dataProviderFactory.NewDataProvider(ctx, msg.Topic, msg.Arguments, c.multiplexedStream) if err != nil { // TODO: handle error here c.logger.Error().Err(err).Msgf("error while creating data provider for topic: %s", msg.Topic) @@ -252,7 +236,7 @@ func (c *Controller) handleSubscribe(ctx context.Context, msg models.SubscribeMe c.dataProviders.Add(dp.ID(), dp) //TODO: return OK response to client - c.communicationChannel <- msg + c.multiplexedStream <- msg go func() { err := dp.Run() @@ -268,7 +252,7 @@ func (c *Controller) handleUnsubscribe(_ context.Context, msg models.Unsubscribe if err != nil { c.logger.Debug().Err(err).Msg("error parsing message ID") //TODO: return an error response to client - c.communicationChannel <- err + c.multiplexedStream <- err return } @@ -288,7 +272,6 @@ func (c *Controller) shutdownConnection() { if err := c.conn.Close(); err != nil { c.logger.Error().Err(err).Msg("error closing connection") } - // TODO: safe closing communicationChannel will be included as a part of PR #6642 }() err := c.dataProviders.ForEach(func(_ uuid.UUID, dp dp.DataProvider) error { @@ -326,3 +309,7 @@ func (c *Controller) keepalive(ctx context.Context) error { } } } + +func IsCloseError(err error) bool { + return websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseAbnormalClosure, websocket.CloseGoingAway) +}