Skip to content

Commit

Permalink
Handle different data flows in controller
Browse files Browse the repository at this point in the history
  • Loading branch information
illia-malachyn committed Dec 11, 2024
1 parent 48c8e4e commit 665cdb0
Show file tree
Hide file tree
Showing 5 changed files with 196 additions and 85 deletions.
193 changes: 141 additions & 52 deletions engine/access/rest/websockets/controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@ package websockets
import (
"context"
"encoding/json"
"errors"
"fmt"
"sync"
"time"

"github.com/google/uuid"
Expand All @@ -27,6 +29,7 @@ type Controller struct {

dataProviders *concurrentmap.Map[uuid.UUID, dp.DataProvider]
dataProviderFactory dp.DataProviderFactory
dataProvidersGroup *sync.WaitGroup
}

func NewWebSocketController(
Expand All @@ -42,6 +45,7 @@ func NewWebSocketController(
multiplexedStream: make(chan interface{}),
dataProviders: concurrentmap.New[uuid.UUID, dp.DataProvider](),
dataProviderFactory: dataProviderFactory,
dataProvidersGroup: &sync.WaitGroup{},
}
}

Expand Down Expand Up @@ -72,6 +76,10 @@ func (c *Controller) HandleConnection(ctx context.Context) {
})

if err = g.Wait(); err != nil {
if errors.Is(err, websocket.ErrCloseSent) {
return
}

c.logger.Error().Err(err).Msg("error detected in one of the goroutines")
}
}
Expand Down Expand Up @@ -116,21 +124,17 @@ func (c *Controller) writeMessages(ctx context.Context) error {
select {
case <-ctx.Done():
return ctx.Err()
case msg, ok := <-c.multiplexedStream:
case message, ok := <-c.multiplexedStream:
if !ok {
return nil
return fmt.Errorf("multiplexed stream closed")
}

if err := c.conn.SetWriteDeadline(time.Now().Add(WriteWait)); err != nil {
return fmt.Errorf("failed to set the write deadline: %w", err)
}

err := c.conn.WriteJSON(msg)
if err != nil {
if IsCloseError(err) {
return nil
}
c.logger.Error().Err(err).Msg("failed to write msg to connection")
if err := c.conn.WriteJSON(message); err != nil {
return err
}
}
}
Expand All @@ -143,37 +147,32 @@ func (c *Controller) writeMessages(ctx context.Context) error {
// - context.Canceled if the client disconnected
func (c *Controller) readMessages(ctx context.Context) error {
for {
msg, err := c.readMessage()
if err != nil {
if IsCloseError(err) {
return nil
var message json.RawMessage
if err := c.conn.ReadJSON(&message); err != nil {
if errors.Is(err, websocket.ErrCloseSent) {
return err
}

c.writeBaseErrorResponse(ctx, err, "")
c.logger.Error().Err(err).Msg("error reading message")
continue
}

validatedMsg, err := c.parseAndValidateMessage(msg)
validatedMsg, err := c.parseAndValidateMessage(message)
if err != nil {
c.writeBaseErrorResponse(ctx, err, "")
c.logger.Error().Err(err).Msg("failed to parse message")
continue
}

if err := c.handleAction(ctx, validatedMsg); err != nil {
if err = c.handleAction(ctx, validatedMsg); err != nil {
c.writeBaseErrorResponse(ctx, err, "")
c.logger.Error().Err(err).Msg("failed to handle action")
continue
}
}
}

func (c *Controller) readMessage() (json.RawMessage, error) {
var message json.RawMessage
if err := c.conn.ReadJSON(&message); err != nil {
return nil, fmt.Errorf("error reading JSON from client: %w", err)
}
return message, nil
}

func (c *Controller) parseAndValidateMessage(message json.RawMessage) (interface{}, error) {
var baseMsg models.BaseMessageRequest
if err := json.Unmarshal(message, &baseMsg); err != nil {
Expand All @@ -187,7 +186,6 @@ func (c *Controller) parseAndValidateMessage(message json.RawMessage) (interface
if err := json.Unmarshal(message, &subscribeMsg); err != nil {
return nil, fmt.Errorf("error unmarshalling subscribe message: %w", err)
}
//TODO: add validation logic for `topic` field
validatedMsg = subscribeMsg

case "unsubscribe":
Expand Down Expand Up @@ -219,70 +217,120 @@ func (c *Controller) handleAction(ctx context.Context, message interface{}) erro
case models.UnsubscribeMessageRequest:
c.handleUnsubscribe(ctx, msg)
case models.ListSubscriptionsMessageRequest:
c.handleListSubscriptions(ctx, msg)
c.handleListSubscriptions(ctx)
default:
return fmt.Errorf("unknown message type: %T", msg)
}
return nil
}

func (c *Controller) handleSubscribe(ctx context.Context, msg models.SubscribeMessageRequest) {
dp, err := c.dataProviderFactory.NewDataProvider(ctx, msg.Topic, msg.Arguments, c.multiplexedStream)
// register new provider
provider, err := c.dataProviderFactory.NewDataProvider(ctx, msg.Topic, msg.Arguments, c.multiplexedStream)
if err != nil {
// TODO: handle error here
c.logger.Error().Err(err).Msgf("error while creating data provider for topic: %s", msg.Topic)
c.writeBaseErrorResponse(ctx, err, "subscribe")
c.logger.Error().Err(err).Msg("error creating data provider")
return
}

c.dataProviders.Add(dp.ID(), dp)

//TODO: return OK response to client
c.multiplexedStream <- msg
c.dataProviders.Add(provider.ID(), provider)
c.writeSubscribeOkResponse(ctx, provider.ID())

// run provider
c.dataProvidersGroup.Add(1)
go func() {
err := dp.Run()
err = provider.Run()
if err != nil {
//TODO: Log or handle the error from Run
c.writeBaseErrorResponse(ctx, err, "")
c.logger.Error().Err(err).Msgf("error while running data provider for topic: %s", msg.Topic)
}

c.dataProvidersGroup.Done()
c.dataProviders.Remove(provider.ID())
}()
}

func (c *Controller) handleUnsubscribe(_ context.Context, msg models.UnsubscribeMessageRequest) {
func (c *Controller) handleUnsubscribe(ctx context.Context, msg models.UnsubscribeMessageRequest) {
id, err := uuid.Parse(msg.ID)
if err != nil {
c.writeBaseErrorResponse(ctx, err, "unsubscribe")
c.logger.Debug().Err(err).Msg("error parsing message ID")
//TODO: return an error response to client
c.multiplexedStream <- err
return
}

dp, ok := c.dataProviders.Get(id)
if ok {
dp.Close()
c.dataProviders.Remove(id)
provider, ok := c.dataProviders.Get(id)
if !ok {
c.writeBaseErrorResponse(ctx, err, "unsubscribe")
c.logger.Debug().Err(err).Msg("no active subscription with such ID found")
return
}

err = provider.Close()
if err != nil {
c.writeBaseErrorResponse(ctx, err, "unsubscribe")
return
}

c.dataProviders.Remove(id)
c.writeUnsubscribeOkResponse(ctx, id)
}

func (c *Controller) handleListSubscriptions(ctx context.Context, msg models.ListSubscriptionsMessageRequest) {
//TODO: return a response to client
func (c *Controller) handleListSubscriptions(ctx context.Context) {
var subs []*models.SubscriptionEntry

err := c.dataProviders.ForEach(func(id uuid.UUID, provider dp.DataProvider) error {
subs = append(subs, &models.SubscriptionEntry{
ID: id.String(),
Topic: provider.Topic(),
})
return nil
})

if err != nil {
c.writeBaseErrorResponse(ctx, err, "list_subscriptions")
c.logger.Debug().Err(err).Msg("error listing subscriptions")
return
}

resp := models.ListSubscriptionsMessageResponse{
BaseMessageResponse: models.BaseMessageResponse{
Success: true,
},
Subscriptions: subs,
}
c.writeResponse(ctx, resp)
}

func (c *Controller) shutdownConnection() {
defer func() {
if err := c.conn.Close(); err != nil {
c.logger.Error().Err(err).Msg("error closing connection")
err := c.conn.Close()
if err != nil {
c.logger.Error().Err(err).Msg("error closing connection")
}

err = c.dataProviders.ForEach(func(_ uuid.UUID, dp dp.DataProvider) error {
//TODO: why did i think it's a good idea to return error in Close()? it's messy now
err = dp.Close()
if err != nil {
c.logger.Error().Err(err).Msg("error closing data provider")
}
}()

err := c.dataProviders.ForEach(func(_ uuid.UUID, dp dp.DataProvider) error {
dp.Close()
return nil
})

if err != nil {
c.logger.Error().Err(err).Msg("error closing data provider")
}

c.dataProviders.Clear()

// drain the channel as some providers may still send data to it during shutdown
go func() {
for range c.multiplexedStream {
}
}()

c.dataProvidersGroup.Wait()
close(c.multiplexedStream)
}

// keepalive sends a ping message periodically to keep the WebSocket connection alive
Expand All @@ -301,15 +349,56 @@ func (c *Controller) keepalive(ctx context.Context) error {
case <-pingTicker.C:
err := c.conn.WriteControl(websocket.PingMessage, time.Now().Add(WriteWait))
if err != nil {
// Log error and exit the loop on failure
c.logger.Debug().Err(err).Msg("failed to send ping")
if errors.Is(err, websocket.ErrCloseSent) {
return err
}

c.writeBaseErrorResponse(ctx, err, "")
c.logger.Debug().Err(err).Msg("failed to send ping")
return fmt.Errorf("failed to write ping message: %w", err)
}
}
}
}

func IsCloseError(err error) bool {
return websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseAbnormalClosure, websocket.CloseGoingAway)
func (c *Controller) writeBaseErrorResponse(ctx context.Context, err error, action string) {
request := models.BaseMessageResponse{
Action: action,
Success: false,
ErrorMessage: err.Error(),
}

c.writeResponse(ctx, request)
}

func (c *Controller) writeSubscribeOkResponse(ctx context.Context, id uuid.UUID) {
request := models.SubscribeMessageResponse{
BaseMessageResponse: models.BaseMessageResponse{
Action: "subscribe",
Success: true,
},
ID: id.String(),
}

c.writeResponse(ctx, request)
}

func (c *Controller) writeUnsubscribeOkResponse(ctx context.Context, id uuid.UUID) {
request := models.UnsubscribeMessageResponse{
BaseMessageResponse: models.BaseMessageResponse{
Action: "unsubscribe",
Success: true,
},
ID: id.String(),
}

c.writeResponse(ctx, request)
}

func (c *Controller) writeResponse(ctx context.Context, response interface{}) {
select {
case <-ctx.Done():
return
case c.multiplexedStream <- response:
}
}
Loading

0 comments on commit 665cdb0

Please sign in to comment.