From f7c7884f6fa695631005ef85a8b2f85f8297e2c6 Mon Sep 17 00:00:00 2001 From: canstand Date: Sat, 26 Oct 2024 01:03:35 +0800 Subject: [PATCH] fix: WebSocketRoute connect to server synchronously --- tests/route_web_socket_test.go | 44 +++++++++++++++++++++------------- tests/utils_test.go | 1 + websocket_route.go | 2 +- 3 files changed, 30 insertions(+), 17 deletions(-) diff --git a/tests/route_web_socket_test.go b/tests/route_web_socket_test.go index 532f2b5..6556983 100644 --- a/tests/route_web_socket_test.go +++ b/tests/route_web_socket_test.go @@ -19,7 +19,7 @@ func assertSlicesEqual(t *testing.T, expected []interface{}, cb func() (interfac require.EventuallyWithT(t, func(collect *assert.CollectT) { actual, err := cb() require.NoError(t, err) - assert.EqualValues(t, expected, actual) + require.EqualValues(t, expected, actual) }, 5*time.Second, 200*time.Millisecond) } @@ -157,7 +157,8 @@ func TestRouteWebSocketShouldWorkWithServer(t *testing.T) { wsRouteChan := make(chan playwright.WebSocketRoute, 1) handleWS := func(ws playwright.WebSocketRoute) { - server, _ := ws.ConnectToServer() + server, err := ws.ConnectToServer() + require.NoError(t, err) ws.OnMessage(func(message interface{}) { msg := message.(string) @@ -192,24 +193,23 @@ func TestRouteWebSocketShouldWorkWithServer(t *testing.T) { wsRouteChan <- ws } - require.NoError(t, page.RouteWebSocket(regexp.MustCompile(".*"), handleWS)) - - wsConnChan := server.WaitForWebSocketConnection() log := newSyncSlice[string]() - server.OnceWebSocketConnection(func(c *websocket.Conn, r *http.Request) { - server.OnWebSocketMessage(func(c *websocket.Conn, r *http.Request, msgType websocket.MessageType, msg []byte) { - log.Append(fmt.Sprintf("message: %s", msg)) - }) - server.OnWebSocketClose(func(err *websocket.CloseError) { - log.Append(fmt.Sprintf("close: code=%d reason=%s", err.Code, err.Reason)) - }) + server.OnWebSocketMessage(func(c *websocket.Conn, r *http.Request, msgType websocket.MessageType, msg []byte) { + log.Append(fmt.Sprintf("message: %s", msg)) + }) + server.OnWebSocketClose(func(err *websocket.CloseError) { + log.Append(fmt.Sprintf("close: code=%d reason=%s", err.Code, err.Reason)) }) + require.NoError(t, page.RouteWebSocket(regexp.MustCompile(".*"), handleWS)) + + wsConnChan := server.WaitForWebSocketConnection() + setupWS(t, page, server.PORT, "blob") ws := <-wsConnChan require.EventuallyWithT(t, func(collect *assert.CollectT) { - assert.EqualValues(t, []string{"message: fake"}, log.Get()) + require.EqualValues(t, []string{"message: fake"}, log.Get()) }, 5*time.Second, 200*time.Millisecond) ws.SendMessage(websocket.MessageText, []byte("to-modify")) @@ -234,7 +234,7 @@ func TestRouteWebSocketShouldWorkWithServer(t *testing.T) { require.NoError(t, err) require.EventuallyWithT(t, func(collect *assert.CollectT) { - assert.EqualValues(t, []string{"message: fake", "message: modified", "message: pass-client"}, log.Get()) + require.EqualValues(t, []string{"message: fake", "message: modified", "message: pass-client"}, log.Get()) }, 5*time.Second, 200*time.Millisecond) assertSlicesEqual(t, []interface{}{ @@ -246,6 +246,18 @@ func TestRouteWebSocketShouldWorkWithServer(t *testing.T) { return page.Evaluate(`window.log`) }) + route := <-wsRouteChan + route.Send("another") + assertSlicesEqual(t, []interface{}{ + "open", + "message: data=modified origin=ws://localhost:" + server.PORT + " lastEventId=", + "message: data=pass-server origin=ws://localhost:" + server.PORT + " lastEventId=", + "message: data=response origin=ws://localhost:" + server.PORT + " lastEventId=", + "message: data=another origin=ws://localhost:" + server.PORT + " lastEventId=", + }, func() (interface{}, error) { + return page.Evaluate(`window.log`) + }) + _, err = page.Evaluate(` () => { window.ws.send('pass-client-2'); @@ -253,7 +265,7 @@ func TestRouteWebSocketShouldWorkWithServer(t *testing.T) { require.NoError(t, err) require.EventuallyWithT(t, func(collect *assert.CollectT) { - assert.EqualValues(t, []string{"message: fake", "message: modified", "message: pass-client", "message: pass-client-2"}, log.Get()) + require.EqualValues(t, []string{"message: fake", "message: modified", "message: pass-client", "message: pass-client-2"}, log.Get()) }, 5*time.Second, 200*time.Millisecond) _, err = page.Evaluate(` @@ -263,7 +275,7 @@ func TestRouteWebSocketShouldWorkWithServer(t *testing.T) { require.NoError(t, err) require.EventuallyWithT(t, func(collect *assert.CollectT) { - assert.EqualValues(t, []string{ + require.EqualValues(t, []string{ "message: fake", "message: modified", "message: pass-client", diff --git a/tests/utils_test.go b/tests/utils_test.go index 04e8911..2acc2f2 100644 --- a/tests/utils_test.go +++ b/tests/utils_test.go @@ -86,6 +86,7 @@ func (t *testServer) AfterEach() { t.requestSubscriberes = make(map[string][]chan *http.Request) t.eventEmitter.RemoveListeners("connection") t.eventEmitter.RemoveListeners("message") + t.eventEmitter.RemoveListeners("close") t.testServer.CloseClientConnections() } diff --git a/websocket_route.go b/websocket_route.go index 491f9ad..bb74ab9 100644 --- a/websocket_route.go +++ b/websocket_route.go @@ -77,7 +77,7 @@ func (r *webSocketRouteImpl) ConnectToServer() (WebSocketRoute, error) { if r.connected.Load() { return nil, fmt.Errorf("Already connected to the server") } - go r.channel.SendNoReply("connect") + r.channel.SendNoReply("connect") r.connected.Store(true) return r.server, nil }