From e942fcdcfe040e8fec6e94d0e0af424bda4f7879 Mon Sep 17 00:00:00 2001 From: Elad Gildnur <6321801+shleikes@users.noreply.github.com> Date: Tue, 6 Aug 2024 15:05:49 +0200 Subject: [PATCH] feat: PRT-1178: Subscriptions phase 1 (without handover) (#1462) * Enable the subscription again * Add GetParams() to GenericMessage struct * Fix tendermint small bug * RelayProcessor to skip some logic on subscription * Add subscription management in the provider - some TODOs are left * Fix the subscription category for supported chains * Remove unused functions * Small fix to the provider web socket manager * Small fixes to the provider * Move finalization consensus to a new package * Split the SendRelay into 2 functions: ParseRelay & SendParsedRelay * Add an utility function for readable hex * Remove the logic of disconnecting consumers after the end of epoch * Fix to the sage channel sender * Use a utility function CreateHashFromParams * Add some trace logs * Rename IsSubscription -> IsSubscriptionCategory * Add LongLastingProvidersStorage * Add the cancellable context to subscription relays int he rpcconsumer_server * Add ConsumerWebsocketManager & ConsumerWSSubscriptionManager * Move the subscription closing logic to a new function * Rename to make the code more sense * Implemented timeout of 15 minutes for open subscription * Add the subscription function tags * Updated cosmos and ethereum specs with new subscription parse directives * Implemented unsubscribe logic in the consumer * Small logs improvements * Rename CreateSubscriptionKey -> CreateDappKey * Use hashed params as map keys for subscription context * Move the websocket replies channel creation to the subscription manager * Rethink the storing of connected dapps and active subscriptions in the consumer Which now supports multiple subscriptions in one websocket * Some small fixes to the Unsubscribe function * Small rename * Implement unsubscribe_all for consumer * Typos fix * Improved some logs * Add a missing lock * Make the websocket channel writing a little bit safer * Fix a small bug * Don't support ubsubscribe_all in Relay flow * Split the TryRelay into smaller pieces * Add GetID to RPCInput interface * Add support for unsubscribe relay in the provider * Safe channel sender improvements * Use SafeChannelSender for the subscription in the consumer and fix bugs * Use formatter to preserve original jsonrpc ID * Add some comments * Small fixes * Rename and fix AnalyzeWebSocketErrorAndWriteMessage * Add a comment * Remove done TODO * Some log improvements * Provider subscription manager fix * Move UpdateCU call to start of subscription instead of its end * Updated subscription CU cost * Fix finalization consensus test * Lint fixes * Add LongLastingProvidersStorage tests * Add SafeChannelSender tests * Add 10 seconds timeout to handle hang when waiting for first message from subscription * Move the first subscription reply verification into the rpcconsumer_server so we can control better the OnSessionFailure call * Handle bad provider signature better * Fix a small bug and remove redundant return value * Fix typos * Move ProviderNodeSubscriptionManager to chainlib * Fix to the UnsubscribeAll call * Create a const for the subscription relay timeout * Sync the consumer and provider on session failure * Change the type of replyServer to not be a pointer to interface * Small fix for the consumer subscription manager * Fix a small bug in the provider node subscription manager * Small fix to the consumer subscription manager * Remove epochs from provider subscription manager * Small fix to the safe channel sender * Tiny log fix * Allow weboscket and http connectors in jsonRPC * Small log fix to connector.go * Add websocket server to mock chain lib * Add tests for provider and consumer subscription manager * Post merge fixes * Fix lint errors * Verify webSocket is up in consumer in protocol integration tests * Revert "Allow weboscket and http connectors in jsonRPC" This reverts commit 0a5d142969d0f574961936be2e00f57855847fdf. * WIP * changing functionality of parsed directive to be saved on the chain message instead of every time looking for it from scratch * rename long lasting to active subscription provider storage * adding purge callback. * two consumers setup * fix a bug where 2 consumers wouldn't be able to subscribe. * fixing test for subscriptions, fixing bug in chain router, and managing the state properly. * fixing safe channel sender functionality * fixing provider subscription manager test. * fix problem with websocket listener in tests. * fixing test routine condition. * lint * fix ws issue on generic chain lib mocks * fix e2e for jsonrpc. * removing unused code. * fix typo * adding comments * comment fix * adding more comments. * removing spam trace logs * undoing WIP changes * adding seen block to rpc consumer's consistency. * handling same consumer subscription hash. * rename to a proper convention * fixing case where two subscriptions at the same time could trigger a race, and a hanging lock. * improve readability * adding comments for better readability * insert sdk address inside the consumer container to avoid unnecessary unmarshlling * improve json marshalling by using gojson :) * remove unused code * change json marshaling package * improve flow on rpcconsumer server. * improved encoding on rpc consumer server * mistakenly forgotten unlock. * Small log fix * Pass also the ws connection to the provider in the setup_providers.sh script * Add websocket subscription test in e2e * fix mock tests * add replace channel method to safe channel sender * fix test * use replace channel instead of close * add more documentation * fix sub manager test * Fix lint * fix pending subscription race issue and add tests * Checking for errors when writing to file in e2e * Add websocket response to log * Allow more than 2 websocket in e2e test for subscriptions * fix nil deref on a race between read and close connection * Remove log * add comments * remove log spam and make addon print better * fix ws message bug * Leftovers from committed WIP * Small code cleaning * Post merge fix * Small test fix * fix: PRT-1178: Subscription phase 1 unsubscribe fix (#1575) * Fix init_chain command for macOS * Fix the bug with unsubscribing from jsonrpc * Send error message to user when subscription not found * Updated the ethereum spec to match the fix * Delete the dappkey from connected dapps when empty * Test fix * Fix lint issues * Make the logs clear to investigate test fail on GH * Add some logs for the SDK to trace the E2E test failure * Attempt to fix test by removing go routines * Another attempt to fix the protocol tests * Small lint fix * merged * Adding better search func for requirements * fix unused else * adding consumer guid for subscription requests from multiple consumer processes with the same key * consumer websocket manager fixes * fix 7 more comments * fixing more comments :) * set all id parsing in chain message. * fixing more comments * merged v2 changes * merge conflict fixes * terminate connection in handleNewNodeMessage on error * fixing missing websocket id for unique dapp key for websocket subscriptions from the same user. * another review bites the dust * another one bites the dust * logs * fixing a bug in unsubscribe brackets * do not purge providers if they have more than one subscription * solving issue after issue. I love to call it, Hero mode. * solve the safe channel sender issue when sending a message and waiting the entire time until the consumer responds * remove log. * fix lint * lint v2 * lint v3 * lint v4 * fixing context issue reading headers on subscription. * nil deref fix * fix initialize redundancy * adding multiple unique id on same dapp etc.. * fixing pending race for more than one pending subscriptions * adding a test to validate the queue mechanism * fixed. --------- Co-authored-by: Ran Mishael Co-authored-by: Omer <100387053+omerlavanet@users.noreply.github.com> --- .github/workflows/lava.yml | 9 +- .../avalanch_internal_paths_example.yml | 26 +- cookbook/specs/ethereum.json | 21 +- cookbook/specs/fantom.json | 6 +- cookbook/specs/tendermint.json | 32 +- .../stateQuery/state_badge_query.ts | 4 +- .../stateQuery/state_chain_query.ts | 4 +- proto/lavanet/lava/spec/api_collection.proto | 3 + protocol/chainlib/base_chain_parser.go | 10 +- protocol/chainlib/chain_fetcher.go | 7 +- protocol/chainlib/chain_message.go | 109 +- protocol/chainlib/chain_message_queries.go | 27 +- protocol/chainlib/chain_router.go | 63 +- protocol/chainlib/chain_router_test.go | 496 +++++++++- protocol/chainlib/chainlib.go | 36 +- protocol/chainlib/chainlib_mock.go | 872 ++++++++++++++++ protocol/chainlib/chainproxy/common.go | 4 + protocol/chainlib/chainproxy/connector.go | 5 +- .../chainlib/chainproxy/connector_test.go | 44 +- .../chainproxy/rpcInterfaceMessages/common.go | 4 + .../rpcInterfaceMessages/grpcMessage.go | 10 + .../rpcInterfaceMessages/jsonRPCMessage.go | 38 +- .../rpcInterfaceMessages/restMessage.go | 25 +- .../tendermintRPCMessage.go | 22 +- .../chainlib/chainproxy/rpcclient/client.go | 5 + .../chainlib/chainproxy/rpcclient/utils.go | 9 + protocol/chainlib/common.go | 50 +- protocol/chainlib/common_test.go | 5 + protocol/chainlib/common_test_utils.go | 64 +- .../chainlib/consumer_websocket_manager.go | 273 +++++ .../consumer_ws_subscription_manager.go | 930 ++++++++++++++++++ .../consumer_ws_subscription_manager_test.go | 714 ++++++++++++++ protocol/chainlib/grpc.go | 1 + protocol/chainlib/grpc_test.go | 12 +- protocol/chainlib/jsonRPC.go | 211 ++-- protocol/chainlib/jsonRPC_test.go | 183 +++- .../provider_node_subscription_manager.go | 650 ++++++++++++ ...provider_node_subscription_manager_test.go | 441 +++++++++ protocol/chainlib/rest.go | 4 + protocol/chainlib/rest_test.go | 12 +- protocol/chainlib/tendermintRPC.go | 265 ++--- protocol/chainlib/tendermintRPC_test.go | 8 +- protocol/chaintracker/chain_tracker.go | 19 +- protocol/chaintracker/errors.go | 2 +- protocol/common/endpoints.go | 38 +- protocol/common/errors.go | 1 + protocol/common/return_errors.go | 11 + protocol/common/safe_channel_sender.go | 99 ++ protocol/common/safe_channel_sender_test.go | 41 + protocol/common/strings.go | 18 + protocol/common/timeout.go | 5 + protocol/integration/mocks.go | 4 +- protocol/integration/protocol_test.go | 160 ++- protocol/lavaprotocol/errors.go | 12 +- .../finalization_consensus.go | 9 +- .../finalization_consensus_test.go | 6 +- protocol/lavaprotocol/response_builder.go | 66 +- .../active_subscription_provider_storage.go | 71 ++ ...tive_subscription_provider_storage_test.go | 65 ++ .../lavasession/consumer_session_manager.go | 42 +- .../consumer_session_manager_test.go | 2 +- .../lavasession/provider_session_manager.go | 105 -- .../provider_session_manager_test.go | 188 ---- protocol/lavasession/used_providers.go | 1 + protocol/metrics/rpcconsumerlogs.go | 12 +- protocol/metrics/rpcconsumerlogs_test.go | 4 +- protocol/parser/parser.go | 1 + protocol/parser/parser_test.go | 4 + protocol/performance/cache.go | 4 +- protocol/performance/errors.go | 2 +- protocol/rpcconsumer/consumer_consistency.go | 18 +- protocol/rpcconsumer/relay_processor.go | 7 + protocol/rpcconsumer/relay_processor_test.go | 18 +- protocol/rpcconsumer/rpcconsumer.go | 27 +- protocol/rpcconsumer/rpcconsumer_server.go | 288 +++++- protocol/rpcprovider/provider_listener.go | 4 +- .../reliability_manager_test.go | 2 +- protocol/rpcprovider/rpcprovider.go | 15 +- protocol/rpcprovider/rpcprovider_server.go | 711 +++++++------ .../rpcprovider/rpcprovider_server_test.go | 2 +- .../statetracker/consumer_state_tracker.go | 4 +- .../finalization_consensus_updater.go | 8 +- scripts/init_chain.sh | 11 +- .../pre_setups/init_lava_only_with_node.sh | 6 +- .../init_lava_only_with_node_protocol_only.sh | 36 + .../init_lava_only_with_node_two_consumers.sh | 69 ++ scripts/setup_providers.sh | 6 +- .../e2eConfigs/provider/jsonrpcProvider1.yml | 8 +- .../e2eConfigs/provider/jsonrpcProvider2.yml | 1 + .../e2eConfigs/provider/jsonrpcProvider3.yml | 1 + .../e2eConfigs/provider/jsonrpcProvider4.yml | 1 + .../e2eConfigs/provider/jsonrpcProvider5.yml | 1 + testutil/e2e/protocolE2E.go | 185 +++- testutil/e2e/proxy/proxy.go | 55 +- utils/lavalog.go | 4 + x/pairing/types/relay_mock.pb.go | 410 ++++++++ x/spec/types/api_collection.pb.go | 9 + 97 files changed, 7326 insertions(+), 1287 deletions(-) create mode 100644 protocol/chainlib/chainlib_mock.go create mode 100644 protocol/chainlib/chainproxy/rpcclient/utils.go create mode 100644 protocol/chainlib/consumer_websocket_manager.go create mode 100644 protocol/chainlib/consumer_ws_subscription_manager.go create mode 100644 protocol/chainlib/consumer_ws_subscription_manager_test.go create mode 100644 protocol/chainlib/provider_node_subscription_manager.go create mode 100644 protocol/chainlib/provider_node_subscription_manager_test.go create mode 100644 protocol/common/safe_channel_sender.go create mode 100644 protocol/common/safe_channel_sender_test.go create mode 100644 protocol/common/strings.go rename protocol/lavaprotocol/{ => finalizationconsensus}/finalization_consensus.go (96%) rename protocol/lavaprotocol/{ => finalizationconsensus}/finalization_consensus_test.go (99%) create mode 100644 protocol/lavasession/active_subscription_provider_storage.go create mode 100644 protocol/lavasession/active_subscription_provider_storage_test.go create mode 100755 scripts/pre_setups/init_lava_only_with_node_protocol_only.sh create mode 100755 scripts/pre_setups/init_lava_only_with_node_two_consumers.sh create mode 100644 x/pairing/types/relay_mock.pb.go diff --git a/.github/workflows/lava.yml b/.github/workflows/lava.yml index 996f8f178f..5b0dad8e36 100644 --- a/.github/workflows/lava.yml +++ b/.github/workflows/lava.yml @@ -369,7 +369,14 @@ jobs: report-tests-results: runs-on: ubuntu-latest - needs: [test-consensus, test-protocol, test-protocol-e2e, test-payment-e2e] # test-sdk-e2e, + needs: + [ + test-consensus, + test-protocol, + test-protocol-e2e, + # test-sdk-e2e, + test-payment-e2e, + ] if: always() steps: - name: Download Artifacts diff --git a/config/provider_examples/avalanch_internal_paths_example.yml b/config/provider_examples/avalanch_internal_paths_example.yml index 9d6fc74509..bf69abb6ad 100644 --- a/config/provider_examples/avalanch_internal_paths_example.yml +++ b/config/provider_examples/avalanch_internal_paths_example.yml @@ -1,14 +1,14 @@ -# this example show cases how you can setup Avalanche +# this example show cases how you can setup Avalanche endpoints: - - api-interface: jsonrpc - chain-id: AVAX - network-address: 127.0.0.1:2221 - node-urls: - - url: ws://127.0.0.1:3333/C/rpc/ws - internal-path: "/C/rpc" # c chain like specified in the spec - - url: https://127.0.0.1:3334/C/avax - internal-path: "/C/avax" # c/avax like specified in the spec - - url: https://127.0.0.1:3335/X - internal-path: "/X" # x chain like specified in the spec - - url: https://127.0.0.1:3336/P - internal-path: "/P" # p chain like specified in the spec \ No newline at end of file + - api-interface: jsonrpc + chain-id: AVAX + network-address: 127.0.0.1:2221 + node-urls: + - url: ws://127.0.0.1:3333/C/rpc/ws + internal-path: "/C/rpc" # c chain like specified in the spec + - url: https://127.0.0.1:3334/C/avax + internal-path: "/C/avax" # c/avax like specified in the spec + - url: https://127.0.0.1:3335/X + internal-path: "/X" # x chain like specified in the spec + - url: https://127.0.0.1:3336/P + internal-path: "/P" # p chain like specified in the spec diff --git a/cookbook/specs/ethereum.json b/cookbook/specs/ethereum.json index f3cddb7c6d..c16b30a64a 100644 --- a/cookbook/specs/ethereum.json +++ b/cookbook/specs/ethereum.json @@ -384,7 +384,7 @@ "category": { "deterministic": false, "local": true, - "subscription": true, + "subscription": false, "stateful": 0 }, "extra_compute_units": 0 @@ -833,7 +833,7 @@ ], "parser_func": "DEFAULT" }, - "compute_units": 10, + "compute_units": 1000, "enabled": true, "category": { "deterministic": false, @@ -874,7 +874,7 @@ "category": { "deterministic": false, "local": true, - "subscription": true, + "subscription": false, "stateful": 0 }, "extra_compute_units": 0 @@ -883,16 +883,16 @@ "name": "eth_unsubscribe", "block_parsing": { "parser_arg": [ - "" + "latest" ], - "parser_func": "EMPTY" + "parser_func": "DEFAULT" }, "compute_units": 10, "enabled": true, "category": { "deterministic": false, "local": true, - "subscription": false, + "subscription": true, "stateful": 0 }, "extra_compute_units": 0 @@ -1045,6 +1045,15 @@ "encoding": "hex" }, "api_name": "eth_getBlockByNumber" + }, + { + "function_tag": "SUBSCRIBE", + "api_name": "eth_subscribe" + }, + { + "function_template": "{\"jsonrpc\":\"2.0\",\"method\":\"eth_unsubscribe\",\"params\":[\"%s\"],\"id\":1}", + "function_tag": "UNSUBSCRIBE", + "api_name": "eth_unsubscribe" } ], "verifications": [ diff --git a/cookbook/specs/fantom.json b/cookbook/specs/fantom.json index 286e9e81d0..7e14dbad37 100644 --- a/cookbook/specs/fantom.json +++ b/cookbook/specs/fantom.json @@ -98,7 +98,7 @@ "category": { "deterministic": false, "local": true, - "subscription": false, + "subscription": true, "stateful": 0 }, "extra_compute_units": 0 @@ -387,7 +387,7 @@ "category": { "deterministic": false, "local": true, - "subscription": true, + "subscription": false, "stateful": 0 }, "extra_compute_units": 0 @@ -477,7 +477,7 @@ "category": { "deterministic": false, "local": true, - "subscription": true, + "subscription": false, "stateful": 0 }, "extra_compute_units": 0 diff --git a/cookbook/specs/tendermint.json b/cookbook/specs/tendermint.json index 3886cb9828..3472a379f8 100644 --- a/cookbook/specs/tendermint.json +++ b/cookbook/specs/tendermint.json @@ -449,11 +449,11 @@ "name": "subscribe", "block_parsing": { "parser_arg": [ - "" + "latest" ], - "parser_func": "EMPTY" + "parser_func": "DEFAULT" }, - "compute_units": 10, + "compute_units": 1000, "enabled": true, "category": { "deterministic": false, @@ -521,16 +521,16 @@ "name": "unsubscribe", "block_parsing": { "parser_arg": [ - "" + "latest" ], - "parser_func": "EMPTY" + "parser_func": "DEFAULT" }, "compute_units": 10, "enabled": true, "category": { "deterministic": false, "local": true, - "subscription": false, + "subscription": true, "stateful": 0 }, "extra_compute_units": 0 @@ -539,16 +539,16 @@ "name": "unsubscribe_all", "block_parsing": { "parser_arg": [ - "" + "latest" ], - "parser_func": "EMPTY" + "parser_func": "DEFAULT" }, "compute_units": 10, "enabled": true, "category": { "deterministic": false, "local": true, - "subscription": false, + "subscription": true, "stateful": 0 }, "extra_compute_units": 0 @@ -618,6 +618,20 @@ "encoding": "base64" }, "api_name": "earliest_block" + }, + { + "function_tag": "SUBSCRIBE", + "api_name": "subscribe" + }, + { + "function_template": "{\"jsonrpc\":\"2.0\",\"method\":\"unsubscribe\",\"params\":%s,\"id\":1}", + "function_tag": "UNSUBSCRIBE", + "api_name": "unsubscribe" + }, + { + "function_template": "{\"jsonrpc\":\"2.0\",\"method\":\"unsubscribe_all\",\"params\":[],\"id\":1}", + "function_tag": "UNSUBSCRIBE_ALL", + "api_name": "unsubscribe_all" } ], "verifications": [ diff --git a/ecosystem/lava-sdk/src/stateTracker/stateQuery/state_badge_query.ts b/ecosystem/lava-sdk/src/stateTracker/stateQuery/state_badge_query.ts index dd818005fc..564150eec9 100644 --- a/ecosystem/lava-sdk/src/stateTracker/stateQuery/state_badge_query.ts +++ b/ecosystem/lava-sdk/src/stateTracker/stateQuery/state_badge_query.ts @@ -43,7 +43,7 @@ export class StateBadgeQuery { // fetchPairing fetches pairing for all chainIDs we support public async fetchPairing(): Promise { - Logger.debug("Fetching pairing started"); + Logger.debug("Fetching pairing from badge started"); let timeLeftToNextPairing; let virtualEpoch; @@ -110,7 +110,7 @@ export class StateBadgeQuery { this.virtualEpoch = virtualEpoch; this.currentEpoch = currentEpoch; - Logger.debug("Fetching pairing ended"); + Logger.debug("Fetching pairing from badge ended", timeLeftToNextPairing); return timeLeftToNextPairing; } diff --git a/ecosystem/lava-sdk/src/stateTracker/stateQuery/state_chain_query.ts b/ecosystem/lava-sdk/src/stateTracker/stateQuery/state_chain_query.ts index 5cf6247041..0b3601f12d 100644 --- a/ecosystem/lava-sdk/src/stateTracker/stateQuery/state_chain_query.ts +++ b/ecosystem/lava-sdk/src/stateTracker/stateQuery/state_chain_query.ts @@ -81,7 +81,7 @@ export class StateChainQuery { // fetchPairing fetches pairing for all chainIDs we support public async fetchPairing(): Promise { try { - Logger.debug("Fetching pairing started"); + Logger.debug("Fetching pairing from chain started"); // Save time till next epoch let timeLeftToNextPairing; let currentEpoch; @@ -154,7 +154,7 @@ export class StateChainQuery { this.currentEpoch = currentEpoch; this.downtimeParams = downtimeParams; - Logger.debug("Fetching pairing ended"); + Logger.debug("Fetching pairing from chain ended", timeLeftToNextPairing); // Return timeLeftToNextPairing return timeLeftToNextPairing; diff --git a/proto/lavanet/lava/spec/api_collection.proto b/proto/lavanet/lava/spec/api_collection.proto index 9b60c97532..b872121e3a 100644 --- a/proto/lavanet/lava/spec/api_collection.proto +++ b/proto/lavanet/lava/spec/api_collection.proto @@ -111,6 +111,9 @@ enum FUNCTION_TAG { SET_LATEST_IN_BODY = 4; VERIFICATION = 5; GET_EARLIEST_BLOCK = 6; + SUBSCRIBE = 7; + UNSUBSCRIBE = 8; + UNSUBSCRIBE_ALL = 9; } enum PARSER_TYPE { diff --git a/protocol/chainlib/base_chain_parser.go b/protocol/chainlib/base_chain_parser.go index 7a703109dc..f94dfe9b49 100644 --- a/protocol/chainlib/base_chain_parser.go +++ b/protocol/chainlib/base_chain_parser.go @@ -189,12 +189,14 @@ func (bcp *BaseChainParser) SeparateAddonsExtensions(supported []string) (addons if supportedToCheck == "" { continue } - if bcp.isExtension(supportedToCheck) { + if bcp.isExtension(supportedToCheck) || supportedToCheck == WebSocketExtension { extensions = append(extensions, supportedToCheck) continue } // neither is an error - err = utils.LavaFormatError("invalid supported to check, is neither an addon or an extension", err, utils.Attribute{Key: "spec", Value: bcp.spec.Index}, utils.Attribute{Key: "supported", Value: supportedToCheck}) + err = utils.LavaFormatError("invalid supported to check, is neither an addon or an extension", err, + utils.Attribute{Key: "spec", Value: bcp.spec.Index}, + utils.Attribute{Key: "supported", Value: supportedToCheck}) } } return addons, extensions, err @@ -252,7 +254,7 @@ func (bcp *BaseChainParser) Construct(spec spectypes.Spec, internalPaths map[str bcp.extensionParser.SetConfiguredExtensions(extensionParser.GetConfiguredExtensions()) } -func (bcp *BaseChainParser) GetParsingByTag(tag spectypes.FUNCTION_TAG) (parsing *spectypes.ParseDirective, collectionData *spectypes.CollectionData, existed bool) { +func (bcp *BaseChainParser) GetParsingByTag(tag spectypes.FUNCTION_TAG) (parsing *spectypes.ParseDirective, apiCollection *spectypes.ApiCollection, existed bool) { bcp.rwLock.RLock() defer bcp.rwLock.RUnlock() @@ -260,7 +262,7 @@ func (bcp *BaseChainParser) GetParsingByTag(tag spectypes.FUNCTION_TAG) (parsing if !ok { return nil, nil, false } - return val.Parsing, &val.ApiCollection.CollectionData, ok + return val.Parsing, val.ApiCollection, ok } func (bcp *BaseChainParser) ExtensionParsing(addon string, parsedMessageArg *baseChainMessageContainer, extensionInfo extensionslib.ExtensionInfo) { diff --git a/protocol/chainlib/chain_fetcher.go b/protocol/chainlib/chain_fetcher.go index b72a9a30fb..8decdba7e7 100644 --- a/protocol/chainlib/chain_fetcher.go +++ b/protocol/chainlib/chain_fetcher.go @@ -262,11 +262,12 @@ func (cf *ChainFetcher) ChainFetcherMetadata() []pairingtypes.Metadata { } func (cf *ChainFetcher) FetchLatestBlockNum(ctx context.Context) (int64, error) { - parsing, collectionData, ok := cf.chainParser.GetParsingByTag(spectypes.FUNCTION_TAG_GET_BLOCKNUM) + parsing, apiCollection, ok := cf.chainParser.GetParsingByTag(spectypes.FUNCTION_TAG_GET_BLOCKNUM) tagName := spectypes.FUNCTION_TAG_GET_BLOCKNUM.String() if !ok { return spectypes.NOT_APPLICABLE, utils.LavaFormatError(tagName+" tag function not found", nil, []utils.Attribute{{Key: "chainID", Value: cf.endpoint.ChainID}, {Key: "APIInterface", Value: cf.endpoint.ApiInterface}}...) } + collectionData := apiCollection.CollectionData var craftData *CraftData if parsing.FunctionTemplate != "" { path := parsing.ApiName @@ -321,11 +322,13 @@ func (cf *ChainFetcher) constructRelayData(conectionType string, path string, da } func (cf *ChainFetcher) FetchBlockHashByNum(ctx context.Context, blockNum int64) (string, error) { - parsing, collectionData, ok := cf.chainParser.GetParsingByTag(spectypes.FUNCTION_TAG_GET_BLOCK_BY_NUM) + parsing, apiCollection, ok := cf.chainParser.GetParsingByTag(spectypes.FUNCTION_TAG_GET_BLOCK_BY_NUM) tagName := spectypes.FUNCTION_TAG_GET_BLOCK_BY_NUM.String() if !ok { return "", utils.LavaFormatError(tagName+" tag function not found", nil, []utils.Attribute{{Key: "chainID", Value: cf.endpoint.ChainID}, {Key: "APIInterface", Value: cf.endpoint.ApiInterface}}...) } + collectionData := apiCollection.CollectionData + if parsing.FunctionTemplate == "" { return "", utils.LavaFormatError(tagName+" missing function template", nil, []utils.Attribute{{Key: "chainID", Value: cf.endpoint.ChainID}, {Key: "APIInterface", Value: cf.endpoint.ApiInterface}}...) } diff --git a/protocol/chainlib/chain_message.go b/protocol/chainlib/chain_message.go index e87049d4ca..f3d915743b 100644 --- a/protocol/chainlib/chain_message.go +++ b/protocol/chainlib/chain_message.go @@ -5,6 +5,7 @@ import ( "time" "github.com/lavanet/lava/v2/protocol/chainlib/chainproxy/rpcInterfaceMessages" + "github.com/lavanet/lava/v2/protocol/chainlib/chainproxy/rpcclient" "github.com/lavanet/lava/v2/protocol/chainlib/extensionslib" "github.com/lavanet/lava/v2/utils" pairingtypes "github.com/lavanet/lava/v2/x/pairing/types" @@ -15,6 +16,7 @@ type updatableRPCInput interface { rpcInterfaceMessages.GenericMessage UpdateLatestBlockInMessage(latestBlock uint64, modifyContent bool) (success bool) AppendHeader(metadata []pairingtypes.Metadata) + SubscriptionIdExtractor(reply *rpcclient.JsonrpcMessage) string GetRawRequestHash() ([]byte, error) } @@ -27,12 +29,23 @@ type baseChainMessageContainer struct { extensions []*spectypes.Extension timeoutOverride time.Duration forceCacheRefresh bool - inputHashCache []byte + parseDirective *spectypes.ParseDirective // setting the parse directive related to the api, can be nil + + inputHashCache []byte // resultErrorParsingMethod passed by each api interface message to parse the result of the message // and validate it doesn't contain a node error resultErrorParsingMethod func(data []byte, httpStatusCode int) (hasError bool, errorMessage string) } +func (bcnc *baseChainMessageContainer) SubscriptionIdExtractor(reply *rpcclient.JsonrpcMessage) string { + return bcnc.msg.SubscriptionIdExtractor(reply) +} + +// returning parse directive for the api. can be nil. +func (bcnc *baseChainMessageContainer) GetParseDirective() *spectypes.ParseDirective { + return bcnc.parseDirective +} + func (pm *baseChainMessageContainer) GetRawRequestHash() ([]byte, error) { if pm.inputHashCache != nil && len(pm.inputHashCache) > 0 { // Get the cached value @@ -47,79 +60,79 @@ func (pm *baseChainMessageContainer) GetRawRequestHash() ([]byte, error) { } // not necessary for base chain message. -func (pm *baseChainMessageContainer) CheckResponseError(data []byte, httpStatusCode int) (hasError bool, errorMessage string) { - if pm.resultErrorParsingMethod == nil { +func (bcnc *baseChainMessageContainer) CheckResponseError(data []byte, httpStatusCode int) (hasError bool, errorMessage string) { + if bcnc.resultErrorParsingMethod == nil { utils.LavaFormatError("tried calling resultErrorParsingMethod when it is not set", nil) return false, "" } - return pm.resultErrorParsingMethod(data, httpStatusCode) + return bcnc.resultErrorParsingMethod(data, httpStatusCode) } -func (pm *baseChainMessageContainer) TimeoutOverride(override ...time.Duration) time.Duration { +func (bcnc *baseChainMessageContainer) TimeoutOverride(override ...time.Duration) time.Duration { if len(override) > 0 { - pm.timeoutOverride = override[0] + bcnc.timeoutOverride = override[0] } - return pm.timeoutOverride + return bcnc.timeoutOverride } -func (pm *baseChainMessageContainer) SetForceCacheRefresh(force bool) bool { - pm.forceCacheRefresh = force - return pm.forceCacheRefresh +func (bcnc *baseChainMessageContainer) SetForceCacheRefresh(force bool) bool { + bcnc.forceCacheRefresh = force + return bcnc.forceCacheRefresh } -func (pm *baseChainMessageContainer) GetForceCacheRefresh() bool { - return pm.forceCacheRefresh +func (bcnc *baseChainMessageContainer) GetForceCacheRefresh() bool { + return bcnc.forceCacheRefresh } -func (pm *baseChainMessageContainer) DisableErrorHandling() { - pm.msg.DisableErrorHandling() +func (bcnc *baseChainMessageContainer) DisableErrorHandling() { + bcnc.msg.DisableErrorHandling() } -func (pm baseChainMessageContainer) AppendHeader(metadata []pairingtypes.Metadata) { - pm.msg.AppendHeader(metadata) +func (bcnc baseChainMessageContainer) AppendHeader(metadata []pairingtypes.Metadata) { + bcnc.msg.AppendHeader(metadata) } -func (pm baseChainMessageContainer) GetApi() *spectypes.Api { - return pm.api +func (bcnc baseChainMessageContainer) GetApi() *spectypes.Api { + return bcnc.api } -func (pm baseChainMessageContainer) GetApiCollection() *spectypes.ApiCollection { - return pm.apiCollection +func (bcnc baseChainMessageContainer) GetApiCollection() *spectypes.ApiCollection { + return bcnc.apiCollection } -func (pm baseChainMessageContainer) RequestedBlock() (latest int64, earliest int64) { - if pm.earliestRequestedBlock == 0 { +func (bcnc baseChainMessageContainer) RequestedBlock() (latest int64, earliest int64) { + if bcnc.earliestRequestedBlock == 0 { // earliest is optional and not set here - return pm.latestRequestedBlock, pm.latestRequestedBlock + return bcnc.latestRequestedBlock, bcnc.latestRequestedBlock } - return pm.latestRequestedBlock, pm.earliestRequestedBlock + return bcnc.latestRequestedBlock, bcnc.earliestRequestedBlock } -func (pm baseChainMessageContainer) GetRPCMessage() rpcInterfaceMessages.GenericMessage { - return pm.msg +func (bcnc baseChainMessageContainer) GetRPCMessage() rpcInterfaceMessages.GenericMessage { + return bcnc.msg } -func (pm *baseChainMessageContainer) UpdateLatestBlockInMessage(latestBlock int64, modifyContent bool) (modifiedOnLatestReq bool) { - requestedBlock, _ := pm.RequestedBlock() +func (bcnc *baseChainMessageContainer) UpdateLatestBlockInMessage(latestBlock int64, modifyContent bool) (modifiedOnLatestReq bool) { + requestedBlock, _ := bcnc.RequestedBlock() if latestBlock <= spectypes.NOT_APPLICABLE || requestedBlock != spectypes.LATEST_BLOCK { return false } - success := pm.msg.UpdateLatestBlockInMessage(uint64(latestBlock), modifyContent) + success := bcnc.msg.UpdateLatestBlockInMessage(uint64(latestBlock), modifyContent) if success { - pm.latestRequestedBlock = latestBlock + bcnc.latestRequestedBlock = latestBlock return true } return false } -func (pm *baseChainMessageContainer) GetExtensions() []*spectypes.Extension { - return pm.extensions +func (bcnc *baseChainMessageContainer) GetExtensions() []*spectypes.Extension { + return bcnc.extensions } // adds the following extensions -func (pm *baseChainMessageContainer) OverrideExtensions(extensionNames []string, extensionParser *extensionslib.ExtensionParser) { +func (bcnc *baseChainMessageContainer) OverrideExtensions(extensionNames []string, extensionParser *extensionslib.ExtensionParser) { existingExtensions := map[string]struct{}{} - for _, extension := range pm.extensions { + for _, extension := range bcnc.extensions { existingExtensions[extension.Name] = struct{}{} } for _, extensionName := range extensionNames { @@ -127,38 +140,38 @@ func (pm *baseChainMessageContainer) OverrideExtensions(extensionNames []string, existingExtensions[extensionName] = struct{}{} extensionKey := extensionslib.ExtensionKey{ Extension: extensionName, - ConnectionType: pm.apiCollection.CollectionData.Type, - InternalPath: pm.apiCollection.CollectionData.InternalPath, - Addon: pm.apiCollection.CollectionData.AddOn, + ConnectionType: bcnc.apiCollection.CollectionData.Type, + InternalPath: bcnc.apiCollection.CollectionData.InternalPath, + Addon: bcnc.apiCollection.CollectionData.AddOn, } extension := extensionParser.GetExtension(extensionKey) if extension != nil { - pm.extensions = append(pm.extensions, extension) - pm.updateCUForApi(extension) + bcnc.extensions = append(bcnc.extensions, extension) + bcnc.updateCUForApi(extension) } } } } -func (pm *baseChainMessageContainer) SetExtension(extension *spectypes.Extension) { - if len(pm.extensions) > 0 { - for _, ext := range pm.extensions { +func (bcnc *baseChainMessageContainer) SetExtension(extension *spectypes.Extension) { + if len(bcnc.extensions) > 0 { + for _, ext := range bcnc.extensions { if ext.Name == extension.Name { // already existing, no need to add return } } - pm.extensions = append(pm.extensions, extension) + bcnc.extensions = append(bcnc.extensions, extension) } else { - pm.extensions = []*spectypes.Extension{extension} + bcnc.extensions = []*spectypes.Extension{extension} } - pm.updateCUForApi(extension) + bcnc.updateCUForApi(extension) } -func (pm *baseChainMessageContainer) updateCUForApi(extension *spectypes.Extension) { - copyApi := *pm.api // we can't modify this because it points to an object inside the chainParser +func (bcnc *baseChainMessageContainer) updateCUForApi(extension *spectypes.Extension) { + copyApi := *bcnc.api // we can't modify this because it points to an object inside the chainParser copyApi.ComputeUnits = uint64(math.Floor(float64(extension.GetCuMultiplier()) * float64(copyApi.ComputeUnits))) - pm.api = ©Api + bcnc.api = ©Api } type CraftData struct { diff --git a/protocol/chainlib/chain_message_queries.go b/protocol/chainlib/chain_message_queries.go index 389d80e9d7..479cf259b2 100644 --- a/protocol/chainlib/chain_message_queries.go +++ b/protocol/chainlib/chain_message_queries.go @@ -1,6 +1,9 @@ package chainlib -import "github.com/lavanet/lava/v2/protocol/common" +import ( + "github.com/lavanet/lava/v2/protocol/common" + "github.com/lavanet/lava/v2/x/spec/types" +) func ShouldSendToAllProviders(chainMessage ChainMessage) bool { return chainMessage.GetApi().Category.Stateful == common.CONSISTENCY_SELECT_ALL_PROVIDERS @@ -10,10 +13,6 @@ func GetAddon(chainMessage ChainMessageForSend) string { return chainMessage.GetApiCollection().CollectionData.AddOn } -func IsSubscription(chainMessage ChainMessageForSend) bool { - return chainMessage.GetApi().Category.Subscription -} - func IsHangingApi(chainMessage ChainMessageForSend) bool { return chainMessage.GetApi().Category.HangingApi } @@ -25,3 +24,21 @@ func GetComputeUnits(chainMessage ChainMessageForSend) uint64 { func GetStateful(chainMessage ChainMessageForSend) uint32 { return chainMessage.GetApi().Category.Stateful } + +func GetParseDirective(api *types.Api, apiCollection *types.ApiCollection) *types.ParseDirective { + chainMessageApiName := api.Name + for _, parseDirective := range apiCollection.GetParseDirectives() { + if parseDirective.ApiName == chainMessageApiName { + return parseDirective + } + } + return nil +} + +func IsFunctionTagOfType(chainMessage ChainMessageForSend, functionTag types.FUNCTION_TAG) bool { + parseDirective := chainMessage.GetParseDirective() + if parseDirective != nil { + return parseDirective.FunctionTag == functionTag + } + return false +} diff --git a/protocol/chainlib/chain_router.go b/protocol/chainlib/chain_router.go index 100b54227f..8530d30a74 100644 --- a/protocol/chainlib/chain_router.go +++ b/protocol/chainlib/chain_router.go @@ -2,12 +2,16 @@ package chainlib import ( "context" + "net/url" + "strings" "sync" "github.com/lavanet/lava/v2/protocol/chainlib/chainproxy/rpcclient" "github.com/lavanet/lava/v2/protocol/common" "github.com/lavanet/lava/v2/protocol/lavasession" "github.com/lavanet/lava/v2/utils" + spectypes "github.com/lavanet/lava/v2/x/spec/types" + "google.golang.org/grpc" "google.golang.org/grpc/metadata" ) @@ -77,8 +81,30 @@ func (cri chainRouterImpl) SendNodeMsg(ctx context.Context, ch chan interface{}, func batchNodeUrlsByServices(rpcProviderEndpoint lavasession.RPCProviderEndpoint) map[lavasession.RouterKey]lavasession.RPCProviderEndpoint { returnedBatch := map[lavasession.RouterKey]lavasession.RPCProviderEndpoint{} for _, nodeUrl := range rpcProviderEndpoint.NodeUrls { - if existingEndpoint, ok := returnedBatch[lavasession.NewRouterKey(nodeUrl.Addons)]; !ok { - returnedBatch[lavasession.NewRouterKey(nodeUrl.Addons)] = lavasession.RPCProviderEndpoint{ + routerKey := lavasession.NewRouterKey(nodeUrl.Addons) + + u, err := url.Parse(nodeUrl.Url) + // Some parsing may fail because of gRPC + if err == nil && (u.Scheme == "ws" || u.Scheme == "wss") { + // if websocket, check if we have a router key for http already. if not add a websocket router key + // so in case we didn't get an http endpoint, we can use the ws one. + if _, ok := returnedBatch[routerKey]; !ok { + returnedBatch[routerKey] = lavasession.RPCProviderEndpoint{ + NetworkAddress: rpcProviderEndpoint.NetworkAddress, + ChainID: rpcProviderEndpoint.ChainID, + ApiInterface: rpcProviderEndpoint.ApiInterface, + Geolocation: rpcProviderEndpoint.Geolocation, + NodeUrls: []common.NodeUrl{nodeUrl}, // add existing nodeUrl to the batch + } + } + + // now change the router key to fit the websocket extension key. + nodeUrl.Addons = append(nodeUrl.Addons, WebSocketExtension) + routerKey = lavasession.NewRouterKey(nodeUrl.Addons) + } + + if existingEndpoint, ok := returnedBatch[routerKey]; !ok { + returnedBatch[routerKey] = lavasession.RPCProviderEndpoint{ NetworkAddress: rpcProviderEndpoint.NetworkAddress, ChainID: rpcProviderEndpoint.ChainID, ApiInterface: rpcProviderEndpoint.ApiInterface, @@ -86,10 +112,12 @@ func batchNodeUrlsByServices(rpcProviderEndpoint lavasession.RPCProviderEndpoint NodeUrls: []common.NodeUrl{nodeUrl}, // add existing nodeUrl to the batch } } else { - existingEndpoint.NodeUrls = append(existingEndpoint.NodeUrls, nodeUrl) - returnedBatch[lavasession.NewRouterKey(nodeUrl.Addons)] = existingEndpoint + // setting the incoming url first as it might be http while existing is websocket. (we prioritize http over ws when possible) + existingEndpoint.NodeUrls = append([]common.NodeUrl{nodeUrl}, existingEndpoint.NodeUrls...) + returnedBatch[routerKey] = existingEndpoint } } + return returnedBatch } @@ -140,6 +168,25 @@ func newChainRouter(ctx context.Context, nConns uint, rpcProviderEndpoint lavase return nil, utils.LavaFormatError("not all requirements supported in chainRouter, missing extensions or addons in definitions", nil, utils.Attribute{Key: "required", Value: requiredMap}, utils.Attribute{Key: "supported", Value: supportedMap}) } + _, apiCollection, hasSubscriptionInSpec := chainParser.GetParsingByTag(spectypes.FUNCTION_TAG_SUBSCRIBE) + // validating we have websocket support for subscription supported specs. + webSocketSupported := false + for key := range supportedMap { + if key.IsRequirementMet(WebSocketExtension) { + webSocketSupported = true + } + } + if hasSubscriptionInSpec && apiCollection.Enabled && !webSocketSupported { + err := utils.LavaFormatError("subscriptions are applicable for this chain, but websocket is not provided in 'supported' map. By not setting ws/wss your provider wont be able to accept ws subscriptions, therefore might receive less rewards and lower QOS score.", nil, + utils.LogAttr("apiInterface", apiCollection.CollectionData.ApiInterface), + utils.LogAttr("supportedMap", supportedMap), + utils.LogAttr("required", WebSocketExtension), + ) + if !IgnoreSubscriptionNotConfiguredError { + return nil, err + } + } + cri := chainRouterImpl{ lock: &sync.RWMutex{}, chainProxyRouter: chainProxyRouter, @@ -152,6 +199,14 @@ type requirementSt struct { addon string } +func (rs *requirementSt) String() string { + return string(rs.extensions) + rs.addon +} + +func (rs *requirementSt) IsRequirementMet(requirement string) bool { + return strings.Contains(string(rs.extensions), requirement) || strings.Contains(rs.addon, requirement) +} + func populateRequiredForAddon(addon string, extensions []string, required map[requirementSt]struct{}) { if len(extensions) == 0 { required[requirementSt{ diff --git a/protocol/chainlib/chain_router_test.go b/protocol/chainlib/chain_router_test.go index 24d8dc31f1..56e650380e 100644 --- a/protocol/chainlib/chain_router_test.go +++ b/protocol/chainlib/chain_router_test.go @@ -2,26 +2,433 @@ package chainlib import ( "context" + "log" + "net" + "os" "testing" + "time" + gojson "github.com/goccy/go-json" + "github.com/gofiber/fiber/v2" + "github.com/gofiber/fiber/v2/middleware/compress" + "github.com/gofiber/fiber/v2/middleware/favicon" + "github.com/gofiber/websocket/v2" + "github.com/lavanet/lava/v2/protocol/chainlib/chainproxy/rpcclient" "github.com/lavanet/lava/v2/protocol/common" "github.com/lavanet/lava/v2/protocol/lavasession" testcommon "github.com/lavanet/lava/v2/testutil/common" + "github.com/lavanet/lava/v2/utils" spectypes "github.com/lavanet/lava/v2/x/spec/types" "github.com/stretchr/testify/require" ) -func TestChainRouter(t *testing.T) { +var ( + listenerAddressTcp = "localhost:0" + listenerAddressHttp = "" + listenerAddressWs = "" +) + +type TimeServer int64 + +func TestChainRouterWithDisabledWebSocketInSpec(t *testing.T) { + ctx := context.Background() + apiInterface := spectypes.APIInterfaceJsonRPC + chainParser, err := NewChainParser(apiInterface) + require.NoError(t, err) + + IgnoreSubscriptionNotConfiguredError = false + + addonsOptions := []string{"-addon-", "-addon2-"} + extensionsOptions := []string{"-test-", "-test2-", "-test3-"} + + spec := testcommon.CreateMockSpec() + spec.ApiCollections = []*spectypes.ApiCollection{ + { + Enabled: false, + CollectionData: spectypes.CollectionData{ + ApiInterface: apiInterface, + InternalPath: "", + Type: "", + AddOn: "", + }, + Extensions: []*spectypes.Extension{ + { + Name: extensionsOptions[0], + CuMultiplier: 1, + }, + { + Name: extensionsOptions[1], + CuMultiplier: 1, + }, + { + Name: extensionsOptions[2], + CuMultiplier: 1, + }, + }, + ParseDirectives: []*spectypes.ParseDirective{{ + FunctionTag: spectypes.FUNCTION_TAG_SUBSCRIBE, + }}, + }, + { + Enabled: true, + CollectionData: spectypes.CollectionData{ + ApiInterface: apiInterface, + InternalPath: "", + Type: "", + AddOn: "", + }, + Extensions: []*spectypes.Extension{ + { + Name: extensionsOptions[0], + CuMultiplier: 1, + }, + { + Name: extensionsOptions[1], + CuMultiplier: 1, + }, + { + Name: extensionsOptions[2], + CuMultiplier: 1, + }, + }, + }, + { + Enabled: true, + CollectionData: spectypes.CollectionData{ + ApiInterface: apiInterface, + InternalPath: "", + Type: "", + AddOn: addonsOptions[0], + }, + Extensions: []*spectypes.Extension{ + { + Name: extensionsOptions[0], + CuMultiplier: 1, + }, + { + Name: extensionsOptions[1], + CuMultiplier: 1, + }, + { + Name: extensionsOptions[2], + CuMultiplier: 1, + }, + }, + }, + { + Enabled: true, + CollectionData: spectypes.CollectionData{ + ApiInterface: apiInterface, + InternalPath: "", + Type: "", + AddOn: addonsOptions[1], + }, + Extensions: []*spectypes.Extension{ + { + Name: extensionsOptions[0], + CuMultiplier: 1, + }, + { + Name: extensionsOptions[1], + CuMultiplier: 1, + }, + { + Name: extensionsOptions[2], + CuMultiplier: 1, + }, + }, + }, + } + chainParser.SetSpec(spec) + endpoint := &lavasession.RPCProviderEndpoint{ + NetworkAddress: lavasession.NetworkAddressData{}, + ChainID: spec.Index, + ApiInterface: apiInterface, + Geolocation: 1, + NodeUrls: []common.NodeUrl{}, + } + + type servicesStruct struct { + services []string + } + + playBook := []struct { + name string + services []servicesStruct + success bool + }{ + { + name: "empty services", + services: []servicesStruct{{ + services: []string{}, + }}, + success: true, + }, + { + name: "one-addon", + services: []servicesStruct{{ + services: []string{addonsOptions[0]}, + }}, + success: true, + }, + { + name: "one-extension", + services: []servicesStruct{{ + services: []string{extensionsOptions[0]}, + }}, + success: false, + }, + { + name: "one-extension with empty services", + services: []servicesStruct{ + { + services: []string{extensionsOptions[0]}, + }, + { + services: []string{}, + }, + }, + success: true, + }, + { + name: "two-addons together", + services: []servicesStruct{{ + services: addonsOptions, + }}, + success: true, + }, + { + name: "two-addons, separated", + services: []servicesStruct{{ + services: []string{addonsOptions[0]}, + }, { + services: []string{addonsOptions[1]}, + }}, + success: true, + }, + { + name: "addon + extension only", + services: []servicesStruct{{ + services: []string{addonsOptions[0], extensionsOptions[0]}, + }}, + success: false, + }, + { + name: "addon + extension, addon", + services: []servicesStruct{{ + services: []string{addonsOptions[0], extensionsOptions[0]}, + }, { + services: []string{addonsOptions[0]}, + }}, + success: true, + }, + { + name: "two addons + extension, addon", + services: []servicesStruct{{ + services: []string{addonsOptions[0], addonsOptions[1], extensionsOptions[0]}, + }, { + services: []string{addonsOptions[0]}, + }}, + success: false, + }, + { + name: "addons + extension, two addons", + services: []servicesStruct{{ + services: []string{addonsOptions[0], extensionsOptions[0]}, + }, { + services: []string{addonsOptions[0], addonsOptions[1]}, + }}, + success: true, + }, + { + name: "addons + two extensions, addon extension", + services: []servicesStruct{{ + services: []string{addonsOptions[0], extensionsOptions[0], extensionsOptions[1]}, + }, { + services: []string{addonsOptions[0], extensionsOptions[1]}, + }}, + success: false, + }, + { + name: "addons + two extensions, addon", + services: []servicesStruct{{ + services: []string{addonsOptions[0], extensionsOptions[0], extensionsOptions[1]}, + }, { + services: []string{addonsOptions[0]}, + }}, + success: false, + }, + { + name: "addons + two extensions, other addon", + services: []servicesStruct{ + { + services: []string{addonsOptions[0], extensionsOptions[0], extensionsOptions[1]}, + }, + { + services: []string{addonsOptions[1], extensionsOptions[0]}, + }, + { + services: []string{addonsOptions[0], extensionsOptions[1]}, + }, + }, + success: false, + }, + { + name: "addons + two extensions, addon ext1, addon ext2", + services: []servicesStruct{ + { + services: []string{addonsOptions[0], extensionsOptions[0], extensionsOptions[1]}, + }, + { + services: []string{addonsOptions[0], extensionsOptions[0]}, + }, + { + services: []string{addonsOptions[0], extensionsOptions[1]}, + }, + }, + success: false, + }, + { + name: "addons + two extensions, works", + services: []servicesStruct{ + { + services: []string{addonsOptions[0], extensionsOptions[0], extensionsOptions[1]}, + }, + { + services: []string{addonsOptions[0], extensionsOptions[0]}, + }, + { + services: []string{addonsOptions[0], extensionsOptions[1]}, + }, + { + services: []string{addonsOptions[0]}, + }, + }, + success: true, + }, + { + name: "addons + two extensions, works, addon2", + services: []servicesStruct{ + { + services: []string{addonsOptions[0], extensionsOptions[0], extensionsOptions[1]}, + }, + { + services: []string{addonsOptions[0], extensionsOptions[0]}, + }, + { + services: []string{addonsOptions[0], extensionsOptions[1]}, + }, + { + services: []string{addonsOptions[0], addonsOptions[1]}, + }, + }, + success: true, + }, + { + name: "addon1 + ext, addon 2 + ext, addon 1", + services: []servicesStruct{ + { + services: []string{addonsOptions[0], extensionsOptions[0]}, + }, + { + services: []string{addonsOptions[1], extensionsOptions[0]}, + }, + { + services: []string{addonsOptions[0]}, + }, + }, + success: false, + }, + { + name: "addon1 + ext, addon 2 + ext, addon 1,addon2", + services: []servicesStruct{ + { + services: []string{addonsOptions[0], extensionsOptions[0]}, + }, + { + services: []string{addonsOptions[1], extensionsOptions[0]}, + }, + { + services: []string{addonsOptions[0]}, + }, + { + services: []string{addonsOptions[1]}, + }, + }, + success: true, + }, + { + name: "addon, ext", + services: []servicesStruct{ + { + services: []string{addonsOptions[0]}, + }, + { + services: []string{extensionsOptions[0]}, + }, + }, + success: true, + }, + } + for _, play := range playBook { + t.Run(play.name, func(t *testing.T) { + nodeUrls := []common.NodeUrl{} + for _, service := range play.services { + nodeUrl := common.NodeUrl{Url: listenerAddressHttp} + nodeUrl.Addons = service.services + nodeUrls = append(nodeUrls, nodeUrl) + } + + endpoint.NodeUrls = nodeUrls + _, err := GetChainRouter(ctx, 1, endpoint, chainParser) + if play.success { + require.NoError(t, err) + } else { + require.Error(t, err) + } + }) + } +} + +func TestChainRouterWithEnabledWebSocketInSpec(t *testing.T) { ctx := context.Background() apiInterface := spectypes.APIInterfaceJsonRPC chainParser, err := NewChainParser(apiInterface) require.NoError(t, err) + IgnoreSubscriptionNotConfiguredError = false + addonsOptions := []string{"-addon-", "-addon2-"} extensionsOptions := []string{"-test-", "-test2-", "-test3-"} spec := testcommon.CreateMockSpec() spec.ApiCollections = []*spectypes.ApiCollection{ + { + Enabled: true, + CollectionData: spectypes.CollectionData{ + ApiInterface: apiInterface, + InternalPath: "", + Type: "", + AddOn: "", + }, + Extensions: []*spectypes.Extension{ + { + Name: extensionsOptions[0], + CuMultiplier: 1, + }, + { + Name: extensionsOptions[1], + CuMultiplier: 1, + }, + { + Name: extensionsOptions[2], + CuMultiplier: 1, + }, + }, + ParseDirectives: []*spectypes.ParseDirective{{ + FunctionTag: spectypes.FUNCTION_TAG_SUBSCRIBE, + }}, + }, { Enabled: true, CollectionData: spectypes.CollectionData{ @@ -132,7 +539,7 @@ func TestChainRouter(t *testing.T) { success: false, }, { - name: "one-extension works", + name: "one-extension with empty services", services: []servicesStruct{ { services: []string{extensionsOptions[0]}, @@ -327,9 +734,11 @@ func TestChainRouter(t *testing.T) { t.Run(play.name, func(t *testing.T) { nodeUrls := []common.NodeUrl{} for _, service := range play.services { - nodeUrl := common.NodeUrl{Url: "http://127.0.0.1:0"} + nodeUrl := common.NodeUrl{Url: listenerAddressHttp} nodeUrl.Addons = service.services nodeUrls = append(nodeUrls, nodeUrl) + nodeUrl.Url = listenerAddressWs + nodeUrls = append(nodeUrls, nodeUrl) } endpoint.NodeUrls = nodeUrls _, err := GetChainRouter(ctx, 1, endpoint, chainParser) @@ -341,3 +750,84 @@ func TestChainRouter(t *testing.T) { }) } } + +func createRPCServer() net.Listener { + listener, err := net.Listen("tcp", listenerAddressTcp) + if err != nil { + log.Fatal("Listener error: ", err) + } + + app := fiber.New(fiber.Config{ + JSONEncoder: gojson.Marshal, + JSONDecoder: gojson.Unmarshal, + }) + app.Use(favicon.New()) + app.Use(compress.New(compress.Config{Level: compress.LevelBestSpeed})) + app.Use("/ws", func(c *fiber.Ctx) error { + // IsWebSocketUpgrade returns true if the client + // requested upgrade to the WebSocket protocol. + if websocket.IsWebSocketUpgrade(c) { + c.Locals("allowed", true) + return c.Next() + } + return fiber.ErrUpgradeRequired + }) + + app.Get("/ws", websocket.New(func(c *websocket.Conn) { + defer c.Close() + for { + // Read message from WebSocket + mt, message, err := c.ReadMessage() + if err != nil { + log.Println("Read error:", err) + break + } + + // Print the message to the console + log.Printf("Received: %s", message) + + // Echo the message back + err = c.WriteMessage(mt, message) + if err != nil { + log.Println("Write error:", err) + break + } + } + })) + + listenerAddressTcp = listener.Addr().String() + listenerAddressHttp = "http://" + listenerAddressTcp + listenerAddressWs = "ws://" + listenerAddressTcp + "/ws" + // Serve accepts incoming HTTP connections on the listener l, creating + // a new service goroutine for each. The service goroutines read requests + // and then call handler to reply to them + go app.Listener(listener) + + return listener +} + +func TestMain(m *testing.M) { + listener := createRPCServer() + for { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + _, err := rpcclient.DialContext(ctx, listenerAddressHttp) + _, err2 := rpcclient.DialContext(ctx, listenerAddressWs) + if err2 != nil { + utils.LavaFormatDebug("waiting for grpc server to launch") + continue + } + if err != nil { + utils.LavaFormatDebug("waiting for grpc server to launch") + continue + } + cancel() + break + } + + utils.LavaFormatDebug("listening on", utils.LogAttr("address", listenerAddressHttp)) + + // Start running tests. + code := m.Run() + listener.Close() + os.Exit(code) +} diff --git a/protocol/chainlib/chainlib.go b/protocol/chainlib/chainlib.go index 1fd57dc493..01dfd237a6 100644 --- a/protocol/chainlib/chainlib.go +++ b/protocol/chainlib/chainlib.go @@ -15,6 +15,11 @@ import ( spectypes "github.com/lavanet/lava/v2/x/spec/types" ) +var ( + IgnoreSubscriptionNotConfiguredError = true + IgnoreSubscriptionNotConfiguredErrorFlag = "ignore-subscription-not-configured-error" +) + func NewChainParser(apiInterface string) (chainParser ChainParser, err error) { switch apiInterface { case spectypes.APIInterfaceJsonRPC: @@ -37,12 +42,13 @@ func NewChainListener( rpcConsumerLogs *metrics.RPCConsumerLogs, chainParser ChainParser, refererData *RefererData, + consumerWsSubscriptionManager *ConsumerWSSubscriptionManager, ) (ChainListener, error) { switch listenEndpoint.ApiInterface { case spectypes.APIInterfaceJsonRPC: - return NewJrpcChainListener(ctx, listenEndpoint, relaySender, healthReporter, rpcConsumerLogs, refererData), nil + return NewJrpcChainListener(ctx, listenEndpoint, relaySender, healthReporter, rpcConsumerLogs, refererData, consumerWsSubscriptionManager), nil case spectypes.APIInterfaceTendermintRPC: - return NewTendermintRpcChainListener(ctx, listenEndpoint, relaySender, healthReporter, rpcConsumerLogs, refererData), nil + return NewTendermintRpcChainListener(ctx, listenEndpoint, relaySender, healthReporter, rpcConsumerLogs, refererData, consumerWsSubscriptionManager), nil case spectypes.APIInterfaceRest: return NewRestChainListener(ctx, listenEndpoint, relaySender, healthReporter, rpcConsumerLogs, refererData), nil case spectypes.APIInterfaceGrpc: @@ -56,7 +62,7 @@ type ChainParser interface { SetSpec(spec spectypes.Spec) DataReliabilityParams() (enabled bool, dataReliabilityThreshold uint32) ChainBlockStats() (allowedBlockLagForQosSync int64, averageBlockTime time.Duration, blockDistanceForFinalizedData, blocksInFinalizationProof uint32) - GetParsingByTag(tag spectypes.FUNCTION_TAG) (parsing *spectypes.ParseDirective, collectionData *spectypes.CollectionData, existed bool) + GetParsingByTag(tag spectypes.FUNCTION_TAG) (parsing *spectypes.ParseDirective, apiCollection *spectypes.ApiCollection, existed bool) CraftMessage(parser *spectypes.ParseDirective, connectionType string, craftData *CraftData, metadata []pairingtypes.Metadata) (ChainMessageForSend, error) HandleHeaders(metadata []pairingtypes.Metadata, apiCollection *spectypes.ApiCollection, headersDirection spectypes.Header_HeaderType) (filtered []pairingtypes.Metadata, overwriteReqBlock string, ignoredMetadata []pairingtypes.Metadata) GetVerifications(supported []string) ([]VerificationContainer, error) @@ -70,6 +76,7 @@ type ChainParser interface { } type ChainMessage interface { + SubscriptionIdExtractor(reply *rpcclient.JsonrpcMessage) string RequestedBlock() (latest int64, earliest int64) UpdateLatestBlockInMessage(latestBlock int64, modifyContent bool) (modified bool) AppendHeader(metadata []pairingtypes.Metadata) @@ -90,6 +97,7 @@ type ChainMessageForSend interface { GetApi() *spectypes.Api GetRPCMessage() rpcInterfaceMessages.GenericMessage GetApiCollection() *spectypes.ApiCollection + GetParseDirective() *spectypes.ParseDirective CheckResponseError(data []byte, httpStatusCode int) (hasError bool, errorMessage string) } @@ -108,6 +116,28 @@ type RelaySender interface { analytics *metrics.RelayMetrics, metadataValues []pairingtypes.Metadata, ) (*common.RelayResult, error) + ParseRelay( + ctx context.Context, + url string, + req string, + connectionType string, + dappID string, + consumerIp string, + analytics *metrics.RelayMetrics, + metadata []pairingtypes.Metadata, + ) (ChainMessage, map[string]string, *pairingtypes.RelayPrivateData, error) + SendParsedRelay( + ctx context.Context, + dappID string, + consumerIp string, + analytics *metrics.RelayMetrics, + chainMessage ChainMessage, + directiveHeaders map[string]string, + relayRequestData *pairingtypes.RelayPrivateData, + ) (relayResult *common.RelayResult, errRet error) + CreateDappKey(dappID, consumerIp string) string + CancelSubscriptionContext(subscriptionKey string) + SetConsistencySeenBlock(blockSeen int64, key string) } type ChainListener interface { diff --git a/protocol/chainlib/chainlib_mock.go b/protocol/chainlib/chainlib_mock.go new file mode 100644 index 0000000000..284a6c9b26 --- /dev/null +++ b/protocol/chainlib/chainlib_mock.go @@ -0,0 +1,872 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: protocol/chainlib/chainlib.go +// +// Generated by this command: +// +// mockgen -source protocol/chainlib/chainlib.go -destination protocol/chainlib/chainlib_mock.go -package chainlib +// + +// Package chainlib is a generated GoMock package. +package chainlib + +import ( + context "context" + reflect "reflect" + time "time" + + rpcInterfaceMessages "github.com/lavanet/lava/v2/protocol/chainlib/chainproxy/rpcInterfaceMessages" + rpcclient "github.com/lavanet/lava/v2/protocol/chainlib/chainproxy/rpcclient" + extensionslib "github.com/lavanet/lava/v2/protocol/chainlib/extensionslib" + common "github.com/lavanet/lava/v2/protocol/common" + metrics "github.com/lavanet/lava/v2/protocol/metrics" + types "github.com/lavanet/lava/v2/x/pairing/types" + types0 "github.com/lavanet/lava/v2/x/spec/types" + gomock "go.uber.org/mock/gomock" +) + +// MockChainParser is a mock of ChainParser interface. +type MockChainParser struct { + ctrl *gomock.Controller + recorder *MockChainParserMockRecorder +} + +// MockChainParserMockRecorder is the mock recorder for MockChainParser. +type MockChainParserMockRecorder struct { + mock *MockChainParser +} + +// NewMockChainParser creates a new mock instance. +func NewMockChainParser(ctrl *gomock.Controller) *MockChainParser { + mock := &MockChainParser{ctrl: ctrl} + mock.recorder = &MockChainParserMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockChainParser) EXPECT() *MockChainParserMockRecorder { + return m.recorder +} + +// Activate mocks base method. +func (m *MockChainParser) Activate() { + m.ctrl.T.Helper() + m.ctrl.Call(m, "Activate") +} + +// Activate indicates an expected call of Activate. +func (mr *MockChainParserMockRecorder) Activate() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Activate", reflect.TypeOf((*MockChainParser)(nil).Activate)) +} + +// Active mocks base method. +func (m *MockChainParser) Active() bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Active") + ret0, _ := ret[0].(bool) + return ret0 +} + +// Active indicates an expected call of Active. +func (mr *MockChainParserMockRecorder) Active() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Active", reflect.TypeOf((*MockChainParser)(nil).Active)) +} + +// ChainBlockStats mocks base method. +func (m *MockChainParser) ChainBlockStats() (int64, time.Duration, uint32, uint32) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ChainBlockStats") + ret0, _ := ret[0].(int64) + ret1, _ := ret[1].(time.Duration) + ret2, _ := ret[2].(uint32) + ret3, _ := ret[3].(uint32) + return ret0, ret1, ret2, ret3 +} + +// ChainBlockStats indicates an expected call of ChainBlockStats. +func (mr *MockChainParserMockRecorder) ChainBlockStats() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ChainBlockStats", reflect.TypeOf((*MockChainParser)(nil).ChainBlockStats)) +} + +// CraftMessage mocks base method. +func (m *MockChainParser) CraftMessage(parser *types0.ParseDirective, connectionType string, craftData *CraftData, metadata []types.Metadata) (ChainMessageForSend, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CraftMessage", parser, connectionType, craftData, metadata) + ret0, _ := ret[0].(ChainMessageForSend) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// CraftMessage indicates an expected call of CraftMessage. +func (mr *MockChainParserMockRecorder) CraftMessage(parser, connectionType, craftData, metadata any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CraftMessage", reflect.TypeOf((*MockChainParser)(nil).CraftMessage), parser, connectionType, craftData, metadata) +} + +// DataReliabilityParams mocks base method. +func (m *MockChainParser) DataReliabilityParams() (bool, uint32) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DataReliabilityParams") + ret0, _ := ret[0].(bool) + ret1, _ := ret[1].(uint32) + return ret0, ret1 +} + +// DataReliabilityParams indicates an expected call of DataReliabilityParams. +func (mr *MockChainParserMockRecorder) DataReliabilityParams() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DataReliabilityParams", reflect.TypeOf((*MockChainParser)(nil).DataReliabilityParams)) +} + +// ExtensionsParser mocks base method. +func (m *MockChainParser) ExtensionsParser() *extensionslib.ExtensionParser { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ExtensionsParser") + ret0, _ := ret[0].(*extensionslib.ExtensionParser) + return ret0 +} + +// ExtensionsParser indicates an expected call of ExtensionsParser. +func (mr *MockChainParserMockRecorder) ExtensionsParser() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ExtensionsParser", reflect.TypeOf((*MockChainParser)(nil).ExtensionsParser)) +} + +// GetParsingByTag mocks base method. +func (m *MockChainParser) GetParsingByTag(tag types0.FUNCTION_TAG) (*types0.ParseDirective, *types0.ApiCollection, bool) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetParsingByTag", tag) + ret0, _ := ret[0].(*types0.ParseDirective) + ret1, _ := ret[1].(*types0.ApiCollection) + ret2, _ := ret[2].(bool) + return ret0, ret1, ret2 +} + +// GetParsingByTag indicates an expected call of GetParsingByTag. +func (mr *MockChainParserMockRecorder) GetParsingByTag(tag any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetParsingByTag", reflect.TypeOf((*MockChainParser)(nil).GetParsingByTag), tag) +} + +// GetUniqueName mocks base method. +func (m *MockChainParser) GetUniqueName() string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetUniqueName") + ret0, _ := ret[0].(string) + return ret0 +} + +// GetUniqueName indicates an expected call of GetUniqueName. +func (mr *MockChainParserMockRecorder) GetUniqueName() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUniqueName", reflect.TypeOf((*MockChainParser)(nil).GetUniqueName)) +} + +// GetVerifications mocks base method. +func (m *MockChainParser) GetVerifications(supported []string) ([]VerificationContainer, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetVerifications", supported) + ret0, _ := ret[0].([]VerificationContainer) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetVerifications indicates an expected call of GetVerifications. +func (mr *MockChainParserMockRecorder) GetVerifications(supported any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetVerifications", reflect.TypeOf((*MockChainParser)(nil).GetVerifications), supported) +} + +// HandleHeaders mocks base method. +func (m *MockChainParser) HandleHeaders(metadata []types.Metadata, apiCollection *types0.ApiCollection, headersDirection types0.Header_HeaderType) ([]types.Metadata, string, []types.Metadata) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "HandleHeaders", metadata, apiCollection, headersDirection) + ret0, _ := ret[0].([]types.Metadata) + ret1, _ := ret[1].(string) + ret2, _ := ret[2].([]types.Metadata) + return ret0, ret1, ret2 +} + +// HandleHeaders indicates an expected call of HandleHeaders. +func (mr *MockChainParserMockRecorder) HandleHeaders(metadata, apiCollection, headersDirection any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandleHeaders", reflect.TypeOf((*MockChainParser)(nil).HandleHeaders), metadata, apiCollection, headersDirection) +} + +// ParseMsg mocks base method. +func (m *MockChainParser) ParseMsg(url string, data []byte, connectionType string, metadata []types.Metadata, extensionInfo extensionslib.ExtensionInfo) (ChainMessage, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ParseMsg", url, data, connectionType, metadata, extensionInfo) + ret0, _ := ret[0].(ChainMessage) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ParseMsg indicates an expected call of ParseMsg. +func (mr *MockChainParserMockRecorder) ParseMsg(url, data, connectionType, metadata, extensionInfo any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ParseMsg", reflect.TypeOf((*MockChainParser)(nil).ParseMsg), url, data, connectionType, metadata, extensionInfo) +} + +// SeparateAddonsExtensions mocks base method. +func (m *MockChainParser) SeparateAddonsExtensions(supported []string) ([]string, []string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SeparateAddonsExtensions", supported) + ret0, _ := ret[0].([]string) + ret1, _ := ret[1].([]string) + ret2, _ := ret[2].(error) + return ret0, ret1, ret2 +} + +// SeparateAddonsExtensions indicates an expected call of SeparateAddonsExtensions. +func (mr *MockChainParserMockRecorder) SeparateAddonsExtensions(supported any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SeparateAddonsExtensions", reflect.TypeOf((*MockChainParser)(nil).SeparateAddonsExtensions), supported) +} + +// SetPolicy mocks base method. +func (m *MockChainParser) SetPolicy(policy PolicyInf, chainId, apiInterface string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SetPolicy", policy, chainId, apiInterface) + ret0, _ := ret[0].(error) + return ret0 +} + +// SetPolicy indicates an expected call of SetPolicy. +func (mr *MockChainParserMockRecorder) SetPolicy(policy, chainId, apiInterface any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetPolicy", reflect.TypeOf((*MockChainParser)(nil).SetPolicy), policy, chainId, apiInterface) +} + +// SetSpec mocks base method. +func (m *MockChainParser) SetSpec(spec types0.Spec) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "SetSpec", spec) +} + +// SetSpec indicates an expected call of SetSpec. +func (mr *MockChainParserMockRecorder) SetSpec(spec any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetSpec", reflect.TypeOf((*MockChainParser)(nil).SetSpec), spec) +} + +// UpdateBlockTime mocks base method. +func (m *MockChainParser) UpdateBlockTime(newBlockTime time.Duration) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "UpdateBlockTime", newBlockTime) +} + +// UpdateBlockTime indicates an expected call of UpdateBlockTime. +func (mr *MockChainParserMockRecorder) UpdateBlockTime(newBlockTime any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateBlockTime", reflect.TypeOf((*MockChainParser)(nil).UpdateBlockTime), newBlockTime) +} + +// MockChainMessage is a mock of ChainMessage interface. +type MockChainMessage struct { + ctrl *gomock.Controller + recorder *MockChainMessageMockRecorder +} + +// MockChainMessageMockRecorder is the mock recorder for MockChainMessage. +type MockChainMessageMockRecorder struct { + mock *MockChainMessage +} + +// NewMockChainMessage creates a new mock instance. +func NewMockChainMessage(ctrl *gomock.Controller) *MockChainMessage { + mock := &MockChainMessage{ctrl: ctrl} + mock.recorder = &MockChainMessageMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockChainMessage) EXPECT() *MockChainMessageMockRecorder { + return m.recorder +} + +// AppendHeader mocks base method. +func (m *MockChainMessage) AppendHeader(metadata []types.Metadata) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "AppendHeader", metadata) +} + +// AppendHeader indicates an expected call of AppendHeader. +func (mr *MockChainMessageMockRecorder) AppendHeader(metadata any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AppendHeader", reflect.TypeOf((*MockChainMessage)(nil).AppendHeader), metadata) +} + +// CheckResponseError mocks base method. +func (m *MockChainMessage) CheckResponseError(data []byte, httpStatusCode int) (bool, string) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CheckResponseError", data, httpStatusCode) + ret0, _ := ret[0].(bool) + ret1, _ := ret[1].(string) + return ret0, ret1 +} + +// CheckResponseError indicates an expected call of CheckResponseError. +func (mr *MockChainMessageMockRecorder) CheckResponseError(data, httpStatusCode any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CheckResponseError", reflect.TypeOf((*MockChainMessage)(nil).CheckResponseError), data, httpStatusCode) +} + +// DisableErrorHandling mocks base method. +func (m *MockChainMessage) DisableErrorHandling() { + m.ctrl.T.Helper() + m.ctrl.Call(m, "DisableErrorHandling") +} + +// DisableErrorHandling indicates an expected call of DisableErrorHandling. +func (mr *MockChainMessageMockRecorder) DisableErrorHandling() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DisableErrorHandling", reflect.TypeOf((*MockChainMessage)(nil).DisableErrorHandling)) +} + +// GetApi mocks base method. +func (m *MockChainMessage) GetApi() *types0.Api { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetApi") + ret0, _ := ret[0].(*types0.Api) + return ret0 +} + +// GetApi indicates an expected call of GetApi. +func (mr *MockChainMessageMockRecorder) GetApi() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetApi", reflect.TypeOf((*MockChainMessage)(nil).GetApi)) +} + +// GetApiCollection mocks base method. +func (m *MockChainMessage) GetApiCollection() *types0.ApiCollection { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetApiCollection") + ret0, _ := ret[0].(*types0.ApiCollection) + return ret0 +} + +// GetApiCollection indicates an expected call of GetApiCollection. +func (mr *MockChainMessageMockRecorder) GetApiCollection() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetApiCollection", reflect.TypeOf((*MockChainMessage)(nil).GetApiCollection)) +} + +// GetExtensions mocks base method. +func (m *MockChainMessage) GetExtensions() []*types0.Extension { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetExtensions") + ret0, _ := ret[0].([]*types0.Extension) + return ret0 +} + +// GetExtensions indicates an expected call of GetExtensions. +func (mr *MockChainMessageMockRecorder) GetExtensions() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetExtensions", reflect.TypeOf((*MockChainMessage)(nil).GetExtensions)) +} + +// GetForceCacheRefresh mocks base method. +func (m *MockChainMessage) GetForceCacheRefresh() bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetForceCacheRefresh") + ret0, _ := ret[0].(bool) + return ret0 +} + +// GetForceCacheRefresh indicates an expected call of GetForceCacheRefresh. +func (mr *MockChainMessageMockRecorder) GetForceCacheRefresh() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetForceCacheRefresh", reflect.TypeOf((*MockChainMessage)(nil).GetForceCacheRefresh)) +} + +// GetParseDirective mocks base method. +func (m *MockChainMessage) GetParseDirective() *types0.ParseDirective { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetParseDirective") + ret0, _ := ret[0].(*types0.ParseDirective) + return ret0 +} + +// GetParseDirective indicates an expected call of GetParseDirective. +func (mr *MockChainMessageMockRecorder) GetParseDirective() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetParseDirective", reflect.TypeOf((*MockChainMessage)(nil).GetParseDirective)) +} + +// GetRPCMessage mocks base method. +func (m *MockChainMessage) GetRPCMessage() rpcInterfaceMessages.GenericMessage { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetRPCMessage") + ret0, _ := ret[0].(rpcInterfaceMessages.GenericMessage) + return ret0 +} + +// GetRPCMessage indicates an expected call of GetRPCMessage. +func (mr *MockChainMessageMockRecorder) GetRPCMessage() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetRPCMessage", reflect.TypeOf((*MockChainMessage)(nil).GetRPCMessage)) +} + +// OverrideExtensions mocks base method. +func (m *MockChainMessage) OverrideExtensions(extensionNames []string, extensionParser *extensionslib.ExtensionParser) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "OverrideExtensions", extensionNames, extensionParser) +} + +// OverrideExtensions indicates an expected call of OverrideExtensions. +func (mr *MockChainMessageMockRecorder) OverrideExtensions(extensionNames, extensionParser any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OverrideExtensions", reflect.TypeOf((*MockChainMessage)(nil).OverrideExtensions), extensionNames, extensionParser) +} + +// RequestedBlock mocks base method. +func (m *MockChainMessage) RequestedBlock() (int64, int64) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "RequestedBlock") + ret0, _ := ret[0].(int64) + ret1, _ := ret[1].(int64) + return ret0, ret1 +} + +// RequestedBlock indicates an expected call of RequestedBlock. +func (mr *MockChainMessageMockRecorder) RequestedBlock() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RequestedBlock", reflect.TypeOf((*MockChainMessage)(nil).RequestedBlock)) +} + +// SetForceCacheRefresh mocks base method. +func (m *MockChainMessage) SetForceCacheRefresh(force bool) bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SetForceCacheRefresh", force) + ret0, _ := ret[0].(bool) + return ret0 +} + +// SetForceCacheRefresh indicates an expected call of SetForceCacheRefresh. +func (mr *MockChainMessageMockRecorder) SetForceCacheRefresh(force any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetForceCacheRefresh", reflect.TypeOf((*MockChainMessage)(nil).SetForceCacheRefresh), force) +} + +// TimeoutOverride mocks base method. +func (m *MockChainMessage) TimeoutOverride(arg0 ...time.Duration) time.Duration { + m.ctrl.T.Helper() + varargs := []any{} + for _, a := range arg0 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "TimeoutOverride", varargs...) + ret0, _ := ret[0].(time.Duration) + return ret0 +} + +// TimeoutOverride indicates an expected call of TimeoutOverride. +func (mr *MockChainMessageMockRecorder) TimeoutOverride(arg0 ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TimeoutOverride", reflect.TypeOf((*MockChainMessage)(nil).TimeoutOverride), arg0...) +} + +// UpdateLatestBlockInMessage mocks base method. +func (m *MockChainMessage) UpdateLatestBlockInMessage(latestBlock int64, modifyContent bool) bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UpdateLatestBlockInMessage", latestBlock, modifyContent) + ret0, _ := ret[0].(bool) + return ret0 +} + +// UpdateLatestBlockInMessage indicates an expected call of UpdateLatestBlockInMessage. +func (mr *MockChainMessageMockRecorder) UpdateLatestBlockInMessage(latestBlock, modifyContent any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateLatestBlockInMessage", reflect.TypeOf((*MockChainMessage)(nil).UpdateLatestBlockInMessage), latestBlock, modifyContent) +} + +// MockChainMessageForSend is a mock of ChainMessageForSend interface. +type MockChainMessageForSend struct { + ctrl *gomock.Controller + recorder *MockChainMessageForSendMockRecorder +} + +// MockChainMessageForSendMockRecorder is the mock recorder for MockChainMessageForSend. +type MockChainMessageForSendMockRecorder struct { + mock *MockChainMessageForSend +} + +// NewMockChainMessageForSend creates a new mock instance. +func NewMockChainMessageForSend(ctrl *gomock.Controller) *MockChainMessageForSend { + mock := &MockChainMessageForSend{ctrl: ctrl} + mock.recorder = &MockChainMessageForSendMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockChainMessageForSend) EXPECT() *MockChainMessageForSendMockRecorder { + return m.recorder +} + +// GetApi mocks base method. +func (m *MockChainMessageForSend) GetApi() *types0.Api { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetApi") + ret0, _ := ret[0].(*types0.Api) + return ret0 +} + +// GetApi indicates an expected call of GetApi. +func (mr *MockChainMessageForSendMockRecorder) GetApi() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetApi", reflect.TypeOf((*MockChainMessageForSend)(nil).GetApi)) +} + +// GetApiCollection mocks base method. +func (m *MockChainMessageForSend) GetApiCollection() *types0.ApiCollection { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetApiCollection") + ret0, _ := ret[0].(*types0.ApiCollection) + return ret0 +} + +// GetApiCollection indicates an expected call of GetApiCollection. +func (mr *MockChainMessageForSendMockRecorder) GetApiCollection() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetApiCollection", reflect.TypeOf((*MockChainMessageForSend)(nil).GetApiCollection)) +} + +// GetParseDirective mocks base method. +func (m *MockChainMessageForSend) GetParseDirective() *types0.ParseDirective { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetParseDirective") + ret0, _ := ret[0].(*types0.ParseDirective) + return ret0 +} + +// GetParseDirective indicates an expected call of GetParseDirective. +func (mr *MockChainMessageForSendMockRecorder) GetParseDirective() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetParseDirective", reflect.TypeOf((*MockChainMessageForSend)(nil).GetParseDirective)) +} + +// GetRPCMessage mocks base method. +func (m *MockChainMessageForSend) GetRPCMessage() rpcInterfaceMessages.GenericMessage { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetRPCMessage") + ret0, _ := ret[0].(rpcInterfaceMessages.GenericMessage) + return ret0 +} + +// GetRPCMessage indicates an expected call of GetRPCMessage. +func (mr *MockChainMessageForSendMockRecorder) GetRPCMessage() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetRPCMessage", reflect.TypeOf((*MockChainMessageForSend)(nil).GetRPCMessage)) +} + +// TimeoutOverride mocks base method. +func (m *MockChainMessageForSend) TimeoutOverride(arg0 ...time.Duration) time.Duration { + m.ctrl.T.Helper() + varargs := []any{} + for _, a := range arg0 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "TimeoutOverride", varargs...) + ret0, _ := ret[0].(time.Duration) + return ret0 +} + +// TimeoutOverride indicates an expected call of TimeoutOverride. +func (mr *MockChainMessageForSendMockRecorder) TimeoutOverride(arg0 ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TimeoutOverride", reflect.TypeOf((*MockChainMessageForSend)(nil).TimeoutOverride), arg0...) +} + +// MockHealthReporter is a mock of HealthReporter interface. +type MockHealthReporter struct { + ctrl *gomock.Controller + recorder *MockHealthReporterMockRecorder +} + +// MockHealthReporterMockRecorder is the mock recorder for MockHealthReporter. +type MockHealthReporterMockRecorder struct { + mock *MockHealthReporter +} + +// NewMockHealthReporter creates a new mock instance. +func NewMockHealthReporter(ctrl *gomock.Controller) *MockHealthReporter { + mock := &MockHealthReporter{ctrl: ctrl} + mock.recorder = &MockHealthReporterMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockHealthReporter) EXPECT() *MockHealthReporterMockRecorder { + return m.recorder +} + +// IsHealthy mocks base method. +func (m *MockHealthReporter) IsHealthy() bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "IsHealthy") + ret0, _ := ret[0].(bool) + return ret0 +} + +// IsHealthy indicates an expected call of IsHealthy. +func (mr *MockHealthReporterMockRecorder) IsHealthy() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsHealthy", reflect.TypeOf((*MockHealthReporter)(nil).IsHealthy)) +} + +// MockRelaySender is a mock of RelaySender interface. +type MockRelaySender struct { + ctrl *gomock.Controller + recorder *MockRelaySenderMockRecorder +} + +// MockRelaySenderMockRecorder is the mock recorder for MockRelaySender. +type MockRelaySenderMockRecorder struct { + mock *MockRelaySender +} + +// NewMockRelaySender creates a new mock instance. +func NewMockRelaySender(ctrl *gomock.Controller) *MockRelaySender { + mock := &MockRelaySender{ctrl: ctrl} + mock.recorder = &MockRelaySenderMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockRelaySender) EXPECT() *MockRelaySenderMockRecorder { + return m.recorder +} + +// CancelSubscriptionContext mocks base method. +func (m *MockRelaySender) CancelSubscriptionContext(subscriptionKey string) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "CancelSubscriptionContext", subscriptionKey) +} + +// CancelSubscriptionContext indicates an expected call of CancelSubscriptionContext. +func (mr *MockRelaySenderMockRecorder) CancelSubscriptionContext(subscriptionKey any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CancelSubscriptionContext", reflect.TypeOf((*MockRelaySender)(nil).CancelSubscriptionContext), subscriptionKey) +} + +// CreateDappKey mocks base method. +func (m *MockRelaySender) CreateDappKey(dappID, consumerIp string) string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CreateDappKey", dappID, consumerIp) + ret0, _ := ret[0].(string) + return ret0 +} + +// CreateDappKey indicates an expected call of CreateDappKey. +func (mr *MockRelaySenderMockRecorder) CreateDappKey(dappID, consumerIp any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateDappKey", reflect.TypeOf((*MockRelaySender)(nil).CreateDappKey), dappID, consumerIp) +} + +// ParseRelay mocks base method. +func (m *MockRelaySender) ParseRelay(ctx context.Context, url, req, connectionType, dappID, consumerIp string, analytics *metrics.RelayMetrics, metadata []types.Metadata) (ChainMessage, map[string]string, *types.RelayPrivateData, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ParseRelay", ctx, url, req, connectionType, dappID, consumerIp, analytics, metadata) + ret0, _ := ret[0].(ChainMessage) + ret1, _ := ret[1].(map[string]string) + ret2, _ := ret[2].(*types.RelayPrivateData) + ret3, _ := ret[3].(error) + return ret0, ret1, ret2, ret3 +} + +// ParseRelay indicates an expected call of ParseRelay. +func (mr *MockRelaySenderMockRecorder) ParseRelay(ctx, url, req, connectionType, dappID, consumerIp, analytics, metadata any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ParseRelay", reflect.TypeOf((*MockRelaySender)(nil).ParseRelay), ctx, url, req, connectionType, dappID, consumerIp, analytics, metadata) +} + +// SendParsedRelay mocks base method. +func (m *MockRelaySender) SendParsedRelay(ctx context.Context, dappID, consumerIp string, analytics *metrics.RelayMetrics, chainMessage ChainMessage, directiveHeaders map[string]string, relayRequestData *types.RelayPrivateData) (*common.RelayResult, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SendParsedRelay", ctx, dappID, consumerIp, analytics, chainMessage, directiveHeaders, relayRequestData) + ret0, _ := ret[0].(*common.RelayResult) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// SendParsedRelay indicates an expected call of SendParsedRelay. +func (mr *MockRelaySenderMockRecorder) SendParsedRelay(ctx, dappID, consumerIp, analytics, chainMessage, directiveHeaders, relayRequestData any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendParsedRelay", reflect.TypeOf((*MockRelaySender)(nil).SendParsedRelay), ctx, dappID, consumerIp, analytics, chainMessage, directiveHeaders, relayRequestData) +} + +// SendRelay mocks base method. +func (m *MockRelaySender) SendRelay(ctx context.Context, url, req, connectionType, dappID, consumerIp string, analytics *metrics.RelayMetrics, metadataValues []types.Metadata) (*common.RelayResult, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SendRelay", ctx, url, req, connectionType, dappID, consumerIp, analytics, metadataValues) + ret0, _ := ret[0].(*common.RelayResult) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// SendRelay indicates an expected call of SendRelay. +func (mr *MockRelaySenderMockRecorder) SendRelay(ctx, url, req, connectionType, dappID, consumerIp, analytics, metadataValues any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendRelay", reflect.TypeOf((*MockRelaySender)(nil).SendRelay), ctx, url, req, connectionType, dappID, consumerIp, analytics, metadataValues) +} + +// SetConsistencySeenBlock mocks base method. +func (m *MockRelaySender) SetConsistencySeenBlock(blockSeen int64, key string) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "SetConsistencySeenBlock", blockSeen, key) +} + +// SetConsistencySeenBlock indicates an expected call of SetConsistencySeenBlock. +func (mr *MockRelaySenderMockRecorder) SetConsistencySeenBlock(blockSeen, key any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetConsistencySeenBlock", reflect.TypeOf((*MockRelaySender)(nil).SetConsistencySeenBlock), blockSeen, key) +} + +// MockChainListener is a mock of ChainListener interface. +type MockChainListener struct { + ctrl *gomock.Controller + recorder *MockChainListenerMockRecorder +} + +// MockChainListenerMockRecorder is the mock recorder for MockChainListener. +type MockChainListenerMockRecorder struct { + mock *MockChainListener +} + +// NewMockChainListener creates a new mock instance. +func NewMockChainListener(ctrl *gomock.Controller) *MockChainListener { + mock := &MockChainListener{ctrl: ctrl} + mock.recorder = &MockChainListenerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockChainListener) EXPECT() *MockChainListenerMockRecorder { + return m.recorder +} + +// Serve mocks base method. +func (m *MockChainListener) Serve(ctx context.Context, cmdFlags common.ConsumerCmdFlags) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "Serve", ctx, cmdFlags) +} + +// Serve indicates an expected call of Serve. +func (mr *MockChainListenerMockRecorder) Serve(ctx, cmdFlags any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Serve", reflect.TypeOf((*MockChainListener)(nil).Serve), ctx, cmdFlags) +} + +// MockChainRouter is a mock of ChainRouter interface. +type MockChainRouter struct { + ctrl *gomock.Controller + recorder *MockChainRouterMockRecorder +} + +// MockChainRouterMockRecorder is the mock recorder for MockChainRouter. +type MockChainRouterMockRecorder struct { + mock *MockChainRouter +} + +// NewMockChainRouter creates a new mock instance. +func NewMockChainRouter(ctrl *gomock.Controller) *MockChainRouter { + mock := &MockChainRouter{ctrl: ctrl} + mock.recorder = &MockChainRouterMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockChainRouter) EXPECT() *MockChainRouterMockRecorder { + return m.recorder +} + +// ExtensionsSupported mocks base method. +func (m *MockChainRouter) ExtensionsSupported(arg0 []string) bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ExtensionsSupported", arg0) + ret0, _ := ret[0].(bool) + return ret0 +} + +// ExtensionsSupported indicates an expected call of ExtensionsSupported. +func (mr *MockChainRouterMockRecorder) ExtensionsSupported(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ExtensionsSupported", reflect.TypeOf((*MockChainRouter)(nil).ExtensionsSupported), arg0) +} + +// SendNodeMsg mocks base method. +func (m *MockChainRouter) SendNodeMsg(ctx context.Context, ch chan any, chainMessage ChainMessageForSend, extensions []string) (*RelayReplyWrapper, string, *rpcclient.ClientSubscription, common.NodeUrl, string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SendNodeMsg", ctx, ch, chainMessage, extensions) + ret0, _ := ret[0].(*RelayReplyWrapper) + ret1, _ := ret[1].(string) + ret2, _ := ret[2].(*rpcclient.ClientSubscription) + ret3, _ := ret[3].(common.NodeUrl) + ret4, _ := ret[4].(string) + ret5, _ := ret[5].(error) + return ret0, ret1, ret2, ret3, ret4, ret5 +} + +// SendNodeMsg indicates an expected call of SendNodeMsg. +func (mr *MockChainRouterMockRecorder) SendNodeMsg(ctx, ch, chainMessage, extensions any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendNodeMsg", reflect.TypeOf((*MockChainRouter)(nil).SendNodeMsg), ctx, ch, chainMessage, extensions) +} + +// MockChainProxy is a mock of ChainProxy interface. +type MockChainProxy struct { + ctrl *gomock.Controller + recorder *MockChainProxyMockRecorder +} + +// MockChainProxyMockRecorder is the mock recorder for MockChainProxy. +type MockChainProxyMockRecorder struct { + mock *MockChainProxy +} + +// NewMockChainProxy creates a new mock instance. +func NewMockChainProxy(ctrl *gomock.Controller) *MockChainProxy { + mock := &MockChainProxy{ctrl: ctrl} + mock.recorder = &MockChainProxyMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockChainProxy) EXPECT() *MockChainProxyMockRecorder { + return m.recorder +} + +// GetChainProxyInformation mocks base method. +func (m *MockChainProxy) GetChainProxyInformation() (common.NodeUrl, string) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetChainProxyInformation") + ret0, _ := ret[0].(common.NodeUrl) + ret1, _ := ret[1].(string) + return ret0, ret1 +} + +// GetChainProxyInformation indicates an expected call of GetChainProxyInformation. +func (mr *MockChainProxyMockRecorder) GetChainProxyInformation() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChainProxyInformation", reflect.TypeOf((*MockChainProxy)(nil).GetChainProxyInformation)) +} + +// SendNodeMsg mocks base method. +func (m *MockChainProxy) SendNodeMsg(ctx context.Context, ch chan any, chainMessage ChainMessageForSend) (*RelayReplyWrapper, string, *rpcclient.ClientSubscription, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SendNodeMsg", ctx, ch, chainMessage) + ret0, _ := ret[0].(*RelayReplyWrapper) + ret1, _ := ret[1].(string) + ret2, _ := ret[2].(*rpcclient.ClientSubscription) + ret3, _ := ret[3].(error) + return ret0, ret1, ret2, ret3 +} + +// SendNodeMsg indicates an expected call of SendNodeMsg. +func (mr *MockChainProxyMockRecorder) SendNodeMsg(ctx, ch, chainMessage any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendNodeMsg", reflect.TypeOf((*MockChainProxy)(nil).SendNodeMsg), ctx, ch, chainMessage) +} diff --git a/protocol/chainlib/chainproxy/common.go b/protocol/chainlib/chainproxy/common.go index 97322ecea7..b09f6caa70 100644 --- a/protocol/chainlib/chainproxy/common.go +++ b/protocol/chainlib/chainproxy/common.go @@ -89,6 +89,10 @@ func (dri DefaultRPCInput) GetResult() json.RawMessage { return dri.Result } +func (dri DefaultRPCInput) GetID() json.RawMessage { + return nil +} + func (dri DefaultRPCInput) ParseBlock(inp string) (int64, error) { return parser.ParseDefaultBlockParameter(inp) } diff --git a/protocol/chainlib/chainproxy/connector.go b/protocol/chainlib/chainproxy/connector.go index 48883865f0..3c56f398ba 100644 --- a/protocol/chainlib/chainproxy/connector.go +++ b/protocol/chainlib/chainproxy/connector.go @@ -124,11 +124,12 @@ func (connector *Connector) createConnection(ctx context.Context, nodeUrl common timeout := common.AverageWorldLatency * (1 + time.Duration(numberOfConnectionAttempts)) nctx, cancel := nodeUrl.LowerContextTimeoutWithDuration(ctx, timeout) // add auth path - rpcClient, err = rpcclient.DialContext(nctx, nodeUrl.AuthConfig.AddAuthPath(nodeUrl.Url)) + authPathNodeUrl := nodeUrl.AuthConfig.AddAuthPath(nodeUrl.Url) + rpcClient, err = rpcclient.DialContext(nctx, authPathNodeUrl) if err != nil { utils.LavaFormatWarning("Could not connect to the node, retrying", err, []utils.Attribute{ {Key: "Current Number Of Connections", Value: currentNumberOfConnections}, - {Key: "Network Address", Value: nodeUrl.UrlStr()}, + {Key: "Network Address", Value: authPathNodeUrl}, {Key: "Number Of Attempts", Value: numberOfConnectionAttempts}, {Key: "timeout", Value: timeout}, }...) diff --git a/protocol/chainlib/chainproxy/connector_test.go b/protocol/chainlib/chainproxy/connector_test.go index b1ffc0a183..7b052295e0 100644 --- a/protocol/chainlib/chainproxy/connector_test.go +++ b/protocol/chainlib/chainproxy/connector_test.go @@ -70,28 +70,6 @@ func createGRPCServerWithRegisteredProto(t *testing.T) *grpc.Server { return s } -func createRPCServer() net.Listener { - timeserver := new(TimeServer) - // Register the timeserver object upon which the GiveServerTime - // function will be called from the RPC server (from the client) - rpc.Register(timeserver) - // Registers an HTTP handler for RPC messages - rpc.HandleHTTP() - // Start listening for the requests on port 1234 - listener, err := net.Listen("tcp", listenerAddress) - if err != nil { - log.Fatal("Listener error: ", err) - } - listenerAddress = listener.Addr().String() - listenerAddressTcp = "http://" + listenerAddress - // Serve accepts incoming HTTP connections on the listener l, creating - // a new service goroutine for each. The service goroutines read requests - // and then call handler to reply to them - go http.Serve(listener, nil) - - return listener -} - func TestConnector(t *testing.T) { ctx := context.Background() conn, err := NewConnector(ctx, numberOfClients, common.NodeUrl{Url: listenerAddressTcp}) @@ -181,6 +159,28 @@ func TestHashing(t *testing.T) { require.Equal(t, conn.hashedNodeUrl, HashURL(listenerAddressTcp)) } +func createRPCServer() net.Listener { + timeserver := new(TimeServer) + // Register the timeserver object upon which the GiveServerTime + // function will be called from the RPC server (from the client) + rpc.Register(timeserver) + // Registers an HTTP handler for RPC messages + rpc.HandleHTTP() + // Start listening for the requests on port 1234 + listener, err := net.Listen("tcp", listenerAddress) + if err != nil { + log.Fatal("Listener error: ", err) + } + listenerAddress = listener.Addr().String() + listenerAddressTcp = "http://" + listenerAddress + // Serve accepts incoming HTTP connections on the listener l, creating + // a new service goroutine for each. The service goroutines read requests + // and then call handler to reply to them + go http.Serve(listener, nil) + + return listener +} + func TestMain(m *testing.M) { listener := createRPCServer() for { diff --git a/protocol/chainlib/chainproxy/rpcInterfaceMessages/common.go b/protocol/chainlib/chainproxy/rpcInterfaceMessages/common.go index 026c01f0e3..58a6041be1 100644 --- a/protocol/chainlib/chainproxy/rpcInterfaceMessages/common.go +++ b/protocol/chainlib/chainproxy/rpcInterfaceMessages/common.go @@ -32,6 +32,10 @@ func (pri ParsableRPCInput) GetResult() json.RawMessage { return pri.Result } +func (pri ParsableRPCInput) GetID() json.RawMessage { + return nil +} + type GenericMessage interface { GetHeaders() []pairingtypes.Metadata DisableErrorHandling() diff --git a/protocol/chainlib/chainproxy/rpcInterfaceMessages/grpcMessage.go b/protocol/chainlib/chainproxy/rpcInterfaceMessages/grpcMessage.go index 04f4d4ada8..403f5303f4 100644 --- a/protocol/chainlib/chainproxy/rpcInterfaceMessages/grpcMessage.go +++ b/protocol/chainlib/chainproxy/rpcInterfaceMessages/grpcMessage.go @@ -14,10 +14,12 @@ import ( "github.com/jhump/protoreflect/dynamic" "github.com/jhump/protoreflect/grpcreflect" "github.com/lavanet/lava/v2/protocol/chainlib/chainproxy" + "github.com/lavanet/lava/v2/protocol/chainlib/chainproxy/rpcclient" dyncodec "github.com/lavanet/lava/v2/protocol/chainlib/grpcproxy/dyncodec" "github.com/lavanet/lava/v2/protocol/parser" "github.com/lavanet/lava/v2/utils" "github.com/lavanet/lava/v2/utils/sigs" + "google.golang.org/grpc/codes" "google.golang.org/protobuf/reflect/protoreflect" "google.golang.org/protobuf/types/dynamicpb" @@ -34,6 +36,10 @@ type GrpcMessage struct { chainproxy.BaseMessage } +func (gm *GrpcMessage) SubscriptionIdExtractor(reply *rpcclient.JsonrpcMessage) string { + return "" +} + // get msg hash byte array containing all the relevant information for a unique request. (headers / api / params) func (gm *GrpcMessage) GetRawRequestHash() ([]byte, error) { headers := gm.GetHeaders() @@ -116,6 +122,10 @@ func (gm GrpcMessage) GetMethod() string { return gm.Path } +func (gm GrpcMessage) GetID() json.RawMessage { + return nil +} + func (gm GrpcMessage) NewParsableRPCInput(input json.RawMessage) (parser.RPCInput, error) { msgFactory := dynamic.NewMessageFactoryWithDefaults() if gm.methodDesc == nil { diff --git a/protocol/chainlib/chainproxy/rpcInterfaceMessages/jsonRPCMessage.go b/protocol/chainlib/chainproxy/rpcInterfaceMessages/jsonRPCMessage.go index 357683c156..285927184d 100644 --- a/protocol/chainlib/chainproxy/rpcInterfaceMessages/jsonRPCMessage.go +++ b/protocol/chainlib/chainproxy/rpcInterfaceMessages/jsonRPCMessage.go @@ -25,6 +25,10 @@ type JsonrpcMessage struct { chainproxy.BaseMessage `json:"-"` } +func (jm *JsonrpcMessage) SubscriptionIdExtractor(reply *rpcclient.JsonrpcMessage) string { + return string(reply.Result) +} + // get msg hash byte array containing all the relevant information for a unique request. (headers / api / params) func (jm *JsonrpcMessage) GetRawRequestHash() ([]byte, error) { headers := jm.GetHeaders() @@ -106,11 +110,11 @@ func ConvertBatchElement(batchElement rpcclient.BatchElemWithId) (JsonrpcMessage return msg, nil } -func (gm *JsonrpcMessage) UpdateLatestBlockInMessage(latestBlock uint64, modifyContent bool) (success bool) { +func (jm *JsonrpcMessage) UpdateLatestBlockInMessage(latestBlock uint64, modifyContent bool) (success bool) { return false } -func (gm JsonrpcMessage) NewParsableRPCInput(input json.RawMessage) (parser.RPCInput, error) { +func (jm JsonrpcMessage) NewParsableRPCInput(input json.RawMessage) (parser.RPCInput, error) { msg := &JsonrpcMessage{} err := json.Unmarshal(input, msg) if err != nil { @@ -124,22 +128,26 @@ func (gm JsonrpcMessage) NewParsableRPCInput(input json.RawMessage) (parser.RPCI return ParsableRPCInput{Result: msg.Result}, nil } -func (cp JsonrpcMessage) GetParams() interface{} { - return cp.Params +func (jm JsonrpcMessage) GetParams() interface{} { + return jm.Params } -func (cp JsonrpcMessage) GetMethod() string { - return cp.Method +func (jm JsonrpcMessage) GetMethod() string { + return jm.Method } -func (cp JsonrpcMessage) GetResult() json.RawMessage { - if cp.Error != nil { - utils.LavaFormatWarning("GetResult() Request got an error from the node", nil, utils.Attribute{Key: "error", Value: cp.Error}) +func (jm JsonrpcMessage) GetResult() json.RawMessage { + if jm.Error != nil { + utils.LavaFormatWarning("GetResult() Request got an error from the node", nil, utils.Attribute{Key: "error", Value: jm.Error}) } - return cp.Result + return jm.Result +} + +func (jm JsonrpcMessage) GetID() json.RawMessage { + return jm.ID } -func (cp JsonrpcMessage) ParseBlock(inp string) (int64, error) { +func (jm JsonrpcMessage) ParseBlock(inp string) (int64, error) { return parser.ParseDefaultBlockParameter(inp) } @@ -170,8 +178,8 @@ type JsonrpcBatchMessage struct { chainproxy.BaseMessage } -func (jbm JsonrpcBatchMessage) GetParams() interface{} { - return nil +func (jbm *JsonrpcBatchMessage) SubscriptionIdExtractor(reply *rpcclient.JsonrpcMessage) string { + return "" } // on batches we don't want to calculate the batch hash as its impossible to get the args @@ -188,6 +196,10 @@ func (jbm *JsonrpcBatchMessage) GetBatch() []rpcclient.BatchElemWithId { return jbm.batch } +func (jbm JsonrpcBatchMessage) GetParams() interface{} { + return [][]byte{} +} + func NewBatchMessage(msgs []JsonrpcMessage) (JsonrpcBatchMessage, error) { batch := make([]rpcclient.BatchElemWithId, len(msgs)) for idx, msg := range msgs { diff --git a/protocol/chainlib/chainproxy/rpcInterfaceMessages/restMessage.go b/protocol/chainlib/chainproxy/rpcInterfaceMessages/restMessage.go index 75f85a4f9f..eda3b0435c 100644 --- a/protocol/chainlib/chainproxy/rpcInterfaceMessages/restMessage.go +++ b/protocol/chainlib/chainproxy/rpcInterfaceMessages/restMessage.go @@ -7,6 +7,7 @@ import ( "github.com/goccy/go-json" "github.com/lavanet/lava/v2/protocol/chainlib/chainproxy" + "github.com/lavanet/lava/v2/protocol/chainlib/chainproxy/rpcclient" "github.com/lavanet/lava/v2/protocol/parser" "github.com/lavanet/lava/v2/utils" "github.com/lavanet/lava/v2/utils/sigs" @@ -19,6 +20,10 @@ type RestMessage struct { chainproxy.BaseMessage } +func (rm *RestMessage) SubscriptionIdExtractor(reply *rpcclient.JsonrpcMessage) string { + return "" +} + // get msg hash byte array containing all the relevant information for a unique request. (headers / api / params) func (rm *RestMessage) GetRawRequestHash() ([]byte, error) { headers := rm.GetHeaders() @@ -53,13 +58,13 @@ func (jm RestMessage) CheckResponseError(data []byte, httpStatusCode int) (hasEr // GetParams will be deprecated after we remove old client // Currently needed because of parser.RPCInput interface -func (cp RestMessage) GetParams() interface{} { - urlObj, err := url.Parse(cp.Path) +func (rm RestMessage) GetParams() interface{} { + urlObj, err := url.Parse(rm.Path) if err != nil { return nil } parsedMethod := urlObj.Path - objectSpec := strings.Split(cp.SpecPath, "/") + objectSpec := strings.Split(rm.SpecPath, "/") objectPath := strings.Split(parsedMethod, "/") parameters := map[string]interface{}{} @@ -81,22 +86,26 @@ func (cp RestMessage) GetParams() interface{} { func (rm *RestMessage) UpdateLatestBlockInMessage(latestBlock uint64, modifyContent bool) (success bool) { // return rm.SetLatestBlockWithHeader(latestBlock, modifyContent) - // removed until behaviour inconsistency with the cosmos sdk header is solved + // removed until behavior inconsistency with the cosmos sdk header is solved return false // if !done else we need a different setter } // GetResult will be deprecated after we remove old client // Currently needed because of parser.RPCInput interface -func (cp RestMessage) GetResult() json.RawMessage { +func (rm RestMessage) GetResult() json.RawMessage { return nil } -func (cp RestMessage) GetMethod() string { - return cp.Path +func (rm RestMessage) GetMethod() string { + return rm.Path +} + +func (rm RestMessage) GetID() json.RawMessage { + return nil } // ParseBlock parses default block number from string to int -func (cp RestMessage) ParseBlock(inp string) (int64, error) { +func (rm RestMessage) ParseBlock(inp string) (int64, error) { return parser.ParseDefaultBlockParameter(inp) } diff --git a/protocol/chainlib/chainproxy/rpcInterfaceMessages/tendermintRPCMessage.go b/protocol/chainlib/chainproxy/rpcInterfaceMessages/tendermintRPCMessage.go index 0f43c64677..5a81767acf 100644 --- a/protocol/chainlib/chainproxy/rpcInterfaceMessages/tendermintRPCMessage.go +++ b/protocol/chainlib/chainproxy/rpcInterfaceMessages/tendermintRPCMessage.go @@ -4,9 +4,8 @@ import ( "fmt" "reflect" - "github.com/goccy/go-json" - tenderminttypes "github.com/cometbft/cometbft/rpc/jsonrpc/types" + "github.com/goccy/go-json" "github.com/lavanet/lava/v2/protocol/chainlib/chainproxy" "github.com/lavanet/lava/v2/protocol/chainlib/chainproxy/rpcclient" "github.com/lavanet/lava/v2/protocol/parser" @@ -20,6 +19,15 @@ type TendermintrpcMessage struct { Path string } +func (tm TendermintrpcMessage) SubscriptionIdExtractor(reply *rpcclient.JsonrpcMessage) string { + params, err := json.Marshal(tm.GetParams()) + if err != nil { + utils.LavaFormatWarning("failed marshaling params", err, utils.LogAttr("request", tm)) + return "" + } + return string(params) +} + // get msg hash byte array containing all the relevant information for a unique request. (headers / api / params) func (tm *TendermintrpcMessage) GetRawRequestHash() ([]byte, error) { headers := tm.GetHeaders() @@ -42,14 +50,14 @@ func (cp TendermintrpcMessage) GetParams() interface{} { return cp.Params } -func (cp TendermintrpcMessage) GetResult() json.RawMessage { - if cp.Error != nil { - utils.LavaFormatWarning("GetResult() Request got an error from the node", nil, utils.Attribute{Key: "error", Value: cp.Error}) +func (tm TendermintrpcMessage) GetResult() json.RawMessage { + if tm.Error != nil { + utils.LavaFormatWarning("GetResult() Request got an error from the node", nil, utils.Attribute{Key: "error", Value: tm.Error}) } - return cp.Result + return tm.Result } -func (cp TendermintrpcMessage) ParseBlock(inp string) (int64, error) { +func (tm TendermintrpcMessage) ParseBlock(inp string) (int64, error) { return parser.ParseDefaultBlockParameter(inp) } diff --git a/protocol/chainlib/chainproxy/rpcclient/client.go b/protocol/chainlib/chainproxy/rpcclient/client.go index dbbfdb7074..6e4e4c00cd 100644 --- a/protocol/chainlib/chainproxy/rpcclient/client.go +++ b/protocol/chainlib/chainproxy/rpcclient/client.go @@ -514,6 +514,11 @@ func (c *Client) Subscribe(ctx context.Context, id json.RawMessage, method strin return nil, nil, err } resp, err := op.wait(ctx, c) + // In the case of response containing the error message, we want to return it to the user as-is + if err != nil && resp != nil && resp.Error != nil { + return nil, resp, nil + } + if err != nil { return nil, nil, err } diff --git a/protocol/chainlib/chainproxy/rpcclient/utils.go b/protocol/chainlib/chainproxy/rpcclient/utils.go new file mode 100644 index 0000000000..4c7d7abd8c --- /dev/null +++ b/protocol/chainlib/chainproxy/rpcclient/utils.go @@ -0,0 +1,9 @@ +package rpcclient + +import ( + "github.com/lavanet/lava/v2/utils/sigs" +) + +func CreateHashFromParams(params []byte) string { + return string(sigs.HashMsg(params)) +} diff --git a/protocol/chainlib/common.go b/protocol/chainlib/common.go index 88f819d2e6..4cfda1c374 100644 --- a/protocol/chainlib/common.go +++ b/protocol/chainlib/common.go @@ -5,10 +5,10 @@ import ( "fmt" "net" "net/http" - "net/url" "strings" "time" + sdkerrors "cosmossdk.io/errors" gojson "github.com/goccy/go-json" "github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2/middleware/compress" @@ -30,9 +30,14 @@ const ( relayMsgLogMaxChars = 200 RPCProviderNodeAddressHash = "Lava-Provider-Node-Address-Hash" RPCProviderNodeExtension = "Lava-Provider-Node-Extension" + WebSocketExtension = "websocket" ) -var InvalidResponses = []string{"null", "", "nil", "undefined"} +var ( + InvalidResponses = []string{"null", "", "nil", "undefined"} + FailedSendingSubscriptionToClients = sdkerrors.New("failed Sending Subscription To Clients", 1015, "Failed Sending Subscription To Clients connection might have been closed by the user") + NoActiveSubscriptionFound = sdkerrors.New("failed finding an active subscription on provider side", 1016, "no active subscriptions for hashed params.") +) type RelayReplyWrapper struct { StatusCode int @@ -181,47 +186,10 @@ func addAttributeToError(key, value, errorMessage string) string { return errorMessage + fmt.Sprintf(`, "%v": "%v"`, key, value) } -// rpc default endpoint should be websocket. otherwise return an error -func verifyRPCEndpoint(endpoint string) { - u, err := url.Parse(endpoint) - if err != nil { - utils.LavaFormatFatal("unparsable url", err, utils.Attribute{Key: "url", Value: endpoint}) - } - switch u.Scheme { - case "ws", "wss": - return - default: - utils.LavaFormatWarning("URL scheme should be websocket (ws/wss), got: "+u.Scheme+", By not setting ws/wss your provider wont be able to accept ws subscriptions, therefore might receive less rewards and lower QOS score. if subscriptions are not applicable for this chain you can ignore this warning", nil) - } -} - -// rpc default endpoint should be websocket. otherwise return an error -func verifyTendermintEndpoint(endpoints []common.NodeUrl) (websocketEndpoint, httpEndpoint common.NodeUrl) { +func validateEndpoints(endpoints []common.NodeUrl, apiInterface string) { for _, endpoint := range endpoints { - u, err := url.Parse(endpoint.Url) - if err != nil { - utils.LavaFormatFatal("unparsable url", err, utils.Attribute{Key: "url", Value: endpoint.UrlStr()}) - } - switch u.Scheme { - case "http", "https": - httpEndpoint = endpoint - case "ws", "wss": - websocketEndpoint = endpoint - default: - utils.LavaFormatFatal("URL scheme should be websocket (ws/wss) or (http/https), got: "+u.Scheme, nil) - } - } - - if websocketEndpoint.String() == "" || httpEndpoint.String() == "" { - utils.LavaFormatError("Tendermint Provider was not provided with both http and websocket urls. please provide both", nil, - utils.Attribute{Key: "websocket", Value: websocketEndpoint.String()}, utils.Attribute{Key: "http", Value: httpEndpoint.String()}) - if httpEndpoint.String() != "" { - return httpEndpoint, httpEndpoint - } else { - utils.LavaFormatFatal("Tendermint Provider was not provided with http url. please provide a url that starts with http/https", nil) - } + common.ValidateEndpoint(endpoint.Url, apiInterface) } - return websocketEndpoint, httpEndpoint } func ListenWithRetry(app *fiber.App, address string) { diff --git a/protocol/chainlib/common_test.go b/protocol/chainlib/common_test.go index c637fb7199..757b3ba2f6 100644 --- a/protocol/chainlib/common_test.go +++ b/protocol/chainlib/common_test.go @@ -12,6 +12,7 @@ import ( "github.com/gofiber/websocket/v2" websocket2 "github.com/gorilla/websocket" "github.com/lavanet/lava/v2/protocol/chainlib/chainproxy" + "github.com/lavanet/lava/v2/protocol/chainlib/chainproxy/rpcclient" spectypes "github.com/lavanet/lava/v2/x/spec/types" "github.com/stretchr/testify/assert" ) @@ -313,6 +314,10 @@ type mockRPCInput struct { chainproxy.BaseMessage } +func (m *mockRPCInput) SubscriptionIdExtractor(reply *rpcclient.JsonrpcMessage) string { + return "" +} + func (m *mockRPCInput) GetRawRequestHash() ([]byte, error) { return nil, fmt.Errorf("test") } diff --git a/protocol/chainlib/common_test_utils.go b/protocol/chainlib/common_test_utils.go index 8449016e64..2782799b23 100644 --- a/protocol/chainlib/common_test_utils.go +++ b/protocol/chainlib/common_test_utils.go @@ -2,13 +2,17 @@ package chainlib import ( "context" + "fmt" "net" "net/http" "net/http/httptest" "strconv" + "strings" "testing" "time" + "github.com/gorilla/websocket" + "github.com/lavanet/lava/v2/utils" "github.com/lavanet/lava/v2/utils/rand" "github.com/lavanet/lava/v2/utils/sigs" @@ -84,9 +88,43 @@ func generateCombinations(arr []string) [][]string { return append(combinationsWithoutFirst, combinationsWithFirst...) } +func genericWebSocketHandler() http.HandlerFunc { + upGrader := websocket.Upgrader{} + + // Create a simple websocket server that mocks the node + return func(w http.ResponseWriter, r *http.Request) { + conn, err := upGrader.Upgrade(w, r, nil) + if err != nil { + fmt.Println(err) + panic("got error in upgrader") + } + defer conn.Close() + + for { + // Read the request + messageType, message, err := conn.ReadMessage() + if err != nil { + panic("got error in ReadMessage") + } + fmt.Println("got ws message", string(message), messageType) + conn.WriteMessage(messageType, message) + fmt.Println("writing ws message", string(message), messageType) + } + } +} + // generates a chain parser, a chain fetcher messages based on it // apiInterface can either be an ApiInterface string as in spectypes.ApiInterfaceXXX or a number for an index in the apiCollections -func CreateChainLibMocks(ctx context.Context, specIndex string, apiInterface string, serverCallback http.HandlerFunc, getToTopMostPath string, services []string) (cpar ChainParser, crout ChainRouter, cfetc chaintracker.ChainFetcher, closeServer func(), endpointRet *lavasession.RPCProviderEndpoint, errRet error) { +func CreateChainLibMocks( + ctx context.Context, + specIndex string, + apiInterface string, + httpServerCallback http.HandlerFunc, + wsServerCallback http.HandlerFunc, + getToTopMostPath string, + services []string, +) (cpar ChainParser, crout ChainRouter, cfetc chaintracker.ChainFetcher, closeServer func(), endpointRet *lavasession.RPCProviderEndpoint, errRet error) { + utils.SetGlobalLoggingLevel("debug") closeServer = nil spec, err := keepertest.GetASpec(specIndex, getToTopMostPath, nil, nil) if err != nil { @@ -114,6 +152,14 @@ func CreateChainLibMocks(ctx context.Context, specIndex string, apiInterface str return nil, nil, nil, nil, nil, err } + if httpServerCallback == nil { + httpServerCallback = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) }) + } + + if wsServerCallback == nil { + wsServerCallback = genericWebSocketHandler() + } + if apiInterface == spectypes.APIInterfaceGrpc { // Start a new gRPC server using the buffered connection grpcServer := grpc.NewServer() @@ -127,7 +173,7 @@ func CreateChainLibMocks(ctx context.Context, specIndex string, apiInterface str endpoint.NodeUrls = append(endpoint.NodeUrls, common.NodeUrl{Url: lis.Addr().String(), Addons: append(addons, extensionsList...)}) } go func() { - service := myServiceImplementation{serverCallback: serverCallback} + service := myServiceImplementation{serverCallback: httpServerCallback} tmservice.RegisterServiceServer(grpcServer, service) gogoreflection.Register(grpcServer) // Serve requests on the buffered connection @@ -141,9 +187,17 @@ func CreateChainLibMocks(ctx context.Context, specIndex string, apiInterface str return nil, nil, nil, closeServer, nil, err } } else { - mockServer := httptest.NewServer(serverCallback) - closeServer = mockServer.Close - endpoint.NodeUrls = append(endpoint.NodeUrls, common.NodeUrl{Url: mockServer.URL, Addons: addons}) + var mockWebSocketServer *httptest.Server + var wsUrl string + mockWebSocketServer = httptest.NewServer(wsServerCallback) + wsUrl = "ws" + strings.TrimPrefix(mockWebSocketServer.URL, "http") + mockHttpServer := httptest.NewServer(httpServerCallback) + closeServer = func() { + mockHttpServer.Close() + mockWebSocketServer.Close() + } + endpoint.NodeUrls = append(endpoint.NodeUrls, common.NodeUrl{Url: mockHttpServer.URL, Addons: addons}) + endpoint.NodeUrls = append(endpoint.NodeUrls, common.NodeUrl{Url: wsUrl, Addons: nil}) chainRouter, err = GetChainRouter(ctx, 1, endpoint, chainParser) if err != nil { return nil, nil, nil, closeServer, nil, err diff --git a/protocol/chainlib/consumer_websocket_manager.go b/protocol/chainlib/consumer_websocket_manager.go new file mode 100644 index 0000000000..3017328db4 --- /dev/null +++ b/protocol/chainlib/consumer_websocket_manager.go @@ -0,0 +1,273 @@ +package chainlib + +import ( + "context" + "strconv" + "time" + + gojson "github.com/goccy/go-json" + "github.com/gofiber/websocket/v2" + formatter "github.com/lavanet/lava/v2/ecosystem/cache/format" + "github.com/lavanet/lava/v2/protocol/common" + "github.com/lavanet/lava/v2/protocol/metrics" + "github.com/lavanet/lava/v2/utils" + spectypes "github.com/lavanet/lava/v2/x/spec/types" +) + +type ConsumerWebsocketManager struct { + websocketConn *websocket.Conn + rpcConsumerLogs *metrics.RPCConsumerLogs + cmdFlags common.ConsumerCmdFlags + refererMatchString string + relayMsgLogMaxChars int + chainId string + apiInterface string + connectionType string + refererData *RefererData + relaySender RelaySender + consumerWsSubscriptionManager *ConsumerWSSubscriptionManager + WebsocketConnectionUID string +} + +type ConsumerWebsocketManagerOptions struct { + WebsocketConn *websocket.Conn + RpcConsumerLogs *metrics.RPCConsumerLogs + RefererMatchString string + CmdFlags common.ConsumerCmdFlags + RelayMsgLogMaxChars int + ChainID string + ApiInterface string + ConnectionType string + RefererData *RefererData + RelaySender RelaySender + ConsumerWsSubscriptionManager *ConsumerWSSubscriptionManager + WebsocketConnectionUID string +} + +func NewConsumerWebsocketManager(options ConsumerWebsocketManagerOptions) *ConsumerWebsocketManager { + cwm := &ConsumerWebsocketManager{ + websocketConn: options.WebsocketConn, + relaySender: options.RelaySender, + rpcConsumerLogs: options.RpcConsumerLogs, + cmdFlags: options.CmdFlags, + refererMatchString: options.RefererMatchString, + relayMsgLogMaxChars: options.RelayMsgLogMaxChars, + chainId: options.ChainID, + apiInterface: options.ApiInterface, + connectionType: options.ConnectionType, + refererData: options.RefererData, + consumerWsSubscriptionManager: options.ConsumerWsSubscriptionManager, + WebsocketConnectionUID: options.WebsocketConnectionUID, + } + return cwm +} + +func (cwm *ConsumerWebsocketManager) GetWebSocketConnectionUniqueId(dappId, userIp string) string { + return dappId + "__" + userIp + "__" + cwm.WebsocketConnectionUID +} + +func (cwm *ConsumerWebsocketManager) ListenToMessages() { + var ( + messageType int + msg []byte + err error + ) + + type webSocketMsgWithType struct { + messageType int + msg []byte + } + + websocketConnWriteChan := make(chan webSocketMsgWithType) + + websocketConn := cwm.websocketConn + logger := cwm.rpcConsumerLogs + + webSocketCtx, cancelWebSocketCtx := context.WithCancel(context.Background()) + guid := utils.GenerateUniqueIdentifier() + webSocketCtx = utils.WithUniqueIdentifier(webSocketCtx, guid) + utils.LavaFormatDebug("consumer websocket manager started", utils.LogAttr("GUID", webSocketCtx)) + defer func() { + cancelWebSocketCtx() // In case there's a problem make sure to cancel the connection + utils.LavaFormatDebug("consumer websocket manager stopped", utils.LogAttr("GUID", webSocketCtx)) + }() + + go func() { + for msg := range websocketConnWriteChan { + select { + case <-webSocketCtx.Done(): + utils.LavaFormatTrace("websocket's context cancelled", utils.LogAttr("GUID", webSocketCtx)) + return + default: + err := cwm.websocketConn.WriteMessage(msg.messageType, msg.msg) + if err != nil { + utils.LavaFormatTrace("error writing msg to the websocket") + return + } + } + } + }() + + for { + startTime := time.Now() + msgSeed := logger.GetMessageSeed() + + utils.LavaFormatTrace("listening for new message from the websocket") + + if messageType, msg, err = websocketConn.ReadMessage(); err != nil { + utils.LavaFormatTrace("error reading msg from the websocket, probably websocket was closed by the user", utils.LogAttr("err", err)) + formatterMsg := logger.AnalyzeWebSocketErrorAndGetFormattedMessage(websocketConn.LocalAddr().String(), err, msgSeed, msg, cwm.apiInterface, time.Since(startTime)) + if formatterMsg != nil { + websocketConnWriteChan <- webSocketMsgWithType{messageType: messageType, msg: formatterMsg} + } + break + } + + dappID, ok := websocketConn.Locals("dapp-id").(string) + if !ok { + // Log and remove the analyze + formatterMsg := logger.AnalyzeWebSocketErrorAndGetFormattedMessage(websocketConn.LocalAddr().String(), nil, msgSeed, []byte("Unable to extract dappID"), cwm.apiInterface, time.Since(startTime)) + if formatterMsg != nil { + websocketConnWriteChan <- webSocketMsgWithType{messageType: messageType, msg: formatterMsg} + } + } + + msgSeed = strconv.FormatUint(guid, 10) + userIp := websocketConn.RemoteAddr().String() + + logFormattedMsg := string(msg) + if !cwm.cmdFlags.DebugRelays { + logFormattedMsg = utils.FormatLongString(logFormattedMsg, cwm.relayMsgLogMaxChars) + } + + utils.LavaFormatDebug("ws in <<<", + utils.LogAttr("seed", msgSeed), + utils.LogAttr("GUID", webSocketCtx), + utils.LogAttr("msg", logFormattedMsg), + utils.LogAttr("dappID", dappID), + ) + + metricsData := metrics.NewRelayAnalytics(dappID, cwm.chainId, cwm.apiInterface) + + chainMessage, directiveHeaders, relayRequestData, err := cwm.relaySender.ParseRelay(webSocketCtx, "", string(msg), cwm.connectionType, dappID, userIp, metricsData, nil) + if err != nil { + formatterMsg := logger.AnalyzeWebSocketErrorAndGetFormattedMessage(websocketConn.LocalAddr().String(), utils.LavaFormatError("could not parse message", err), msgSeed, msg, cwm.apiInterface, time.Since(startTime)) + if formatterMsg != nil { + websocketConnWriteChan <- webSocketMsgWithType{messageType: messageType, msg: formatterMsg} + } + continue + } + + // check whether its a normal relay / unsubscribe / unsubscribe_all otherwise its a subscription flow. + if !IsFunctionTagOfType(chainMessage, spectypes.FUNCTION_TAG_SUBSCRIBE) { + if IsFunctionTagOfType(chainMessage, spectypes.FUNCTION_TAG_UNSUBSCRIBE) { + err := cwm.consumerWsSubscriptionManager.Unsubscribe(webSocketCtx, chainMessage, directiveHeaders, relayRequestData, dappID, userIp, cwm.WebsocketConnectionUID, metricsData) + if err != nil { + utils.LavaFormatWarning("error unsubscribing from subscription", err, utils.LogAttr("GUID", webSocketCtx)) + if err == common.SubscriptionNotFoundError { + msgData, err := gojson.Marshal(common.JsonRpcSubscriptionNotFoundError) + if err != nil { + continue + } + + websocketConnWriteChan <- webSocketMsgWithType{messageType: messageType, msg: msgData} + } + } + continue + } else if IsFunctionTagOfType(chainMessage, spectypes.FUNCTION_TAG_UNSUBSCRIBE_ALL) { + err := cwm.consumerWsSubscriptionManager.UnsubscribeAll(webSocketCtx, dappID, userIp, cwm.WebsocketConnectionUID, metricsData) + if err != nil { + utils.LavaFormatWarning("error unsubscribing from all subscription", err, utils.LogAttr("GUID", webSocketCtx)) + } + continue + } else { + // Normal relay over websocket. (not subscription related) + relayResult, err := cwm.relaySender.SendParsedRelay(webSocketCtx, dappID, userIp, metricsData, chainMessage, directiveHeaders, relayRequestData) + if err != nil { + formatterMsg := logger.AnalyzeWebSocketErrorAndGetFormattedMessage(websocketConn.LocalAddr().String(), utils.LavaFormatError("could not send parsed relay", err), msgSeed, msg, cwm.apiInterface, time.Since(startTime)) + if formatterMsg != nil { + websocketConnWriteChan <- webSocketMsgWithType{messageType: messageType, msg: formatterMsg} + continue + } + } + + relayResultReply := relayResult.GetReply() + if relayResultReply != nil { + // No need to verify signature since this is already happening inside the SendParsedRelay flow + websocketConnWriteChan <- webSocketMsgWithType{messageType: messageType, msg: relayResult.GetReply().Data} + continue + } + utils.LavaFormatError("Relay result is nil over websocket normal request flow, should not happen", err, utils.LogAttr("messageType", messageType)) + } + } + + // Subscription flow + inputFormatter, outputFormatter := formatter.FormatterForRelayRequestAndResponse(relayRequestData.ApiInterface) // we use this to preserve the original jsonrpc id + inputFormatter(relayRequestData.Data) // set the extracted jsonrpc id + + reply, subscriptionMsgsChan, err := cwm.consumerWsSubscriptionManager.StartSubscription(webSocketCtx, chainMessage, directiveHeaders, relayRequestData, dappID, userIp, cwm.WebsocketConnectionUID, metricsData) + if err != nil { + utils.LavaFormatWarning("StartSubscription returned an error", err, + utils.LogAttr("GUID", webSocketCtx), + utils.LogAttr("dappID", dappID), + utils.LogAttr("userIp", userIp), + utils.LogAttr("params", chainMessage.GetRPCMessage().GetParams()), + ) + + formatterMsg := logger.AnalyzeWebSocketErrorAndGetFormattedMessage(websocketConn.LocalAddr().String(), utils.LavaFormatError("could not start subscription", err), msgSeed, msg, cwm.apiInterface, time.Since(startTime)) + if formatterMsg != nil { + websocketConnWriteChan <- webSocketMsgWithType{messageType: messageType, msg: formatterMsg} // No need to use outputFormatter here since we are sending an error + continue + } + + // Handle the case when the error is a method not found error + if common.APINotSupportedError.Is(err) { + msgData, err := gojson.Marshal(common.JsonRpcMethodNotFoundError) + if err != nil { + continue + } + + websocketConnWriteChan <- webSocketMsgWithType{messageType: messageType, msg: outputFormatter(msgData)} + continue + } + continue + } + + if subscriptionMsgsChan != nil { // if == nil, it means that we already have an active subscription running on this query + go func() { + utils.LavaFormatTrace("created go routine for new websocketSubMsgsChan", + utils.LogAttr("GUID", webSocketCtx), + utils.LogAttr("dappID", dappID), + utils.LogAttr("userIp", userIp), + utils.LogAttr("params", chainMessage.GetRPCMessage().GetParams()), + ) + + for subscriptionMsgReply := range subscriptionMsgsChan { + websocketConnWriteChan <- webSocketMsgWithType{messageType: messageType, msg: outputFormatter(subscriptionMsgReply.Data)} + } + + utils.LavaFormatTrace("subscriptionMsgsChan was closed", + utils.LogAttr("GUID", webSocketCtx), + utils.LogAttr("dappID", dappID), + utils.LogAttr("userIp", userIp), + utils.LogAttr("params", chainMessage.GetRPCMessage().GetParams()), + ) + }() + } + + refererMatch, referrerMatchCastedSuccessfully := websocketConn.Locals(cwm.refererMatchString).(string) + if referrerMatchCastedSuccessfully && refererMatch != "" && cwm.refererData != nil { + go cwm.refererData.SendReferer(refererMatch, cwm.chainId, string(msg), websocketConn.RemoteAddr().String(), nil, websocketConn) + } + + go logger.AddMetricForWebSocket(metricsData, err, websocketConn) + + if reply != nil { + reply.Data = outputFormatter(reply.Data) // use that id for the reply + + websocketConnWriteChan <- webSocketMsgWithType{messageType: messageType, msg: reply.Data} + + logger.LogRequestAndResponse("jsonrpc ws msg", false, "ws", websocketConn.LocalAddr().String(), string(msg), string(reply.Data), msgSeed, time.Since(startTime), nil) + } + } +} diff --git a/protocol/chainlib/consumer_ws_subscription_manager.go b/protocol/chainlib/consumer_ws_subscription_manager.go new file mode 100644 index 0000000000..6b993588cd --- /dev/null +++ b/protocol/chainlib/consumer_ws_subscription_manager.go @@ -0,0 +1,930 @@ +package chainlib + +import ( + "context" + "fmt" + "strconv" + "sync" + + gojson "github.com/goccy/go-json" + rpcclient "github.com/lavanet/lava/v2/protocol/chainlib/chainproxy/rpcclient" + "github.com/lavanet/lava/v2/protocol/common" + "github.com/lavanet/lava/v2/protocol/lavaprotocol" + "github.com/lavanet/lava/v2/protocol/lavasession" + "github.com/lavanet/lava/v2/protocol/metrics" + "github.com/lavanet/lava/v2/utils" + "github.com/lavanet/lava/v2/utils/protocopy" + pairingtypes "github.com/lavanet/lava/v2/x/pairing/types" + spectypes "github.com/lavanet/lava/v2/x/spec/types" +) + +type unsubscribeRelayData struct { + chainMessage ChainMessage + directiveHeaders map[string]string + relayRequestData *pairingtypes.RelayPrivateData +} + +type activeSubscriptionHolder struct { + firstSubscriptionReply *pairingtypes.RelayReply + subscriptionOriginalRequest *pairingtypes.RelayRequest + subscriptionOriginalRequestChainMessage ChainMessage + firstSubscriptionReplyAsJsonrpcMessage *rpcclient.JsonrpcMessage + replyServer pairingtypes.Relayer_RelaySubscribeClient + closeSubscriptionChan chan *unsubscribeRelayData + connectedDappKeys map[string]struct{} // key is dapp key + subscriptionId string +} + +// by using the broadcast manager, we make sure we don't have a race between read and writes and make sure we don't hang for ever +type pendingSubscriptionsBroadcastManager struct { + broadcastChannelList []chan bool +} + +func (psbm *pendingSubscriptionsBroadcastManager) broadcastToChannelList(value bool) { + for _, ch := range psbm.broadcastChannelList { + utils.LavaFormatTrace("broadcastToChannelList Notified pending subscriptions", utils.LogAttr("success", value)) + ch <- value + } +} + +type ConsumerWSSubscriptionManager struct { + connectedDapps map[string]map[string]*common.SafeChannelSender[*pairingtypes.RelayReply] // first key is dapp key, second key is hashed params + activeSubscriptions map[string]*activeSubscriptionHolder // key is params hash + relaySender RelaySender + consumerSessionManager *lavasession.ConsumerSessionManager + chainParser ChainParser + refererData *RefererData + connectionType string + activeSubscriptionProvidersStorage *lavasession.ActiveSubscriptionProvidersStorage + currentlyPendingSubscriptions map[string]*pendingSubscriptionsBroadcastManager + lock sync.RWMutex +} + +func NewConsumerWSSubscriptionManager( + consumerSessionManager *lavasession.ConsumerSessionManager, + relaySender RelaySender, + refererData *RefererData, + connectionType string, + chainParser ChainParser, + activeSubscriptionProvidersStorage *lavasession.ActiveSubscriptionProvidersStorage, +) *ConsumerWSSubscriptionManager { + return &ConsumerWSSubscriptionManager{ + connectedDapps: make(map[string]map[string]*common.SafeChannelSender[*pairingtypes.RelayReply]), + activeSubscriptions: make(map[string]*activeSubscriptionHolder), + currentlyPendingSubscriptions: make(map[string]*pendingSubscriptionsBroadcastManager), + consumerSessionManager: consumerSessionManager, + chainParser: chainParser, + refererData: refererData, + relaySender: relaySender, + connectionType: connectionType, + activeSubscriptionProvidersStorage: activeSubscriptionProvidersStorage, + } +} + +// must be called while locked! +// checking whether hashed params exist in storage, if it does return the subscription stream and indicate it was found. +// otherwise return false +func (cwsm *ConsumerWSSubscriptionManager) checkForActiveSubscriptionAndConnect(webSocketCtx context.Context, hashedParams string, chainMessage ChainMessage, dappKey string, websocketRepliesSafeChannelSender *common.SafeChannelSender[*pairingtypes.RelayReply]) (*pairingtypes.RelayReply, bool) { + activeSubscription, found := cwsm.activeSubscriptions[hashedParams] + if found { + utils.LavaFormatTrace("found active subscription for given params", + utils.LogAttr("GUID", webSocketCtx), + utils.LogAttr("params", chainMessage.GetRPCMessage().GetParams()), + utils.LogAttr("hashedParams", utils.ToHexString(hashedParams)), + utils.LogAttr("dappKey", dappKey), + ) + + if _, ok := activeSubscription.connectedDappKeys[dappKey]; ok { + utils.LavaFormatTrace("found active subscription for given params and dappKey", + utils.LogAttr("GUID", webSocketCtx), + utils.LogAttr("params", chainMessage.GetRPCMessage().GetParams()), + utils.LogAttr("hashedParams", utils.ToHexString(hashedParams)), + utils.LogAttr("dappKey", dappKey), + ) + + return activeSubscription.firstSubscriptionReply, true // found and already active + } + + // Add to existing subscription + cwsm.connectDappWithSubscription(dappKey, websocketRepliesSafeChannelSender, hashedParams) + + return activeSubscription.firstSubscriptionReply, false // found and not active, register new. + } + // not found, need to apply new message + return nil, false +} + +func (cwsm *ConsumerWSSubscriptionManager) failedPendingSubscription(hashedParams string) { + cwsm.lock.Lock() + defer cwsm.lock.Unlock() + pendingSubscriptionChannel, ok := cwsm.currentlyPendingSubscriptions[hashedParams] + if !ok { + utils.LavaFormatError("failed fetching hashed params in failedPendingSubscriptions", nil, utils.LogAttr("hash", hashedParams), utils.LogAttr("cwsm.currentlyPendingSubscriptions", cwsm.currentlyPendingSubscriptions)) + } else { + pendingSubscriptionChannel.broadcastToChannelList(false) + delete(cwsm.currentlyPendingSubscriptions, hashedParams) // removed pending + } +} + +// must be called under a lock. +func (cwsm *ConsumerWSSubscriptionManager) successfulPendingSubscription(hashedParams string) { + pendingSubscriptionChannel, ok := cwsm.currentlyPendingSubscriptions[hashedParams] + if !ok { + utils.LavaFormatError("failed fetching hashed params in successfulPendingSubscription", nil, utils.LogAttr("hash", hashedParams), utils.LogAttr("cwsm.currentlyPendingSubscriptions", cwsm.currentlyPendingSubscriptions)) + } else { + pendingSubscriptionChannel.broadcastToChannelList(true) + delete(cwsm.currentlyPendingSubscriptions, hashedParams) // removed pending + } +} + +func (cwsm *ConsumerWSSubscriptionManager) checkAndAddPendingSubscriptionsWithLock(hashedParams string) (chan bool, bool) { + cwsm.lock.Lock() + defer cwsm.lock.Unlock() + pendingSubscriptionBroadcastManager, ok := cwsm.currentlyPendingSubscriptions[hashedParams] + if !ok { + // we didn't find hashed params for pending subscriptions, we can create a new subscription + utils.LavaFormatTrace("No pending subscription for incoming hashed params found", utils.LogAttr("params", hashedParams)) + // create pending subscription broadcast manager for other users to sync on the same relay. + cwsm.currentlyPendingSubscriptions[hashedParams] = &pendingSubscriptionsBroadcastManager{} + return nil, ok + } + utils.LavaFormatTrace("found subscription for incoming hashed params, registering our channel", utils.LogAttr("params", hashedParams)) + // by creating a buffered channel we make sure that we wont miss out on the update between the time we register and listen to the channel + listenChan := make(chan bool, 1) + pendingSubscriptionBroadcastManager.broadcastChannelList = append(pendingSubscriptionBroadcastManager.broadcastChannelList, listenChan) + return listenChan, ok +} + +func (cwsm *ConsumerWSSubscriptionManager) checkForActiveSubscriptionWithLock( + webSocketCtx context.Context, + hashedParams string, + chainMessage ChainMessage, + dappKey string, + websocketRepliesSafeChannelSender *common.SafeChannelSender[*pairingtypes.RelayReply], + closeWebsocketRepliesChannel func(), +) (*pairingtypes.RelayReply, bool) { + cwsm.lock.Lock() + defer cwsm.lock.Unlock() + + firstSubscriptionReply, alreadyActiveSubscription := cwsm.checkForActiveSubscriptionAndConnect( + webSocketCtx, hashedParams, chainMessage, dappKey, websocketRepliesSafeChannelSender, + ) + + if firstSubscriptionReply != nil { + if alreadyActiveSubscription { // same dapp Id, no need for new channel + closeWebsocketRepliesChannel() + return firstSubscriptionReply, false + } + // Added to existing subscriptions with a new dappId. + return firstSubscriptionReply, true + } + + // if we reached here, the subscription is currently not registered, we will need to check again later when we apply the subscription, and + // handle the case where two identical subscriptions were launched at the same time. + return nil, false +} + +func (cwsm *ConsumerWSSubscriptionManager) StartSubscription( + webSocketCtx context.Context, + chainMessage ChainMessage, + directiveHeaders map[string]string, + relayRequestData *pairingtypes.RelayPrivateData, + dappID string, + consumerIp string, + webSocketConnectionUniqueId string, + metricsData *metrics.RelayMetrics, +) (firstReply *pairingtypes.RelayReply, repliesChan <-chan *pairingtypes.RelayReply, err error) { + hashedParams, _, err := cwsm.getHashedParams(chainMessage) + if err != nil { + return nil, nil, utils.LavaFormatError("could not marshal params", err) + } + + dappKey := cwsm.CreateWebSocketConnectionUniqueKey(dappID, consumerIp, webSocketConnectionUniqueId) + + utils.LavaFormatTrace("request to start subscription", + utils.LogAttr("GUID", webSocketCtx), + utils.LogAttr("hashedParams", utils.ToHexString(hashedParams)), + utils.LogAttr("dappKey", dappKey), + utils.LogAttr("connectedDapps", cwsm.connectedDapps), + ) + + websocketRepliesChan := make(chan *pairingtypes.RelayReply) + websocketRepliesSafeChannelSender := common.NewSafeChannelSender(webSocketCtx, websocketRepliesChan) + + closeWebsocketRepliesChan := make(chan struct{}) + closeWebsocketRepliesChannel := func() { + select { + case closeWebsocketRepliesChan <- struct{}{}: + default: + } + } + + // called after send relay failure or parsing failure afterwards + onSubscriptionFailure := func() { + cwsm.failedPendingSubscription(hashedParams) + closeWebsocketRepliesChannel() + } + + go func() { + <-closeWebsocketRepliesChan + utils.LavaFormatTrace("requested to close websocketRepliesChan", + utils.LogAttr("GUID", webSocketCtx), + utils.LogAttr("params", chainMessage.GetRPCMessage().GetParams()), + utils.LogAttr("hashedParams", utils.ToHexString(hashedParams)), + utils.LogAttr("dappKey", dappKey), + ) + + websocketRepliesSafeChannelSender.Close() + }() + + // Remove the websocket from the active subscriptions, when the websocket is closed + go func() { + <-webSocketCtx.Done() + utils.LavaFormatTrace("websocket context is done, removing websocket from active subscriptions", + utils.LogAttr("GUID", webSocketCtx), + utils.LogAttr("params", chainMessage.GetRPCMessage().GetParams()), + utils.LogAttr("hashedParams", utils.ToHexString(hashedParams)), + utils.LogAttr("dappKey", dappKey), + ) + + cwsm.lock.Lock() + defer cwsm.lock.Unlock() + + if _, ok := cwsm.connectedDapps[dappKey]; ok { + // The websocket can be closed before the first reply is received, so we need to check if the dapp was even added to the connectedDapps map + cwsm.verifyAndDisconnectDappFromSubscription(webSocketCtx, dappKey, hashedParams, nil) + } + closeWebsocketRepliesChannel() + }() + + // Validated there are no active subscriptions that we can use. + firstSubscriptionReply, returnWebsocketRepliesChan := cwsm.checkForActiveSubscriptionWithLock(webSocketCtx, hashedParams, chainMessage, dappKey, websocketRepliesSafeChannelSender, closeWebsocketRepliesChannel) + if firstSubscriptionReply != nil { + if returnWebsocketRepliesChan { + return firstSubscriptionReply, websocketRepliesChan, nil + } + return firstSubscriptionReply, nil, nil + } + + // This for loop will break when there is a successful queue lock, allowing us to avoid racing new subscription creation when + // there is a failed subscription. the loop will break for the first routine the manages to lock and create the pendingSubscriptionsBroadcastManager + for { + // Incase there are no active subscriptions, check for pending subscriptions with the same hashed params. + // avoiding having the same subscription twice. + pendingSubscriptionChannel, foundPendingSubscription := cwsm.checkAndAddPendingSubscriptionsWithLock(hashedParams) + if foundPendingSubscription { + utils.LavaFormatTrace("Found pending subscription, waiting for it to complete") + // this is a buffered channel, it wont get stuck even if it is written to before the time we listen + res := <-pendingSubscriptionChannel + utils.LavaFormatTrace("Finished pending for subscription, have results", utils.LogAttr("success", res)) + // Check res is valid, if not fall through logs and try again with a new client. + if res { + firstSubscriptionReply, returnWebsocketRepliesChan := cwsm.checkForActiveSubscriptionWithLock(webSocketCtx, hashedParams, chainMessage, dappKey, websocketRepliesSafeChannelSender, closeWebsocketRepliesChannel) + if firstSubscriptionReply != nil { + if returnWebsocketRepliesChan { + return firstSubscriptionReply, websocketRepliesChan, nil + } + return firstSubscriptionReply, nil, nil + } + // In case we expected a subscription to return as res != nil we should find an active subscription. + // If we fail to find it, it might have suddenly stopped. we will log a warning and try with a new client. + utils.LavaFormatWarning("failed getting a result when channel indicated we got a successful relay", nil) + } + // Failed the subscription attempt, will retry using current relay. + utils.LavaFormatDebug("Failed the subscription attempt, retrying with the incoming message", utils.LogAttr("hash", hashedParams)) + } else { + utils.LavaFormatDebug("No Pending subscriptions, creating a new one", utils.LogAttr("hash", hashedParams)) + break + } + } + + utils.LavaFormatTrace("could not find active subscription for given params, creating new one", + utils.LogAttr("GUID", webSocketCtx), + utils.LogAttr("params", chainMessage.GetRPCMessage().GetParams()), + utils.LogAttr("hashedParams", utils.ToHexString(hashedParams)), + utils.LogAttr("dappKey", dappKey), + ) + + relayResult, err := cwsm.relaySender.SendParsedRelay(webSocketCtx, dappID, consumerIp, metricsData, chainMessage, directiveHeaders, relayRequestData) + if err != nil { + onSubscriptionFailure() + return nil, nil, utils.LavaFormatError("could not send subscription relay", err) + } + + utils.LavaFormatTrace("got relay result from SendParsedRelay", + utils.LogAttr("GUID", webSocketCtx), + utils.LogAttr("params", chainMessage.GetRPCMessage().GetParams()), + utils.LogAttr("hashedParams", utils.ToHexString(hashedParams)), + utils.LogAttr("dappKey", dappKey), + utils.LogAttr("relayResult", relayResult), + ) + + replyServer := relayResult.GetReplyServer() + if replyServer == nil { + // This code should never be reached, but just in case + onSubscriptionFailure() + return nil, nil, utils.LavaFormatError("reply server is nil, probably an error with the subscription initiation", nil, + utils.LogAttr("GUID", webSocketCtx), + utils.LogAttr("params", chainMessage.GetRPCMessage().GetParams()), + utils.LogAttr("hashedParams", utils.ToHexString(hashedParams)), + utils.LogAttr("dappKey", dappKey), + ) + } + + reply := *relayResult.Reply + if reply.Data == nil { + onSubscriptionFailure() + return nil, nil, utils.LavaFormatError("Reply data is nil", nil, + utils.LogAttr("GUID", webSocketCtx), + utils.LogAttr("params", chainMessage.GetRPCMessage().GetParams()), + utils.LogAttr("hashedParams", utils.ToHexString(hashedParams)), + utils.LogAttr("dappKey", dappKey), + ) + } + + copiedRequest := &pairingtypes.RelayRequest{} + err = protocopy.DeepCopyProtoObject(relayResult.Request, copiedRequest) + if err != nil { + onSubscriptionFailure() + return nil, nil, utils.LavaFormatError("could not copy relay request", err, + utils.LogAttr("GUID", webSocketCtx), + utils.LogAttr("params", chainMessage.GetRPCMessage().GetParams()), + utils.LogAttr("hashedParams", utils.ToHexString(hashedParams)), + utils.LogAttr("dappKey", dappKey), + ) + } + + err = cwsm.verifySubscriptionMessage(hashedParams, chainMessage, relayResult.Request, &reply, relayResult.ProviderInfo.ProviderAddress) + if err != nil { + onSubscriptionFailure() + return nil, nil, utils.LavaFormatError("Failed VerifyRelayReply on subscription message", err, + utils.LogAttr("GUID", webSocketCtx), + utils.LogAttr("params", chainMessage.GetRPCMessage().GetParams()), + utils.LogAttr("hashedParams", utils.ToHexString(hashedParams)), + utils.LogAttr("dappKey", dappKey), + utils.LogAttr("reply", string(reply.Data)), + ) + } + + // Parse the reply + var replyJsonrpcMessage rpcclient.JsonrpcMessage + err = gojson.Unmarshal(reply.Data, &replyJsonrpcMessage) + if err != nil { + onSubscriptionFailure() + return nil, nil, utils.LavaFormatError("could not parse reply into json", err, + utils.LogAttr("GUID", webSocketCtx), + utils.LogAttr("params", chainMessage.GetRPCMessage().GetParams()), + utils.LogAttr("hashedParams", utils.ToHexString(hashedParams)), + utils.LogAttr("dappKey", dappKey), + utils.LogAttr("reply", reply.Data), + ) + } + + utils.LavaFormatTrace("Adding new subscription", + utils.LogAttr("GUID", webSocketCtx), + utils.LogAttr("hashedParams", utils.ToHexString(hashedParams)), + utils.LogAttr("dappKey", dappKey), + utils.LogAttr("dappID", dappID), + utils.LogAttr("consumerIp", consumerIp), + ) + + cwsm.lock.Lock() + defer cwsm.lock.Unlock() + + subscriptionId := chainMessage.SubscriptionIdExtractor(&replyJsonrpcMessage) + subscriptionId = common.UnSquareBracket(subscriptionId) + if common.IsQuoted(subscriptionId) { + subscriptionId, _ = strconv.Unquote(subscriptionId) + } + + // we don't have a subscription of this hashedParams stored, create a new one. + closeSubscriptionChan := make(chan *unsubscribeRelayData) + cwsm.activeSubscriptions[hashedParams] = &activeSubscriptionHolder{ + firstSubscriptionReply: &reply, + firstSubscriptionReplyAsJsonrpcMessage: &replyJsonrpcMessage, + replyServer: replyServer, + subscriptionOriginalRequest: copiedRequest, + subscriptionOriginalRequestChainMessage: chainMessage, + closeSubscriptionChan: closeSubscriptionChan, + connectedDappKeys: map[string]struct{}{dappKey: {}}, + subscriptionId: subscriptionId, + } + + providerAddr := relayResult.ProviderInfo.ProviderAddress + cwsm.activeSubscriptionProvidersStorage.AddProvider(providerAddr) + cwsm.connectDappWithSubscription(dappKey, websocketRepliesSafeChannelSender, hashedParams) + // trigger success for other pending subscriptions + cwsm.successfulPendingSubscription(hashedParams) + // Need to be run once for subscription + go cwsm.listenForSubscriptionMessages(webSocketCtx, dappID, consumerIp, replyServer, hashedParams, providerAddr, metricsData, closeSubscriptionChan) + + return &reply, websocketRepliesChan, nil +} + +func (cwsm *ConsumerWSSubscriptionManager) listenForSubscriptionMessages( + webSocketCtx context.Context, + dappID string, + userIp string, + replyServer pairingtypes.Relayer_RelaySubscribeClient, + hashedParams string, + providerAddr string, + metricsData *metrics.RelayMetrics, + closeSubscriptionChan chan *unsubscribeRelayData, +) { + var unsubscribeData *unsubscribeRelayData + + defer func() { + // Only gets here when there is an issue with the connection to the provider or the connection's context is canceled + // Then, we close all active connections with dapps + + cwsm.lock.Lock() + defer cwsm.lock.Unlock() + + utils.LavaFormatTrace("closing all connected dapps for closed subscription connection", + utils.LogAttr("GUID", webSocketCtx), + utils.LogAttr("hashedParams", utils.ToHexString(hashedParams)), + ) + + cwsm.activeSubscriptions[hashedParams].connectedDappKeys = make(map[string]struct{}) // disconnect all dapps at once from active subscription + + // Close all remaining active connections + for _, connectedDapp := range cwsm.connectedDapps { + if _, ok := connectedDapp[hashedParams]; ok { + connectedDapp[hashedParams].Close() + delete(connectedDapp, hashedParams) + } + } + + // we run the unsubscribe flow in an inner function so it wont prevent us from removing the activeSubscriptions at the end. + func() { + var err error + var chainMessage ChainMessage + var directiveHeaders map[string]string + var relayRequestData *pairingtypes.RelayPrivateData + if unsubscribeData != nil { + // This unsubscribe request was initiated by the user + utils.LavaFormatTrace("unsubscribe request was made by the user", + utils.LogAttr("GUID", webSocketCtx), + utils.LogAttr("hashedParams", utils.ToHexString(hashedParams)), + ) + + chainMessage = unsubscribeData.chainMessage + directiveHeaders = unsubscribeData.directiveHeaders + relayRequestData = unsubscribeData.relayRequestData + } else { + // This unsubscribe request was initiated by us + utils.LavaFormatTrace("unsubscribe request was made automatically", + utils.LogAttr("GUID", webSocketCtx), + utils.LogAttr("hashedParams", utils.ToHexString(hashedParams)), + ) + + chainMessage, directiveHeaders, relayRequestData, err = cwsm.craftUnsubscribeMessage(hashedParams, dappID, userIp, metricsData) + if err != nil { + utils.LavaFormatError("could not craft unsubscribe message", err, utils.LogAttr("GUID", webSocketCtx)) + return + } + + stringJson, err := gojson.Marshal(chainMessage.GetRPCMessage()) + if err != nil { + utils.LavaFormatError("could not marshal chain message", err, utils.LogAttr("GUID", webSocketCtx)) + return + } + + utils.LavaFormatTrace("crafted unsubscribe message to send to the provider", + utils.LogAttr("GUID", webSocketCtx), + utils.LogAttr("hashedParams", utils.ToHexString(hashedParams)), + utils.LogAttr("chainMessage", string(stringJson)), + ) + } + + unsubscribeRelayCtx := utils.WithUniqueIdentifier(context.Background(), utils.GenerateUniqueIdentifier()) + err = cwsm.sendUnsubscribeMessage(unsubscribeRelayCtx, dappID, userIp, chainMessage, directiveHeaders, relayRequestData, metricsData) + if err != nil { + utils.LavaFormatError("could not send unsubscribe message due to a relay error", + err, + utils.LogAttr("GUID", webSocketCtx), + utils.LogAttr("relayRequestData", relayRequestData), + utils.LogAttr("dappID", dappID), + utils.LogAttr("userIp", userIp), + utils.LogAttr("api", chainMessage.GetApi().Name), + utils.LogAttr("params", chainMessage.GetRPCMessage().GetParams()), + ) + } else { + utils.LavaFormatTrace("success sending unsubscribe message, deleting hashed params from activeSubscriptions", + utils.LogAttr("hashedParams", utils.ToHexString(hashedParams)), + utils.LogAttr("chainMessage", cwsm.activeSubscriptions), + ) + } + }() + + delete(cwsm.activeSubscriptions, hashedParams) + utils.LavaFormatTrace("after delete") + + cwsm.activeSubscriptionProvidersStorage.RemoveProvider(providerAddr) + utils.LavaFormatTrace("after remove") + cwsm.relaySender.CancelSubscriptionContext(hashedParams) + utils.LavaFormatTrace("after cancel") + }() + + for { + select { + case unsubscribeData = <-closeSubscriptionChan: + utils.LavaFormatTrace("requested to close subscription connection", + utils.LogAttr("GUID", webSocketCtx), + utils.LogAttr("hashedParams", utils.ToHexString(hashedParams)), + ) + return + case <-replyServer.Context().Done(): + utils.LavaFormatTrace("reply server context canceled", + utils.LogAttr("GUID", webSocketCtx), + utils.LogAttr("hashedParams", utils.ToHexString(hashedParams)), + ) + return + default: + var reply pairingtypes.RelayReply + err := replyServer.RecvMsg(&reply) + if err != nil { + // The connection was closed by the provider + utils.LavaFormatTrace("error reading from subscription stream", utils.LogAttr("original error", err.Error())) + return + } + err = cwsm.handleIncomingSubscriptionNodeMessage(hashedParams, &reply, providerAddr) + if err != nil { + utils.LavaFormatError("failed handling subscription message", err, + utils.LogAttr("hashedParams", hashedParams), + utils.LogAttr("reply", reply), + ) + return + } + } + } +} + +func (cwsm *ConsumerWSSubscriptionManager) verifySubscriptionMessage(hashedParams string, chainMessage ChainMessage, request *pairingtypes.RelayRequest, reply *pairingtypes.RelayReply, providerAddr string) error { + lavaprotocol.UpdateRequestedBlock(request.RelayData, reply) // update relay request requestedBlock to the provided one in case it was arbitrary + filteredHeaders, _, ignoredHeaders := cwsm.chainParser.HandleHeaders(reply.Metadata, chainMessage.GetApiCollection(), spectypes.Header_pass_reply) + reply.Metadata = filteredHeaders + err := lavaprotocol.VerifyRelayReply(context.Background(), reply, request, providerAddr) + if err != nil { + return utils.LavaFormatError("Failed VerifyRelayReply on subscription message", err, + utils.LogAttr("subscriptionMsg", reply.Data), + utils.LogAttr("hashedParams", hashedParams), + utils.LogAttr("originalRequest", request), + ) + } + + reply.Metadata = append(reply.Metadata, ignoredHeaders...) + return nil +} + +func (cwsm *ConsumerWSSubscriptionManager) handleIncomingSubscriptionNodeMessage(hashedParams string, subscriptionRelayReplyMsg *pairingtypes.RelayReply, providerAddr string) error { + cwsm.lock.RLock() + defer cwsm.lock.RUnlock() + + activeSubscription := cwsm.activeSubscriptions[hashedParams] + // we need to copy the original message because the verify changes the requested block every time. + copiedRequest := &pairingtypes.RelayRequest{} + err := protocopy.DeepCopyProtoObject(activeSubscription.subscriptionOriginalRequest, copiedRequest) + if err != nil { + return utils.LavaFormatError("could not copy relay request", err, + utils.LogAttr("hashedParams", hashedParams), + utils.LogAttr("subscriptionMsg", subscriptionRelayReplyMsg.Data), + utils.LogAttr("providerAddr", providerAddr), + ) + } + + chainMessage := activeSubscription.subscriptionOriginalRequestChainMessage + err = cwsm.verifySubscriptionMessage(hashedParams, chainMessage, copiedRequest, subscriptionRelayReplyMsg, providerAddr) + if err != nil { + // Critical error, we need to close the connection + return utils.LavaFormatError("Failed VerifyRelayReply on subscription message", err, + utils.LogAttr("hashedParams", hashedParams), + utils.LogAttr("subscriptionMsg", subscriptionRelayReplyMsg.Data), + utils.LogAttr("providerAddr", providerAddr), + ) + } + + // message is valid, we can now distribute the message to all active listening users. + for connectedDappKey := range activeSubscription.connectedDappKeys { + if _, ok := cwsm.connectedDapps[connectedDappKey]; !ok { + utils.LavaFormatError("connected dapp not found", nil, + utils.LogAttr("connectedDappKey", connectedDappKey), + utils.LogAttr("hashedParams", hashedParams), + utils.LogAttr("activeSubscriptions[hashedParams].connectedDapps", activeSubscription.connectedDappKeys), + utils.LogAttr("connectedDapps", cwsm.connectedDapps), + ) + continue + } + + if _, ok := cwsm.connectedDapps[connectedDappKey][hashedParams]; !ok { + utils.LavaFormatError("dapp is not connected to subscription", nil, + utils.LogAttr("connectedDappKey", connectedDappKey), + utils.LogAttr("hashedParams", hashedParams), + utils.LogAttr("activeSubscriptions[hashedParams].connectedDapps", activeSubscription.connectedDappKeys), + utils.LogAttr("connectedDapps[connectedDappKey]", cwsm.connectedDapps[connectedDappKey]), + ) + continue + } + // set consistency seen block + cwsm.relaySender.SetConsistencySeenBlock(subscriptionRelayReplyMsg.LatestBlock, connectedDappKey) + // send the reply to the user + cwsm.connectedDapps[connectedDappKey][hashedParams].Send(subscriptionRelayReplyMsg) + } + + return nil +} + +func (cwsm *ConsumerWSSubscriptionManager) getHashedParams(chainMessage ChainMessageForSend) (hashedParams string, params []byte, err error) { + params, err = gojson.Marshal(chainMessage.GetRPCMessage().GetParams()) + if err != nil { + return "", nil, utils.LavaFormatError("could not marshal params", err) + } + + hashedParams = rpcclient.CreateHashFromParams(params) + + return hashedParams, params, nil +} + +func (cwsm *ConsumerWSSubscriptionManager) Unsubscribe(webSocketCtx context.Context, chainMessage ChainMessage, directiveHeaders map[string]string, relayRequestData *pairingtypes.RelayPrivateData, dappID, consumerIp string, webSocketConnectionUniqueId string, metricsData *metrics.RelayMetrics) error { + utils.LavaFormatTrace("want to unsubscribe", + utils.LogAttr("GUID", webSocketCtx), + utils.LogAttr("dappID", dappID), + utils.LogAttr("consumerIp", consumerIp), + utils.LogAttr("webSocketConnectionUniqueId", webSocketConnectionUniqueId), + ) + + dappKey := cwsm.CreateWebSocketConnectionUniqueKey(dappID, consumerIp, webSocketConnectionUniqueId) + + cwsm.lock.Lock() + defer cwsm.lock.Unlock() + hashedParams, err := cwsm.findActiveSubscriptionHashedParamsFromChainMessage(chainMessage) + if err != nil { + return err + } + return cwsm.verifyAndDisconnectDappFromSubscription(webSocketCtx, dappKey, hashedParams, func() (*unsubscribeRelayData, error) { + return &unsubscribeRelayData{chainMessage, directiveHeaders, relayRequestData}, nil + }) +} + +func (cwsm *ConsumerWSSubscriptionManager) craftUnsubscribeMessage(hashedParams, dappID, consumerIp string, metricsData *metrics.RelayMetrics) (ChainMessage, map[string]string, *pairingtypes.RelayPrivateData, error) { + request := cwsm.activeSubscriptions[hashedParams].subscriptionOriginalRequestChainMessage + subscriptionId := cwsm.activeSubscriptions[hashedParams].subscriptionId + + // Craft the message data from function template + var unsubscribeRequestData string + var found bool + for _, currParseDirective := range request.GetApiCollection().ParseDirectives { + if currParseDirective.FunctionTag == spectypes.FUNCTION_TAG_UNSUBSCRIBE { + unsubscribeRequestData = fmt.Sprintf(currParseDirective.FunctionTemplate, subscriptionId) + found = true + break + } + } + + if !found { + return nil, nil, nil, utils.LavaFormatError("could not find unsubscribe parse directive for given chain message", nil, + utils.LogAttr("hashedParams", utils.ToHexString(hashedParams)), + utils.LogAttr("subscriptionId", subscriptionId), + ) + } + + if unsubscribeRequestData == "" { + return nil, nil, nil, utils.LavaFormatError("unsubscribe request data is empty", nil, + utils.LogAttr("hashedParams", utils.ToHexString(hashedParams)), + utils.LogAttr("subscriptionId", subscriptionId), + ) + } + + // Craft the unsubscribe chain message + ctx := context.Background() + chainMessage, directiveHeaders, relayRequestData, err := cwsm.relaySender.ParseRelay(ctx, "", unsubscribeRequestData, cwsm.connectionType, dappID, consumerIp, metricsData, nil) + if err != nil { + return nil, nil, nil, utils.LavaFormatError("could not craft unsubscribe chain message", err, + utils.LogAttr("hashedParams", utils.ToHexString(hashedParams)), + utils.LogAttr("subscriptionId", subscriptionId), + utils.LogAttr("unsubscribeRequestData", unsubscribeRequestData), + utils.LogAttr("cwsm.connectionType", cwsm.connectionType), + ) + } + + return chainMessage, directiveHeaders, relayRequestData, nil +} + +func (cwsm *ConsumerWSSubscriptionManager) sendUnsubscribeMessage(ctx context.Context, dappID, consumerIp string, chainMessage ChainMessage, directiveHeaders map[string]string, relayRequestData *pairingtypes.RelayPrivateData, metricsData *metrics.RelayMetrics) error { + // Send the crafted unsubscribe relay + utils.LavaFormatTrace("sending unsubscribe relay", + utils.LogAttr("GUID", ctx), + utils.LogAttr("dappID", dappID), + utils.LogAttr("consumerIp", consumerIp), + ) + + _, err := cwsm.relaySender.SendParsedRelay(ctx, dappID, consumerIp, metricsData, chainMessage, directiveHeaders, relayRequestData) + if err != nil { + return utils.LavaFormatError("could not send unsubscribe relay", err) + } + + return nil +} + +func (cwsm *ConsumerWSSubscriptionManager) connectDappWithSubscription(dappKey string, webSocketChan *common.SafeChannelSender[*pairingtypes.RelayReply], hashedParams string) { + // Must be called under a lock + + // Validate hashedParams is in active subscriptions. + if _, ok := cwsm.activeSubscriptions[hashedParams]; !ok { + utils.LavaFormatError("Failed finding hashed params in connectDappWithSubscription, should never happen", nil, utils.LogAttr("hashedParams", hashedParams), utils.LogAttr("cwsm.activeSubscriptions", cwsm.activeSubscriptions)) + return + } + cwsm.activeSubscriptions[hashedParams].connectedDappKeys[dappKey] = struct{}{} + if _, ok := cwsm.connectedDapps[dappKey]; !ok { + cwsm.connectedDapps[dappKey] = make(map[string]*common.SafeChannelSender[*pairingtypes.RelayReply]) + } + cwsm.connectedDapps[dappKey][hashedParams] = webSocketChan +} + +func (cwsm *ConsumerWSSubscriptionManager) CreateWebSocketConnectionUniqueKey(dappID, consumerIp, webSocketConnectionUniqueId string) string { + return cwsm.relaySender.CreateDappKey(dappID, consumerIp) + "__" + webSocketConnectionUniqueId +} + +func (cwsm *ConsumerWSSubscriptionManager) UnsubscribeAll(webSocketCtx context.Context, dappID, consumerIp string, webSocketConnectionUniqueId string, metricsData *metrics.RelayMetrics) error { + utils.LavaFormatTrace("want to unsubscribe all", + utils.LogAttr("GUID", webSocketCtx), + utils.LogAttr("dappID", dappID), + utils.LogAttr("consumerIp", consumerIp), + ) + + dappKey := cwsm.CreateWebSocketConnectionUniqueKey(dappID, consumerIp, webSocketConnectionUniqueId) + + cwsm.lock.Lock() + defer cwsm.lock.Unlock() + + // Look for active connection + if _, ok := cwsm.connectedDapps[dappKey]; !ok { + return utils.LavaFormatDebug("webSocket has no active subscriptions", + utils.LogAttr("GUID", webSocketCtx), + utils.LogAttr("dappID", dappID), + utils.LogAttr("consumerIp", consumerIp), + ) + } + + for hashedParams := range cwsm.connectedDapps[dappKey] { + utils.LavaFormatTrace("disconnecting dapp from subscription", + utils.LogAttr("GUID", webSocketCtx), + utils.LogAttr("dappID", dappID), + utils.LogAttr("consumerIp", consumerIp), + utils.LogAttr("hashedParams", utils.ToHexString(hashedParams)), + ) + + unsubscribeRelayGetter := func() (*unsubscribeRelayData, error) { + chainMessage, directiveHeaders, relayRequestData, err := cwsm.craftUnsubscribeMessage(hashedParams, dappID, consumerIp, metricsData) + if err != nil { + return nil, err + } + + return &unsubscribeRelayData{chainMessage, directiveHeaders, relayRequestData}, nil + } + + cwsm.verifyAndDisconnectDappFromSubscription(webSocketCtx, dappKey, hashedParams, unsubscribeRelayGetter) + } + + return nil +} + +func (cwsm *ConsumerWSSubscriptionManager) findActiveSubscriptionHashedParamsFromChainMessage(chainMessage ChainMessage) (string, error) { + // Must be called under lock + + // Extract the subscription id from the chain message + unsubscribeRequestParams, err := gojson.Marshal(chainMessage.GetRPCMessage().GetParams()) + if err != nil { + return "", utils.LavaFormatError("could not marshal params", err, utils.LogAttr("params", chainMessage.GetRPCMessage().GetParams())) + } + + unsubscribeRequestParamsString := string(unsubscribeRequestParams) + + // In JsonRPC, the subscription id is a string, but it is sent in an array + // In Tendermint, the subscription id is the query params, and sent as an object, so skipped + unsubscribeRequestParamsString = common.UnSquareBracket(unsubscribeRequestParamsString) + + if common.IsQuoted(unsubscribeRequestParamsString) { + unsubscribeRequestParamsString, err = strconv.Unquote(unsubscribeRequestParamsString) + if err != nil { + return "", utils.LavaFormatError("could not unquote params", err) + } + } + + for hashesParams, activeSubscription := range cwsm.activeSubscriptions { + if activeSubscription.subscriptionId == unsubscribeRequestParamsString { + return hashesParams, nil + } + } + + utils.LavaFormatDebug("could not find active subscription for given params", utils.LogAttr("params", chainMessage.GetRPCMessage().GetParams())) + + return "", common.SubscriptionNotFoundError +} + +func (cwsm *ConsumerWSSubscriptionManager) verifyAndDisconnectDappFromSubscription( + webSocketCtx context.Context, + dappKey string, + hashedParams string, + unsubscribeRelayDataGetter func() (*unsubscribeRelayData, error), +) error { + // Must be called under lock + if _, ok := cwsm.connectedDapps[dappKey]; !ok { + utils.LavaFormatDebug("webSocket has no active subscriptions", + utils.LogAttr("GUID", webSocketCtx), + utils.LogAttr("dappKey", dappKey), + utils.LogAttr("hashedParams", utils.ToHexString(hashedParams)), + ) + + return nil + } + + if _, ok := cwsm.connectedDapps[dappKey][hashedParams]; !ok { + utils.LavaFormatDebug("no active subscription found for given dapp", + utils.LogAttr("GUID", webSocketCtx), + utils.LogAttr("dappKey", dappKey), + utils.LogAttr("hashedParams", utils.ToHexString(hashedParams)), + utils.LogAttr("cwsm.connectedDapps", cwsm.connectedDapps), + ) + + return nil + } + + if _, ok := cwsm.activeSubscriptions[hashedParams]; !ok { + utils.LavaFormatError("no active subscription found, but the subscription is found in connectedDapps, this should never happen", nil, + utils.LogAttr("GUID", webSocketCtx), + utils.LogAttr("dappKey", dappKey), + utils.LogAttr("hashedParams", utils.ToHexString(hashedParams)), + ) + + return common.SubscriptionNotFoundError + } + + if _, ok := cwsm.activeSubscriptions[hashedParams].connectedDappKeys[dappKey]; !ok { + utils.LavaFormatError("active subscription found, but the dappKey is not found in it's connectedDapps, this should never happen", nil, + utils.LogAttr("GUID", webSocketCtx), + utils.LogAttr("dappKey", dappKey), + utils.LogAttr("hashedParams", utils.ToHexString(hashedParams)), + ) + + return common.SubscriptionNotFoundError + } + + cwsm.connectedDapps[dappKey][hashedParams].Close() // close the subscription msgs channel + + delete(cwsm.connectedDapps[dappKey], hashedParams) + utils.LavaFormatTrace("deleted hashedParams from connected dapp's active subscriptions", + utils.LogAttr("GUID", webSocketCtx), + utils.LogAttr("dappKey", dappKey), + utils.LogAttr("hashedParams", utils.ToHexString(hashedParams)), + utils.LogAttr("connectedDappActiveSubs", cwsm.connectedDapps[dappKey]), + ) + + if len(cwsm.connectedDapps[dappKey]) == 0 { + delete(cwsm.connectedDapps, dappKey) + utils.LavaFormatTrace("deleted dappKey from connected dapps", + utils.LogAttr("GUID", webSocketCtx), + utils.LogAttr("dappKey", dappKey), + ) + } + + delete(cwsm.activeSubscriptions[hashedParams].connectedDappKeys, dappKey) + utils.LavaFormatTrace("deleted dappKey from active subscriptions", + utils.LogAttr("GUID", webSocketCtx), + utils.LogAttr("dappKey", dappKey), + utils.LogAttr("hashedParams", utils.ToHexString(hashedParams)), + utils.LogAttr("activeSubConnectedDapps", cwsm.activeSubscriptions[hashedParams].connectedDappKeys), + ) + + if len(cwsm.activeSubscriptions[hashedParams].connectedDappKeys) == 0 { + // No more dapps are connected, close the subscription with provider + utils.LavaFormatTrace("no more dapps are connected to subscription, closing subscription", + utils.LogAttr("GUID", webSocketCtx), + utils.LogAttr("dappKey", dappKey), + utils.LogAttr("hashedParams", utils.ToHexString(hashedParams)), + ) + + var unsubscribeData *unsubscribeRelayData + + if unsubscribeRelayDataGetter != nil { + var err error + unsubscribeData, err = unsubscribeRelayDataGetter() + if err != nil { + return utils.LavaFormatError("got error from getUnsubscribeRelay function", err, + utils.LogAttr("GUID", webSocketCtx), + utils.LogAttr("dappKey", dappKey), + utils.LogAttr("hashedParams", utils.ToHexString(hashedParams)), + ) + } + } + + // Close subscription with provider + go func() { + // In a go routine because the reading routine is also locking on new messages from the node + // So we need to release the lock here, and let the last message be sent, and then the channel will be released + cwsm.activeSubscriptions[hashedParams].closeSubscriptionChan <- unsubscribeData + }() + } + + return nil +} diff --git a/protocol/chainlib/consumer_ws_subscription_manager_test.go b/protocol/chainlib/consumer_ws_subscription_manager_test.go new file mode 100644 index 0000000000..02de8604e5 --- /dev/null +++ b/protocol/chainlib/consumer_ws_subscription_manager_test.go @@ -0,0 +1,714 @@ +package chainlib + +import ( + "context" + "strconv" + "strings" + "sync" + "testing" + "time" + + "github.com/lavanet/lava/v2/protocol/chainlib/extensionslib" + "github.com/lavanet/lava/v2/protocol/common" + "github.com/lavanet/lava/v2/protocol/lavaprotocol" + "github.com/lavanet/lava/v2/protocol/lavasession" + "github.com/lavanet/lava/v2/protocol/metrics" + "github.com/lavanet/lava/v2/protocol/provideroptimizer" + "github.com/lavanet/lava/v2/utils" + "github.com/lavanet/lava/v2/utils/rand" + pairingtypes "github.com/lavanet/lava/v2/x/pairing/types" + spectypes "github.com/lavanet/lava/v2/x/spec/types" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + gomock "go.uber.org/mock/gomock" +) + +const ( + numberOfParallelSubscriptions = 10 + uniqueId = "1234" +) + +func TestConsumerWSSubscriptionManagerParallelSubscriptionsOnSameDappIdIp(t *testing.T) { + playbook := []struct { + name string + specId string + apiInterface string + connectionType string + subscriptionRequestData1 []byte + subscriptionFirstReply1 []byte + subscriptionRequestData2 []byte + subscriptionFirstReply2 []byte + }{ + { + name: "TendermintRPC", + specId: "LAV1", + apiInterface: spectypes.APIInterfaceTendermintRPC, + connectionType: "", + subscriptionRequestData1: []byte(`{"jsonrpc":"2.0","id":3,"method":"subscribe","params":{"query":"tm.event='NewBlock'"}}`), + subscriptionFirstReply1: []byte(`{"jsonrpc":"2.0","id":3,"result":{}}`), + subscriptionRequestData2: []byte(`{"jsonrpc":"2.0","id":4,"method":"subscribe","params":{"query":"tm.event= 'NewBlock'"}}`), + subscriptionFirstReply2: []byte(`{"jsonrpc":"2.0","id":4,"result":{}}`), + }, + } + + for _, play := range playbook { + t.Run(play.name, func(t *testing.T) { + ts := SetupForTests(t, 1, play.specId, "../../") + + dapp := "dapp" + ip := "127.0.0.1" + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + chainParser, _, _, _, _, err := CreateChainLibMocks(ts.Ctx, play.specId, play.apiInterface, nil, nil, "../../", nil) + require.NoError(t, err) + + chainMessage1, err := chainParser.ParseMsg("", play.subscriptionRequestData1, play.connectionType, nil, extensionslib.ExtensionInfo{LatestBlock: 0}) + require.NoError(t, err) + + relaySender := NewMockRelaySender(ctrl) + relaySender. + EXPECT(). + CreateDappKey(gomock.Any(), gomock.Any()). + DoAndReturn(func(dappID, consumerIp string) string { + return dappID + consumerIp + }). + AnyTimes() + relaySender. + EXPECT(). + SetConsistencySeenBlock(gomock.Any(), gomock.Any()). + AnyTimes() + + relaySender. + EXPECT(). + ParseRelay(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + Return(chainMessage1, nil, nil, nil). + AnyTimes() + + mockRelayerClient1 := pairingtypes.NewMockRelayer_RelaySubscribeClient(ctrl) + + relayResult1 := &common.RelayResult{ + ReplyServer: mockRelayerClient1, + ProviderInfo: common.ProviderInfo{ + ProviderAddress: ts.Providers[0].Addr.String(), + }, + Reply: &pairingtypes.RelayReply{ + Data: play.subscriptionFirstReply1, + LatestBlock: 1, + }, + Request: &pairingtypes.RelayRequest{ + RelayData: &pairingtypes.RelayPrivateData{ + Data: play.subscriptionRequestData1, + }, + RelaySession: &pairingtypes.RelaySession{}, + }, + } + + relayResult1.Reply, err = lavaprotocol.SignRelayResponse(ts.Consumer.Addr, *relayResult1.Request, ts.Providers[0].SK, relayResult1.Reply, true) + require.NoError(t, err) + + mockRelayerClient1. + EXPECT(). + Context(). + Return(context.Background()). + AnyTimes() + + mockRelayerClient1. + EXPECT(). + RecvMsg(gomock.Any()). + DoAndReturn(func(msg interface{}) error { + relayReply, ok := msg.(*pairingtypes.RelayReply) + require.True(t, ok) + + *relayReply = *relayResult1.Reply + return nil + }). + AnyTimes() + + relaySender. + EXPECT(). + SendParsedRelay(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + Return(relayResult1, nil). + Times(1) // Should call SendParsedRelay, because it is the first time we subscribe + + consumerSessionManager := CreateConsumerSessionManager(play.specId, play.apiInterface, ts.Consumer.Addr.String()) + + // Create a new ConsumerWSSubscriptionManager + manager := NewConsumerWSSubscriptionManager(consumerSessionManager, relaySender, nil, play.connectionType, chainParser, lavasession.NewActiveSubscriptionProvidersStorage()) + uniqueIdentifiers := make([]string, numberOfParallelSubscriptions) + wg := sync.WaitGroup{} + wg.Add(numberOfParallelSubscriptions) + // Start a new subscription for the first time, called SendParsedRelay once while in parallel calling 10 times subscribe with the same message + // expected result is to have SendParsedRelay only once and 9 other messages waiting the broadcast. + for i := 0; i < numberOfParallelSubscriptions; i++ { + uniqueIdentifiers[i] = strconv.FormatUint(utils.GenerateUniqueIdentifier(), 10) + // sending + go func(index int) { + ctx := utils.WithUniqueIdentifier(ts.Ctx, utils.GenerateUniqueIdentifier()) + var repliesChan <-chan *pairingtypes.RelayReply + var firstReply *pairingtypes.RelayReply + firstReply, repliesChan, err = manager.StartSubscription(ctx, chainMessage1, nil, nil, dapp, ip, uniqueIdentifiers[index], nil) + go func() { + for subMsg := range repliesChan { + utils.LavaFormatInfo("got reply for index", utils.LogAttr("index", index)) + require.Equal(t, string(play.subscriptionFirstReply1), string(subMsg.Data)) + } + }() + assert.NoError(t, err) + assert.Equal(t, string(play.subscriptionFirstReply1), string(firstReply.Data)) + assert.NotNil(t, repliesChan) + wg.Done() + }(i) + } + wg.Wait() + + // now we have numberOfParallelSubscriptions subscriptions currently running + require.Len(t, manager.connectedDapps, numberOfParallelSubscriptions) + // remove one + err = manager.Unsubscribe(ts.Ctx, chainMessage1, nil, nil, dapp, ip, uniqueIdentifiers[0], nil) + require.NoError(t, err) + // now we have numberOfParallelSubscriptions - 1 + require.Len(t, manager.connectedDapps, numberOfParallelSubscriptions-1) + // check we still have an active subscription. + require.Len(t, manager.activeSubscriptions, 1) + + // same flow for unsubscribe all + err = manager.UnsubscribeAll(ts.Ctx, dapp, ip, uniqueIdentifiers[1], nil) + require.NoError(t, err) + // now we have numberOfParallelSubscriptions - 2 + require.Len(t, manager.connectedDapps, numberOfParallelSubscriptions-2) + // check we still have an active subscription. + require.Len(t, manager.activeSubscriptions, 1) + }) + } +} + +func TestConsumerWSSubscriptionManagerParallelSubscriptions(t *testing.T) { + playbook := []struct { + name string + specId string + apiInterface string + connectionType string + subscriptionRequestData1 []byte + subscriptionFirstReply1 []byte + subscriptionRequestData2 []byte + subscriptionFirstReply2 []byte + }{ + { + name: "TendermintRPC", + specId: "LAV1", + apiInterface: spectypes.APIInterfaceTendermintRPC, + connectionType: "", + subscriptionRequestData1: []byte(`{"jsonrpc":"2.0","id":3,"method":"subscribe","params":{"query":"tm.event='NewBlock'"}}`), + subscriptionFirstReply1: []byte(`{"jsonrpc":"2.0","id":3,"result":{}}`), + subscriptionRequestData2: []byte(`{"jsonrpc":"2.0","id":4,"method":"subscribe","params":{"query":"tm.event= 'NewBlock'"}}`), + subscriptionFirstReply2: []byte(`{"jsonrpc":"2.0","id":4,"result":{}}`), + }, + } + + for _, play := range playbook { + t.Run(play.name, func(t *testing.T) { + ts := SetupForTests(t, 1, play.specId, "../../") + + dapp := "dapp" + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + chainParser, _, _, _, _, err := CreateChainLibMocks(ts.Ctx, play.specId, play.apiInterface, nil, nil, "../../", nil) + require.NoError(t, err) + + chainMessage1, err := chainParser.ParseMsg("", play.subscriptionRequestData1, play.connectionType, nil, extensionslib.ExtensionInfo{LatestBlock: 0}) + require.NoError(t, err) + + relaySender := NewMockRelaySender(ctrl) + relaySender. + EXPECT(). + CreateDappKey(gomock.Any(), gomock.Any()). + DoAndReturn(func(dappID, consumerIp string) string { + return dappID + consumerIp + }). + AnyTimes() + relaySender. + EXPECT(). + SetConsistencySeenBlock(gomock.Any(), gomock.Any()). + AnyTimes() + + relaySender. + EXPECT(). + ParseRelay(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + Return(chainMessage1, nil, nil, nil). + AnyTimes() + + mockRelayerClient1 := pairingtypes.NewMockRelayer_RelaySubscribeClient(ctrl) + + relayResult1 := &common.RelayResult{ + ReplyServer: mockRelayerClient1, + ProviderInfo: common.ProviderInfo{ + ProviderAddress: ts.Providers[0].Addr.String(), + }, + Reply: &pairingtypes.RelayReply{ + Data: play.subscriptionFirstReply1, + LatestBlock: 1, + }, + Request: &pairingtypes.RelayRequest{ + RelayData: &pairingtypes.RelayPrivateData{ + Data: play.subscriptionRequestData1, + }, + RelaySession: &pairingtypes.RelaySession{}, + }, + } + + relayResult1.Reply, err = lavaprotocol.SignRelayResponse(ts.Consumer.Addr, *relayResult1.Request, ts.Providers[0].SK, relayResult1.Reply, true) + require.NoError(t, err) + + mockRelayerClient1. + EXPECT(). + Context(). + Return(context.Background()). + AnyTimes() + + mockRelayerClient1. + EXPECT(). + RecvMsg(gomock.Any()). + DoAndReturn(func(msg interface{}) error { + relayReply, ok := msg.(*pairingtypes.RelayReply) + require.True(t, ok) + + *relayReply = *relayResult1.Reply + return nil + }). + AnyTimes() + + relaySender. + EXPECT(). + SendParsedRelay(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + Return(relayResult1, nil). + Times(1) // Should call SendParsedRelay, because it is the first time we subscribe + + consumerSessionManager := CreateConsumerSessionManager(play.specId, play.apiInterface, ts.Consumer.Addr.String()) + + // Create a new ConsumerWSSubscriptionManager + manager := NewConsumerWSSubscriptionManager(consumerSessionManager, relaySender, nil, play.connectionType, chainParser, lavasession.NewActiveSubscriptionProvidersStorage()) + + wg := sync.WaitGroup{} + wg.Add(10) + // Start a new subscription for the first time, called SendParsedRelay once while in parallel calling 10 times subscribe with the same message + // expected result is to have SendParsedRelay only once and 9 other messages waiting the broadcast. + for i := 0; i < 10; i++ { + // sending + go func(index int) { + ctx := utils.WithUniqueIdentifier(ts.Ctx, utils.GenerateUniqueIdentifier()) + var repliesChan <-chan *pairingtypes.RelayReply + var firstReply *pairingtypes.RelayReply + firstReply, repliesChan, err = manager.StartSubscription(ctx, chainMessage1, nil, nil, dapp+strconv.Itoa(index), ts.Consumer.Addr.String(), uniqueId, nil) + go func() { + for subMsg := range repliesChan { + require.Equal(t, string(play.subscriptionFirstReply1), string(subMsg.Data)) + } + }() + assert.NoError(t, err) + assert.Equal(t, string(play.subscriptionFirstReply1), string(firstReply.Data)) + assert.NotNil(t, repliesChan) + wg.Done() + }(i) + } + wg.Wait() + }) + } +} + +func TestConsumerWSSubscriptionManager(t *testing.T) { + // This test does the following: + // 1. Create a new ConsumerWSSubscriptionManager + // 2. Start a new subscription for the first time -> should call SendParsedRelay once + // 3. Start a subscription again, same params, same dappKey -> should not call SendParsedRelay + // 4. Start a subscription again, same params, different dappKey -> should not call SendParsedRelay + // 5. Start a new subscription, different params, same dappKey -> should call SendParsedRelay + // 6. Start a subscription again, different params, different dappKey -> should call SendParsedRelay + // 7. Unsubscribe from the first subscription -> should call CancelSubscriptionContext and SendParsedRelay + // 8. Unsubscribe from the second subscription -> should call CancelSubscriptionContext and SendParsedRelay + playbook := []struct { + name string + specId string + apiInterface string + connectionType string + subscriptionRequestData1 []byte + subscriptionId1 string + subscriptionFirstReply1 []byte + unsubscribeMessage1 []byte + subscriptionRequestData2 []byte + subscriptionId2 string + subscriptionFirstReply2 []byte + unsubscribeMessage2 []byte + }{ + { + name: "Lava_TendermintRPC", + specId: "LAV1", + apiInterface: spectypes.APIInterfaceTendermintRPC, + connectionType: "", + + subscriptionRequestData1: []byte(`{"jsonrpc":"2.0","id":3,"method":"subscribe","params":{"query":"tm.event='NewBlock'"}}`), + subscriptionId1: `{"query":"tm.event='NewBlock'"}`, + subscriptionFirstReply1: []byte(`{"jsonrpc":"2.0","id":3,"result":{}}`), + unsubscribeMessage1: []byte(`{"jsonrpc":"2.0","method":"unsubscribe","params":{"query":"tm.event='NewBlock'"},"id":1}`), + + subscriptionRequestData2: []byte(`{"jsonrpc":"2.0","id":4,"method":"subscribe","params":{"query":"tm.event= 'NewBlock'"}}`), + subscriptionId2: `{"query":"tm.event= 'NewBlock'"}`, + subscriptionFirstReply2: []byte(`{"jsonrpc":"2.0","id":4,"result":{}}`), + unsubscribeMessage2: []byte(`{"jsonrpc":"2.0","method":"unsubscribe","params":{"query":"tm.event= 'NewBlock'"},"id":1}`), + }, + { + name: "Ethereum_JsonRPC", + specId: "ETH1", + apiInterface: spectypes.APIInterfaceJsonRPC, + connectionType: "POST", + + subscriptionRequestData1: []byte(`{"jsonrpc":"2.0","id":5,"method":"eth_subscribe","params":["newHeads"]}`), + subscriptionId1: "0x1234567890", + subscriptionFirstReply1: []byte(`{"jsonrpc":"2.0","id":5,"result":["0x1234567890"]}`), + unsubscribeMessage1: []byte(`{"jsonrpc":"2.0","method":"eth_unsubscribe","params":["0x1234567890"],"id":1}`), + + subscriptionRequestData2: []byte(`{"jsonrpc":"2.0","id":6,"method":"eth_subscribe","params":["logs"]}`), + subscriptionId2: "0x2134567890", + subscriptionFirstReply2: []byte(`{"jsonrpc":"2.0","id":6,"result":["0x2134567890"]}`), + unsubscribeMessage2: []byte(`{"jsonrpc":"2.0","method":"eth_unsubscribe","params":["0x2134567890"],"id":1}`), + }, + } + + for _, play := range playbook { + t.Run(play.name, func(t *testing.T) { + ts := SetupForTests(t, 1, play.specId, "../../") + unsubscribeMessageWg := sync.WaitGroup{} + ctx, cancel := context.WithCancel(ts.Ctx) + defer func() { + cancel() + unsubscribeMessageWg.Wait() + }() + + listenForExpectedMessages := func(ctx context.Context, repliesChan <-chan *pairingtypes.RelayReply, expectedMsg string) { + select { + case <-time.After(5 * time.Second): + require.Fail(t, "Timeout waiting for messages", "Expected message: %s", expectedMsg) + return + case subMsg := <-repliesChan: + require.Equal(t, expectedMsg, string(subMsg.Data)) + case <-ctx.Done(): + return + } + } + + expectNoMoreMessages := func(ctx context.Context, repliesChan <-chan *pairingtypes.RelayReply) { + msgCounter := 0 + select { + case <-ctx.Done(): + return + case <-repliesChan: + msgCounter++ + if msgCounter > 2 { + require.Fail(t, "Unexpected message received") + } + } + } + + dapp1 := "dapp1" + dapp2 := "dapp2" + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + chainParser, _, _, _, _, err := CreateChainLibMocks(ts.Ctx, play.specId, play.apiInterface, nil, nil, "../../", nil) + require.NoError(t, err) + + unsubscribeChainMessage1, err := chainParser.ParseMsg("", play.unsubscribeMessage1, play.connectionType, nil, extensionslib.ExtensionInfo{LatestBlock: 0}) + require.NoError(t, err) + + subscribeChainMessage1, err := chainParser.ParseMsg("", play.subscriptionRequestData1, play.connectionType, nil, extensionslib.ExtensionInfo{LatestBlock: 0}) + require.NoError(t, err) + + relaySender := NewMockRelaySender(ctrl) + relaySender. + EXPECT(). + SendParsedRelay(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Cond(func(x any) bool { + relayPrivateData, ok := x.(*pairingtypes.RelayPrivateData) + if !ok || relayPrivateData == nil { + return false + } + + if strings.Contains(string(relayPrivateData.Data), "unsubscribe") { + unsubscribeMessageWg.Done() + } + + // Always return false, because we don't to use this mock's default return for the calls + return false + })). + AnyTimes() + + relaySender. + EXPECT(). + CreateDappKey(gomock.Any(), gomock.Any()). + DoAndReturn(func(dappID, consumerIp string) string { + return dappID + consumerIp + }). + AnyTimes() + + relaySender. + EXPECT(). + SetConsistencySeenBlock(gomock.Any(), gomock.Any()). + AnyTimes() + + relaySender. + EXPECT(). + ParseRelay(gomock.Any(), gomock.Any(), gomock.Cond(func(x any) bool { + reqData, ok := x.(string) + require.True(t, ok) + areEqual := reqData == string(play.unsubscribeMessage1) + return areEqual + }), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + Return(unsubscribeChainMessage1, nil, &pairingtypes.RelayPrivateData{ + Data: play.unsubscribeMessage1, + }, nil). + AnyTimes() + + relaySender. + EXPECT(). + ParseRelay(gomock.Any(), gomock.Any(), gomock.Cond(func(x any) bool { + reqData, ok := x.(string) + require.True(t, ok) + areEqual := reqData == string(play.subscriptionRequestData1) + return areEqual + }), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + Return(subscribeChainMessage1, nil, nil, nil). + AnyTimes() + + mockRelayerClient1 := pairingtypes.NewMockRelayer_RelaySubscribeClient(ctrl) + + relayResult1 := &common.RelayResult{ + ReplyServer: mockRelayerClient1, + ProviderInfo: common.ProviderInfo{ + ProviderAddress: ts.Providers[0].Addr.String(), + }, + Reply: &pairingtypes.RelayReply{ + Data: play.subscriptionFirstReply1, + LatestBlock: 1, + }, + Request: &pairingtypes.RelayRequest{ + RelayData: &pairingtypes.RelayPrivateData{ + Data: play.subscriptionRequestData1, + }, + RelaySession: &pairingtypes.RelaySession{}, + }, + } + + relayResult1.Reply, err = lavaprotocol.SignRelayResponse(ts.Consumer.Addr, *relayResult1.Request, ts.Providers[0].SK, relayResult1.Reply, true) + require.NoError(t, err) + + mockRelayerClient1. + EXPECT(). + Context(). + Return(context.Background()). + AnyTimes() + + mockRelayerClient1. + EXPECT(). + RecvMsg(gomock.Any()). + DoAndReturn(func(msg interface{}) error { + relayReply, ok := msg.(*pairingtypes.RelayReply) + require.True(t, ok) + + *relayReply = *relayResult1.Reply + return nil + }). + AnyTimes() + + relaySender. + EXPECT(). + SendParsedRelay(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + Return(relayResult1, nil). + Times(1) // Should call SendParsedRelay, because it is the first time we subscribe + + consumerSessionManager := CreateConsumerSessionManager(play.specId, play.apiInterface, ts.Consumer.Addr.String()) + + // Create a new ConsumerWSSubscriptionManager + manager := NewConsumerWSSubscriptionManager(consumerSessionManager, relaySender, nil, play.connectionType, chainParser, lavasession.NewActiveSubscriptionProvidersStorage()) + + // Start a new subscription for the first time, called SendParsedRelay once + ctx = utils.WithUniqueIdentifier(ctx, utils.GenerateUniqueIdentifier()) + firstReply, repliesChan1, err := manager.StartSubscription(ctx, subscribeChainMessage1, nil, nil, dapp1, ts.Consumer.Addr.String(), uniqueId, nil) + assert.NoError(t, err) + unsubscribeMessageWg.Add(1) + assert.Equal(t, string(play.subscriptionFirstReply1), string(firstReply.Data)) + assert.NotNil(t, repliesChan1) + + listenForExpectedMessages(ctx, repliesChan1, string(play.subscriptionFirstReply1)) + + relaySender. + EXPECT(). + SendParsedRelay(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + Return(relayResult1, nil). + Times(0) // Should not call SendParsedRelay, because it is already subscribed + + // Start a subscription again, same params, same dappKey, should not call SendParsedRelay + ctx = utils.WithUniqueIdentifier(ctx, utils.GenerateUniqueIdentifier()) + firstReply, repliesChan2, err := manager.StartSubscription(ctx, subscribeChainMessage1, nil, nil, dapp1, ts.Consumer.Addr.String(), uniqueId, nil) + assert.NoError(t, err) + assert.Equal(t, string(play.subscriptionFirstReply1), string(firstReply.Data)) + assert.Nil(t, repliesChan2) // Same subscription, same dappKey, no need for a new channel + + listenForExpectedMessages(ctx, repliesChan1, string(play.subscriptionFirstReply1)) + + // Start a subscription again, same params, different dappKey, should not call SendParsedRelay + ctx = utils.WithUniqueIdentifier(ctx, utils.GenerateUniqueIdentifier()) + firstReply, repliesChan3, err := manager.StartSubscription(ctx, subscribeChainMessage1, nil, nil, dapp2, ts.Consumer.Addr.String(), uniqueId, nil) + assert.NoError(t, err) + assert.Equal(t, string(play.subscriptionFirstReply1), string(firstReply.Data)) + assert.NotNil(t, repliesChan3) // Same subscription, but different dappKey, so will create new channel + + listenForExpectedMessages(ctx, repliesChan1, string(play.subscriptionFirstReply1)) + listenForExpectedMessages(ctx, repliesChan3, string(play.subscriptionFirstReply1)) + + // Prepare for the next subscription + unsubscribeChainMessage2, err := chainParser.ParseMsg("", play.unsubscribeMessage2, play.connectionType, nil, extensionslib.ExtensionInfo{LatestBlock: 0}) + require.NoError(t, err) + + subscribeChainMessage2, err := chainParser.ParseMsg("", play.subscriptionRequestData2, play.connectionType, nil, extensionslib.ExtensionInfo{LatestBlock: 0}) + require.NoError(t, err) + + relaySender. + EXPECT(). + ParseRelay(gomock.Any(), gomock.Any(), gomock.Cond(func(x any) bool { + reqData, ok := x.(string) + require.True(t, ok) + areEqual := reqData == string(play.unsubscribeMessage2) + return areEqual + }), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + Return(unsubscribeChainMessage2, nil, &pairingtypes.RelayPrivateData{ + Data: play.unsubscribeMessage2, + }, nil). + AnyTimes() + + relaySender. + EXPECT(). + ParseRelay(gomock.Any(), gomock.Any(), gomock.Cond(func(x any) bool { + reqData, ok := x.(string) + require.True(t, ok) + areEqual := reqData == string(play.subscriptionRequestData2) + return areEqual + }), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + Return(subscribeChainMessage2, nil, nil, nil). + AnyTimes() + + mockRelayerClient2 := pairingtypes.NewMockRelayer_RelaySubscribeClient(ctrl) + mockRelayerClient2. + EXPECT(). + Context(). + Return(context.Background()). + AnyTimes() + + relayResult2 := &common.RelayResult{ + ReplyServer: mockRelayerClient2, + ProviderInfo: common.ProviderInfo{ + ProviderAddress: ts.Providers[0].Addr.String(), + }, + Reply: &pairingtypes.RelayReply{ + Data: play.subscriptionFirstReply2, + }, + Request: &pairingtypes.RelayRequest{ + RelayData: &pairingtypes.RelayPrivateData{ + Data: play.subscriptionRequestData2, + }, + RelaySession: &pairingtypes.RelaySession{}, + }, + } + + relayResult2.Reply, err = lavaprotocol.SignRelayResponse(ts.Consumer.Addr, *relayResult2.Request, ts.Providers[0].SK, relayResult2.Reply, true) + require.NoError(t, err) + + mockRelayerClient2. + EXPECT(). + RecvMsg(gomock.Any()). + DoAndReturn(func(msg interface{}) error { + relayReply, ok := msg.(*pairingtypes.RelayReply) + require.True(t, ok) + + *relayReply = *relayResult2.Reply + return nil + }). + AnyTimes() + + relaySender. + EXPECT(). + SendParsedRelay(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + Return(relayResult2, nil). + Times(1) // Should call SendParsedRelay, because it is the first time we subscribe + + // Start a subscription again, different params, same dappKey, should call SendParsedRelay + ctx = utils.WithUniqueIdentifier(ctx, utils.GenerateUniqueIdentifier()) + firstReply, repliesChan4, err := manager.StartSubscription(ctx, subscribeChainMessage2, nil, nil, dapp1, ts.Consumer.Addr.String(), uniqueId, nil) + assert.NoError(t, err) + unsubscribeMessageWg.Add(1) + assert.Equal(t, string(play.subscriptionFirstReply2), string(firstReply.Data)) + assert.NotNil(t, repliesChan4) // New subscription, new channel + + listenForExpectedMessages(ctx, repliesChan1, string(play.subscriptionFirstReply1)) + listenForExpectedMessages(ctx, repliesChan3, string(play.subscriptionFirstReply1)) + listenForExpectedMessages(ctx, repliesChan4, string(play.subscriptionFirstReply2)) + + // Prepare for unsubscribe from the first subscription + relaySender. + EXPECT(). + SendParsedRelay(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + Return(relayResult1, nil). + Times(0) // Should call SendParsedRelay, because it unsubscribed + + ctx = utils.WithUniqueIdentifier(ctx, utils.GenerateUniqueIdentifier()) + err = manager.Unsubscribe(ctx, unsubscribeChainMessage1, nil, relayResult1.Request.RelayData, dapp2, ts.Consumer.Addr.String(), uniqueId, nil) + require.NoError(t, err) + + listenForExpectedMessages(ctx, repliesChan1, string(play.subscriptionFirstReply1)) + expectNoMoreMessages(ctx, repliesChan3) + listenForExpectedMessages(ctx, repliesChan4, string(play.subscriptionFirstReply2)) + + wg := sync.WaitGroup{} + wg.Add(2) + + relaySender. + EXPECT(). + CancelSubscriptionContext(gomock.Any()). + AnyTimes() + + // Prepare for unsubscribe from the second subscription + relaySender. + EXPECT(). + SendParsedRelay(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, dappID string, consumerIp string, analytics *metrics.RelayMetrics, chainMessage ChainMessage, directiveHeaders map[string]string, relayRequestData *pairingtypes.RelayPrivateData) (relayResult *common.RelayResult, errRet error) { + wg.Done() + return relayResult2, nil + }). + Times(2) // Should call SendParsedRelay, because it unsubscribed + + ctx = utils.WithUniqueIdentifier(ctx, utils.GenerateUniqueIdentifier()) + err = manager.UnsubscribeAll(ctx, dapp1, ts.Consumer.Addr.String(), uniqueId, nil) + require.NoError(t, err) + + expectNoMoreMessages(ctx, repliesChan1) + expectNoMoreMessages(ctx, repliesChan3) + expectNoMoreMessages(ctx, repliesChan4) + + // Because the SendParsedRelay is called in a goroutine, we need to wait for it to finish + wg.Wait() + }) + } +} + +func CreateConsumerSessionManager(chainID, apiInterface, consumerPublicAddress string) *lavasession.ConsumerSessionManager { + rand.InitRandomSeed() + baseLatency := common.AverageWorldLatency / 2 // we want performance to be half our timeout or better + return lavasession.NewConsumerSessionManager( + &lavasession.RPCEndpoint{NetworkAddress: "stub", ChainID: chainID, ApiInterface: apiInterface, TLSEnabled: false, HealthCheckPath: "/", Geolocation: 0}, + provideroptimizer.NewProviderOptimizer(provideroptimizer.STRATEGY_BALANCED, 0, baseLatency, 1), + nil, nil, consumerPublicAddress, + lavasession.NewActiveSubscriptionProvidersStorage(), + ) +} diff --git a/protocol/chainlib/grpc.go b/protocol/chainlib/grpc.go index f97db4cf1a..5465ddaee4 100644 --- a/protocol/chainlib/grpc.go +++ b/protocol/chainlib/grpc.go @@ -205,6 +205,7 @@ func (*GrpcChainParser) newChainMessage(api *spectypes.Api, requestedBlock int64 latestRequestedBlock: requestedBlock, apiCollection: apiCollection, resultErrorParsingMethod: grpcMessage.CheckResponseError, + parseDirective: GetParseDirective(api, apiCollection), } return nodeMsg } diff --git a/protocol/chainlib/grpc_test.go b/protocol/chainlib/grpc_test.go index cc92d5a371..934cf82809 100644 --- a/protocol/chainlib/grpc_test.go +++ b/protocol/chainlib/grpc_test.go @@ -142,7 +142,7 @@ func TestGrpcChainProxy(t *testing.T) { // Handle the incoming request and provide the desired response wasCalled = true }) - chainParser, chainProxy, chainFetcher, closeServer, _, err := CreateChainLibMocks(ctx, "LAV1", spectypes.APIInterfaceGrpc, serverHandle, "../../", nil) + chainParser, chainProxy, chainFetcher, closeServer, _, err := CreateChainLibMocks(ctx, "LAV1", spectypes.APIInterfaceGrpc, serverHandle, nil, "../../", nil) require.NoError(t, err) require.NotNil(t, chainParser) require.NotNil(t, chainProxy) @@ -169,15 +169,16 @@ func TestParsingRequestedBlocksHeadersGrpc(t *testing.T) { w.WriteHeader(244591) } }) - chainParser, chainRouter, _, closeServer, _, err := CreateChainLibMocks(ctx, "LAV1", spectypes.APIInterfaceGrpc, serverHandler, "../../", nil) + chainParser, chainRouter, _, closeServer, _, err := CreateChainLibMocks(ctx, "LAV1", spectypes.APIInterfaceGrpc, serverHandler, nil, "../../", nil) require.NoError(t, err) defer func() { if closeServer != nil { closeServer() } }() - parsingForCrafting, collectionData, ok := chainParser.GetParsingByTag(spectypes.FUNCTION_TAG_GET_BLOCKNUM) + parsingForCrafting, apiCollection, ok := chainParser.GetParsingByTag(spectypes.FUNCTION_TAG_GET_BLOCKNUM) require.True(t, ok) + collectionData := apiCollection.CollectionData headerParsingDirective, _, ok := chainParser.GetParsingByTag(spectypes.FUNCTION_TAG_SET_LATEST_IN_METADATA) callbackHeaderNameToCheck = headerParsingDirective.GetApiName() // this causes the callback to modify the response to simulate a real behavior require.True(t, ok) @@ -237,15 +238,16 @@ func TestSettingBlocksHeadersGrpc(t *testing.T) { w.WriteHeader(244591) } }) - chainParser, chainRouter, _, closeServer, _, err := CreateChainLibMocks(ctx, "LAV1", spectypes.APIInterfaceGrpc, serverHandler, "../../", nil) + chainParser, chainRouter, _, closeServer, _, err := CreateChainLibMocks(ctx, "LAV1", spectypes.APIInterfaceGrpc, serverHandler, nil, "../../", nil) require.NoError(t, err) defer func() { if closeServer != nil { closeServer() } }() - parsingForCrafting, collectionData, ok := chainParser.GetParsingByTag(spectypes.FUNCTION_TAG_GET_BLOCKNUM) + parsingForCrafting, apiCollection, ok := chainParser.GetParsingByTag(spectypes.FUNCTION_TAG_GET_BLOCKNUM) require.True(t, ok) + collectionData := apiCollection.CollectionData headerParsingDirective, _, ok := chainParser.GetParsingByTag(spectypes.FUNCTION_TAG_SET_LATEST_IN_METADATA) callbackHeaderNameToCheck = headerParsingDirective.GetApiName() // this causes the callback to modify the response to simulate a real behavior require.True(t, ok) diff --git a/protocol/chainlib/jsonRPC.go b/protocol/chainlib/jsonRPC.go index 400a30edd5..50d02375d3 100644 --- a/protocol/chainlib/jsonRPC.go +++ b/protocol/chainlib/jsonRPC.go @@ -114,7 +114,13 @@ func (apip *JsonRPCChainParser) ParseMsg(url string, data []byte, connectionType // Check api is supported and save it in nodeMsg apiCont, err := apip.getSupportedApi(msg.Method, connectionType, internalPath) if err != nil { - utils.LavaFormatDebug("getSupportedApi jsonrpc failed", utils.LogAttr("method", msg.Method), utils.LogAttr("error", err)) + utils.LavaFormatDebug("getSupportedApi jsonrpc failed", + utils.LogAttr("method", msg.Method), + utils.LogAttr("connectionType", connectionType), + utils.LogAttr("internalPath", internalPath), + utils.LogAttr("error", err), + ) + return nil, err } @@ -213,6 +219,7 @@ func (*JsonRPCChainParser) newBatchChainMessage(serviceApi *spectypes.Api, reque msg: &batchMessage, earliestRequestedBlock: earliestRequestedBlock, resultErrorParsingMethod: rpcInterfaceMessages.CheckResponseErrorForJsonRpcBatch, + parseDirective: nil, } return nodeMsg, err } @@ -224,6 +231,7 @@ func (*JsonRPCChainParser) newChainMessage(serviceApi *spectypes.Api, requestedB latestRequestedBlock: requestedBlock, msg: msg, resultErrorParsingMethod: msg.CheckResponseError, + parseDirective: GetParseDirective(serviceApi, apiCollection), } return nodeMsg } @@ -287,11 +295,12 @@ func (apip *JsonRPCChainParser) ChainBlockStats() (allowedBlockLagForQosSync int } type JsonRPCChainListener struct { - endpoint *lavasession.RPCEndpoint - relaySender RelaySender - healthReporter HealthReporter - logger *metrics.RPCConsumerLogs - refererData *RefererData + endpoint *lavasession.RPCEndpoint + relaySender RelaySender + healthReporter HealthReporter + logger *metrics.RPCConsumerLogs + refererData *RefererData + consumerWsSubscriptionManager *ConsumerWSSubscriptionManager } // NewJrpcChainListener creates a new instance of JsonRPCChainListener @@ -299,6 +308,7 @@ func NewJrpcChainListener(ctx context.Context, listenEndpoint *lavasession.RPCEn relaySender RelaySender, healthReporter HealthReporter, rpcConsumerLogs *metrics.RPCConsumerLogs, refererData *RefererData, + consumerWsSubscriptionManager *ConsumerWSSubscriptionManager, ) (chainListener *JsonRPCChainListener) { // Create a new instance of JsonRPCChainListener chainListener = &JsonRPCChainListener{ @@ -307,6 +317,7 @@ func NewJrpcChainListener(ctx context.Context, listenEndpoint *lavasession.RPCEn healthReporter, rpcConsumerLogs, refererData, + consumerWsSubscriptionManager, } return chainListener @@ -335,91 +346,26 @@ func (apil *JsonRPCChainListener) Serve(ctx context.Context, cmdFlags common.Con chainID := apil.endpoint.ChainID apiInterface := apil.endpoint.ApiInterface - webSocketCallback := websocket.New(func(websockConn *websocket.Conn) { - var ( - messageType int - msg []byte - err error - ) - startTime := time.Now() - msgSeed := apil.logger.GetMessageSeed() - for { - if messageType, msg, err = websockConn.ReadMessage(); err != nil { - apil.logger.AnalyzeWebSocketErrorAndWriteMessage(websockConn, messageType, err, msgSeed, msg, spectypes.APIInterfaceJsonRPC, time.Since(startTime)) - break - } - dappID, ok := websockConn.Locals("dapp-id").(string) - if !ok { - apil.logger.AnalyzeWebSocketErrorAndWriteMessage(websockConn, messageType, nil, msgSeed, []byte("Unable to extract dappID"), spectypes.APIInterfaceJsonRPC, time.Since(startTime)) - } - refererMatch, ok := websockConn.Locals(refererMatchString).(string) - ctx, cancel := context.WithCancel(context.Background()) - guid := utils.GenerateUniqueIdentifier() - ctx = utils.WithUniqueIdentifier(ctx, guid) - msgSeed = strconv.FormatUint(guid, 10) - defer cancel() // incase there's a problem make sure to cancel the connection - - logFormattedMsg := string(msg) - if !cmdFlags.DebugRelays { - logFormattedMsg = utils.FormatLongString(logFormattedMsg, relayMsgLogMaxChars) - } - - utils.LavaFormatDebug("ws in <<<", - utils.LogAttr("seed", msgSeed), - utils.LogAttr("GUID", ctx), - utils.LogAttr("msg", logFormattedMsg), - utils.LogAttr("dappID", dappID), - ) - metricsData := metrics.NewRelayAnalytics(dappID, chainID, apiInterface) - relayResult, err := apil.relaySender.SendRelay(ctx, "", string(msg), http.MethodPost, dappID, websockConn.RemoteAddr().String(), metricsData, nil) - if ok && refererMatch != "" && apil.refererData != nil && err == nil { - go apil.refererData.SendReferer(refererMatch, chainID, string(msg), websockConn.RemoteAddr().String(), nil, websockConn) - } - reply := relayResult.GetReply() - replyServer := relayResult.GetReplyServer() - go apil.logger.AddMetricForWebSocket(metricsData, err, websockConn) - if err != nil { - apil.logger.AnalyzeWebSocketErrorAndWriteMessage(websockConn, messageType, err, msgSeed, msg, spectypes.APIInterfaceJsonRPC, time.Since(startTime)) - continue - } - // If subscribe the first reply would contain the RPC ID that can be used for disconnect. - if replyServer != nil { - var reply pairingtypes.RelayReply - err = (*replyServer).RecvMsg(&reply) // this reply contains the RPC ID - if err != nil { - apil.logger.AnalyzeWebSocketErrorAndWriteMessage(websockConn, messageType, err, msgSeed, msg, spectypes.APIInterfaceJsonRPC, time.Since(startTime)) - continue - } + webSocketCallback := websocket.New(func(websocketConn *websocket.Conn) { + utils.LavaFormatDebug("jsonrpc websocket opened", utils.LogAttr("consumerIp", websocketConn.LocalAddr().String())) + defer utils.LavaFormatDebug("jsonrpc websocket closed", utils.LogAttr("consumerIp", websocketConn.LocalAddr().String())) + + consumerWebsocketManager := NewConsumerWebsocketManager(ConsumerWebsocketManagerOptions{ + WebsocketConn: websocketConn, + RpcConsumerLogs: apil.logger, + RefererMatchString: refererMatchString, + CmdFlags: cmdFlags, + RelayMsgLogMaxChars: relayMsgLogMaxChars, + ChainID: chainID, + ApiInterface: apiInterface, + ConnectionType: fiber.MethodPost, // We use it for the ParseMsg method, which needs to know the connection type to find the method in the spec + RefererData: apil.refererData, + RelaySender: apil.relaySender, + ConsumerWsSubscriptionManager: apil.consumerWsSubscriptionManager, + WebsocketConnectionUID: strconv.FormatUint(utils.GenerateUniqueIdentifier(), 10), + }) - if err = websockConn.WriteMessage(messageType, reply.Data); err != nil { - apil.logger.AnalyzeWebSocketErrorAndWriteMessage(websockConn, messageType, err, msgSeed, msg, spectypes.APIInterfaceJsonRPC, time.Since(startTime)) - continue - } - apil.logger.LogRequestAndResponse("jsonrpc ws msg", false, "ws", websockConn.LocalAddr().String(), string(msg), string(reply.Data), msgSeed, time.Since(startTime), nil) - for { - err = (*replyServer).RecvMsg(&reply) - if err != nil { - apil.logger.AnalyzeWebSocketErrorAndWriteMessage(websockConn, messageType, err, msgSeed, msg, spectypes.APIInterfaceJsonRPC, time.Since(startTime)) - break - } - - // If portal cant write to the client - if err = websockConn.WriteMessage(messageType, reply.Data); err != nil { - cancel() - apil.logger.AnalyzeWebSocketErrorAndWriteMessage(websockConn, messageType, err, msgSeed, msg, spectypes.APIInterfaceJsonRPC, time.Since(startTime)) - // break - } - - apil.logger.LogRequestAndResponse("jsonrpc ws msg", false, "ws", websockConn.LocalAddr().String(), string(msg), string(reply.Data), msgSeed, time.Since(startTime), nil) - } - } else { - if err = websockConn.WriteMessage(messageType, reply.Data); err != nil { - apil.logger.AnalyzeWebSocketErrorAndWriteMessage(websockConn, messageType, err, msgSeed, msg, spectypes.APIInterfaceJsonRPC, time.Since(startTime)) - continue - } - apil.logger.LogRequestAndResponse("jsonrpc ws msg", false, "ws", websockConn.LocalAddr().String(), string(msg), string(reply.Data), msgSeed, time.Since(startTime), nil) - } - } + consumerWebsocketManager.ListenToMessages() }) websocketCallbackWithDappID := constructFiberCallbackWithHeaderAndParameterExtraction(webSocketCallback, apil.logger.StoreMetricData) app.Get("/ws", websocketCallbackWithDappID) @@ -542,10 +488,17 @@ func NewJrpcChainProxy(ctx context.Context, nConns uint, rpcProviderEndpoint lav _, averageBlockTime, _, _ := chainParser.ChainBlockStats() nodeUrl := rpcProviderEndpoint.NodeUrls[0] cp := &JrpcChainProxy{ - BaseChainProxy: BaseChainProxy{averageBlockTime: averageBlockTime, NodeUrl: nodeUrl, ErrorHandler: &JsonRPCErrorHandler{}, ChainID: rpcProviderEndpoint.ChainID}, - conn: map[string]*chainproxy.Connector{}, + BaseChainProxy: BaseChainProxy{ + averageBlockTime: averageBlockTime, + NodeUrl: nodeUrl, + ErrorHandler: &JsonRPCErrorHandler{}, + ChainID: rpcProviderEndpoint.ChainID, + }, + conn: map[string]*chainproxy.Connector{}, } - verifyRPCEndpoint(nodeUrl.Url) + + validateEndpoints(rpcProviderEndpoint.NodeUrls, spectypes.APIInterfaceJsonRPC) + internalPaths := map[string]struct{}{} jsonRPCChainParser, ok := chainParser.(*JsonRPCChainParser) if ok { @@ -591,6 +544,7 @@ func (cp *JrpcChainProxy) start(ctx context.Context, nConns uint, nodeUrl common if err != nil { return err } + cp.conn[path] = conn if cp.conn == nil { return errors.New("g_conn == nil") @@ -666,14 +620,15 @@ func (cp *JrpcChainProxy) SendNodeMsg(ctx context.Context, ch chan interface{}, return reply, "", nil, err } internalPath := chainMessage.GetApiCollection().CollectionData.InternalPath - rpc, err := cp.conn[internalPath].GetRpc(ctx, true) + connector := cp.conn[internalPath] + rpc, err := connector.GetRpc(ctx, true) if err != nil { return nil, "", nil, err } - defer cp.conn[internalPath].ReturnRpc(rpc) + defer connector.ReturnRpc(rpc) // appending hashed url - grpc.SetTrailer(ctx, metadata.Pairs(RPCProviderNodeAddressHash, cp.conn[internalPath].GetUrlHash())) + grpc.SetTrailer(ctx, metadata.Pairs(RPCProviderNodeAddressHash, connector.GetUrlHash())) // Call our node var rpcMessage *rpcclient.JsonrpcMessage @@ -687,8 +642,9 @@ func (cp *JrpcChainProxy) SendNodeMsg(ctx context.Context, ch chan interface{}, defer rpc.SetHeader(metadata.Name, "") } } + var nodeErr error if ch != nil { - sub, rpcMessage, err = rpc.Subscribe(context.Background(), nodeMessage.ID, nodeMessage.Method, ch, nodeMessage.Params) + sub, rpcMessage, nodeErr = rpc.Subscribe(context.Background(), nodeMessage.ID, nodeMessage.Method, ch, nodeMessage.Params) } else { // we use the minimum timeout between the two, spec or context. to prevent the provider from hanging // we don't use the context alone so the provider won't be hanging forever by an attack @@ -696,7 +652,7 @@ func (cp *JrpcChainProxy) SendNodeMsg(ctx context.Context, ch chan interface{}, defer cancel() cp.NodeUrl.SetIpForwardingIfNecessary(ctx, rpc.SetHeader) - rpcMessage, err = rpc.CallContext(connectCtx, nodeMessage.ID, nodeMessage.Method, nodeMessage.Params, true, nodeMessage.GetDisableErrorHandling()) + rpcMessage, nodeErr = rpc.CallContext(connectCtx, nodeMessage.ID, nodeMessage.Method, nodeMessage.Params, true, nodeMessage.GetDisableErrorHandling()) if err != nil { // here we are getting an error for every code that is not 200-300 if common.StatusCodeError504.Is(err) || common.StatusCodeError429.Is(err) || common.StatusCodeErrorStrict.Is(err) { @@ -712,31 +668,32 @@ func (cp *JrpcChainProxy) SendNodeMsg(ctx context.Context, ch chan interface{}, var replyMsg rpcInterfaceMessages.JsonrpcMessage // the error check here would only wrap errors not from the rpc + + if nodeErr != nil { + utils.LavaFormatDebug("got error from node", utils.LogAttr("GUID", ctx), utils.LogAttr("nodeErr", nodeErr)) + return nil, "", nil, nodeErr + } + + replyMessage, err = rpcInterfaceMessages.ConvertJsonRPCMsg(rpcMessage) if err != nil { - utils.LavaFormatDebug("received an error from SendNodeMsg", utils.Attribute{Key: "GUID", Value: ctx}, utils.Attribute{Key: "error", Value: err}) - return nil, "", nil, err - } else { - replyMessage, err = rpcInterfaceMessages.ConvertJsonRPCMsg(rpcMessage) - if err != nil { - return nil, "", nil, utils.LavaFormatError("jsonRPC error", err, utils.Attribute{Key: "GUID", Value: ctx}) - } - // validate result is valid - if replyMessage.Error == nil { - responseIsNilValidationError := ValidateNilResponse(string(replyMessage.Result)) - if responseIsNilValidationError != nil { - return nil, "", nil, responseIsNilValidationError - } + return nil, "", nil, utils.LavaFormatError("jsonRPC error", err, utils.Attribute{Key: "GUID", Value: ctx}) + } + // validate result is valid + if replyMessage.Error == nil { + responseIsNilValidationError := ValidateNilResponse(string(replyMessage.Result)) + if responseIsNilValidationError != nil { + return nil, "", nil, responseIsNilValidationError } + } - replyMsg = *replyMessage - err := cp.ValidateRequestAndResponseIds(nodeMessage.ID, replyMessage.ID) - if err != nil { - return nil, "", nil, utils.LavaFormatError("jsonRPC ID mismatch error", err, - utils.Attribute{Key: "GUID", Value: ctx}, - utils.Attribute{Key: "requestId", Value: nodeMessage.ID}, - utils.Attribute{Key: "responseId", Value: rpcMessage.ID}, - ) - } + replyMsg = *replyMessage + err = cp.ValidateRequestAndResponseIds(nodeMessage.ID, replyMessage.ID) + if err != nil { + return nil, "", nil, utils.LavaFormatError("jsonRPC ID mismatch error", err, + utils.Attribute{Key: "GUID", Value: ctx}, + utils.Attribute{Key: "requestId", Value: nodeMessage.ID}, + utils.Attribute{Key: "responseId", Value: rpcMessage.ID}, + ) } retData, err := json.Marshal(replyMsg) @@ -753,9 +710,17 @@ func (cp *JrpcChainProxy) SendNodeMsg(ctx context.Context, ch chan interface{}, } if ch != nil { - subscriptionID, err = strconv.Unquote(string(replyMsg.Result)) - if err != nil { - return nil, "", nil, utils.LavaFormatError("Subscription failed", err, utils.Attribute{Key: "GUID", Value: ctx}) + if replyMsg.Error != nil { + return reply, "", nil, nil + } + + if common.IsQuoted(string(replyMsg.Result)) { + subscriptionID, err = strconv.Unquote(string(replyMsg.Result)) + if err != nil { + return nil, "", nil, utils.LavaFormatError("Subscription failed", err, utils.Attribute{Key: "GUID", Value: ctx}) + } + } else { + subscriptionID = string(replyMsg.Result) } } diff --git a/protocol/chainlib/jsonRPC_test.go b/protocol/chainlib/jsonRPC_test.go index 3404614941..a110b22bca 100644 --- a/protocol/chainlib/jsonRPC_test.go +++ b/protocol/chainlib/jsonRPC_test.go @@ -9,8 +9,10 @@ import ( "testing" "time" + "github.com/gorilla/websocket" "github.com/lavanet/lava/v2/protocol/chainlib/chainproxy/rpcInterfaceMessages" "github.com/lavanet/lava/v2/protocol/chainlib/extensionslib" + "github.com/lavanet/lava/v2/protocol/common" keepertest "github.com/lavanet/lava/v2/testutil/keeper" plantypes "github.com/lavanet/lava/v2/x/plans/types" spectypes "github.com/lavanet/lava/v2/x/spec/types" @@ -18,6 +20,32 @@ import ( "github.com/stretchr/testify/require" ) +func createWebSocketHandler(handler func(string) string) http.HandlerFunc { + upGrader := websocket.Upgrader{} + + // Create a simple websocket server that mocks the node + return func(w http.ResponseWriter, r *http.Request) { + conn, err := upGrader.Upgrade(w, r, nil) + if err != nil { + fmt.Println(err) + panic("got error in upgrader") + } + defer conn.Close() + + for { + // Read the request + messageType, message, err := conn.ReadMessage() + if err != nil { + panic("got error in ReadMessage") + } + fmt.Println("got ws message", string(message), messageType) + retMsg := handler(string(message)) + conn.WriteMessage(messageType, []byte(retMsg)) + fmt.Println("writing ws message", string(message), messageType) + } + } +} + func TestJSONChainParser_Spec(t *testing.T) { // create a new instance of RestChainParser apip, err := NewJrpcChainParser() @@ -134,26 +162,33 @@ func TestJSONParseMessage(t *testing.T) { func TestJsonRpcChainProxy(t *testing.T) { ctx := context.Background() - serverHandle := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + serverHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Handle the incoming request and provide the desired response w.WriteHeader(http.StatusOK) fmt.Fprint(w, `{"jsonrpc":"2.0","id":1,"result":"0x10a7a08"}`) }) - chainParser, chainProxy, chainFetcher, closeServer, _, err := CreateChainLibMocks(ctx, "ETH1", spectypes.APIInterfaceJsonRPC, serverHandle, "../../", nil) + wsServerHandler := func(message string) string { + return `{"jsonrpc":"2.0","id":1,"result":"0x10a7a08"}` + } + + chainParser, chainProxy, chainFetcher, closeServer, _, err := CreateChainLibMocks(ctx, "ETH1", spectypes.APIInterfaceJsonRPC, serverHandler, createWebSocketHandler(wsServerHandler), "../../", nil) + if closeServer != nil { + defer closeServer() + } + require.NoError(t, err) require.NotNil(t, chainParser) require.NotNil(t, chainProxy) require.NotNil(t, chainFetcher) + block, err := chainFetcher.FetchLatestBlockNum(ctx) require.Greater(t, block, int64(0)) require.NoError(t, err) + _, err = chainFetcher.FetchBlockHashByNum(ctx, block) errMsg := "GET_BLOCK_BY_NUM Failed ParseMessageResponse {error:invalid parser input format" require.True(t, err.Error()[:len(errMsg)] == errMsg, err.Error()) - if closeServer != nil { - closeServer() - } } func TestAddonAndVerifications(t *testing.T) { @@ -164,7 +199,15 @@ func TestAddonAndVerifications(t *testing.T) { fmt.Fprint(w, `{"jsonrpc":"2.0","id":1,"result":"0xf9ccdff90234a064"}`) }) - chainParser, chainRouter, chainFetcher, closeServer, _, err := CreateChainLibMocks(ctx, "ETH1", spectypes.APIInterfaceJsonRPC, serverHandle, "../../", []string{"debug"}) + wsServerHandler := func(message string) string { + return `{"jsonrpc":"2.0","id":1,"result":"0xf9ccdff90234a064"}` + } + + chainParser, chainRouter, chainFetcher, closeServer, _, err := CreateChainLibMocks(ctx, "ETH1", spectypes.APIInterfaceJsonRPC, serverHandle, createWebSocketHandler(wsServerHandler), "../../", []string{"debug"}) + if closeServer != nil { + defer closeServer() + } + require.NoError(t, err) require.NotNil(t, chainParser) require.NotNil(t, chainRouter) @@ -183,9 +226,6 @@ func TestAddonAndVerifications(t *testing.T) { _, err = FormatResponseForParsing(reply.RelayReply, chainMessage) require.NoError(t, err) } - if closeServer != nil { - closeServer() - } } func TestExtensions(t *testing.T) { @@ -196,8 +236,16 @@ func TestExtensions(t *testing.T) { fmt.Fprint(w, `{"jsonrpc":"2.0","id":1,"result":"0xf9ccdff90234a064"}`) }) + wsServerHandler := func(message string) string { + return `{"jsonrpc":"2.0","id":1,"result":"0xf9ccdff90234a064"}` + } + specname := "ETH1" - chainParser, chainRouter, chainFetcher, closeServer, _, err := CreateChainLibMocks(ctx, specname, spectypes.APIInterfaceJsonRPC, serverHandle, "../../", []string{"archive"}) + chainParser, chainRouter, chainFetcher, closeServer, _, err := CreateChainLibMocks(ctx, specname, spectypes.APIInterfaceJsonRPC, serverHandle, createWebSocketHandler(wsServerHandler), "../../", []string{"archive"}) + if closeServer != nil { + defer closeServer() + } + require.NoError(t, err) require.NotNil(t, chainParser) require.NotNil(t, chainRouter) @@ -209,8 +257,9 @@ func TestExtensions(t *testing.T) { require.NoError(t, err) chainParser.SetPolicy(&plantypes.Policy{ChainPolicies: []plantypes.ChainPolicy{{ChainId: specname, Requirements: []plantypes.ChainRequirement{{Collection: spectypes.CollectionData{ApiInterface: "jsonrpc"}, Extensions: []string{"archive"}}}}}}, specname, "jsonrpc") - parsingForCrafting, collectionData, ok := chainParser.GetParsingByTag(spectypes.FUNCTION_TAG_GET_BLOCK_BY_NUM) + parsingForCrafting, apiCollection, ok := chainParser.GetParsingByTag(spectypes.FUNCTION_TAG_GET_BLOCK_BY_NUM) require.True(t, ok) + collectionData := apiCollection.CollectionData cuCost := uint64(0) for _, api := range spec.ApiCollections[0].Apis { if api.Name == parsingForCrafting.ApiName { @@ -258,9 +307,6 @@ func TestExtensions(t *testing.T) { require.Len(t, chainMessage.GetExtensions(), 1) require.Equal(t, "archive", chainMessage.GetExtensions()[0].Name) require.Equal(t, cuCostExt, chainMessage.GetApi().ComputeUnits) - if closeServer != nil { - closeServer() - } } func TestJsonRpcBatchCall(t *testing.T) { @@ -279,7 +325,16 @@ func TestJsonRpcBatchCall(t *testing.T) { fmt.Fprint(w, response) }) - chainParser, chainProxy, chainFetcher, closeServer, _, err := CreateChainLibMocks(ctx, "ETH1", spectypes.APIInterfaceJsonRPC, serverHandle, "../../", nil) + wsServerHandler := func(message string) string { + require.Equal(t, batchCallData, message) + return response + } + + chainParser, chainProxy, chainFetcher, closeServer, _, err := CreateChainLibMocks(ctx, "ETH1", spectypes.APIInterfaceJsonRPC, serverHandle, createWebSocketHandler(wsServerHandler), "../../", nil) + if closeServer != nil { + defer closeServer() + } + require.NoError(t, err) require.NotNil(t, chainParser) require.NotNil(t, chainProxy) @@ -287,18 +342,15 @@ func TestJsonRpcBatchCall(t *testing.T) { chainMessage, err := chainParser.ParseMsg("", []byte(batchCallData), http.MethodPost, nil, extensionslib.ExtensionInfo{LatestBlock: 0}) require.NoError(t, err) + requestedBlock, _ := chainMessage.RequestedBlock() require.Equal(t, spectypes.LATEST_BLOCK, requestedBlock) + relayReply, _, _, _, _, err := chainProxy.SendNodeMsg(ctx, nil, chainMessage, nil) require.True(t, gotCalled) require.NoError(t, err) require.NotNil(t, relayReply) require.Equal(t, response, string(relayReply.RelayReply.Data)) - defer func() { - if closeServer != nil { - closeServer() - } - }() } func TestJsonRpcBatchCallSameID(t *testing.T) { @@ -320,7 +372,16 @@ func TestJsonRpcBatchCallSameID(t *testing.T) { fmt.Fprint(w, response) }) - chainParser, chainProxy, chainFetcher, closeServer, _, err := CreateChainLibMocks(ctx, "ETH1", spectypes.APIInterfaceJsonRPC, serverHandle, "../../", nil) + wsServerHandler := func(message string) string { + require.Equal(t, sentBatchCallData, message) + return response + } + + chainParser, chainProxy, chainFetcher, closeServer, _, err := CreateChainLibMocks(ctx, "ETH1", spectypes.APIInterfaceJsonRPC, serverHandle, createWebSocketHandler(wsServerHandler), "../../", nil) + if closeServer != nil { + defer closeServer() + } + require.NoError(t, err) require.NotNil(t, chainParser) require.NotNil(t, chainProxy) @@ -335,21 +396,21 @@ func TestJsonRpcBatchCallSameID(t *testing.T) { require.NoError(t, err) require.NotNil(t, relayReply) require.Equal(t, responseExpected, string(relayReply.RelayReply.Data)) - defer func() { - if closeServer != nil { - closeServer() - } - }() } -func TestJsonRpcInternalPathsMultipleVersions(t *testing.T) { +func TestJsonRpcInternalPathsMultipleVersionsStarkNet(t *testing.T) { ctx := context.Background() serverHandle := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Handle the incoming request and provide the desired response w.WriteHeader(http.StatusOK) fmt.Fprintf(w, `{"jsonrpc":"2.0","id":1,"result":"%s"}`, r.RequestURI) }) - chainParser, chainProxy, chainFetcher, closeServer, _, err := CreateChainLibMocks(ctx, "STRK", spectypes.APIInterfaceJsonRPC, serverHandle, "../../", nil) + + chainParser, chainProxy, chainFetcher, closeServer, _, err := CreateChainLibMocks(ctx, "STRK", spectypes.APIInterfaceJsonRPC, serverHandle, nil, "../../", nil) + if closeServer != nil { + defer closeServer() + } + require.NoError(t, err) require.NotNil(t, chainParser) require.NotNil(t, chainProxy) @@ -377,7 +438,71 @@ func TestJsonRpcInternalPathsMultipleVersions(t *testing.T) { collection = chainMessage.GetApiCollection() require.Equal(t, "starknet_specVersion", api.Name) require.Equal(t, v6_path, collection.CollectionData.InternalPath) +} + +func TestJsonRpcInternalPathsMultipleVersionsAvalanche(t *testing.T) { + type reqWithApiName struct { + apiName string + reqData []byte + } + + // TODO: Add the empty path back in once the ETH spec will be fixed + // allPaths := []string{"", "/C/rpc", "/C/avax", "/P", "/X"} + allPaths := []string{"/C/rpc", "/C/avax", "/P", "/X"} + pathToReqData := map[string]reqWithApiName{ + "/C/rpc": { // Eth jsonrpc path + apiName: "eth_blockNumber", + reqData: []byte(`{"jsonrpc": "2.0", "id": 1, "method": "eth_blockNumber", "params": []}`), + }, + "/C/avax": { // Avalanche jsonrpc path + apiName: "avax.export", + reqData: []byte(`{"jsonrpc": "2.0", "id": 1, "method": "avax.export", "params": []}`), + }, + "/P": { // Platform jsonrpc path + apiName: "platform.addDelegator", + reqData: []byte(`{"jsonrpc": "2.0", "id": 1, "method": "platform.addDelegator", "params": []}`), + }, + "/X": { // Avm jsonrpc path + apiName: "avm.getAssetDescription", + reqData: []byte(`{"jsonrpc": "2.0", "id": 1, "method": "avm.getAssetDescription", "params": []}`), + }, + } + + ctx := context.Background() + serverHandle := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Handle the incoming request and provide the desired response + w.WriteHeader(http.StatusOK) + fmt.Fprintf(w, `{"jsonrpc":"2.0","id":1,"result":"%s"}`, r.RequestURI) + }) + + chainParser, chainProxy, chainFetcher, closeServer, _, err := CreateChainLibMocks(ctx, "AVAX", spectypes.APIInterfaceJsonRPC, serverHandle, nil, "../../", nil) if closeServer != nil { - closeServer() + defer closeServer() + } + + require.NoError(t, err) + require.NotNil(t, chainParser) + require.NotNil(t, chainProxy) + require.NotNil(t, chainFetcher) + + for correctPath, reqDataWithApiName := range pathToReqData { + for _, path := range allPaths { + shouldErr := path != correctPath + t.Run(fmt.Sprintf("ApiName:%s,CorrectPath:%s,Path:%s,ShouldError:%v", reqDataWithApiName.apiName, correctPath, path, shouldErr), func(t *testing.T) { + chainMessage, err := chainParser.ParseMsg(path, reqDataWithApiName.reqData, http.MethodPost, nil, extensionslib.ExtensionInfo{LatestBlock: 0}) + + if !shouldErr { + require.NoError(t, err) + api := chainMessage.GetApi() + collection := chainMessage.GetApiCollection() + require.Equal(t, reqDataWithApiName.apiName, api.Name) + require.Equal(t, correctPath, collection.CollectionData.InternalPath) + } else { + require.Error(t, err) + require.ErrorIs(t, err, common.APINotSupportedError) + require.Nil(t, chainMessage) + } + }) + } } } diff --git a/protocol/chainlib/provider_node_subscription_manager.go b/protocol/chainlib/provider_node_subscription_manager.go new file mode 100644 index 0000000000..3766a063cc --- /dev/null +++ b/protocol/chainlib/provider_node_subscription_manager.go @@ -0,0 +1,650 @@ +package chainlib + +import ( + "context" + "fmt" + "sync" + "time" + + "github.com/btcsuite/btcd/btcec/v2" + sdk "github.com/cosmos/cosmos-sdk/types" + gojson "github.com/goccy/go-json" + "github.com/lavanet/lava/v2/protocol/chainlib/chainproxy/rpcInterfaceMessages" + "github.com/lavanet/lava/v2/protocol/chainlib/chainproxy/rpcclient" + "github.com/lavanet/lava/v2/protocol/chainlib/extensionslib" + "github.com/lavanet/lava/v2/protocol/chaintracker" + "github.com/lavanet/lava/v2/protocol/common" + "github.com/lavanet/lava/v2/protocol/lavaprotocol" + "github.com/lavanet/lava/v2/utils" + "github.com/lavanet/lava/v2/utils/protocopy" + pairingtypes "github.com/lavanet/lava/v2/x/pairing/types" + spectypes "github.com/lavanet/lava/v2/x/spec/types" +) + +const SubscriptionTimeoutDuration = 15 * time.Minute + +type relayFinalizationBlocksHandler interface { + GetParametersForRelayDataReliability( + ctx context.Context, + request *pairingtypes.RelayRequest, + chainMsg ChainMessage, + relayTimeout time.Duration, + blockLagForQosSync int64, + averageBlockTime time.Duration, + blockDistanceToFinalization, + blocksInFinalizationData uint32, + ) (latestBlock int64, requestedBlockHash []byte, requestedHashes []*chaintracker.BlockStore, modifiedReqBlock int64, finalized, updatedChainMessage bool, err error) + + BuildRelayFinalizedBlockHashes( + ctx context.Context, + request *pairingtypes.RelayRequest, + reply *pairingtypes.RelayReply, + latestBlock int64, + requestedHashes []*chaintracker.BlockStore, + updatedChainMessage bool, + relayTimeout time.Duration, + averageBlockTime time.Duration, + blockDistanceToFinalization uint32, + blocksInFinalizationData uint32, + modifiedReqBlock int64, + ) (err error) +} + +type connectedConsumerContainer struct { + consumerChannel *common.SafeChannelSender[*pairingtypes.RelayReply] + firstSetupRequest *pairingtypes.RelayRequest + consumerSDKAddress sdk.AccAddress +} + +type activeSubscription struct { + cancellableContext context.Context + cancellableContextCancelFunc context.CancelFunc + messagesChannel chan interface{} + nodeSubscription *rpcclient.ClientSubscription + subscriptionID string + firstSetupReply *pairingtypes.RelayReply + apiCollection *spectypes.ApiCollection + connectedConsumers map[string]map[string]*connectedConsumerContainer // first key is consumer address, 2nd key is consumer guid +} + +type ProviderNodeSubscriptionManager struct { + chainRouter ChainRouter + chainParser ChainParser + relayFinalizationBlocksHandler relayFinalizationBlocksHandler + activeSubscriptions map[string]*activeSubscription // key is request params hash + currentlyPendingSubscriptions map[string]*pendingSubscriptionsBroadcastManager // pending subscriptions waiting for node message to return. + privKey *btcec.PrivateKey + lock sync.RWMutex +} + +func NewProviderNodeSubscriptionManager(chainRouter ChainRouter, chainParser ChainParser, relayFinalizationBlocksHandler relayFinalizationBlocksHandler, privKey *btcec.PrivateKey) *ProviderNodeSubscriptionManager { + return &ProviderNodeSubscriptionManager{ + chainRouter: chainRouter, + chainParser: chainParser, + relayFinalizationBlocksHandler: relayFinalizationBlocksHandler, + activeSubscriptions: make(map[string]*activeSubscription), + currentlyPendingSubscriptions: make(map[string]*pendingSubscriptionsBroadcastManager), + privKey: privKey, + } +} + +func (pnsm *ProviderNodeSubscriptionManager) failedPendingSubscription(hashedParams string) { + pnsm.lock.Lock() + defer pnsm.lock.Unlock() + pendingSubscriptionChannel, ok := pnsm.currentlyPendingSubscriptions[hashedParams] + if !ok { + utils.LavaFormatError("failed fetching hashed params in failedPendingSubscriptions", nil, utils.LogAttr("hash", hashedParams), utils.LogAttr("cwsm.currentlyPendingSubscriptions", pnsm.currentlyPendingSubscriptions)) + } else { + pendingSubscriptionChannel.broadcastToChannelList(false) + delete(pnsm.currentlyPendingSubscriptions, hashedParams) // removed pending + } +} + +// must be called under a lock. +func (pnsm *ProviderNodeSubscriptionManager) successfulPendingSubscription(hashedParams string) { + pendingSubscriptionChannel, ok := pnsm.currentlyPendingSubscriptions[hashedParams] + if !ok { + utils.LavaFormatError("failed fetching hashed params in successfulPendingSubscription", nil, utils.LogAttr("hash", hashedParams), utils.LogAttr("cwsm.currentlyPendingSubscriptions", pnsm.currentlyPendingSubscriptions)) + } else { + pendingSubscriptionChannel.broadcastToChannelList(true) + delete(pnsm.currentlyPendingSubscriptions, hashedParams) // removed pending + } +} + +func (pnsm *ProviderNodeSubscriptionManager) checkAndAddPendingSubscriptionsWithLock(hashedParams string) (chan bool, bool) { + pnsm.lock.Lock() + defer pnsm.lock.Unlock() + pendingSubscriptionBroadcastManager, ok := pnsm.currentlyPendingSubscriptions[hashedParams] + if !ok { + // we didn't find hashed params for pending subscriptions, we can create a new subscription + utils.LavaFormatTrace("No pending subscription for incoming hashed params found", utils.LogAttr("params", hashedParams)) + // create pending subscription broadcast manager for other users to sync on the same relay. + pnsm.currentlyPendingSubscriptions[hashedParams] = &pendingSubscriptionsBroadcastManager{} + return nil, ok + } + utils.LavaFormatTrace("found subscription for incoming hashed params, registering our channel", utils.LogAttr("params", hashedParams)) + // by creating a buffered channel we make sure that we wont miss out on the update between the time we register and listen to the channel + listenChan := make(chan bool, 1) + pendingSubscriptionBroadcastManager.broadcastChannelList = append(pendingSubscriptionBroadcastManager.broadcastChannelList, listenChan) + return listenChan, ok +} + +func (pnsm *ProviderNodeSubscriptionManager) checkForActiveSubscriptionsWithLock(ctx context.Context, hashedParams string, consumerAddr sdk.AccAddress, consumerProcessGuid string, params []byte, chainMessage ChainMessage, consumerChannel chan<- *pairingtypes.RelayReply, request *pairingtypes.RelayRequest) (subscriptionId string, err error) { + pnsm.lock.Lock() + defer pnsm.lock.Unlock() + paramsChannelToConnectedConsumers, foundSubscriptionHash := pnsm.activeSubscriptions[hashedParams] + if foundSubscriptionHash { + consumerAddrString := consumerAddr.String() + utils.LavaFormatTrace("[AddConsumer] found existing subscription", + utils.LogAttr("GUID", ctx), + utils.LogAttr("consumerAddr", consumerAddr), + utils.LogAttr("consumerProcessGuid", consumerProcessGuid), + utils.LogAttr("params", string(params)), + utils.LogAttr("hashedParams", utils.ToHexString(hashedParams)), + ) + + if _, foundConsumer := paramsChannelToConnectedConsumers.connectedConsumers[consumerAddrString]; foundConsumer { // Consumer is already connected to this subscription, dismiss + // check consumer guid. + if consumerGuidContainer, foundGuid := paramsChannelToConnectedConsumers.connectedConsumers[consumerAddrString][consumerProcessGuid]; foundGuid { + // if the consumer exists and channel is already active, and the consumer tried to resubscribe, we assume the connection was interrupted, we disconnect the previous channel and reconnect the incoming channel. + utils.LavaFormatWarning("consumer tried to subscribe twice to the same subscription hash, disconnecting the previous one and attaching incoming channel", nil, + utils.LogAttr("consumerAddr", consumerAddr), + utils.LogAttr("hashedParams", hashedParams), + utils.LogAttr("params_provided", chainMessage.GetRPCMessage().GetParams()), + ) + // disconnecting the previous channel, attaching new channel, and returning subscription Id. + consumerGuidContainer.consumerChannel.ReplaceChannel(consumerChannel) + return paramsChannelToConnectedConsumers.subscriptionID, nil + } + // else we have this consumer but two different processes try to subscribe + utils.LavaFormatTrace("[AddConsumer] consumer address exists but consumer GUID does not exist in the subscription map, adding", + utils.LogAttr("GUID", ctx), + utils.LogAttr("consumerAddr", consumerAddr), + utils.LogAttr("params", string(params)), + utils.LogAttr("hashedParams", utils.ToHexString(hashedParams)), + utils.LogAttr("consumerProcessGuid", consumerProcessGuid), + ) + } + + // Create a new map for this consumer address if it doesn't exist + if paramsChannelToConnectedConsumers.connectedConsumers[consumerAddrString] == nil { + utils.LavaFormatError("missing map object from paramsChannelToConnectedConsumers.connectedConsumers[consumerAddrString], creating to avoid nil deref", nil) + paramsChannelToConnectedConsumers.connectedConsumers[consumerAddrString] = make(map[string]*connectedConsumerContainer) + } + + utils.LavaFormatTrace("[AddConsumer] consumer GUID does not exist in the subscription, adding", + utils.LogAttr("GUID", ctx), + utils.LogAttr("consumerAddr", consumerAddr), + utils.LogAttr("params", string(params)), + utils.LogAttr("hashedParams", utils.ToHexString(hashedParams)), + utils.LogAttr("consumerProcessGuid", consumerProcessGuid), + ) + + // Add the new entry for the consumer + paramsChannelToConnectedConsumers.connectedConsumers[consumerAddrString][consumerProcessGuid] = &connectedConsumerContainer{ + consumerChannel: common.NewSafeChannelSender(ctx, consumerChannel), + firstSetupRequest: &pairingtypes.RelayRequest{}, // Deep copy later firstSetupChainMessage: chainMessage, + consumerSDKAddress: consumerAddr, + } + + copyRequestErr := protocopy.DeepCopyProtoObject(request, paramsChannelToConnectedConsumers.connectedConsumers[consumerAddrString][consumerProcessGuid].firstSetupRequest) + if copyRequestErr != nil { + return "", utils.LavaFormatError("failed to copy subscription request", copyRequestErr) + } + + firstSetupReply := paramsChannelToConnectedConsumers.firstSetupReply + // Making sure to sign the reply before returning it to the consumer. This will replace the sig field with the correct value + // (and not the signature for another consumer) + signingError := pnsm.signReply(ctx, firstSetupReply, consumerAddr, chainMessage, request) + if signingError != nil { + return "", utils.LavaFormatError("AddConsumer failed signing reply", signingError) + } + + subscriptionId = paramsChannelToConnectedConsumers.subscriptionID + // Send the first reply to the consumer asynchronously, allowing the lock to be released while waiting for the consumer to receive the response. + pnsm.activeSubscriptions[hashedParams].connectedConsumers[consumerAddrString][consumerProcessGuid].consumerChannel.LockAndSendAsynchronously(firstSetupReply) + return subscriptionId, nil + } + return "", NoActiveSubscriptionFound +} + +func (pnsm *ProviderNodeSubscriptionManager) AddConsumer(ctx context.Context, request *pairingtypes.RelayRequest, chainMessage ChainMessage, consumerAddr sdk.AccAddress, consumerChannel chan<- *pairingtypes.RelayReply, consumerProcessGuid string) (subscriptionId string, err error) { + utils.LavaFormatTrace("[AddConsumer] called", utils.LogAttr("consumerAddr", consumerAddr)) + + if pnsm == nil { + return "", fmt.Errorf("ProviderNodeSubscriptionManager is nil") + } + + hashedParams, params, err := pnsm.getHashedParams(chainMessage) + if err != nil { + return "", err + } + + utils.LavaFormatTrace("[AddConsumer] hashed params", + utils.LogAttr("params", string(params)), + utils.LogAttr("hashedParams", utils.ToHexString(hashedParams)), + ) + + subscriptionId, err = pnsm.checkForActiveSubscriptionsWithLock(ctx, hashedParams, consumerAddr, consumerProcessGuid, params, chainMessage, consumerChannel, request) + if NoActiveSubscriptionFound.Is(err) { + // This for loop will break when there is a successful queue lock, allowing us to avoid racing new subscription creation when + // there is a failed subscription. the loop will break for the first routine the manages to lock and create the pendingSubscriptionsBroadcastManager + for { + pendingSubscriptionChannel, foundPendingSubscription := pnsm.checkAndAddPendingSubscriptionsWithLock(hashedParams) + if foundPendingSubscription { + utils.LavaFormatTrace("Found pending subscription, waiting for it to complete") + pendingResult := <-pendingSubscriptionChannel + utils.LavaFormatTrace("Finished pending for subscription, have results", utils.LogAttr("success", pendingResult)) + // Check result is valid, if not fall through logs and try again with a new message. + if pendingResult { + subscriptionId, err = pnsm.checkForActiveSubscriptionsWithLock(ctx, hashedParams, consumerAddr, consumerProcessGuid, params, chainMessage, consumerChannel, request) + if err == nil { + // found new the subscription after waiting for a pending subscription + return subscriptionId, err + } + // In case we expected a subscription to return as res != nil we should find an active subscription. + // If we fail to find it, it might have suddenly stopped. we will log a warning and try with a new client. + utils.LavaFormatWarning("failed getting a result when channel indicated we got a successful relay", nil) + } else { + utils.LavaFormatWarning("Failed the subscription attempt, retrying with the incoming message", nil, utils.LogAttr("hash", hashedParams)) + } + } else { + utils.LavaFormatDebug("No Pending subscriptions, creating a new one", utils.LogAttr("hash", hashedParams)) + break + } + } + + // did not find active or pending subscriptions, will try to create a new subscription. + consumerAddrString := consumerAddr.String() + utils.LavaFormatTrace("[AddConsumer] did not found existing subscription for hashed params, creating new one", utils.LogAttr("hash", hashedParams)) + nodeChan := make(chan interface{}) + var replyWrapper *RelayReplyWrapper + var clientSubscription *rpcclient.ClientSubscription + replyWrapper, subscriptionId, clientSubscription, _, _, err = pnsm.chainRouter.SendNodeMsg(ctx, nodeChan, chainMessage, append(request.RelayData.Extensions, WebSocketExtension)) + utils.LavaFormatTrace("[AddConsumer] subscription reply received", + utils.LogAttr("replyWrapper", replyWrapper), + utils.LogAttr("subscriptionId", subscriptionId), + utils.LogAttr("clientSubscription", clientSubscription), + utils.LogAttr("err", err), + ) + + if err != nil { + pnsm.failedPendingSubscription(hashedParams) + return "", utils.LavaFormatError("ProviderNodeSubscriptionManager: Subscription failed", err, utils.LogAttr("GUID", ctx), utils.LogAttr("params", params)) + } + + if replyWrapper == nil || replyWrapper.RelayReply == nil { + pnsm.failedPendingSubscription(hashedParams) + return "", utils.LavaFormatError("ProviderNodeSubscriptionManager: Subscription failed, relayWrapper or RelayReply are nil", nil, utils.LogAttr("GUID", ctx)) + } + + reply := replyWrapper.RelayReply + + copiedRequest := &pairingtypes.RelayRequest{} + copyRequestErr := protocopy.DeepCopyProtoObject(request, copiedRequest) + if copyRequestErr != nil { + pnsm.failedPendingSubscription(hashedParams) + return "", utils.LavaFormatError("failed to copy subscription request", copyRequestErr) + } + + err = pnsm.signReply(ctx, reply, consumerAddr, chainMessage, request) + if err != nil { + pnsm.failedPendingSubscription(hashedParams) + return "", utils.LavaFormatError("failed signing subscription Reply", err) + } + + if clientSubscription == nil { + // failed subscription, but not an error. (probably a node error) + SafeChannelSender := common.NewSafeChannelSender(ctx, consumerChannel) + // Send the first message to the consumer, so it can handle the error in a routine. + go SafeChannelSender.Send(reply) + pnsm.failedPendingSubscription(hashedParams) + return "", utils.LavaFormatWarning("ProviderNodeSubscriptionManager: Subscription failed, node error", nil, utils.LogAttr("GUID", ctx), utils.LogAttr("reply", reply)) + } + + utils.LavaFormatTrace("[AddConsumer] subscription successful", + utils.LogAttr("subscriptionId", subscriptionId), + utils.LogAttr("consumerAddr", consumerAddr), + utils.LogAttr("firstSetupReplyData", reply.Data), + ) + + // Initialize the map for connected consumers + connectedConsumers := map[string]map[string]*connectedConsumerContainer{ + consumerAddrString: { + consumerProcessGuid: { + consumerChannel: common.NewSafeChannelSender(ctx, consumerChannel), + firstSetupRequest: copiedRequest, + consumerSDKAddress: consumerAddr, + }, + }, + } + + // Create the activeSubscription instance + cancellableCtx, cancel := context.WithCancel(context.Background()) + channelToConnectedConsumers := &activeSubscription{ + cancellableContext: cancellableCtx, + cancellableContextCancelFunc: cancel, + messagesChannel: nodeChan, + nodeSubscription: clientSubscription, + subscriptionID: subscriptionId, + firstSetupReply: reply, + apiCollection: chainMessage.GetApiCollection(), + connectedConsumers: connectedConsumers, + } + + // now we can lock after we have a successful subscription. + pnsm.lock.Lock() + defer pnsm.lock.Unlock() + + pnsm.activeSubscriptions[hashedParams] = channelToConnectedConsumers + firstSetupReply := reply + + // let other channels waiting the new subscription know we have a channel ready. + pnsm.successfulPendingSubscription(hashedParams) + + // send the first reply to the consumer, reply needs to be signed. + pnsm.activeSubscriptions[hashedParams].connectedConsumers[consumerAddrString][consumerProcessGuid].consumerChannel.LockAndSendAsynchronously(firstSetupReply) + // now when all channels are set, start listening to incoming data. + go pnsm.listenForSubscriptionMessages(cancellableCtx, nodeChan, clientSubscription.Err(), hashedParams) + } + + return subscriptionId, err +} + +func (pnsm *ProviderNodeSubscriptionManager) listenForSubscriptionMessages(ctx context.Context, nodeChan chan interface{}, nodeErrChan <-chan error, hashedParams string) { + utils.LavaFormatTrace("Inside ProviderNodeSubscriptionManager:startListeningForSubscription()", utils.LogAttr("hashedParams", utils.ToHexString(hashedParams))) + defer utils.LavaFormatTrace("Leaving ProviderNodeSubscriptionManager:startListeningForSubscription()", utils.LogAttr("hashedParams", utils.ToHexString(hashedParams))) + + subscriptionTimeoutTicker := time.NewTicker(SubscriptionTimeoutDuration) // Set a time limit of 15 minutes for the subscription + defer subscriptionTimeoutTicker.Stop() + + closeNodeSubscriptionCallback := func() { + pnsm.lock.Lock() + defer pnsm.lock.Unlock() + pnsm.closeNodeSubscription(hashedParams) + } + + for { + select { + case <-ctx.Done(): + // If this context is done, it means that the subscription was already closed + utils.LavaFormatTrace("ProviderNodeSubscriptionManager:startListeningForSubscription() subscription context is done, ending subscription", + utils.LogAttr("GUID", ctx), + utils.LogAttr("hashedParams", utils.ToHexString(hashedParams)), + ) + return + case <-subscriptionTimeoutTicker.C: + utils.LavaFormatTrace("ProviderNodeSubscriptionManager:startListeningForSubscription() timeout reached, ending subscription", + utils.LogAttr("GUID", ctx), + utils.LogAttr("hashedParams", utils.ToHexString(hashedParams)), + ) + closeNodeSubscriptionCallback() + return + case nodeErr := <-nodeErrChan: + utils.LavaFormatWarning("ProviderNodeSubscriptionManager:startListeningForSubscription() got error from node, ending subscription", nodeErr, + utils.LogAttr("GUID", ctx), + utils.LogAttr("hashedParams", utils.ToHexString(hashedParams)), + ) + closeNodeSubscriptionCallback() + return + case nodeMsg := <-nodeChan: + utils.LavaFormatTrace("ProviderNodeSubscriptionManager:startListeningForSubscription() got new message from node", + utils.LogAttr("GUID", ctx), + utils.LogAttr("hashedParams", utils.ToHexString(hashedParams)), + utils.LogAttr("nodeMsg", nodeMsg), + ) + err := pnsm.handleNewNodeMessage(ctx, hashedParams, nodeMsg) + if err != nil { + closeNodeSubscriptionCallback() + } + } + } +} + +func (pnsm *ProviderNodeSubscriptionManager) getHashedParams(chainMessage ChainMessageForSend) (hashedParams string, params []byte, err error) { + rpcInputMessage := chainMessage.GetRPCMessage() + params, err = gojson.Marshal(rpcInputMessage.GetParams()) + if err != nil { + return "", nil, utils.LavaFormatError("could not marshal params", err) + } + + hashedParams = rpcclient.CreateHashFromParams(params) + return hashedParams, params, nil +} + +func (pnsm *ProviderNodeSubscriptionManager) convertNodeMsgToMarshalledJsonRpcResponse(data interface{}, apiCollection *spectypes.ApiCollection) ([]byte, error) { + msg, ok := data.(*rpcclient.JsonrpcMessage) + if !ok { + return nil, fmt.Errorf("data is not a *rpcclient.JsonrpcMessage, but: %T", data) + } + + var convertedMsg any + var err error + + switch apiCollection.GetCollectionData().ApiInterface { + case spectypes.APIInterfaceJsonRPC: + convertedMsg, err = rpcInterfaceMessages.ConvertJsonRPCMsg(msg) + if err != nil { + return nil, err + } + case spectypes.APIInterfaceTendermintRPC: + convertedMsg, err = rpcInterfaceMessages.ConvertTendermintMsg(msg) + if err != nil { + return nil, err + } + default: + return nil, fmt.Errorf("unsupported API interface: %s", apiCollection.GetCollectionData().ApiInterface) + } + + marshalledMsg, err := gojson.Marshal(convertedMsg) + if err != nil { + return nil, err + } + + return marshalledMsg, nil +} + +func (pnsm *ProviderNodeSubscriptionManager) signReply(ctx context.Context, reply *pairingtypes.RelayReply, consumerAddr sdk.AccAddress, chainMessage ChainMessage, request *pairingtypes.RelayRequest) error { + // Send the first setup message to the consumer in a go routine because the blocking listening for this channel happens after this function + dataReliabilityEnabled, _ := pnsm.chainParser.DataReliabilityParams() + blockLagForQosSync, averageBlockTime, blockDistanceToFinalization, blocksInFinalizationData := pnsm.chainParser.ChainBlockStats() + relayTimeout := GetRelayTimeout(chainMessage, averageBlockTime) + + if dataReliabilityEnabled { + var err error + latestBlock, _, requestedHashes, modifiedReqBlock, _, updatedChainMessage, err := pnsm.relayFinalizationBlocksHandler.GetParametersForRelayDataReliability(ctx, request, chainMessage, relayTimeout, blockLagForQosSync, averageBlockTime, blockDistanceToFinalization, blocksInFinalizationData) + if err != nil { + return err + } + + err = pnsm.relayFinalizationBlocksHandler.BuildRelayFinalizedBlockHashes(ctx, request, reply, latestBlock, requestedHashes, updatedChainMessage, relayTimeout, averageBlockTime, blockDistanceToFinalization, blocksInFinalizationData, modifiedReqBlock) + if err != nil { + return err + } + } + + var ignoredMetadata []pairingtypes.Metadata + reply.Metadata, _, ignoredMetadata = pnsm.chainParser.HandleHeaders(reply.Metadata, chainMessage.GetApiCollection(), spectypes.Header_pass_reply) + reply, err := lavaprotocol.SignRelayResponse(consumerAddr, *request, pnsm.privKey, reply, dataReliabilityEnabled) + if err != nil { + return err + } + reply.Metadata = append(reply.Metadata, ignoredMetadata...) // appended here only after signing + return nil +} + +func (pnsm *ProviderNodeSubscriptionManager) handleNewNodeMessage(ctx context.Context, hashedParams string, nodeMsg interface{}) error { + pnsm.lock.RLock() + defer pnsm.lock.RUnlock() + activeSub, foundActiveSubscription := pnsm.activeSubscriptions[hashedParams] + if !foundActiveSubscription { + return utils.LavaFormatWarning("No hashed params in handleNewNodeMessage, connection might have been closed", FailedSendingSubscriptionToClients, utils.LogAttr("hash", hashedParams)) + } + // Sending message to all connected consumers + for consumerAddrString, connectedConsumerAddress := range activeSub.connectedConsumers { + for consumerProcessGuid, connectedConsumerContainer := range connectedConsumerAddress { + utils.LavaFormatTrace("ProviderNodeSubscriptionManager:startListeningForSubscription() sending to consumer", + utils.LogAttr("consumerAddr", consumerAddrString), + utils.LogAttr("consumerProcessGuid", consumerProcessGuid), + utils.LogAttr("hashedParams", utils.ToHexString(hashedParams)), + ) + + copiedRequest := &pairingtypes.RelayRequest{} + // TODO: Optimization, a better way is to avoid ParseMsg multiple times by creating a deep copy for chain message. + // this way we can parse msg only once and use the copy method to save resources. + copyRequestErr := protocopy.DeepCopyProtoObject(connectedConsumerContainer.firstSetupRequest, copiedRequest) + if copyRequestErr != nil { + return utils.LavaFormatError("failed to copy subscription request", copyRequestErr) + } + + extensionInfo := extensionslib.ExtensionInfo{LatestBlock: 0, ExtensionOverride: copiedRequest.RelayData.Extensions} + if extensionInfo.ExtensionOverride == nil { // in case consumer did not set an extension, we skip the extension parsing and we are sending it to the regular url + extensionInfo.ExtensionOverride = []string{} + } + + chainMessage, err := pnsm.chainParser.ParseMsg(copiedRequest.RelayData.ApiUrl, copiedRequest.RelayData.Data, copiedRequest.RelayData.ConnectionType, copiedRequest.RelayData.GetMetadata(), extensionInfo) + if err != nil { + return utils.LavaFormatError("failed to parse message", err) + } + + apiCollection := pnsm.activeSubscriptions[hashedParams].apiCollection + + marshalledNodeMsg, err := pnsm.convertNodeMsgToMarshalledJsonRpcResponse(nodeMsg, apiCollection) + if err != nil { + return utils.LavaFormatError("error converting node message", err) + } + + relayMessageFromNode := &pairingtypes.RelayReply{ + Data: marshalledNodeMsg, + Metadata: []pairingtypes.Metadata{}, + } + + err = pnsm.signReply(ctx, relayMessageFromNode, connectedConsumerContainer.consumerSDKAddress, chainMessage, copiedRequest) + if err != nil { + return utils.LavaFormatError("error signing reply", err) + } + + utils.LavaFormatDebug("Sending relay to consumer", + utils.LogAttr("requestRelayData", copiedRequest.RelayData), + utils.LogAttr("reply", marshalledNodeMsg), + utils.LogAttr("replyLatestBlock", relayMessageFromNode.LatestBlock), + utils.LogAttr("consumerAddr", connectedConsumerContainer.consumerSDKAddress), + ) + + go connectedConsumerContainer.consumerChannel.Send(relayMessageFromNode) + } + } + return nil +} + +func (pnsm *ProviderNodeSubscriptionManager) RemoveConsumer(ctx context.Context, chainMessage ChainMessageForSend, consumerAddr sdk.AccAddress, closeConsumerChannel bool, consumerProcessGuid string) error { + if pnsm == nil { + return nil + } + + hashedParams, params, err := pnsm.getHashedParams(chainMessage) + if err != nil { + return err + } + + utils.LavaFormatTrace("[RemoveConsumer] requested to remove consumer from subscription", + utils.LogAttr("GUID", ctx), + utils.LogAttr("consumerAddr", consumerAddr), + utils.LogAttr("params", params), + utils.LogAttr("hashedParams", utils.ToHexString(hashedParams)), + ) + + consumerAddrString := consumerAddr.String() + + pnsm.lock.Lock() + defer pnsm.lock.Unlock() + + openSubscriptions, ok := pnsm.activeSubscriptions[hashedParams] + if !ok { + utils.LavaFormatTrace("[RemoveConsumer] no subscription found for params, subscription is already closed", + utils.LogAttr("GUID", ctx), + utils.LogAttr("consumerAddr", consumerAddr), + utils.LogAttr("params", params), + utils.LogAttr("hashedParams", utils.ToHexString(hashedParams)), + ) + return nil + } + + // Remove consumer from connected consumers + if _, ok := openSubscriptions.connectedConsumers[consumerAddrString]; ok { + if _, foundGuid := openSubscriptions.connectedConsumers[consumerAddrString][consumerProcessGuid]; foundGuid { + utils.LavaFormatTrace("[RemoveConsumer] found consumer connected consumers", + utils.LogAttr("GUID", ctx), + utils.LogAttr("consumerAddr", consumerAddrString), + utils.LogAttr("params", params), + utils.LogAttr("hashedParams", utils.ToHexString(hashedParams)), + utils.LogAttr("consumerProcessGuid", consumerProcessGuid), + utils.LogAttr("connectedConsumers", openSubscriptions.connectedConsumers), + ) + if closeConsumerChannel { + utils.LavaFormatTrace("[RemoveConsumer] closing consumer channel", + utils.LogAttr("GUID", ctx), + utils.LogAttr("consumerAddr", consumerAddrString), + utils.LogAttr("params", params), + utils.LogAttr("consumerProcessGuid", consumerProcessGuid), + utils.LogAttr("hashedParams", utils.ToHexString(hashedParams)), + ) + openSubscriptions.connectedConsumers[consumerAddrString][consumerProcessGuid].consumerChannel.Close() + } + + // delete guid + delete(pnsm.activeSubscriptions[hashedParams].connectedConsumers[consumerAddrString], consumerProcessGuid) + // check if this was our only subscription for this consumer. + if len(pnsm.activeSubscriptions[hashedParams].connectedConsumers[consumerAddrString]) == 0 { + // delete consumer as well. + delete(pnsm.activeSubscriptions[hashedParams].connectedConsumers, consumerAddrString) + } + if len(pnsm.activeSubscriptions[hashedParams].connectedConsumers) == 0 { + utils.LavaFormatTrace("[RemoveConsumer] no more connected consumers", + utils.LogAttr("GUID", ctx), + utils.LogAttr("consumerAddr", consumerAddr), + utils.LogAttr("params", params), + utils.LogAttr("hashedParams", utils.ToHexString(hashedParams)), + ) + // Cancel the subscription's context and close the subscription + pnsm.activeSubscriptions[hashedParams].cancellableContextCancelFunc() + pnsm.closeNodeSubscription(hashedParams) + } + } + utils.LavaFormatTrace("[RemoveConsumer] removed consumer", utils.LogAttr("consumerAddr", consumerAddr), utils.LogAttr("params", params)) + } else { + utils.LavaFormatTrace("[RemoveConsumer] consumer not found in connected consumers", + utils.LogAttr("GUID", ctx), + utils.LogAttr("consumerAddr", consumerAddr), + utils.LogAttr("params", params), + utils.LogAttr("hashedParams", utils.ToHexString(hashedParams)), + utils.LogAttr("consumerProcessGuid", consumerProcessGuid), + utils.LogAttr("connectedConsumers", openSubscriptions.connectedConsumers), + ) + } + return nil +} + +func (pnsm *ProviderNodeSubscriptionManager) closeNodeSubscription(hashedParams string) error { + activeSub, foundActiveSubscription := pnsm.activeSubscriptions[hashedParams] + if !foundActiveSubscription { + return utils.LavaFormatError("closeNodeSubscription called with hashedParams that does not exist", nil, utils.LogAttr("hashedParams", utils.ToHexString(hashedParams))) + } + + // Disconnect all connected consumers + for consumerAddrString, consumerChannels := range activeSub.connectedConsumers { + for consumerGuid, consumerChannel := range consumerChannels { + utils.LavaFormatTrace("ProviderNodeSubscriptionManager:closeNodeSubscription() closing consumer channel", + utils.LogAttr("consumerAddr", consumerAddrString), + utils.LogAttr("consumerGuid", consumerGuid), + utils.LogAttr("hashedParams", utils.ToHexString(hashedParams)), + ) + consumerChannel.consumerChannel.Close() + } + } + + pnsm.activeSubscriptions[hashedParams].nodeSubscription.Unsubscribe() + close(pnsm.activeSubscriptions[hashedParams].messagesChannel) + delete(pnsm.activeSubscriptions, hashedParams) + return nil +} diff --git a/protocol/chainlib/provider_node_subscription_manager_test.go b/protocol/chainlib/provider_node_subscription_manager_test.go new file mode 100644 index 0000000000..2104adc994 --- /dev/null +++ b/protocol/chainlib/provider_node_subscription_manager_test.go @@ -0,0 +1,441 @@ +package chainlib + +import ( + "context" + "net/http" + "strconv" + "sync" + "testing" + "time" + + "github.com/gorilla/websocket" + "github.com/lavanet/lava/v2/protocol/chainlib/extensionslib" + "github.com/lavanet/lava/v2/protocol/chaintracker" + "github.com/lavanet/lava/v2/utils" + pairingtypes "github.com/lavanet/lava/v2/x/pairing/types" + spectypes "github.com/lavanet/lava/v2/x/spec/types" + "github.com/stretchr/testify/require" +) + +const testGuid = "testGuid" + +type RelayFinalizationBlocksHandlerMock struct{} + +func (rf *RelayFinalizationBlocksHandlerMock) GetParametersForRelayDataReliability( + ctx context.Context, + request *pairingtypes.RelayRequest, + chainMsg ChainMessage, + relayTimeout time.Duration, + blockLagForQosSync int64, + averageBlockTime time.Duration, + blockDistanceToFinalization, + blocksInFinalizationData uint32, +) (latestBlock int64, requestedBlockHash []byte, requestedHashes []*chaintracker.BlockStore, modifiedReqBlock int64, finalized, updatedChainMessage bool, err error) { + return 0, []byte{}, []*chaintracker.BlockStore{}, 0, true, true, nil +} + +func (rf *RelayFinalizationBlocksHandlerMock) BuildRelayFinalizedBlockHashes( + ctx context.Context, + request *pairingtypes.RelayRequest, + reply *pairingtypes.RelayReply, + latestBlock int64, + requestedHashes []*chaintracker.BlockStore, + updatedChainMessage bool, + relayTimeout time.Duration, + averageBlockTime time.Duration, + blockDistanceToFinalization uint32, + blocksInFinalizationData uint32, + modifiedReqBlock int64, +) (err error) { + return nil +} + +func TestSubscriptionManager_HappyFlow(t *testing.T) { + playbook := []struct { + name string + specId string + apiInterface string + connectionType string + subscriptionRequestData []byte + subscriptionFirstReply []byte + }{ + { + name: "TendermintRPC", + specId: "LAV1", + apiInterface: spectypes.APIInterfaceTendermintRPC, + connectionType: "", + subscriptionRequestData: []byte(`{"jsonrpc":"2.0","id":3,"method":"subscribe","params":{"query":"tm.event='NewBlock'"}}`), + subscriptionFirstReply: []byte(`{"jsonrpc":"2.0","id":3,"result":{}}`), + }, + { + name: "JsonRPC", + specId: "ETH1", + apiInterface: spectypes.APIInterfaceJsonRPC, + connectionType: "POST", + subscriptionRequestData: []byte(`{"jsonrpc":"2.0","id":5,"method":"eth_subscribe","params":["newHeads"]}`), + subscriptionFirstReply: []byte(`{"jsonrpc":"2.0","id":5,"result":"0x1234567890"}`), + }, + } + + for _, play := range playbook { + t.Run(play.name, func(t *testing.T) { + ts := SetupForTests(t, 1, play.specId, "../../") + + wg := sync.WaitGroup{} + wg.Add(1) + // msgCount := 0 + upgrader := websocket.Upgrader{} + + // Create a simple websocket server that mocks the node + handleWebSocket := func(w http.ResponseWriter, r *http.Request) { + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + require.NoError(t, err) + return + } + defer conn.Close() + + for { + // Read the request + messageType, message, err := conn.ReadMessage() + if err != nil { + require.NoError(t, err) + return + } + + wg.Done() + + require.Equal(t, string(play.subscriptionRequestData)+"\n", string(message)) + + // Write the first reply + err = conn.WriteMessage(messageType, play.subscriptionFirstReply) + if err != nil { + require.NoError(t, err) + return + } + } + } + + chainParser, chainRouter, _, closeServer, _, err := CreateChainLibMocks(context.Background(), play.specId, play.apiInterface, nil, handleWebSocket, "../../", nil) + require.NoError(t, err) + if closeServer != nil { + defer closeServer() + } + + // Create the relay request and chain message + relayRequest := &pairingtypes.RelayRequest{ + RelayData: &pairingtypes.RelayPrivateData{ + Data: play.subscriptionRequestData, + }, + RelaySession: &pairingtypes.RelaySession{}, + } + + chainMessage, err := chainParser.ParseMsg("", play.subscriptionRequestData, play.connectionType, nil, extensionslib.ExtensionInfo{LatestBlock: 0}) + require.NoError(t, err) + + // Create the provider node subscription manager + mockRpcProvider := &RelayFinalizationBlocksHandlerMock{} + pnsm := NewProviderNodeSubscriptionManager(chainRouter, chainParser, mockRpcProvider, ts.Providers[0].SK) + + consumerChannel := make(chan *pairingtypes.RelayReply) + + // Read the consumer channel that simulates consumer + go func() { + reply := <-consumerChannel + require.NotNil(t, reply) + require.Equal(t, string(play.subscriptionFirstReply), string(reply.Data)) + }() + + // Subscribe to the chain + subscriptionId, err := pnsm.AddConsumer(ts.Ctx, relayRequest, chainMessage, ts.Consumer.Addr, consumerChannel, testGuid) + require.NoError(t, err) + require.NotEmpty(t, subscriptionId) + + wg.Wait() // Make sure the subscription manager sent a message to the node + + // Subscribe to the same subscription again, should return the same subscription id + subscriptionIdNew, err := pnsm.AddConsumer(ts.Ctx, relayRequest, chainMessage, ts.Consumer.Addr, consumerChannel, testGuid) + require.NoError(t, err) + require.NotEmpty(t, subscriptionId) + require.Equal(t, subscriptionId, subscriptionIdNew) + + // Cut the subscription, and re-subscribe, should send another message to node + err = pnsm.RemoveConsumer(ts.Ctx, chainMessage, ts.Consumer.Addr, true, testGuid) + require.NoError(t, err) + + // Make sure both the consumer channels are closed + _, ok := <-consumerChannel + require.False(t, ok) + + consumerChannel = make(chan *pairingtypes.RelayReply) + waitTestToEnd := make(chan bool) + // Read the consumer channel that simulates consumer + go func() { + defer func() { waitTestToEnd <- true }() + reply := <-consumerChannel + require.NotNil(t, reply) + require.Equal(t, string(play.subscriptionFirstReply), string(reply.Data)) + }() + + wg.Add(1) // Should send another message to the node + + subscriptionId, err = pnsm.AddConsumer(ts.Ctx, relayRequest, chainMessage, ts.Consumer.Addr, consumerChannel, "testGuid") + require.NoError(t, err) + require.NotEmpty(t, subscriptionId) + + wg.Wait() // Make sure the subscription manager sent another message to the node + + // making sure our routine ended, otherwise the routine can read the wrong play.subscriptionFirstReply + <-waitTestToEnd + }) + } +} + +func TestSubscriptionManager_MultipleParallelSubscriptionsWithTheSameParams(t *testing.T) { + playbook := []struct { + name string + specId string + apiInterface string + connectionType string + subscriptionRequestData []byte + subscriptionFirstReply []byte + }{ + { + name: "TendermintRPC", + specId: "LAV1", + apiInterface: spectypes.APIInterfaceTendermintRPC, + connectionType: "", + subscriptionRequestData: []byte(`{"jsonrpc":"2.0","id":3,"method":"subscribe","params":{"query":"tm.event='NewBlock'"}}`), + subscriptionFirstReply: []byte(`{"jsonrpc":"2.0","id":3,"result":{}}`), + }, + { + name: "JsonRPC", + specId: "ETH1", + apiInterface: spectypes.APIInterfaceJsonRPC, + connectionType: "POST", + subscriptionRequestData: []byte(`{"jsonrpc":"2.0","id":5,"method":"eth_subscribe","params":["newHeads"]}`), + subscriptionFirstReply: []byte(`{"jsonrpc":"2.0","id":5,"result":"0x1234567890"}`), + }, + } + + for _, play := range playbook { + t.Run(play.name, func(t *testing.T) { + ts := SetupForTests(t, 1, play.specId, "../../") + + wg := sync.WaitGroup{} + // msgCount := 0 + upgrader := websocket.Upgrader{} + + // Create a simple websocket server that mocks the node + handleWebSocket := func(w http.ResponseWriter, r *http.Request) { + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + require.NoError(t, err) + return + } + defer conn.Close() + first := true + for { + // Read the request + messageType, message, err := conn.ReadMessage() + if err != nil { + require.NoError(t, err) + return + } + + require.Equal(t, string(play.subscriptionRequestData)+"\n", string(message)) + + if first { // on first reply we want some delay, so we can make sure the pending is working properly + time.Sleep(time.Second * 2) + first = false + } + utils.LavaFormatDebug("write message") + wg.Done() + + // Write the first reply + err = conn.WriteMessage(messageType, play.subscriptionFirstReply) + if err != nil { + require.NoError(t, err) + return + } + } + } + + chainParser, chainRouter, _, closeServer, _, err := CreateChainLibMocks(context.Background(), play.specId, play.apiInterface, nil, handleWebSocket, "../../", nil) + require.NoError(t, err) + if closeServer != nil { + defer closeServer() + } + + // Create the relay request and chain message + relayRequest := &pairingtypes.RelayRequest{ + RelayData: &pairingtypes.RelayPrivateData{ + Data: play.subscriptionRequestData, + }, + RelaySession: &pairingtypes.RelaySession{}, + } + + chainMessage, err := chainParser.ParseMsg("", play.subscriptionRequestData, play.connectionType, nil, extensionslib.ExtensionInfo{LatestBlock: 0}) + require.NoError(t, err) + + // Create the provider node subscription manager + mockRpcProvider := &RelayFinalizationBlocksHandlerMock{} + pnsm := NewProviderNodeSubscriptionManager(chainRouter, chainParser, mockRpcProvider, ts.Providers[0].SK) + + wg.Add(1) + wgAllIds := sync.WaitGroup{} + wgAllIds.Add(10) + for i := 0; i < 10; i++ { + consumerChannel := make(chan *pairingtypes.RelayReply) + // Read the consumer channel that simulates consumer + go func() { + reply := <-consumerChannel + require.NotNil(t, reply) + require.Equal(t, string(play.subscriptionFirstReply), string(reply.Data)) + wgAllIds.Done() + }() + // Subscribe to the chain + go func(index int) { + subscriptionId, err := pnsm.AddConsumer(ts.Ctx, relayRequest, chainMessage, ts.Consumer.Addr, consumerChannel, testGuid+strconv.Itoa(index)) + require.NoError(t, err) + require.NotEmpty(t, subscriptionId) + }(i) + } + + utils.LavaFormatDebug("Waiting wait group") + wgAllIds.Wait() + wg.Wait() // Make sure the subscription manager sent a message to the node + + // Cut the subscription, and re-subscribe, should send another message to node + err = pnsm.RemoveConsumer(ts.Ctx, chainMessage, ts.Consumer.Addr, true, testGuid) + require.NoError(t, err) + }) + } +} + +func TestSubscriptionManager_MultipleParallelSubscriptionsWithTheSameParamsAndNodeMessageFailure(t *testing.T) { + playbook := []struct { + name string + specId string + apiInterface string + connectionType string + subscriptionRequestData []byte + subscriptionFirstReply []byte + }{ + { + name: "TendermintRPC", + specId: "LAV1", + apiInterface: spectypes.APIInterfaceTendermintRPC, + connectionType: "", + subscriptionRequestData: []byte(`{"jsonrpc":"2.0","id":3,"method":"subscribe","params":{"query":"tm.event='NewBlock'"}}`), + subscriptionFirstReply: []byte(`{"jsonrpc":"2.0","id":3,"result":{}}`), + }, + { + name: "JsonRPC", + specId: "ETH1", + apiInterface: spectypes.APIInterfaceJsonRPC, + connectionType: "POST", + subscriptionRequestData: []byte(`{"jsonrpc":"2.0","id":5,"method":"eth_subscribe","params":["newHeads"]}`), + subscriptionFirstReply: []byte(`{"jsonrpc":"2.0","id":5,"result":"0x1234567890"}`), + }, + } + + for _, play := range playbook { + t.Run(play.name, func(t *testing.T) { + ts := SetupForTests(t, 1, play.specId, "../../") + + wg := sync.WaitGroup{} + // msgCount := 0 + upgrader := websocket.Upgrader{} + first := true + // Create a simple websocket server that mocks the node + handleWebSocket := func(w http.ResponseWriter, r *http.Request) { + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + require.NoError(t, err) + return + } + defer conn.Close() + for { + // Read the request + messageType, message, err := conn.ReadMessage() + if err != nil { + require.NoError(t, err) + return + } + + require.Equal(t, string(play.subscriptionRequestData)+"\n", string(message)) + + if first { // on first reply we want some delay, so we can make sure the pending is working properly + time.Sleep(time.Second * 2) + first = false + conn.Close() // first connection should fail. + return + } + utils.LavaFormatDebug("write message") + wg.Done() + + // Write the first reply + err = conn.WriteMessage(messageType, play.subscriptionFirstReply) + if err != nil { + require.NoError(t, err) + return + } + } + } + + chainParser, chainRouter, _, closeServer, _, err := CreateChainLibMocks(context.Background(), play.specId, play.apiInterface, nil, handleWebSocket, "../../", nil) + require.NoError(t, err) + if closeServer != nil { + defer closeServer() + } + + // Create the relay request and chain message + relayRequest := &pairingtypes.RelayRequest{ + RelayData: &pairingtypes.RelayPrivateData{ + Data: play.subscriptionRequestData, + }, + RelaySession: &pairingtypes.RelaySession{}, + } + + chainMessage, err := chainParser.ParseMsg("", play.subscriptionRequestData, play.connectionType, nil, extensionslib.ExtensionInfo{LatestBlock: 0}) + require.NoError(t, err) + + // Create the provider node subscription manager + mockRpcProvider := &RelayFinalizationBlocksHandlerMock{} + pnsm := NewProviderNodeSubscriptionManager(chainRouter, chainParser, mockRpcProvider, ts.Providers[0].SK) + + wg.Add(1) + wgAllIds := sync.WaitGroup{} + wgAllIds.Add(9) + errors := []error{} + for i := 0; i < 10; i++ { + consumerChannel := make(chan *pairingtypes.RelayReply) + // Read the consumer channel that simulates consumer + go func() { + reply := <-consumerChannel + require.NotNil(t, reply) + require.Equal(t, string(play.subscriptionFirstReply), string(reply.Data)) + wgAllIds.Done() + }() + // Subscribe to the chain + go func(index int) { + _, err := pnsm.AddConsumer(ts.Ctx, relayRequest, chainMessage, ts.Consumer.Addr, consumerChannel, testGuid+strconv.Itoa(index)) + if err != nil { + errors = append(errors, err) + } + }(i) + } + + utils.LavaFormatDebug("Waiting wait group") + wgAllIds.Wait() + wg.Wait() // Make sure the subscription manager sent a message to the node + // make sure we had only one error, on the first subscription attempt + require.Len(t, errors, 1) + + // Cut the subscription, and re-subscribe, should send another message to node + err = pnsm.RemoveConsumer(ts.Ctx, chainMessage, ts.Consumer.Addr, true, testGuid) + require.NoError(t, err) + }) + } +} diff --git a/protocol/chainlib/rest.go b/protocol/chainlib/rest.go index 1f9c1ebcde..c3cd6485a3 100644 --- a/protocol/chainlib/rest.go +++ b/protocol/chainlib/rest.go @@ -146,6 +146,7 @@ func (*RestChainParser) newChainMessage(serviceApi *spectypes.Api, requestBlock msg: restMessage, latestRequestedBlock: requestBlock, resultErrorParsingMethod: restMessage.CheckResponseError, + parseDirective: GetParseDirective(serviceApi, apiCollection), } return nodeMsg } @@ -453,6 +454,9 @@ func NewRestChainProxy(ctx context.Context, nConns uint, rpcProviderEndpoint lav if len(rpcProviderEndpoint.NodeUrls) == 0 { return nil, utils.LavaFormatError("rpcProviderEndpoint.NodeUrl list is empty missing node url", nil, utils.Attribute{Key: "chainID", Value: rpcProviderEndpoint.ChainID}, utils.Attribute{Key: "ApiInterface", Value: rpcProviderEndpoint.ApiInterface}) } + + validateEndpoints(rpcProviderEndpoint.NodeUrls, spectypes.APIInterfaceRest) + _, averageBlockTime, _, _ := chainParser.ChainBlockStats() nodeUrl := rpcProviderEndpoint.NodeUrls[0] nodeUrl.Url = strings.TrimSuffix(rpcProviderEndpoint.NodeUrls[0].Url, "/") diff --git a/protocol/chainlib/rest_test.go b/protocol/chainlib/rest_test.go index d115003489..5b5102ae67 100644 --- a/protocol/chainlib/rest_test.go +++ b/protocol/chainlib/rest_test.go @@ -137,7 +137,7 @@ func TestRestChainProxy(t *testing.T) { w.WriteHeader(http.StatusOK) fmt.Fprint(w, `{"block": { "header": {"height": "244591"}}}`) }) - chainParser, chainProxy, chainFetcher, closeServer, _, err := CreateChainLibMocks(ctx, "LAV1", spectypes.APIInterfaceRest, serverHandler, "../../", nil) + chainParser, chainProxy, chainFetcher, closeServer, _, err := CreateChainLibMocks(ctx, "LAV1", spectypes.APIInterfaceRest, serverHandler, nil, "../../", nil) require.NoError(t, err) require.NotNil(t, chainParser) require.NotNil(t, chainProxy) @@ -168,15 +168,16 @@ func TestParsingRequestedBlocksHeadersRest(t *testing.T) { fmt.Fprint(w, `{"block": { "header": {"height": "244591"}}}`) } }) - chainParser, chainRouter, _, closeServer, _, err := CreateChainLibMocks(ctx, "LAV1", spectypes.APIInterfaceRest, serverHandler, "../../", nil) + chainParser, chainRouter, _, closeServer, _, err := CreateChainLibMocks(ctx, "LAV1", spectypes.APIInterfaceRest, serverHandler, nil, "../../", nil) require.NoError(t, err) defer func() { if closeServer != nil { closeServer() } }() - parsingForCrafting, collectionData, ok := chainParser.GetParsingByTag(spectypes.FUNCTION_TAG_GET_BLOCKNUM) + parsingForCrafting, apiCollection, ok := chainParser.GetParsingByTag(spectypes.FUNCTION_TAG_GET_BLOCKNUM) require.True(t, ok) + collectionData := apiCollection.CollectionData headerParsingDirective, _, ok := chainParser.GetParsingByTag(spectypes.FUNCTION_TAG_SET_LATEST_IN_METADATA) callbackHeaderNameToCheck = headerParsingDirective.GetApiName() // this causes the callback to modify the response to simulate a real behavior require.True(t, ok) @@ -238,15 +239,16 @@ func TestSettingRequestedBlocksHeadersRest(t *testing.T) { } fmt.Fprint(w, `{"block": { "header": {"height": "244591"}}}`) }) - chainParser, chainRouter, _, closeServer, _, err := CreateChainLibMocks(ctx, "LAV1", spectypes.APIInterfaceRest, serverHandler, "../../", nil) + chainParser, chainRouter, _, closeServer, _, err := CreateChainLibMocks(ctx, "LAV1", spectypes.APIInterfaceRest, serverHandler, nil, "../../", nil) require.NoError(t, err) defer func() { if closeServer != nil { closeServer() } }() - parsingForCrafting, collectionData, ok := chainParser.GetParsingByTag(spectypes.FUNCTION_TAG_GET_BLOCKNUM) + parsingForCrafting, apiCollection, ok := chainParser.GetParsingByTag(spectypes.FUNCTION_TAG_GET_BLOCKNUM) require.True(t, ok) + collectionData := apiCollection.CollectionData headerParsingDirective, _, ok := chainParser.GetParsingByTag(spectypes.FUNCTION_TAG_SET_LATEST_IN_METADATA) callbackHeaderNameToCheck = headerParsingDirective.GetApiName() // this causes the callback to modify the response to simulate a real behavior require.True(t, ok) diff --git a/protocol/chainlib/tendermintRPC.go b/protocol/chainlib/tendermintRPC.go index 92ca30893e..50bd4e1312 100644 --- a/protocol/chainlib/tendermintRPC.go +++ b/protocol/chainlib/tendermintRPC.go @@ -139,7 +139,11 @@ func (apip *TendermintChainParser) ParseMsg(urlPath string, data []byte, connect // Check api is supported and save it in nodeMsg apiCont, err := apip.getSupportedApi(msg.Method, connectionType) if err != nil { - utils.LavaFormatDebug("getSupportedApi jsonrpc failed", utils.LogAttr("method", msg.Method), utils.LogAttr("error", err)) + utils.LavaFormatDebug("getSupportedApi tendermintrpc failed", + utils.LogAttr("method", msg.Method), + utils.LogAttr("connectionType", connectionType), + utils.LogAttr("error", err), + ) return nil, err } @@ -245,6 +249,7 @@ func (*TendermintChainParser) newBatchChainMessage(serviceApi *spectypes.Api, re msg: &batchMessage, earliestRequestedBlock: earliestRequestedBlock, resultErrorParsingMethod: rpcInterfaceMessages.CheckResponseErrorForJsonRpcBatch, + parseDirective: GetParseDirective(serviceApi, apiCollection), } return nodeMsg, err } @@ -256,6 +261,7 @@ func (*TendermintChainParser) newChainMessage(serviceApi *spectypes.Api, request latestRequestedBlock: requestedBlock, msg: msg, resultErrorParsingMethod: msg.CheckResponseError, + parseDirective: GetParseDirective(serviceApi, apiCollection), } return nodeMsg } @@ -311,11 +317,12 @@ func (apip *TendermintChainParser) ChainBlockStats() (allowedBlockLagForQosSync } type TendermintRpcChainListener struct { - endpoint *lavasession.RPCEndpoint - relaySender RelaySender - healthReporter HealthReporter - logger *metrics.RPCConsumerLogs - refererData *RefererData + endpoint *lavasession.RPCEndpoint + relaySender RelaySender + healthReporter HealthReporter + logger *metrics.RPCConsumerLogs + refererData *RefererData + consumerWsSubscriptionManager *ConsumerWSSubscriptionManager } // NewTendermintRpcChainListener creates a new instance of TendermintRpcChainListener @@ -323,6 +330,7 @@ func NewTendermintRpcChainListener(ctx context.Context, listenEndpoint *lavasess relaySender RelaySender, healthReporter HealthReporter, rpcConsumerLogs *metrics.RPCConsumerLogs, refererData *RefererData, + consumerWsSubscriptionManager *ConsumerWSSubscriptionManager, ) (chainListener *TendermintRpcChainListener) { // Create a new instance of JsonRPCChainListener chainListener = &TendermintRpcChainListener{ @@ -331,6 +339,7 @@ func NewTendermintRpcChainListener(ctx context.Context, listenEndpoint *lavasess healthReporter, rpcConsumerLogs, refererData, + consumerWsSubscriptionManager, } return chainListener @@ -358,90 +367,25 @@ func (apil *TendermintRpcChainListener) Serve(ctx context.Context, cmdFlags comm return fiber.ErrUpgradeRequired }) webSocketCallback := websocket.New(func(websocketConn *websocket.Conn) { - var ( - mt int - msg []byte - err error - ) - msgSeed := apil.logger.GetMessageSeed() - startTime := time.Now() - for { - if mt, msg, err = websocketConn.ReadMessage(); err != nil { - apil.logger.AnalyzeWebSocketErrorAndWriteMessage(websocketConn, mt, err, msgSeed, msg, "tendermint", time.Since(startTime)) - break - } - dappID, ok := websocketConn.Locals("dappId").(string) - if !ok { - apil.logger.AnalyzeWebSocketErrorAndWriteMessage(websocketConn, mt, nil, msgSeed, []byte("Unable to extract dappID"), spectypes.APIInterfaceJsonRPC, time.Since(startTime)) - } - - ctx, cancel := context.WithCancel(context.Background()) - guid := utils.GenerateUniqueIdentifier() - ctx = utils.WithUniqueIdentifier(ctx, guid) - defer cancel() // incase there's a problem make sure to cancel the connection - - logFormattedMsg := string(msg) - if !cmdFlags.DebugRelays { - logFormattedMsg = utils.FormatLongString(logFormattedMsg, relayMsgLogMaxChars) - } - - utils.LavaFormatDebug("ws in <<<", - utils.LogAttr("GUID", ctx), - utils.LogAttr("seed", msgSeed), - utils.LogAttr("msg", logFormattedMsg), - utils.LogAttr("dappID", dappID), - ) - msgSeed = strconv.FormatUint(guid, 10) - refererMatch, ok := websocketConn.Locals(refererMatchString).(string) - metricsData := metrics.NewRelayAnalytics(dappID, chainID, apiInterface) - relayResult, err := apil.relaySender.SendRelay(ctx, "", string(msg), "", dappID, websocketConn.RemoteAddr().String(), metricsData, nil) - if ok && refererMatch != "" && apil.refererData != nil && err == nil { - go apil.refererData.SendReferer(refererMatch, chainID, string(msg), websocketConn.RemoteAddr().String(), nil, websocketConn) - } - reply := relayResult.GetReply() - replyServer := relayResult.GetReplyServer() - go apil.logger.AddMetricForWebSocket(metricsData, err, websocketConn) - if err != nil { - apil.logger.AnalyzeWebSocketErrorAndWriteMessage(websocketConn, mt, err, msgSeed, msg, "tendermint", time.Since(startTime)) - continue - } - // If subscribe the first reply would contain the RPC ID that can be used for disconnect. - if replyServer != nil { - var reply pairingtypes.RelayReply - err = (*replyServer).RecvMsg(&reply) // this reply contains the RPC ID - if err != nil { - apil.logger.AnalyzeWebSocketErrorAndWriteMessage(websocketConn, mt, err, msgSeed, msg, "tendermint", time.Since(startTime)) - continue - } + utils.LavaFormatDebug("tendermintrpc websocket opened", utils.LogAttr("consumerIp", websocketConn.LocalAddr().String())) + defer utils.LavaFormatDebug("tendermintrpc websocket closed", utils.LogAttr("consumerIp", websocketConn.LocalAddr().String())) + + consumerWebsocketManager := NewConsumerWebsocketManager(ConsumerWebsocketManagerOptions{ + WebsocketConn: websocketConn, + RpcConsumerLogs: apil.logger, + RefererMatchString: refererMatchString, + CmdFlags: cmdFlags, + RelayMsgLogMaxChars: relayMsgLogMaxChars, + ChainID: chainID, + ApiInterface: apiInterface, + ConnectionType: "", // We use it for the ParseMsg method, which needs to know the connection type to find the method in the spec + RefererData: apil.refererData, + RelaySender: apil.relaySender, + ConsumerWsSubscriptionManager: apil.consumerWsSubscriptionManager, + WebsocketConnectionUID: strconv.FormatUint(utils.GenerateUniqueIdentifier(), 10), + }) - if err = websocketConn.WriteMessage(mt, reply.Data); err != nil { - apil.logger.AnalyzeWebSocketErrorAndWriteMessage(websocketConn, mt, err, msgSeed, msg, "tendermint", time.Since(startTime)) - continue - } - apil.logger.LogRequestAndResponse("tendermint ws", false, "ws", websocketConn.LocalAddr().String(), string(msg), string(reply.Data), msgSeed, time.Since(startTime), nil) - for { - err = (*replyServer).RecvMsg(&reply) - if err != nil { - apil.logger.AnalyzeWebSocketErrorAndWriteMessage(websocketConn, mt, err, msgSeed, msg, "tendermint", time.Since(startTime)) - break - } - - // If portal cant write to the client - if err = websocketConn.WriteMessage(mt, reply.Data); err != nil { - cancel() - apil.logger.AnalyzeWebSocketErrorAndWriteMessage(websocketConn, mt, err, msgSeed, msg, "tendermint", time.Since(startTime)) - // break - } - apil.logger.LogRequestAndResponse("tendermint ws", false, "ws", websocketConn.LocalAddr().String(), string(msg), string(reply.Data), msgSeed, time.Since(startTime), nil) - } - } else { - if err = websocketConn.WriteMessage(mt, reply.Data); err != nil { - apil.logger.AnalyzeWebSocketErrorAndWriteMessage(websocketConn, mt, err, msgSeed, msg, "tendermint", time.Since(startTime)) - continue - } - apil.logger.LogRequestAndResponse("tendermint ws", false, "ws", websocketConn.LocalAddr().String(), string(msg), string(reply.Data), msgSeed, time.Since(startTime), nil) - } - } + consumerWebsocketManager.ListenToMessages() }) websocketCallbackWithDappID := constructFiberCallbackWithHeaderAndParameterExtraction(webSocketCallback, apil.logger.StoreMetricData) app.Get("/ws", websocketCallbackWithDappID) @@ -615,9 +559,7 @@ func (apil *TendermintRpcChainListener) Serve(ctx context.Context, cmdFlags comm type tendermintRpcChainProxy struct { // embedding the jrpc chain proxy because the only diff is on parse message JrpcChainProxy - httpNodeUrl common.NodeUrl - httpConnector *chainproxy.Connector - httpClient *http.Client + httpClient *http.Client } func NewtendermintRpcChainProxy(ctx context.Context, nConns uint, rpcProviderEndpoint lavasession.RPCProviderEndpoint, chainParser ChainParser) (ChainProxy, error) { @@ -625,26 +567,23 @@ func NewtendermintRpcChainProxy(ctx context.Context, nConns uint, rpcProviderEnd return nil, utils.LavaFormatError("rpcProviderEndpoint.NodeUrl list is empty missing node url", nil, utils.Attribute{Key: "chainID", Value: rpcProviderEndpoint.ChainID}, utils.Attribute{Key: "ApiInterface", Value: rpcProviderEndpoint.ApiInterface}) } _, averageBlockTime, _, _ := chainParser.ChainBlockStats() - websocketUrl, httpUrl := verifyTendermintEndpoint(rpcProviderEndpoint.NodeUrls) + + validateEndpoints(rpcProviderEndpoint.NodeUrls, spectypes.APIInterfaceTendermintRPC) + + nodeUrl := rpcProviderEndpoint.NodeUrls[0] cp := &tendermintRpcChainProxy{ - JrpcChainProxy: JrpcChainProxy{BaseChainProxy: BaseChainProxy{averageBlockTime: averageBlockTime, NodeUrl: websocketUrl, ErrorHandler: &TendermintRPCErrorHandler{}, ChainID: rpcProviderEndpoint.ChainID}, conn: map[string]*chainproxy.Connector{}}, - httpNodeUrl: httpUrl, - httpConnector: nil, + JrpcChainProxy: JrpcChainProxy{ + BaseChainProxy: BaseChainProxy{ + averageBlockTime: averageBlockTime, + NodeUrl: nodeUrl, + ErrorHandler: &TendermintRPCErrorHandler{}, + ChainID: rpcProviderEndpoint.ChainID, + }, + conn: map[string]*chainproxy.Connector{}, + }, } - cp.addHttpConnector(ctx, nConns, httpUrl) - return cp, cp.start(ctx, nConns, websocketUrl, nil) -} -func (cp *tendermintRpcChainProxy) addHttpConnector(ctx context.Context, nConns uint, nodeUrl common.NodeUrl) error { - conn, err := chainproxy.NewConnector(ctx, nConns, nodeUrl) - if err != nil { - return err - } - cp.httpConnector = conn - if cp.httpConnector == nil { - return errors.New("g_conn == nil") - } - return nil + return cp, cp.start(ctx, nConns, nodeUrl, nil) } func (cp *tendermintRpcChainProxy) SendNodeMsg(ctx context.Context, ch chan interface{}, chainMessage ChainMessageForSend) (relayReply *RelayReplyWrapper, subscriptionID string, relayReplyServer *rpcclient.ClientSubscription, err error) { @@ -675,25 +614,28 @@ func (cp *tendermintRpcChainProxy) SendURI(ctx context.Context, nodeMessage *rpc // return an error if the channel is not nil return nil, "", nil, utils.LavaFormatError("Subscribe is not allowed on Tendermint URI", nil) } + if cp.httpClient == nil { cp.httpClient = &http.Client{ Timeout: 5 * time.Minute, // we are doing a timeout by request } } + httpClient := cp.httpClient // appending hashed url - grpc.SetTrailer(ctx, metadata.Pairs(RPCProviderNodeAddressHash, cp.httpConnector.GetUrlHash())) + internalPath := chainMessage.GetApiCollection().GetCollectionData().InternalPath + grpc.SetTrailer(ctx, metadata.Pairs(RPCProviderNodeAddressHash, cp.conn[internalPath].GetUrlHash())) // construct the url by concatenating the node url with the path variable - url := cp.httpNodeUrl.Url + "/" + nodeMessage.Path + url := cp.NodeUrl.Url + "/" + nodeMessage.Path // set context with timeout connectCtx, cancel := cp.CapTimeoutForSend(ctx, chainMessage) defer cancel() // create a new http request - req, err := http.NewRequestWithContext(connectCtx, http.MethodGet, cp.httpNodeUrl.AuthConfig.AddAuthPath(url), nil) + req, err := http.NewRequestWithContext(connectCtx, http.MethodGet, cp.NodeUrl.AuthConfig.AddAuthPath(url), nil) if err != nil { // Validate if the error is related to the provider connection to the node or it is a valid error // in case the error is valid (e.g. bad input parameters) the error will return in the form of a valid error reply @@ -710,9 +652,9 @@ func (cp *tendermintRpcChainProxy) SendURI(ctx context.Context, nodeMessage *rpc } } - cp.httpNodeUrl.SetAuthHeaders(ctx, req.Header.Set) + cp.NodeUrl.SetAuthHeaders(ctx, req.Header.Set) - cp.httpNodeUrl.SetIpForwardingIfNecessary(ctx, req.Header.Set) + cp.NodeUrl.SetIpForwardingIfNecessary(ctx, req.Header.Set) // send the http request and get the response res, err := httpClient.Do(req) if res != nil { @@ -759,26 +701,19 @@ func (cp *tendermintRpcChainProxy) SendURI(ctx context.Context, nodeMessage *rpc func (cp *tendermintRpcChainProxy) SendRPC(ctx context.Context, nodeMessage *rpcInterfaceMessages.TendermintrpcMessage, ch chan interface{}, chainMessage ChainMessageForSend) (relayReply *RelayReplyWrapper, subscriptionID string, relayReplyServer *rpcclient.ClientSubscription, err error) { // Get rpc connection from the connection pool var rpc *rpcclient.Client - if ch != nil { - internalPath := chainMessage.GetApiCollection().CollectionData.InternalPath - rpc, err = cp.conn[internalPath].GetRpc(ctx, true) - if err != nil { - return nil, "", nil, err - } - // return the rpc connection to the websocket pool after the function completes - defer cp.conn[internalPath].ReturnRpc(rpc) - // appending hashed url - grpc.SetTrailer(ctx, metadata.Pairs(RPCProviderNodeAddressHash, cp.conn[internalPath].GetUrlHash())) - } else { - rpc, err = cp.httpConnector.GetRpc(ctx, true) - if err != nil { - return nil, "", nil, err - } - // return the rpc connection to the http pool after the function completes - defer cp.httpConnector.ReturnRpc(rpc) - // appending hashed url - grpc.SetTrailer(ctx, metadata.Pairs(RPCProviderNodeAddressHash, cp.httpConnector.GetUrlHash())) + internalPath := chainMessage.GetApiCollection().CollectionData.InternalPath + + connector := cp.conn[internalPath] + + rpc, err = connector.GetRpc(ctx, true) + if err != nil { + return nil, "", nil, err } + // return the rpc connection to the websocket pool after the function completes + defer connector.ReturnRpc(rpc) + + // appending hashed url + grpc.SetTrailer(ctx, metadata.Pairs(RPCProviderNodeAddressHash, connector.GetUrlHash())) // create variables for the rpc message and reply message var rpcMessage *rpcclient.JsonrpcMessage @@ -792,9 +727,17 @@ func (cp *tendermintRpcChainProxy) SendRPC(ctx context.Context, nodeMessage *rpc } } // If ch is not nil do subscription + var nodeErr error if ch != nil { // subscribe to the rpc call if the channel is not nil - sub, rpcMessage, err = rpc.Subscribe(context.Background(), nodeMessage.ID, nodeMessage.Method, ch, nodeMessage.Params) + utils.LavaFormatTrace("Sending subscription", + utils.LogAttr("chainID", cp.BaseChainProxy.ChainID), + utils.LogAttr("apiName", chainMessage.GetApi().Name), + utils.LogAttr("nodeMessage.ID", nodeMessage.ID), + utils.LogAttr("nodeMessage.Method", nodeMessage.Method), + utils.LogAttr("nodeMessage.Params", nodeMessage.Params), + ) + sub, rpcMessage, nodeErr = rpc.Subscribe(context.Background(), nodeMessage.ID, nodeMessage.Method, ch, nodeMessage.Params) } else { // set context with timeout connectCtx, cancel := cp.CapTimeoutForSend(ctx, chainMessage) @@ -802,7 +745,7 @@ func (cp *tendermintRpcChainProxy) SendRPC(ctx context.Context, nodeMessage *rpc cp.NodeUrl.SetIpForwardingIfNecessary(ctx, rpc.SetHeader) // perform the rpc call - rpcMessage, err = rpc.CallContext(connectCtx, nodeMessage.ID, nodeMessage.Method, nodeMessage.Params, false, nodeMessage.GetDisableErrorHandling()) + rpcMessage, nodeErr = rpc.CallContext(connectCtx, nodeMessage.ID, nodeMessage.Method, nodeMessage.Params, false, nodeMessage.GetDisableErrorHandling()) if err != nil { if common.StatusCodeError504.Is(err) || common.StatusCodeError429.Is(err) || common.StatusCodeErrorStrict.Is(err) { return nil, "", nil, utils.LavaFormatWarning("Received invalid status code", err, utils.Attribute{Key: "chainID", Value: cp.BaseChainProxy.ChainID}, utils.Attribute{Key: "apiName", Value: chainMessage.GetApi().Name}) @@ -817,33 +760,35 @@ func (cp *tendermintRpcChainProxy) SendRPC(ctx context.Context, nodeMessage *rpc var replyMsg *rpcInterfaceMessages.RPCResponse // the error check here would only wrap errors not from the rpc + + if nodeErr != nil { + utils.LavaFormatDebug("got error from node", utils.LogAttr("GUID", ctx), utils.LogAttr("nodeErr", nodeErr)) + return nil, "", nil, nodeErr + } + + replyMessage, err = rpcInterfaceMessages.ConvertTendermintMsg(rpcMessage) if err != nil { - utils.LavaFormatDebug("received an error from SendNodeMsg", utils.Attribute{Key: "GUID", Value: ctx}, utils.Attribute{Key: "error", Value: err}) - return nil, "", nil, err - } else { - replyMessage, err = rpcInterfaceMessages.ConvertTendermintMsg(rpcMessage) - if err != nil { - return nil, "", nil, utils.LavaFormatError("tendermingRPC error", err) - } - // if we didn't get a node error. - if replyMessage.Error == nil { - // validate result is valid - responseIsNilValidationError := ValidateNilResponse(string(replyMessage.Result)) - if responseIsNilValidationError != nil { - return nil, "", nil, responseIsNilValidationError - } - } - replyMsg = replyMessage + return nil, "", nil, utils.LavaFormatError("tendermintRPC error", err) + } - err := cp.ValidateRequestAndResponseIds(nodeMessage.ID, rpcMessage.ID) - if err != nil { - return nil, "", nil, utils.LavaFormatError("tendermintRPC ID mismatch error", err, - utils.Attribute{Key: "GUID", Value: ctx}, - utils.Attribute{Key: "requestId", Value: nodeMessage.ID}, - utils.Attribute{Key: "responseId", Value: rpcMessage.ID}, - ) + // if we didn't get a node error. + if replyMessage.Error == nil { + // validate result is valid + responseIsNilValidationError := ValidateNilResponse(string(replyMessage.Result)) + if responseIsNilValidationError != nil { + return nil, "", nil, responseIsNilValidationError } } + replyMsg = replyMessage + + err = cp.ValidateRequestAndResponseIds(nodeMessage.ID, rpcMessage.ID) + if err != nil { + return nil, "", nil, utils.LavaFormatError("tendermintRPC ID mismatch error", err, + utils.Attribute{Key: "GUID", Value: ctx}, + utils.Attribute{Key: "requestId", Value: nodeMessage.ID}, + utils.Attribute{Key: "responseId", Value: rpcMessage.ID}, + ) + } // marshal the jsonrpc message to json data, err := json.Marshal(replyMsg) @@ -869,7 +814,9 @@ func (cp *tendermintRpcChainProxy) SendRPC(ctx context.Context, nodeMessage *rpc } subscriptionID, ok = paramsMap["query"].(string) if !ok { - return nil, "", nil, utils.LavaFormatError("unknown subscriptionID type on tendermint subscribe", nil) + utils.LavaFormatTrace("could not get subscriptionID from query params", utils.LogAttr("params", params)) + // This is probably because of a misuse, therefore the provider will return a node error to the user as the subscription failed + subscriptionID = "" } } diff --git a/protocol/chainlib/tendermintRPC_test.go b/protocol/chainlib/tendermintRPC_test.go index 0dc1f60744..a40de882a8 100644 --- a/protocol/chainlib/tendermintRPC_test.go +++ b/protocol/chainlib/tendermintRPC_test.go @@ -149,7 +149,7 @@ func TestTendermintRpcChainProxy(t *testing.T) { }`) }) - chainParser, chainProxy, chainFetcher, closeServer, _, err := CreateChainLibMocks(ctx, "LAV1", spectypes.APIInterfaceTendermintRPC, serverHandle, "../../", nil) + chainParser, chainProxy, chainFetcher, closeServer, _, err := CreateChainLibMocks(ctx, "LAV1", spectypes.APIInterfaceTendermintRPC, serverHandle, nil, "../../", nil) require.NoError(t, err) require.NotNil(t, chainParser) require.NotNil(t, chainProxy) @@ -180,7 +180,7 @@ func TestTendermintRpcBatchCall(t *testing.T) { fmt.Fprint(w, response) }) - chainParser, chainProxy, chainFetcher, closeServer, _, err := CreateChainLibMocks(ctx, "LAV1", spectypes.APIInterfaceTendermintRPC, serverHandle, "../../", nil) + chainParser, chainProxy, chainFetcher, closeServer, _, err := CreateChainLibMocks(ctx, "LAV1", spectypes.APIInterfaceTendermintRPC, serverHandle, nil, "../../", nil) require.NoError(t, err) require.NotNil(t, chainParser) require.NotNil(t, chainProxy) @@ -222,7 +222,7 @@ func TestTendermintRpcBatchCallWithSameID(t *testing.T) { fmt.Fprint(w, nodeResponse) }) - chainParser, chainProxy, chainFetcher, closeServer, _, err := CreateChainLibMocks(ctx, "LAV1", spectypes.APIInterfaceTendermintRPC, serverHandle, "../../", nil) + chainParser, chainProxy, chainFetcher, closeServer, _, err := CreateChainLibMocks(ctx, "LAV1", spectypes.APIInterfaceTendermintRPC, serverHandle, nil, "../../", nil) require.NoError(t, err) require.NotNil(t, chainParser) require.NotNil(t, chainProxy) @@ -256,7 +256,7 @@ func TestTendermintURIRPC(t *testing.T) { }`) }) - chainParser, chainProxy, chainFetcher, closeServer, _, err := CreateChainLibMocks(ctx, "LAV1", spectypes.APIInterfaceTendermintRPC, serverHandle, "../../", nil) + chainParser, chainProxy, chainFetcher, closeServer, _, err := CreateChainLibMocks(ctx, "LAV1", spectypes.APIInterfaceTendermintRPC, serverHandle, nil, "../../", nil) require.NoError(t, err) require.NotNil(t, chainParser) require.NotNil(t, chainProxy) diff --git a/protocol/chaintracker/chain_tracker.go b/protocol/chaintracker/chain_tracker.go index d8a4d74cf7..5ce9c0ff1d 100644 --- a/protocol/chaintracker/chain_tracker.go +++ b/protocol/chaintracker/chain_tracker.go @@ -3,19 +3,16 @@ package chaintracker import ( "context" "errors" - fmt "fmt" "net" "net/http" "os" "os/signal" - "strconv" "sync" "sync/atomic" "time" rand "github.com/lavanet/lava/v2/utils/rand" - sdkerrors "cosmossdk.io/errors" "github.com/improbable-eng/grpc-web/go/grpcweb" "github.com/lavanet/lava/v2/protocol/common" "github.com/lavanet/lava/v2/protocol/lavasession" @@ -90,10 +87,14 @@ func (cs *ChainTracker) GetLatestBlockData(fromBlock, toBlock, specificBlock int wantedBlocksData := WantedBlocksData{} err = wantedBlocksData.New(fromBlock, toBlock, specificBlock, latestBlock, earliestBlockSaved) if err != nil { - return latestBlock, nil, time.Time{}, sdkerrors.Wrap(err, fmt.Sprintf("invalid input for GetLatestBlockData %v", &map[string]string{ - "fromBlock": strconv.FormatInt(fromBlock, 10), "toBlock": strconv.FormatInt(toBlock, 10), "specificBlock": strconv.FormatInt(specificBlock, 10), - "latestBlock": strconv.FormatInt(latestBlock, 10), "earliestBlockSaved": strconv.FormatInt(earliestBlockSaved, 10), - })) + return latestBlock, nil, time.Time{}, utils.LavaFormatDebug("invalid input for GetLatestBlockData", + utils.LogAttr("err", err), + utils.LogAttr("fromBlock", fromBlock), + utils.LogAttr("toBlock", toBlock), + utils.LogAttr("specificBlock", specificBlock), + utils.LogAttr("latestBlock", latestBlock), + utils.LogAttr("earliestBlockSaved", earliestBlockSaved), + ) } for _, blocksQueueIdx := range wantedBlocksData.IterationIndexes() { @@ -104,8 +105,8 @@ func (cs *ChainTracker) GetLatestBlockData(fromBlock, toBlock, specificBlock int } requestedHashes = append(requestedHashes, &blockStore) } - changeTime = cs.latestChangeTime - return + + return latestBlock, requestedHashes, cs.latestChangeTime, nil } func (cs *ChainTracker) RegisterForBlockTimeUpdates(updatable blockTimeUpdatable) { diff --git a/protocol/chaintracker/errors.go b/protocol/chaintracker/errors.go index 034d7cc30c..2ece9ec7b9 100644 --- a/protocol/chaintracker/errors.go +++ b/protocol/chaintracker/errors.go @@ -10,7 +10,7 @@ var ( // Consumer Side Errors InvalidLatestBlockNumValue = sdkerrors.New("Invalid value for latestBlockNum", 10703, "returned latest block num should be greater than 0, but it's not") InvalidReturnedHashes = sdkerrors.New("Invalid value for requestedHashes length", 10704, "returned requestedHashes key count should be greater than 0, but it's not") ErrorFailedToFetchLatestBlock = sdkerrors.New("Error FailedToFetchLatestBlock", 10705, "Failed to fetch latest block from node") - InvalidRequestedBlocks = sdkerrors.New("Error InvalidRequestedBlocks", 10706, "provided requested blocks for function do not compse a valid request") + InvalidRequestedBlocks = sdkerrors.New("Error InvalidRequestedBlocks", 10706, "provided requested blocks for function do not compose a valid request") RequestedBlocksOutOfRange = sdkerrors.New("RequestedBlocksOutOfRange", 10707, "requested blocks are outside the supported range by the state tracker") ErrorFailedToFetchTooEarlyBlock = sdkerrors.New("Error ErrorFailedToFetchTooEarlyBlock", 10708, "server memory protection triggered, requested block is too early") InvalidRequestedSpecificBlock = sdkerrors.New("Error InvalidRequestedSpecificBlock", 10709, "provided requested specific blocks for function do not compose a stored entry") diff --git a/protocol/common/endpoints.go b/protocol/common/endpoints.go index 8c01e5a65a..8ab6fb03ff 100644 --- a/protocol/common/endpoints.go +++ b/protocol/common/endpoints.go @@ -26,6 +26,7 @@ const ( GUID_HEADER_NAME = "Lava-Guid" ERRORED_PROVIDERS_HEADER_NAME = "Lava-Errored-Providers" REPORTED_PROVIDERS_HEADER_NAME = "Lava-Reported-Providers" + LAVA_CONSUMER_PROCESS_GUID = "lava-consumer-process-guid" // these headers need to be lowercase BLOCK_PROVIDERS_ADDRESSES_HEADER_NAME = "lava-providers-block" RELAY_TIMEOUT_HEADER_NAME = "lava-relay-timeout" @@ -63,7 +64,7 @@ func (nurl NodeUrl) String() string { urlStr := nurl.UrlStr() if len(nurl.Addons) > 0 { - return urlStr + "(" + strings.Join(nurl.Addons, ",") + ")" + return urlStr + ", addons: (" + strings.Join(nurl.Addons, ",") + ")" } return urlStr } @@ -165,18 +166,43 @@ func (ac *AuthConfig) AddAuthPath(url string) string { func ValidateEndpoint(endpoint, apiInterface string) error { switch apiInterface { - case spectypes.APIInterfaceJsonRPC, spectypes.APIInterfaceTendermintRPC, spectypes.APIInterfaceRest: + case spectypes.APIInterfaceRest: parsedUrl, err := url.Parse(endpoint) if err != nil { - return utils.LavaFormatError("could not parse node url", err, utils.Attribute{Key: "url", Value: endpoint}, utils.Attribute{Key: "apiInterface", Value: apiInterface}) + return utils.LavaFormatError("could not parse node url", err, + utils.LogAttr("url", endpoint), + utils.LogAttr("apiInterface", apiInterface), + ) } + + switch parsedUrl.Scheme { + case "http", "https": + return nil + default: + return utils.LavaFormatError("URL scheme should be (http/https), got: "+parsedUrl.Scheme, nil, + utils.LogAttr("url", endpoint), + utils.LogAttr("apiInterface", apiInterface), + ) + } + case spectypes.APIInterfaceJsonRPC, spectypes.APIInterfaceTendermintRPC: + parsedUrl, err := url.Parse(endpoint) + if err != nil { + return utils.LavaFormatError("could not parse node url", err, + utils.LogAttr("url", endpoint), + utils.LogAttr("apiInterface", apiInterface), + ) + } + switch parsedUrl.Scheme { case "http", "https": return nil case "ws", "wss": return nil default: - return utils.LavaFormatError("URL scheme should be websocket (ws/wss) or (http/https), got: "+parsedUrl.Scheme, nil, utils.Attribute{Key: "apiInterface", Value: apiInterface}) + return utils.LavaFormatError("URL scheme should be websocket (ws/wss) or (http/https), got: "+parsedUrl.Scheme, nil, + utils.LogAttr("url", endpoint), + utils.LogAttr("apiInterface", apiInterface), + ) } case spectypes.APIInterfaceGrpc: if endpoint == "" { @@ -217,7 +243,7 @@ type RelayResult struct { Request *pairingtypes.RelayRequest Reply *pairingtypes.RelayReply ProviderInfo ProviderInfo - ReplyServer *pairingtypes.Relayer_RelaySubscribeClient + ReplyServer pairingtypes.Relayer_RelaySubscribeClient Finalized bool ConflictHandler ConflictHandlerInterface StatusCode int @@ -225,7 +251,7 @@ type RelayResult struct { ProviderTrailer metadata.MD // the provider trailer attached to the request. used to transfer useful information (which is not signed so shouldn't be trusted completely). } -func (rr *RelayResult) GetReplyServer() *pairingtypes.Relayer_RelaySubscribeClient { +func (rr *RelayResult) GetReplyServer() pairingtypes.Relayer_RelaySubscribeClient { if rr == nil { return nil } diff --git a/protocol/common/errors.go b/protocol/common/errors.go index 77a88cb6bd..e67811e2d0 100644 --- a/protocol/common/errors.go +++ b/protocol/common/errors.go @@ -8,4 +8,5 @@ var ( StatusCodeError429 = sdkerrors.New("Disallowed StatusCode Error", 429, "Disallowed status code error") StatusCodeErrorStrict = sdkerrors.New("Disallowed StatusCode Error", 800, "Disallowed status code error") APINotSupportedError = sdkerrors.New("APINotSupported Error", 900, "api not supported") + SubscriptionNotFoundError = sdkerrors.New("SubscriptionNotFoundError Error", 901, "subscription not found") ) diff --git a/protocol/common/return_errors.go b/protocol/common/return_errors.go index f0c124ad73..5394ba1f3d 100644 --- a/protocol/common/return_errors.go +++ b/protocol/common/return_errors.go @@ -9,6 +9,7 @@ import "github.com/gofiber/fiber/v2" type JsonRPCError struct { Code int `json:"code"` Message string `json:"message"` + Data string `json:"data"` } type JsonRPCErrorMessage struct { @@ -26,6 +27,16 @@ var JsonRpcMethodNotFoundError = JsonRPCErrorMessage{ }, } +var JsonRpcSubscriptionNotFoundError = JsonRPCErrorMessage{ + JsonRPC: "2.0", + Id: 1, + Error: JsonRPCError{ + Code: -32603, + Message: "Internal error", + Data: "subscription not found", + }, +} + // ####### // Rest // ####### diff --git a/protocol/common/safe_channel_sender.go b/protocol/common/safe_channel_sender.go new file mode 100644 index 0000000000..c8dcb3ea49 --- /dev/null +++ b/protocol/common/safe_channel_sender.go @@ -0,0 +1,99 @@ +package common + +import ( + "context" + "sync" + "time" + + "github.com/lavanet/lava/v2/utils" +) + +const retryAttemptsForChannelWrite = 10 + +type SafeChannelSender[T any] struct { + ctx context.Context + cancelCtx context.CancelFunc + ch chan<- T + closed bool + lock sync.Mutex +} + +func NewSafeChannelSender[T any](ctx context.Context, ch chan<- T) *SafeChannelSender[T] { + ctx, cancel := context.WithCancel(ctx) + return &SafeChannelSender[T]{ + ctx: ctx, + cancelCtx: cancel, + ch: ch, + closed: false, + lock: sync.Mutex{}, + } +} + +func (scs *SafeChannelSender[T]) sendInner(msg T) { + if scs.closed { + utils.LavaFormatTrace("Attempted to send message to closed channel") + return + } + + shouldBreak := false + for retry := 0; retry < retryAttemptsForChannelWrite; retry++ { + select { + case <-scs.ctx.Done(): + // trying to write to the channel, if the channel is not ready this will fail and retry again up to retryAttemptsForChannelWrite times + case scs.ch <- msg: + shouldBreak = true + default: + utils.LavaFormatTrace("Failed to send message to channel", utils.LogAttr("attempt", retry)) + } + if shouldBreak { + break + } + time.Sleep(time.Millisecond) // wait 1 millisecond between each attempt to write to the channel + } +} + +// Used when there is a need to validate locked, but you don't want to wait for the channel +// to return. +func (scs *SafeChannelSender[T]) LockAndSendAsynchronously(msg T) { + scs.lock.Lock() + go func() { + defer scs.lock.Unlock() + scs.sendInner(msg) + }() +} + +// Used when you need to wait for the other side to receive the message. +func (scs *SafeChannelSender[T]) Send(msg T) { + scs.lock.Lock() + defer scs.lock.Unlock() + scs.sendInner(msg) +} + +func (scs *SafeChannelSender[T]) ReplaceChannel(ch chan<- T) { + scs.lock.Lock() + defer scs.lock.Unlock() + + if scs.closed { + return + } + + // check wether the incoming channel is different than the one we currently have. + // this helps us avoids closing our channel and holding a closed channel causing Close to panic. + if scs.ch != ch { + close(scs.ch) + scs.ch = ch + } +} + +func (scs *SafeChannelSender[T]) Close() { + scs.lock.Lock() + defer scs.lock.Unlock() + + if scs.closed { + return + } + + scs.cancelCtx() + close(scs.ch) + scs.closed = true +} diff --git a/protocol/common/safe_channel_sender_test.go b/protocol/common/safe_channel_sender_test.go new file mode 100644 index 0000000000..eb05f5316a --- /dev/null +++ b/protocol/common/safe_channel_sender_test.go @@ -0,0 +1,41 @@ +package common + +import ( + "context" + "fmt" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestSafeChannelSender(t *testing.T) { + t.Run("Send message", func(t *testing.T) { + ctx := context.Background() + ch := make(chan int) + sender := NewSafeChannelSender(ctx, ch) + + msg := 42 + chEnd := make(chan bool) + go func() { + defer func() { chEnd <- true }() + select { + case received, ok := <-ch: + fmt.Println("got message from channel", ok, received) + require.True(t, ok) + require.Equal(t, msg, received) + return + case <-time.After(time.Second * 20): + require.Fail(t, "Expected message to be sent, but channel is empty") + } + }() + + // wait for the routine to listen to the channel + <-time.After(time.Second * 1) + sender.Send(msg) + sender.Close() + require.True(t, sender.closed) + // wait for the test to end + <-chEnd + }) +} diff --git a/protocol/common/strings.go b/protocol/common/strings.go new file mode 100644 index 0000000000..831e47943a --- /dev/null +++ b/protocol/common/strings.go @@ -0,0 +1,18 @@ +package common + +import "strings" + +func IsQuoted(s string) bool { + return strings.HasPrefix(s, "\"") && strings.HasSuffix(s, "\"") +} + +func IsSquareBracketed(s string) bool { + return strings.HasPrefix(s, "[") && strings.HasSuffix(s, "]") +} + +func UnSquareBracket(s string) string { + if IsSquareBracketed(s) { + return s[1 : len(s)-1] + } + return s +} diff --git a/protocol/common/timeout.go b/protocol/common/timeout.go index 464937deca..af1137ffa7 100644 --- a/protocol/common/timeout.go +++ b/protocol/common/timeout.go @@ -20,6 +20,11 @@ const ( DefaultTimeoutLongIsh = 1 * time.Minute DefaultTimeoutLong = 3 * time.Minute CacheTimeout = 50 * time.Millisecond + // On subscriptions we must use context.Background(), + // we cant have a context.WithTimeout() context, meaning we can hang for ever. + // to avoid that we introduced a first reply timeout using a routine. + // if the first reply doesn't return after the specified timeout a timeout error will occur + SubscriptionFirstReplyTimeout = 10 * time.Second ) func LocalNodeTimePerCu(cu uint64) time.Duration { diff --git a/protocol/integration/mocks.go b/protocol/integration/mocks.go index 9f150f707d..e5bd8d22fe 100644 --- a/protocol/integration/mocks.go +++ b/protocol/integration/mocks.go @@ -11,7 +11,7 @@ import ( "github.com/lavanet/lava/v2/protocol/chaintracker" "github.com/lavanet/lava/v2/protocol/common" - "github.com/lavanet/lava/v2/protocol/lavaprotocol" + "github.com/lavanet/lava/v2/protocol/lavaprotocol/finalizationconsensus" "github.com/lavanet/lava/v2/protocol/lavasession" "github.com/lavanet/lava/v2/protocol/rpcprovider/reliabilitymanager" "github.com/lavanet/lava/v2/protocol/statetracker/updaters" @@ -34,7 +34,7 @@ func (m *mockConsumerStateTracker) RegisterForSpecUpdates(ctx context.Context, s return nil } -func (m *mockConsumerStateTracker) RegisterFinalizationConsensusForUpdates(context.Context, *lavaprotocol.FinalizationConsensus) { +func (m *mockConsumerStateTracker) RegisterFinalizationConsensusForUpdates(context.Context, *finalizationconsensus.FinalizationConsensus) { } func (m *mockConsumerStateTracker) RegisterForDowntimeParamsUpdates(ctx context.Context, downtimeParamsUpdatable updaters.DowntimeParamsUpdatable) error { diff --git a/protocol/integration/protocol_test.go b/protocol/integration/protocol_test.go index bcecc67b94..1fc35efda4 100644 --- a/protocol/integration/protocol_test.go +++ b/protocol/integration/protocol_test.go @@ -7,15 +7,17 @@ import ( "fmt" "io" "net/http" + "net/url" "os" "testing" "time" + "github.com/gorilla/websocket" "github.com/lavanet/lava/v2/protocol/chainlib" "github.com/lavanet/lava/v2/protocol/chainlib/chainproxy/rpcInterfaceMessages" "github.com/lavanet/lava/v2/protocol/chaintracker" "github.com/lavanet/lava/v2/protocol/common" - "github.com/lavanet/lava/v2/protocol/lavaprotocol" + "github.com/lavanet/lava/v2/protocol/lavaprotocol/finalizationconsensus" "github.com/lavanet/lava/v2/protocol/lavasession" "github.com/lavanet/lava/v2/protocol/metrics" "github.com/lavanet/lava/v2/protocol/provideroptimizer" @@ -92,12 +94,28 @@ func checkGrpcServerStatusWithTimeout(url string, totalTimeout time.Duration) bo return false } -func isServerUp(url string) bool { +func isServerUp(urlPath string) bool { + u, err := url.Parse(urlPath) + if err != nil { + panic(err) + } + + switch { + case u.Scheme == "http": + return isHttpServerUp(urlPath) + case u.Scheme == "ws": + return isWebsocketServerUp(urlPath) + default: + panic("unsupported scheme") + } +} + +func isHttpServerUp(urlPath string) bool { client := http.Client{ Timeout: 20 * time.Millisecond, } - resp, err := client.Get(url) + resp, err := client.Get(urlPath) if err != nil { return false } @@ -107,6 +125,15 @@ func isServerUp(url string) bool { return resp.ContentLength > 0 } +func isWebsocketServerUp(urlPath string) bool { + client, _, err := websocket.DefaultDialer.Dial(urlPath, nil) + if err != nil { + return false + } + client.Close() + return true +} + func checkServerStatusWithTimeout(url string, totalTimeout time.Duration) bool { startTime := time.Now() @@ -137,7 +164,7 @@ func createRpcConsumer(t *testing.T, ctx context.Context, specId string, apiInte // Handle the incoming request and provide the desired response w.WriteHeader(http.StatusOK) }) - chainParser, _, chainFetcher, _, _, err := chainlib.CreateChainLibMocks(ctx, specId, apiInterface, serverHandler, "../../", nil) + chainParser, _, chainFetcher, _, _, err := chainlib.CreateChainLibMocks(ctx, specId, apiInterface, serverHandler, nil, "../../", nil) require.NoError(t, err) require.NotNil(t, chainParser) require.NotNil(t, chainFetcher) @@ -152,22 +179,26 @@ func createRpcConsumer(t *testing.T, ctx context.Context, specId string, apiInte Geolocation: 1, } consumerStateTracker := &mockConsumerStateTracker{} - finalizationConsensus := lavaprotocol.NewFinalizationConsensus(rpcEndpoint.ChainID) + finalizationConsensus := finalizationconsensus.NewFinalizationConsensus(rpcEndpoint.ChainID) _, averageBlockTime, _, _ := chainParser.ChainBlockStats() baseLatency := common.AverageWorldLatency / 2 optimizer := provideroptimizer.NewProviderOptimizer(provideroptimizer.STRATEGY_BALANCED, averageBlockTime, baseLatency, 2) - consumerSessionManager := lavasession.NewConsumerSessionManager(rpcEndpoint, optimizer, nil, nil, "test") + consumerSessionManager := lavasession.NewConsumerSessionManager(rpcEndpoint, optimizer, nil, nil, "test", lavasession.NewActiveSubscriptionProvidersStorage()) consumerSessionManager.UpdateAllProviders(epoch, pairingList) consumerConsistency := rpcconsumer.NewConsumerConsistency(specId) consumerCmdFlags := common.ConsumerCmdFlags{} rpcsonumerLogs, err := metrics.NewRPCConsumerLogs(nil, nil) require.NoError(t, err) - err = rpcConsumerServer.ServeRPCRequests(ctx, rpcEndpoint, consumerStateTracker, chainParser, finalizationConsensus, consumerSessionManager, requiredResponses, account.SK, lavaChainID, nil, rpcsonumerLogs, account.Addr, consumerConsistency, nil, consumerCmdFlags, false, nil, nil) + err = rpcConsumerServer.ServeRPCRequests(ctx, rpcEndpoint, consumerStateTracker, chainParser, finalizationConsensus, consumerSessionManager, requiredResponses, account.SK, lavaChainID, nil, rpcsonumerLogs, account.Addr, consumerConsistency, nil, consumerCmdFlags, false, nil, nil, nil) require.NoError(t, err) // wait for consumer server to be up consumerUp := checkServerStatusWithTimeout("http://"+consumerListenAddress, time.Millisecond*61) require.True(t, consumerUp) + if rpcEndpoint.ApiInterface == "tendermintrpc" || rpcEndpoint.ApiInterface == "jsonrpc" { + consumerUp = checkServerStatusWithTimeout("ws://"+consumerListenAddress+"/ws", time.Millisecond*61) + require.True(t, consumerUp) + } return rpcConsumerServer } @@ -192,7 +223,7 @@ func createRpcProvider(t *testing.T, ctx context.Context, consumerAddress string fmt.Fprint(w, string(data)) }) - chainParser, chainRouter, chainFetcher, _, endpoint, err := chainlib.CreateChainLibMocks(ctx, specId, apiInterface, serverHandler, "../../", addons) + chainParser, chainRouter, chainFetcher, _, endpoint, err := chainlib.CreateChainLibMocks(ctx, specId, apiInterface, serverHandler, nil, "../../", addons) require.NoError(t, err) require.NotNil(t, chainParser) require.NotNil(t, chainFetcher) @@ -248,7 +279,7 @@ func createRpcProvider(t *testing.T, ctx context.Context, consumerAddress string chainTracker, err := chaintracker.NewChainTracker(ctx, mockChainFetcher, chainTrackerConfig) require.NoError(t, err) reliabilityManager := reliabilitymanager.NewReliabilityManager(chainTracker, &mockProviderStateTracker, account.Addr.String(), chainRouter, chainParser) - rpcProviderServer.ServeRPCRequests(ctx, rpcProviderEndpoint, chainParser, rws, providerSessionManager, reliabilityManager, account.SK, nil, chainRouter, &mockProviderStateTracker, account.Addr, lavaChainID, rpcprovider.DEFAULT_ALLOWED_MISSING_CU, nil, nil) + rpcProviderServer.ServeRPCRequests(ctx, rpcProviderEndpoint, chainParser, rws, providerSessionManager, reliabilityManager, account.SK, nil, chainRouter, &mockProviderStateTracker, account.Addr, lavaChainID, rpcprovider.DEFAULT_ALLOWED_MISSING_CU, nil, nil, nil) listener := rpcprovider.NewProviderListener(ctx, rpcProviderEndpoint.NetworkAddress, "/health") err = listener.RegisterReceiver(rpcProviderServer, rpcProviderEndpoint) require.NoError(t, err) @@ -665,3 +696,114 @@ func TestConsumerProviderJsonRpcWithNullID(t *testing.T) { }) } } + +func TestConsumerProviderSubscriptionsHappyFlow(t *testing.T) { + playbook := []struct { + name string + specId string + method string + expected string + apiInterface string + }{ + { + name: "jsonrpc", + specId: "ETH1", + method: "eth_blockNumber", + expected: `{"jsonrpc":"2.0","id":null,"result":{}}`, + apiInterface: spectypes.APIInterfaceJsonRPC, + }, + { + name: "tendermintrpc", + specId: "LAV1", + method: "status", + expected: `{"jsonrpc":"2.0","result":{}}`, + apiInterface: spectypes.APIInterfaceTendermintRPC, + }, + } + for _, play := range playbook { + t.Run(play.name, func(t *testing.T) { + ctx := context.Background() + // can be any spec and api interface + specId := play.specId + apiInterface := play.apiInterface + epoch := uint64(100) + requiredResponses := 1 + lavaChainID := "lava" + numProviders := 5 + + consumerListenAddress := addressGen.GetAddress() + pairingList := map[uint64]*lavasession.ConsumerSessionsWithProvider{} + type providerData struct { + account sigs.Account + endpoint *lavasession.RPCProviderEndpoint + server *rpcprovider.RPCProviderServer + replySetter *ReplySetter + mockChainFetcher *MockChainFetcher + } + providers := []providerData{} + + for i := 0; i < numProviders; i++ { + // providerListenAddress := "localhost:111" + strconv.Itoa(i) + account := sigs.GenerateDeterministicFloatingKey(randomizer) + providerDataI := providerData{account: account} + providers = append(providers, providerDataI) + } + consumerAccount := sigs.GenerateDeterministicFloatingKey(randomizer) + for i := 0; i < numProviders; i++ { + ctx := context.Background() + providerDataI := providers[i] + listenAddress := addressGen.GetAddress() + providers[i].server, providers[i].endpoint, providers[i].replySetter, providers[i].mockChainFetcher = createRpcProvider(t, ctx, consumerAccount.Addr.String(), specId, apiInterface, listenAddress, providerDataI.account, lavaChainID, []string(nil)) + providers[i].replySetter.replyDataBuf = []byte(fmt.Sprintf(`{"result": %d}`, i+1)) + } + for i := 0; i < numProviders; i++ { + pairingList[uint64(i)] = &lavasession.ConsumerSessionsWithProvider{ + PublicLavaAddress: providers[i].account.Addr.String(), + Endpoints: []*lavasession.Endpoint{ + { + NetworkAddress: providers[i].endpoint.NetworkAddress.Address, + Enabled: true, + Geolocation: 1, + }, + }, + Sessions: map[int64]*lavasession.SingleConsumerSession{}, + MaxComputeUnits: 10000, + UsedComputeUnits: 0, + PairingEpoch: epoch, + } + } + rpcconsumerServer := createRpcConsumer(t, ctx, specId, apiInterface, consumerAccount, consumerListenAddress, epoch, pairingList, requiredResponses, lavaChainID) + require.NotNil(t, rpcconsumerServer) + + for i := 0; i < numProviders; i++ { + handler := func(req []byte, header http.Header) (data []byte, status int) { + var jsonRpcMessage rpcInterfaceMessages.JsonrpcMessage + err := json.Unmarshal(req, &jsonRpcMessage) + require.NoError(t, err) + + response := fmt.Sprintf(`{"jsonrpc":"2.0","result": {}, "id": %v}`, string(jsonRpcMessage.ID)) + return []byte(response), http.StatusOK + } + providers[i].replySetter.handler = handler + } + + client := http.Client{Timeout: 500 * time.Millisecond} + jsonMsg := fmt.Sprintf(`{"jsonrpc":"2.0","method":"%v","params": [], "id":null}`, play.method) + msgBuffer := bytes.NewBuffer([]byte(jsonMsg)) + req, err := http.NewRequest(http.MethodPost, "http://"+consumerListenAddress, msgBuffer) + require.NoError(t, err) + req.Header.Set("Content-Type", "application/json") + + resp, err := client.Do(req) + require.NoError(t, err) + + bodyBytes, err := io.ReadAll(resp.Body) + require.NoError(t, err) + require.Equal(t, http.StatusOK, resp.StatusCode, string(bodyBytes)) + + resp.Body.Close() + + require.Equal(t, play.expected, string(bodyBytes)) + }) + } +} diff --git a/protocol/lavaprotocol/errors.go b/protocol/lavaprotocol/errors.go index aedb32ec0e..70c662f9bf 100644 --- a/protocol/lavaprotocol/errors.go +++ b/protocol/lavaprotocol/errors.go @@ -5,10 +5,10 @@ import ( ) var ( - ProviderFinzalizationDataError = sdkerrors.New("ProviderFinzalizationData Error", 3365, "provider did not sign finalization data correctly") - ProviderFinzalizationDataAccountabilityError = sdkerrors.New("ProviderFinzalizationDataAccountability Error", 3366, "provider returned invalid finalization data, with accountability") - HashesConsunsusError = sdkerrors.New("HashesConsunsus Error", 3367, "identified finalized responses with conflicting hashes, from two providers") - ConsistencyError = sdkerrors.New("Consistency Error", 3368, "does not meet consistency requirements") - UnhandledRelayReceiverError = sdkerrors.New("UnhandledRelayReceiver Error", 3369, "provider does not handle requested api interface and spec") - DisabledRelayReceiverError = sdkerrors.New("DisabledRelayReceiverError Error", 3370, "provider does not pass verification and disabled this interface and spec") + ProviderFinalizationDataError = sdkerrors.New("ProviderFinalizationData Error", 3365, "provider did not sign finalization data correctly") + ProviderFinalizationDataAccountabilityError = sdkerrors.New("ProviderFinalizationDataAccountability Error", 3366, "provider returned invalid finalization data, with accountability") + HashesConsensusError = sdkerrors.New("HashesConsensus Error", 3367, "identified finalized responses with conflicting hashes, from two providers") + ConsistencyError = sdkerrors.New("Consistency Error", 3368, "does not meet consistency requirements") + UnhandledRelayReceiverError = sdkerrors.New("UnhandledRelayReceiver Error", 3369, "provider does not handle requested api interface and spec") + DisabledRelayReceiverError = sdkerrors.New("DisabledRelayReceiverError Error", 3370, "provider does not pass verification and disabled this interface and spec") ) diff --git a/protocol/lavaprotocol/finalization_consensus.go b/protocol/lavaprotocol/finalizationconsensus/finalization_consensus.go similarity index 96% rename from protocol/lavaprotocol/finalization_consensus.go rename to protocol/lavaprotocol/finalizationconsensus/finalization_consensus.go index 396168ef75..eb9e07e0ed 100644 --- a/protocol/lavaprotocol/finalization_consensus.go +++ b/protocol/lavaprotocol/finalizationconsensus/finalization_consensus.go @@ -1,4 +1,4 @@ -package lavaprotocol +package finalizationconsensus import ( "fmt" @@ -8,12 +8,17 @@ import ( "time" "github.com/lavanet/lava/v2/protocol/chainlib" + "github.com/lavanet/lava/v2/protocol/lavaprotocol" "github.com/lavanet/lava/v2/utils" "github.com/lavanet/lava/v2/utils/lavaslices" conflicttypes "github.com/lavanet/lava/v2/x/conflict/types" pairingtypes "github.com/lavanet/lava/v2/x/pairing/types" ) +const ( + debug = false +) + type FinalizationConsensus struct { currentProviderHashesConsensus []ProviderHashesConsensus prevEpochProviderHashesConsensus []ProviderHashesConsensus @@ -171,7 +176,7 @@ func (fc *FinalizationConsensus) discrepancyChecker(finalizedBlocksA map[int64]s if otherHash, ok := otherBlocks[blockNum]; ok { if blockHash != otherHash { // TODO: gather discrepancy data - return utils.LavaFormatError("Simulation: reliability discrepancy, different hashes detected for block", HashesConsunsusError, utils.Attribute{Key: "blockNum", Value: blockNum}, utils.Attribute{Key: "Hashes", Value: fmt.Sprintf("%s vs %s", blockHash, otherHash)}, utils.Attribute{Key: "toIterate", Value: toIterate}, utils.Attribute{Key: "otherBlocks", Value: otherBlocks}) + return utils.LavaFormatError("Simulation: reliability discrepancy, different hashes detected for block", lavaprotocol.HashesConsensusError, utils.Attribute{Key: "blockNum", Value: blockNum}, utils.Attribute{Key: "Hashes", Value: fmt.Sprintf("%s vs %s", blockHash, otherHash)}, utils.Attribute{Key: "toIterate", Value: toIterate}, utils.Attribute{Key: "otherBlocks", Value: otherBlocks}) } } } diff --git a/protocol/lavaprotocol/finalization_consensus_test.go b/protocol/lavaprotocol/finalizationconsensus/finalization_consensus_test.go similarity index 99% rename from protocol/lavaprotocol/finalization_consensus_test.go rename to protocol/lavaprotocol/finalizationconsensus/finalization_consensus_test.go index 38429a18e7..824c4ba41d 100644 --- a/protocol/lavaprotocol/finalization_consensus_test.go +++ b/protocol/lavaprotocol/finalizationconsensus/finalization_consensus_test.go @@ -1,4 +1,4 @@ -package lavaprotocol +package finalizationconsensus import ( "context" @@ -74,7 +74,7 @@ func TestConsensusHashesInsertion(t *testing.T) { chainsToTest := []string{"APT1", "LAV1", "ETH1"} for _, chainID := range chainsToTest { ctx := context.Background() - chainParser, _, _, closeServer, _, err := chainlib.CreateChainLibMocks(ctx, chainID, "0", func(http.ResponseWriter, *http.Request) {}, "../../", nil) + chainParser, _, _, closeServer, _, err := chainlib.CreateChainLibMocks(ctx, chainID, "0", func(http.ResponseWriter, *http.Request) {}, nil, "../../../", nil) if closeServer != nil { defer closeServer() } @@ -163,7 +163,7 @@ func TestQoS(t *testing.T) { for _, chainID := range chainsToTest { t.Run(chainID, func(t *testing.T) { ctx := context.Background() - chainParser, _, _, closeServer, _, err := chainlib.CreateChainLibMocks(ctx, chainID, "0", func(http.ResponseWriter, *http.Request) {}, "../../", nil) + chainParser, _, _, closeServer, _, err := chainlib.CreateChainLibMocks(ctx, chainID, "0", func(http.ResponseWriter, *http.Request) {}, nil, "../../../", nil) if closeServer != nil { defer closeServer() } diff --git a/protocol/lavaprotocol/response_builder.go b/protocol/lavaprotocol/response_builder.go index 1feffd19be..187d942623 100644 --- a/protocol/lavaprotocol/response_builder.go +++ b/protocol/lavaprotocol/response_builder.go @@ -8,6 +8,7 @@ import ( btcSecp256k1 "github.com/btcsuite/btcd/btcec/v2" sdk "github.com/cosmos/cosmos-sdk/types" + "github.com/lavanet/lava/v2/protocol/chainlib/chainproxy/rpcInterfaceMessages" "github.com/lavanet/lava/v2/utils" "github.com/lavanet/lava/v2/utils/lavaslices" "github.com/lavanet/lava/v2/utils/sigs" @@ -16,10 +17,45 @@ import ( spectypes "github.com/lavanet/lava/v2/x/spec/types" ) +func CraftEmptyRPCResponseFromGenericMessage(message rpcInterfaceMessages.GenericMessage) (*rpcInterfaceMessages.RPCResponse, error) { + createRPCResponse := func(rawId json.RawMessage) (*rpcInterfaceMessages.RPCResponse, error) { + jsonRpcId, err := rpcInterfaceMessages.IdFromRawMessage(rawId) + if err != nil { + return nil, utils.LavaFormatError("failed creating jsonrpc id", err) + } + + jsonResponse := &rpcInterfaceMessages.RPCResponse{ + JSONRPC: "2.0", + ID: jsonRpcId, + Result: nil, + Error: nil, + } + + return jsonResponse, nil + } + + var err error + var rpcResponse *rpcInterfaceMessages.RPCResponse + if hasID, ok := message.(interface{ GetID() json.RawMessage }); ok { + rpcResponse, err = createRPCResponse(hasID.GetID()) + if err != nil { + return nil, utils.LavaFormatError("failed creating jsonrpc id", err) + } + } else { + rpcResponse, err = createRPCResponse([]byte("1")) + if err != nil { + return nil, utils.LavaFormatError("failed creating jsonrpc id", err) + } + } + + return rpcResponse, nil +} + func SignRelayResponse(consumerAddress sdk.AccAddress, request pairingtypes.RelayRequest, pkey *btcSecp256k1.PrivateKey, reply *pairingtypes.RelayReply, signDataReliability bool) (*pairingtypes.RelayReply, error) { // request is a copy of the original request, but won't modify it // update relay request requestedBlock to the provided one in case it was arbitrary UpdateRequestedBlock(request.RelayData, reply) + // Update signature, relayExchange := pairingtypes.NewRelayExchange(request, *reply) sig, err := sigs.Sign(pkey, relayExchange) @@ -46,14 +82,20 @@ func VerifyRelayReply(ctx context.Context, reply *pairingtypes.RelayReply, relay relayExchange := pairingtypes.NewRelayExchange(*relayRequest, *reply) serverKey, err := sigs.RecoverPubKey(relayExchange) if err != nil { - return err + return utils.LavaFormatWarning("Relay reply verification failed, RecoverPubKey returned error", err, utils.LogAttr("GUID", ctx)) } serverAddr, err := sdk.AccAddressFromHexUnsafe(serverKey.Address().String()) if err != nil { - return err + return utils.LavaFormatWarning("Relay reply verification failed, AccAddressFromHexUnsafe returned error", err, utils.LogAttr("GUID", ctx)) } if serverAddr.String() != addr { - return utils.LavaFormatError("reply server address mismatch ", ProviderFinzalizationDataError, utils.LogAttr("GUID", ctx), utils.Attribute{Key: "parsed Address", Value: serverAddr.String()}, utils.Attribute{Key: "expected address", Value: addr}, utils.Attribute{Key: "requestedBlock", Value: relayRequest.RelayData.RequestBlock}, utils.Attribute{Key: "latestBlock", Value: reply.GetLatestBlock()}) + return utils.LavaFormatError("reply server address mismatch", ProviderFinalizationDataError, + utils.LogAttr("GUID", ctx), + utils.LogAttr("parsedAddress", serverAddr.String()), + utils.LogAttr("expectedAddress", addr), + utils.LogAttr("requestedBlock", relayRequest.RelayData.RequestBlock), + utils.LogAttr("latestBlock", reply.GetLatestBlock()), + ) } return nil @@ -63,22 +105,22 @@ func VerifyFinalizationData(reply *pairingtypes.RelayReply, relayRequest *pairin relayFinalization := pairingtypes.NewRelayFinalization(pairingtypes.NewRelayExchange(*relayRequest, *reply), consumerAcc) serverKey, err := sigs.RecoverPubKey(relayFinalization) if err != nil { - return nil, nil, err + return nil, nil, utils.LavaFormatWarning("Finalization data verification failed, RecoverPubKey returned error", err) } serverAddr, err := sdk.AccAddressFromHexUnsafe(serverKey.Address().String()) if err != nil { - return nil, nil, err + return nil, nil, utils.LavaFormatWarning("Finalization data verification failed, AccAddressFromHexUnsafe returned error", err) } if serverAddr.String() != providerAddr { - return nil, nil, utils.LavaFormatError("reply server address mismatch in finalization data ", ProviderFinzalizationDataError, utils.Attribute{Key: "parsed Address", Value: serverAddr.String()}, utils.Attribute{Key: "expected address", Value: providerAddr}) + return nil, nil, utils.LavaFormatError("reply server address mismatch in finalization data ", ProviderFinalizationDataError, utils.Attribute{Key: "parsed Address", Value: serverAddr.String()}, utils.Attribute{Key: "expected address", Value: providerAddr}) } finalizedBlocks = map[int64]string{} // TODO:: define struct in relay response err = json.Unmarshal(reply.FinalizedBlocksHashes, &finalizedBlocks) if err != nil { - return nil, nil, utils.LavaFormatError("failed in unmarshalling finalized blocks data", ProviderFinzalizationDataError, utils.Attribute{Key: "FinalizedBlocksHashes", Value: string(reply.FinalizedBlocksHashes)}, utils.Attribute{Key: "errMsg", Value: err.Error()}) + return nil, nil, utils.LavaFormatError("failed in unmarshalling finalized blocks data", ProviderFinalizationDataError, utils.Attribute{Key: "FinalizedBlocksHashes", Value: string(reply.FinalizedBlocksHashes)}, utils.Attribute{Key: "errMsg", Value: err.Error()}) } finalizationConflict, err = verifyFinalizationDataIntegrity(reply, latestSessionBlock, finalizedBlocks, blockDistanceForfinalization, providerAddr) @@ -89,7 +131,7 @@ func VerifyFinalizationData(reply *pairingtypes.RelayReply, relayRequest *pairin seenBlock := relayRequest.RelayData.SeenBlock requestBlock := relayRequest.RelayData.RequestBlock if providerLatestBlock < lavaslices.Min([]int64{seenBlock, requestBlock}) { - return nil, nil, utils.LavaFormatError("provider response does not meet consistency requirements", ProviderFinzalizationDataError, utils.LogAttr("ProviderAddress", relayRequest.RelaySession.Provider), utils.LogAttr("providerLatestBlock", providerLatestBlock), utils.LogAttr("seenBlock", seenBlock), utils.LogAttr("requestBlock", requestBlock), utils.Attribute{Key: "provider address", Value: providerAddr}) + return nil, nil, utils.LavaFormatError("provider response does not meet consistency requirements", ProviderFinalizationDataError, utils.LogAttr("ProviderAddress", relayRequest.RelaySession.Provider), utils.LogAttr("providerLatestBlock", providerLatestBlock), utils.LogAttr("seenBlock", seenBlock), utils.LogAttr("requestBlock", requestBlock), utils.Attribute{Key: "provider address", Value: providerAddr}) } return finalizedBlocks, finalizationConflict, errRet } @@ -104,7 +146,7 @@ func verifyFinalizationDataIntegrity(reply *pairingtypes.RelayReply, latestSessi for blockNum := range finalizedBlocks { if !spectypes.IsFinalizedBlock(blockNum, latestBlock, blockDistanceForfinalization) { finalizationConflict = &conflicttypes.FinalizationConflict{RelayReply0: reply} - return finalizationConflict, utils.LavaFormatError("Simulation: provider returned non finalized block reply for reliability", ProviderFinzalizationDataAccountabilityError, utils.Attribute{Key: "blockNum", Value: blockNum}, utils.Attribute{Key: "latestBlock", Value: latestBlock}, utils.Attribute{Key: "Provider", Value: providerAddr}, utils.Attribute{Key: "finalizedBlocks", Value: finalizedBlocks}) + return finalizationConflict, utils.LavaFormatError("Simulation: provider returned non finalized block reply for reliability", ProviderFinalizationDataAccountabilityError, utils.Attribute{Key: "blockNum", Value: blockNum}, utils.Attribute{Key: "latestBlock", Value: latestBlock}, utils.Attribute{Key: "Provider", Value: providerAddr}, utils.Attribute{Key: "finalizedBlocks", Value: finalizedBlocks}) } sorted[idx] = blockNum @@ -122,14 +164,14 @@ func verifyFinalizationDataIntegrity(reply *pairingtypes.RelayReply, latestSessi if index != 0 && sorted[index]-1 != sorted[index-1] { // log.Println("provider returned non consecutive finalized blocks reply.\n Provider: %s", providerAcc) finalizationConflict = &conflicttypes.FinalizationConflict{RelayReply0: reply} - return finalizationConflict, utils.LavaFormatError("Simulation: provider returned non consecutive finalized blocks reply", ProviderFinzalizationDataAccountabilityError, utils.Attribute{Key: "curr block", Value: sorted[index]}, utils.Attribute{Key: "prev block", Value: sorted[index-1]}, utils.Attribute{Key: "Provider", Value: providerAddr}, utils.Attribute{Key: "finalizedBlocks", Value: finalizedBlocks}) + return finalizationConflict, utils.LavaFormatError("Simulation: provider returned non consecutive finalized blocks reply", ProviderFinalizationDataAccountabilityError, utils.Attribute{Key: "curr block", Value: sorted[index]}, utils.Attribute{Key: "prev block", Value: sorted[index-1]}, utils.Attribute{Key: "Provider", Value: providerAddr}, utils.Attribute{Key: "finalizedBlocks", Value: finalizedBlocks}) } } // check that latest finalized block address + 1 points to a non finalized block if spectypes.IsFinalizedBlock(maxBlockNum+1, latestBlock, blockDistanceForfinalization) { finalizationConflict = &conflicttypes.FinalizationConflict{RelayReply0: reply} - return finalizationConflict, utils.LavaFormatError("Simulation: provider returned finalized hashes for an older latest block", ProviderFinzalizationDataAccountabilityError, + return finalizationConflict, utils.LavaFormatError("Simulation: provider returned finalized hashes for an older latest block", ProviderFinalizationDataAccountabilityError, utils.Attribute{Key: "maxBlockNum", Value: maxBlockNum}, utils.Attribute{Key: "latestBlock", Value: latestBlock}, utils.Attribute{Key: "Provider", Value: providerAddr}, utils.Attribute{Key: "finalizedBlocks", Value: finalizedBlocks}) } @@ -137,7 +179,7 @@ func verifyFinalizationDataIntegrity(reply *pairingtypes.RelayReply, latestSessi // New reply should have blocknum >= from block same provider if latestSessionBlock > latestBlock { finalizationConflict = &conflicttypes.FinalizationConflict{RelayReply0: reply} - return finalizationConflict, utils.LavaFormatError("Simulation: Provider supplied an older latest block than it has previously", ProviderFinzalizationDataAccountabilityError, + return finalizationConflict, utils.LavaFormatError("Simulation: Provider supplied an older latest block than it has previously", ProviderFinalizationDataAccountabilityError, utils.Attribute{Key: "session.LatestBlock", Value: latestSessionBlock}, utils.Attribute{Key: "latestBlock", Value: latestBlock}, utils.Attribute{Key: "Provider", Value: providerAddr}) } diff --git a/protocol/lavasession/active_subscription_provider_storage.go b/protocol/lavasession/active_subscription_provider_storage.go new file mode 100644 index 0000000000..65f0520cff --- /dev/null +++ b/protocol/lavasession/active_subscription_provider_storage.go @@ -0,0 +1,71 @@ +package lavasession + +import ( + "sync" + + "github.com/lavanet/lava/v2/utils" +) + +// stores all providers that are currently used to stream subscriptions. + +type NumberOfActiveSubscriptions int64 + +type ActiveSubscriptionProvidersStorage struct { + lock sync.RWMutex + providers map[string]NumberOfActiveSubscriptions + purgeWhenDone map[string]func() +} + +func NewActiveSubscriptionProvidersStorage() *ActiveSubscriptionProvidersStorage { + return &ActiveSubscriptionProvidersStorage{ + providers: map[string]NumberOfActiveSubscriptions{}, + purgeWhenDone: map[string]func(){}, + } +} + +func (asps *ActiveSubscriptionProvidersStorage) AddProvider(providerAddress string) { + asps.lock.Lock() + defer asps.lock.Unlock() + numberOfSubscriptionsActive := asps.providers[providerAddress] + // Increase numberOfSubscriptionsActive by 1 (even if it didn't exist it will be 0) + asps.providers[providerAddress] = numberOfSubscriptionsActive + 1 +} + +func (asps *ActiveSubscriptionProvidersStorage) RemoveProvider(providerAddress string) { + asps.lock.Lock() + defer asps.lock.Unlock() + // Fetch number of currently active subscriptions for this provider address. + activeSubscriptions, foundProviderAddress := asps.providers[providerAddress] + if foundProviderAddress { + // Check there are no other active subscriptions + if activeSubscriptions <= 1 { + delete(asps.providers, providerAddress) + purgeCallBack, foundPurgerCb := asps.purgeWhenDone[providerAddress] + if foundPurgerCb { + utils.LavaFormatTrace("RemoveProvider, Purging provider on callback", utils.LogAttr("address", providerAddress)) + if purgeCallBack != nil { + purgeCallBack() + } + delete(asps.purgeWhenDone, providerAddress) + } + } else { + // Reduce number of active subscriptions on this provider address + utils.LavaFormatTrace("RemoveProvider, Reducing number of active provider subscriptions", utils.LogAttr("address", providerAddress)) + asps.providers[providerAddress] = activeSubscriptions - 1 + } + } +} + +func (asps *ActiveSubscriptionProvidersStorage) IsProviderCurrentlyUsed(providerAddress string) bool { + asps.lock.RLock() + defer asps.lock.RUnlock() + + _, ok := asps.providers[providerAddress] + return ok +} + +func (asps *ActiveSubscriptionProvidersStorage) addToPurgeWhenDone(providerAddress string, purgeCallback func()) { + asps.lock.Lock() + defer asps.lock.Unlock() + asps.purgeWhenDone[providerAddress] = purgeCallback +} diff --git a/protocol/lavasession/active_subscription_provider_storage_test.go b/protocol/lavasession/active_subscription_provider_storage_test.go new file mode 100644 index 0000000000..6d7cf733e0 --- /dev/null +++ b/protocol/lavasession/active_subscription_provider_storage_test.go @@ -0,0 +1,65 @@ +package lavasession + +import ( + "strconv" + "sync" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestIsProviderInActiveSubscription(t *testing.T) { + acps := NewActiveSubscriptionProvidersStorage() + + // Add a provider + providerAddress := "provider1" + acps.AddProvider(providerAddress) + + // Check if the provider is long-lasting + isActiveSubscription := acps.IsProviderCurrentlyUsed(providerAddress) + require.True(t, isActiveSubscription) + + // Remove the provider + acps.RemoveProvider(providerAddress) + + // Check if the provider is still long-lasting + isActiveSubscription = acps.IsProviderCurrentlyUsed(providerAddress) + require.False(t, isActiveSubscription) +} + +func TestConcurrentAccess(t *testing.T) { + acps := NewActiveSubscriptionProvidersStorage() + + // Add and remove providers concurrently + numProviders := 100 + var wg sync.WaitGroup + wg.Add(numProviders * 2) + for i := 0; i < numProviders; i++ { + providerSpecificWg := sync.WaitGroup{} + providerSpecificWg.Add(1) + go func(providerIndex int) { + defer func() { + wg.Done() + providerSpecificWg.Done() + }() + + providerAddress := "provider" + strconv.Itoa(providerIndex) + acps.AddProvider(providerAddress) + }(i) + + go func(providerIndex int) { + providerSpecificWg.Wait() + defer wg.Done() + providerAddress := "provider" + strconv.Itoa(providerIndex) + acps.RemoveProvider(providerAddress) + }(i) + } + wg.Wait() + + // Check if all providers were added and removed + for i := 0; i < numProviders; i++ { + providerAddress := "provider" + strconv.Itoa(i) + isCurrentlyActive := acps.IsProviderCurrentlyUsed(providerAddress) + require.False(t, isCurrentlyActive) + } +} diff --git a/protocol/lavasession/consumer_session_manager.go b/protocol/lavasession/consumer_session_manager.go index bc767ae732..dd4fc10b35 100644 --- a/protocol/lavasession/consumer_session_manager.go +++ b/protocol/lavasession/consumer_session_manager.go @@ -57,10 +57,11 @@ type ConsumerSessionManager struct { reportedProviders *ReportedProviders // pairingPurge - contains all pairings that are unwanted this epoch, keeps them in memory in order to avoid release. // (if a consumer session still uses one of them or we want to report it.) - pairingPurge map[string]*ConsumerSessionsWithProvider - providerOptimizer ProviderOptimizer - consumerMetricsManager *metrics.ConsumerMetricsManager - consumerPublicAddress string + pairingPurge map[string]*ConsumerSessionsWithProvider + providerOptimizer ProviderOptimizer + consumerMetricsManager *metrics.ConsumerMetricsManager + consumerPublicAddress string + activeSubscriptionProvidersStorage *ActiveSubscriptionProvidersStorage } // this is being read in multiple locations and but never changes so no need to lock. @@ -155,14 +156,29 @@ func (csm *ConsumerSessionManager) getValidAddresses(addon string, extensions [] // otherwise golang garbage collector is not closing network connections and they // will remain open forever. func (csm *ConsumerSessionManager) closePurgedUnusedPairingsConnections() { - for _, purgedPairing := range csm.pairingPurge { - for _, endpoint := range purgedPairing.Endpoints { - for _, endpointConnection := range endpoint.Connections { - if endpointConnection.connection != nil { - endpointConnection.connection.Close() + for providerAddr, purgedPairing := range csm.pairingPurge { + callbackPurge := func() { + for _, endpoint := range purgedPairing.Endpoints { + for _, endpointConnection := range endpoint.Connections { + if endpointConnection.connection != nil { + utils.LavaFormatTrace("purging connection", + utils.LogAttr("providerAddr", providerAddr), + utils.LogAttr("endpoint", endpoint.NetworkAddress), + ) + endpointConnection.connection.Close() + } } } } + // on cases where there is still an active subscription over the epoch handover, we purge the connection when subscription ends. + if csm.activeSubscriptionProvidersStorage.IsProviderCurrentlyUsed(providerAddr) { + utils.LavaFormatTrace("skipping purge for provider, as its currently used in a subscription", + utils.LogAttr("providerAddr", providerAddr), + ) + csm.activeSubscriptionProvidersStorage.addToPurgeWhenDone(providerAddr, callbackPurge) + continue + } + callbackPurge() } } @@ -1055,12 +1071,13 @@ func (csm *ConsumerSessionManager) GetAtomicPairingAddressesLength() uint64 { } // On a successful Subscribe relay -func (csm *ConsumerSessionManager) OnSessionDoneIncreaseCUOnly(consumerSession *SingleConsumerSession) error { +func (csm *ConsumerSessionManager) OnSessionDoneIncreaseCUOnly(consumerSession *SingleConsumerSession, latestServicedBlock int64) error { if err := consumerSession.VerifyLock(); err != nil { return sdkerrors.Wrapf(err, "OnSessionDoneIncreaseRelayAndCu consumerSession.lock must be locked before accessing this method") } - defer consumerSession.Free(nil) // we need to be locked here, if we didn't get it locked we try lock anyway + defer consumerSession.Free(nil) // we need to be locked here, if we didn't get it locked we try lock anyway + consumerSession.LatestBlock = latestServicedBlock consumerSession.CuSum += consumerSession.LatestRelayCu // add CuSum to current cu usage. consumerSession.LatestRelayCu = 0 // reset cu just in case consumerSession.ConsecutiveErrors = []error{} @@ -1079,7 +1096,7 @@ func (csm *ConsumerSessionManager) GenerateReconnectCallback(consumerSessionsWit } } -func NewConsumerSessionManager(rpcEndpoint *RPCEndpoint, providerOptimizer ProviderOptimizer, consumerMetricsManager *metrics.ConsumerMetricsManager, reporter metrics.Reporter, consumerPublicAddress string) *ConsumerSessionManager { +func NewConsumerSessionManager(rpcEndpoint *RPCEndpoint, providerOptimizer ProviderOptimizer, consumerMetricsManager *metrics.ConsumerMetricsManager, reporter metrics.Reporter, consumerPublicAddress string, activeSubscriptionProvidersStorage *ActiveSubscriptionProvidersStorage) *ConsumerSessionManager { csm := &ConsumerSessionManager{ reportedProviders: NewReportedProviders(reporter, rpcEndpoint.ChainID), consumerMetricsManager: consumerMetricsManager, @@ -1087,5 +1104,6 @@ func NewConsumerSessionManager(rpcEndpoint *RPCEndpoint, providerOptimizer Provi } csm.rpcEndpoint = rpcEndpoint csm.providerOptimizer = providerOptimizer + csm.activeSubscriptionProvidersStorage = activeSubscriptionProvidersStorage return csm } diff --git a/protocol/lavasession/consumer_session_manager_test.go b/protocol/lavasession/consumer_session_manager_test.go index 6719a1d41c..9bdf1b7fcf 100644 --- a/protocol/lavasession/consumer_session_manager_test.go +++ b/protocol/lavasession/consumer_session_manager_test.go @@ -161,7 +161,7 @@ func TestEndpointSortingFlow(t *testing.T) { func CreateConsumerSessionManager() *ConsumerSessionManager { rand.InitRandomSeed() baseLatency := common.AverageWorldLatency / 2 // we want performance to be half our timeout or better - return NewConsumerSessionManager(&RPCEndpoint{"stub", "stub", "stub", false, "/", 0}, provideroptimizer.NewProviderOptimizer(provideroptimizer.STRATEGY_BALANCED, 0, baseLatency, 1), nil, nil, "lava@test") + return NewConsumerSessionManager(&RPCEndpoint{"stub", "stub", "stub", false, "/", 0}, provideroptimizer.NewProviderOptimizer(provideroptimizer.STRATEGY_BALANCED, 0, baseLatency, 1), nil, nil, "lava@test", NewActiveSubscriptionProvidersStorage()) } func TestMain(m *testing.M) { diff --git a/protocol/lavasession/provider_session_manager.go b/protocol/lavasession/provider_session_manager.go index 1ccdebc7f5..cc805168b4 100644 --- a/protocol/lavasession/provider_session_manager.go +++ b/protocol/lavasession/provider_session_manager.go @@ -272,7 +272,6 @@ func filterOldEpochEntries[T dataHandler](blockedEpochHeight uint64, allEpochsMa if !IsEpochValidForUse(epochStored, blockedEpochHeight) { // epoch is not valid so we don't keep its key in the new map - // in the case of subscribe, we need to unsubscribe before deleting the key from storage. value.onDeleteEvent() continue @@ -283,110 +282,6 @@ func filterOldEpochEntries[T dataHandler](blockedEpochHeight uint64, allEpochsMa return } -func (psm *ProviderSessionManager) ProcessUnsubscribe(apiName, subscriptionID, consumerAddress string, epoch uint64) error { - providerSessionWithConsumer, activeError := psm.getActiveProjectFromConsumerAddress(consumerAddress, epoch) - if activeError != nil { - return utils.LavaFormatError("[ProcessUnsubscribe] Couldn't find providerSessionWithConsumer", activeError, utils.Attribute{Key: "epoch", Value: epoch}, utils.Attribute{Key: "address", Value: consumerAddress}) - } - - psm.lock.Lock() - defer psm.lock.Unlock() - var err error - if apiName == TendermintUnsubscribeAll { - // unsubscribe all subscriptions - for _, v := range providerSessionWithConsumer.ongoingSubscriptions { - if v.Sub == nil { - err = utils.LavaFormatError("[ProcessUnsubscribe] TendermintUnsubscribeAll providerSessionWithConsumer.ongoingSubscriptions Error", SubscriptionPointerIsNilError, utils.Attribute{Key: "subscripionId", Value: subscriptionID}) - } else { - v.Sub.Unsubscribe() - } - } - providerSessionWithConsumer.ongoingSubscriptions = make(map[string]*RPCSubscription) // delete the entire map. - return err - } - - subscription, foundSubscription := providerSessionWithConsumer.ongoingSubscriptions[subscriptionID] - if !foundSubscription { - return utils.LavaFormatError("Couldn't find subscription Id in psm.subscriptionSessionsWithAllConsumers", nil, utils.Attribute{Key: "epoch", Value: epoch}, utils.Attribute{Key: "address", Value: consumerAddress}, utils.Attribute{Key: "subscriptionId", Value: subscriptionID}) - } - - if subscription.Sub == nil { - err = utils.LavaFormatError("ProcessUnsubscribe Error", SubscriptionPointerIsNilError, utils.Attribute{Key: "subscripionId", Value: subscriptionID}) - } else { - subscription.Sub.Unsubscribe() - } - delete(providerSessionWithConsumer.ongoingSubscriptions, subscriptionID) // delete subscription after finished with it - return err -} - -// use this method when unlocked. -func (psm *ProviderSessionManager) getActiveProjectFromConsumerAddress(consumerAddress string, epoch uint64) (*ProviderSessionsWithConsumerProject, error) { - projectId, found := psm.readConsumerToPairedWithProjectMap(consumerAddress, epoch) - if !found { - return nil, utils.LavaFormatError("getActiveProjectFromConsumerAddress Couldn't find consumerAddress readConsumerToPairedWithProjectMap", nil, - utils.Attribute{Key: "epoch", Value: epoch}, - utils.Attribute{Key: "address", Value: consumerAddress}, - ) - } - providerSessionWithConsumer, activeError := psm.IsActiveProject(epoch, projectId) - if activeError != nil { - return nil, utils.LavaFormatError("getActiveProjectFromConsumerAddress Couldn't find projectId in psm.subscriptionSessionsWithAllConsumers", activeError, utils.Attribute{Key: "epoch", Value: epoch}, utils.Attribute{Key: "address", Value: projectId}) - } - return providerSessionWithConsumer, nil -} - -func (psm *ProviderSessionManager) addSubscriptionToStorage(subscription *RPCSubscription, consumerAddress string, epoch uint64) error { - // we already validated the epoch is valid in the GetSessions no need to verify again. - providerSessionWithConsumer, activeError := psm.getActiveProjectFromConsumerAddress(consumerAddress, epoch) - if activeError != nil { - return utils.LavaFormatError("[addSubscriptionToStorage] Couldn't find providerSessionWithConsumer", activeError, utils.Attribute{Key: "epoch", Value: epoch}, utils.Attribute{Key: "address", Value: consumerAddress}) - } - - psm.lock.Lock() - defer psm.lock.Unlock() - _, foundSubscription := providerSessionWithConsumer.ongoingSubscriptions[subscription.Id] - if !foundSubscription { - // we shouldnt find a subscription already in the storage. - providerSessionWithConsumer.ongoingSubscriptions[subscription.Id] = subscription - return nil // successfully added subscription to storage - } - - // if we get here we found a subscription already in the storage and we need to return an error as we can't add two subscriptions with the same id - return utils.LavaFormatError("addSubscription", SubscriptionAlreadyExistsError, utils.Attribute{Key: "SubscriptionId", Value: subscription.Id}, utils.Attribute{Key: "epoch", Value: epoch}, utils.Attribute{Key: "address", Value: consumerAddress}) -} - -func (psm *ProviderSessionManager) ReleaseSessionAndCreateSubscription(session *SingleProviderSession, subscription *RPCSubscription, consumerAddress string, epoch, relayNumber uint64) error { - err := psm.OnSessionDone(session, relayNumber) - if err != nil { - return utils.LavaFormatError("Failed ReleaseSessionAndCreateSubscription", err) - } - return psm.addSubscriptionToStorage(subscription, consumerAddress, epoch) -} - -// try to disconnect the subscription incase we got an error. -// if fails to find assumes it was unsubscribed normally -func (psm *ProviderSessionManager) SubscriptionEnded(consumerAddress string, epoch uint64, subscriptionID string) { - providerSessionWithConsumer, activeError := psm.getActiveProjectFromConsumerAddress(consumerAddress, epoch) - if activeError != nil { - utils.LavaFormatError("[SubscriptionEnded] Couldn't find providerSessionWithConsumer, this must be found when getting here", activeError, utils.Attribute{Key: "epoch", Value: epoch}, utils.Attribute{Key: "address", Value: consumerAddress}) - return - } - - psm.lock.Lock() - defer psm.lock.Unlock() - subscription, foundSubscription := providerSessionWithConsumer.ongoingSubscriptions[subscriptionID] - if !foundSubscription { - return - } - - if subscription.Sub == nil { // validate subscription not nil - utils.LavaFormatError("SubscriptionEnded Error", SubscriptionPointerIsNilError, utils.Attribute{Key: "subscripionId", Value: subscription.Id}) - } else { - subscription.Sub.Unsubscribe() - } - delete(providerSessionWithConsumer.ongoingSubscriptions, subscriptionID) // delete subscription after finished with it -} - // Called when the reward server has information on a higher cu proof and usage and this providerSessionsManager needs to sync up on it func (psm *ProviderSessionManager) UpdateSessionCU(consumerAddress string, epoch, sessionID, newCU uint64) error { // load the session and update the CU inside diff --git a/protocol/lavasession/provider_session_manager_test.go b/protocol/lavasession/provider_session_manager_test.go index 2749590d8e..16d4efb9d6 100644 --- a/protocol/lavasession/provider_session_manager_test.go +++ b/protocol/lavasession/provider_session_manager_test.go @@ -20,8 +20,6 @@ const ( dataReliabilityRelayCu = uint64(0) epoch1 = uint64(10) sessionId = uint64(123) - subscriptionID = "124" - subscriptionID2 = "125" dataReliabilitySessionId = uint64(0) relayNumber = uint64(1) relayNumberBeforeUse = uint64(0) @@ -556,192 +554,6 @@ func TestPSMCUMisMatch(t *testing.T) { require.True(t, ProviderConsumerCuMisMatch.Is(err)) } -func TestPSMSubscribeHappyFlowProcessUnsubscribe(t *testing.T) { - // init test - psm, sps := prepareSession(t, context.Background()) - - // validate subscription map is empty - require.Empty(t, psm.sessionsWithAllConsumers[epoch1].sessionMap[projectId].ongoingSubscriptions) - - // subscribe - var channel chan interface{} - subscription := &RPCSubscription{ - Id: subscriptionID, - Sub: nil, - SubscribeRepliesChan: channel, - } - psm.ReleaseSessionAndCreateSubscription(sps, subscription, consumerOneAddress, epoch1, relayNumber) - - // verify state after subscription creation - require.True(t, LockMisUseDetectedError.Is(sps.VerifyLock())) // validating session was unlocked. - require.NotEmpty(t, psm.sessionsWithAllConsumers) - _, foundSubscription := psm.sessionsWithAllConsumers[epoch1].sessionMap[projectId].ongoingSubscriptions[subscriptionID] - require.True(t, foundSubscription) - - err := psm.ProcessUnsubscribe("unsubscribe", subscriptionID, consumerOneAddress, epoch1) - require.True(t, SubscriptionPointerIsNilError.Is(err)) - require.Empty(t, psm.sessionsWithAllConsumers[epoch1].sessionMap[projectId].ongoingSubscriptions) -} - -func TestPSMSubscribeHappyFlowProcessUnsubscribeUnsubscribeAll(t *testing.T) { - // init test - psm, sps := prepareSession(t, context.Background()) - - // validate subscription map is empty - require.Empty(t, psm.sessionsWithAllConsumers[epoch1].sessionMap[projectId].ongoingSubscriptions) - - // subscribe - var channel chan interface{} - subscription := &RPCSubscription{ - Id: subscriptionID, - Sub: nil, - SubscribeRepliesChan: channel, - } - subscription2 := &RPCSubscription{ - Id: subscriptionID2, - Sub: nil, - SubscribeRepliesChan: channel, - } - psm.ReleaseSessionAndCreateSubscription(sps, subscription, consumerOneAddress, epoch1, relayNumber) - - sps, err := psm.GetSession(context.Background(), consumerOneAddress, epoch1, sessionId, relayNumber+1, nil) - require.NoError(t, err) - require.NotNil(t, sps) - - // create 2nd subscription - psm.ReleaseSessionAndCreateSubscription(sps, subscription2, consumerOneAddress, epoch1, relayNumber+1) - - // verify state after subscription creation - require.True(t, LockMisUseDetectedError.Is(sps.VerifyLock())) // validating session was unlocked. - require.NotEmpty(t, psm.sessionsWithAllConsumers) - _, foundSubscription := psm.sessionsWithAllConsumers[epoch1].sessionMap[projectId].ongoingSubscriptions[subscriptionID] - require.True(t, foundSubscription) - _, foundSubscription2 := psm.sessionsWithAllConsumers[epoch1].sessionMap[projectId].ongoingSubscriptions[subscriptionID2] - require.True(t, foundSubscription2) - - err = psm.ProcessUnsubscribe(TendermintUnsubscribeAll, subscriptionID, consumerOneAddress, epoch1) - require.True(t, SubscriptionPointerIsNilError.Is(err)) - require.Empty(t, psm.sessionsWithAllConsumers[epoch1].sessionMap[projectId].ongoingSubscriptions) -} - -func TestPSMSubscribeHappyFlowProcessUnsubscribeUnsubscribeOneOutOfTwo(t *testing.T) { - // init test - psm, sps := prepareSession(t, context.Background()) - - // validate subscription map is empty - require.Empty(t, psm.sessionsWithAllConsumers[epoch1].sessionMap[projectId].ongoingSubscriptions) - - // subscribe - var channel chan interface{} - subscription := &RPCSubscription{ - Id: subscriptionID, - Sub: nil, - SubscribeRepliesChan: channel, - } - subscription2 := &RPCSubscription{ - Id: subscriptionID2, - Sub: nil, - SubscribeRepliesChan: channel, - } - psm.ReleaseSessionAndCreateSubscription(sps, subscription, consumerOneAddress, epoch1, relayNumber) - // create 2nd subscription as we release the session we can just ask for it again with relayNumber + 1 - sps, err := psm.GetSession(context.Background(), consumerOneAddress, epoch1, sessionId, relayNumber+1, nil) - require.NoError(t, err) - psm.ReleaseSessionAndCreateSubscription(sps, subscription2, consumerOneAddress, epoch1, relayNumber+1) - - err = psm.ProcessUnsubscribe("unsubscribeOne", subscriptionID, consumerOneAddress, epoch1) - require.True(t, SubscriptionPointerIsNilError.Is(err)) - require.NotEmpty(t, psm.sessionsWithAllConsumers[epoch1].sessionMap[projectId].ongoingSubscriptions) - _, foundId2 := psm.sessionsWithAllConsumers[epoch1].sessionMap[projectId].ongoingSubscriptions[subscriptionID2] - require.True(t, foundId2) -} - -func TestPSMSubscribeHappyFlowSubscriptionEnded(t *testing.T) { - // init test - psm, sps := prepareSession(t, context.Background()) - - // validate subscription map is empty - require.Empty(t, psm.sessionsWithAllConsumers[epoch1].sessionMap[projectId].ongoingSubscriptions) - - // subscribe - var channel chan interface{} - subscription := &RPCSubscription{ - Id: subscriptionID, - Sub: nil, - SubscribeRepliesChan: channel, - } - psm.ReleaseSessionAndCreateSubscription(sps, subscription, consumerOneAddress, epoch1, relayNumber) - - // verify state after subscription creation - require.True(t, LockMisUseDetectedError.Is(sps.VerifyLock())) // validating session was unlocked. - require.NotEmpty(t, psm.sessionsWithAllConsumers) - _, foundSubscription := psm.sessionsWithAllConsumers[epoch1].sessionMap[projectId].ongoingSubscriptions[subscriptionID] - require.True(t, foundSubscription) - - psm.SubscriptionEnded(consumerOneAddress, epoch1, subscriptionID) - require.Empty(t, psm.sessionsWithAllConsumers[epoch1].sessionMap[projectId].ongoingSubscriptions) -} - -func TestPSMSubscribeHappyFlowSubscriptionEndedOneOutOfTwo(t *testing.T) { - // init test - psm, sps := prepareSession(t, context.Background()) - - // validate subscription map is empty - require.Empty(t, psm.sessionsWithAllConsumers[epoch1].sessionMap[projectId].ongoingSubscriptions) - - // subscribe - var channel chan interface{} - subscription := &RPCSubscription{ - Id: subscriptionID, - Sub: nil, - SubscribeRepliesChan: channel, - } - subscription2 := &RPCSubscription{ - Id: subscriptionID2, - Sub: nil, - SubscribeRepliesChan: channel, - } - psm.ReleaseSessionAndCreateSubscription(sps, subscription, consumerOneAddress, epoch1, relayNumber) - // create 2nd subscription as we release the session we can just ask for it again with relayNumber + 1 - sps, err := psm.GetSession(context.Background(), consumerOneAddress, epoch1, sessionId, relayNumber+1, nil) - require.NoError(t, err) - psm.ReleaseSessionAndCreateSubscription(sps, subscription2, consumerOneAddress, epoch1, relayNumber) - - psm.SubscriptionEnded(consumerOneAddress, epoch1, subscriptionID) - require.NotEmpty(t, psm.sessionsWithAllConsumers[epoch1].sessionMap[projectId].ongoingSubscriptions) - _, foundId2 := psm.sessionsWithAllConsumers[epoch1].sessionMap[projectId].ongoingSubscriptions[subscriptionID2] - require.True(t, foundId2) -} - -func TestPSMSubscribeEpochChange(t *testing.T) { - // init test - psm, sps := prepareSession(t, context.Background()) - - // validate subscription map is empty - require.Empty(t, psm.sessionsWithAllConsumers[epoch1].sessionMap[projectId].ongoingSubscriptions) - - // subscribe - var channel chan interface{} - subscription := &RPCSubscription{ - Id: subscriptionID, - Sub: nil, - SubscribeRepliesChan: channel, - } - subscription2 := &RPCSubscription{ - Id: subscriptionID2, - Sub: nil, - SubscribeRepliesChan: channel, - } - psm.ReleaseSessionAndCreateSubscription(sps, subscription, consumerOneAddress, epoch1, relayNumber) - // create 2nd subscription as we release the session we can just ask for it again with relayNumber + 1 - sps, err := psm.GetSession(context.Background(), consumerOneAddress, epoch1, sessionId, relayNumber+1, nil) - require.NoError(t, err) - psm.ReleaseSessionAndCreateSubscription(sps, subscription2, consumerOneAddress, epoch1, relayNumber+1) - - psm.UpdateEpoch(epoch2) - require.Empty(t, psm.sessionsWithAllConsumers[epoch2]) -} - type testSessionData struct { currentCU uint64 inUse bool diff --git a/protocol/lavasession/used_providers.go b/protocol/lavasession/used_providers.go index 7d50610fb1..8f8a6b7fab 100644 --- a/protocol/lavasession/used_providers.go +++ b/protocol/lavasession/used_providers.go @@ -162,6 +162,7 @@ func (up *UsedProviders) TryLockSelection(ctx context.Context) error { for counter := 0; counter < MaximumNumberOfSelectionLockAttempts; counter++ { select { case <-ctx.Done(): + utils.LavaFormatTrace("Failed locking selection, context is done") return ContextDoneNoNeedToLockSelectionError default: canSelect := up.tryLockSelection() diff --git a/protocol/metrics/rpcconsumerlogs.go b/protocol/metrics/rpcconsumerlogs.go index 6d6fa749a1..74a066dffa 100644 --- a/protocol/metrics/rpcconsumerlogs.go +++ b/protocol/metrics/rpcconsumerlogs.go @@ -129,22 +129,24 @@ func (rpccl *RPCConsumerLogs) GetUniqueGuidResponseForError(responseError error, } // Websocket healthy disconnections throw "websocket: close 1005 (no status)" error, -// We dont want to alert error monitoring for that purpses. -func (rpccl *RPCConsumerLogs) AnalyzeWebSocketErrorAndWriteMessage(c *websocket.Conn, mt int, err error, msgSeed string, msg []byte, rpcType string, timeTaken time.Duration) { +// We don't want to alert error monitoring for that purpses. +func (rpccl *RPCConsumerLogs) AnalyzeWebSocketErrorAndGetFormattedMessage(webSocketAddr string, err error, msgSeed string, msg []byte, rpcType string, timeTaken time.Duration) []byte { if err != nil { errMessage := err.Error() if strings.Contains(errMessage, webSocketCloseMessage) { utils.LavaFormatDebug("Websocket connection closed by the user, " + errMessage) - return + return nil } - rpccl.LogRequestAndResponse(rpcType+" ws msg", true, "ws", c.LocalAddr().String(), string(msg), "", msgSeed, timeTaken, err) + rpccl.LogRequestAndResponse(rpcType+" ws msg", true, "ws", webSocketAddr, string(msg), "", msgSeed, timeTaken, err) jsonResponse, _ := json.Marshal(fiber.Map{ "Error_Received": rpccl.GetUniqueGuidResponseForError(err, msgSeed), }) - c.WriteMessage(mt, jsonResponse) + return jsonResponse } + + return nil } func (rpccl *RPCConsumerLogs) LogRequestAndResponse(module string, hasError bool, method, path, req, resp, msgSeed string, timeTaken time.Duration, err error) { diff --git a/protocol/metrics/rpcconsumerlogs_test.go b/protocol/metrics/rpcconsumerlogs_test.go index f0f0c62e62..a8482eb93c 100644 --- a/protocol/metrics/rpcconsumerlogs_test.go +++ b/protocol/metrics/rpcconsumerlogs_test.go @@ -60,7 +60,9 @@ func TestAnalyzeWebSocketErrorAndWriteMessage(t *testing.T) { mt, _, _ := c.ReadMessage() plog, _ := NewRPCConsumerLogs(nil, nil) responseError := errors.New("response error") - plog.AnalyzeWebSocketErrorAndWriteMessage(c, mt, responseError, "seed", []byte{}, "rpcType", 1*time.Millisecond) + formatterMsg := plog.AnalyzeWebSocketErrorAndGetFormattedMessage(c.LocalAddr().String(), responseError, "seed", []byte{}, "rpcType", 1*time.Millisecond) + assert.NotNil(t, formatterMsg) + c.WriteMessage(mt, formatterMsg) })) listenFunc := func() { diff --git a/protocol/parser/parser.go b/protocol/parser/parser.go index ae491f622d..ea46008dd9 100644 --- a/protocol/parser/parser.go +++ b/protocol/parser/parser.go @@ -29,6 +29,7 @@ type RPCInput interface { ParseBlock(block string) (int64, error) GetHeaders() []pairingtypes.Metadata GetMethod() string + GetID() json.RawMessage } func ParseDefaultBlockParameter(block string) (int64, error) { diff --git a/protocol/parser/parser_test.go b/protocol/parser/parser_test.go index 82fcd2c36c..fb66257877 100644 --- a/protocol/parser/parser_test.go +++ b/protocol/parser/parser_test.go @@ -31,6 +31,10 @@ func (rpcInputTest *RPCInputTest) GetResult() json.RawMessage { return rpcInputTest.Result } +func (rpcInputTest *RPCInputTest) GetID() json.RawMessage { + return nil +} + func (rpcInputTest *RPCInputTest) ParseBlock(block string) (int64, error) { if rpcInputTest.ParseBlockFunc == nil { return ParseDefaultBlockParameter(block) diff --git a/protocol/performance/cache.go b/protocol/performance/cache.go index 0402b78bc5..0150e8b5f1 100644 --- a/protocol/performance/cache.go +++ b/protocol/performance/cache.go @@ -39,7 +39,7 @@ func InitCache(ctx context.Context, addr string) (*Cache, error) { func (cache *Cache) GetEntry(ctx context.Context, relayCacheGet *pairingtypes.RelayCacheGet) (reply *pairingtypes.CacheRelayReply, err error) { if cache == nil { // TODO: try to connect again once in a while - return nil, NotInitialisedError + return nil, NotInitializedError } if cache.client == nil { return nil, NotConnectedError.Wrapf("No client connected to address: %s", cache.address) @@ -55,7 +55,7 @@ func (cache *Cache) CacheActive() bool { func (cache *Cache) SetEntry(ctx context.Context, cacheSet *pairingtypes.RelayCacheSet) error { if cache == nil { // TODO: try to connect again once in a while - return NotInitialisedError + return NotInitializedError } if cache.client == nil { return NotConnectedError.Wrapf("No client connected to address: %s", cache.address) diff --git a/protocol/performance/errors.go b/protocol/performance/errors.go index d587ac91af..bb69cd3f51 100644 --- a/protocol/performance/errors.go +++ b/protocol/performance/errors.go @@ -6,5 +6,5 @@ import ( var ( NotConnectedError = sdkerrors.New("Not Connected Error", 700, "No Connection To grpc server") - NotInitialisedError = sdkerrors.New("Not Initialised Error", 701, "to use cache run initCache") + NotInitializedError = sdkerrors.New("Not Initialised Error", 701, "to use cache run initCache") ) diff --git a/protocol/rpcconsumer/consumer_consistency.go b/protocol/rpcconsumer/consumer_consistency.go index 9bafbc78fb..dc54667e44 100644 --- a/protocol/rpcconsumer/consumer_consistency.go +++ b/protocol/rpcconsumer/consumer_consistency.go @@ -44,17 +44,29 @@ func (cc *ConsumerConsistency) Key(dappId string, ip string) string { return dappId + "__" + ip } -func (cc *ConsumerConsistency) SetSeenBlock(blockSeen int64, dappId string, ip string) { +// used on subscription, where we already have the dapp key stored, but we don't keep the dappId and ip separately +func (cc *ConsumerConsistency) SetSeenBlockFromKey(blockSeen int64, key string) { if cc == nil { return } - block, _ := cc.getLatestBlock(cc.Key(dappId, ip)) + block, _ := cc.getLatestBlock(key) if block < blockSeen { - cc.setLatestBlock(cc.Key(dappId, ip), blockSeen) + cc.setLatestBlock(key, blockSeen) } } +func (cc *ConsumerConsistency) SetSeenBlock(blockSeen int64, dappId string, ip string) { + if cc == nil { + return + } + key := cc.Key(dappId, ip) + cc.SetSeenBlockFromKey(blockSeen, key) +} + func (cc *ConsumerConsistency) GetSeenBlock(dappId string, ip string) (int64, bool) { + if cc == nil { + return 0, false + } return cc.getLatestBlock(cc.Key(dappId, ip)) } diff --git a/protocol/rpcconsumer/relay_processor.go b/protocol/rpcconsumer/relay_processor.go index 5efbb211b9..4c83ac2c72 100644 --- a/protocol/rpcconsumer/relay_processor.go +++ b/protocol/rpcconsumer/relay_processor.go @@ -221,6 +221,7 @@ func (rp *RelayProcessor) setValidResponse(response *relayResponse) { response.relayResult.Finalized = false // shut down data reliability // } } + if response.relayResult.Reply == nil { utils.LavaFormatError("got to setValidResponse with nil Reply", response.err, @@ -235,6 +236,12 @@ func (rp *RelayProcessor) setValidResponse(response *relayResponse) { blockSeen := response.relayResult.Reply.LatestBlock // nil safe rp.consumerConsistency.SetSeenBlock(blockSeen, rp.dappID, rp.consumerIp) + // on subscribe results, we just append to successful results instead of parsing results because we already have a validation. + if chainlib.IsFunctionTagOfType(rp.chainMessage, spectypes.FUNCTION_TAG_SUBSCRIBE) { + rp.successResults = append(rp.successResults, response.relayResult) + return + } + // check response error foundError, errorMessage := rp.chainMessage.CheckResponseError(response.relayResult.Reply.Data, response.relayResult.StatusCode) if foundError { diff --git a/protocol/rpcconsumer/relay_processor_test.go b/protocol/rpcconsumer/relay_processor_test.go index 598c9499a8..bb21a8eda3 100644 --- a/protocol/rpcconsumer/relay_processor_test.go +++ b/protocol/rpcconsumer/relay_processor_test.go @@ -97,7 +97,7 @@ func TestRelayProcessorHappyFlow(t *testing.T) { w.WriteHeader(http.StatusOK) }) specId := "LAV1" - chainParser, _, _, closeServer, _, err := chainlib.CreateChainLibMocks(ctx, specId, spectypes.APIInterfaceRest, serverHandler, "../../", nil) + chainParser, _, _, closeServer, _, err := chainlib.CreateChainLibMocks(ctx, specId, spectypes.APIInterfaceRest, serverHandler, nil, "../../", nil) if closeServer != nil { defer closeServer() } @@ -139,7 +139,7 @@ func TestRelayProcessorNodeErrorRetryFlow(t *testing.T) { w.WriteHeader(http.StatusOK) }) specId := "LAV1" - chainParser, _, _, closeServer, _, err := chainlib.CreateChainLibMocks(ctx, specId, spectypes.APIInterfaceRest, serverHandler, "../../", nil) + chainParser, _, _, closeServer, _, err := chainlib.CreateChainLibMocks(ctx, specId, spectypes.APIInterfaceRest, serverHandler, nil, "../../", nil) if closeServer != nil { defer closeServer() } @@ -278,7 +278,7 @@ func TestRelayProcessorNodeErrorRetryFlow(t *testing.T) { w.WriteHeader(http.StatusOK) }) specId := "LAV1" - chainParser, _, _, closeServer, _, err := chainlib.CreateChainLibMocks(ctx, specId, spectypes.APIInterfaceRest, serverHandler, "../../", nil) + chainParser, _, _, closeServer, _, err := chainlib.CreateChainLibMocks(ctx, specId, spectypes.APIInterfaceRest, serverHandler, nil, "../../", nil) if closeServer != nil { defer closeServer() } @@ -316,7 +316,7 @@ func TestRelayProcessorTimeout(t *testing.T) { w.WriteHeader(http.StatusOK) }) specId := "LAV1" - chainParser, _, _, closeServer, _, err := chainlib.CreateChainLibMocks(ctx, specId, spectypes.APIInterfaceRest, serverHandler, "../../", nil) + chainParser, _, _, closeServer, _, err := chainlib.CreateChainLibMocks(ctx, specId, spectypes.APIInterfaceRest, serverHandler, nil, "../../", nil) if closeServer != nil { defer closeServer() } @@ -368,7 +368,7 @@ func TestRelayProcessorRetry(t *testing.T) { w.WriteHeader(http.StatusOK) }) specId := "LAV1" - chainParser, _, _, closeServer, _, err := chainlib.CreateChainLibMocks(ctx, specId, spectypes.APIInterfaceRest, serverHandler, "../../", nil) + chainParser, _, _, closeServer, _, err := chainlib.CreateChainLibMocks(ctx, specId, spectypes.APIInterfaceRest, serverHandler, nil, "../../", nil) if closeServer != nil { defer closeServer() } @@ -412,7 +412,7 @@ func TestRelayProcessorRetryNodeError(t *testing.T) { w.WriteHeader(http.StatusOK) }) specId := "LAV1" - chainParser, _, _, closeServer, _, err := chainlib.CreateChainLibMocks(ctx, specId, spectypes.APIInterfaceRest, serverHandler, "../../", nil) + chainParser, _, _, closeServer, _, err := chainlib.CreateChainLibMocks(ctx, specId, spectypes.APIInterfaceRest, serverHandler, nil, "../../", nil) if closeServer != nil { defer closeServer() } @@ -457,7 +457,7 @@ func TestRelayProcessorStatefulApi(t *testing.T) { w.WriteHeader(http.StatusOK) }) specId := "LAV1" - chainParser, _, _, closeServer, _, err := chainlib.CreateChainLibMocks(ctx, specId, spectypes.APIInterfaceRest, serverHandler, "../../", nil) + chainParser, _, _, closeServer, _, err := chainlib.CreateChainLibMocks(ctx, specId, spectypes.APIInterfaceRest, serverHandler, nil, "../../", nil) if closeServer != nil { defer closeServer() } @@ -502,7 +502,7 @@ func TestRelayProcessorStatefulApiErr(t *testing.T) { w.WriteHeader(http.StatusOK) }) specId := "LAV1" - chainParser, _, _, closeServer, _, err := chainlib.CreateChainLibMocks(ctx, specId, spectypes.APIInterfaceRest, serverHandler, "../../", nil) + chainParser, _, _, closeServer, _, err := chainlib.CreateChainLibMocks(ctx, specId, spectypes.APIInterfaceRest, serverHandler, nil, "../../", nil) if closeServer != nil { defer closeServer() } @@ -546,7 +546,7 @@ func TestRelayProcessorLatest(t *testing.T) { w.WriteHeader(http.StatusOK) }) specId := "LAV1" - chainParser, _, _, closeServer, _, err := chainlib.CreateChainLibMocks(ctx, specId, spectypes.APIInterfaceRest, serverHandler, "../../", nil) + chainParser, _, _, closeServer, _, err := chainlib.CreateChainLibMocks(ctx, specId, spectypes.APIInterfaceRest, serverHandler, nil, "../../", nil) if closeServer != nil { defer closeServer() } diff --git a/protocol/rpcconsumer/rpcconsumer.go b/protocol/rpcconsumer/rpcconsumer.go index 747e4855e7..3e8a476e00 100644 --- a/protocol/rpcconsumer/rpcconsumer.go +++ b/protocol/rpcconsumer/rpcconsumer.go @@ -3,6 +3,7 @@ package rpcconsumer import ( "context" "fmt" + "net/http" "os" "os/signal" "strconv" @@ -18,7 +19,7 @@ import ( "github.com/lavanet/lava/v2/app" "github.com/lavanet/lava/v2/protocol/chainlib" "github.com/lavanet/lava/v2/protocol/common" - "github.com/lavanet/lava/v2/protocol/lavaprotocol" + "github.com/lavanet/lava/v2/protocol/lavaprotocol/finalizationconsensus" "github.com/lavanet/lava/v2/protocol/lavasession" "github.com/lavanet/lava/v2/protocol/metrics" "github.com/lavanet/lava/v2/protocol/performance" @@ -32,6 +33,7 @@ import ( conflicttypes "github.com/lavanet/lava/v2/x/conflict/types" plantypes "github.com/lavanet/lava/v2/x/plans/types" protocoltypes "github.com/lavanet/lava/v2/x/protocol/types" + spectypes "github.com/lavanet/lava/v2/x/spec/types" "github.com/spf13/cobra" "github.com/spf13/viper" ) @@ -89,7 +91,7 @@ type ConsumerStateTrackerInf interface { RegisterForVersionUpdates(ctx context.Context, version *protocoltypes.Version, versionValidator updaters.VersionValidationInf) RegisterConsumerSessionManagerForPairingUpdates(ctx context.Context, consumerSessionManager *lavasession.ConsumerSessionManager) RegisterForSpecUpdates(ctx context.Context, specUpdatable updaters.SpecUpdatable, endpoint lavasession.RPCEndpoint) error - RegisterFinalizationConsensusForUpdates(context.Context, *lavaprotocol.FinalizationConsensus) + RegisterFinalizationConsensusForUpdates(context.Context, *finalizationconsensus.FinalizationConsensus) RegisterForDowntimeParamsUpdates(ctx context.Context, downtimeParamsUpdatable updaters.DowntimeParamsUpdatable) error TxConflictDetection(ctx context.Context, finalizationConflict *conflicttypes.FinalizationConflict, responseConflict *conflicttypes.ResponseConflict, sameProviderConflict *conflicttypes.FinalizationConflict, conflictHandler common.ConflictHandlerInterface) error GetConsumerPolicy(ctx context.Context, consumerAddress, chainID string) (*plantypes.Policy, error) @@ -221,7 +223,7 @@ func (rpcc *RPCConsumer) Start(ctx context.Context, options *rpcConsumerStartOpt _, averageBlockTime, _, _ := chainParser.ChainBlockStats() var optimizer *provideroptimizer.ProviderOptimizer var consumerConsistency *ConsumerConsistency - var finalizationConsensus *lavaprotocol.FinalizationConsensus + var finalizationConsensus *finalizationconsensus.FinalizationConsensus getOrCreateChainAssets := func() error { // this is locked so we don't race optimizers creation chainMutexes[chainID].Lock() @@ -256,12 +258,12 @@ func (rpcc *RPCConsumer) Start(ctx context.Context, options *rpcConsumerStartOpt value, exists = finalizationConsensuses.Load(chainID) if !exists { // doesn't exist for this chain create a new one - finalizationConsensus = lavaprotocol.NewFinalizationConsensus(rpcEndpoint.ChainID) + finalizationConsensus = finalizationconsensus.NewFinalizationConsensus(rpcEndpoint.ChainID) consumerStateTracker.RegisterFinalizationConsensusForUpdates(ctx, finalizationConsensus) finalizationConsensuses.Store(chainID, finalizationConsensus) } else { var ok bool - finalizationConsensus, ok = value.(*lavaprotocol.FinalizationConsensus) + finalizationConsensus, ok = value.(*finalizationconsensus.FinalizationConsensus) if !ok { err = utils.LavaFormatError("failed loading finalization consensus, value is of the wrong type", nil, utils.Attribute{Key: "endpoint", Value: rpcEndpoint.Key()}) return err @@ -281,8 +283,10 @@ func (rpcc *RPCConsumer) Start(ctx context.Context, options *rpcConsumerStartOpt return err } + // Create active subscription provider storage for each unique chain + activeSubscriptionProvidersStorage := lavasession.NewActiveSubscriptionProvidersStorage() + consumerSessionManager := lavasession.NewConsumerSessionManager(rpcEndpoint, optimizer, consumerMetricsManager, consumerReportsManager, consumerAddr.String(), activeSubscriptionProvidersStorage) // Register For Updates - consumerSessionManager := lavasession.NewConsumerSessionManager(rpcEndpoint, optimizer, consumerMetricsManager, consumerReportsManager, consumerAddr.String()) rpcc.consumerStateTracker.RegisterConsumerSessionManagerForPairingUpdates(ctx, consumerSessionManager) var relaysMonitor *metrics.RelaysMonitor @@ -290,9 +294,18 @@ func (rpcc *RPCConsumer) Start(ctx context.Context, options *rpcConsumerStartOpt relaysMonitor = metrics.NewRelaysMonitor(options.cmdFlags.RelaysHealthIntervalFlag, rpcEndpoint.ChainID, rpcEndpoint.ApiInterface) relaysMonitorAggregator.RegisterRelaysMonitor(rpcEndpoint.String(), relaysMonitor) } + rpcConsumerServer := &RPCConsumerServer{} + + var consumerWsSubscriptionManager *chainlib.ConsumerWSSubscriptionManager + var specMethodType string + if rpcEndpoint.ApiInterface == spectypes.APIInterfaceJsonRPC { + specMethodType = http.MethodPost + } + consumerWsSubscriptionManager = chainlib.NewConsumerWSSubscriptionManager(consumerSessionManager, rpcConsumerServer, options.refererData, specMethodType, chainParser, activeSubscriptionProvidersStorage) + utils.LavaFormatInfo("RPCConsumer Listening", utils.Attribute{Key: "endpoints", Value: rpcEndpoint.String()}) - err = rpcConsumerServer.ServeRPCRequests(ctx, rpcEndpoint, rpcc.consumerStateTracker, chainParser, finalizationConsensus, consumerSessionManager, options.requiredResponses, privKey, lavaChainID, options.cache, rpcConsumerMetrics, consumerAddr, consumerConsistency, relaysMonitor, options.cmdFlags, options.stateShare, options.refererData, consumerReportsManager) + err = rpcConsumerServer.ServeRPCRequests(ctx, rpcEndpoint, rpcc.consumerStateTracker, chainParser, finalizationConsensus, consumerSessionManager, options.requiredResponses, privKey, lavaChainID, options.cache, rpcConsumerMetrics, consumerAddr, consumerConsistency, relaysMonitor, options.cmdFlags, options.stateShare, options.refererData, consumerReportsManager, consumerWsSubscriptionManager) if err != nil { err = utils.LavaFormatError("failed serving rpc requests", err, utils.Attribute{Key: "endpoint", Value: rpcEndpoint}) errCh <- err diff --git a/protocol/rpcconsumer/rpcconsumer_server.go b/protocol/rpcconsumer/rpcconsumer_server.go index 1ebfdf379a..87c7c4b3df 100644 --- a/protocol/rpcconsumer/rpcconsumer_server.go +++ b/protocol/rpcconsumer/rpcconsumer_server.go @@ -6,8 +6,11 @@ import ( "fmt" "strconv" "strings" + "sync" "time" + "github.com/goccy/go-json" + sdkerrors "cosmossdk.io/errors" "github.com/btcsuite/btcd/btcec/v2" sdk "github.com/cosmos/cosmos-sdk/types" @@ -16,6 +19,7 @@ import ( "github.com/lavanet/lava/v2/protocol/chainlib/extensionslib" "github.com/lavanet/lava/v2/protocol/common" "github.com/lavanet/lava/v2/protocol/lavaprotocol" + "github.com/lavanet/lava/v2/protocol/lavaprotocol/finalizationconsensus" "github.com/lavanet/lava/v2/protocol/lavasession" "github.com/lavanet/lava/v2/protocol/metrics" "github.com/lavanet/lava/v2/protocol/performance" @@ -40,26 +44,34 @@ const ( var NoResponseTimeout = sdkerrors.New("NoResponseTimeout Error", 685, "timeout occurred while waiting for providers responses") +type CancelableContextHolder struct { + Ctx context.Context + CancelFunc context.CancelFunc +} + // implements Relay Sender interfaced and uses an ChainListener to get it called type RPCConsumerServer struct { - chainParser chainlib.ChainParser - consumerSessionManager *lavasession.ConsumerSessionManager - listenEndpoint *lavasession.RPCEndpoint - rpcConsumerLogs *metrics.RPCConsumerLogs - cache *performance.Cache - privKey *btcec.PrivateKey - consumerTxSender ConsumerTxSender - requiredResponses int - finalizationConsensus *lavaprotocol.FinalizationConsensus - lavaChainID string - ConsumerAddress sdk.AccAddress - consumerConsistency *ConsumerConsistency - sharedState bool // using the cache backend to sync the latest seen block with other consumers - relaysMonitor *metrics.RelaysMonitor - reporter metrics.Reporter - debugRelays bool - disableNodeErrorRetry bool - relayRetriesManager *RelayRetriesManager + consumerProcessGuid string + chainParser chainlib.ChainParser + consumerSessionManager *lavasession.ConsumerSessionManager + listenEndpoint *lavasession.RPCEndpoint + rpcConsumerLogs *metrics.RPCConsumerLogs + cache *performance.Cache + privKey *btcec.PrivateKey + consumerTxSender ConsumerTxSender + requiredResponses int + finalizationConsensus *finalizationconsensus.FinalizationConsensus + lavaChainID string + ConsumerAddress sdk.AccAddress + consumerConsistency *ConsumerConsistency + sharedState bool // using the cache backend to sync the latest seen block with other consumers + relaysMonitor *metrics.RelaysMonitor + reporter metrics.Reporter + debugRelays bool + connectedSubscriptionsContexts map[string]*CancelableContextHolder + connectedSubscriptionsLock sync.RWMutex + disableNodeErrorRetry bool + relayRetriesManager *RelayRetriesManager } type relayResponse struct { @@ -76,7 +88,7 @@ type ConsumerTxSender interface { func (rpccs *RPCConsumerServer) ServeRPCRequests(ctx context.Context, listenEndpoint *lavasession.RPCEndpoint, consumerStateTracker ConsumerStateTrackerInf, chainParser chainlib.ChainParser, - finalizationConsensus *lavaprotocol.FinalizationConsensus, + finalizationConsensus *finalizationconsensus.FinalizationConsensus, consumerSessionManager *lavasession.ConsumerSessionManager, requiredResponses int, privKey *btcec.PrivateKey, @@ -90,6 +102,7 @@ func (rpccs *RPCConsumerServer) ServeRPCRequests(ctx context.Context, listenEndp sharedState bool, refererData *chainlib.RefererData, reporter metrics.Reporter, + consumerWsSubscriptionManager *chainlib.ConsumerWSSubscriptionManager, ) (err error) { rpccs.consumerSessionManager = consumerSessionManager rpccs.listenEndpoint = listenEndpoint @@ -106,9 +119,11 @@ func (rpccs *RPCConsumerServer) ServeRPCRequests(ctx context.Context, listenEndp rpccs.sharedState = sharedState rpccs.reporter = reporter rpccs.debugRelays = cmdFlags.DebugRelays + rpccs.connectedSubscriptionsContexts = make(map[string]*CancelableContextHolder) + rpccs.consumerProcessGuid = strconv.FormatUint(utils.GenerateUniqueIdentifier(), 10) rpccs.disableNodeErrorRetry = cmdFlags.DisableRetryOnNodeErrors rpccs.relayRetriesManager = NewRelayRetriesManager() - chainListener, err := chainlib.NewChainListener(ctx, listenEndpoint, rpccs, rpccs, rpcConsumerLogs, chainParser, refererData) + chainListener, err := chainlib.NewChainListener(ctx, listenEndpoint, rpccs, rpccs, rpcConsumerLogs, chainParser, refererData, consumerWsSubscriptionManager) if err != nil { return err } @@ -134,6 +149,10 @@ func (rpccs *RPCConsumerServer) ServeRPCRequests(ctx context.Context, listenEndp return nil } +func (rpccs *RPCConsumerServer) SetConsistencySeenBlock(blockSeen int64, key string) { + rpccs.consumerConsistency.SetSeenBlockFromKey(blockSeen, key) +} + func (rpccs *RPCConsumerServer) sendCraftedRelaysWrapper(initialRelays bool) (bool, error) { if initialRelays { // Only start after everything is initialized - check consumer session manager @@ -173,13 +192,14 @@ func (rpccs *RPCConsumerServer) waitForPairing() { } func (rpccs *RPCConsumerServer) craftRelay(ctx context.Context) (ok bool, relay *pairingtypes.RelayPrivateData, chainMessage chainlib.ChainMessage, err error) { - parsing, collectionData, ok := rpccs.chainParser.GetParsingByTag(spectypes.FUNCTION_TAG_GET_BLOCKNUM) + parsing, apiCollection, ok := rpccs.chainParser.GetParsingByTag(spectypes.FUNCTION_TAG_GET_BLOCKNUM) if !ok { return false, nil, nil, utils.LavaFormatWarning("did not send initial relays because the spec does not contain "+spectypes.FUNCTION_TAG_GET_BLOCKNUM.String(), nil, utils.LogAttr("chainID", rpccs.listenEndpoint.ChainID), utils.LogAttr("APIInterface", rpccs.listenEndpoint.ApiInterface), ) } + collectionData := apiCollection.CollectionData path := parsing.ApiName data := []byte(parsing.FunctionTemplate) @@ -275,42 +295,70 @@ func (rpccs *RPCConsumerServer) SendRelay( analytics *metrics.RelayMetrics, metadata []pairingtypes.Metadata, ) (relayResult *common.RelayResult, errRet error) { + chainMessage, directiveHeaders, relayRequestData, err := rpccs.ParseRelay(ctx, url, req, connectionType, dappID, consumerIp, analytics, metadata) + if err != nil { + return nil, err + } + + return rpccs.SendParsedRelay(ctx, dappID, consumerIp, analytics, chainMessage, directiveHeaders, relayRequestData) +} + +func (rpccs *RPCConsumerServer) ParseRelay( + ctx context.Context, + url string, + req string, + connectionType string, + dappID string, + consumerIp string, + analytics *metrics.RelayMetrics, + metadata []pairingtypes.Metadata, +) (chainMessage chainlib.ChainMessage, directiveHeaders map[string]string, relayRequestData *pairingtypes.RelayPrivateData, err error) { // gets the relay request data from the ChainListener // parses the request into an APIMessage, and validating it corresponds to the spec currently in use // construct the common data for a relay message, common data is identical across multiple sends and data reliability - // sends a relay message to a provider - // compares the result with other providers if defined so - // compares the response with other consumer wallets if defined so - // asynchronously sends data reliability if necessary // remove lava directive headers - metadata, directiveHeaders := rpccs.LavaDirectiveHeaders(metadata) - relaySentTime := time.Now() - chainMessage, err := rpccs.chainParser.ParseMsg(url, []byte(req), connectionType, metadata, rpccs.getExtensionsFromDirectiveHeaders(directiveHeaders)) + metadata, directiveHeaders = rpccs.LavaDirectiveHeaders(metadata) + chainMessage, err = rpccs.chainParser.ParseMsg(url, []byte(req), connectionType, metadata, rpccs.getExtensionsFromDirectiveHeaders(directiveHeaders)) if err != nil { - return nil, err - } - // temporarily disable subscriptions - isSubscription := chainlib.IsSubscription(chainMessage) - if isSubscription { - return &common.RelayResult{ProviderInfo: common.ProviderInfo{ProviderAddress: ""}}, utils.LavaFormatError("Subscriptions are not supported at the moment", nil) + return nil, nil, nil, err } rpccs.HandleDirectiveHeadersForMessage(chainMessage, directiveHeaders) + // do this in a loop with retry attempts, configurable via a flag, limited by the number of providers in CSM reqBlock, _ := chainMessage.RequestedBlock() seenBlock, _ := rpccs.consumerConsistency.GetSeenBlock(dappID, consumerIp) if seenBlock < 0 { seenBlock = 0 } - relayRequestData := lavaprotocol.NewRelayData(ctx, connectionType, url, []byte(req), seenBlock, reqBlock, rpccs.listenEndpoint.ApiInterface, chainMessage.GetRPCMessage().GetHeaders(), chainlib.GetAddon(chainMessage), common.GetExtensionNames(chainMessage.GetExtensions())) + relayRequestData = lavaprotocol.NewRelayData(ctx, connectionType, url, []byte(req), seenBlock, reqBlock, rpccs.listenEndpoint.ApiInterface, chainMessage.GetRPCMessage().GetHeaders(), chainlib.GetAddon(chainMessage), common.GetExtensionNames(chainMessage.GetExtensions())) + return chainMessage, directiveHeaders, relayRequestData, nil +} + +func (rpccs *RPCConsumerServer) SendParsedRelay( + ctx context.Context, + dappID string, + consumerIp string, + analytics *metrics.RelayMetrics, + chainMessage chainlib.ChainMessage, + directiveHeaders map[string]string, + relayRequestData *pairingtypes.RelayPrivateData, +) (relayResult *common.RelayResult, errRet error) { + // sends a relay message to a provider + // compares the result with other providers if defined so + // compares the response with other consumer wallets if defined so + // asynchronously sends data reliability if necessary + + relaySentTime := time.Now() relayProcessor, err := rpccs.ProcessRelaySend(ctx, directiveHeaders, chainMessage, relayRequestData, dappID, consumerIp, analytics) if err != nil && !relayProcessor.HasResults() { // we can't send anymore, and we don't have any responses utils.LavaFormatError("failed getting responses from providers", err, utils.Attribute{Key: "GUID", Value: ctx}, utils.LogAttr("endpoint", rpccs.listenEndpoint.Key()), utils.LogAttr("userIp", consumerIp), utils.LogAttr("relayProcessor", relayProcessor)) return nil, err } + // Handle Data Reliability enabled, dataReliabilityThreshold := rpccs.chainParser.DataReliabilityParams() // check if data reliability is enabled and relay processor allows us to perform data reliability @@ -330,6 +378,7 @@ func (rpccs *RPCConsumerServer) SendRelay( if err != nil { return returnedResult, utils.LavaFormatError("failed processing responses from providers", err, utils.Attribute{Key: "GUID", Value: ctx}, utils.LogAttr("endpoint", rpccs.listenEndpoint.Key())) } + if analytics != nil { currentLatency := time.Since(relaySentTime) analytics.Latency = currentLatency.Milliseconds() @@ -471,6 +520,24 @@ func (rpccs *RPCConsumerServer) ProcessRelaySend(ctx context.Context, directiveH } } +func (rpccs *RPCConsumerServer) CreateDappKey(dappID, consumerIp string) string { + return rpccs.consumerConsistency.Key(dappID, consumerIp) +} + +func (rpccs *RPCConsumerServer) CancelSubscriptionContext(subscriptionKey string) { + rpccs.connectedSubscriptionsLock.Lock() + defer rpccs.connectedSubscriptionsLock.Unlock() + + ctxHolder, ok := rpccs.connectedSubscriptionsContexts[subscriptionKey] + if ok { + utils.LavaFormatTrace("cancelling subscription context", utils.LogAttr("subscriptionID", subscriptionKey)) + ctxHolder.CancelFunc() + delete(rpccs.connectedSubscriptionsContexts, subscriptionKey) + } else { + utils.LavaFormatWarning("tried to cancel context for subscription ID that does not exist", nil, utils.LogAttr("subscriptionID", subscriptionKey)) + } +} + func (rpccs *RPCConsumerServer) sendRelayToProvider( ctx context.Context, chainMessage chainlib.ChainMessage, @@ -491,7 +558,6 @@ func (rpccs *RPCConsumerServer) sendRelayToProvider( // if necessary send detection tx for hashes consensus mismatch // handle QoS updates // in case connection totally fails, update unresponsive providers in ConsumerSessionManager - isSubscription := chainlib.IsSubscription(chainMessage) var sharedStateId string // defaults to "", if shared state is disabled then no shared state will be used. if rpccs.sharedState { sharedStateId = rpccs.consumerConsistency.Key(dappID, consumerIp) // use same key as we use for consistency, (for better consistency :-D) @@ -665,12 +731,39 @@ func (rpccs *RPCConsumerServer) sendRelayToProvider( // set relay sent metric go rpccs.rpcConsumerLogs.SetRelaySentToProviderMetric(chainId, apiInterface) - if isSubscription { - errResponse = rpccs.relaySubscriptionInner(goroutineCtx, endpointClient, singleConsumerSession, localRelayResult) - if errResponse != nil { - utils.LavaFormatError("Failed relaySubscriptionInner", errResponse, utils.LogAttr("Request data", localRelayRequestData)) + if chainlib.IsFunctionTagOfType(chainMessage, spectypes.FUNCTION_TAG_SUBSCRIBE) { + utils.LavaFormatTrace("inside sendRelayToProvider, relay is subscription", utils.LogAttr("requestData", localRelayRequestData.Data)) + + params, err := json.Marshal(chainMessage.GetRPCMessage().GetParams()) + if err != nil { + utils.LavaFormatError("could not marshal params", err) return } + + hashedParams := rpcclient.CreateHashFromParams(params) + cancellableCtx, cancelFunc := context.WithCancel(utils.WithUniqueIdentifier(context.Background(), utils.GenerateUniqueIdentifier())) + + ctxHolder := func() *CancelableContextHolder { + rpccs.connectedSubscriptionsLock.Lock() + defer rpccs.connectedSubscriptionsLock.Unlock() + + ctxHolder := &CancelableContextHolder{ + Ctx: cancellableCtx, + CancelFunc: cancelFunc, + } + rpccs.connectedSubscriptionsContexts[hashedParams] = ctxHolder + return ctxHolder + }() + + errResponse = rpccs.relaySubscriptionInner(ctxHolder.Ctx, hashedParams, endpointClient, singleConsumerSession, localRelayResult) + if errResponse != nil { + utils.LavaFormatError("Failed relaySubscriptionInner", errResponse, + utils.LogAttr("Request", localRelayRequestData), + utils.LogAttr("Request data", string(localRelayRequestData.Data)), + ) + } + + return } // unique per dappId and ip @@ -805,7 +898,7 @@ func (rpccs *RPCConsumerServer) relayInner(ctx context.Context, singleConsumerSe callRelay := func() (reply *pairingtypes.RelayReply, relayLatency time.Duration, err error, backoff bool) { relaySentTime := time.Now() connectCtx, connectCtxCancel := context.WithTimeout(ctx, relayTimeout) - metadataAdd := metadata.New(map[string]string{common.IP_FORWARDING_HEADER_NAME: consumerToken}) + metadataAdd := metadata.New(map[string]string{common.IP_FORWARDING_HEADER_NAME: consumerToken, common.LAVA_CONSUMER_PROCESS_GUID: rpccs.consumerProcessGuid}) connectCtx = metadata.NewOutgoingContext(connectCtx, metadataAdd) defer connectCtxCancel() @@ -893,7 +986,7 @@ func (rpccs *RPCConsumerServer) relayInner(ctx context.Context, singleConsumerSe // TODO: DETECTION instead of existingSessionLatestBlock, we need proof of last reply to send the previous reply and the current reply finalizedBlocks, finalizationConflict, err := lavaprotocol.VerifyFinalizationData(reply, relayRequest, providerPublicAddress, rpccs.ConsumerAddress, existingSessionLatestBlock, blockDistanceForFinalizedData) if err != nil { - if sdkerrors.IsOf(err, lavaprotocol.ProviderFinzalizationDataAccountabilityError) && finalizationConflict != nil { + if sdkerrors.IsOf(err, lavaprotocol.ProviderFinalizationDataAccountabilityError) && finalizationConflict != nil { go rpccs.consumerTxSender.TxConflictDetection(ctx, finalizationConflict, nil, nil, singleConsumerSession.Parent) } return 0, err, false @@ -909,25 +1002,118 @@ func (rpccs *RPCConsumerServer) relayInner(ctx context.Context, singleConsumerSe return relayLatency, nil, false } -func (rpccs *RPCConsumerServer) relaySubscriptionInner(ctx context.Context, endpointClient pairingtypes.RelayerClient, singleConsumerSession *lavasession.SingleConsumerSession, relayResult *common.RelayResult) (err error) { - // relaySentTime := time.Now() +func (rpccs *RPCConsumerServer) relaySubscriptionInner(ctx context.Context, hashedParams string, endpointClient pairingtypes.RelayerClient, singleConsumerSession *lavasession.SingleConsumerSession, relayResult *common.RelayResult) (err error) { + // add consumer guid to relay request. + metadataAdd := metadata.Pairs(common.LAVA_CONSUMER_PROCESS_GUID, rpccs.consumerProcessGuid) + ctx = metadata.NewOutgoingContext(ctx, metadataAdd) + replyServer, err := endpointClient.RelaySubscribe(ctx, relayResult.Request) - // relayLatency := time.Since(relaySentTime) // TODO: use subscription QoS if err != nil { errReport := rpccs.consumerSessionManager.OnSessionFailure(singleConsumerSession, err) if errReport != nil { - return utils.LavaFormatError("subscribe relay failed onSessionFailure errored", errReport, utils.Attribute{Key: "GUID", Value: ctx}, utils.Attribute{Key: "original error", Value: err.Error()}) + return utils.LavaFormatError("subscribe relay failed onSessionFailure errored", errReport, + utils.LogAttr("GUID", ctx), + utils.LogAttr("hashedParams", utils.ToHexString(hashedParams)), + utils.LogAttr("originalError", err.Error()), + ) } + return err } - // TODO: need to check that if provider fails and returns error, this is reflected here and we run onSessionDone - // my thoughts are that this fails if the grpc fails not if the provider fails, and if the provider returns an error this is reflected by the Recv function on the chainListener calling us here - // and this is too late - relayResult.ReplyServer = &replyServer - err = rpccs.consumerSessionManager.OnSessionDoneIncreaseCUOnly(singleConsumerSession) + + reply, err := rpccs.getFirstSubscriptionReply(ctx, hashedParams, replyServer) + if err != nil { + errReport := rpccs.consumerSessionManager.OnSessionFailure(singleConsumerSession, err) + if errReport != nil { + return utils.LavaFormatError("subscribe relay failed onSessionFailure errored", errReport, + utils.LogAttr("GUID", ctx), + utils.LogAttr("hashedParams", utils.ToHexString(hashedParams)), + utils.LogAttr("originalError", err.Error()), + ) + } + return err + } + + utils.LavaFormatTrace("subscribe relay succeeded", + utils.LogAttr("GUID", ctx), + utils.LogAttr("hashedParams", utils.ToHexString(hashedParams)), + ) + + relayResult.ReplyServer = replyServer + relayResult.Reply = reply + latestBlock := relayResult.Reply.LatestBlock + err = rpccs.consumerSessionManager.OnSessionDoneIncreaseCUOnly(singleConsumerSession, latestBlock) return err } +func (rpccs *RPCConsumerServer) getFirstSubscriptionReply(ctx context.Context, hashedParams string, replyServer pairingtypes.Relayer_RelaySubscribeClient) (*pairingtypes.RelayReply, error) { + var reply pairingtypes.RelayReply + gotFirstReplyChanOrErr := make(chan struct{}) + + // Cancel the context after SubscriptionFirstReplyTimeout duration, so we won't hang forever + go func() { + for { + select { + case <-time.After(common.SubscriptionFirstReplyTimeout): + if reply.Data == nil { + utils.LavaFormatError("Timeout exceeded when waiting for first reply message from subscription, cancelling the context with the provider", nil, + utils.LogAttr("GUID", ctx), + utils.LogAttr("hashedParams", utils.ToHexString(hashedParams)), + ) + rpccs.CancelSubscriptionContext(hashedParams) // Cancel the context with the provider, which will trigger the replyServer's context to be cancelled + } + case <-gotFirstReplyChanOrErr: + return + } + } + }() + + select { + case <-replyServer.Context().Done(): // Make sure the reply server is open + return nil, utils.LavaFormatError("reply server context canceled before first time read", nil, + utils.LogAttr("GUID", ctx), + utils.LogAttr("hashedParams", utils.ToHexString(hashedParams)), + ) + default: + err := replyServer.RecvMsg(&reply) + gotFirstReplyChanOrErr <- struct{}{} + if err != nil { + return nil, utils.LavaFormatError("Could not read reply from reply server", err, + utils.LogAttr("GUID", ctx), + utils.LogAttr("hashedParams", utils.ToHexString(hashedParams)), + ) + } + } + + utils.LavaFormatTrace("successfully got first reply", + utils.LogAttr("GUID", ctx), + utils.LogAttr("hashedParams", utils.ToHexString(hashedParams)), + utils.LogAttr("reply", string(reply.Data)), + ) + + // Make sure we can parse the reply + var replyJson rpcclient.JsonrpcMessage + err := json.Unmarshal(reply.Data, &replyJson) + if err != nil { + return nil, utils.LavaFormatError("could not parse reply into json", err, + utils.LogAttr("GUID", ctx), + utils.LogAttr("hashedParams", utils.ToHexString(hashedParams)), + utils.LogAttr("reply", reply.Data), + ) + } + + if replyJson.Error != nil { + // Node error, subscription was not initialized, triggering OnSessionFailure + return nil, utils.LavaFormatError("error in reply from subscription", nil, + utils.LogAttr("GUID", ctx), + utils.LogAttr("hashedParams", utils.ToHexString(hashedParams)), + utils.LogAttr("reply", replyJson), + ) + } + + return &reply, nil +} + func (rpccs *RPCConsumerServer) sendDataReliabilityRelayIfApplicable(ctx context.Context, dappID string, consumerIp string, chainMessage chainlib.ChainMessage, dataReliabilityThreshold uint32, relayProcessor *RelayProcessor) error { processingTimeout, expectedRelayTimeout := rpccs.getProcessingTimeout(chainMessage) // Wait another relayTimeout duration to maybe get additional relay results diff --git a/protocol/rpcprovider/provider_listener.go b/protocol/rpcprovider/provider_listener.go index 41250fae8b..981ab30e8a 100644 --- a/protocol/rpcprovider/provider_listener.go +++ b/protocol/rpcprovider/provider_listener.go @@ -3,6 +3,7 @@ package rpcprovider import ( "context" "errors" + "fmt" "net/http" "strings" "sync" @@ -10,6 +11,7 @@ import ( "github.com/gogo/status" "github.com/improbable-eng/grpc-web/go/grpcweb" "github.com/lavanet/lava/v2/protocol/chainlib" + "github.com/lavanet/lava/v2/protocol/common" "github.com/lavanet/lava/v2/protocol/lavaprotocol" "github.com/lavanet/lava/v2/protocol/lavasession" "github.com/lavanet/lava/v2/utils" @@ -70,7 +72,7 @@ func NewProviderListener(ctx context.Context, networkAddress lavasession.Network handler := func(resp http.ResponseWriter, req *http.Request) { // Set CORS headers resp.Header().Set("Access-Control-Allow-Origin", "*") - resp.Header().Set("Access-Control-Allow-Headers", "Content-Type, x-grpc-web, lava-sdk-relay-timeout") + resp.Header().Set("Access-Control-Allow-Headers", fmt.Sprintf("Content-Type, x-grpc-web, lava-sdk-relay-timeout, %s", common.LAVA_CONSUMER_PROCESS_GUID)) if req.URL.Path == healthCheckPath && req.Method == http.MethodGet { resp.WriteHeader(http.StatusOK) diff --git a/protocol/rpcprovider/reliabilitymanager/reliability_manager_test.go b/protocol/rpcprovider/reliabilitymanager/reliability_manager_test.go index 3d6059919b..225f03e3b6 100644 --- a/protocol/rpcprovider/reliabilitymanager/reliability_manager_test.go +++ b/protocol/rpcprovider/reliabilitymanager/reliability_manager_test.go @@ -185,7 +185,7 @@ func TestFullFlowReliabilityConflict(t *testing.T) { w.WriteHeader(http.StatusOK) fmt.Fprint(w, string(replyDataBuf)) }) - chainParser, chainProxy, chainFetcher, closeServer, _, err := chainlib.CreateChainLibMocks(ts.Ctx, specId, spectypes.APIInterfaceRest, serverHandler, "../../../", nil) + chainParser, chainProxy, chainFetcher, closeServer, _, err := chainlib.CreateChainLibMocks(ts.Ctx, specId, spectypes.APIInterfaceRest, serverHandler, nil, "../../../", nil) if closeServer != nil { defer closeServer() } diff --git a/protocol/rpcprovider/rpcprovider.go b/protocol/rpcprovider/rpcprovider.go index f3404c4bab..19e81ffab0 100644 --- a/protocol/rpcprovider/rpcprovider.go +++ b/protocol/rpcprovider/rpcprovider.go @@ -35,6 +35,7 @@ import ( epochstorage "github.com/lavanet/lava/v2/x/epochstorage/types" pairingtypes "github.com/lavanet/lava/v2/x/pairing/types" protocoltypes "github.com/lavanet/lava/v2/x/protocol/types" + spectypes "github.com/lavanet/lava/v2/x/spec/types" "github.com/spf13/cobra" "github.com/spf13/viper" ) @@ -215,6 +216,7 @@ func (rpcp *RPCProvider) Start(options *rpcProviderStartOptions) (err error) { } specValidator := NewSpecValidator() + utils.LavaFormatTrace("Running setup for RPCProvider endpoints", utils.LogAttr("endpoints", options.rpcProviderEndpoints)) disabledEndpointsList := rpcp.SetupProviderEndpoints(options.rpcProviderEndpoints, specValidator, true) rpcp.relaysMonitorAggregator.StartMonitoring(ctx) specValidator.Start(ctx) @@ -456,7 +458,14 @@ func (rpcp *RPCProvider) SetupEndpoint(ctx context.Context, rpcProviderEndpoint } rpcProviderServer := &RPCProviderServer{} - rpcProviderServer.ServeRPCRequests(ctx, rpcProviderEndpoint, chainParser, rpcp.rewardServer, providerSessionManager, reliabilityManager, rpcp.privKey, rpcp.cache, chainRouter, rpcp.providerStateTracker, rpcp.addr, rpcp.lavaChainID, DEFAULT_ALLOWED_MISSING_CU, providerMetrics, relaysMonitor) + + var providerNodeSubscriptionManager *chainlib.ProviderNodeSubscriptionManager + if rpcProviderEndpoint.ApiInterface == spectypes.APIInterfaceTendermintRPC || rpcProviderEndpoint.ApiInterface == spectypes.APIInterfaceJsonRPC { + utils.LavaFormatTrace("Creating provider node subscription manager", utils.LogAttr("rpcProviderEndpoint", rpcProviderEndpoint)) + providerNodeSubscriptionManager = chainlib.NewProviderNodeSubscriptionManager(chainRouter, chainParser, rpcProviderServer, rpcp.privKey) + } + + rpcProviderServer.ServeRPCRequests(ctx, rpcProviderEndpoint, chainParser, rpcp.rewardServer, providerSessionManager, reliabilityManager, rpcp.privKey, rpcp.cache, chainRouter, rpcp.providerStateTracker, rpcp.addr, rpcp.lavaChainID, DEFAULT_ALLOWED_MISSING_CU, providerMetrics, relaysMonitor, providerNodeSubscriptionManager) // set up grpc listener var listener *ProviderListener func() { @@ -471,13 +480,16 @@ func (rpcp *RPCProvider) SetupEndpoint(ctx context.Context, rpcProviderEndpoint rpcp.rpcProviderListeners[rpcProviderEndpoint.NetworkAddress.Address] = listener } }() + if listener == nil { utils.LavaFormatFatal("listener not defined, cant register RPCProviderServer", nil, utils.Attribute{Key: "RPCProviderEndpoint", Value: rpcProviderEndpoint.String()}) } + err = listener.RegisterReceiver(rpcProviderServer, rpcProviderEndpoint) if err != nil { utils.LavaFormatError("error in register receiver", err) } + utils.LavaFormatDebug("provider finished setting up endpoint", utils.Attribute{Key: "endpoint", Value: rpcProviderEndpoint.Key()}) // prevents these objects form being overrun later chainParser.Activate() @@ -729,6 +741,7 @@ rpcprovider 127.0.0.1:3333 OSMOSIS tendermintrpc "wss://www.node-path.com:80,htt cmdRPCProvider.Flags().Duration(common.RelayHealthIntervalFlag, RelayHealthIntervalFlagDefault, "interval between relay health checks") cmdRPCProvider.Flags().String(HealthCheckURLPathFlagName, HealthCheckURLPathFlagDefault, "the url path for the provider's grpc health check") cmdRPCProvider.Flags().DurationVar(&updaters.TimeOutForFetchingLavaBlocks, common.TimeOutForFetchingLavaBlocksFlag, time.Second*5, "setting the timeout for fetching lava blocks") + cmdRPCProvider.Flags().BoolVar(&chainlib.IgnoreSubscriptionNotConfiguredError, chainlib.IgnoreSubscriptionNotConfiguredErrorFlag, chainlib.IgnoreSubscriptionNotConfiguredError, "ignore webSocket node url not configured error, when subscription is enabled in spec") common.AddRollingLogConfig(cmdRPCProvider) return cmdRPCProvider diff --git a/protocol/rpcprovider/rpcprovider_server.go b/protocol/rpcprovider/rpcprovider_server.go index feaea258f9..3773c83186 100644 --- a/protocol/rpcprovider/rpcprovider_server.go +++ b/protocol/rpcprovider/rpcprovider_server.go @@ -3,8 +3,10 @@ package rpcprovider import ( "bytes" "context" + "errors" "strconv" "strings" + "sync" "time" "github.com/goccy/go-json" @@ -14,8 +16,6 @@ import ( sdk "github.com/cosmos/cosmos-sdk/types" "github.com/gogo/status" "github.com/lavanet/lava/v2/protocol/chainlib" - "github.com/lavanet/lava/v2/protocol/chainlib/chainproxy/rpcInterfaceMessages" - "github.com/lavanet/lava/v2/protocol/chainlib/chainproxy/rpcclient" "github.com/lavanet/lava/v2/protocol/chainlib/extensionslib" "github.com/lavanet/lava/v2/protocol/chaintracker" "github.com/lavanet/lava/v2/protocol/common" @@ -48,20 +48,21 @@ const ( ) type RPCProviderServer struct { - cache *performance.Cache - chainRouter chainlib.ChainRouter - privKey *btcec.PrivateKey - reliabilityManager ReliabilityManagerInf - providerSessionManager *lavasession.ProviderSessionManager - rewardServer RewardServerInf - chainParser chainlib.ChainParser - rpcProviderEndpoint *lavasession.RPCProviderEndpoint - stateTracker StateTrackerInf - providerAddress sdk.AccAddress - lavaChainID string - allowedMissingCUThreshold float64 - metrics *metrics.ProviderMetrics - relaysMonitor *metrics.RelaysMonitor + cache *performance.Cache + chainRouter chainlib.ChainRouter + privKey *btcec.PrivateKey + reliabilityManager ReliabilityManagerInf + providerSessionManager *lavasession.ProviderSessionManager + rewardServer RewardServerInf + chainParser chainlib.ChainParser + rpcProviderEndpoint *lavasession.RPCProviderEndpoint + stateTracker StateTrackerInf + providerAddress sdk.AccAddress + lavaChainID string + allowedMissingCUThreshold float64 + metrics *metrics.ProviderMetrics + relaysMonitor *metrics.RelaysMonitor + providerNodeSubscriptionManager *chainlib.ProviderNodeSubscriptionManager } type ReliabilityManagerInf interface { @@ -97,6 +98,7 @@ func (rpcps *RPCProviderServer) ServeRPCRequests( allowedMissingCUThreshold float64, providerMetrics *metrics.ProviderMetrics, relaysMonitor *metrics.RelaysMonitor, + providerNodeSubscriptionManager *chainlib.ProviderNodeSubscriptionManager, ) { rpcps.cache = cache rpcps.chainRouter = chainRouter @@ -112,6 +114,7 @@ func (rpcps *RPCProviderServer) ServeRPCRequests( rpcps.allowedMissingCUThreshold = allowedMissingCUThreshold rpcps.metrics = providerMetrics rpcps.relaysMonitor = relaysMonitor + rpcps.providerNodeSubscriptionManager = providerNodeSubscriptionManager rpcps.initRelaysMonitor(ctx) } @@ -135,13 +138,14 @@ func (rpcps *RPCProviderServer) initRelaysMonitor(ctx context.Context) { } func (rpcps *RPCProviderServer) craftChainMessage() (chainMessage chainlib.ChainMessage, err error) { - parsing, collectionData, ok := rpcps.chainParser.GetParsingByTag(spectypes.FUNCTION_TAG_GET_BLOCKNUM) + parsing, apiCollection, ok := rpcps.chainParser.GetParsingByTag(spectypes.FUNCTION_TAG_GET_BLOCKNUM) if !ok { return nil, utils.LavaFormatWarning("did not send initial relays because the spec does not contain "+spectypes.FUNCTION_TAG_GET_BLOCKNUM.String(), nil, utils.LogAttr("chainID", rpcps.rpcProviderEndpoint.ChainID), utils.LogAttr("APIInterface", rpcps.rpcProviderEndpoint.ApiInterface), ) } + collectionData := apiCollection.CollectionData path := parsing.ApiName data := []byte(parsing.FunctionTemplate) @@ -188,11 +192,26 @@ func (rpcps *RPCProviderServer) Relay(ctx context.Context, request *pairingtypes // Init relay relaySession, consumerAddress, chainMessage, err := rpcps.initRelay(ctx, request) if err != nil { + utils.LavaFormatDebug("got error from init relay", utils.LogAttr("error", err)) return nil, rpcps.handleRelayErrorStatus(err) } - // Try sending relay - reply, err := rpcps.TryRelay(ctx, request, consumerAddress, chainMessage) + // Check that this is not subscription related messages + if chainlib.IsFunctionTagOfType(chainMessage, spectypes.FUNCTION_TAG_SUBSCRIBE) { + return nil, errors.New("subscribe method is not supported through Relay") + } + + if chainlib.IsFunctionTagOfType(chainMessage, spectypes.FUNCTION_TAG_UNSUBSCRIBE_ALL) { + return nil, errors.New("unsubscribe_all method is not supported through Relay") + } + + var reply *pairingtypes.RelayReply + if chainlib.IsFunctionTagOfType(chainMessage, spectypes.FUNCTION_TAG_UNSUBSCRIBE) { + reply, err = rpcps.TryRelayUnsubscribe(ctx, request, consumerAddress, chainMessage) + } else { + // Try sending relay + reply, err = rpcps.TryRelay(ctx, request, consumerAddress, chainMessage) + } if err != nil || common.ContextOutOfTime(ctx) { // failed to send relay. we need to adjust session state. cuSum and relayNumber. @@ -333,38 +352,38 @@ func (rpcps *RPCProviderServer) RelaySubscribe(request *pairingtypes.RelayReques if request.RelayData == nil || request.RelaySession == nil { return utils.LavaFormatError("invalid relay subscribe request, internal fields are nil", nil) } + ctx := utils.AppendUniqueIdentifier(context.Background(), lavaprotocol.GetSalt(request.RelayData)) utils.LavaFormatDebug("Provider got relay subscribe request", - utils.Attribute{Key: "request.SessionId", Value: request.RelaySession.SessionId}, - utils.Attribute{Key: "request.relayNumber", Value: request.RelaySession.RelayNum}, - utils.Attribute{Key: "request.cu", Value: request.RelaySession.CuSum}, - utils.Attribute{Key: "GUID", Value: ctx}, + utils.LogAttr("request.SessionId", request.RelaySession.SessionId), + utils.LogAttr("request.relayNumber", request.RelaySession.RelayNum), + utils.LogAttr("request.cu", request.RelaySession.CuSum), + utils.LogAttr("GUID", ctx), ) + relaySession, consumerAddress, chainMessage, err := rpcps.initRelay(ctx, request) if err != nil { + utils.LavaFormatDebug("got error from init relay", utils.LogAttr("error", err)) return rpcps.handleRelayErrorStatus(err) } - subscribed, err := rpcps.TryRelaySubscribe(ctx, uint64(request.RelaySession.Epoch), srv, chainMessage, consumerAddress, relaySession, request.RelaySession.RelayNum) // this function does not return until subscription ends + + // TryRelaySubscribe is blocking until subscription ends + subscribed, err := rpcps.TryRelaySubscribe(ctx, uint64(request.RelaySession.Epoch), request, srv, chainMessage, consumerAddress, relaySession, request.RelaySession.RelayNum) if subscribed { - // meaning we created a subscription and used it for at least a message - pairingEpoch := relaySession.PairingEpoch - // no need to perform on session done as we did it in try relay subscribe - go rpcps.SendProof(ctx, pairingEpoch, request, consumerAddress, chainMessage.GetApiCollection().CollectionData.ApiInterface) utils.LavaFormatDebug("Provider Finished Relay Successfully", - utils.Attribute{Key: "request.SessionId", Value: request.RelaySession.SessionId}, - utils.Attribute{Key: "request.relayNumber", Value: request.RelaySession.RelayNum}, - utils.Attribute{Key: "GUID", Value: ctx}, + utils.LogAttr("request.SessionId", request.RelaySession.SessionId), + utils.LogAttr("request.relayNumber", request.RelaySession.RelayNum), + utils.LogAttr("GUID", ctx), ) err = nil // we don't want to return an error here } else { // we didn't even manage to subscribe - relayFailureError := rpcps.providerSessionManager.OnSessionFailure(relaySession, request.RelaySession.RelayNum) - if relayFailureError != nil { - err = utils.LavaFormatError("failed subscribing", lavasession.SubscriptionInitiationError, utils.Attribute{Key: "GUID", Value: ctx}, utils.Attribute{Key: "onSessionFailureError", Value: relayFailureError.Error()}, utils.Attribute{Key: "error", Value: err}) - } else { - err = utils.LavaFormatError("failed subscribing", lavasession.SubscriptionInitiationError, utils.Attribute{Key: "GUID", Value: ctx}, utils.Attribute{Key: "error", Value: err}) - } + err = utils.LavaFormatError("failed subscribing", lavasession.SubscriptionInitiationError, + utils.LogAttr("GUID", ctx), + utils.LogAttr("error", err), + ) } + return rpcps.handleRelayErrorStatus(err) } @@ -378,91 +397,101 @@ func (rpcps *RPCProviderServer) SendProof(ctx context.Context, epoch uint64, req return nil } -func (rpcps *RPCProviderServer) TryRelaySubscribe(ctx context.Context, requestBlockHeight uint64, srv pairingtypes.Relayer_RelaySubscribeServer, chainMessage chainlib.ChainMessage, consumerAddress sdk.AccAddress, relaySession *lavasession.SingleProviderSession, relayNumber uint64) (subscribed bool, errRet error) { - var reply *pairingtypes.RelayReply - var clientSub *rpcclient.ClientSubscription - var subscriptionID string - subscribeRepliesChan := make(chan interface{}) - replyWrapper, subscriptionID, clientSub, _, _, err := rpcps.chainRouter.SendNodeMsg(ctx, subscribeRepliesChan, chainMessage, nil) - if err != nil { - return false, utils.LavaFormatError("Subscription failed", err, utils.Attribute{Key: "GUID", Value: ctx}) - } - if replyWrapper == nil || replyWrapper.RelayReply == nil { - return false, utils.LavaFormatError("Subscription failed, relayWrapper or RelayReply are nil", nil, utils.Attribute{Key: "GUID", Value: ctx}) - } - reply = replyWrapper.RelayReply - reply.Metadata, _, _ = rpcps.chainParser.HandleHeaders(reply.Metadata, chainMessage.GetApiCollection(), spectypes.Header_pass_reply) - if clientSub == nil { - // failed subscription, but not an error. (probably a node error) - // return the response to the user, and close the session. - relayError := rpcps.providerSessionManager.OnSessionDone(relaySession, relayNumber) // subscription failed due to node error mark session as done and return - if relayError != nil { - utils.LavaFormatError("Error OnSessionDone", relayError) - } - err = srv.Send(reply) // this reply contains the error to the subscription - if err != nil { - utils.LavaFormatError("Error returning response", err) - } - return true, nil // we already returned the error to the user so no need to return another error. +func (rpcps *RPCProviderServer) TryRelaySubscribe(ctx context.Context, requestBlockHeight uint64, request *pairingtypes.RelayRequest, srv pairingtypes.Relayer_RelaySubscribeServer, chainMessage chainlib.ChainMessage, consumerAddress sdk.AccAddress, relaySession *lavasession.SingleProviderSession, relayNumber uint64) (subscribedSuccessfully bool, errRet error) { + subscribeRepliesChan := make(chan *pairingtypes.RelayReply) + ctx, cancel := context.WithCancel(ctx) + defer cancel() + consumerProcessGuid, found := rpcps.fetchConsumerProcessGuidFromContext(srv.Context()) + if !found { + return false, utils.LavaFormatWarning("Could not find consumer process GUID in context, which is required for subscription relays", nil) } - // if we got a node error clientSub will be nil and also err will be nil. in that case we need to check for clientSub - subscription := &lavasession.RPCSubscription{ - Id: subscriptionID, - Sub: clientSub, - SubscribeRepliesChan: subscribeRepliesChan, - } - err = rpcps.providerSessionManager.ReleaseSessionAndCreateSubscription(relaySession, subscription, consumerAddress.String(), requestBlockHeight, relayNumber) - if err != nil { - return false, err - } - rpcps.rewardServer.SubscribeStarted(consumerAddress.String(), requestBlockHeight, subscriptionID) - processSubscribeMessages := func() (subscribed bool, errRet error) { - err = srv.Send(reply) // this reply contains the RPC ID - if err != nil { - utils.LavaFormatError("Error getting RPC ID", err, utils.Attribute{Key: "GUID", Value: ctx}) - } else { - subscribed = true - } + // The reasons that we have a wait group here, and we pass it to the go routine is because we want to start the channel read before calling AddConsumer, + // because it might stuck on writing to the channel if we don't do that, which will create a deadlock. + // But, we still want to wait the go routine to finish before we return (because the gRPC stream will close on return), so we use a wait group to wait for the go routine to finish. + wg := sync.WaitGroup{} + wg.Add(1) + + // Process subscription messages + go func() { + defer wg.Done() for { select { - case <-clientSub.Err(): - utils.LavaFormatError("client sub", err, utils.Attribute{Key: "GUID", Value: ctx}) - // delete this connection from the subs map + case <-ctx.Done(): + case <-srv.Context().Done(): + utils.LavaFormatTrace("ctx or relay server context closed", + utils.LogAttr("GUID", ctx), + utils.LogAttr("consumerAddr", consumerAddress), + ) - return subscribed, err - case subscribeReply := <-subscribeRepliesChan: - data, err := json.Marshal(subscribeReply) + err := rpcps.providerNodeSubscriptionManager.RemoveConsumer(ctx, chainMessage, consumerAddress, true, consumerProcessGuid) if err != nil { - return subscribed, utils.LavaFormatError("client sub unmarshal", err, utils.Attribute{Key: "GUID", Value: ctx}) + errRet = utils.LavaFormatError("Error RemoveConsumer", err, utils.LogAttr("GUID", ctx)) + } + return + case subscribeReply, ok := <-subscribeRepliesChan: + if !ok { // channel is closed + errRet = utils.LavaFormatTrace("subscribeRepliesChan closed", + utils.LogAttr("GUID", ctx), + utils.LogAttr("consumerAddr", consumerAddress), + ) + err := rpcps.providerNodeSubscriptionManager.RemoveConsumer(ctx, chainMessage, consumerAddress, false, consumerProcessGuid) // false because the channel is already closed + if err != nil { + errRet = utils.LavaFormatError("Error RemoveConsumer", err, utils.LogAttr("GUID", ctx)) + } + return } - err = srv.Send( - &pairingtypes.RelayReply{ - Data: data, - }, - ) - if err != nil { + errRet = srv.Send(subscribeReply) + + if errRet != nil { // usually triggered when client closes connection - if strings.Contains(err.Error(), "Canceled desc = context canceled") { - err = utils.LavaFormatWarning("Client closed connection", err, utils.Attribute{Key: "GUID", Value: ctx}) + if strings.Contains(errRet.Error(), "Canceled desc = context canceled") { + errRet = utils.LavaFormatWarning("Client closed connection", errRet, utils.Attribute{Key: "GUID", Value: ctx}) } else { - err = utils.LavaFormatError("srv.Send", err, utils.Attribute{Key: "GUID", Value: ctx}) + errRet = utils.LavaFormatError("Got error from srv.Send()", errRet, utils.Attribute{Key: "GUID", Value: ctx}) } - return subscribed, err - } else { - subscribed = true + + return } - utils.LavaFormatDebug("Sending data", utils.Attribute{Key: "data", Value: string(data)}, utils.Attribute{Key: "GUID", Value: ctx}) + subscribedSuccessfully = true + utils.LavaFormatTrace("Sending data to consumer", + utils.LogAttr("GUID", ctx), + utils.LogAttr("data", subscribeReply.Data), + utils.LogAttr("consumerAddr", consumerAddress), + ) } } + }() + + subscriptionId, err := rpcps.providerNodeSubscriptionManager.AddConsumer(ctx, request, chainMessage, consumerAddress, subscribeRepliesChan, consumerProcessGuid) + if err != nil { + // Subscription failed due to node error mark session as done and return + relayError := rpcps.providerSessionManager.OnSessionFailure(relaySession, relayNumber) + if relayError != nil { + utils.LavaFormatError("Error OnSessionDone", relayError) + } + + return false, utils.LavaFormatWarning("RPCProviderServer: Subscription failed", err, + utils.LogAttr("GUID", ctx), + utils.LogAttr("consumerAddr", consumerAddress), + ) } - subscribed, errRet = processSubscribeMessages() - rpcps.providerSessionManager.SubscriptionEnded(consumerAddress.String(), requestBlockHeight, subscriptionID) - rpcps.rewardServer.SubscribeEnded(consumerAddress.String(), requestBlockHeight, subscriptionID) - return subscribed, errRet + + relayError := rpcps.providerSessionManager.OnSessionDone(relaySession, relayNumber) + if relayError != nil { + utils.LavaFormatError("Error OnSessionDone", relayError) + } + + go rpcps.SendProof(ctx, relaySession.PairingEpoch, request, consumerAddress, chainMessage.GetApiCollection().CollectionData.ApiInterface) + + rpcps.rewardServer.SubscribeStarted(consumerAddress.String(), requestBlockHeight, subscriptionId) + wg.Wait() // Block until subscription is done + + rpcps.rewardServer.SubscribeEnded(consumerAddress.String(), requestBlockHeight, subscriptionId) + return subscribedSuccessfully, errRet } // verifies basic relay fields, and gets a provider session @@ -666,211 +695,321 @@ func (rpcps *RPCProviderServer) TryRelay(ctx context.Context, request *pairingty if errV != nil { return nil, errV } - // Send - var reqMsg *rpcInterfaceMessages.JsonrpcMessage - var reqParams interface{} - switch msg := chainMsg.GetRPCMessage().(type) { - case *rpcInterfaceMessages.JsonrpcMessage: - reqMsg = msg - reqParams = reqMsg.Params - default: - reqMsg = nil - } - var requestedBlockHash []byte = nil - finalized := false - dataReliabilityEnabled, _ := rpcps.chainParser.DataReliabilityParams() + var latestBlock int64 + var requestedBlockHash []byte var requestedHashes []*chaintracker.BlockStore var modifiedReqBlock int64 - var blocksInFinalizationData uint32 - var blockDistanceToFinalization uint32 - var averageBlockTime time.Duration + + finalized := false updatedChainMessage := false - var blockLagForQosSync int64 - blockLagForQosSync, averageBlockTime, blockDistanceToFinalization, blocksInFinalizationData = rpcps.chainParser.ChainBlockStats() + + dataReliabilityEnabled, _ := rpcps.chainParser.DataReliabilityParams() + blockLagForQosSync, averageBlockTime, blockDistanceToFinalization, blocksInFinalizationData := rpcps.chainParser.ChainBlockStats() relayTimeout := chainlib.GetRelayTimeout(chainMsg, averageBlockTime) + if dataReliabilityEnabled { var err error - specificBlock := request.RelayData.RequestBlock - if specificBlock < spectypes.LATEST_BLOCK { - // cases of EARLIEST, FINALIZED, SAFE - // GetLatestBlockData only supports latest relative queries or specific block numbers - specificBlock = spectypes.NOT_APPLICABLE - } - - // handle consistency, if the consumer requested information we do not have in the state tracker - - latestBlock, requestedHashes, _, err = rpcps.handleConsistency(ctx, relayTimeout, request.RelayData.GetSeenBlock(), request.RelayData.GetRequestBlock(), averageBlockTime, blockLagForQosSync, blockDistanceToFinalization, blocksInFinalizationData) + latestBlock, requestedBlockHash, requestedHashes, modifiedReqBlock, finalized, updatedChainMessage, err = rpcps.GetParametersForRelayDataReliability(ctx, request, chainMsg, relayTimeout, blockLagForQosSync, averageBlockTime, blockDistanceToFinalization, blocksInFinalizationData) if err != nil { return nil, err } - // get specific block data for caching - _, specificRequestedHashes, _, err := rpcps.reliabilityManager.GetLatestBlockData(spectypes.NOT_APPLICABLE, spectypes.NOT_APPLICABLE, specificBlock) - if err == nil && len(specificRequestedHashes) == 1 { - requestedBlockHash = []byte(specificRequestedHashes[0].Hash) - } - - // TODO: take latestBlock and lastSeenBlock and put the greater one of them - updatedChainMessage = chainMsg.UpdateLatestBlockInMessage(latestBlock, true) - - modifiedReqBlock = lavaprotocol.ReplaceRequestedBlock(request.RelayData.RequestBlock, latestBlock) - if modifiedReqBlock != request.RelayData.RequestBlock { - request.RelayData.RequestBlock = modifiedReqBlock - updatedChainMessage = true // meaning we can't bring a newer proof - } - // requestedBlockHash, finalizedBlockHashes = chaintracker.FindRequestedBlockHash(requestedHashes, request.RelayData.RequestBlock, toBlock, fromBlock, finalizedBlockHashes) - finalized = spectypes.IsFinalizedBlock(modifiedReqBlock, latestBlock, blockDistanceToFinalization) - if !finalized && requestedBlockHash == nil && modifiedReqBlock != spectypes.NOT_APPLICABLE { - // avoid using cache, but can still service - utils.LavaFormatWarning("no hash data for requested block", nil, utils.Attribute{Key: "specID", Value: rpcps.rpcProviderEndpoint.ChainID}, utils.Attribute{Key: "GUID", Value: ctx}, utils.Attribute{Key: "requestedBlock", Value: request.RelayData.RequestBlock}, utils.Attribute{Key: "latestBlock", Value: latestBlock}, utils.Attribute{Key: "modifiedReqBlock", Value: modifiedReqBlock}, utils.Attribute{Key: "specificBlock", Value: specificBlock}) - } } - cache := rpcps.cache + // TODO: handle cache on fork for dataReliability = false - var reply *pairingtypes.RelayReply = nil - var err error = nil - ignoredMetadata := []pairingtypes.Metadata{} - if requestedBlockHash != nil || finalized { - var cacheReply *pairingtypes.CacheRelayReply - - hashKey, outPutFormatter, hashErr := chainlib.HashCacheRequest(request.RelayData, rpcps.rpcProviderEndpoint.ChainID) - if hashErr != nil { - utils.LavaFormatError("TryRelay Failed computing hash for cache request", hashErr) - } else { - cacheCtx, cancel := context.WithTimeout(ctx, common.CacheTimeout) - cacheReply, err = cache.GetEntry(cacheCtx, &pairingtypes.RelayCacheGet{ - RequestHash: hashKey, - RequestedBlock: request.RelayData.RequestBlock, - ChainId: rpcps.rpcProviderEndpoint.ChainID, - BlockHash: requestedBlockHash, - Finalized: finalized, - SeenBlock: request.RelayData.SeenBlock, - }) - cancel() - reply = cacheReply.GetReply() - if reply != nil { - reply.Data = outPutFormatter(reply.Data) // setting request id back to reply. - } - ignoredMetadata = cacheReply.GetOptionalMetadata() - if err != nil && performance.NotConnectedError.Is(err) { - utils.LavaFormatDebug("cache not connected", utils.LogAttr("err", err), utils.Attribute{Key: "GUID", Value: ctx}) - } - } + var reply *pairingtypes.RelayReply + var ignoredMetadata []pairingtypes.Metadata + var err error + if requestedBlockHash != nil || finalized { // try get reply from cache + reply, ignoredMetadata, err = rpcps.tryGetRelayReplyFromCache(ctx, request, requestedBlockHash, finalized) } + if err != nil || reply == nil { // we need to send relay, cache miss or invalid - sendTime := time.Now() - if debugLatency { - utils.LavaFormatDebug("sending relay to node", utils.Attribute{Key: "GUID", Value: ctx}, utils.Attribute{Key: "specID", Value: rpcps.rpcProviderEndpoint.ChainID}) - } - // add stickiness header - chainMsg.AppendHeader([]pairingtypes.Metadata{{Name: RPCProviderStickinessHeaderName, Value: common.GetUniqueToken(consumerAddr.String(), common.GetTokenFromGrpcContext(ctx))}}) - chainMsg.AppendHeader([]pairingtypes.Metadata{{Name: RPCProviderAddressHeader, Value: rpcps.providerAddress.String()}}) - if debugConsistency { - utils.LavaFormatDebug("adding stickiness header", utils.LogAttr("tokenFromContext", common.GetTokenFromGrpcContext(ctx)), utils.LogAttr("unique_token", common.GetUniqueToken(consumerAddr.String(), common.GetIpFromGrpcContext(ctx)))) - } var replyWrapper *chainlib.RelayReplyWrapper - replyWrapper, _, _, _, _, err = rpcps.chainRouter.SendNodeMsg(ctx, nil, chainMsg, request.RelayData.Extensions) + replyWrapper, err = rpcps.sendRelayMessageToNode(ctx, request, chainMsg, consumerAddr) if err != nil { - return nil, utils.LavaFormatError("Sending chainMsg failed", err, utils.Attribute{Key: "GUID", Value: ctx}, utils.Attribute{Key: "specID", Value: rpcps.rpcProviderEndpoint.ChainID}) - } - if replyWrapper == nil || replyWrapper.RelayReply == nil { - return nil, utils.LavaFormatError("Relay Wrapper returned nil without an error", nil, utils.Attribute{Key: "GUID", Value: ctx}, utils.Attribute{Key: "specID", Value: rpcps.rpcProviderEndpoint.ChainID}) + return nil, err } reply = replyWrapper.RelayReply - if debugLatency { - utils.LavaFormatDebug("node reply received", utils.Attribute{Key: "timeTaken", Value: time.Since(sendTime)}, utils.Attribute{Key: "GUID", Value: ctx}, utils.Attribute{Key: "specID", Value: rpcps.rpcProviderEndpoint.ChainID}) - } + reply.Metadata, _, ignoredMetadata = rpcps.chainParser.HandleHeaders(reply.Metadata, chainMsg.GetApiCollection(), spectypes.Header_pass_reply) // TODO: use overwriteReqBlock on the reply metadata to set the correct latest block - if cache.CacheActive() && (requestedBlockHash != nil || finalized) { - isNodeError, _ := chainMsg.CheckResponseError(reply.Data, replyWrapper.StatusCode) - // in case the error is a node error we don't want to cache - if !isNodeError { - // copy request and reply as they change later on and we call SetEntry in a routine. - requestedBlock := request.RelayData.RequestBlock // get requested block before removing it from the data - hashKey, _, hashErr := chainlib.HashCacheRequest(request.RelayData, rpcps.rpcProviderEndpoint.ChainID) // get the hash (this changes the data) - copyReply := &pairingtypes.RelayReply{} - copyReplyErr := protocopy.DeepCopyProtoObject(reply, copyReply) - go func() { - if hashErr != nil || copyReplyErr != nil { - utils.LavaFormatError("Failed copying relay private data on TryRelay", nil, utils.LogAttr("copyReplyErr", copyReplyErr), utils.LogAttr("hashErr", hashErr)) - return - } - new_ctx := context.Background() - new_ctx, cancel := context.WithTimeout(new_ctx, common.DataReliabilityTimeoutIncrease) - defer cancel() - if err != nil { - utils.LavaFormatError("TryRelay failed calculating hash for cach.SetEntry", err) - return - } - err = cache.SetEntry(new_ctx, &pairingtypes.RelayCacheSet{ - RequestHash: hashKey, - RequestedBlock: requestedBlock, - BlockHash: requestedBlockHash, - ChainId: rpcps.rpcProviderEndpoint.ChainID, - Response: copyReply, - Finalized: finalized, - OptionalMetadata: ignoredMetadata, - AverageBlockTime: int64(averageBlockTime), - SeenBlock: latestBlock, - IsNodeError: isNodeError, - }) - if err != nil && request.RelaySession.Epoch != spectypes.NOT_APPLICABLE { - utils.LavaFormatWarning("error updating cache with new entry", err, utils.Attribute{Key: "GUID", Value: ctx}) - } - }() - } + if rpcps.cache.CacheActive() && (requestedBlockHash != nil || finalized) { + rpcps.trySetRelayReplyInCache(ctx, request, chainMsg, replyWrapper, latestBlock, averageBlockTime, requestedBlockHash, finalized, ignoredMetadata) } } - apiName := chainMsg.GetApi().Name - if reqMsg != nil && strings.Contains(apiName, "unsubscribe") { - err := rpcps.processUnsubscribe(ctx, apiName, consumerAddr, reqParams, uint64(request.RelayData.RequestBlock)) + if dataReliabilityEnabled { + err := rpcps.BuildRelayFinalizedBlockHashes(ctx, request, reply, latestBlock, requestedHashes, updatedChainMessage, relayTimeout, averageBlockTime, blockDistanceToFinalization, blocksInFinalizationData, modifiedReqBlock) if err != nil { return nil, err } } - if dataReliabilityEnabled { - // now we need to provide the proof for the response - proofBlock := latestBlock - if !updatedChainMessage || len(requestedHashes) == 0 { - // we can fetch a more advanced finalization proof, than we fetched previously - proofBlock, requestedHashes, _, err = rpcps.GetLatestBlockData(ctx, blockDistanceToFinalization, blocksInFinalizationData) - if err != nil { - return nil, err - } - } // else: we updated the chain message to request the specific latestBlock we fetched earlier, so use the previously fetched latest block and hashes - if proofBlock < modifiedReqBlock && proofBlock < request.RelayData.SeenBlock { - // we requested with a newer block, but don't necessarily have the finaliziation proof, chaintracker might be behind - proofBlock = lavaslices.Min([]int64{modifiedReqBlock, request.RelayData.SeenBlock}) - proofBlock, requestedHashes, err = rpcps.GetBlockDataForOptimisticFetch(ctx, relayTimeout, proofBlock, blockDistanceToFinalization, blocksInFinalizationData, averageBlockTime) - if err != nil { - return nil, utils.LavaFormatError("error getting block range for finalization proof", err) + // utils.LavaFormatDebug("response signing", utils.LogAttr("request block", request.RelayData.RequestBlock), utils.LogAttr("GUID", ctx), utils.LogAttr("latestBlock", reply.LatestBlock)) + reply, err = lavaprotocol.SignRelayResponse(consumerAddr, *request, rpcps.privKey, reply, dataReliabilityEnabled) + if err != nil { + return nil, err + } + + reply.Metadata = append(reply.Metadata, ignoredMetadata...) // appended here only after signing + + // return reply to user + return reply, nil +} + +func (rpcps *RPCProviderServer) tryGetRelayReplyFromCache(ctx context.Context, request *pairingtypes.RelayRequest, requestedBlockHash []byte, finalized bool) (*pairingtypes.RelayReply, []pairingtypes.Metadata, error) { + cache := rpcps.cache + hashKey, outPutFormatter, hashErr := chainlib.HashCacheRequest(request.RelayData, rpcps.rpcProviderEndpoint.ChainID) + if hashErr != nil { + utils.LavaFormatError("TryRelay Failed computing hash for cache request", hashErr) + return nil, nil, nil + } + cacheCtx, cancel := context.WithTimeout(ctx, common.CacheTimeout) + cacheReply, err := cache.GetEntry(cacheCtx, &pairingtypes.RelayCacheGet{ + RequestHash: hashKey, + RequestedBlock: request.RelayData.RequestBlock, + ChainId: rpcps.rpcProviderEndpoint.ChainID, + BlockHash: requestedBlockHash, + Finalized: finalized, + SeenBlock: request.RelayData.SeenBlock, + }) + cancel() + + if err != nil && performance.NotConnectedError.Is(err) { + utils.LavaFormatDebug("cache not connected", utils.LogAttr("err", err), utils.Attribute{Key: "GUID", Value: ctx}) + return nil, nil, err + } + + reply := cacheReply.GetReply() + if reply != nil { + reply.Data = outPutFormatter(reply.Data) // setting request id back to reply. + } + + ignoredMetadata := cacheReply.GetOptionalMetadata() + + return reply, ignoredMetadata, err +} + +func (rpcps *RPCProviderServer) trySetRelayReplyInCache(ctx context.Context, request *pairingtypes.RelayRequest, chainMsg chainlib.ChainMessage, replyWrapper *chainlib.RelayReplyWrapper, latestBlock int64, averageBlockTime time.Duration, requestedBlockHash []byte, finalized bool, ignoredMetadata []pairingtypes.Metadata) { + cache := rpcps.cache + reply := replyWrapper.RelayReply + + isNodeError, _ := chainMsg.CheckResponseError(reply.Data, replyWrapper.StatusCode) + // in case the error is a node error we don't want to cache + if !isNodeError { + // copy request and reply as they change later on and we call SetEntry in a routine. + requestedBlock := request.RelayData.RequestBlock // get requested block before removing it from the data + hashKey, _, hashErr := chainlib.HashCacheRequest(request.RelayData, rpcps.rpcProviderEndpoint.ChainID) // get the hash (this changes the data) + copyReply := &pairingtypes.RelayReply{} + copyReplyErr := protocopy.DeepCopyProtoObject(reply, copyReply) + go func() { + if hashErr != nil || copyReplyErr != nil { + utils.LavaFormatError("Failed copying relay private data on TryRelay", nil, utils.LogAttr("copyReplyErr", copyReplyErr), utils.LogAttr("hashErr", hashErr)) + return + } + new_ctx := context.Background() + new_ctx, cancel := context.WithTimeout(new_ctx, common.DataReliabilityTimeoutIncrease) + defer cancel() + err := cache.SetEntry(new_ctx, &pairingtypes.RelayCacheSet{ + RequestHash: hashKey, + RequestedBlock: requestedBlock, + BlockHash: requestedBlockHash, + ChainId: rpcps.rpcProviderEndpoint.ChainID, + Response: copyReply, + Finalized: finalized, + OptionalMetadata: ignoredMetadata, + AverageBlockTime: int64(averageBlockTime), + SeenBlock: latestBlock, + IsNodeError: isNodeError, + }) + if err != nil && request.RelaySession.Epoch != spectypes.NOT_APPLICABLE { + utils.LavaFormatWarning("error updating cache with new entry", err, utils.Attribute{Key: "GUID", Value: ctx}) } + }() + } +} + +func (rpcps *RPCProviderServer) sendRelayMessageToNode(ctx context.Context, request *pairingtypes.RelayRequest, chainMsg chainlib.ChainMessage, consumerAddr sdk.AccAddress) (*chainlib.RelayReplyWrapper, error) { + sendTime := time.Now() + if debugLatency { + utils.LavaFormatDebug("sending relay to node", utils.Attribute{Key: "GUID", Value: ctx}, utils.Attribute{Key: "specID", Value: rpcps.rpcProviderEndpoint.ChainID}) + } + // add stickiness header + chainMsg.AppendHeader([]pairingtypes.Metadata{{Name: RPCProviderStickinessHeaderName, Value: common.GetUniqueToken(consumerAddr.String(), common.GetTokenFromGrpcContext(ctx))}}) + chainMsg.AppendHeader([]pairingtypes.Metadata{{Name: RPCProviderAddressHeader, Value: rpcps.providerAddress.String()}}) + if debugConsistency { + utils.LavaFormatDebug("adding stickiness header", utils.LogAttr("tokenFromContext", common.GetTokenFromGrpcContext(ctx)), utils.LogAttr("unique_token", common.GetUniqueToken(consumerAddr.String(), common.GetIpFromGrpcContext(ctx)))) + } + + replyWrapper, _, _, _, _, err := rpcps.chainRouter.SendNodeMsg(ctx, nil, chainMsg, request.RelayData.Extensions) + if err != nil { + return nil, utils.LavaFormatError("Sending chainMsg failed", err, utils.Attribute{Key: "GUID", Value: ctx}, utils.Attribute{Key: "specID", Value: rpcps.rpcProviderEndpoint.ChainID}) + } + + if replyWrapper == nil || replyWrapper.RelayReply == nil { + return nil, utils.LavaFormatError("Relay Wrapper returned nil without an error", nil, utils.Attribute{Key: "GUID", Value: ctx}, utils.Attribute{Key: "specID", Value: rpcps.rpcProviderEndpoint.ChainID}) + } + + if debugLatency { + utils.LavaFormatDebug("node reply received", utils.Attribute{Key: "timeTaken", Value: time.Since(sendTime)}, utils.Attribute{Key: "GUID", Value: ctx}, utils.Attribute{Key: "specID", Value: rpcps.rpcProviderEndpoint.ChainID}) + } + + return replyWrapper, nil +} + +func (rpcps *RPCProviderServer) TryRelayUnsubscribe(ctx context.Context, request *pairingtypes.RelayRequest, consumerAddress sdk.AccAddress, chainMessage chainlib.ChainMessage) (*pairingtypes.RelayReply, error) { + errV := rpcps.ValidateRequest(chainMessage, request, ctx) + if errV != nil { + return nil, errV + } + + utils.LavaFormatDebug("Provider got unsubscribe request", utils.LogAttr("GUID", ctx)) + + consumerProcessGuid, found := rpcps.fetchConsumerProcessGuidFromContext(ctx) + if !found { + return nil, utils.LavaFormatWarning("Could not find consumer process GUID in context, which is required for unsubscribe relays", nil) + } + + // Remove the consumer from the connected consumers list of the subscription + err := rpcps.providerNodeSubscriptionManager.RemoveConsumer(ctx, chainMessage, consumerAddress, true, consumerProcessGuid) + if err != nil { + return nil, err + } + + rpcResponse, err := lavaprotocol.CraftEmptyRPCResponseFromGenericMessage(chainMessage.GetRPCMessage()) + if err != nil { + return nil, utils.LavaFormatError("failed crafting empty rpc response", err) + } + + dataToSend, err := json.Marshal(rpcResponse) + if err != nil { + return nil, utils.LavaFormatError("failed marshaling json response", err) + } + + reply := &pairingtypes.RelayReply{ + Data: dataToSend, + } + + dataReliabilityEnabled, _ := rpcps.chainParser.DataReliabilityParams() + if dataReliabilityEnabled { + blockLagForQosSync, averageBlockTime, blockDistanceToFinalization, blocksInFinalizationData := rpcps.chainParser.ChainBlockStats() + relayTimeout := chainlib.GetRelayTimeout(chainMessage, averageBlockTime) + latestBlock, _, requestedHashes, modifiedReqBlock, _, updatedChainMessage, err := rpcps.GetParametersForRelayDataReliability(ctx, request, chainMessage, relayTimeout, blockLagForQosSync, averageBlockTime, blockDistanceToFinalization, blocksInFinalizationData) + if err != nil { + return nil, err } - finalizedBlockHashes := chaintracker.BuildProofFromBlocks(requestedHashes) - jsonStr, err := json.Marshal(finalizedBlockHashes) + err = rpcps.BuildRelayFinalizedBlockHashes(ctx, request, reply, latestBlock, requestedHashes, updatedChainMessage, relayTimeout, averageBlockTime, blockDistanceToFinalization, blocksInFinalizationData, modifiedReqBlock) if err != nil { - return nil, utils.LavaFormatError("failed unmarshaling finalizedBlockHashes", err, utils.Attribute{Key: "GUID", Value: ctx}, - utils.Attribute{Key: "finalizedBlockHashes", Value: finalizedBlockHashes}, utils.Attribute{Key: "specID", Value: rpcps.rpcProviderEndpoint.ChainID}) + return nil, err } - reply.FinalizedBlocksHashes = jsonStr - reply.LatestBlock = proofBlock } - // utils.LavaFormatDebug("response signing", utils.LogAttr("request block", request.RelayData.RequestBlock), utils.LogAttr("GUID", ctx), utils.LogAttr("latestBlock", reply.LatestBlock)) - reply, err = lavaprotocol.SignRelayResponse(consumerAddr, *request, rpcps.privKey, reply, dataReliabilityEnabled) + + var ignoredMetadata []pairingtypes.Metadata + reply.Metadata, _, ignoredMetadata = rpcps.chainParser.HandleHeaders(reply.Metadata, chainMessage.GetApiCollection(), spectypes.Header_pass_reply) + reply, err = lavaprotocol.SignRelayResponse(consumerAddress, *request, rpcps.privKey, reply, dataReliabilityEnabled) if err != nil { return nil, err } reply.Metadata = append(reply.Metadata, ignoredMetadata...) // appended here only after signing - // return reply to user + return reply, nil } +func (rpcps *RPCProviderServer) GetParametersForRelayDataReliability( + ctx context.Context, + request *pairingtypes.RelayRequest, + chainMsg chainlib.ChainMessage, + relayTimeout time.Duration, + blockLagForQosSync int64, + averageBlockTime time.Duration, + blockDistanceToFinalization, + blocksInFinalizationData uint32, +) (latestBlock int64, requestedBlockHash []byte, requestedHashes []*chaintracker.BlockStore, modifiedReqBlock int64, finalized, updatedChainMessage bool, err error) { + specificBlock := request.RelayData.RequestBlock + if specificBlock < spectypes.LATEST_BLOCK { + // cases of EARLIEST, FINALIZED, SAFE + // GetLatestBlockData only supports latest relative queries or specific block numbers + specificBlock = spectypes.NOT_APPLICABLE + } + + // handle consistency, if the consumer requested information we do not have in the state tracker + + latestBlock, requestedHashes, _, err = rpcps.handleConsistency(ctx, relayTimeout, request.RelayData.GetSeenBlock(), request.RelayData.GetRequestBlock(), averageBlockTime, blockLagForQosSync, blockDistanceToFinalization, blocksInFinalizationData) + if err != nil { + return 0, nil, nil, 0, false, false, err + } + + // get specific block data for caching + _, specificRequestedHashes, _, getLatestBlockErr := rpcps.reliabilityManager.GetLatestBlockData(spectypes.NOT_APPLICABLE, spectypes.NOT_APPLICABLE, specificBlock) + if getLatestBlockErr == nil && len(specificRequestedHashes) == 1 { + requestedBlockHash = []byte(specificRequestedHashes[0].Hash) + } + + // TODO: take latestBlock and lastSeenBlock and put the greater one of them + updatedChainMessage = chainMsg.UpdateLatestBlockInMessage(latestBlock, true) + + modifiedReqBlock = lavaprotocol.ReplaceRequestedBlock(request.RelayData.RequestBlock, latestBlock) + if modifiedReqBlock != request.RelayData.RequestBlock { + request.RelayData.RequestBlock = modifiedReqBlock + updatedChainMessage = true // meaning we can't bring a newer proof + } + // requestedBlockHash, finalizedBlockHashes = chaintracker.FindRequestedBlockHash(requestedHashes, request.RelayData.RequestBlock, toBlock, fromBlock, finalizedBlockHashes) + finalized = spectypes.IsFinalizedBlock(modifiedReqBlock, latestBlock, blockDistanceToFinalization) + if !finalized && requestedBlockHash == nil && modifiedReqBlock != spectypes.NOT_APPLICABLE { + // avoid using cache, but can still service + utils.LavaFormatWarning("no hash data for requested block", nil, utils.Attribute{Key: "specID", Value: rpcps.rpcProviderEndpoint.ChainID}, utils.Attribute{Key: "GUID", Value: ctx}, utils.Attribute{Key: "requestedBlock", Value: request.RelayData.RequestBlock}, utils.Attribute{Key: "latestBlock", Value: latestBlock}, utils.Attribute{Key: "modifiedReqBlock", Value: modifiedReqBlock}, utils.Attribute{Key: "specificBlock", Value: specificBlock}) + } + + return latestBlock, requestedBlockHash, requestedHashes, modifiedReqBlock, finalized, updatedChainMessage, nil +} + +func (rpcps *RPCProviderServer) BuildRelayFinalizedBlockHashes( + ctx context.Context, + request *pairingtypes.RelayRequest, + reply *pairingtypes.RelayReply, + latestBlock int64, + requestedHashes []*chaintracker.BlockStore, + updatedChainMessage bool, + relayTimeout time.Duration, + averageBlockTime time.Duration, + blockDistanceToFinalization uint32, + blocksInFinalizationData uint32, + modifiedReqBlock int64, +) (err error) { + // now we need to provide the proof for the response + proofBlock := latestBlock + if !updatedChainMessage || len(requestedHashes) == 0 { + // we can fetch a more advanced finalization proof, than we fetched previously + proofBlock, requestedHashes, _, err = rpcps.GetLatestBlockData(ctx, blockDistanceToFinalization, blocksInFinalizationData) + if err != nil { + return err + } + } // else: we updated the chain message to request the specific latestBlock we fetched earlier, so use the previously fetched latest block and hashes + if proofBlock < modifiedReqBlock && proofBlock < request.RelayData.SeenBlock { + // we requested with a newer block, but don't necessarily have the finaliziation proof, chaintracker might be behind + proofBlock = lavaslices.Min([]int64{modifiedReqBlock, request.RelayData.SeenBlock}) + + proofBlock, requestedHashes, err = rpcps.GetBlockDataForOptimisticFetch(ctx, relayTimeout, proofBlock, blockDistanceToFinalization, blocksInFinalizationData, averageBlockTime) + if err != nil { + return utils.LavaFormatError("error getting block range for finalization proof", err) + } + } + + finalizedBlockHashes := chaintracker.BuildProofFromBlocks(requestedHashes) + jsonStr, err := json.Marshal(finalizedBlockHashes) + if err != nil { + return utils.LavaFormatError("failed unmarshaling finalizedBlockHashes", err, utils.Attribute{Key: "GUID", Value: ctx}, + utils.Attribute{Key: "finalizedBlockHashes", Value: finalizedBlockHashes}, utils.Attribute{Key: "specID", Value: rpcps.rpcProviderEndpoint.ChainID}) + } + reply.FinalizedBlocksHashes = jsonStr + reply.LatestBlock = proofBlock + return nil +} + func (rpcps *RPCProviderServer) GetBlockDataForOptimisticFetch(ctx context.Context, relayBaseTimeout time.Duration, requiredProofBlock int64, blockDistanceToFinalization uint32, blocksInFinalizationData uint32, averageBlockTime time.Duration) (latestBlock int64, requestedHashes []*chaintracker.BlockStore, err error) { utils.LavaFormatDebug("getting new blockData for optimistic fetch", utils.Attribute{Key: "GUID", Value: ctx}, utils.Attribute{Key: "requiredProofBlock", Value: requiredProofBlock}) proofBlock := requiredProofBlock @@ -1028,27 +1167,6 @@ func (rpcps *RPCProviderServer) GetLatestBlockData(ctx context.Context, blockDis return } -func (rpcps *RPCProviderServer) processUnsubscribe(ctx context.Context, apiName string, consumerAddr sdk.AccAddress, reqParams interface{}, epoch uint64) error { - var subscriptionID string - switch reqParamsCasted := reqParams.(type) { - case []interface{}: - var ok bool - subscriptionID, ok = reqParamsCasted[0].(string) - if !ok { - return utils.LavaFormatError("processUnsubscribe - p[0].(string) - type assertion failed", nil, utils.Attribute{Key: "GUID", Value: ctx}, utils.Attribute{Key: "type", Value: reqParamsCasted[0]}) - } - case map[string]interface{}: - if apiName == "unsubscribe" { - var ok bool - subscriptionID, ok = reqParamsCasted["query"].(string) - if !ok { - return utils.LavaFormatError("processUnsubscribe - p['query'].(string) - type assertion failed", nil, utils.Attribute{Key: "GUID", Value: ctx}, utils.Attribute{Key: "type", Value: reqParamsCasted["query"]}) - } - } - } - return rpcps.providerSessionManager.ProcessUnsubscribe(apiName, subscriptionID, consumerAddr.String(), epoch) -} - func (rpcps *RPCProviderServer) Probe(ctx context.Context, probeReq *pairingtypes.ProbeRequest) (*pairingtypes.ProbeReply, error) { latestB, _ := rpcps.reliabilityManager.GetLatestBlockNum() probeReply := &pairingtypes.ProbeReply{ @@ -1063,6 +1181,23 @@ func (rpcps *RPCProviderServer) Probe(ctx context.Context, probeReq *pairingtype return probeReply, nil } +func (rpcps *RPCProviderServer) fetchConsumerProcessGuidFromContext(ctx context.Context) (string, bool) { + incomingMetaData, found := metadata.FromIncomingContext(ctx) + if !found { + utils.LavaFormatDebug("fetchConsumerProcessGuidFromContext: no incoming meta found in context") + return "", false + } + for key, value := range incomingMetaData { + if key == common.LAVA_CONSUMER_PROCESS_GUID { + for _, metaDataValue := range value { + return metaDataValue, true + } + } + } + utils.LavaFormatDebug("incoming meta data does not contain process guid", utils.LogAttr("incoming_meta_data", incomingMetaData)) + return "", false +} + func (rpcps *RPCProviderServer) tryGetTimeoutFromRequest(ctx context.Context) (time.Duration, bool, error) { incomingMetaData, found := metadata.FromIncomingContext(ctx) if !found { diff --git a/protocol/rpcprovider/rpcprovider_server_test.go b/protocol/rpcprovider/rpcprovider_server_test.go index e8ff57fa6a..cde8d9263b 100644 --- a/protocol/rpcprovider/rpcprovider_server_test.go +++ b/protocol/rpcprovider/rpcprovider_server_test.go @@ -222,7 +222,7 @@ func TestHandleConsistency(t *testing.T) { w.WriteHeader(http.StatusOK) fmt.Fprint(w, string(replyDataBuf)) }) - chainParser, chainProxy, _, closeServer, _, err := chainlib.CreateChainLibMocks(ts.Ctx, specId, spectypes.APIInterfaceRest, serverHandler, "../../", nil) + chainParser, chainProxy, _, closeServer, _, err := chainlib.CreateChainLibMocks(ts.Ctx, specId, spectypes.APIInterfaceRest, serverHandler, nil, "../../", nil) if closeServer != nil { defer closeServer() } diff --git a/protocol/statetracker/consumer_state_tracker.go b/protocol/statetracker/consumer_state_tracker.go index e7789971f1..121019422f 100644 --- a/protocol/statetracker/consumer_state_tracker.go +++ b/protocol/statetracker/consumer_state_tracker.go @@ -8,7 +8,7 @@ import ( "github.com/cosmos/cosmos-sdk/client/tx" "github.com/lavanet/lava/v2/protocol/chaintracker" "github.com/lavanet/lava/v2/protocol/common" - "github.com/lavanet/lava/v2/protocol/lavaprotocol" + "github.com/lavanet/lava/v2/protocol/lavaprotocol/finalizationconsensus" "github.com/lavanet/lava/v2/protocol/lavasession" "github.com/lavanet/lava/v2/protocol/metrics" updaters "github.com/lavanet/lava/v2/protocol/statetracker/updaters" @@ -93,7 +93,7 @@ func (cst *ConsumerStateTracker) RegisterForPairingUpdates(ctx context.Context, } } -func (cst *ConsumerStateTracker) RegisterFinalizationConsensusForUpdates(ctx context.Context, finalizationConsensus *lavaprotocol.FinalizationConsensus) { +func (cst *ConsumerStateTracker) RegisterFinalizationConsensusForUpdates(ctx context.Context, finalizationConsensus *finalizationconsensus.FinalizationConsensus) { finalizationConsensusUpdater := updaters.NewFinalizationConsensusUpdater(cst.stateQuery, finalizationConsensus.SpecId) finalizationConsensusUpdaterRaw := cst.StateTracker.RegisterForUpdates(ctx, finalizationConsensusUpdater) finalizationConsensusUpdater, ok := finalizationConsensusUpdaterRaw.(*updaters.FinalizationConsensusUpdater) diff --git a/protocol/statetracker/updaters/finalization_consensus_updater.go b/protocol/statetracker/updaters/finalization_consensus_updater.go index 5efeeca430..384829df76 100644 --- a/protocol/statetracker/updaters/finalization_consensus_updater.go +++ b/protocol/statetracker/updaters/finalization_consensus_updater.go @@ -5,7 +5,7 @@ import ( "sync" "time" - "github.com/lavanet/lava/v2/protocol/lavaprotocol" + "github.com/lavanet/lava/v2/protocol/lavaprotocol/finalizationconsensus" "github.com/lavanet/lava/v2/utils" ) @@ -15,17 +15,17 @@ const ( type FinalizationConsensusUpdater struct { lock sync.RWMutex - registeredFinalizationConsensuses []*lavaprotocol.FinalizationConsensus + registeredFinalizationConsensuses []*finalizationconsensus.FinalizationConsensus nextBlockForUpdate uint64 stateQuery *ConsumerStateQuery specId string } func NewFinalizationConsensusUpdater(stateQuery *ConsumerStateQuery, specId string) *FinalizationConsensusUpdater { - return &FinalizationConsensusUpdater{registeredFinalizationConsensuses: []*lavaprotocol.FinalizationConsensus{}, stateQuery: stateQuery, specId: specId} + return &FinalizationConsensusUpdater{registeredFinalizationConsensuses: []*finalizationconsensus.FinalizationConsensus{}, stateQuery: stateQuery, specId: specId} } -func (fcu *FinalizationConsensusUpdater) RegisterFinalizationConsensus(finalizationConsensus *lavaprotocol.FinalizationConsensus) { +func (fcu *FinalizationConsensusUpdater) RegisterFinalizationConsensus(finalizationConsensus *finalizationconsensus.FinalizationConsensus) { // TODO: also update here for the first time fcu.lock.Lock() defer fcu.lock.Unlock() diff --git a/scripts/init_chain.sh b/scripts/init_chain.sh index 99f555f9f7..0cd5b37947 100755 --- a/scripts/init_chain.sh +++ b/scripts/init_chain.sh @@ -73,16 +73,15 @@ echo $(cat "$path$genesis") os_name=$(uname) case "$(uname)" in Darwin) - SED_INLINE="-i ''" ;; + SED_INLINE=(-i '') ;; Linux) - SED_INLINE="-i" ;; + SED_INLINE=(-i) ;; *) echo "unknown system: $(uname)" exit 1 ;; esac - -sed $SED_INLINE \ +sed "${SED_INLINE[@]}" \ -e 's/timeout_propose = .*/timeout_propose = "1s"/' \ -e 's/timeout_propose_delta = .*/timeout_propose_delta = "500ms"/' \ -e 's/timeout_prevote = .*/timeout_prevote = "1s"/' \ @@ -93,8 +92,8 @@ sed $SED_INLINE \ -e 's/skip_timeout_commit = .*/skip_timeout_commit = false/' "$path$config" # Edit app.toml file -sed $SED_INLINE -e "s/enable = .*/enable = true/" "$path$app" -sed $SED_INLINE -e "/Enable defines if the Rosetta API server should be enabled.*/{n;s/enable = .*/enable = false/}" "$path$app" +sed "${SED_INLINE[@]}" -e "s/enable = .*/enable = true/" "$path$app" +sed "${SED_INLINE[@]}" -e "/Enable defines if the Rosetta API server should be enabled.*/{n;s/enable = .*/enable = false/;}" "$path$app" # Add users diff --git a/scripts/pre_setups/init_lava_only_with_node.sh b/scripts/pre_setups/init_lava_only_with_node.sh index 9cc82feefd..4a5c91f227 100755 --- a/scripts/pre_setups/init_lava_only_with_node.sh +++ b/scripts/pre_setups/init_lava_only_with_node.sh @@ -49,15 +49,15 @@ wait_next_block screen -d -m -S provider1 bash -c "source ~/.bashrc; lavap rpcprovider \ $PROVIDER1_LISTENER LAV1 rest '$LAVA_REST' \ -$PROVIDER1_LISTENER LAV1 tendermintrpc '$LAVA_RPC,$LAVA_RPC' \ +$PROVIDER1_LISTENER LAV1 tendermintrpc '$LAVA_RPC,$LAVA_RPC_WS' \ $PROVIDER1_LISTENER LAV1 grpc '$LAVA_GRPC' \ -$EXTRA_PROVIDER_FLAGS --geolocation 1 --log_level debug --from servicer1 --chain-id lava --metrics-listen-address ":7776" 2>&1 | tee $LOGS_DIR/PROVIDER1.log" && sleep 0.25 +$EXTRA_PROVIDER_FLAGS --geolocation 1 --log_level trace --from servicer1 --chain-id lava --metrics-listen-address ":7776" 2>&1 | tee $LOGS_DIR/PROVIDER1.log" && sleep 0.25 wait_next_block screen -d -m -S consumers bash -c "source ~/.bashrc; lavap rpcconsumer \ 127.0.0.1:3360 LAV1 rest 127.0.0.1:3361 LAV1 tendermintrpc 127.0.0.1:3362 LAV1 grpc \ -$EXTRA_PORTAL_FLAGS --geolocation 1 --log_level debug --from user1 --chain-id lava --add-api-method-metrics --allow-insecure-provider-dialing --metrics-listen-address ":7779" 2>&1 | tee $LOGS_DIR/CONSUMERS.log" && sleep 0.25 +$EXTRA_PORTAL_FLAGS --geolocation 1 --log_level trace --from user1 --chain-id lava --add-api-method-metrics --allow-insecure-provider-dialing --metrics-listen-address ":7779" 2>&1 | tee $LOGS_DIR/CONSUMERS.log" && sleep 0.25 echo "--- setting up screens done ---" screen -ls \ No newline at end of file diff --git a/scripts/pre_setups/init_lava_only_with_node_protocol_only.sh b/scripts/pre_setups/init_lava_only_with_node_protocol_only.sh new file mode 100755 index 0000000000..8080e9dd93 --- /dev/null +++ b/scripts/pre_setups/init_lava_only_with_node_protocol_only.sh @@ -0,0 +1,36 @@ +#!/bin/bash +__dir=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +source "$__dir"/../useful_commands.sh +. "${__dir}"/../vars/variables.sh + +LOGS_DIR=${__dir}/../../testutil/debugging/logs +mkdir -p $LOGS_DIR +rm $LOGS_DIR/*.log +echo "[Test Setup] killing lavap" +killall lavap +sleep 1 + +echo "[Test Setup] installing all binaries" +make install-all + +GASPRICE="0.000000001ulava" + +CLIENTSTAKE="500000000000ulava" +PROVIDERSTAKE="500000000000ulava" + +PROVIDER1_LISTENER="127.0.0.1:2220" + +screen -d -m -S provider1 bash -c "source ~/.bashrc; lavap rpcprovider \ +$PROVIDER1_LISTENER LAV1 rest '$LAVA_REST' \ +$PROVIDER1_LISTENER LAV1 tendermintrpc '$LAVA_RPC,$LAVA_RPC_WS' \ +$PROVIDER1_LISTENER LAV1 grpc '$LAVA_GRPC' \ +$EXTRA_PROVIDER_FLAGS --geolocation 1 --log_level trace --from servicer1 --chain-id lava --metrics-listen-address ":7776" 2>&1 | tee $LOGS_DIR/PROVIDER1.log" && sleep 0.25 + +wait_next_block + +screen -d -m -S consumers bash -c "source ~/.bashrc; lavap rpcconsumer \ +127.0.0.1:3360 LAV1 rest 127.0.0.1:3361 LAV1 tendermintrpc 127.0.0.1:3362 LAV1 grpc \ +$EXTRA_PORTAL_FLAGS --geolocation 1 --log_level trace --from user1 --chain-id lava --add-api-method-metrics --allow-insecure-provider-dialing --metrics-listen-address ":7779" 2>&1 | tee $LOGS_DIR/CONSUMERS.log" && sleep 0.25 + +echo "--- setting up screens done ---" +screen -ls \ No newline at end of file diff --git a/scripts/pre_setups/init_lava_only_with_node_two_consumers.sh b/scripts/pre_setups/init_lava_only_with_node_two_consumers.sh new file mode 100755 index 0000000000..96f47d65dc --- /dev/null +++ b/scripts/pre_setups/init_lava_only_with_node_two_consumers.sh @@ -0,0 +1,69 @@ +#!/bin/bash +__dir=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +source "$__dir"/../useful_commands.sh +. "${__dir}"/../vars/variables.sh + +LOGS_DIR=${__dir}/../../testutil/debugging/logs +mkdir -p $LOGS_DIR +rm $LOGS_DIR/*.log + +killall screen +screen -wipe + +echo "[Test Setup] installing all binaries" +make install-all + +echo "[Test Setup] setting up a new lava node" +screen -d -m -S node bash -c "./scripts/start_env_dev.sh" +screen -ls +echo "[Test Setup] sleeping 20 seconds for node to finish setup (if its not enough increase timeout)" +sleep 5 +wait_for_lava_node_to_start + +GASPRICE="0.000000001ulava" +lavad tx gov submit-legacy-proposal spec-add ./cookbook/specs/ibc.json,./cookbook/specs/cosmoswasm.json,./cookbook/specs/tendermint.json,./cookbook/specs/cosmossdk.json,./cookbook/specs/cosmossdk_45.json,./cookbook/specs/cosmossdk_full.json,./cookbook/specs/ethermint.json,./cookbook/specs/ethereum.json,./cookbook/specs/cosmoshub.json,./cookbook/specs/lava.json,./cookbook/specs/osmosis.json,./cookbook/specs/fantom.json,./cookbook/specs/celo.json,./cookbook/specs/optimism.json,./cookbook/specs/arbitrum.json,./cookbook/specs/starknet.json,./cookbook/specs/aptos.json,./cookbook/specs/juno.json,./cookbook/specs/polygon.json,./cookbook/specs/evmos.json,./cookbook/specs/base.json,./cookbook/specs/canto.json,./cookbook/specs/sui.json,./cookbook/specs/solana.json,./cookbook/specs/bsc.json,./cookbook/specs/axelar.json,./cookbook/specs/avalanche.json,./cookbook/specs/fvm.json --lava-dev-test -y --from alice --gas-adjustment "1.5" --gas "auto" --gas-prices $GASPRICE & +wait_next_block +wait_next_block +lavad tx gov vote 1 yes -y --from alice --gas-adjustment "1.5" --gas "auto" --gas-prices $GASPRICE +sleep 4 + +# Plans proposal +lavad tx gov submit-legacy-proposal plans-add ./cookbook/plans/test_plans/default.json,./cookbook/plans/test_plans/temporary-add.json -y --from alice --gas-adjustment "1.5" --gas "auto" --gas-prices $GASPRICE +wait_next_block +wait_next_block +lavad tx gov vote 2 yes -y --from alice --gas-adjustment "1.5" --gas "auto" --gas-prices $GASPRICE + +sleep 4 + +CLIENTSTAKE="500000000000ulava" +PROVIDERSTAKE="500000000000ulava" + +PROVIDER1_LISTENER="127.0.0.1:2220" + +lavad tx subscription buy DefaultPlan $(lavad keys show user1 -a) -y --from user1 --gas-adjustment "1.5" --gas "auto" --gas-prices $GASPRICE +wait_next_block +lavad tx subscription buy DefaultPlan $(lavad keys show user2 -a) -y --from user2 --gas-adjustment "1.5" --gas "auto" --gas-prices $GASPRICE +wait_next_block +lavad tx pairing stake-provider "LAV1" $PROVIDERSTAKE "$PROVIDER1_LISTENER,1" 1 $(operator_address) -y --from servicer1 --delegate-limit 1000ulava --provider-moniker "dummyMoniker" --gas-adjustment "1.5" --gas "auto" --gas-prices $GASPRICE + +sleep_until_next_epoch +wait_next_block + +screen -d -m -S provider1 bash -c "source ~/.bashrc; lavap rpcprovider \ +$PROVIDER1_LISTENER LAV1 rest '$LAVA_REST' \ +$PROVIDER1_LISTENER LAV1 tendermintrpc '$LAVA_RPC,$LAVA_RPC_WS' \ +$PROVIDER1_LISTENER LAV1 grpc '$LAVA_GRPC' \ +$EXTRA_PROVIDER_FLAGS --geolocation 1 --log_level trace --from servicer1 --chain-id lava --metrics-listen-address ":7776" 2>&1 | tee $LOGS_DIR/PROVIDER1.log" && sleep 0.25 + +wait_next_block + +screen -d -m -S consumers bash -c "source ~/.bashrc; lavap rpcconsumer \ +127.0.0.1:3360 LAV1 rest 127.0.0.1:3361 LAV1 tendermintrpc 127.0.0.1:3362 LAV1 grpc \ +$EXTRA_PORTAL_FLAGS --geolocation 1 --log_level trace --from user1 --chain-id lava --add-api-method-metrics --allow-insecure-provider-dialing --metrics-listen-address ":7779" 2>&1 | tee $LOGS_DIR/CONSUMERS.log" && sleep 0.25 + +screen -d -m -S consumers2 bash -c "source ~/.bashrc; lavap rpcconsumer \ +127.0.0.1:3350 LAV1 rest 127.0.0.1:3351 LAV1 tendermintrpc 127.0.0.1:3352 LAV1 grpc \ +$EXTRA_PORTAL_FLAGS --geolocation 1 --log_level trace --from user2 --chain-id lava --add-api-method-metrics --allow-insecure-provider-dialing --metrics-listen-address ":7773" 2>&1 | tee $LOGS_DIR/CONSUMERS2.log" && sleep 0.25 + +echo "--- setting up screens done ---" +screen -ls \ No newline at end of file diff --git a/scripts/setup_providers.sh b/scripts/setup_providers.sh index d4c0f35c75..df06e65760 100755 --- a/scripts/setup_providers.sh +++ b/scripts/setup_providers.sh @@ -44,7 +44,7 @@ $PROVIDER1_LISTENER OSMOSIS rest '$OSMO_REST' \ $PROVIDER1_LISTENER OSMOSIS tendermintrpc '$OSMO_RPC,$OSMO_RPC' \ $PROVIDER1_LISTENER OSMOSIS grpc '$OSMO_GRPC' \ $PROVIDER1_LISTENER LAV1 rest '$LAVA_REST' \ -$PROVIDER1_LISTENER LAV1 tendermintrpc '$LAVA_RPC,$LAVA_RPC' \ +$PROVIDER1_LISTENER LAV1 tendermintrpc '$LAVA_RPC,$LAVA_RPC_WS' \ $PROVIDER1_LISTENER LAV1 grpc '$LAVA_GRPC' \ $PROVIDER1_LISTENER COSMOSHUB rest '$GAIA_REST' \ $PROVIDER1_LISTENER COSMOSHUB tendermintrpc '$GAIA_RPC,$GAIA_RPC' \ @@ -84,7 +84,7 @@ echo; echo "#### Starting provider 2 ####" screen -d -m -S provider2 bash -c "source ~/.bashrc; lavap rpcprovider \ $PROVIDER2_LISTENER ETH1 jsonrpc '$ETH_RPC_WS' \ $PROVIDER2_LISTENER LAV1 rest '$LAVA_REST' \ -$PROVIDER2_LISTENER LAV1 tendermintrpc '$LAVA_RPC,$LAVA_RPC' \ +$PROVIDER2_LISTENER LAV1 tendermintrpc '$LAVA_RPC,$LAVA_RPC_WS' \ $PROVIDER2_LISTENER LAV1 grpc '$LAVA_GRPC' \ $EXTRA_PROVIDER_FLAGS --geolocation "$GEOLOCATION" --log_level debug --from servicer2 --chain-id lava 2>&1 | tee $LOGS_DIR/PROVIDER2.log" && sleep 0.25 # $PROVIDER2_LISTENER MANTLE jsonrpc '$MANTLE_JRPC' \ @@ -93,7 +93,7 @@ echo; echo "#### Starting provider 3 ####" screen -d -m -S provider3 bash -c "source ~/.bashrc; lavap rpcprovider \ $PROVIDER3_LISTENER ETH1 jsonrpc '$ETH_RPC_WS' \ $PROVIDER3_LISTENER LAV1 rest '$LAVA_REST' \ -$PROVIDER3_LISTENER LAV1 tendermintrpc '$LAVA_RPC,$LAVA_RPC' \ +$PROVIDER3_LISTENER LAV1 tendermintrpc '$LAVA_RPC,$LAVA_RPC_WS' \ $PROVIDER3_LISTENER LAV1 grpc '$LAVA_GRPC' \ $EXTRA_PROVIDER_FLAGS --geolocation "$GEOLOCATION" --log_level debug --from servicer3 --chain-id lava 2>&1 | tee $LOGS_DIR/PROVIDER3.log" && sleep 0.25 # $PROVIDER3_LISTENER MANTLE jsonrpc '$MANTLE_JRPC' \ diff --git a/testutil/e2e/e2eConfigs/provider/jsonrpcProvider1.yml b/testutil/e2e/e2eConfigs/provider/jsonrpcProvider1.yml index 801cc178a9..b68e6c812c 100644 --- a/testutil/e2e/e2eConfigs/provider/jsonrpcProvider1.yml +++ b/testutil/e2e/e2eConfigs/provider/jsonrpcProvider1.yml @@ -7,6 +7,12 @@ endpoints: - url: http://127.0.0.1:1111 Addons: - debug + - url: ws://127.0.0.1:1111/ws + Addons: + - debug - url: http://127.0.0.1:1111 Addons: - - archive + - archive + - url: ws://127.0.0.1:1111/ws + Addons: + - archive diff --git a/testutil/e2e/e2eConfigs/provider/jsonrpcProvider2.yml b/testutil/e2e/e2eConfigs/provider/jsonrpcProvider2.yml index 41f3992d6f..b3429ed76c 100644 --- a/testutil/e2e/e2eConfigs/provider/jsonrpcProvider2.yml +++ b/testutil/e2e/e2eConfigs/provider/jsonrpcProvider2.yml @@ -5,3 +5,4 @@ endpoints: address: 127.0.0.1:2222 node-urls: - url: http://127.0.0.1:1111 + - url: ws://127.0.0.1:1111/ws diff --git a/testutil/e2e/e2eConfigs/provider/jsonrpcProvider3.yml b/testutil/e2e/e2eConfigs/provider/jsonrpcProvider3.yml index ca4553b9db..78f216c720 100644 --- a/testutil/e2e/e2eConfigs/provider/jsonrpcProvider3.yml +++ b/testutil/e2e/e2eConfigs/provider/jsonrpcProvider3.yml @@ -5,3 +5,4 @@ endpoints: address: 127.0.0.1:2223 node-urls: - url: http://127.0.0.1:1111 + - url: ws://127.0.0.1:1111/ws diff --git a/testutil/e2e/e2eConfigs/provider/jsonrpcProvider4.yml b/testutil/e2e/e2eConfigs/provider/jsonrpcProvider4.yml index 54f7cf46b0..f95f7081fd 100644 --- a/testutil/e2e/e2eConfigs/provider/jsonrpcProvider4.yml +++ b/testutil/e2e/e2eConfigs/provider/jsonrpcProvider4.yml @@ -5,3 +5,4 @@ endpoints: address: 127.0.0.1:2224 node-urls: - url: http://127.0.0.1:1111 + - url: ws://127.0.0.1:1111/ws diff --git a/testutil/e2e/e2eConfigs/provider/jsonrpcProvider5.yml b/testutil/e2e/e2eConfigs/provider/jsonrpcProvider5.yml index 8b9a0c76ce..09585aef14 100644 --- a/testutil/e2e/e2eConfigs/provider/jsonrpcProvider5.yml +++ b/testutil/e2e/e2eConfigs/provider/jsonrpcProvider5.yml @@ -5,3 +5,4 @@ endpoints: address: 127.0.0.1:2225 node-urls: - url: http://127.0.0.1:1111 + - url: ws://127.0.0.1:1111/ws diff --git a/testutil/e2e/protocolE2E.go b/testutil/e2e/protocolE2E.go index 50a27d5bb8..7c32fe0c40 100644 --- a/testutil/e2e/protocolE2E.go +++ b/testutil/e2e/protocolE2E.go @@ -20,6 +20,7 @@ import ( "strconv" "strings" "sync" + "sync/atomic" "time" tmclient "github.com/cometbft/cometbft/rpc/client/http" @@ -28,6 +29,7 @@ import ( "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/ethclient" "github.com/ethereum/go-ethereum/rpc" + "github.com/gorilla/websocket" commonconsts "github.com/lavanet/lava/v2/testutil/common/consts" "github.com/lavanet/lava/v2/testutil/e2e/sdk" "github.com/lavanet/lava/v2/utils" @@ -804,9 +806,22 @@ func (lt *lavaTest) saveLogs() { panic(err) } writer := bufio.NewWriter(file) - writer.Write(logBuffer.Bytes()) - writer.Flush() - utils.LavaFormatDebug("writing file", []utils.Attribute{{Key: "fileName", Value: fileName}, {Key: "lines", Value: len(logBuffer.Bytes())}}...) + var bytesWritten int + bytesWritten, err = writer.Write(logBuffer.Bytes()) + if err != nil { + utils.LavaFormatError("Error writing to file", err) + } else { + err = writer.Flush() + if err != nil { + utils.LavaFormatError("Error flushing writer", err) + } else { + utils.LavaFormatDebug("success writing to file", + utils.LogAttr("fileName", fileName), + utils.LogAttr("bytesWritten", bytesWritten), + utils.LogAttr("lines", len(logBuffer.Bytes())), + ) + } + } file.Close() lines := strings.Split(logBuffer.String(), "\n") @@ -1192,6 +1207,168 @@ func (lt *lavaTest) getKeyAddress(key string) string { return string(output) } +func (lt *lavaTest) runWebSocketSubscriptionTest(tendermintConsumerWebSocketURL string) { + utils.LavaFormatInfo("Starting WebSocket Subscription Test") + + subscriptionsCount := 5 + + createWebSocketClient := func() *websocket.Conn { + websocketDialer := websocket.Dialer{ + ReadBufferSize: 1024, + WriteBufferSize: 1024, + } + + header := make(http.Header) + + webSocketClient, resp, err := websocketDialer.DialContext(context.Background(), tendermintConsumerWebSocketURL, header) + if err != nil { + panic(err) + } + utils.LavaFormatDebug("Dialed WebSocket Successful", + utils.LogAttr("url", tendermintConsumerWebSocketURL), + utils.LogAttr("response", resp), + ) + + return webSocketClient + } + + const ( + SUBSCRIBE = "subscribe" + UNSUBSCRIBE = "unsubscribe" + ) + + createSubscriptionJsonRpcMessage := func(method string) map[string]interface{} { + return map[string]interface{}{ + "jsonrpc": "2.0", + "method": method, + "id": 1, + "params": map[string]interface{}{ + "query": "tm.event = 'NewBlock'", + }, + } + } + + subscribeToNewBlockEvents := func(webSocketClient *websocket.Conn) { + msgData := createSubscriptionJsonRpcMessage(SUBSCRIBE) + err := webSocketClient.WriteJSON(msgData) + if err != nil { + panic(err) + } + } + + webSocketShouldListen := true + defer func() { + webSocketShouldListen = false + }() + + type subscriptionContainer struct { + newBlockMessageCount int32 + webSocketClient *websocket.Conn + } + + incrementNewBlockMessageCount := func(sc *subscriptionContainer) { + atomic.AddInt32(&sc.newBlockMessageCount, 1) + } + + readNewBlockMessageCount := func(subscriptionContainer *subscriptionContainer) int32 { + return atomic.LoadInt32(&subscriptionContainer.newBlockMessageCount) + } + + startWebSocketReader := func(webSocketName string, webSocketClient *websocket.Conn, subscriptionContainer *subscriptionContainer) { + for { + _, message, err := webSocketClient.ReadMessage() + if err != nil { + if webSocketShouldListen { + panic(err) + } + + // Once the test is done, we can safely ignore the error + return + } + + if strings.Contains(string(message), "NewBlock") { + incrementNewBlockMessageCount(subscriptionContainer) + } + } + } + + startSubscriptions := func(count int) []*subscriptionContainer { + subscriptionContainers := []*subscriptionContainer{} + // Start a websocket clients and connect them to tendermint consumer endpoint + for i := 0; i < count; i++ { + utils.LavaFormatInfo("Setting up web socket client " + strconv.Itoa(i+1)) + webSocketClient := createWebSocketClient() + + subscriptionContainer := &subscriptionContainer{ + webSocketClient: webSocketClient, + newBlockMessageCount: 0, + } + + // Start a reader for each client to count the number of NewBlock messages received + utils.LavaFormatInfo("Start listening for NewBlock messages on web socket " + strconv.Itoa(i+1)) + + go startWebSocketReader("webSocketClient"+strconv.Itoa(i+1), webSocketClient, subscriptionContainer) + + // Subscribe to new block events + utils.LavaFormatInfo("Subscribing to NewBlock events on web socket " + strconv.Itoa(i+1)) + + subscribeToNewBlockEvents(webSocketClient) + subscriptionContainers = append(subscriptionContainers, subscriptionContainer) + } + + return subscriptionContainers + } + + subscriptions := startSubscriptions(subscriptionsCount) + + // Wait for 10 blocks + utils.LavaFormatInfo("Sleeping for 12 seconds to receive blocks") + time.Sleep(12 * time.Second) + + utils.LavaFormatDebug("Looping through subscription containers", + utils.LogAttr("subscriptionContainers", subscriptions), + ) + // Check the all web socket clients received at least 10 blocks + for i := 0; i < subscriptionsCount; i++ { + utils.LavaFormatInfo("Making sure both clients received at least 10 blocks") + if subscriptions[i] == nil { + panic("subscriptionContainers[" + strconv.Itoa(i+1) + "] is nil") + } + newBlockMessageCount := readNewBlockMessageCount(subscriptions[i]) + if newBlockMessageCount < 10 { + panic(fmt.Sprintf("subscription should have received at least 10 blocks, got: %d", newBlockMessageCount)) + } + } + + // Unsubscribe one client + utils.LavaFormatInfo("Unsubscribing from NewBlock events on web socket 1") + msgData := createSubscriptionJsonRpcMessage(UNSUBSCRIBE) + err := subscriptions[0].webSocketClient.WriteJSON(msgData) + if err != nil { + panic(err) + } + + // Make sure that the unsubscribed client stops receiving blocks + webSocketClient1NewBlockMsgCountAfterUnsubscribe := readNewBlockMessageCount(subscriptions[0]) + + utils.LavaFormatInfo("Sleeping for 7 seconds to make sure unsubscribed client stops receiving blocks") + time.Sleep(7 * time.Second) + + if readNewBlockMessageCount(subscriptions[0]) != webSocketClient1NewBlockMsgCountAfterUnsubscribe { + panic("unsubscribed client should not receive new blocks") + } + + webSocketShouldListen = false + + // Disconnect all websocket clients + for i := 0; i < subscriptionsCount; i++ { + utils.LavaFormatInfo("Closing web socket " + strconv.Itoa(i+1)) + subscriptions[i].webSocketClient.Close() + } + + utils.LavaFormatInfo("WebSocket Subscription Test OK") +} + func calculateProviderCU(pairingClient pairingTypes.QueryClient) (map[string]uint64, error) { providerCU := make(map[string]uint64) res, err := pairingClient.ProvidersEpochCu(context.Background(), &pairingTypes.QueryProvidersEpochCuRequest{}) @@ -1348,6 +1525,8 @@ func runProtocolE2E(timeout time.Duration) { lt.checkResponse("http://127.0.0.1:3340", "http://127.0.0.1:3341", "127.0.0.1:3342") + lt.runWebSocketSubscriptionTest("ws://127.0.0.1:3340/websocket") + lt.checkQoS() utils.LavaFormatInfo("Sleeping Until All Rewards are collected") diff --git a/testutil/e2e/proxy/proxy.go b/testutil/e2e/proxy/proxy.go index f69f125fb1..1ad6c09412 100644 --- a/testutil/e2e/proxy/proxy.go +++ b/testutil/e2e/proxy/proxy.go @@ -10,6 +10,7 @@ import ( "strings" "time" + "github.com/gorilla/websocket" "github.com/lavanet/lava/v2/protocol/chainlib/chainproxy/rpcInterfaceMessages" "github.com/lavanet/lava/v2/protocol/chainlib/chainproxy/rpcclient" ) @@ -143,6 +144,15 @@ func main(host, port *string) { startProxyProcess(process) } +// Define the upgrader +var upgrader = websocket.Upgrader{ + ReadBufferSize: 1024, + WriteBufferSize: 1024, + CheckOrigin: func(r *http.Request) bool { + return true + }, +} + func startProxyProcess(process proxyProcess) { process.mock.requests = jsonFileToMap(process.mockfile) if process.mock.requests == nil { @@ -168,7 +178,34 @@ func startProxyProcess(process proxyProcess) { fmt.Print(fmt.Sprintf(" ::: Proxy Started! ::: ID: %s", process.id) + "\n") fmt.Print(fmt.Sprintf(" ::: Listening On ::: %s", "http://0.0.0.0:"+process.port+"/") + "\n") + // Define the WebSocket handler + wsHandler := func(w http.ResponseWriter, r *http.Request) { + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + log.Println("Upgrade error:", err) + return + } + defer conn.Close() + + for { + // Read message from browser + msgType, msg, err := conn.ReadMessage() + if err != nil { + log.Println("Read error:", err) + break + } + // Print the message to the console + log.Printf("Received: %s\n", msg) + // Write message back to browser + if err = conn.WriteMessage(msgType, msg); err != nil { + log.Println("Write error:", err) + break + } + } + } + http.HandleFunc("/", process.handler) + http.HandleFunc("/ws", wsHandler) err := http.ListenAndServe(":"+process.port, nil) if err != nil { log.Fatal(err.Error()) @@ -243,12 +280,12 @@ func idInstertedResponse(val string, replyMessage *rpcInterfaceMessages.JsonrpcM const dotsStr = " ::: " -func (p proxyProcess) LavaTestProxy(rw http.ResponseWriter, req *http.Request) { +func (p proxyProcess) LavaTestProxy(responseWriter http.ResponseWriter, request *http.Request) { host := p.host mock := p.mock // Get request body - rawBody := getDataFromIORead(&req.Body, true) + rawBody := getDataFromIORead(&request.Body, true) // TODO: set all ids to 1 rawBodyS := string(rawBody) // sep := "id\":" @@ -261,7 +298,7 @@ func (p proxyProcess) LavaTestProxy(rw http.ResponseWriter, req *http.Request) { println(dotsStr+p.port+dotsStr+p.id+" ::: INCOMING PROXY MSG :::", rawBodyS) var respmsg rpcclient.JsonrpcMessage - if err := json.NewDecoder(req.Body).Decode(&respmsg); err != nil { + if err := json.NewDecoder(request.Body).Decode(&respmsg); err != nil { println(err.Error()) } replyMessage, err := rpcInterfaceMessages.ConvertJsonRPCMsg(&respmsg) @@ -275,8 +312,7 @@ func (p proxyProcess) LavaTestProxy(rw http.ResponseWriter, req *http.Request) { if fakeResponse && strings.Contains(rawBodyS, "blockNumber") { println("!!!!!!!!!!!!!! block number") - rw.WriteHeader(200) - rw.Write([]byte(fmt.Sprintf("{\"jsonrpc\":\"2.0\",\"id\":%s,\"result\":\"%s\"}", respId, getMockBlockNumber()))) + returnResponse(responseWriter, http.StatusOK, []byte(fmt.Sprintf("{\"jsonrpc\":\"2.0\",\"id\":%s,\"result\":\"%s\"}", respId, getMockBlockNumber()))) } else { // Return Cached data if found in history and fromCache is set on jStruct := &jsonStruct{} @@ -296,11 +332,10 @@ func (p proxyProcess) LavaTestProxy(rw http.ResponseWriter, req *http.Request) { fakeCount += 1 } time.Sleep(500 * time.Millisecond) - rw.WriteHeader(200) - rw.Write([]byte(orderedJSON)) + returnResponse(responseWriter, http.StatusOK, []byte(orderedJSON)) } else { // Recreating Request - proxyRequest, err := createProxyRequest(req, host, rawBodyS) + proxyRequest, err := createProxyRequest(request, host, rawBodyS) if err != nil { println(err.Error()) } else { @@ -327,7 +362,7 @@ func (p proxyProcess) LavaTestProxy(rw http.ResponseWriter, req *http.Request) { println("XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX Got error in response - retrying request") // Recreating Request - proxyRequest, err = createProxyRequest(req, host, rawBodyS) + proxyRequest, err = createProxyRequest(request, host, rawBodyS) if err != nil { println(err.Error()) respBody = []byte(err.Error()) @@ -370,7 +405,7 @@ func (p proxyProcess) LavaTestProxy(rw http.ResponseWriter, req *http.Request) { respBody = []byte("error") } // time.Sleep(500 * time.Millisecond) - returnResponse(rw, status, respBody) + returnResponse(responseWriter, status, respBody) } } } diff --git a/utils/lavalog.go b/utils/lavalog.go index 870890e78c..1e1e81393b 100644 --- a/utils/lavalog.go +++ b/utils/lavalog.go @@ -343,3 +343,7 @@ func FormatLongString(msg string, maxCharacters int) string { } return msg } + +func ToHexString(hash string) string { + return fmt.Sprintf("%x", hash) +} diff --git a/x/pairing/types/relay_mock.pb.go b/x/pairing/types/relay_mock.pb.go new file mode 100644 index 0000000000..e49f7212e9 --- /dev/null +++ b/x/pairing/types/relay_mock.pb.go @@ -0,0 +1,410 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: x/pairing/types/relay.pb.go +// +// Generated by this command: +// +// mockgen -source=x/pairing/types/relay.pb.go -destination x/pairing/types/relay_mock.pb.go -package types +// +// Package types is a generated GoMock package. +package types + +import ( + context "context" + reflect "reflect" + + gomock "go.uber.org/mock/gomock" + grpc "google.golang.org/grpc" + metadata "google.golang.org/grpc/metadata" +) + +// MockRelayerClient is a mock of RelayerClient interface. +type MockRelayerClient struct { + ctrl *gomock.Controller + recorder *MockRelayerClientMockRecorder +} + +// MockRelayerClientMockRecorder is the mock recorder for MockRelayerClient. +type MockRelayerClientMockRecorder struct { + mock *MockRelayerClient +} + +// NewMockRelayerClient creates a new mock instance. +func NewMockRelayerClient(ctrl *gomock.Controller) *MockRelayerClient { + mock := &MockRelayerClient{ctrl: ctrl} + mock.recorder = &MockRelayerClientMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockRelayerClient) EXPECT() *MockRelayerClientMockRecorder { + return m.recorder +} + +// Probe mocks base method. +func (m *MockRelayerClient) Probe(ctx context.Context, in *ProbeRequest, opts ...grpc.CallOption) (*ProbeReply, error) { + m.ctrl.T.Helper() + varargs := []any{ctx, in} + for _, a := range opts { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "Probe", varargs...) + ret0, _ := ret[0].(*ProbeReply) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Probe indicates an expected call of Probe. +func (mr *MockRelayerClientMockRecorder) Probe(ctx, in any, opts ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{ctx, in}, opts...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Probe", reflect.TypeOf((*MockRelayerClient)(nil).Probe), varargs...) +} + +// Relay mocks base method. +func (m *MockRelayerClient) Relay(ctx context.Context, in *RelayRequest, opts ...grpc.CallOption) (*RelayReply, error) { + m.ctrl.T.Helper() + varargs := []any{ctx, in} + for _, a := range opts { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "Relay", varargs...) + ret0, _ := ret[0].(*RelayReply) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Relay indicates an expected call of Relay. +func (mr *MockRelayerClientMockRecorder) Relay(ctx, in any, opts ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{ctx, in}, opts...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Relay", reflect.TypeOf((*MockRelayerClient)(nil).Relay), varargs...) +} + +// RelaySubscribe mocks base method. +func (m *MockRelayerClient) RelaySubscribe(ctx context.Context, in *RelayRequest, opts ...grpc.CallOption) (Relayer_RelaySubscribeClient, error) { + m.ctrl.T.Helper() + varargs := []any{ctx, in} + for _, a := range opts { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "RelaySubscribe", varargs...) + ret0, _ := ret[0].(Relayer_RelaySubscribeClient) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// RelaySubscribe indicates an expected call of RelaySubscribe. +func (mr *MockRelayerClientMockRecorder) RelaySubscribe(ctx, in any, opts ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{ctx, in}, opts...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RelaySubscribe", reflect.TypeOf((*MockRelayerClient)(nil).RelaySubscribe), varargs...) +} + +// MockRelayer_RelaySubscribeClient is a mock of Relayer_RelaySubscribeClient interface. +type MockRelayer_RelaySubscribeClient struct { + ctrl *gomock.Controller + recorder *MockRelayer_RelaySubscribeClientMockRecorder +} + +// MockRelayer_RelaySubscribeClientMockRecorder is the mock recorder for MockRelayer_RelaySubscribeClient. +type MockRelayer_RelaySubscribeClientMockRecorder struct { + mock *MockRelayer_RelaySubscribeClient +} + +// NewMockRelayer_RelaySubscribeClient creates a new mock instance. +func NewMockRelayer_RelaySubscribeClient(ctrl *gomock.Controller) *MockRelayer_RelaySubscribeClient { + mock := &MockRelayer_RelaySubscribeClient{ctrl: ctrl} + mock.recorder = &MockRelayer_RelaySubscribeClientMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockRelayer_RelaySubscribeClient) EXPECT() *MockRelayer_RelaySubscribeClientMockRecorder { + return m.recorder +} + +// CloseSend mocks base method. +func (m *MockRelayer_RelaySubscribeClient) CloseSend() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CloseSend") + ret0, _ := ret[0].(error) + return ret0 +} + +// CloseSend indicates an expected call of CloseSend. +func (mr *MockRelayer_RelaySubscribeClientMockRecorder) CloseSend() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CloseSend", reflect.TypeOf((*MockRelayer_RelaySubscribeClient)(nil).CloseSend)) +} + +// Context mocks base method. +func (m *MockRelayer_RelaySubscribeClient) Context() context.Context { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Context") + ret0, _ := ret[0].(context.Context) + return ret0 +} + +// Context indicates an expected call of Context. +func (mr *MockRelayer_RelaySubscribeClientMockRecorder) Context() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Context", reflect.TypeOf((*MockRelayer_RelaySubscribeClient)(nil).Context)) +} + +// Header mocks base method. +func (m *MockRelayer_RelaySubscribeClient) Header() (metadata.MD, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Header") + ret0, _ := ret[0].(metadata.MD) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Header indicates an expected call of Header. +func (mr *MockRelayer_RelaySubscribeClientMockRecorder) Header() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Header", reflect.TypeOf((*MockRelayer_RelaySubscribeClient)(nil).Header)) +} + +// Recv mocks base method. +func (m *MockRelayer_RelaySubscribeClient) Recv() (*RelayReply, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Recv") + ret0, _ := ret[0].(*RelayReply) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Recv indicates an expected call of Recv. +func (mr *MockRelayer_RelaySubscribeClientMockRecorder) Recv() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Recv", reflect.TypeOf((*MockRelayer_RelaySubscribeClient)(nil).Recv)) +} + +// RecvMsg mocks base method. +func (m_2 *MockRelayer_RelaySubscribeClient) RecvMsg(m any) error { + m_2.ctrl.T.Helper() + ret := m_2.ctrl.Call(m_2, "RecvMsg", m) + ret0, _ := ret[0].(error) + return ret0 +} + +// RecvMsg indicates an expected call of RecvMsg. +func (mr *MockRelayer_RelaySubscribeClientMockRecorder) RecvMsg(m any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RecvMsg", reflect.TypeOf((*MockRelayer_RelaySubscribeClient)(nil).RecvMsg), m) +} + +// SendMsg mocks base method. +func (m_2 *MockRelayer_RelaySubscribeClient) SendMsg(m any) error { + m_2.ctrl.T.Helper() + ret := m_2.ctrl.Call(m_2, "SendMsg", m) + ret0, _ := ret[0].(error) + return ret0 +} + +// SendMsg indicates an expected call of SendMsg. +func (mr *MockRelayer_RelaySubscribeClientMockRecorder) SendMsg(m any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendMsg", reflect.TypeOf((*MockRelayer_RelaySubscribeClient)(nil).SendMsg), m) +} + +// Trailer mocks base method. +func (m *MockRelayer_RelaySubscribeClient) Trailer() metadata.MD { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Trailer") + ret0, _ := ret[0].(metadata.MD) + return ret0 +} + +// Trailer indicates an expected call of Trailer. +func (mr *MockRelayer_RelaySubscribeClientMockRecorder) Trailer() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Trailer", reflect.TypeOf((*MockRelayer_RelaySubscribeClient)(nil).Trailer)) +} + +// MockRelayerServer is a mock of RelayerServer interface. +type MockRelayerServer struct { + ctrl *gomock.Controller + recorder *MockRelayerServerMockRecorder +} + +// MockRelayerServerMockRecorder is the mock recorder for MockRelayerServer. +type MockRelayerServerMockRecorder struct { + mock *MockRelayerServer +} + +// NewMockRelayerServer creates a new mock instance. +func NewMockRelayerServer(ctrl *gomock.Controller) *MockRelayerServer { + mock := &MockRelayerServer{ctrl: ctrl} + mock.recorder = &MockRelayerServerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockRelayerServer) EXPECT() *MockRelayerServerMockRecorder { + return m.recorder +} + +// Probe mocks base method. +func (m *MockRelayerServer) Probe(arg0 context.Context, arg1 *ProbeRequest) (*ProbeReply, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Probe", arg0, arg1) + ret0, _ := ret[0].(*ProbeReply) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Probe indicates an expected call of Probe. +func (mr *MockRelayerServerMockRecorder) Probe(arg0, arg1 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Probe", reflect.TypeOf((*MockRelayerServer)(nil).Probe), arg0, arg1) +} + +// Relay mocks base method. +func (m *MockRelayerServer) Relay(arg0 context.Context, arg1 *RelayRequest) (*RelayReply, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Relay", arg0, arg1) + ret0, _ := ret[0].(*RelayReply) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Relay indicates an expected call of Relay. +func (mr *MockRelayerServerMockRecorder) Relay(arg0, arg1 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Relay", reflect.TypeOf((*MockRelayerServer)(nil).Relay), arg0, arg1) +} + +// RelaySubscribe mocks base method. +func (m *MockRelayerServer) RelaySubscribe(arg0 *RelayRequest, arg1 Relayer_RelaySubscribeServer) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "RelaySubscribe", arg0, arg1) + ret0, _ := ret[0].(error) + return ret0 +} + +// RelaySubscribe indicates an expected call of RelaySubscribe. +func (mr *MockRelayerServerMockRecorder) RelaySubscribe(arg0, arg1 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RelaySubscribe", reflect.TypeOf((*MockRelayerServer)(nil).RelaySubscribe), arg0, arg1) +} + +// MockRelayer_RelaySubscribeServer is a mock of Relayer_RelaySubscribeServer interface. +type MockRelayer_RelaySubscribeServer struct { + ctrl *gomock.Controller + recorder *MockRelayer_RelaySubscribeServerMockRecorder +} + +// MockRelayer_RelaySubscribeServerMockRecorder is the mock recorder for MockRelayer_RelaySubscribeServer. +type MockRelayer_RelaySubscribeServerMockRecorder struct { + mock *MockRelayer_RelaySubscribeServer +} + +// NewMockRelayer_RelaySubscribeServer creates a new mock instance. +func NewMockRelayer_RelaySubscribeServer(ctrl *gomock.Controller) *MockRelayer_RelaySubscribeServer { + mock := &MockRelayer_RelaySubscribeServer{ctrl: ctrl} + mock.recorder = &MockRelayer_RelaySubscribeServerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockRelayer_RelaySubscribeServer) EXPECT() *MockRelayer_RelaySubscribeServerMockRecorder { + return m.recorder +} + +// Context mocks base method. +func (m *MockRelayer_RelaySubscribeServer) Context() context.Context { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Context") + ret0, _ := ret[0].(context.Context) + return ret0 +} + +// Context indicates an expected call of Context. +func (mr *MockRelayer_RelaySubscribeServerMockRecorder) Context() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Context", reflect.TypeOf((*MockRelayer_RelaySubscribeServer)(nil).Context)) +} + +// RecvMsg mocks base method. +func (m_2 *MockRelayer_RelaySubscribeServer) RecvMsg(m any) error { + m_2.ctrl.T.Helper() + ret := m_2.ctrl.Call(m_2, "RecvMsg", m) + ret0, _ := ret[0].(error) + return ret0 +} + +// RecvMsg indicates an expected call of RecvMsg. +func (mr *MockRelayer_RelaySubscribeServerMockRecorder) RecvMsg(m any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RecvMsg", reflect.TypeOf((*MockRelayer_RelaySubscribeServer)(nil).RecvMsg), m) +} + +// Send mocks base method. +func (m *MockRelayer_RelaySubscribeServer) Send(arg0 *RelayReply) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Send", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// Send indicates an expected call of Send. +func (mr *MockRelayer_RelaySubscribeServerMockRecorder) Send(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Send", reflect.TypeOf((*MockRelayer_RelaySubscribeServer)(nil).Send), arg0) +} + +// SendHeader mocks base method. +func (m *MockRelayer_RelaySubscribeServer) SendHeader(arg0 metadata.MD) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SendHeader", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// SendHeader indicates an expected call of SendHeader. +func (mr *MockRelayer_RelaySubscribeServerMockRecorder) SendHeader(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendHeader", reflect.TypeOf((*MockRelayer_RelaySubscribeServer)(nil).SendHeader), arg0) +} + +// SendMsg mocks base method. +func (m_2 *MockRelayer_RelaySubscribeServer) SendMsg(m any) error { + m_2.ctrl.T.Helper() + ret := m_2.ctrl.Call(m_2, "SendMsg", m) + ret0, _ := ret[0].(error) + return ret0 +} + +// SendMsg indicates an expected call of SendMsg. +func (mr *MockRelayer_RelaySubscribeServerMockRecorder) SendMsg(m any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendMsg", reflect.TypeOf((*MockRelayer_RelaySubscribeServer)(nil).SendMsg), m) +} + +// SetHeader mocks base method. +func (m *MockRelayer_RelaySubscribeServer) SetHeader(arg0 metadata.MD) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SetHeader", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// SetHeader indicates an expected call of SetHeader. +func (mr *MockRelayer_RelaySubscribeServerMockRecorder) SetHeader(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetHeader", reflect.TypeOf((*MockRelayer_RelaySubscribeServer)(nil).SetHeader), arg0) +} + +// SetTrailer mocks base method. +func (m *MockRelayer_RelaySubscribeServer) SetTrailer(arg0 metadata.MD) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "SetTrailer", arg0) +} + +// SetTrailer indicates an expected call of SetTrailer. +func (mr *MockRelayer_RelaySubscribeServerMockRecorder) SetTrailer(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetTrailer", reflect.TypeOf((*MockRelayer_RelaySubscribeServer)(nil).SetTrailer), arg0) +} diff --git a/x/spec/types/api_collection.pb.go b/x/spec/types/api_collection.pb.go index 80991d5006..b4e9b59da0 100644 --- a/x/spec/types/api_collection.pb.go +++ b/x/spec/types/api_collection.pb.go @@ -58,6 +58,9 @@ const ( FUNCTION_TAG_SET_LATEST_IN_BODY FUNCTION_TAG = 4 FUNCTION_TAG_VERIFICATION FUNCTION_TAG = 5 FUNCTION_TAG_GET_EARLIEST_BLOCK FUNCTION_TAG = 6 + FUNCTION_TAG_SUBSCRIBE FUNCTION_TAG = 7 + FUNCTION_TAG_UNSUBSCRIBE FUNCTION_TAG = 8 + FUNCTION_TAG_UNSUBSCRIBE_ALL FUNCTION_TAG = 9 ) var FUNCTION_TAG_name = map[int32]string{ @@ -68,6 +71,9 @@ var FUNCTION_TAG_name = map[int32]string{ 4: "SET_LATEST_IN_BODY", 5: "VERIFICATION", 6: "GET_EARLIEST_BLOCK", + 7: "SUBSCRIBE", + 8: "UNSUBSCRIBE", + 9: "UNSUBSCRIBE_ALL", } var FUNCTION_TAG_value = map[string]int32{ @@ -78,6 +84,9 @@ var FUNCTION_TAG_value = map[string]int32{ "SET_LATEST_IN_BODY": 4, "VERIFICATION": 5, "GET_EARLIEST_BLOCK": 6, + "SUBSCRIBE": 7, + "UNSUBSCRIBE": 8, + "UNSUBSCRIBE_ALL": 9, } func (x FUNCTION_TAG) String() string {