Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Websocket shutdown logic #2277

Merged
merged 10 commits into from
Dec 23, 2024
19 changes: 17 additions & 2 deletions jsonrpc/websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,17 @@
log utils.SimpleLogger
connParams *WebsocketConnParams
listener NewRequestListener

shutdown <-chan struct{}
}

func NewWebsocket(rpc *Server, log utils.SimpleLogger) *Websocket {
func NewWebsocket(rpc *Server, shutdown <-chan struct{}, log utils.SimpleLogger) *Websocket {
ws := &Websocket{
rpc: rpc,
log: log,
connParams: DefaultWebsocketConnParams(),
listener: &SelectiveListener{},
shutdown: shutdown,
}

return ws
Expand Down Expand Up @@ -54,7 +57,19 @@

// TODO include connection information, such as the remote address, in the logs.

wsc := newWebsocketConn(r.Context(), conn, ws.connParams)
ctx, cancel := context.WithCancel(r.Context())
defer cancel()
go func() {
select {
case <-ws.shutdown:
cancel()

Check warning on line 65 in jsonrpc/websocket.go

View check run for this annotation

Codecov / codecov/patch

jsonrpc/websocket.go#L64-L65

Added lines #L64 - L65 were not covered by tests
case <-ctx.Done():
// in case websocket connection is closed and server is not in shutdown mode
// we need to release this goroutine from waiting
}
kirugan marked this conversation as resolved.
Show resolved Hide resolved
}()

wsc := newWebsocketConn(ctx, conn, ws.connParams)

for {
_, wsc.r, err = wsc.conn.Reader(wsc.ctx)
Expand Down
2 changes: 1 addition & 1 deletion jsonrpc/websocket_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ func testConnection(t *testing.T, ctx context.Context, method jsonrpc.Method, li
require.NoError(t, rpc.RegisterMethods(method))

// Server
srv := httptest.NewServer(jsonrpc.NewWebsocket(rpc, utils.NewNopZapLogger()))
srv := httptest.NewServer(jsonrpc.NewWebsocket(rpc, nil, utils.NewNopZapLogger()))
kirugan marked this conversation as resolved.
Show resolved Hide resolved

// Client
conn, resp, err := websocket.Dial(ctx, srv.URL, nil) //nolint:bodyclose // websocket package closes resp.Body for us.
Expand Down
15 changes: 13 additions & 2 deletions node/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@
}
}

func (h *httpService) registerOnShutdown(f func()) {
h.srv.RegisterOnShutdown(f)
}

func makeHTTPService(host string, port uint16, handler http.Handler) *httpService {
portStr := strconv.FormatUint(uint64(port), 10)
return &httpService{
Expand Down Expand Up @@ -108,9 +112,11 @@
listener = makeWSMetrics()
}

shutdown := make(chan struct{})

mux := http.NewServeMux()
for path, server := range servers {
wsHandler := jsonrpc.NewWebsocket(server, log)
wsHandler := jsonrpc.NewWebsocket(server, shutdown, log)
if listener != nil {
wsHandler = wsHandler.WithListener(listener)
}
Expand All @@ -124,7 +130,12 @@
if corsEnabled {
handler = cors.Default().Handler(handler)
}
return makeHTTPService(host, port, handler)

httpServ := makeHTTPService(host, port, handler)
httpServ.registerOnShutdown(func() {
close(shutdown)
})

Check warning on line 137 in node/http.go

View check run for this annotation

Codecov / codecov/patch

node/http.go#L136-L137

Added lines #L136 - L137 were not covered by tests
return httpServ
}

func makeMetrics(host string, port uint16) *httpService {
Expand Down
2 changes: 1 addition & 1 deletion rpc/events_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,7 @@ func TestMultipleSubscribeNewHeadsAndUnsubscribe(t *testing.T) {
Params: []jsonrpc.Parameter{{Name: "id"}},
Handler: handler.Unsubscribe,
}))
ws := jsonrpc.NewWebsocket(server, log)
ws := jsonrpc.NewWebsocket(server, nil, log)
kirugan marked this conversation as resolved.
Show resolved Hide resolved
httpSrv := httptest.NewServer(ws)
conn1, _, err := websocket.Dial(ctx, httpSrv.URL, nil)
require.NoError(t, err)
Expand Down
1 change: 0 additions & 1 deletion rpc/subscriptions.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,6 @@ func (h *Handler) SubscribeEvents(ctx context.Context, fromAddr *felt.Felt, keys
case <-subscriptionCtx.Done():
return
case header := <-headerSub.Recv():

h.processEvents(subscriptionCtx, w, id, header.Number, header.Number, fromAddr, keys)
}
}
Expand Down
Loading