From a574fb680f981ddeba1e731736db461786eca413 Mon Sep 17 00:00:00 2001 From: Illia Malachyn Date: Tue, 26 Nov 2024 15:49:46 +0200 Subject: [PATCH 1/5] Add unit test for websocket controller * Add unit test for websocket controller * Add mock for websocket connection * Add mock for data provider * Add mock for data provider factory * Add mock for websocket connection The WebSocket Controller interacts with: 1. Data Provider: Supplies data to the controller. 2. WebSocket Connection: Handles communication with the client. To properly test the controller's logic, we mock these interactions. Since the controller runs two parallel routines (reader and writer), the tests also ensure both can shut down cleanly. A done channel is used in the tests to coordinate this process. --- Makefile | 3 + engine/access/rest/router/router.go | 4 +- engine/access/rest/websockets/connection.go | 39 +++ engine/access/rest/websockets/controller.go | 140 +++++++---- .../access/rest/websockets/controller_test.go | 235 ++++++++++++++++++ .../rest/websockets/data_provider/blocks.go | 3 +- .../rest/websockets/data_provider/factory.go | 12 +- .../data_provider/mock/data_provider.go | 17 +- .../websockets/data_provider/mock/factory.go | 47 ++++ .../rest/websockets/data_provider/provider.go | 2 +- engine/access/rest/websockets/handler.go | 27 +- engine/access/rest/websockets/handler_test.go | 86 ------- .../websockets/mock/websocket_connection.go | 78 ++++++ 13 files changed, 530 insertions(+), 163 deletions(-) create mode 100644 engine/access/rest/websockets/connection.go create mode 100644 engine/access/rest/websockets/controller_test.go create mode 100644 engine/access/rest/websockets/data_provider/mock/factory.go delete mode 100644 engine/access/rest/websockets/handler_test.go create mode 100644 engine/access/rest/websockets/mock/websocket_connection.go diff --git a/Makefile b/Makefile index 2578fffe4b6..36c495fb1b5 100644 --- a/Makefile +++ b/Makefile @@ -214,6 +214,9 @@ generate-mocks: install-mock-generators mockery --name 'Storage' --dir=module/executiondatasync/tracker --case=underscore --output="module/executiondatasync/tracker/mock" --outpkg="mocktracker" mockery --name 'ScriptExecutor' --dir=module/execution --case=underscore --output="module/execution/mock" --outpkg="mock" mockery --name 'StorageSnapshot' --dir=fvm/storage/snapshot --case=underscore --output="fvm/storage/snapshot/mock" --outpkg="mock" + mockery --name 'DataProvider' --dir=engine/access/rest/websockets/data_provider --case=underscore --output="engine/access/rest/websockets/data_provider/mock" --outpkg="mock" + mockery --name 'Factory' --dir=engine/access/rest/websockets/data_provider --case=underscore --output="engine/access/rest/websockets/data_provider/mock" --outpkg="mock" + mockery --name 'WebsocketConnection' --dir=engine/access/rest/websockets --case=underscore --output="engine/access/rest/websockets/mock" --outpkg="mock" #temporarily make insecure/ a non-module to allow mockery to create mocks mv insecure/go.mod insecure/go2.mod diff --git a/engine/access/rest/router/router.go b/engine/access/rest/router/router.go index a2d81cb0a58..37dac306fa2 100644 --- a/engine/access/rest/router/router.go +++ b/engine/access/rest/router/router.go @@ -14,6 +14,7 @@ import ( flowhttp "github.com/onflow/flow-go/engine/access/rest/http" "github.com/onflow/flow-go/engine/access/rest/http/models" "github.com/onflow/flow-go/engine/access/rest/websockets" + "github.com/onflow/flow-go/engine/access/rest/websockets/data_provider" legacyws "github.com/onflow/flow-go/engine/access/rest/websockets/legacy" "github.com/onflow/flow-go/engine/access/state_stream" "github.com/onflow/flow-go/engine/access/state_stream/backend" @@ -93,7 +94,8 @@ func (b *RouterBuilder) AddWebsocketsRoute( streamConfig backend.Config, maxRequestSize int64, ) *RouterBuilder { - handler := websockets.NewWebSocketHandler(b.logger, config, chain, streamApi, streamConfig, maxRequestSize) + factory := data_provider.NewDataProviderFactory(b.logger, streamApi, streamConfig) + handler := websockets.NewWebSocketHandler(b.logger, config, chain, factory, maxRequestSize) b.v1SubRouter. Methods(http.MethodGet). Path("/ws"). diff --git a/engine/access/rest/websockets/connection.go b/engine/access/rest/websockets/connection.go new file mode 100644 index 00000000000..9f762f6d389 --- /dev/null +++ b/engine/access/rest/websockets/connection.go @@ -0,0 +1,39 @@ +package websockets + +import ( + "github.com/gorilla/websocket" +) + +type WebsocketConnection interface { + ReadJSON(v interface{}) error + WriteJSON(v interface{}) error + Close() error +} + +type GorillaWebsocketConnection struct { + conn *websocket.Conn +} + +func NewGorillaWebsocketConnection(conn *websocket.Conn) *GorillaWebsocketConnection { + return &GorillaWebsocketConnection{ + conn: conn, + } +} + +var _ WebsocketConnection = (*GorillaWebsocketConnection)(nil) + +func (m *GorillaWebsocketConnection) ReadJSON(v interface{}) error { + return m.conn.ReadJSON(v) +} + +func (m *GorillaWebsocketConnection) WriteJSON(v interface{}) error { + return m.conn.WriteJSON(v) +} + +func (m *GorillaWebsocketConnection) SetCloseHandler(handler func(code int, text string) error) { + m.conn.SetCloseHandler(handler) +} + +func (m *GorillaWebsocketConnection) Close() error { + return m.conn.Close() +} diff --git a/engine/access/rest/websockets/controller.go b/engine/access/rest/websockets/controller.go index fe873f5f61c..2b36b9303ae 100644 --- a/engine/access/rest/websockets/controller.go +++ b/engine/access/rest/websockets/controller.go @@ -3,7 +3,9 @@ package websockets import ( "context" "encoding/json" + "errors" "fmt" + "sync" "github.com/google/uuid" "github.com/gorilla/websocket" @@ -11,34 +13,35 @@ import ( dp "github.com/onflow/flow-go/engine/access/rest/websockets/data_provider" "github.com/onflow/flow-go/engine/access/rest/websockets/models" - "github.com/onflow/flow-go/engine/access/state_stream" - "github.com/onflow/flow-go/engine/access/state_stream/backend" "github.com/onflow/flow-go/utils/concurrentmap" ) +var ErrEmptyMessage = errors.New("empty message") + type Controller struct { logger zerolog.Logger config Config - conn *websocket.Conn + conn WebsocketConnection communicationChannel chan interface{} dataProviders *concurrentmap.Map[uuid.UUID, dp.DataProvider] - dataProvidersFactory *dp.Factory + dataProvidersFactory dp.Factory + shutdownOnce sync.Once } func NewWebSocketController( logger zerolog.Logger, config Config, - streamApi state_stream.API, - streamConfig backend.Config, - conn *websocket.Conn, + factory dp.Factory, + conn WebsocketConnection, ) *Controller { return &Controller{ logger: logger.With().Str("component", "websocket-controller").Logger(), config: config, conn: conn, - communicationChannel: make(chan interface{}), //TODO: should it be buffered chan? + communicationChannel: make(chan interface{}, 10), //TODO: should it be buffered chan? dataProviders: concurrentmap.New[uuid.UUID, dp.DataProvider](), - dataProvidersFactory: dp.NewDataProviderFactory(logger, streamApi, streamConfig), + dataProvidersFactory: factory, + shutdownOnce: sync.Once{}, } } @@ -46,59 +49,73 @@ func NewWebSocketController( func (c *Controller) HandleConnection(ctx context.Context) { //TODO: configure the connection with ping-pong and deadlines //TODO: spin up a response limit tracker routine - go c.readMessagesFromClient(ctx) - c.writeMessagesToClient(ctx) + go c.readMessages(ctx) + c.writeMessages(ctx) } -// writeMessagesToClient reads a messages from communication channel and passes them on to a client WebSocket connection. +// writeMessages reads a messages from communication channel and passes them on to a client WebSocket connection. // The communication channel is filled by data providers. Besides, the response limit tracker is involved in // write message regulation -func (c *Controller) writeMessagesToClient(ctx context.Context) { - //TODO: can it run forever? maybe we should cancel the ctx in the reader routine +func (c *Controller) writeMessages(ctx context.Context) { + defer c.shutdownConnection() + for { select { case <-ctx.Done(): return - case msg := <-c.communicationChannel: - // TODO: handle 'response per second' limits + case msg, ok := <-c.communicationChannel: + if !ok { + return + } + c.logger.Debug().Msgf("read message from communication channel: %s", msg) + // TODO: handle 'response per second' limits err := c.conn.WriteJSON(msg) if err != nil { + if websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) || + websocket.IsUnexpectedCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) { + return + } + c.logger.Error().Err(err).Msg("error writing to connection") + return } + + c.logger.Debug().Msg("written message to client") } } } -// readMessagesFromClient continuously reads messages from a client WebSocket connection, +// readMessages continuously reads messages from a client WebSocket connection, // processes each message, and handles actions based on the message type. -func (c *Controller) readMessagesFromClient(ctx context.Context) { +func (c *Controller) readMessages(ctx context.Context) { defer c.shutdownConnection() for { - select { - case <-ctx.Done(): - c.logger.Info().Msg("context canceled, stopping read message loop") - return - default: - msg, err := c.readMessage() - if err != nil { - if websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseAbnormalClosure) { - return - } - c.logger.Warn().Err(err).Msg("error reading message from client") + msg, err := c.readMessage() + if err != nil { + if websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseAbnormalClosure) || + websocket.IsUnexpectedCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) { return + } else if errors.Is(err, ErrEmptyMessage) { + continue } - baseMsg, validatedMsg, err := c.parseAndValidateMessage(msg) - if err != nil { - c.logger.Debug().Err(err).Msg("error parsing and validating client message") - return - } + c.logger.Debug().Err(err).Msg("error reading message from client") + continue + } - if err := c.handleAction(ctx, validatedMsg); err != nil { - c.logger.Warn().Err(err).Str("action", baseMsg.Action).Msg("error handling action") - } + baseMsg, validatedMsg, err := c.parseAndValidateMessage(msg) + if err != nil { + c.logger.Debug().Err(err).Msg("error parsing and validating client message") + //TODO: write error to error channel + continue + } + + if err := c.handleAction(ctx, validatedMsg); err != nil { + c.logger.Debug().Err(err).Str("action", baseMsg.Action).Msg("error handling action") + //TODO: write error to error channel + continue } } } @@ -108,6 +125,11 @@ func (c *Controller) readMessage() (json.RawMessage, error) { if err := c.conn.ReadJSON(&message); err != nil { return nil, fmt.Errorf("error reading JSON from client: %w", err) } + + if message == nil { + return nil, ErrEmptyMessage + } + return message, nil } @@ -166,10 +188,18 @@ func (c *Controller) handleAction(ctx context.Context, message interface{}) erro func (c *Controller) handleSubscribe(ctx context.Context, msg models.SubscribeMessageRequest) { dp := c.dataProvidersFactory.NewDataProvider(c.communicationChannel, msg.Topic) c.dataProviders.Add(dp.ID(), dp) - dp.Run(ctx) - //TODO: return OK response to client - c.communicationChannel <- msg + // firstly, we want to write OK response to client and only after that we can start providing actual data + response := models.SubscribeMessageResponse{ + BaseMessageResponse: models.BaseMessageResponse{ + Success: true, + }, + Topic: dp.Topic(), + ID: dp.ID().String(), + } + c.communicationChannel <- response + + dp.Run(ctx) } func (c *Controller) handleUnsubscribe(_ context.Context, msg models.UnsubscribeMessageRequest) { @@ -193,20 +223,24 @@ func (c *Controller) handleListSubscriptions(ctx context.Context, msg models.Lis } func (c *Controller) shutdownConnection() { - defer close(c.communicationChannel) - defer func(conn *websocket.Conn) { - if err := c.conn.Close(); err != nil { - c.logger.Error().Err(err).Msg("error closing connection") + c.shutdownOnce.Do(func() { + defer close(c.communicationChannel) + defer func(conn WebsocketConnection) { + if err := c.conn.Close(); err != nil { + c.logger.Warn().Err(err).Msg("error closing connection") + } + }(c.conn) + + c.logger.Debug().Msg("shutting down connection") + + err := c.dataProviders.ForEach(func(_ uuid.UUID, dp dp.DataProvider) error { + dp.Close() + return nil + }) + if err != nil { + c.logger.Error().Err(err).Msg("error closing data provider") } - }(c.conn) - err := c.dataProviders.ForEach(func(_ uuid.UUID, dp dp.DataProvider) error { - dp.Close() - return nil + c.dataProviders.Clear() }) - if err != nil { - c.logger.Error().Err(err).Msg("error closing data provider") - } - - c.dataProviders.Clear() } diff --git a/engine/access/rest/websockets/controller_test.go b/engine/access/rest/websockets/controller_test.go new file mode 100644 index 00000000000..36375d9e733 --- /dev/null +++ b/engine/access/rest/websockets/controller_test.go @@ -0,0 +1,235 @@ +package websockets + +import ( + "context" + "encoding/json" + "testing" + + "github.com/google/uuid" + "github.com/gorilla/websocket" + "github.com/rs/zerolog" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" + + dpmock "github.com/onflow/flow-go/engine/access/rest/websockets/data_provider/mock" + connmock "github.com/onflow/flow-go/engine/access/rest/websockets/mock" + "github.com/onflow/flow-go/engine/access/rest/websockets/models" + "github.com/onflow/flow-go/engine/access/state_stream/backend" + streammock "github.com/onflow/flow-go/engine/access/state_stream/mock" + "github.com/onflow/flow-go/model/flow" + "github.com/onflow/flow-go/utils/unittest" +) + +type WsControllerSuite struct { + suite.Suite + + logger zerolog.Logger + wsConfig Config + streamApi *streammock.API + streamConfig backend.Config +} + +func (s *WsControllerSuite) SetupTest() { + //s.logger = unittest.LoggerWithWriterAndLevel(os.Stdout, zerolog.DebugLevel) + s.logger = unittest.Logger() + s.wsConfig = NewDefaultWebsocketConfig() + s.streamApi = streammock.NewAPI(s.T()) + s.streamConfig = backend.Config{} +} + +func TestWsControllerSuite(t *testing.T) { + suite.Run(t, new(WsControllerSuite)) +} + +func (s *WsControllerSuite) TestSubscribeRequest() { + s.T().Run("Happy path", func(t *testing.T) { + conn, dataProviderFactory, dataProvider := newControllerMocks(t) + controller := NewWebSocketController(s.logger, s.wsConfig, dataProviderFactory, conn) + + dataProvider. + On("Run", mock.Anything). + Run(func(args mock.Arguments) {}). + Once() + + requestMessage := models.SubscribeMessageRequest{ + BaseMessageRequest: models.BaseMessageRequest{Action: "subscribe"}, + Topic: "blocks", + Arguments: nil, + } + + conn. + On("ReadJSON", mock.Anything). + Run(func(args mock.Arguments) { + reqMsg, ok := args.Get(0).(*json.RawMessage) + require.True(t, ok) + msg, err := json.Marshal(requestMessage) + require.NoError(t, err) + *reqMsg = msg + }). + Return(nil). + Once() + + done := make(chan struct{}, 1) + conn. + On("WriteJSON", mock.Anything). + Return(func(msg interface{}) error { + response, ok := msg.(models.SubscribeMessageResponse) + require.True(t, ok) + require.True(t, response.Success) + close(done) + return websocket.ErrCloseSent + }) + + conn. + On("ReadJSON", mock.Anything). + Return(func(interface{}) error { + _, ok := <-done + if !ok { + return websocket.ErrCloseSent + } + return nil + }) + + controller.HandleConnection(context.Background()) + }) +} + +func (s *WsControllerSuite) TestSubscribeBlocks() { + s.T().Run("Stream one block", func(t *testing.T) { + conn, dataProviderFactory, dataProvider := newControllerMocks(t) + controller := NewWebSocketController(s.logger, s.wsConfig, dataProviderFactory, conn) + + // we want data provider to write some block to controller + expectedBlock := unittest.BlockFixture() + dataProvider. + On("Run", mock.Anything). + Run(func(args mock.Arguments) { + controller.communicationChannel <- expectedBlock + }). + Once() + + done := make(chan struct{}, 1) + var actualBlock flow.Block + + s.expectSubscriptionRequest(conn, done) + s.expectSubscriptionResponse(conn, true) + + conn. + On("WriteJSON", mock.Anything). + Return(func(msg interface{}) error { + block := msg.(flow.Block) + actualBlock = block + + close(done) + return websocket.ErrCloseSent + }) + + controller.HandleConnection(context.Background()) + require.Equal(t, expectedBlock, actualBlock) + }) + + s.T().Run("Stream many blocks", func(t *testing.T) { + conn, dataProviderFactory, dataProvider := newControllerMocks(t) + controller := NewWebSocketController(s.logger, s.wsConfig, dataProviderFactory, conn) + + // we want data provider to write some block to controller + expectedBlocks := unittest.BlockFixtures(100) + dataProvider. + On("Run", mock.Anything). + Run(func(args mock.Arguments) { + for _, block := range expectedBlocks { + controller.communicationChannel <- *block + } + }). + Once() + + done := make(chan struct{}, 1) + actualBlocks := make([]*flow.Block, len(expectedBlocks)) + i := 0 + + s.expectSubscriptionRequest(conn, done) + s.expectSubscriptionResponse(conn, true) + + conn. + On("WriteJSON", mock.Anything). + Return(func(msg interface{}) error { + block := msg.(flow.Block) + actualBlocks[i] = &block + i += 1 + + if i == len(expectedBlocks) { + close(done) + return websocket.ErrCloseSent + } + + return nil + }). + Times(len(expectedBlocks)) + + controller.HandleConnection(context.Background()) + require.Equal(t, expectedBlocks, actualBlocks) + }) +} + +func newControllerMocks(t *testing.T) (*connmock.WebsocketConnection, *dpmock.Factory, *dpmock.DataProvider) { + conn := connmock.NewWebsocketConnection(t) + conn.On("Close").Return(nil).Once() + + id := uuid.New() + topic := "blocks" + dataProvider := dpmock.NewDataProvider(t) + dataProvider.On("ID").Return(id) + dataProvider.On("Close").Return(nil) + dataProvider.On("Topic").Return(topic) + + factory := dpmock.NewFactory(t) + factory. + On("NewDataProvider", mock.Anything, mock.Anything). + Return(dataProvider). + Once() + + return conn, factory, dataProvider +} + +func (s *WsControllerSuite) expectSubscriptionRequest(conn *connmock.WebsocketConnection, done <-chan struct{}) { + requestMessage := models.SubscribeMessageRequest{ + BaseMessageRequest: models.BaseMessageRequest{Action: "subscribe"}, + Topic: "blocks", + } + + // The very first message from a client is a request to subscribe to some topic + conn.On("ReadJSON", mock.Anything). + Run(func(args mock.Arguments) { + reqMsg, ok := args.Get(0).(*json.RawMessage) + require.True(s.T(), ok) + msg, err := json.Marshal(requestMessage) + require.NoError(s.T(), err) + *reqMsg = msg + }). + Return(nil). + Once() + + // In the default case, no further communication is expected from the client. + // We wait for the writer routine to signal completion, allowing us to close the connection gracefully + conn. + On("ReadJSON", mock.Anything). + Return(func(msg interface{}) error { + _, ok := <-done + if !ok { + return websocket.ErrCloseSent + } + return nil + }) +} + +func (s *WsControllerSuite) expectSubscriptionResponse(conn *connmock.WebsocketConnection, success bool) { + conn.On("WriteJSON", mock.Anything). + Run(func(args mock.Arguments) { + response, ok := args.Get(0).(models.SubscribeMessageResponse) + require.True(s.T(), ok) + require.Equal(s.T(), success, response.Success) + }). + Return(nil). + Once() +} diff --git a/engine/access/rest/websockets/data_provider/blocks.go b/engine/access/rest/websockets/data_provider/blocks.go index 01b4d07d2e7..6699abfea04 100644 --- a/engine/access/rest/websockets/data_provider/blocks.go +++ b/engine/access/rest/websockets/data_provider/blocks.go @@ -56,6 +56,7 @@ func (p *MockBlockProvider) Topic() string { return p.topic } -func (p *MockBlockProvider) Close() { +func (p *MockBlockProvider) Close() error { p.stopProviderFunc() + return nil } diff --git a/engine/access/rest/websockets/data_provider/factory.go b/engine/access/rest/websockets/data_provider/factory.go index 6a2658b1b95..8c19965ee90 100644 --- a/engine/access/rest/websockets/data_provider/factory.go +++ b/engine/access/rest/websockets/data_provider/factory.go @@ -7,21 +7,25 @@ import ( "github.com/onflow/flow-go/engine/access/state_stream/backend" ) -type Factory struct { +type Factory interface { + NewDataProvider(ch chan<- interface{}, topic string) DataProvider +} + +type SimpleFactory struct { logger zerolog.Logger streamApi state_stream.API streamConfig backend.Config } -func NewDataProviderFactory(logger zerolog.Logger, streamApi state_stream.API, streamConfig backend.Config) *Factory { - return &Factory{ +func NewDataProviderFactory(logger zerolog.Logger, streamApi state_stream.API, streamConfig backend.Config) *SimpleFactory { + return &SimpleFactory{ logger: logger, streamApi: streamApi, streamConfig: streamConfig, } } -func (f *Factory) NewDataProvider(ch chan<- interface{}, topic string) DataProvider { +func (f *SimpleFactory) NewDataProvider(ch chan<- interface{}, topic string) DataProvider { switch topic { case "blocks": return NewMockBlockProvider(ch, topic, f.logger, f.streamApi) diff --git a/engine/access/rest/websockets/data_provider/mock/data_provider.go b/engine/access/rest/websockets/data_provider/mock/data_provider.go index 4a2a22a44a0..9178a58d385 100644 --- a/engine/access/rest/websockets/data_provider/mock/data_provider.go +++ b/engine/access/rest/websockets/data_provider/mock/data_provider.go @@ -16,8 +16,21 @@ type DataProvider struct { } // Close provides a mock function with given fields: -func (_m *DataProvider) Close() { - _m.Called() +func (_m *DataProvider) Close() error { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for Close") + } + + var r0 error + if rf, ok := ret.Get(0).(func() error); ok { + r0 = rf() + } else { + r0 = ret.Error(0) + } + + return r0 } // ID provides a mock function with given fields: diff --git a/engine/access/rest/websockets/data_provider/mock/factory.go b/engine/access/rest/websockets/data_provider/mock/factory.go new file mode 100644 index 00000000000..9dd9d92372d --- /dev/null +++ b/engine/access/rest/websockets/data_provider/mock/factory.go @@ -0,0 +1,47 @@ +// Code generated by mockery v2.43.2. DO NOT EDIT. + +package mock + +import ( + data_provider "github.com/onflow/flow-go/engine/access/rest/websockets/data_provider" + mock "github.com/stretchr/testify/mock" +) + +// Factory is an autogenerated mock type for the Factory type +type Factory struct { + mock.Mock +} + +// NewDataProvider provides a mock function with given fields: ch, topic +func (_m *Factory) NewDataProvider(ch chan<- interface{}, topic string) data_provider.DataProvider { + ret := _m.Called(ch, topic) + + if len(ret) == 0 { + panic("no return value specified for NewDataProvider") + } + + var r0 data_provider.DataProvider + if rf, ok := ret.Get(0).(func(chan<- interface{}, string) data_provider.DataProvider); ok { + r0 = rf(ch, topic) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(data_provider.DataProvider) + } + } + + return r0 +} + +// NewFactory creates a new instance of Factory. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewFactory(t interface { + mock.TestingT + Cleanup(func()) +}) *Factory { + mock := &Factory{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/engine/access/rest/websockets/data_provider/provider.go b/engine/access/rest/websockets/data_provider/provider.go index ce2914140ba..9e0870ae274 100644 --- a/engine/access/rest/websockets/data_provider/provider.go +++ b/engine/access/rest/websockets/data_provider/provider.go @@ -10,5 +10,5 @@ type DataProvider interface { Run(ctx context.Context) ID() uuid.UUID Topic() string - Close() + Close() error } diff --git a/engine/access/rest/websockets/handler.go b/engine/access/rest/websockets/handler.go index 247890c2a62..08c27aa2682 100644 --- a/engine/access/rest/websockets/handler.go +++ b/engine/access/rest/websockets/handler.go @@ -8,18 +8,16 @@ import ( "github.com/rs/zerolog" "github.com/onflow/flow-go/engine/access/rest/common" - "github.com/onflow/flow-go/engine/access/state_stream" - "github.com/onflow/flow-go/engine/access/state_stream/backend" + "github.com/onflow/flow-go/engine/access/rest/websockets/data_provider" "github.com/onflow/flow-go/model/flow" ) type Handler struct { *common.HttpHandler - logger zerolog.Logger - websocketConfig Config - streamApi state_stream.API - streamConfig backend.Config + logger zerolog.Logger + websocketConfig Config + dataProviderFactory data_provider.Factory } var _ http.Handler = (*Handler)(nil) @@ -28,16 +26,14 @@ func NewWebSocketHandler( logger zerolog.Logger, config Config, chain flow.Chain, - streamApi state_stream.API, - streamConfig backend.Config, + factory data_provider.Factory, maxRequestSize int64, ) *Handler { return &Handler{ - HttpHandler: common.NewHttpHandler(logger, chain, maxRequestSize), - websocketConfig: config, - logger: logger, - streamApi: streamApi, - streamConfig: streamConfig, + HttpHandler: common.NewHttpHandler(logger, chain, maxRequestSize), + websocketConfig: config, + logger: logger, + dataProviderFactory: factory, } } @@ -65,6 +61,7 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } - controller := NewWebSocketController(logger, h.websocketConfig, h.streamApi, h.streamConfig, conn) - controller.HandleConnection(context.TODO()) + newConn := NewGorillaWebsocketConnection(conn) + controller := NewWebSocketController(logger, h.websocketConfig, h.dataProviderFactory, newConn) + controller.HandleConnection(context.Background()) } diff --git a/engine/access/rest/websockets/handler_test.go b/engine/access/rest/websockets/handler_test.go deleted file mode 100644 index 6b9cce06572..00000000000 --- a/engine/access/rest/websockets/handler_test.go +++ /dev/null @@ -1,86 +0,0 @@ -package websockets_test - -import ( - "encoding/json" - "net/http" - "net/http/httptest" - "strings" - "testing" - - "github.com/gorilla/websocket" - "github.com/rs/zerolog" - "github.com/stretchr/testify/require" - "github.com/stretchr/testify/suite" - - "github.com/onflow/flow-go/engine/access/rest/websockets" - "github.com/onflow/flow-go/engine/access/rest/websockets/models" - "github.com/onflow/flow-go/engine/access/state_stream/backend" - streammock "github.com/onflow/flow-go/engine/access/state_stream/mock" - "github.com/onflow/flow-go/model/flow" - "github.com/onflow/flow-go/utils/unittest" -) - -var ( - chainID = flow.Testnet -) - -type WsHandlerSuite struct { - suite.Suite - - logger zerolog.Logger - handler *websockets.Handler - wsConfig websockets.Config - streamApi *streammock.API - streamConfig backend.Config -} - -func (s *WsHandlerSuite) SetupTest() { - s.logger = unittest.Logger() - s.wsConfig = websockets.NewDefaultWebsocketConfig() - s.streamApi = streammock.NewAPI(s.T()) - s.streamConfig = backend.Config{} - s.handler = websockets.NewWebSocketHandler(s.logger, s.wsConfig, chainID.Chain(), s.streamApi, s.streamConfig, 1024) -} - -func TestWsHandlerSuite(t *testing.T) { - suite.Run(t, new(WsHandlerSuite)) -} - -func ClientConnection(url string) (*websocket.Conn, *http.Response, error) { - wsURL := "ws" + strings.TrimPrefix(url, "http") - return websocket.DefaultDialer.Dial(wsURL, nil) -} - -func (s *WsHandlerSuite) TestSubscribeRequest() { - s.Run("Happy path", func() { - server := httptest.NewServer(s.handler) - defer server.Close() - - conn, _, err := ClientConnection(server.URL) - defer func(conn *websocket.Conn) { - err := conn.Close() - require.NoError(s.T(), err) - }(conn) - require.NoError(s.T(), err) - - args := map[string]interface{}{ - "start_block_height": 10, - } - body := models.SubscribeMessageRequest{ - BaseMessageRequest: models.BaseMessageRequest{Action: "subscribe"}, - Topic: "blocks", - Arguments: args, - } - bodyJSON, err := json.Marshal(body) - require.NoError(s.T(), err) - - err = conn.WriteMessage(websocket.TextMessage, bodyJSON) - require.NoError(s.T(), err) - - _, msg, err := conn.ReadMessage() - require.NoError(s.T(), err) - - actualMsg := strings.Trim(string(msg), "\n\"\\ ") - require.Equal(s.T(), "block{height: 42}", actualMsg) - }) -} diff --git a/engine/access/rest/websockets/mock/websocket_connection.go b/engine/access/rest/websockets/mock/websocket_connection.go new file mode 100644 index 00000000000..e81b2bcec3f --- /dev/null +++ b/engine/access/rest/websockets/mock/websocket_connection.go @@ -0,0 +1,78 @@ +// Code generated by mockery v2.43.2. DO NOT EDIT. + +package mock + +import mock "github.com/stretchr/testify/mock" + +// WebsocketConnection is an autogenerated mock type for the WebsocketConnection type +type WebsocketConnection struct { + mock.Mock +} + +// Close provides a mock function with given fields: +func (_m *WebsocketConnection) Close() error { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for Close") + } + + var r0 error + if rf, ok := ret.Get(0).(func() error); ok { + r0 = rf() + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// ReadJSON provides a mock function with given fields: v +func (_m *WebsocketConnection) ReadJSON(v interface{}) error { + ret := _m.Called(v) + + if len(ret) == 0 { + panic("no return value specified for ReadJSON") + } + + var r0 error + if rf, ok := ret.Get(0).(func(interface{}) error); ok { + r0 = rf(v) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// WriteJSON provides a mock function with given fields: v +func (_m *WebsocketConnection) WriteJSON(v interface{}) error { + ret := _m.Called(v) + + if len(ret) == 0 { + panic("no return value specified for WriteJSON") + } + + var r0 error + if rf, ok := ret.Get(0).(func(interface{}) error); ok { + r0 = rf(v) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// NewWebsocketConnection creates a new instance of WebsocketConnection. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewWebsocketConnection(t interface { + mock.TestingT + Cleanup(func()) +}) *WebsocketConnection { + mock := &WebsocketConnection{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} From 8552a51124b5edaddf09e133ae21c8b794d263c2 Mon Sep 17 00:00:00 2001 From: Illia Malachyn Date: Thu, 28 Nov 2024 13:02:37 +0200 Subject: [PATCH 2/5] add commentaries for tests --- engine/access/rest/websockets/controller.go | 15 ++++--- .../access/rest/websockets/controller_test.go | 39 +++++++++++++------ 2 files changed, 35 insertions(+), 19 deletions(-) diff --git a/engine/access/rest/websockets/controller.go b/engine/access/rest/websockets/controller.go index 2b36b9303ae..b116f88af87 100644 --- a/engine/access/rest/websockets/controller.go +++ b/engine/access/rest/websockets/controller.go @@ -224,22 +224,21 @@ func (c *Controller) handleListSubscriptions(ctx context.Context, msg models.Lis func (c *Controller) shutdownConnection() { c.shutdownOnce.Do(func() { - defer close(c.communicationChannel) - defer func(conn WebsocketConnection) { + defer func() { + close(c.communicationChannel) + if err := c.conn.Close(); err != nil { c.logger.Warn().Err(err).Msg("error closing connection") } - }(c.conn) + }() c.logger.Debug().Msg("shutting down connection") - err := c.dataProviders.ForEach(func(_ uuid.UUID, dp dp.DataProvider) error { - dp.Close() + _ = c.dataProviders.ForEach(func(id uuid.UUID, dp dp.DataProvider) error { + err := dp.Close() + c.logger.Error().Err(err).Msgf("error closing data provider: %s", id.String()) return nil }) - if err != nil { - c.logger.Error().Err(err).Msg("error closing data provider") - } c.dataProviders.Clear() }) diff --git a/engine/access/rest/websockets/controller_test.go b/engine/access/rest/websockets/controller_test.go index 36375d9e733..e35d7e737d8 100644 --- a/engine/access/rest/websockets/controller_test.go +++ b/engine/access/rest/websockets/controller_test.go @@ -31,7 +31,6 @@ type WsControllerSuite struct { } func (s *WsControllerSuite) SetupTest() { - //s.logger = unittest.LoggerWithWriterAndLevel(os.Stdout, zerolog.DebugLevel) s.logger = unittest.Logger() s.wsConfig = NewDefaultWebsocketConfig() s.streamApi = streammock.NewAPI(s.T()) @@ -42,6 +41,8 @@ func TestWsControllerSuite(t *testing.T) { suite.Run(t, new(WsControllerSuite)) } +// TestSubscribeRequest tests the subscribe to topic flow. +// We emulate a request message from a client, and a response message from a controller. func (s *WsControllerSuite) TestSubscribeRequest() { s.T().Run("Happy path", func(t *testing.T) { conn, dataProviderFactory, dataProvider := newControllerMocks(t) @@ -58,6 +59,7 @@ func (s *WsControllerSuite) TestSubscribeRequest() { Arguments: nil, } + // Simulate receiving the subscription request from the client conn. On("ReadJSON", mock.Anything). Run(func(args mock.Arguments) { @@ -70,17 +72,21 @@ func (s *WsControllerSuite) TestSubscribeRequest() { Return(nil). Once() + // Channel to signal the test flow completion done := make(chan struct{}, 1) + + // Simulate writing a successful subscription response back to the client conn. On("WriteJSON", mock.Anything). Return(func(msg interface{}) error { response, ok := msg.(models.SubscribeMessageResponse) require.True(t, ok) require.True(t, response.Success) - close(done) + close(done) // Signal that response has been sent return websocket.ErrCloseSent }) + // Simulate client closing connection after receiving the response conn. On("ReadJSON", mock.Anything). Return(func(interface{}) error { @@ -95,12 +101,13 @@ func (s *WsControllerSuite) TestSubscribeRequest() { }) } +// TestSubscribeBlocks tests the functionality for streaming blocks to a subscriber. func (s *WsControllerSuite) TestSubscribeBlocks() { s.T().Run("Stream one block", func(t *testing.T) { conn, dataProviderFactory, dataProvider := newControllerMocks(t) controller := NewWebSocketController(s.logger, s.wsConfig, dataProviderFactory, conn) - // we want data provider to write some block to controller + // Simulate data provider write a block to the controller expectedBlock := unittest.BlockFixture() dataProvider. On("Run", mock.Anything). @@ -110,15 +117,17 @@ func (s *WsControllerSuite) TestSubscribeBlocks() { Once() done := make(chan struct{}, 1) - var actualBlock flow.Block - s.expectSubscriptionRequest(conn, done) s.expectSubscriptionResponse(conn, true) + // Expect a valid block to be passed to WriteJSON. + // If we got to this point, the controller executed all its logic properly + var actualBlock flow.Block conn. On("WriteJSON", mock.Anything). Return(func(msg interface{}) error { - block := msg.(flow.Block) + block, ok := msg.(flow.Block) + require.True(t, ok) actualBlock = block close(done) @@ -133,7 +142,7 @@ func (s *WsControllerSuite) TestSubscribeBlocks() { conn, dataProviderFactory, dataProvider := newControllerMocks(t) controller := NewWebSocketController(s.logger, s.wsConfig, dataProviderFactory, conn) - // we want data provider to write some block to controller + // Simulate data provider writes some blocks to the controller expectedBlocks := unittest.BlockFixtures(100) dataProvider. On("Run", mock.Anything). @@ -145,16 +154,20 @@ func (s *WsControllerSuite) TestSubscribeBlocks() { Once() done := make(chan struct{}, 1) - actualBlocks := make([]*flow.Block, len(expectedBlocks)) - i := 0 - s.expectSubscriptionRequest(conn, done) s.expectSubscriptionResponse(conn, true) + i := 0 + actualBlocks := make([]*flow.Block, len(expectedBlocks)) + + // Expect valid blocks to be passed to WriteJSON. + // If we got to this point, the controller executed all its logic properly conn. On("WriteJSON", mock.Anything). Return(func(msg interface{}) error { - block := msg.(flow.Block) + block, ok := msg.(flow.Block) + require.True(t, ok) + actualBlocks[i] = &block i += 1 @@ -172,6 +185,8 @@ func (s *WsControllerSuite) TestSubscribeBlocks() { }) } +// newControllerMocks initializes mock WebSocket connection, data provider, and data provider factory. +// The mocked functions are expected to be called in a case when a test is expected to reach WriteJSON function. func newControllerMocks(t *testing.T) (*connmock.WebsocketConnection, *dpmock.Factory, *dpmock.DataProvider) { conn := connmock.NewWebsocketConnection(t) conn.On("Close").Return(nil).Once() @@ -192,6 +207,7 @@ func newControllerMocks(t *testing.T) (*connmock.WebsocketConnection, *dpmock.Fa return conn, factory, dataProvider } +// expectSubscriptionRequest mocks the client's subscription request. func (s *WsControllerSuite) expectSubscriptionRequest(conn *connmock.WebsocketConnection, done <-chan struct{}) { requestMessage := models.SubscribeMessageRequest{ BaseMessageRequest: models.BaseMessageRequest{Action: "subscribe"}, @@ -223,6 +239,7 @@ func (s *WsControllerSuite) expectSubscriptionRequest(conn *connmock.WebsocketCo }) } +// expectSubscriptionResponse mocks the subscription response sent to the client. func (s *WsControllerSuite) expectSubscriptionResponse(conn *connmock.WebsocketConnection, success bool) { conn.On("WriteJSON", mock.Anything). Run(func(args mock.Arguments) { From 495cf0381b8f2a050959571109ecacc46bc0c6ef Mon Sep 17 00:00:00 2001 From: Illia Malachyn Date: Thu, 28 Nov 2024 13:30:14 +0200 Subject: [PATCH 3/5] Add comment to ws connection --- engine/access/rest/websockets/connection.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/engine/access/rest/websockets/connection.go b/engine/access/rest/websockets/connection.go index 9f762f6d389..5e1880f7ce8 100644 --- a/engine/access/rest/websockets/connection.go +++ b/engine/access/rest/websockets/connection.go @@ -4,6 +4,9 @@ import ( "github.com/gorilla/websocket" ) +// We wrap gorilla's websocket connection with interface +// to be able to mock it in order to test the types dependent on it + type WebsocketConnection interface { ReadJSON(v interface{}) error WriteJSON(v interface{}) error From 8496af356b67b731cfdfa1945db7fe1dcd921f90 Mon Sep 17 00:00:00 2001 From: Illia Malachyn Date: Fri, 29 Nov 2024 16:30:33 +0200 Subject: [PATCH 4/5] remove empty err message check --- engine/access/rest/websockets/controller.go | 9 --------- 1 file changed, 9 deletions(-) diff --git a/engine/access/rest/websockets/controller.go b/engine/access/rest/websockets/controller.go index b116f88af87..a1bca1e5525 100644 --- a/engine/access/rest/websockets/controller.go +++ b/engine/access/rest/websockets/controller.go @@ -3,7 +3,6 @@ package websockets import ( "context" "encoding/json" - "errors" "fmt" "sync" @@ -16,8 +15,6 @@ import ( "github.com/onflow/flow-go/utils/concurrentmap" ) -var ErrEmptyMessage = errors.New("empty message") - type Controller struct { logger zerolog.Logger config Config @@ -97,8 +94,6 @@ func (c *Controller) readMessages(ctx context.Context) { if websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseAbnormalClosure) || websocket.IsUnexpectedCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) { return - } else if errors.Is(err, ErrEmptyMessage) { - continue } c.logger.Debug().Err(err).Msg("error reading message from client") @@ -126,10 +121,6 @@ func (c *Controller) readMessage() (json.RawMessage, error) { return nil, fmt.Errorf("error reading JSON from client: %w", err) } - if message == nil { - return nil, ErrEmptyMessage - } - return message, nil } From 4682bb13d78cf5209e510e78ffa6e2268bec7f8c Mon Sep 17 00:00:00 2001 From: Illia Malachyn Date: Wed, 4 Dec 2024 17:04:45 +0200 Subject: [PATCH 5/5] simplify ReadJSON mock. refactor controller a bit --- engine/access/rest/websockets/controller.go | 16 +++++++------- .../access/rest/websockets/controller_test.go | 22 +++++++------------ 2 files changed, 16 insertions(+), 22 deletions(-) diff --git a/engine/access/rest/websockets/controller.go b/engine/access/rest/websockets/controller.go index a1bca1e5525..623dfb3e12b 100644 --- a/engine/access/rest/websockets/controller.go +++ b/engine/access/rest/websockets/controller.go @@ -35,7 +35,7 @@ func NewWebSocketController( logger: logger.With().Str("component", "websocket-controller").Logger(), config: config, conn: conn, - communicationChannel: make(chan interface{}, 10), //TODO: should it be buffered chan? + communicationChannel: make(chan interface{}), //TODO: should it be buffered chan? dataProviders: concurrentmap.New[uuid.UUID, dp.DataProvider](), dataProvidersFactory: factory, shutdownOnce: sync.Once{}, @@ -46,6 +46,7 @@ func NewWebSocketController( func (c *Controller) HandleConnection(ctx context.Context) { //TODO: configure the connection with ping-pong and deadlines //TODO: spin up a response limit tracker routine + defer c.shutdownConnection() go c.readMessages(ctx) c.writeMessages(ctx) } @@ -54,8 +55,6 @@ func (c *Controller) HandleConnection(ctx context.Context) { // The communication channel is filled by data providers. Besides, the response limit tracker is involved in // write message regulation func (c *Controller) writeMessages(ctx context.Context) { - defer c.shutdownConnection() - for { select { case <-ctx.Done(): @@ -86,8 +85,6 @@ func (c *Controller) writeMessages(ctx context.Context) { // readMessages continuously reads messages from a client WebSocket connection, // processes each message, and handles actions based on the message type. func (c *Controller) readMessages(ctx context.Context) { - defer c.shutdownConnection() - for { msg, err := c.readMessage() if err != nil { @@ -188,7 +185,12 @@ func (c *Controller) handleSubscribe(ctx context.Context, msg models.SubscribeMe Topic: dp.Topic(), ID: dp.ID().String(), } - c.communicationChannel <- response + + select { + case <-ctx.Done(): + return + case c.communicationChannel <- response: + } dp.Run(ctx) } @@ -216,8 +218,6 @@ func (c *Controller) handleListSubscriptions(ctx context.Context, msg models.Lis func (c *Controller) shutdownConnection() { c.shutdownOnce.Do(func() { defer func() { - close(c.communicationChannel) - if err := c.conn.Close(); err != nil { c.logger.Warn().Err(err).Msg("error closing connection") } diff --git a/engine/access/rest/websockets/controller_test.go b/engine/access/rest/websockets/controller_test.go index e35d7e737d8..4e27f9da9de 100644 --- a/engine/access/rest/websockets/controller_test.go +++ b/engine/access/rest/websockets/controller_test.go @@ -53,7 +53,7 @@ func (s *WsControllerSuite) TestSubscribeRequest() { Run(func(args mock.Arguments) {}). Once() - requestMessage := models.SubscribeMessageRequest{ + subscribeRequest := models.SubscribeMessageRequest{ BaseMessageRequest: models.BaseMessageRequest{Action: "subscribe"}, Topic: "blocks", Arguments: nil, @@ -63,11 +63,11 @@ func (s *WsControllerSuite) TestSubscribeRequest() { conn. On("ReadJSON", mock.Anything). Run(func(args mock.Arguments) { - reqMsg, ok := args.Get(0).(*json.RawMessage) + requestMsg, ok := args.Get(0).(*json.RawMessage) require.True(t, ok) - msg, err := json.Marshal(requestMessage) + subscribeRequestMessage, err := json.Marshal(subscribeRequest) require.NoError(t, err) - *reqMsg = msg + *requestMsg = subscribeRequestMessage }). Return(nil). Once() @@ -90,11 +90,8 @@ func (s *WsControllerSuite) TestSubscribeRequest() { conn. On("ReadJSON", mock.Anything). Return(func(interface{}) error { - _, ok := <-done - if !ok { - return websocket.ErrCloseSent - } - return nil + <-done + return websocket.ErrCloseSent }) controller.HandleConnection(context.Background()) @@ -231,11 +228,8 @@ func (s *WsControllerSuite) expectSubscriptionRequest(conn *connmock.WebsocketCo conn. On("ReadJSON", mock.Anything). Return(func(msg interface{}) error { - _, ok := <-done - if !ok { - return websocket.ErrCloseSent - } - return nil + <-done + return websocket.ErrCloseSent }) }