diff --git a/.changeset/moody-rules-agree.md b/.changeset/moody-rules-agree.md new file mode 100644 index 00000000000..ef1f3bcaf62 --- /dev/null +++ b/.changeset/moody-rules-agree.md @@ -0,0 +1,8 @@ +--- +"chainlink": patch +--- + +- register polling subscription to avoid subscription leaking when rpc client gets closed. +- add a temporary special treatment for SubscribeNewHead before we replace it with SubscribeToHeads. Add a goroutine that forwards new head from poller to caller channel. +- fix a deadlock in poller, by using a new lock for subs slice in rpc client. +#bugfix diff --git a/core/chains/evm/client/rpc_client.go b/core/chains/evm/client/rpc_client.go index 763348173aa..a29ed5e118c 100644 --- a/core/chains/evm/client/rpc_client.go +++ b/core/chains/evm/client/rpc_client.go @@ -129,7 +129,8 @@ type rpcClient struct { ws rawclient http *rawclient - stateMu sync.RWMutex // protects state* fields + stateMu sync.RWMutex // protects state* fields + subsSliceMu sync.RWMutex // protects subscription slice // Need to track subscriptions because closing the RPC does not (always?) // close the underlying subscription @@ -317,8 +318,8 @@ func (r *rpcClient) getRPCDomain() string { // registerSub adds the sub to the rpcClient list func (r *rpcClient) registerSub(sub ethereum.Subscription, stopInFLightCh chan struct{}) error { - r.stateMu.Lock() - defer r.stateMu.Unlock() + r.subsSliceMu.Lock() + defer r.subsSliceMu.Unlock() // ensure that the `sub` belongs to current life cycle of the `rpcClient` and it should not be killed due to // previous `DisconnectAll` call. select { @@ -335,12 +336,16 @@ func (r *rpcClient) registerSub(sub ethereum.Subscription, stopInFLightCh chan s // DisconnectAll disconnects all clients connected to the rpcClient func (r *rpcClient) DisconnectAll() { r.stateMu.Lock() - defer r.stateMu.Unlock() if r.ws.rpc != nil { r.ws.rpc.Close() } r.cancelInflightRequests() + r.stateMu.Unlock() + + r.subsSliceMu.Lock() r.unsubscribeAll() + r.subsSliceMu.Unlock() + r.chainInfoLock.Lock() r.latestChainInfo = commonclient.ChainInfo{} r.chainInfoLock.Unlock() @@ -496,11 +501,30 @@ func (r *rpcClient) SubscribeNewHead(ctx context.Context, channel chan<- *evmtyp if r.newHeadsPollInterval > 0 { interval := r.newHeadsPollInterval timeout := interval - poller, _ := commonclient.NewPoller[*evmtypes.Head](interval, r.latestBlock, timeout, r.rpcLog) + poller, pollerCh := commonclient.NewPoller[*evmtypes.Head](interval, r.latestBlock, timeout, r.rpcLog) if err = poller.Start(ctx); err != nil { return nil, err } + // NOTE this is a temporary special treatment for SubscribeNewHead before we refactor head tracker to use SubscribeToHeads + // as we need to forward new head from the poller channel to the channel passed from caller. + go func() { + for head := range pollerCh { + select { + case channel <- head: + // forwarding new head to the channel passed from caller + case <-poller.Err(): + // return as poller returns error + return + } + } + }() + + err = r.registerSub(&poller, chStopInFlight) + if err != nil { + return nil, err + } + lggr.Debugf("Polling new heads over http") return &poller, nil } @@ -547,6 +571,11 @@ func (r *rpcClient) SubscribeToHeads(ctx context.Context) (ch <-chan *evmtypes.H return nil, nil, err } + err = r.registerSub(&poller, chStopInFlight) + if err != nil { + return nil, nil, err + } + lggr.Debugf("Polling new heads over http") return channel, &poller, nil } @@ -579,6 +608,8 @@ func (r *rpcClient) SubscribeToHeads(ctx context.Context) (ch <-chan *evmtypes.H } func (r *rpcClient) SubscribeToFinalizedHeads(ctx context.Context) (<-chan *evmtypes.Head, commontypes.Subscription, error) { + ctx, cancel, chStopInFlight, _, _ := r.acquireQueryCtx(ctx, r.rpcTimeout) + defer cancel() interval := r.finalizedBlockPollInterval if interval == 0 { return nil, nil, errors.New("FinalizedBlockPollInterval is 0") @@ -588,6 +619,12 @@ func (r *rpcClient) SubscribeToFinalizedHeads(ctx context.Context) (<-chan *evmt if err := poller.Start(ctx); err != nil { return nil, nil, err } + + err := r.registerSub(&poller, chStopInFlight) + if err != nil { + return nil, nil, err + } + return channel, &poller, nil } diff --git a/core/chains/evm/client/rpc_client_test.go b/core/chains/evm/client/rpc_client_test.go index b594a0ca166..d959f8d1115 100644 --- a/core/chains/evm/client/rpc_client_test.go +++ b/core/chains/evm/client/rpc_client_test.go @@ -19,6 +19,8 @@ import ( "github.com/tidwall/gjson" "go.uber.org/zap" + commontypes "github.com/smartcontractkit/chainlink/v2/common/types" + "github.com/smartcontractkit/chainlink-common/pkg/utils/tests" "github.com/smartcontractkit/chainlink-common/pkg/logger" @@ -57,6 +59,25 @@ func TestRPCClient_SubscribeNewHead(t *testing.T) { } return } + + checkClosedRPCClientShouldRemoveExistingSub := func(t tests.TestingT, ctx context.Context, sub commontypes.Subscription, rpcClient client.RPCClient) { + errCh := sub.Err() + + // ensure sub exists + require.Equal(t, int32(1), rpcClient.SubscribersCount()) + rpcClient.DisconnectAll() + + // ensure sub is closed + select { + case <-errCh: // ok + default: + assert.Fail(t, "channel should be closed") + } + + require.NoError(t, rpcClient.Dial(ctx)) + require.Equal(t, int32(0), rpcClient.SubscribersCount()) + } + t.Run("Updates chain info on new blocks", func(t *testing.T) { server := testutils.NewWSServer(t, chainId, serverCallBack) wsURL := server.WSURL() @@ -131,6 +152,50 @@ func TestRPCClient_SubscribeNewHead(t *testing.T) { assert.Equal(t, int64(0), highestUserObservations.FinalizedBlockNumber) assert.Equal(t, (*big.Int)(nil), highestUserObservations.TotalDifficulty) }) + t.Run("SubscribeToHeads with http polling enabled will update new heads", func(t *testing.T) { + type rpcServer struct { + Head *evmtypes.Head + URL *url.URL + } + createRPCServer := func() *rpcServer { + server := &rpcServer{} + server.Head = &evmtypes.Head{Number: 127} + server.URL = testutils.NewWSServer(t, chainId, func(method string, params gjson.Result) (resp testutils.JSONRPCResponse) { + assert.Equal(t, "eth_getBlockByNumber", method) + if assert.True(t, params.IsArray()) && assert.Equal(t, "latest", params.Array()[0].String()) { + head := server.Head + jsonHead, err := json.Marshal(head) + if err != nil { + panic(fmt.Errorf("failed to marshal head: %w", err)) + } + resp.Result = string(jsonHead) + } + + return + }).WSURL() + return server + } + + server := createRPCServer() + rpc := client.NewRPCClient(lggr, *server.URL, nil, "rpc", 1, chainId, commonclient.Primary, 0, tests.TestInterval, commonclient.QueryTimeout, commonclient.QueryTimeout, "") + defer rpc.Close() + require.NoError(t, rpc.Dial(ctx)) + latest, highestUserObservations := rpc.GetInterceptedChainInfo() + // latest chain info hasn't been initialized + assert.Equal(t, int64(0), latest.BlockNumber) + assert.Equal(t, int64(0), highestUserObservations.BlockNumber) + + headCh, sub, err := rpc.SubscribeToHeads(commonclient.CtxAddHealthCheckFlag(tests.Context(t))) + require.NoError(t, err) + defer sub.Unsubscribe() + + head := <-headCh + assert.Equal(t, server.Head.Number, head.BlockNumber()) + // the http polling subscription should update the head block + latest, highestUserObservations = rpc.GetInterceptedChainInfo() + assert.Equal(t, server.Head.Number, latest.BlockNumber) + assert.Equal(t, server.Head.Number, highestUserObservations.BlockNumber) + }) t.Run("Concurrent Unsubscribe and onNewHead calls do not lead to a deadlock", func(t *testing.T) { const numberOfAttempts = 1000 // need a large number to increase the odds of reproducing the issue server := testutils.NewWSServer(t, chainId, serverCallBack) @@ -184,6 +249,68 @@ func TestRPCClient_SubscribeNewHead(t *testing.T) { require.ErrorContains(t, err, "RPCClient returned error (rpc)") tests.AssertLogEventually(t, observed, "evmclient.Client#EthSubscribe RPC call failure") }) + t.Run("Closed rpc client should remove existing SubscribeNewHead subscription with WS", func(t *testing.T) { + server := testutils.NewWSServer(t, chainId, serverCallBack) + wsURL := server.WSURL() + + rpc := client.NewRPCClient(lggr, *wsURL, nil, "rpc", 1, chainId, commonclient.Primary, 0, 0, commonclient.QueryTimeout, commonclient.QueryTimeout, "") + defer rpc.Close() + require.NoError(t, rpc.Dial(ctx)) + + ch := make(chan *evmtypes.Head) + sub, err := rpc.SubscribeNewHead(tests.Context(t), ch) + require.NoError(t, err) + checkClosedRPCClientShouldRemoveExistingSub(t, ctx, sub, rpc) + }) + t.Run("Closed rpc client should remove existing SubscribeNewHead subscription with HTTP polling", func(t *testing.T) { + server := testutils.NewWSServer(t, chainId, serverCallBack) + wsURL := server.WSURL() + + rpc := client.NewRPCClient(lggr, *wsURL, &url.URL{}, "rpc", 1, chainId, commonclient.Primary, 0, 1, commonclient.QueryTimeout, commonclient.QueryTimeout, "") + defer rpc.Close() + require.NoError(t, rpc.Dial(ctx)) + + ch := make(chan *evmtypes.Head) + sub, err := rpc.SubscribeNewHead(tests.Context(t), ch) + require.NoError(t, err) + checkClosedRPCClientShouldRemoveExistingSub(t, ctx, sub, rpc) + }) + t.Run("Closed rpc client should remove existing SubscribeToHeads subscription with WS", func(t *testing.T) { + server := testutils.NewWSServer(t, chainId, serverCallBack) + wsURL := server.WSURL() + + rpc := client.NewRPCClient(lggr, *wsURL, nil, "rpc", 1, chainId, commonclient.Primary, 0, 0, commonclient.QueryTimeout, commonclient.QueryTimeout, "") + defer rpc.Close() + require.NoError(t, rpc.Dial(ctx)) + + _, sub, err := rpc.SubscribeToHeads(tests.Context(t)) + require.NoError(t, err) + checkClosedRPCClientShouldRemoveExistingSub(t, ctx, sub, rpc) + }) + t.Run("Closed rpc client should remove existing SubscribeToHeads subscription with HTTP polling", func(t *testing.T) { + server := testutils.NewWSServer(t, chainId, serverCallBack) + wsURL := server.WSURL() + + rpc := client.NewRPCClient(lggr, *wsURL, &url.URL{}, "rpc", 1, chainId, commonclient.Primary, 0, 1, commonclient.QueryTimeout, commonclient.QueryTimeout, "") + defer rpc.Close() + require.NoError(t, rpc.Dial(ctx)) + + _, sub, err := rpc.SubscribeToHeads(tests.Context(t)) + require.NoError(t, err) + checkClosedRPCClientShouldRemoveExistingSub(t, ctx, sub, rpc) + }) + t.Run("Closed rpc client should remove existing SubscribeToFinalizedHeads subscription", func(t *testing.T) { + server := testutils.NewWSServer(t, chainId, serverCallBack) + wsURL := server.WSURL() + + rpc := client.NewRPCClient(lggr, *wsURL, &url.URL{}, "rpc", 1, chainId, commonclient.Primary, 1, 0, commonclient.QueryTimeout, commonclient.QueryTimeout, "") + defer rpc.Close() + require.NoError(t, rpc.Dial(ctx)) + + _, sub, err := rpc.SubscribeToFinalizedHeads(tests.Context(t)) + require.NoError(t, err) + checkClosedRPCClientShouldRemoveExistingSub(t, ctx, sub, rpc) + }) t.Run("Subscription error is properly wrapper", func(t *testing.T) { server := testutils.NewWSServer(t, chainId, serverCallBack) wsURL := server.WSURL()