diff --git a/rpc/events.go b/rpc/events.go index 002c0e077..85ee68f81 100644 --- a/rpc/events.go +++ b/rpc/events.go @@ -44,6 +44,10 @@ type EventsChunk struct { ContinuationToken string `json:"continuation_token,omitempty"` } +type SubscriptionID struct { + ID uint64 `json:"subscription_id"` +} + /**************************************************** Events Handlers *****************************************************/ diff --git a/rpc/handlers.go b/rpc/handlers.go index 4d4d35b50..3f07b9701 100644 --- a/rpc/handlers.go +++ b/rpc/handlers.go @@ -65,12 +65,14 @@ var ( ErrUnsupportedTxVersion = &jsonrpc.Error{Code: 61, Message: "the transaction version is not supported"} ErrUnsupportedContractClassVersion = &jsonrpc.Error{Code: 62, Message: "the contract class version is not supported"} ErrUnexpectedError = &jsonrpc.Error{Code: 63, Message: "An unexpected error occurred"} + ErrTooManyBlocksBack = &jsonrpc.Error{Code: 68, Message: "Cannot go back more than 1024 blocks"} // These errors can be only be returned by Juno-specific methods. ErrSubscriptionNotFound = &jsonrpc.Error{Code: 100, Message: "Subscription not found"} ) const ( + maxBlocksBack = 1024 maxEventChunkSize = 10240 maxEventFilterKeys = 1024 traceCacheSize = 128 @@ -334,6 +336,11 @@ func (h *Handler) Methods() ([]jsonrpc.Method, string) { //nolint: funlen Name: "starknet_specVersion", Handler: h.SpecVersion, }, + { + Name: "starknet_subscribeEvents", + Params: []jsonrpc.Parameter{{Name: "from_address"}, {Name: "keys"}, {Name: "block", Optional: true}}, + Handler: h.SubscribeEvents, + }, { Name: "juno_subscribeNewHeads", Handler: h.SubscribeNewHeads, diff --git a/rpc/subscriptions.go b/rpc/subscriptions.go new file mode 100644 index 000000000..2c9fbdf5a --- /dev/null +++ b/rpc/subscriptions.go @@ -0,0 +1,179 @@ +package rpc + +import ( + "context" + "encoding/json" + "sync" + + "github.com/NethermindEth/juno/blockchain" + "github.com/NethermindEth/juno/core" + "github.com/NethermindEth/juno/core/felt" + "github.com/NethermindEth/juno/jsonrpc" +) + +const subscribeEventsChunkSize = 1024 + +func (h *Handler) SubscribeEvents(ctx context.Context, fromAddr *felt.Felt, keys [][]felt.Felt, + blockID *BlockID, +) (*SubscriptionID, *jsonrpc.Error) { + w, ok := jsonrpc.ConnFromContext(ctx) + if !ok { + return nil, jsonrpc.Err(jsonrpc.MethodNotFound, nil) + } + + lenKeys := len(keys) + for _, k := range keys { + lenKeys += len(k) + } + if lenKeys > maxEventFilterKeys { + return nil, ErrTooManyKeysInFilter + } + + var requestedHeader *core.Header + headHeader, err := h.bcReader.HeadsHeader() + if err != nil { + return nil, ErrInternal.CloneWithData(err.Error()) + } + + if blockID == nil { + requestedHeader = headHeader + } else { + var rpcErr *jsonrpc.Error + requestedHeader, rpcErr = h.blockHeaderByID(blockID) + if rpcErr != nil { + return nil, rpcErr + } + + // Todo: should the pending block be included in the head count? + if headHeader.Number >= maxBlocksBack && requestedHeader.Number <= headHeader.Number-maxBlocksBack { + return nil, ErrTooManyBlocksBack + } + } + + id := h.idgen() + subscriptionCtx, subscriptionCtxCancel := context.WithCancel(ctx) + sub := &subscription{ + cancel: subscriptionCtxCancel, + conn: w, + } + h.mu.Lock() + h.subscriptions[id] = sub + h.mu.Unlock() + + headerSub := h.newHeads.Subscribe() + sub.wg.Go(func() { + defer func() { + h.unsubscribe(sub, id) + headerSub.Unsubscribe() + }() + + // The specification doesn't enforce ordering of events therefore events from new blocks can be sent before + // old blocks. + // Todo: see if sub's wg can be used? + wg := sync.WaitGroup{} + wg.Add(1) + + go func() { + defer wg.Done() + + for { + select { + case <-subscriptionCtx.Done(): + return + case header := <-headerSub.Recv(): + h.processEvents(subscriptionCtx, w, id, header.Number, header.Number, fromAddr, keys) + } + } + }() + + h.processEvents(subscriptionCtx, w, id, requestedHeader.Number, headHeader.Number, fromAddr, keys) + + wg.Wait() + }) + + return &SubscriptionID{ID: id}, nil +} + +func (h *Handler) processEvents(ctx context.Context, w jsonrpc.Conn, id, from, to uint64, fromAddr *felt.Felt, keys [][]felt.Felt) { + filter, err := h.bcReader.EventFilter(fromAddr, keys) + if err != nil { + h.log.Warnw("Error creating event filter", "err", err) + return + } + defer h.callAndLogErr(filter.Close, "Error closing event filter in events subscription") + + if err = setEventFilterRange(filter, &BlockID{Number: from}, &BlockID{Number: to}, to); err != nil { + h.log.Warnw("Error setting event filter range", "err", err) + return + } + + var cToken *blockchain.ContinuationToken + filteredEvents, cToken, err := filter.Events(cToken, subscribeEventsChunkSize) + if err != nil { + h.log.Warnw("Error filtering events", "err", err) + return + } + + err = sendEvents(ctx, w, filteredEvents, id) + if err != nil { + h.log.Warnw("Error sending events", "err", err) + return + } + + for cToken != nil { + filteredEvents, cToken, err = filter.Events(cToken, subscribeEventsChunkSize) + if err != nil { + h.log.Warnw("Error filtering events", "err", err) + return + } + + err = sendEvents(ctx, w, filteredEvents, id) + if err != nil { + h.log.Warnw("Error sending events", "err", err) + return + } + } +} + +func sendEvents(ctx context.Context, w jsonrpc.Conn, events []*blockchain.FilteredEvent, id uint64) error { + for _, event := range events { + select { + case <-ctx.Done(): + return ctx.Err() + default: + // Pending block doesn't have a number + var blockNumber *uint64 + if event.BlockHash != nil { + blockNumber = &(event.BlockNumber) + } + emittedEvent := &EmittedEvent{ + BlockNumber: blockNumber, + BlockHash: event.BlockHash, + TransactionHash: event.TransactionHash, + Event: &Event{ + From: event.From, + Keys: event.Keys, + Data: event.Data, + }, + } + + resp, err := json.Marshal(jsonrpc.Request{ + Version: "2.0", + Method: "starknet_subscriptionEvents", + Params: map[string]any{ + "subscription_id": id, + "result": emittedEvent, + }, + }) + if err != nil { + return err + } + + _, err = w.Write(resp) + if err != nil { + return err + } + } + } + return nil +} diff --git a/rpc/subscriptions_test.go b/rpc/subscriptions_test.go new file mode 100644 index 000000000..d20a93ad6 --- /dev/null +++ b/rpc/subscriptions_test.go @@ -0,0 +1,135 @@ +package rpc_test + +import ( + "context" + "fmt" + "net" + "testing" + + "github.com/NethermindEth/juno/clients/feeder" + "github.com/NethermindEth/juno/core" + "github.com/NethermindEth/juno/core/felt" + "github.com/NethermindEth/juno/jsonrpc" + "github.com/NethermindEth/juno/mocks" + "github.com/NethermindEth/juno/rpc" + adaptfeeder "github.com/NethermindEth/juno/starknetdata/feeder" + "github.com/NethermindEth/juno/utils" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" +) + +func TestSubscribeEventsAndUnsubscribe(t *testing.T) { + log := utils.NewNopZapLogger() + + t.Run("Too many keys in filter", func(t *testing.T) { + mockCtrl := gomock.NewController(t) + t.Cleanup(mockCtrl.Finish) + + mockChain := mocks.NewMockReader(mockCtrl) + mockSyncer := mocks.NewMockSyncReader(mockCtrl) + handler := rpc.New(mockChain, mockSyncer, nil, "", log) + + keys := make([][]felt.Felt, 1024+1) + fromAddr := new(felt.Felt).SetBytes([]byte("from_address")) + + serverConn, clientConn := net.Pipe() + t.Cleanup(func() { + require.NoError(t, serverConn.Close()) + require.NoError(t, clientConn.Close()) + }) + + subCtx := context.WithValue(context.Background(), jsonrpc.ConnKey{}, &fakeConn{w: serverConn}) + + id, rpcErr := handler.SubscribeEvents(subCtx, fromAddr, keys, nil) + assert.Zero(t, id) + assert.Equal(t, rpc.ErrTooManyKeysInFilter, rpcErr) + }) + + t.Run("Too many blocks back", func(t *testing.T) { + mockCtrl := gomock.NewController(t) + t.Cleanup(mockCtrl.Finish) + + mockChain := mocks.NewMockReader(mockCtrl) + mockSyncer := mocks.NewMockSyncReader(mockCtrl) + handler := rpc.New(mockChain, mockSyncer, nil, "", log) + + keys := make([][]felt.Felt, 1) + fromAddr := new(felt.Felt).SetBytes([]byte("from_address")) + blockID := &rpc.BlockID{Number: 0} + + serverConn, clientConn := net.Pipe() + t.Cleanup(func() { + require.NoError(t, serverConn.Close()) + require.NoError(t, clientConn.Close()) + }) + + subCtx := context.WithValue(context.Background(), jsonrpc.ConnKey{}, &fakeConn{w: serverConn}) + + // Note the end of the window doesn't need to be tested because if requested block number is more than the + // head, a block not found error will be returned. This behaviour has been tested in various other test, and we + // don't need to test it here again. + t.Run("head is 1024", func(t *testing.T) { + mockChain.EXPECT().HeadsHeader().Return(&core.Header{Number: 1024}, nil) + mockChain.EXPECT().BlockHeaderByNumber(blockID.Number).Return(&core.Header{Number: 0}, nil) + + id, rpcErr := handler.SubscribeEvents(subCtx, fromAddr, keys, blockID) + assert.Zero(t, id) + assert.Equal(t, rpc.ErrTooManyBlocksBack, rpcErr) + }) + + t.Run("head is more than 1024", func(t *testing.T) { + mockChain.EXPECT().HeadsHeader().Return(&core.Header{Number: 2024}, nil) + mockChain.EXPECT().BlockHeaderByNumber(blockID.Number).Return(&core.Header{Number: 0}, nil) + + id, rpcErr := handler.SubscribeEvents(subCtx, fromAddr, keys, blockID) + assert.Zero(t, id) + assert.Equal(t, rpc.ErrTooManyBlocksBack, rpcErr) + }) + }) + + t.Run("Events from old blocks and new", func(t *testing.T) { + n := utils.Ptr(utils.Sepolia) + client := feeder.NewTestClient(t, n) + gw := adaptfeeder.New(client) + + b1, err := gw.BlockByNumber(context.Background(), 56377) + require.NoError(t, err) + + // Make a shallow copy of b1 into b2 and b3. Then modify them accordingly. + b2, b3 := new(core.Block), new(core.Block) + b2.Header, b3.Header = new(core.Header), new(core.Header) + *b2.Header, *b3.Header = *b1.Header, *b1.Header + b2.Number = b1.Number + 1 + b3.Number = b2.Number + 1 + fmt.Println(b1.Number, b2.Number, b3.Number) + + serverConn, clientConn := net.Pipe() + t.Cleanup(func() { + require.NoError(t, serverConn.Close()) + require.NoError(t, clientConn.Close()) + }) + + subCtx := context.WithValue(context.Background(), jsonrpc.ConnKey{}, &fakeConn{w: serverConn}) + fromAddr := b1.Receipts[0].Events[0].From + keys := make([][]felt.Felt, 1) + for _, k := range b1.Receipts[0].Events[0].Keys { + keys[0] = append(keys[0], *k) + } + + mockCtrl := gomock.NewController(t) + t.Cleanup(mockCtrl.Finish) + + mockChain := mocks.NewMockReader(mockCtrl) + mockSyncer := mocks.NewMockSyncReader(mockCtrl) + handler := rpc.New(mockChain, mockSyncer, nil, "", log) + + mockChain.EXPECT().HeadsHeader().Return(&core.Header{Number: b2.Number}, nil) + mockChain.EXPECT().BlockHeaderByNumber(b1.Number).Return(b1.Header, nil) + + _, rpcErr := handler.SubscribeEvents(subCtx, fromAddr, keys, &rpc.BlockID{Number: b1.Number}) + require.Nil(t, rpcErr) + + // Check from the conn that the correct id has been passed + }) +}