diff --git a/engine/access/rest/websockets/controller.go b/engine/access/rest/websockets/controller.go index 249ccc880cd..fb88106d86b 100644 --- a/engine/access/rest/websockets/controller.go +++ b/engine/access/rest/websockets/controller.go @@ -18,6 +18,10 @@ import ( "github.com/onflow/flow-go/utils/concurrentmap" ) +var ( + ErrUnmarshalMessage = errors.New("failed to unmarshal message") +) + type Controller struct { logger zerolog.Logger config Config @@ -260,7 +264,7 @@ func (c *Controller) handleUnsubscribe(ctx context.Context, msg models.Unsubscri provider, ok := c.dataProviders.Get(id) if !ok { - c.writeBaseErrorResponse(ctx, err, "unsubscribe") + c.writeBaseErrorResponse(ctx, fmt.Errorf("could not find data provider with such id"), "unsubscribe") c.logger.Debug().Err(err).Msg("no active subscription with such ID found") return } @@ -295,6 +299,7 @@ func (c *Controller) handleListSubscriptions(ctx context.Context) { resp := models.ListSubscriptionsMessageResponse{ BaseMessageResponse: models.BaseMessageResponse{ Success: true, + Action: "list_subscriptions", }, Subscriptions: subs, } diff --git a/engine/access/rest/websockets/controller_test.go b/engine/access/rest/websockets/controller_test.go index 0b54b42c36c..002d7e28fb8 100644 --- a/engine/access/rest/websockets/controller_test.go +++ b/engine/access/rest/websockets/controller_test.go @@ -3,6 +3,7 @@ package websockets import ( "context" "encoding/json" + "fmt" "testing" "time" @@ -49,18 +50,25 @@ func (s *WsControllerSuite) TestSubscribeRequest() { conn, dataProviderFactory, dataProvider := newControllerMocks(t) controller := NewWebSocketController(s.logger, s.wsConfig, conn, dataProviderFactory) + dataProviderFactory. + On("NewDataProvider", mock.Anything, mock.Anything, mock.Anything, mock.Anything). + Return(dataProvider, nil). + Once() + + id := uuid.New() + dataProvider.On("ID").Return(id) dataProvider. On("Run"). Run(func(args mock.Arguments) {}). Return(nil). Once() - subscribeRequest := models.SubscribeMessageRequest{ + request := models.SubscribeMessageRequest{ BaseMessageRequest: models.BaseMessageRequest{Action: "subscribe"}, Topic: "blocks", Arguments: nil, } - subscribeRequestJson, err := json.Marshal(subscribeRequest) + requestJson, err := json.Marshal(request) require.NoError(t, err) // Simulate receiving the subscription request from the client @@ -69,38 +77,391 @@ func (s *WsControllerSuite) TestSubscribeRequest() { Run(func(args mock.Arguments) { msg, ok := args.Get(0).(*json.RawMessage) require.True(t, ok) - *msg = subscribeRequestJson + *msg = requestJson }). Return(nil). Once() - // Channel to signal the test flow completion done := make(chan struct{}, 1) + s.expectCloseConnection(conn, done) - // Simulate writing a successful subscription response back to the client conn. On("WriteJSON", mock.Anything). Return(func(msg interface{}) error { response, ok := msg.(models.SubscribeMessageResponse) require.True(t, ok) + require.Equal(t, request.Action, response.Action) require.True(t, response.Success) + require.Equal(t, id.String(), response.ID) + close(done) // Signal that response has been sent return websocket.ErrCloseSent }) - // Simulate client closing connection after receiving the response + controller.HandleConnection(context.Background()) + }) + + s.T().Run("Parse and validate error", func(t *testing.T) { + conn, dataProviderFactory, _ := newControllerMocks(t) + controller := NewWebSocketController(s.logger, s.wsConfig, conn, dataProviderFactory) + + type Request struct { + Action string `json:"action"` + } + + subscribeRequest := Request{ + Action: "SubscribeBlocks", + } + subscribeRequestJson, err := json.Marshal(subscribeRequest) + require.NoError(t, err) + + // Simulate receiving the subscription request from the client conn. On("ReadJSON", mock.Anything). - Return(func(interface{}) error { - <-done + Run(func(args mock.Arguments) { + msg, ok := args.Get(0).(*json.RawMessage) + require.True(t, ok) + *msg = subscribeRequestJson + }). + Return(nil). + Once() + + done := make(chan struct{}, 1) + s.expectCloseConnection(conn, done) + + conn. + On("WriteJSON", mock.Anything). + Return(func(msg interface{}) error { + response, ok := msg.(models.BaseMessageResponse) + require.True(t, ok) + require.Empty(t, response.Action) + require.False(t, response.Success) + require.NotNil(t, response.ErrorMessage) //TODO: add kinds of errors for different data flows + + s.T().Log(response.ErrorMessage) + + close(done) // Signal that response has been sent + return websocket.ErrCloseSent + }) + + controller.HandleConnection(context.Background()) + }) + + s.T().Run("Error creating data provider", func(t *testing.T) { + conn, dataProviderFactory, _ := newControllerMocks(t) + controller := NewWebSocketController(s.logger, s.wsConfig, conn, dataProviderFactory) + + dataProviderFactory. + On("NewDataProvider", mock.Anything, mock.Anything, mock.Anything, mock.Anything). + Return(nil, fmt.Errorf("error creating data provider")). + Once() + + done := make(chan struct{}, 1) + s.expectSubscribeRequest(conn) + s.expectCloseConnection(conn, done) + + conn. + On("WriteJSON", mock.Anything). + Return(func(msg interface{}) error { + response, ok := msg.(models.BaseMessageResponse) + require.True(t, ok) + require.Equal(t, "subscribe", response.Action) + require.False(t, response.Success) + require.Equal(t, response.ErrorMessage, "error creating data provider") + + s.T().Log(response.ErrorMessage) + + close(done) // Signal that response has been sent return websocket.ErrCloseSent }) controller.HandleConnection(context.Background()) }) - s.T().Run("Parse request message error", func(t *testing.T) { - + s.T().Run("Run error", func(t *testing.T) { + conn, dataProviderFactory, dataProvider := newControllerMocks(t) + controller := NewWebSocketController(s.logger, s.wsConfig, conn, dataProviderFactory) + + dataProvider. + On("ID"). + Return(uuid.New()) + + dataProvider. + On("Run"). + Run(func(args mock.Arguments) {}). + Return(fmt.Errorf("error running data provider")). + Once() + + dataProviderFactory. + On("NewDataProvider", mock.Anything, mock.Anything, mock.Anything, mock.Anything). + Return(dataProvider, nil). + Once() + + done := make(chan struct{}, 1) + s.expectSubscribeRequest(conn) + s.expectSubscribeResponse(conn, true) + s.expectCloseConnection(conn, done) + + conn. + On("WriteJSON", mock.Anything). + Return(func(msg interface{}) error { + response, ok := msg.(models.BaseMessageResponse) + require.True(t, ok) + require.Equal(t, "", response.Action) + require.False(t, response.Success) + require.NotNil(t, response.ErrorMessage) //TODO: add kinds of errors for different data flows + + s.T().Log(response.ErrorMessage) + + close(done) // Signal that response has been sent + return websocket.ErrCloseSent + }) + + controller.HandleConnection(context.Background()) + }) +} + +func (s *WsControllerSuite) TestUnsubscribeRequest() { + s.T().Run("Happy path", func(t *testing.T) { + conn, dataProviderFactory, dataProvider := newControllerMocks(t) + controller := NewWebSocketController(s.logger, s.wsConfig, conn, dataProviderFactory) + + dataProviderFactory. + On("NewDataProvider", mock.Anything, mock.Anything, mock.Anything, mock.Anything). + Return(dataProvider, nil). + Once() + + id := uuid.New() + dataProvider.On("ID").Return(id) + dataProvider.On("Close").Return(nil) + dataProvider. + On("Run"). + Run(func(args mock.Arguments) {}). + Return(nil). + Once() + + s.expectSubscribeRequest(conn) + s.expectSubscribeResponse(conn, true) + + request := models.UnsubscribeMessageRequest{ + BaseMessageRequest: models.BaseMessageRequest{Action: "unsubscribe"}, + ID: id.String(), + } + requestJson, err := json.Marshal(request) + require.NoError(s.T(), err) + + conn. + On("ReadJSON", mock.Anything). + Run(func(args mock.Arguments) { + msg, ok := args.Get(0).(*json.RawMessage) + require.True(t, ok) + *msg = requestJson + }). + Return(nil). + Once() + + done := make(chan struct{}, 1) + s.expectCloseConnection(conn, done) + + conn. + On("WriteJSON", mock.Anything). + Return(func(msg interface{}) error { + response, ok := msg.(models.UnsubscribeMessageResponse) + require.True(t, ok) + require.Equal(t, request.Action, response.Action) + require.True(t, response.Success) + require.Empty(t, response.ErrorMessage) + require.Equal(t, request.ID, response.ID) + + close(done) + return websocket.ErrCloseSent + }). + Once() + + ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(time.Second)) + defer cancel() + controller.HandleConnection(ctx) + }) + + s.T().Run("Invalid subscription uuid", func(t *testing.T) { + conn, dataProviderFactory, dataProvider := newControllerMocks(t) + controller := NewWebSocketController(s.logger, s.wsConfig, conn, dataProviderFactory) + + dataProviderFactory. + On("NewDataProvider", mock.Anything, mock.Anything, mock.Anything, mock.Anything). + Return(dataProvider, nil). + Once() + + id := uuid.New() + dataProvider.On("ID").Return(id) + dataProvider. + On("Run"). + Run(func(args mock.Arguments) {}). + Return(nil). + Once() + + s.expectSubscribeRequest(conn) + s.expectSubscribeResponse(conn, true) + + request := models.UnsubscribeMessageRequest{ + BaseMessageRequest: models.BaseMessageRequest{Action: "unsubscribe"}, + ID: "invalid-uuid", + } + requestJson, err := json.Marshal(request) + require.NoError(s.T(), err) + + conn. + On("ReadJSON", mock.Anything). + Run(func(args mock.Arguments) { + msg, ok := args.Get(0).(*json.RawMessage) + require.True(t, ok) + *msg = requestJson + }). + Return(nil). + Once() + + done := make(chan struct{}, 1) + s.expectCloseConnection(conn, done) + + conn. + On("WriteJSON", mock.Anything). + Return(func(msg interface{}) error { + response, ok := msg.(models.BaseMessageResponse) + require.True(t, ok) + require.Equal(t, request.Action, response.Action) + require.False(t, response.Success) + require.NotEmpty(t, response.ErrorMessage) + + s.T().Log(response.ErrorMessage) + + close(done) + return websocket.ErrCloseSent + }). + Once() + + controller.HandleConnection(context.Background()) + }) + + s.T().Run("Unsubscribe from unknown subscription", func(t *testing.T) { + conn, dataProviderFactory, dataProvider := newControllerMocks(t) + controller := NewWebSocketController(s.logger, s.wsConfig, conn, dataProviderFactory) + + dataProviderFactory. + On("NewDataProvider", mock.Anything, mock.Anything, mock.Anything, mock.Anything). + Return(dataProvider, nil). + Once() + + id := uuid.New() + dataProvider.On("ID").Return(id) + dataProvider. + On("Run"). + Run(func(args mock.Arguments) {}). + Return(nil). + Once() + + s.expectSubscribeRequest(conn) + s.expectSubscribeResponse(conn, true) + + request := models.UnsubscribeMessageRequest{ + BaseMessageRequest: models.BaseMessageRequest{Action: "unsubscribe"}, + ID: uuid.New().String(), + } + requestJson, err := json.Marshal(request) + require.NoError(s.T(), err) + + conn. + On("ReadJSON", mock.Anything). + Run(func(args mock.Arguments) { + msg, ok := args.Get(0).(*json.RawMessage) + require.True(t, ok) + *msg = requestJson + }). + Return(nil). + Once() + + done := make(chan struct{}, 1) + s.expectCloseConnection(conn, done) + + conn. + On("WriteJSON", mock.Anything). + Return(func(msg interface{}) error { + response, ok := msg.(models.BaseMessageResponse) + require.True(t, ok) + require.Equal(t, request.Action, response.Action) + require.False(t, response.Success) + require.NotEmpty(t, response.ErrorMessage) + + s.T().Log(response.ErrorMessage) + + close(done) + return websocket.ErrCloseSent + }). + Once() + + controller.HandleConnection(context.Background()) + }) +} + +func (s *WsControllerSuite) TestListSubscriptions() { + s.T().Run("Happy path", func(t *testing.T) { + conn, dataProviderFactory, dataProvider := newControllerMocks(t) + controller := NewWebSocketController(s.logger, s.wsConfig, conn, dataProviderFactory) + + dataProviderFactory. + On("NewDataProvider", mock.Anything, mock.Anything, mock.Anything, mock.Anything). + Return(dataProvider, nil). + Once() + + id := uuid.New() + topic := "blocks" + dataProvider.On("ID").Return(id) + dataProvider.On("Topic").Return(topic) + dataProvider. + On("Run"). + Run(func(args mock.Arguments) {}). + Return(nil). + Once() + + s.expectSubscribeRequest(conn) + s.expectSubscribeResponse(conn, true) + + request := models.ListSubscriptionsMessageRequest{ + BaseMessageRequest: models.BaseMessageRequest{Action: "list_subscriptions"}, + } + requestJson, err := json.Marshal(request) + require.NoError(s.T(), err) + + conn. + On("ReadJSON", mock.Anything). + Run(func(args mock.Arguments) { + msg, ok := args.Get(0).(*json.RawMessage) + require.True(t, ok) + *msg = requestJson + }). + Return(nil). + Once() + + done := make(chan struct{}, 1) + s.expectCloseConnection(conn, done) + + conn. + On("WriteJSON", mock.Anything). + Return(func(msg interface{}) error { + response, ok := msg.(models.ListSubscriptionsMessageResponse) + require.True(t, ok) + require.Equal(t, 1, len(response.Subscriptions)) + require.Equal(t, id.String(), response.Subscriptions[0].ID) + require.Equal(t, topic, response.Subscriptions[0].Topic) + require.Equal(t, response.Action, "list_subscriptions") + require.True(t, response.Success) + require.Empty(t, response.ErrorMessage) + + close(done) + return websocket.ErrCloseSent + }). + Once() + + controller.HandleConnection(context.Background()) }) } @@ -110,6 +471,14 @@ func (s *WsControllerSuite) TestSubscribeBlocks() { conn, dataProviderFactory, dataProvider := newControllerMocks(t) controller := NewWebSocketController(s.logger, s.wsConfig, conn, dataProviderFactory) + dataProviderFactory. + On("NewDataProvider", mock.Anything, mock.Anything, mock.Anything, mock.Anything). + Return(dataProvider, nil). + Once() + + id := uuid.New() + dataProvider.On("ID").Return(id) + // Simulate data provider write a block to the controller expectedBlock := unittest.BlockFixture() dataProvider. @@ -121,8 +490,9 @@ func (s *WsControllerSuite) TestSubscribeBlocks() { Once() done := make(chan struct{}, 1) - s.expectSubscribeRequest(conn, done) + s.expectSubscribeRequest(conn) s.expectSubscribeResponse(conn, true) + s.expectCloseConnection(conn, done) // Expect a valid block to be passed to WriteJSON. // If we got to this point, the controller executed all its logic properly @@ -146,6 +516,15 @@ func (s *WsControllerSuite) TestSubscribeBlocks() { conn, dataProviderFactory, dataProvider := newControllerMocks(t) controller := NewWebSocketController(s.logger, s.wsConfig, conn, dataProviderFactory) + dataProviderFactory. + On("NewDataProvider", mock.Anything, mock.Anything, mock.Anything, mock.Anything). + Return(dataProvider, nil). + Once() + + id := uuid.New() + dataProvider.On("ID").Return(id) + dataProvider.On("Close").Return(nil).Maybe() + // Simulate data provider writes some blocks to the controller expectedBlocks := unittest.BlockFixtures(100) dataProvider. @@ -159,8 +538,9 @@ func (s *WsControllerSuite) TestSubscribeBlocks() { Once() done := make(chan struct{}, 1) - s.expectSubscribeRequest(conn, done) + s.expectSubscribeRequest(conn) s.expectSubscribeResponse(conn, true) + s.expectCloseConnection(conn, done) i := 0 actualBlocks := make([]*flow.Block, len(expectedBlocks)) @@ -199,22 +579,14 @@ func newControllerMocks(t *testing.T) (*connmock.WebsocketConnection, *dpmock.Da conn.On("SetReadDeadline", mock.Anything).Return(nil) conn.On("SetWriteDeadline", mock.Anything).Return(nil) - id := uuid.New() dataProvider := dpmock.NewDataProvider(t) - dataProvider.On("ID").Return(id) - //dataProvider.On("Close").Return(nil). - factory := dpmock.NewDataProviderFactory(t) - factory. - On("NewDataProvider", mock.Anything, mock.Anything, mock.Anything, mock.Anything). - Return(dataProvider, nil). - Once() return conn, factory, dataProvider } // expectSubscribeRequest mocks the client's subscription request. -func (s *WsControllerSuite) expectSubscribeRequest(conn *connmock.WebsocketConnection, done <-chan struct{}) { +func (s *WsControllerSuite) expectSubscribeRequest(conn *connmock.WebsocketConnection) { subscribeRequest := models.SubscribeMessageRequest{ BaseMessageRequest: models.BaseMessageRequest{Action: "subscribe"}, Topic: "blocks", @@ -232,7 +604,9 @@ func (s *WsControllerSuite) expectSubscribeRequest(conn *connmock.WebsocketConne }). Return(nil). Once() +} +func (s *WsControllerSuite) expectCloseConnection(conn *connmock.WebsocketConnection, done <-chan struct{}) { // In the default case, no further communication is expected from the client. // We wait for the writer routine to signal completion, allowing us to close the connection gracefully conn. @@ -241,7 +615,8 @@ func (s *WsControllerSuite) expectSubscribeRequest(conn *connmock.WebsocketConne for range done { } return websocket.ErrCloseSent - }) + }). + Once() } // expectSubscribeResponse mocks the subscription response sent to the client. @@ -251,6 +626,7 @@ func (s *WsControllerSuite) expectSubscribeResponse(conn *connmock.WebsocketConn Run(func(args mock.Arguments) { response, ok := args.Get(0).(models.SubscribeMessageResponse) require.True(s.T(), ok) + require.Equal(s.T(), "subscribe", response.Action) require.Equal(s.T(), success, response.Success) }). Return(nil). @@ -266,5 +642,5 @@ func (s *WsControllerSuite) expectKeepaliveClose(conn *connmock.WebsocketConnect } return websocket.ErrCloseSent }). - Once() + Maybe() }