From 5937abe86e917aab29f0b9fcb6abd925f8f56610 Mon Sep 17 00:00:00 2001 From: Omer <100387053+omerlavanet@users.noreply.github.com> Date: Wed, 21 Aug 2024 10:36:39 +0300 Subject: [PATCH] chore: consollidate-chain-message-data (#1636) * create protocolMessage class * fix bug * fix arg mismatch * fix test --------- Co-authored-by: Ran Mishael --- protocol/chainlib/chainlib.go | 6 +- protocol/chainlib/chainlib_mock.go | 141 ++++++++++++------ .../chainlib/consumer_websocket_manager.go | 24 +-- .../consumer_ws_subscription_manager.go | 86 +++++------ .../consumer_ws_subscription_manager_test.go | 87 ++++++----- protocol/chainlib/protocol_message.go | 53 +++++++ .../consumer_session_manager_test.go | 23 ++- protocol/lavasession/used_providers.go | 15 +- protocol/rpcconsumer/rpcconsumer_server.go | 115 +++++++------- x/pairing/types/relay_mock.pb.go | 48 +++--- 10 files changed, 357 insertions(+), 241 deletions(-) create mode 100644 protocol/chainlib/protocol_message.go diff --git a/protocol/chainlib/chainlib.go b/protocol/chainlib/chainlib.go index 83aa8e1e30..8ed037669d 100644 --- a/protocol/chainlib/chainlib.go +++ b/protocol/chainlib/chainlib.go @@ -125,15 +125,13 @@ type RelaySender interface { consumerIp string, analytics *metrics.RelayMetrics, metadata []pairingtypes.Metadata, - ) (ChainMessage, map[string]string, *pairingtypes.RelayPrivateData, error) + ) (ProtocolMessage, error) SendParsedRelay( ctx context.Context, dappID string, consumerIp string, analytics *metrics.RelayMetrics, - chainMessage ChainMessage, - directiveHeaders map[string]string, - relayRequestData *pairingtypes.RelayPrivateData, + protocolMessage ProtocolMessage, ) (relayResult *common.RelayResult, errRet error) CreateDappKey(dappID, consumerIp string) string CancelSubscriptionContext(subscriptionKey string) diff --git a/protocol/chainlib/chainlib_mock.go b/protocol/chainlib/chainlib_mock.go index 284a6c9b26..757c2cd9e0 100644 --- a/protocol/chainlib/chainlib_mock.go +++ b/protocol/chainlib/chainlib_mock.go @@ -1,10 +1,5 @@ // Code generated by MockGen. DO NOT EDIT. // Source: protocol/chainlib/chainlib.go -// -// Generated by this command: -// -// mockgen -source protocol/chainlib/chainlib.go -destination protocol/chainlib/chainlib_mock.go -package chainlib -// // Package chainlib is a generated GoMock package. package chainlib @@ -14,6 +9,7 @@ import ( reflect "reflect" time "time" + gomock "github.com/golang/mock/gomock" rpcInterfaceMessages "github.com/lavanet/lava/v2/protocol/chainlib/chainproxy/rpcInterfaceMessages" rpcclient "github.com/lavanet/lava/v2/protocol/chainlib/chainproxy/rpcclient" extensionslib "github.com/lavanet/lava/v2/protocol/chainlib/extensionslib" @@ -21,7 +17,6 @@ import ( metrics "github.com/lavanet/lava/v2/protocol/metrics" types "github.com/lavanet/lava/v2/x/pairing/types" types0 "github.com/lavanet/lava/v2/x/spec/types" - gomock "go.uber.org/mock/gomock" ) // MockChainParser is a mock of ChainParser interface. @@ -100,7 +95,7 @@ func (m *MockChainParser) CraftMessage(parser *types0.ParseDirective, connection } // CraftMessage indicates an expected call of CraftMessage. -func (mr *MockChainParserMockRecorder) CraftMessage(parser, connectionType, craftData, metadata any) *gomock.Call { +func (mr *MockChainParserMockRecorder) CraftMessage(parser, connectionType, craftData, metadata interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CraftMessage", reflect.TypeOf((*MockChainParser)(nil).CraftMessage), parser, connectionType, craftData, metadata) } @@ -145,7 +140,7 @@ func (m *MockChainParser) GetParsingByTag(tag types0.FUNCTION_TAG) (*types0.Pars } // GetParsingByTag indicates an expected call of GetParsingByTag. -func (mr *MockChainParserMockRecorder) GetParsingByTag(tag any) *gomock.Call { +func (mr *MockChainParserMockRecorder) GetParsingByTag(tag interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetParsingByTag", reflect.TypeOf((*MockChainParser)(nil).GetParsingByTag), tag) } @@ -174,7 +169,7 @@ func (m *MockChainParser) GetVerifications(supported []string) ([]VerificationCo } // GetVerifications indicates an expected call of GetVerifications. -func (mr *MockChainParserMockRecorder) GetVerifications(supported any) *gomock.Call { +func (mr *MockChainParserMockRecorder) GetVerifications(supported interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetVerifications", reflect.TypeOf((*MockChainParser)(nil).GetVerifications), supported) } @@ -190,7 +185,7 @@ func (m *MockChainParser) HandleHeaders(metadata []types.Metadata, apiCollection } // HandleHeaders indicates an expected call of HandleHeaders. -func (mr *MockChainParserMockRecorder) HandleHeaders(metadata, apiCollection, headersDirection any) *gomock.Call { +func (mr *MockChainParserMockRecorder) HandleHeaders(metadata, apiCollection, headersDirection interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandleHeaders", reflect.TypeOf((*MockChainParser)(nil).HandleHeaders), metadata, apiCollection, headersDirection) } @@ -205,7 +200,7 @@ func (m *MockChainParser) ParseMsg(url string, data []byte, connectionType strin } // ParseMsg indicates an expected call of ParseMsg. -func (mr *MockChainParserMockRecorder) ParseMsg(url, data, connectionType, metadata, extensionInfo any) *gomock.Call { +func (mr *MockChainParserMockRecorder) ParseMsg(url, data, connectionType, metadata, extensionInfo interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ParseMsg", reflect.TypeOf((*MockChainParser)(nil).ParseMsg), url, data, connectionType, metadata, extensionInfo) } @@ -221,7 +216,7 @@ func (m *MockChainParser) SeparateAddonsExtensions(supported []string) ([]string } // SeparateAddonsExtensions indicates an expected call of SeparateAddonsExtensions. -func (mr *MockChainParserMockRecorder) SeparateAddonsExtensions(supported any) *gomock.Call { +func (mr *MockChainParserMockRecorder) SeparateAddonsExtensions(supported interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SeparateAddonsExtensions", reflect.TypeOf((*MockChainParser)(nil).SeparateAddonsExtensions), supported) } @@ -235,7 +230,7 @@ func (m *MockChainParser) SetPolicy(policy PolicyInf, chainId, apiInterface stri } // SetPolicy indicates an expected call of SetPolicy. -func (mr *MockChainParserMockRecorder) SetPolicy(policy, chainId, apiInterface any) *gomock.Call { +func (mr *MockChainParserMockRecorder) SetPolicy(policy, chainId, apiInterface interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetPolicy", reflect.TypeOf((*MockChainParser)(nil).SetPolicy), policy, chainId, apiInterface) } @@ -247,7 +242,7 @@ func (m *MockChainParser) SetSpec(spec types0.Spec) { } // SetSpec indicates an expected call of SetSpec. -func (mr *MockChainParserMockRecorder) SetSpec(spec any) *gomock.Call { +func (mr *MockChainParserMockRecorder) SetSpec(spec interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetSpec", reflect.TypeOf((*MockChainParser)(nil).SetSpec), spec) } @@ -259,7 +254,7 @@ func (m *MockChainParser) UpdateBlockTime(newBlockTime time.Duration) { } // UpdateBlockTime indicates an expected call of UpdateBlockTime. -func (mr *MockChainParserMockRecorder) UpdateBlockTime(newBlockTime any) *gomock.Call { +func (mr *MockChainParserMockRecorder) UpdateBlockTime(newBlockTime interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateBlockTime", reflect.TypeOf((*MockChainParser)(nil).UpdateBlockTime), newBlockTime) } @@ -294,7 +289,7 @@ func (m *MockChainMessage) AppendHeader(metadata []types.Metadata) { } // AppendHeader indicates an expected call of AppendHeader. -func (mr *MockChainMessageMockRecorder) AppendHeader(metadata any) *gomock.Call { +func (mr *MockChainMessageMockRecorder) AppendHeader(metadata interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AppendHeader", reflect.TypeOf((*MockChainMessage)(nil).AppendHeader), metadata) } @@ -309,7 +304,7 @@ func (m *MockChainMessage) CheckResponseError(data []byte, httpStatusCode int) ( } // CheckResponseError indicates an expected call of CheckResponseError. -func (mr *MockChainMessageMockRecorder) CheckResponseError(data, httpStatusCode any) *gomock.Call { +func (mr *MockChainMessageMockRecorder) CheckResponseError(data, httpStatusCode interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CheckResponseError", reflect.TypeOf((*MockChainMessage)(nil).CheckResponseError), data, httpStatusCode) } @@ -410,6 +405,21 @@ func (mr *MockChainMessageMockRecorder) GetRPCMessage() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetRPCMessage", reflect.TypeOf((*MockChainMessage)(nil).GetRPCMessage)) } +// GetRawRequestHash mocks base method. +func (m *MockChainMessage) GetRawRequestHash() ([]byte, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetRawRequestHash") + ret0, _ := ret[0].([]byte) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetRawRequestHash indicates an expected call of GetRawRequestHash. +func (mr *MockChainMessageMockRecorder) GetRawRequestHash() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetRawRequestHash", reflect.TypeOf((*MockChainMessage)(nil).GetRawRequestHash)) +} + // OverrideExtensions mocks base method. func (m *MockChainMessage) OverrideExtensions(extensionNames []string, extensionParser *extensionslib.ExtensionParser) { m.ctrl.T.Helper() @@ -417,7 +427,7 @@ func (m *MockChainMessage) OverrideExtensions(extensionNames []string, extension } // OverrideExtensions indicates an expected call of OverrideExtensions. -func (mr *MockChainMessageMockRecorder) OverrideExtensions(extensionNames, extensionParser any) *gomock.Call { +func (mr *MockChainMessageMockRecorder) OverrideExtensions(extensionNames, extensionParser interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OverrideExtensions", reflect.TypeOf((*MockChainMessage)(nil).OverrideExtensions), extensionNames, extensionParser) } @@ -446,15 +456,29 @@ func (m *MockChainMessage) SetForceCacheRefresh(force bool) bool { } // SetForceCacheRefresh indicates an expected call of SetForceCacheRefresh. -func (mr *MockChainMessageMockRecorder) SetForceCacheRefresh(force any) *gomock.Call { +func (mr *MockChainMessageMockRecorder) SetForceCacheRefresh(force interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetForceCacheRefresh", reflect.TypeOf((*MockChainMessage)(nil).SetForceCacheRefresh), force) } +// SubscriptionIdExtractor mocks base method. +func (m *MockChainMessage) SubscriptionIdExtractor(reply *rpcclient.JsonrpcMessage) string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SubscriptionIdExtractor", reply) + ret0, _ := ret[0].(string) + return ret0 +} + +// SubscriptionIdExtractor indicates an expected call of SubscriptionIdExtractor. +func (mr *MockChainMessageMockRecorder) SubscriptionIdExtractor(reply interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SubscriptionIdExtractor", reflect.TypeOf((*MockChainMessage)(nil).SubscriptionIdExtractor), reply) +} + // TimeoutOverride mocks base method. func (m *MockChainMessage) TimeoutOverride(arg0 ...time.Duration) time.Duration { m.ctrl.T.Helper() - varargs := []any{} + varargs := []interface{}{} for _, a := range arg0 { varargs = append(varargs, a) } @@ -464,7 +488,7 @@ func (m *MockChainMessage) TimeoutOverride(arg0 ...time.Duration) time.Duration } // TimeoutOverride indicates an expected call of TimeoutOverride. -func (mr *MockChainMessageMockRecorder) TimeoutOverride(arg0 ...any) *gomock.Call { +func (mr *MockChainMessageMockRecorder) TimeoutOverride(arg0 ...interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TimeoutOverride", reflect.TypeOf((*MockChainMessage)(nil).TimeoutOverride), arg0...) } @@ -478,7 +502,7 @@ func (m *MockChainMessage) UpdateLatestBlockInMessage(latestBlock int64, modifyC } // UpdateLatestBlockInMessage indicates an expected call of UpdateLatestBlockInMessage. -func (mr *MockChainMessageMockRecorder) UpdateLatestBlockInMessage(latestBlock, modifyContent any) *gomock.Call { +func (mr *MockChainMessageMockRecorder) UpdateLatestBlockInMessage(latestBlock, modifyContent interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateLatestBlockInMessage", reflect.TypeOf((*MockChainMessage)(nil).UpdateLatestBlockInMessage), latestBlock, modifyContent) } @@ -506,6 +530,21 @@ func (m *MockChainMessageForSend) EXPECT() *MockChainMessageForSendMockRecorder return m.recorder } +// CheckResponseError mocks base method. +func (m *MockChainMessageForSend) CheckResponseError(data []byte, httpStatusCode int) (bool, string) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CheckResponseError", data, httpStatusCode) + ret0, _ := ret[0].(bool) + ret1, _ := ret[1].(string) + return ret0, ret1 +} + +// CheckResponseError indicates an expected call of CheckResponseError. +func (mr *MockChainMessageForSendMockRecorder) CheckResponseError(data, httpStatusCode interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CheckResponseError", reflect.TypeOf((*MockChainMessageForSend)(nil).CheckResponseError), data, httpStatusCode) +} + // GetApi mocks base method. func (m *MockChainMessageForSend) GetApi() *types0.Api { m.ctrl.T.Helper() @@ -565,7 +604,7 @@ func (mr *MockChainMessageForSendMockRecorder) GetRPCMessage() *gomock.Call { // TimeoutOverride mocks base method. func (m *MockChainMessageForSend) TimeoutOverride(arg0 ...time.Duration) time.Duration { m.ctrl.T.Helper() - varargs := []any{} + varargs := []interface{}{} for _, a := range arg0 { varargs = append(varargs, a) } @@ -575,7 +614,7 @@ func (m *MockChainMessageForSend) TimeoutOverride(arg0 ...time.Duration) time.Du } // TimeoutOverride indicates an expected call of TimeoutOverride. -func (mr *MockChainMessageForSendMockRecorder) TimeoutOverride(arg0 ...any) *gomock.Call { +func (mr *MockChainMessageForSendMockRecorder) TimeoutOverride(arg0 ...interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TimeoutOverride", reflect.TypeOf((*MockChainMessageForSend)(nil).TimeoutOverride), arg0...) } @@ -647,7 +686,7 @@ func (m *MockRelaySender) CancelSubscriptionContext(subscriptionKey string) { } // CancelSubscriptionContext indicates an expected call of CancelSubscriptionContext. -func (mr *MockRelaySenderMockRecorder) CancelSubscriptionContext(subscriptionKey any) *gomock.Call { +func (mr *MockRelaySenderMockRecorder) CancelSubscriptionContext(subscriptionKey interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CancelSubscriptionContext", reflect.TypeOf((*MockRelaySender)(nil).CancelSubscriptionContext), subscriptionKey) } @@ -661,41 +700,39 @@ func (m *MockRelaySender) CreateDappKey(dappID, consumerIp string) string { } // CreateDappKey indicates an expected call of CreateDappKey. -func (mr *MockRelaySenderMockRecorder) CreateDappKey(dappID, consumerIp any) *gomock.Call { +func (mr *MockRelaySenderMockRecorder) CreateDappKey(dappID, consumerIp interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateDappKey", reflect.TypeOf((*MockRelaySender)(nil).CreateDappKey), dappID, consumerIp) } // ParseRelay mocks base method. -func (m *MockRelaySender) ParseRelay(ctx context.Context, url, req, connectionType, dappID, consumerIp string, analytics *metrics.RelayMetrics, metadata []types.Metadata) (ChainMessage, map[string]string, *types.RelayPrivateData, error) { +func (m *MockRelaySender) ParseRelay(ctx context.Context, url, req, connectionType, dappID, consumerIp string, analytics *metrics.RelayMetrics, metadata []types.Metadata) (ProtocolMessage, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "ParseRelay", ctx, url, req, connectionType, dappID, consumerIp, analytics, metadata) - ret0, _ := ret[0].(ChainMessage) - ret1, _ := ret[1].(map[string]string) - ret2, _ := ret[2].(*types.RelayPrivateData) - ret3, _ := ret[3].(error) - return ret0, ret1, ret2, ret3 + ret0, _ := ret[0].(ProtocolMessage) + ret1, _ := ret[1].(error) + return ret0, ret1 } // ParseRelay indicates an expected call of ParseRelay. -func (mr *MockRelaySenderMockRecorder) ParseRelay(ctx, url, req, connectionType, dappID, consumerIp, analytics, metadata any) *gomock.Call { +func (mr *MockRelaySenderMockRecorder) ParseRelay(ctx, url, req, connectionType, dappID, consumerIp, analytics, metadata interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ParseRelay", reflect.TypeOf((*MockRelaySender)(nil).ParseRelay), ctx, url, req, connectionType, dappID, consumerIp, analytics, metadata) } // SendParsedRelay mocks base method. -func (m *MockRelaySender) SendParsedRelay(ctx context.Context, dappID, consumerIp string, analytics *metrics.RelayMetrics, chainMessage ChainMessage, directiveHeaders map[string]string, relayRequestData *types.RelayPrivateData) (*common.RelayResult, error) { +func (m *MockRelaySender) SendParsedRelay(ctx context.Context, dappID, consumerIp string, analytics *metrics.RelayMetrics, protocolMessage ProtocolMessage) (*common.RelayResult, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "SendParsedRelay", ctx, dappID, consumerIp, analytics, chainMessage, directiveHeaders, relayRequestData) + ret := m.ctrl.Call(m, "SendParsedRelay", ctx, dappID, consumerIp, analytics, protocolMessage) ret0, _ := ret[0].(*common.RelayResult) ret1, _ := ret[1].(error) return ret0, ret1 } // SendParsedRelay indicates an expected call of SendParsedRelay. -func (mr *MockRelaySenderMockRecorder) SendParsedRelay(ctx, dappID, consumerIp, analytics, chainMessage, directiveHeaders, relayRequestData any) *gomock.Call { +func (mr *MockRelaySenderMockRecorder) SendParsedRelay(ctx, dappID, consumerIp, analytics, protocolMessage interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendParsedRelay", reflect.TypeOf((*MockRelaySender)(nil).SendParsedRelay), ctx, dappID, consumerIp, analytics, chainMessage, directiveHeaders, relayRequestData) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendParsedRelay", reflect.TypeOf((*MockRelaySender)(nil).SendParsedRelay), ctx, dappID, consumerIp, analytics, protocolMessage) } // SendRelay mocks base method. @@ -708,7 +745,7 @@ func (m *MockRelaySender) SendRelay(ctx context.Context, url, req, connectionTyp } // SendRelay indicates an expected call of SendRelay. -func (mr *MockRelaySenderMockRecorder) SendRelay(ctx, url, req, connectionType, dappID, consumerIp, analytics, metadataValues any) *gomock.Call { +func (mr *MockRelaySenderMockRecorder) SendRelay(ctx, url, req, connectionType, dappID, consumerIp, analytics, metadataValues interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendRelay", reflect.TypeOf((*MockRelaySender)(nil).SendRelay), ctx, url, req, connectionType, dappID, consumerIp, analytics, metadataValues) } @@ -720,7 +757,7 @@ func (m *MockRelaySender) SetConsistencySeenBlock(blockSeen int64, key string) { } // SetConsistencySeenBlock indicates an expected call of SetConsistencySeenBlock. -func (mr *MockRelaySenderMockRecorder) SetConsistencySeenBlock(blockSeen, key any) *gomock.Call { +func (mr *MockRelaySenderMockRecorder) SetConsistencySeenBlock(blockSeen, key interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetConsistencySeenBlock", reflect.TypeOf((*MockRelaySender)(nil).SetConsistencySeenBlock), blockSeen, key) } @@ -748,6 +785,20 @@ func (m *MockChainListener) EXPECT() *MockChainListenerMockRecorder { return m.recorder } +// GetListeningAddress mocks base method. +func (m *MockChainListener) GetListeningAddress() string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetListeningAddress") + ret0, _ := ret[0].(string) + return ret0 +} + +// GetListeningAddress indicates an expected call of GetListeningAddress. +func (mr *MockChainListenerMockRecorder) GetListeningAddress() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetListeningAddress", reflect.TypeOf((*MockChainListener)(nil).GetListeningAddress)) +} + // Serve mocks base method. func (m *MockChainListener) Serve(ctx context.Context, cmdFlags common.ConsumerCmdFlags) { m.ctrl.T.Helper() @@ -755,7 +806,7 @@ func (m *MockChainListener) Serve(ctx context.Context, cmdFlags common.ConsumerC } // Serve indicates an expected call of Serve. -func (mr *MockChainListenerMockRecorder) Serve(ctx, cmdFlags any) *gomock.Call { +func (mr *MockChainListenerMockRecorder) Serve(ctx, cmdFlags interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Serve", reflect.TypeOf((*MockChainListener)(nil).Serve), ctx, cmdFlags) } @@ -792,13 +843,13 @@ func (m *MockChainRouter) ExtensionsSupported(arg0 []string) bool { } // ExtensionsSupported indicates an expected call of ExtensionsSupported. -func (mr *MockChainRouterMockRecorder) ExtensionsSupported(arg0 any) *gomock.Call { +func (mr *MockChainRouterMockRecorder) ExtensionsSupported(arg0 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ExtensionsSupported", reflect.TypeOf((*MockChainRouter)(nil).ExtensionsSupported), arg0) } // SendNodeMsg mocks base method. -func (m *MockChainRouter) SendNodeMsg(ctx context.Context, ch chan any, chainMessage ChainMessageForSend, extensions []string) (*RelayReplyWrapper, string, *rpcclient.ClientSubscription, common.NodeUrl, string, error) { +func (m *MockChainRouter) SendNodeMsg(ctx context.Context, ch chan interface{}, chainMessage ChainMessageForSend, extensions []string) (*RelayReplyWrapper, string, *rpcclient.ClientSubscription, common.NodeUrl, string, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "SendNodeMsg", ctx, ch, chainMessage, extensions) ret0, _ := ret[0].(*RelayReplyWrapper) @@ -811,7 +862,7 @@ func (m *MockChainRouter) SendNodeMsg(ctx context.Context, ch chan any, chainMes } // SendNodeMsg indicates an expected call of SendNodeMsg. -func (mr *MockChainRouterMockRecorder) SendNodeMsg(ctx, ch, chainMessage, extensions any) *gomock.Call { +func (mr *MockChainRouterMockRecorder) SendNodeMsg(ctx, ch, chainMessage, extensions interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendNodeMsg", reflect.TypeOf((*MockChainRouter)(nil).SendNodeMsg), ctx, ch, chainMessage, extensions) } @@ -855,7 +906,7 @@ func (mr *MockChainProxyMockRecorder) GetChainProxyInformation() *gomock.Call { } // SendNodeMsg mocks base method. -func (m *MockChainProxy) SendNodeMsg(ctx context.Context, ch chan any, chainMessage ChainMessageForSend) (*RelayReplyWrapper, string, *rpcclient.ClientSubscription, error) { +func (m *MockChainProxy) SendNodeMsg(ctx context.Context, ch chan interface{}, chainMessage ChainMessageForSend) (*RelayReplyWrapper, string, *rpcclient.ClientSubscription, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "SendNodeMsg", ctx, ch, chainMessage) ret0, _ := ret[0].(*RelayReplyWrapper) @@ -866,7 +917,7 @@ func (m *MockChainProxy) SendNodeMsg(ctx context.Context, ch chan any, chainMess } // SendNodeMsg indicates an expected call of SendNodeMsg. -func (mr *MockChainProxyMockRecorder) SendNodeMsg(ctx, ch, chainMessage any) *gomock.Call { +func (mr *MockChainProxyMockRecorder) SendNodeMsg(ctx, ch, chainMessage interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendNodeMsg", reflect.TypeOf((*MockChainProxy)(nil).SendNodeMsg), ctx, ch, chainMessage) } diff --git a/protocol/chainlib/consumer_websocket_manager.go b/protocol/chainlib/consumer_websocket_manager.go index 3017328db4..75dd3ca1ba 100644 --- a/protocol/chainlib/consumer_websocket_manager.go +++ b/protocol/chainlib/consumer_websocket_manager.go @@ -149,7 +149,7 @@ func (cwm *ConsumerWebsocketManager) ListenToMessages() { metricsData := metrics.NewRelayAnalytics(dappID, cwm.chainId, cwm.apiInterface) - chainMessage, directiveHeaders, relayRequestData, err := cwm.relaySender.ParseRelay(webSocketCtx, "", string(msg), cwm.connectionType, dappID, userIp, metricsData, nil) + protocolMessage, err := cwm.relaySender.ParseRelay(webSocketCtx, "", string(msg), cwm.connectionType, dappID, userIp, metricsData, nil) if err != nil { formatterMsg := logger.AnalyzeWebSocketErrorAndGetFormattedMessage(websocketConn.LocalAddr().String(), utils.LavaFormatError("could not parse message", err), msgSeed, msg, cwm.apiInterface, time.Since(startTime)) if formatterMsg != nil { @@ -159,9 +159,9 @@ func (cwm *ConsumerWebsocketManager) ListenToMessages() { } // check whether its a normal relay / unsubscribe / unsubscribe_all otherwise its a subscription flow. - if !IsFunctionTagOfType(chainMessage, spectypes.FUNCTION_TAG_SUBSCRIBE) { - if IsFunctionTagOfType(chainMessage, spectypes.FUNCTION_TAG_UNSUBSCRIBE) { - err := cwm.consumerWsSubscriptionManager.Unsubscribe(webSocketCtx, chainMessage, directiveHeaders, relayRequestData, dappID, userIp, cwm.WebsocketConnectionUID, metricsData) + if !IsFunctionTagOfType(protocolMessage, spectypes.FUNCTION_TAG_SUBSCRIBE) { + if IsFunctionTagOfType(protocolMessage, spectypes.FUNCTION_TAG_UNSUBSCRIBE) { + err := cwm.consumerWsSubscriptionManager.Unsubscribe(webSocketCtx, protocolMessage, dappID, userIp, cwm.WebsocketConnectionUID, metricsData) if err != nil { utils.LavaFormatWarning("error unsubscribing from subscription", err, utils.LogAttr("GUID", webSocketCtx)) if err == common.SubscriptionNotFoundError { @@ -174,7 +174,7 @@ func (cwm *ConsumerWebsocketManager) ListenToMessages() { } } continue - } else if IsFunctionTagOfType(chainMessage, spectypes.FUNCTION_TAG_UNSUBSCRIBE_ALL) { + } else if IsFunctionTagOfType(protocolMessage, spectypes.FUNCTION_TAG_UNSUBSCRIBE_ALL) { err := cwm.consumerWsSubscriptionManager.UnsubscribeAll(webSocketCtx, dappID, userIp, cwm.WebsocketConnectionUID, metricsData) if err != nil { utils.LavaFormatWarning("error unsubscribing from all subscription", err, utils.LogAttr("GUID", webSocketCtx)) @@ -182,7 +182,7 @@ func (cwm *ConsumerWebsocketManager) ListenToMessages() { continue } else { // Normal relay over websocket. (not subscription related) - relayResult, err := cwm.relaySender.SendParsedRelay(webSocketCtx, dappID, userIp, metricsData, chainMessage, directiveHeaders, relayRequestData) + relayResult, err := cwm.relaySender.SendParsedRelay(webSocketCtx, dappID, userIp, metricsData, protocolMessage) if err != nil { formatterMsg := logger.AnalyzeWebSocketErrorAndGetFormattedMessage(websocketConn.LocalAddr().String(), utils.LavaFormatError("could not send parsed relay", err), msgSeed, msg, cwm.apiInterface, time.Since(startTime)) if formatterMsg != nil { @@ -202,16 +202,16 @@ func (cwm *ConsumerWebsocketManager) ListenToMessages() { } // Subscription flow - inputFormatter, outputFormatter := formatter.FormatterForRelayRequestAndResponse(relayRequestData.ApiInterface) // we use this to preserve the original jsonrpc id - inputFormatter(relayRequestData.Data) // set the extracted jsonrpc id + inputFormatter, outputFormatter := formatter.FormatterForRelayRequestAndResponse(protocolMessage.GetApiCollection().CollectionData.ApiInterface) // we use this to preserve the original jsonrpc id + inputFormatter(protocolMessage.RelayPrivateData().Data) // set the extracted jsonrpc id - reply, subscriptionMsgsChan, err := cwm.consumerWsSubscriptionManager.StartSubscription(webSocketCtx, chainMessage, directiveHeaders, relayRequestData, dappID, userIp, cwm.WebsocketConnectionUID, metricsData) + reply, subscriptionMsgsChan, err := cwm.consumerWsSubscriptionManager.StartSubscription(webSocketCtx, protocolMessage, dappID, userIp, cwm.WebsocketConnectionUID, metricsData) if err != nil { utils.LavaFormatWarning("StartSubscription returned an error", err, utils.LogAttr("GUID", webSocketCtx), utils.LogAttr("dappID", dappID), utils.LogAttr("userIp", userIp), - utils.LogAttr("params", chainMessage.GetRPCMessage().GetParams()), + utils.LogAttr("params", protocolMessage.GetRPCMessage().GetParams()), ) formatterMsg := logger.AnalyzeWebSocketErrorAndGetFormattedMessage(websocketConn.LocalAddr().String(), utils.LavaFormatError("could not start subscription", err), msgSeed, msg, cwm.apiInterface, time.Since(startTime)) @@ -239,7 +239,7 @@ func (cwm *ConsumerWebsocketManager) ListenToMessages() { utils.LogAttr("GUID", webSocketCtx), utils.LogAttr("dappID", dappID), utils.LogAttr("userIp", userIp), - utils.LogAttr("params", chainMessage.GetRPCMessage().GetParams()), + utils.LogAttr("params", protocolMessage.GetRPCMessage().GetParams()), ) for subscriptionMsgReply := range subscriptionMsgsChan { @@ -250,7 +250,7 @@ func (cwm *ConsumerWebsocketManager) ListenToMessages() { utils.LogAttr("GUID", webSocketCtx), utils.LogAttr("dappID", dappID), utils.LogAttr("userIp", userIp), - utils.LogAttr("params", chainMessage.GetRPCMessage().GetParams()), + utils.LogAttr("params", protocolMessage.GetRPCMessage().GetParams()), ) }() } diff --git a/protocol/chainlib/consumer_ws_subscription_manager.go b/protocol/chainlib/consumer_ws_subscription_manager.go index 6b993588cd..dda1405573 100644 --- a/protocol/chainlib/consumer_ws_subscription_manager.go +++ b/protocol/chainlib/consumer_ws_subscription_manager.go @@ -19,9 +19,7 @@ import ( ) type unsubscribeRelayData struct { - chainMessage ChainMessage - directiveHeaders map[string]string - relayRequestData *pairingtypes.RelayPrivateData + protocolMessage ProtocolMessage } type activeSubscriptionHolder struct { @@ -186,15 +184,13 @@ func (cwsm *ConsumerWSSubscriptionManager) checkForActiveSubscriptionWithLock( func (cwsm *ConsumerWSSubscriptionManager) StartSubscription( webSocketCtx context.Context, - chainMessage ChainMessage, - directiveHeaders map[string]string, - relayRequestData *pairingtypes.RelayPrivateData, + protocolMessage ProtocolMessage, dappID string, consumerIp string, webSocketConnectionUniqueId string, metricsData *metrics.RelayMetrics, ) (firstReply *pairingtypes.RelayReply, repliesChan <-chan *pairingtypes.RelayReply, err error) { - hashedParams, _, err := cwsm.getHashedParams(chainMessage) + hashedParams, _, err := cwsm.getHashedParams(protocolMessage) if err != nil { return nil, nil, utils.LavaFormatError("could not marshal params", err) } @@ -229,7 +225,7 @@ func (cwsm *ConsumerWSSubscriptionManager) StartSubscription( <-closeWebsocketRepliesChan utils.LavaFormatTrace("requested to close websocketRepliesChan", utils.LogAttr("GUID", webSocketCtx), - utils.LogAttr("params", chainMessage.GetRPCMessage().GetParams()), + utils.LogAttr("params", protocolMessage.GetRPCMessage().GetParams()), utils.LogAttr("hashedParams", utils.ToHexString(hashedParams)), utils.LogAttr("dappKey", dappKey), ) @@ -242,7 +238,7 @@ func (cwsm *ConsumerWSSubscriptionManager) StartSubscription( <-webSocketCtx.Done() utils.LavaFormatTrace("websocket context is done, removing websocket from active subscriptions", utils.LogAttr("GUID", webSocketCtx), - utils.LogAttr("params", chainMessage.GetRPCMessage().GetParams()), + utils.LogAttr("params", protocolMessage.GetRPCMessage().GetParams()), utils.LogAttr("hashedParams", utils.ToHexString(hashedParams)), utils.LogAttr("dappKey", dappKey), ) @@ -258,7 +254,7 @@ func (cwsm *ConsumerWSSubscriptionManager) StartSubscription( }() // Validated there are no active subscriptions that we can use. - firstSubscriptionReply, returnWebsocketRepliesChan := cwsm.checkForActiveSubscriptionWithLock(webSocketCtx, hashedParams, chainMessage, dappKey, websocketRepliesSafeChannelSender, closeWebsocketRepliesChannel) + firstSubscriptionReply, returnWebsocketRepliesChan := cwsm.checkForActiveSubscriptionWithLock(webSocketCtx, hashedParams, protocolMessage, dappKey, websocketRepliesSafeChannelSender, closeWebsocketRepliesChannel) if firstSubscriptionReply != nil { if returnWebsocketRepliesChan { return firstSubscriptionReply, websocketRepliesChan, nil @@ -279,7 +275,7 @@ func (cwsm *ConsumerWSSubscriptionManager) StartSubscription( utils.LavaFormatTrace("Finished pending for subscription, have results", utils.LogAttr("success", res)) // Check res is valid, if not fall through logs and try again with a new client. if res { - firstSubscriptionReply, returnWebsocketRepliesChan := cwsm.checkForActiveSubscriptionWithLock(webSocketCtx, hashedParams, chainMessage, dappKey, websocketRepliesSafeChannelSender, closeWebsocketRepliesChannel) + firstSubscriptionReply, returnWebsocketRepliesChan := cwsm.checkForActiveSubscriptionWithLock(webSocketCtx, hashedParams, protocolMessage, dappKey, websocketRepliesSafeChannelSender, closeWebsocketRepliesChannel) if firstSubscriptionReply != nil { if returnWebsocketRepliesChan { return firstSubscriptionReply, websocketRepliesChan, nil @@ -300,12 +296,12 @@ func (cwsm *ConsumerWSSubscriptionManager) StartSubscription( utils.LavaFormatTrace("could not find active subscription for given params, creating new one", utils.LogAttr("GUID", webSocketCtx), - utils.LogAttr("params", chainMessage.GetRPCMessage().GetParams()), + utils.LogAttr("params", protocolMessage.GetRPCMessage().GetParams()), utils.LogAttr("hashedParams", utils.ToHexString(hashedParams)), utils.LogAttr("dappKey", dappKey), ) - relayResult, err := cwsm.relaySender.SendParsedRelay(webSocketCtx, dappID, consumerIp, metricsData, chainMessage, directiveHeaders, relayRequestData) + relayResult, err := cwsm.relaySender.SendParsedRelay(webSocketCtx, dappID, consumerIp, metricsData, protocolMessage) if err != nil { onSubscriptionFailure() return nil, nil, utils.LavaFormatError("could not send subscription relay", err) @@ -313,7 +309,7 @@ func (cwsm *ConsumerWSSubscriptionManager) StartSubscription( utils.LavaFormatTrace("got relay result from SendParsedRelay", utils.LogAttr("GUID", webSocketCtx), - utils.LogAttr("params", chainMessage.GetRPCMessage().GetParams()), + utils.LogAttr("params", protocolMessage.GetRPCMessage().GetParams()), utils.LogAttr("hashedParams", utils.ToHexString(hashedParams)), utils.LogAttr("dappKey", dappKey), utils.LogAttr("relayResult", relayResult), @@ -325,7 +321,7 @@ func (cwsm *ConsumerWSSubscriptionManager) StartSubscription( onSubscriptionFailure() return nil, nil, utils.LavaFormatError("reply server is nil, probably an error with the subscription initiation", nil, utils.LogAttr("GUID", webSocketCtx), - utils.LogAttr("params", chainMessage.GetRPCMessage().GetParams()), + utils.LogAttr("params", protocolMessage.GetRPCMessage().GetParams()), utils.LogAttr("hashedParams", utils.ToHexString(hashedParams)), utils.LogAttr("dappKey", dappKey), ) @@ -336,7 +332,7 @@ func (cwsm *ConsumerWSSubscriptionManager) StartSubscription( onSubscriptionFailure() return nil, nil, utils.LavaFormatError("Reply data is nil", nil, utils.LogAttr("GUID", webSocketCtx), - utils.LogAttr("params", chainMessage.GetRPCMessage().GetParams()), + utils.LogAttr("params", protocolMessage.GetRPCMessage().GetParams()), utils.LogAttr("hashedParams", utils.ToHexString(hashedParams)), utils.LogAttr("dappKey", dappKey), ) @@ -348,18 +344,18 @@ func (cwsm *ConsumerWSSubscriptionManager) StartSubscription( onSubscriptionFailure() return nil, nil, utils.LavaFormatError("could not copy relay request", err, utils.LogAttr("GUID", webSocketCtx), - utils.LogAttr("params", chainMessage.GetRPCMessage().GetParams()), + utils.LogAttr("params", protocolMessage.GetRPCMessage().GetParams()), utils.LogAttr("hashedParams", utils.ToHexString(hashedParams)), utils.LogAttr("dappKey", dappKey), ) } - err = cwsm.verifySubscriptionMessage(hashedParams, chainMessage, relayResult.Request, &reply, relayResult.ProviderInfo.ProviderAddress) + err = cwsm.verifySubscriptionMessage(hashedParams, protocolMessage, relayResult.Request, &reply, relayResult.ProviderInfo.ProviderAddress) if err != nil { onSubscriptionFailure() return nil, nil, utils.LavaFormatError("Failed VerifyRelayReply on subscription message", err, utils.LogAttr("GUID", webSocketCtx), - utils.LogAttr("params", chainMessage.GetRPCMessage().GetParams()), + utils.LogAttr("params", protocolMessage.GetRPCMessage().GetParams()), utils.LogAttr("hashedParams", utils.ToHexString(hashedParams)), utils.LogAttr("dappKey", dappKey), utils.LogAttr("reply", string(reply.Data)), @@ -373,7 +369,7 @@ func (cwsm *ConsumerWSSubscriptionManager) StartSubscription( onSubscriptionFailure() return nil, nil, utils.LavaFormatError("could not parse reply into json", err, utils.LogAttr("GUID", webSocketCtx), - utils.LogAttr("params", chainMessage.GetRPCMessage().GetParams()), + utils.LogAttr("params", protocolMessage.GetRPCMessage().GetParams()), utils.LogAttr("hashedParams", utils.ToHexString(hashedParams)), utils.LogAttr("dappKey", dappKey), utils.LogAttr("reply", reply.Data), @@ -391,7 +387,7 @@ func (cwsm *ConsumerWSSubscriptionManager) StartSubscription( cwsm.lock.Lock() defer cwsm.lock.Unlock() - subscriptionId := chainMessage.SubscriptionIdExtractor(&replyJsonrpcMessage) + subscriptionId := protocolMessage.SubscriptionIdExtractor(&replyJsonrpcMessage) subscriptionId = common.UnSquareBracket(subscriptionId) if common.IsQuoted(subscriptionId) { subscriptionId, _ = strconv.Unquote(subscriptionId) @@ -404,7 +400,7 @@ func (cwsm *ConsumerWSSubscriptionManager) StartSubscription( firstSubscriptionReplyAsJsonrpcMessage: &replyJsonrpcMessage, replyServer: replyServer, subscriptionOriginalRequest: copiedRequest, - subscriptionOriginalRequestChainMessage: chainMessage, + subscriptionOriginalRequestChainMessage: protocolMessage, closeSubscriptionChan: closeSubscriptionChan, connectedDappKeys: map[string]struct{}{dappKey: {}}, subscriptionId: subscriptionId, @@ -458,9 +454,7 @@ func (cwsm *ConsumerWSSubscriptionManager) listenForSubscriptionMessages( // we run the unsubscribe flow in an inner function so it wont prevent us from removing the activeSubscriptions at the end. func() { var err error - var chainMessage ChainMessage - var directiveHeaders map[string]string - var relayRequestData *pairingtypes.RelayPrivateData + var protocolMessage ProtocolMessage if unsubscribeData != nil { // This unsubscribe request was initiated by the user utils.LavaFormatTrace("unsubscribe request was made by the user", @@ -468,9 +462,7 @@ func (cwsm *ConsumerWSSubscriptionManager) listenForSubscriptionMessages( utils.LogAttr("hashedParams", utils.ToHexString(hashedParams)), ) - chainMessage = unsubscribeData.chainMessage - directiveHeaders = unsubscribeData.directiveHeaders - relayRequestData = unsubscribeData.relayRequestData + protocolMessage = unsubscribeData.protocolMessage } else { // This unsubscribe request was initiated by us utils.LavaFormatTrace("unsubscribe request was made automatically", @@ -478,13 +470,13 @@ func (cwsm *ConsumerWSSubscriptionManager) listenForSubscriptionMessages( utils.LogAttr("hashedParams", utils.ToHexString(hashedParams)), ) - chainMessage, directiveHeaders, relayRequestData, err = cwsm.craftUnsubscribeMessage(hashedParams, dappID, userIp, metricsData) + protocolMessage, err = cwsm.craftUnsubscribeMessage(hashedParams, dappID, userIp, metricsData) if err != nil { utils.LavaFormatError("could not craft unsubscribe message", err, utils.LogAttr("GUID", webSocketCtx)) return } - stringJson, err := gojson.Marshal(chainMessage.GetRPCMessage()) + stringJson, err := gojson.Marshal(protocolMessage.GetRPCMessage()) if err != nil { utils.LavaFormatError("could not marshal chain message", err, utils.LogAttr("GUID", webSocketCtx)) return @@ -498,16 +490,16 @@ func (cwsm *ConsumerWSSubscriptionManager) listenForSubscriptionMessages( } unsubscribeRelayCtx := utils.WithUniqueIdentifier(context.Background(), utils.GenerateUniqueIdentifier()) - err = cwsm.sendUnsubscribeMessage(unsubscribeRelayCtx, dappID, userIp, chainMessage, directiveHeaders, relayRequestData, metricsData) + err = cwsm.sendUnsubscribeMessage(unsubscribeRelayCtx, dappID, userIp, protocolMessage, metricsData) if err != nil { utils.LavaFormatError("could not send unsubscribe message due to a relay error", err, utils.LogAttr("GUID", webSocketCtx), - utils.LogAttr("relayRequestData", relayRequestData), + utils.LogAttr("relayRequestData", protocolMessage.RelayPrivateData()), utils.LogAttr("dappID", dappID), utils.LogAttr("userIp", userIp), - utils.LogAttr("api", chainMessage.GetApi().Name), - utils.LogAttr("params", chainMessage.GetRPCMessage().GetParams()), + utils.LogAttr("api", protocolMessage.GetApi().Name), + utils.LogAttr("params", protocolMessage.GetRPCMessage().GetParams()), ) } else { utils.LavaFormatTrace("success sending unsubscribe message, deleting hashed params from activeSubscriptions", @@ -645,7 +637,7 @@ func (cwsm *ConsumerWSSubscriptionManager) getHashedParams(chainMessage ChainMes return hashedParams, params, nil } -func (cwsm *ConsumerWSSubscriptionManager) Unsubscribe(webSocketCtx context.Context, chainMessage ChainMessage, directiveHeaders map[string]string, relayRequestData *pairingtypes.RelayPrivateData, dappID, consumerIp string, webSocketConnectionUniqueId string, metricsData *metrics.RelayMetrics) error { +func (cwsm *ConsumerWSSubscriptionManager) Unsubscribe(webSocketCtx context.Context, protocolMessage ProtocolMessage, dappID, consumerIp string, webSocketConnectionUniqueId string, metricsData *metrics.RelayMetrics) error { utils.LavaFormatTrace("want to unsubscribe", utils.LogAttr("GUID", webSocketCtx), utils.LogAttr("dappID", dappID), @@ -657,16 +649,16 @@ func (cwsm *ConsumerWSSubscriptionManager) Unsubscribe(webSocketCtx context.Cont cwsm.lock.Lock() defer cwsm.lock.Unlock() - hashedParams, err := cwsm.findActiveSubscriptionHashedParamsFromChainMessage(chainMessage) + hashedParams, err := cwsm.findActiveSubscriptionHashedParamsFromChainMessage(protocolMessage) if err != nil { return err } return cwsm.verifyAndDisconnectDappFromSubscription(webSocketCtx, dappKey, hashedParams, func() (*unsubscribeRelayData, error) { - return &unsubscribeRelayData{chainMessage, directiveHeaders, relayRequestData}, nil + return &unsubscribeRelayData{protocolMessage}, nil }) } -func (cwsm *ConsumerWSSubscriptionManager) craftUnsubscribeMessage(hashedParams, dappID, consumerIp string, metricsData *metrics.RelayMetrics) (ChainMessage, map[string]string, *pairingtypes.RelayPrivateData, error) { +func (cwsm *ConsumerWSSubscriptionManager) craftUnsubscribeMessage(hashedParams, dappID, consumerIp string, metricsData *metrics.RelayMetrics) (ProtocolMessage, error) { request := cwsm.activeSubscriptions[hashedParams].subscriptionOriginalRequestChainMessage subscriptionId := cwsm.activeSubscriptions[hashedParams].subscriptionId @@ -682,14 +674,14 @@ func (cwsm *ConsumerWSSubscriptionManager) craftUnsubscribeMessage(hashedParams, } if !found { - return nil, nil, nil, utils.LavaFormatError("could not find unsubscribe parse directive for given chain message", nil, + return nil, utils.LavaFormatError("could not find unsubscribe parse directive for given chain message", nil, utils.LogAttr("hashedParams", utils.ToHexString(hashedParams)), utils.LogAttr("subscriptionId", subscriptionId), ) } if unsubscribeRequestData == "" { - return nil, nil, nil, utils.LavaFormatError("unsubscribe request data is empty", nil, + return nil, utils.LavaFormatError("unsubscribe request data is empty", nil, utils.LogAttr("hashedParams", utils.ToHexString(hashedParams)), utils.LogAttr("subscriptionId", subscriptionId), ) @@ -697,9 +689,9 @@ func (cwsm *ConsumerWSSubscriptionManager) craftUnsubscribeMessage(hashedParams, // Craft the unsubscribe chain message ctx := context.Background() - chainMessage, directiveHeaders, relayRequestData, err := cwsm.relaySender.ParseRelay(ctx, "", unsubscribeRequestData, cwsm.connectionType, dappID, consumerIp, metricsData, nil) + protocolMessage, err := cwsm.relaySender.ParseRelay(ctx, "", unsubscribeRequestData, cwsm.connectionType, dappID, consumerIp, metricsData, nil) if err != nil { - return nil, nil, nil, utils.LavaFormatError("could not craft unsubscribe chain message", err, + return nil, utils.LavaFormatError("could not craft unsubscribe chain message", err, utils.LogAttr("hashedParams", utils.ToHexString(hashedParams)), utils.LogAttr("subscriptionId", subscriptionId), utils.LogAttr("unsubscribeRequestData", unsubscribeRequestData), @@ -707,10 +699,10 @@ func (cwsm *ConsumerWSSubscriptionManager) craftUnsubscribeMessage(hashedParams, ) } - return chainMessage, directiveHeaders, relayRequestData, nil + return protocolMessage, nil } -func (cwsm *ConsumerWSSubscriptionManager) sendUnsubscribeMessage(ctx context.Context, dappID, consumerIp string, chainMessage ChainMessage, directiveHeaders map[string]string, relayRequestData *pairingtypes.RelayPrivateData, metricsData *metrics.RelayMetrics) error { +func (cwsm *ConsumerWSSubscriptionManager) sendUnsubscribeMessage(ctx context.Context, dappID, consumerIp string, protocolMessage ProtocolMessage, metricsData *metrics.RelayMetrics) error { // Send the crafted unsubscribe relay utils.LavaFormatTrace("sending unsubscribe relay", utils.LogAttr("GUID", ctx), @@ -718,7 +710,7 @@ func (cwsm *ConsumerWSSubscriptionManager) sendUnsubscribeMessage(ctx context.Co utils.LogAttr("consumerIp", consumerIp), ) - _, err := cwsm.relaySender.SendParsedRelay(ctx, dappID, consumerIp, metricsData, chainMessage, directiveHeaders, relayRequestData) + _, err := cwsm.relaySender.SendParsedRelay(ctx, dappID, consumerIp, metricsData, protocolMessage) if err != nil { return utils.LavaFormatError("could not send unsubscribe relay", err) } @@ -775,12 +767,12 @@ func (cwsm *ConsumerWSSubscriptionManager) UnsubscribeAll(webSocketCtx context.C ) unsubscribeRelayGetter := func() (*unsubscribeRelayData, error) { - chainMessage, directiveHeaders, relayRequestData, err := cwsm.craftUnsubscribeMessage(hashedParams, dappID, consumerIp, metricsData) + protocolMessage, err := cwsm.craftUnsubscribeMessage(hashedParams, dappID, consumerIp, metricsData) if err != nil { return nil, err } - return &unsubscribeRelayData{chainMessage, directiveHeaders, relayRequestData}, nil + return &unsubscribeRelayData{protocolMessage}, nil } cwsm.verifyAndDisconnectDappFromSubscription(webSocketCtx, dappKey, hashedParams, unsubscribeRelayGetter) diff --git a/protocol/chainlib/consumer_ws_subscription_manager_test.go b/protocol/chainlib/consumer_ws_subscription_manager_test.go index ed79e1bea5..81a59fea87 100644 --- a/protocol/chainlib/consumer_ws_subscription_manager_test.go +++ b/protocol/chainlib/consumer_ws_subscription_manager_test.go @@ -8,6 +8,7 @@ import ( "testing" "time" + gomock "github.com/golang/mock/gomock" "github.com/lavanet/lava/v2/protocol/chainlib/extensionslib" "github.com/lavanet/lava/v2/protocol/common" "github.com/lavanet/lava/v2/protocol/lavaprotocol" @@ -20,7 +21,7 @@ import ( spectypes "github.com/lavanet/lava/v2/x/spec/types" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - gomock "go.uber.org/mock/gomock" + gomockuber "go.uber.org/mock/gomock" ) const ( @@ -66,7 +67,7 @@ func TestConsumerWSSubscriptionManagerParallelSubscriptionsOnSameDappIdIp(t *tes chainMessage1, err := chainParser.ParseMsg("", play.subscriptionRequestData1, play.connectionType, nil, extensionslib.ExtensionInfo{LatestBlock: 0}) require.NoError(t, err) - + protocolMessage1 := NewProtocolMessage(chainMessage1, nil, nil) relaySender := NewMockRelaySender(ctrl) relaySender. EXPECT(). @@ -83,7 +84,7 @@ func TestConsumerWSSubscriptionManagerParallelSubscriptionsOnSameDappIdIp(t *tes relaySender. EXPECT(). ParseRelay(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). - Return(chainMessage1, nil, nil, nil). + Return(protocolMessage1, nil). AnyTimes() mockRelayerClient1 := pairingtypes.NewMockRelayer_RelaySubscribeClient(ctrl) @@ -128,7 +129,7 @@ func TestConsumerWSSubscriptionManagerParallelSubscriptionsOnSameDappIdIp(t *tes relaySender. EXPECT(). - SendParsedRelay(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + SendParsedRelay(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). Return(relayResult1, nil). Times(1) // Should call SendParsedRelay, because it is the first time we subscribe @@ -139,6 +140,7 @@ func TestConsumerWSSubscriptionManagerParallelSubscriptionsOnSameDappIdIp(t *tes uniqueIdentifiers := make([]string, numberOfParallelSubscriptions) wg := sync.WaitGroup{} wg.Add(numberOfParallelSubscriptions) + // Start a new subscription for the first time, called SendParsedRelay once while in parallel calling 10 times subscribe with the same message // expected result is to have SendParsedRelay only once and 9 other messages waiting the broadcast. for i := 0; i < numberOfParallelSubscriptions; i++ { @@ -148,7 +150,8 @@ func TestConsumerWSSubscriptionManagerParallelSubscriptionsOnSameDappIdIp(t *tes ctx := utils.WithUniqueIdentifier(ts.Ctx, utils.GenerateUniqueIdentifier()) var repliesChan <-chan *pairingtypes.RelayReply var firstReply *pairingtypes.RelayReply - firstReply, repliesChan, err = manager.StartSubscription(ctx, chainMessage1, nil, nil, dapp, ip, uniqueIdentifiers[index], nil) + + firstReply, repliesChan, err = manager.StartSubscription(ctx, protocolMessage1, dapp, ip, uniqueIdentifiers[index], nil) go func() { for subMsg := range repliesChan { // utils.LavaFormatInfo("got reply for index", utils.LogAttr("index", index)) @@ -166,7 +169,7 @@ func TestConsumerWSSubscriptionManagerParallelSubscriptionsOnSameDappIdIp(t *tes // now we have numberOfParallelSubscriptions subscriptions currently running require.Len(t, manager.connectedDapps, numberOfParallelSubscriptions) // remove one - err = manager.Unsubscribe(ts.Ctx, chainMessage1, nil, nil, dapp, ip, uniqueIdentifiers[0], nil) + err = manager.Unsubscribe(ts.Ctx, protocolMessage1, dapp, ip, uniqueIdentifiers[0], nil) require.NoError(t, err) // now we have numberOfParallelSubscriptions - 1 require.Len(t, manager.connectedDapps, numberOfParallelSubscriptions-1) @@ -221,7 +224,7 @@ func TestConsumerWSSubscriptionManagerParallelSubscriptions(t *testing.T) { chainMessage1, err := chainParser.ParseMsg("", play.subscriptionRequestData1, play.connectionType, nil, extensionslib.ExtensionInfo{LatestBlock: 0}) require.NoError(t, err) - + protocolMessage1 := NewProtocolMessage(chainMessage1, nil, nil) relaySender := NewMockRelaySender(ctrl) relaySender. EXPECT(). @@ -238,7 +241,7 @@ func TestConsumerWSSubscriptionManagerParallelSubscriptions(t *testing.T) { relaySender. EXPECT(). ParseRelay(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). - Return(chainMessage1, nil, nil, nil). + Return(protocolMessage1, nil). AnyTimes() mockRelayerClient1 := pairingtypes.NewMockRelayer_RelaySubscribeClient(ctrl) @@ -283,7 +286,7 @@ func TestConsumerWSSubscriptionManagerParallelSubscriptions(t *testing.T) { relaySender. EXPECT(). - SendParsedRelay(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + SendParsedRelay(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). Return(relayResult1, nil). Times(1) // Should call SendParsedRelay, because it is the first time we subscribe @@ -302,7 +305,7 @@ func TestConsumerWSSubscriptionManagerParallelSubscriptions(t *testing.T) { ctx := utils.WithUniqueIdentifier(ts.Ctx, utils.GenerateUniqueIdentifier()) var repliesChan <-chan *pairingtypes.RelayReply var firstReply *pairingtypes.RelayReply - firstReply, repliesChan, err = manager.StartSubscription(ctx, chainMessage1, nil, nil, dapp+strconv.Itoa(index), ts.Consumer.Addr.String(), uniqueId, nil) + firstReply, repliesChan, err = manager.StartSubscription(ctx, protocolMessage1, dapp+strconv.Itoa(index), ts.Consumer.Addr.String(), uniqueId, nil) go func() { for subMsg := range repliesChan { require.Equal(t, string(play.subscriptionFirstReply1), string(subMsg.Data)) @@ -427,16 +430,21 @@ func TestConsumerWSSubscriptionManager(t *testing.T) { subscribeChainMessage1, err := chainParser.ParseMsg("", play.subscriptionRequestData1, play.connectionType, nil, extensionslib.ExtensionInfo{LatestBlock: 0}) require.NoError(t, err) + subscribeProtocolMessage1 := NewProtocolMessage(subscribeChainMessage1, nil, nil) + unsubscribeProtocolMessage1 := NewProtocolMessage(unsubscribeChainMessage1, nil, &pairingtypes.RelayPrivateData{ + Data: play.unsubscribeMessage1, + }) relaySender := NewMockRelaySender(ctrl) relaySender. EXPECT(). - SendParsedRelay(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Cond(func(x any) bool { - relayPrivateData, ok := x.(*pairingtypes.RelayPrivateData) - if !ok || relayPrivateData == nil { + SendParsedRelay(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomockuber.Cond(func(x any) bool { + protocolMsg, ok := x.(ProtocolMessage) + require.True(t, ok) + require.NotNil(t, protocolMsg) + if protocolMsg.RelayPrivateData() == nil { return false } - - if strings.Contains(string(relayPrivateData.Data), "unsubscribe") { + if strings.Contains(string(protocolMsg.RelayPrivateData().Data), "unsubscribe") { unsubscribeMessageWg.Done() } @@ -460,26 +468,24 @@ func TestConsumerWSSubscriptionManager(t *testing.T) { relaySender. EXPECT(). - ParseRelay(gomock.Any(), gomock.Any(), gomock.Cond(func(x any) bool { + ParseRelay(gomock.Any(), gomock.Any(), gomockuber.Cond(func(x any) bool { reqData, ok := x.(string) require.True(t, ok) areEqual := reqData == string(play.unsubscribeMessage1) return areEqual }), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). - Return(unsubscribeChainMessage1, nil, &pairingtypes.RelayPrivateData{ - Data: play.unsubscribeMessage1, - }, nil). + Return(unsubscribeProtocolMessage1, nil). AnyTimes() relaySender. EXPECT(). - ParseRelay(gomock.Any(), gomock.Any(), gomock.Cond(func(x any) bool { + ParseRelay(gomock.Any(), gomock.Any(), gomockuber.Cond(func(x any) bool { reqData, ok := x.(string) require.True(t, ok) areEqual := reqData == string(play.subscriptionRequestData1) return areEqual }), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). - Return(subscribeChainMessage1, nil, nil, nil). + Return(subscribeProtocolMessage1, nil). AnyTimes() mockRelayerClient1 := pairingtypes.NewMockRelayer_RelaySubscribeClient(ctrl) @@ -524,7 +530,7 @@ func TestConsumerWSSubscriptionManager(t *testing.T) { relaySender. EXPECT(). - SendParsedRelay(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + SendParsedRelay(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). Return(relayResult1, nil). Times(1) // Should call SendParsedRelay, because it is the first time we subscribe @@ -535,7 +541,8 @@ func TestConsumerWSSubscriptionManager(t *testing.T) { // Start a new subscription for the first time, called SendParsedRelay once ctx = utils.WithUniqueIdentifier(ctx, utils.GenerateUniqueIdentifier()) - firstReply, repliesChan1, err := manager.StartSubscription(ctx, subscribeChainMessage1, nil, nil, dapp1, ts.Consumer.Addr.String(), uniqueId, nil) + + firstReply, repliesChan1, err := manager.StartSubscription(ctx, subscribeProtocolMessage1, dapp1, ts.Consumer.Addr.String(), uniqueId, nil) assert.NoError(t, err) unsubscribeMessageWg.Add(1) assert.Equal(t, string(play.subscriptionFirstReply1), string(firstReply.Data)) @@ -545,13 +552,13 @@ func TestConsumerWSSubscriptionManager(t *testing.T) { relaySender. EXPECT(). - SendParsedRelay(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + SendParsedRelay(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). Return(relayResult1, nil). Times(0) // Should not call SendParsedRelay, because it is already subscribed // Start a subscription again, same params, same dappKey, should not call SendParsedRelay ctx = utils.WithUniqueIdentifier(ctx, utils.GenerateUniqueIdentifier()) - firstReply, repliesChan2, err := manager.StartSubscription(ctx, subscribeChainMessage1, nil, nil, dapp1, ts.Consumer.Addr.String(), uniqueId, nil) + firstReply, repliesChan2, err := manager.StartSubscription(ctx, subscribeProtocolMessage1, dapp1, ts.Consumer.Addr.String(), uniqueId, nil) assert.NoError(t, err) assert.Equal(t, string(play.subscriptionFirstReply1), string(firstReply.Data)) assert.Nil(t, repliesChan2) // Same subscription, same dappKey, no need for a new channel @@ -560,7 +567,7 @@ func TestConsumerWSSubscriptionManager(t *testing.T) { // Start a subscription again, same params, different dappKey, should not call SendParsedRelay ctx = utils.WithUniqueIdentifier(ctx, utils.GenerateUniqueIdentifier()) - firstReply, repliesChan3, err := manager.StartSubscription(ctx, subscribeChainMessage1, nil, nil, dapp2, ts.Consumer.Addr.String(), uniqueId, nil) + firstReply, repliesChan3, err := manager.StartSubscription(ctx, subscribeProtocolMessage1, dapp2, ts.Consumer.Addr.String(), uniqueId, nil) assert.NoError(t, err) assert.Equal(t, string(play.subscriptionFirstReply1), string(firstReply.Data)) assert.NotNil(t, repliesChan3) // Same subscription, but different dappKey, so will create new channel @@ -571,32 +578,30 @@ func TestConsumerWSSubscriptionManager(t *testing.T) { // Prepare for the next subscription unsubscribeChainMessage2, err := chainParser.ParseMsg("", play.unsubscribeMessage2, play.connectionType, nil, extensionslib.ExtensionInfo{LatestBlock: 0}) require.NoError(t, err) - + unsubscribeProtocolMessage2 := NewProtocolMessage(unsubscribeChainMessage2, nil, &pairingtypes.RelayPrivateData{Data: play.unsubscribeMessage2}) subscribeChainMessage2, err := chainParser.ParseMsg("", play.subscriptionRequestData2, play.connectionType, nil, extensionslib.ExtensionInfo{LatestBlock: 0}) require.NoError(t, err) - + subscribeProtocolMessage2 := NewProtocolMessage(subscribeChainMessage2, nil, nil) relaySender. EXPECT(). - ParseRelay(gomock.Any(), gomock.Any(), gomock.Cond(func(x any) bool { + ParseRelay(gomock.Any(), gomock.Any(), gomockuber.Cond(func(x any) bool { reqData, ok := x.(string) require.True(t, ok) areEqual := reqData == string(play.unsubscribeMessage2) return areEqual }), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). - Return(unsubscribeChainMessage2, nil, &pairingtypes.RelayPrivateData{ - Data: play.unsubscribeMessage2, - }, nil). + Return(unsubscribeProtocolMessage2, nil). AnyTimes() relaySender. EXPECT(). - ParseRelay(gomock.Any(), gomock.Any(), gomock.Cond(func(x any) bool { + ParseRelay(gomock.Any(), gomock.Any(), gomockuber.Cond(func(x any) bool { reqData, ok := x.(string) require.True(t, ok) areEqual := reqData == string(play.subscriptionRequestData2) return areEqual }), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). - Return(subscribeChainMessage2, nil, nil, nil). + Return(subscribeProtocolMessage2, nil). AnyTimes() mockRelayerClient2 := pairingtypes.NewMockRelayer_RelaySubscribeClient(ctrl) @@ -639,13 +644,14 @@ func TestConsumerWSSubscriptionManager(t *testing.T) { relaySender. EXPECT(). - SendParsedRelay(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + SendParsedRelay(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). Return(relayResult2, nil). Times(1) // Should call SendParsedRelay, because it is the first time we subscribe // Start a subscription again, different params, same dappKey, should call SendParsedRelay ctx = utils.WithUniqueIdentifier(ctx, utils.GenerateUniqueIdentifier()) - firstReply, repliesChan4, err := manager.StartSubscription(ctx, subscribeChainMessage2, nil, nil, dapp1, ts.Consumer.Addr.String(), uniqueId, nil) + + firstReply, repliesChan4, err := manager.StartSubscription(ctx, subscribeProtocolMessage2, dapp1, ts.Consumer.Addr.String(), uniqueId, nil) assert.NoError(t, err) unsubscribeMessageWg.Add(1) assert.Equal(t, string(play.subscriptionFirstReply2), string(firstReply.Data)) @@ -658,12 +664,13 @@ func TestConsumerWSSubscriptionManager(t *testing.T) { // Prepare for unsubscribe from the first subscription relaySender. EXPECT(). - SendParsedRelay(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + SendParsedRelay(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). Return(relayResult1, nil). Times(0) // Should call SendParsedRelay, because it unsubscribed ctx = utils.WithUniqueIdentifier(ctx, utils.GenerateUniqueIdentifier()) - err = manager.Unsubscribe(ctx, unsubscribeChainMessage1, nil, relayResult1.Request.RelayData, dapp2, ts.Consumer.Addr.String(), uniqueId, nil) + unsubProtocolMessage := NewProtocolMessage(unsubscribeChainMessage1, nil, relayResult1.Request.RelayData) + err = manager.Unsubscribe(ctx, unsubProtocolMessage, dapp2, ts.Consumer.Addr.String(), uniqueId, nil) require.NoError(t, err) listenForExpectedMessages(ctx, repliesChan1, string(play.subscriptionFirstReply1)) @@ -681,8 +688,8 @@ func TestConsumerWSSubscriptionManager(t *testing.T) { // Prepare for unsubscribe from the second subscription relaySender. EXPECT(). - SendParsedRelay(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, dappID string, consumerIp string, analytics *metrics.RelayMetrics, chainMessage ChainMessage, directiveHeaders map[string]string, relayRequestData *pairingtypes.RelayPrivateData) (relayResult *common.RelayResult, errRet error) { + SendParsedRelay(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, dappID string, consumerIp string, analytics *metrics.RelayMetrics, protocolMessage ProtocolMessage) (relayResult *common.RelayResult, errRet error) { wg.Done() return relayResult2, nil }). diff --git a/protocol/chainlib/protocol_message.go b/protocol/chainlib/protocol_message.go new file mode 100644 index 0000000000..c9ed2ea01d --- /dev/null +++ b/protocol/chainlib/protocol_message.go @@ -0,0 +1,53 @@ +package chainlib + +import ( + "strings" + + "github.com/lavanet/lava/v2/protocol/common" + pairingtypes "github.com/lavanet/lava/v2/x/pairing/types" +) + +type BaseProtocolMessage struct { + ChainMessage + directiveHeaders map[string]string + relayRequestData *pairingtypes.RelayPrivateData +} + +func (bpm *BaseProtocolMessage) GetDirectiveHeaders() map[string]string { + return bpm.directiveHeaders +} + +func (bpm *BaseProtocolMessage) RelayPrivateData() *pairingtypes.RelayPrivateData { + return bpm.relayRequestData +} + +func (bpm *BaseProtocolMessage) HashCacheRequest(chainId string) ([]byte, func([]byte) []byte, error) { + return HashCacheRequest(bpm.relayRequestData, chainId) +} + +func (bpm *BaseProtocolMessage) GetBlockedProviders() []string { + if bpm.directiveHeaders == nil { + return nil + } + blockedProviders, ok := bpm.directiveHeaders[common.BLOCK_PROVIDERS_ADDRESSES_HEADER_NAME] + if ok { + return strings.Split(blockedProviders, ",") + } + return nil +} + +func NewProtocolMessage(chainMessage ChainMessage, directiveHeaders map[string]string, relayRequestData *pairingtypes.RelayPrivateData) ProtocolMessage { + return &BaseProtocolMessage{ + ChainMessage: chainMessage, + directiveHeaders: directiveHeaders, + relayRequestData: relayRequestData, + } +} + +type ProtocolMessage interface { + ChainMessage + GetDirectiveHeaders() map[string]string + RelayPrivateData() *pairingtypes.RelayPrivateData + HashCacheRequest(chainId string) ([]byte, func([]byte) []byte, error) + GetBlockedProviders() []string +} diff --git a/protocol/lavasession/consumer_session_manager_test.go b/protocol/lavasession/consumer_session_manager_test.go index 9bdf1b7fcf..ad44b5010e 100644 --- a/protocol/lavasession/consumer_session_manager_test.go +++ b/protocol/lavasession/consumer_session_manager_test.go @@ -8,6 +8,7 @@ import ( "net" "os" "strconv" + "strings" "testing" "time" @@ -221,6 +222,21 @@ func createGRPCServer(changeListener string, probeDelay time.Duration) error { const providerStr = "provider" +type DirectiveHeaders struct { + directiveHeaders map[string]string +} + +func (bpm DirectiveHeaders) GetBlockedProviders() []string { + if bpm.directiveHeaders == nil { + return nil + } + blockedProviders, ok := bpm.directiveHeaders[common.BLOCK_PROVIDERS_ADDRESSES_HEADER_NAME] + if ok { + return strings.Split(blockedProviders, ",") + } + return nil +} + func createPairingList(providerPrefixAddress string, enabled bool) map[uint64]*ConsumerSessionsWithProvider { cswpList := make(map[uint64]*ConsumerSessionsWithProvider, 0) pairingEndpoints := make([]*Endpoint, 1) @@ -322,7 +338,9 @@ func TestSecondChanceRecoveryFlow(t *testing.T) { timeLimit := time.Second * 30 loopStartTime := time.Now() for { - usedProviders := NewUsedProviders(map[string]string{"lava-providers-block": pairingList[1].PublicLavaAddress}) + // implement a struct that returns: map[string]string{"lava-providers-block": pairingList[1].PublicLavaAddress} in the implementation for the DirectiveHeadersInf interface + directiveHeaders := DirectiveHeaders{map[string]string{"lava-providers-block": pairingList[1].PublicLavaAddress}} + usedProviders := NewUsedProviders(directiveHeaders) css, err := csm.GetSessions(ctx, cuForFirstRequest, usedProviders, servicedBlockNumber, "", nil, common.NO_STATE, 0) // get a session require.NoError(t, err) _, expectedProviderAddress := css[pairingList[0].PublicLavaAddress] @@ -372,7 +390,8 @@ func TestSecondChanceRecoveryFlow(t *testing.T) { loopStartTime = time.Now() for { utils.LavaFormatDebug("Test", utils.LogAttr("csm.validAddresses", csm.validAddresses), utils.LogAttr("csm.currentlyBlockedProviderAddresses", csm.currentlyBlockedProviderAddresses), utils.LogAttr("csm.pairing[pairingList[0].PublicLavaAddress].blockedAndUsedWithChanceForRecoveryStatus", csm.pairing[pairingList[0].PublicLavaAddress].blockedAndUsedWithChanceForRecoveryStatus)) - usedProviders := NewUsedProviders(map[string]string{"lava-providers-block": pairingList[1].PublicLavaAddress}) + directiveHeaders := DirectiveHeaders{map[string]string{"lava-providers-block": pairingList[1].PublicLavaAddress}} + usedProviders := NewUsedProviders(directiveHeaders) require.Equal(t, BlockedProviderSessionUnusedStatus, csm.pairing[pairingList[0].PublicLavaAddress].blockedAndUsedWithChanceForRecoveryStatus) css, err := csm.GetSessions(ctx, cuForFirstRequest, usedProviders, servicedBlockNumber, "", nil, common.NO_STATE, 0) // get a session require.Equal(t, BlockedProviderSessionUnusedStatus, csm.pairing[pairingList[0].PublicLavaAddress].blockedAndUsedWithChanceForRecoveryStatus) diff --git a/protocol/lavasession/used_providers.go b/protocol/lavasession/used_providers.go index 854e4823ac..b1d72de953 100644 --- a/protocol/lavasession/used_providers.go +++ b/protocol/lavasession/used_providers.go @@ -2,22 +2,23 @@ package lavasession import ( "context" - "strings" "sync" "time" - "github.com/lavanet/lava/v2/protocol/common" "github.com/lavanet/lava/v2/utils" ) const MaximumNumberOfSelectionLockAttempts = 500 -func NewUsedProviders(directiveHeaders map[string]string) *UsedProviders { +type BlockedProvidersInf interface { + GetBlockedProviders() []string +} + +func NewUsedProviders(blockedProviders BlockedProvidersInf) *UsedProviders { unwantedProviders := map[string]struct{}{} - if len(directiveHeaders) > 0 { - blockedProviders, ok := directiveHeaders[common.BLOCK_PROVIDERS_ADDRESSES_HEADER_NAME] - if ok { - providerAddressesToBlock := strings.Split(blockedProviders, ",") + if blockedProviders != nil { + providerAddressesToBlock := blockedProviders.GetBlockedProviders() + if len(providerAddressesToBlock) > 0 { for _, providerAddress := range providerAddressesToBlock { unwantedProviders[providerAddress] = struct{}{} } diff --git a/protocol/rpcconsumer/rpcconsumer_server.go b/protocol/rpcconsumer/rpcconsumer_server.go index f00af0206f..595038b5a7 100644 --- a/protocol/rpcconsumer/rpcconsumer_server.go +++ b/protocol/rpcconsumer/rpcconsumer_server.go @@ -223,16 +223,16 @@ func (rpccs *RPCConsumerServer) craftRelay(ctx context.Context) (ok bool, relay return } -func (rpccs *RPCConsumerServer) sendRelayWithRetries(ctx context.Context, retries int, initialRelays bool, relay *pairingtypes.RelayPrivateData, chainMessage chainlib.ChainMessage) (bool, error) { +func (rpccs *RPCConsumerServer) sendRelayWithRetries(ctx context.Context, retries int, initialRelays bool, protocolMessage chainlib.ProtocolMessage) (bool, error) { success := false var err error - relayProcessor := NewRelayProcessor(ctx, lavasession.NewUsedProviders(nil), 1, chainMessage, rpccs.consumerConsistency, "-init-", "", rpccs.debugRelays, rpccs.rpcConsumerLogs, rpccs, rpccs.disableNodeErrorRetry, rpccs.relayRetriesManager) + relayProcessor := NewRelayProcessor(ctx, lavasession.NewUsedProviders(nil), 1, protocolMessage, rpccs.consumerConsistency, "-init-", "", rpccs.debugRelays, rpccs.rpcConsumerLogs, rpccs, rpccs.disableNodeErrorRetry, rpccs.relayRetriesManager) for i := 0; i < retries; i++ { - err = rpccs.sendRelayToProvider(ctx, chainMessage, relay, "-init-", "", relayProcessor, nil) + err = rpccs.sendRelayToProvider(ctx, protocolMessage, "-init-", "", relayProcessor, nil) if lavasession.PairingListEmptyError.Is(err) { // we don't have pairings anymore, could be related to unwanted providers relayProcessor.GetUsedProviders().ClearUnwanted() - err = rpccs.sendRelayToProvider(ctx, chainMessage, relay, "-init-", "", relayProcessor, nil) + err = rpccs.sendRelayToProvider(ctx, protocolMessage, "-init-", "", relayProcessor, nil) } if err != nil { utils.LavaFormatError("[-] failed sending init relay", err, []utils.Attribute{{Key: "chainID", Value: rpccs.listenEndpoint.ChainID}, {Key: "APIInterface", Value: rpccs.listenEndpoint.ApiInterface}, {Key: "relayProcessor", Value: relayProcessor}}...) @@ -285,8 +285,8 @@ func (rpccs *RPCConsumerServer) sendCraftedRelays(retries int, initialRelays boo } return false, err } - - return rpccs.sendRelayWithRetries(ctx, retries, initialRelays, relay, chainMessage) + protocolMessage := chainlib.NewProtocolMessage(chainMessage, nil, relay) + return rpccs.sendRelayWithRetries(ctx, retries, initialRelays, protocolMessage) } func (rpccs *RPCConsumerServer) getLatestBlock() uint64 { @@ -308,12 +308,12 @@ func (rpccs *RPCConsumerServer) SendRelay( analytics *metrics.RelayMetrics, metadata []pairingtypes.Metadata, ) (relayResult *common.RelayResult, errRet error) { - chainMessage, directiveHeaders, relayRequestData, err := rpccs.ParseRelay(ctx, url, req, connectionType, dappID, consumerIp, analytics, metadata) + protocolMessage, err := rpccs.ParseRelay(ctx, url, req, connectionType, dappID, consumerIp, analytics, metadata) if err != nil { return nil, err } - return rpccs.SendParsedRelay(ctx, dappID, consumerIp, analytics, chainMessage, directiveHeaders, relayRequestData) + return rpccs.SendParsedRelay(ctx, dappID, consumerIp, analytics, protocolMessage) } func (rpccs *RPCConsumerServer) ParseRelay( @@ -325,16 +325,16 @@ func (rpccs *RPCConsumerServer) ParseRelay( consumerIp string, analytics *metrics.RelayMetrics, metadata []pairingtypes.Metadata, -) (chainMessage chainlib.ChainMessage, directiveHeaders map[string]string, relayRequestData *pairingtypes.RelayPrivateData, err error) { +) (protocolMessage chainlib.ProtocolMessage, err error) { // gets the relay request data from the ChainListener // parses the request into an APIMessage, and validating it corresponds to the spec currently in use // construct the common data for a relay message, common data is identical across multiple sends and data reliability // remove lava directive headers - metadata, directiveHeaders = rpccs.LavaDirectiveHeaders(metadata) - chainMessage, err = rpccs.chainParser.ParseMsg(url, []byte(req), connectionType, metadata, rpccs.getExtensionsFromDirectiveHeaders(directiveHeaders)) + metadata, directiveHeaders := rpccs.LavaDirectiveHeaders(metadata) + chainMessage, err := rpccs.chainParser.ParseMsg(url, []byte(req), connectionType, metadata, rpccs.getExtensionsFromDirectiveHeaders(directiveHeaders)) if err != nil { - return nil, nil, nil, err + return nil, err } rpccs.HandleDirectiveHeadersForMessage(chainMessage, directiveHeaders) @@ -346,8 +346,9 @@ func (rpccs *RPCConsumerServer) ParseRelay( seenBlock = 0 } - relayRequestData = lavaprotocol.NewRelayData(ctx, connectionType, url, []byte(req), seenBlock, reqBlock, rpccs.listenEndpoint.ApiInterface, chainMessage.GetRPCMessage().GetHeaders(), chainlib.GetAddon(chainMessage), common.GetExtensionNames(chainMessage.GetExtensions())) - return chainMessage, directiveHeaders, relayRequestData, nil + relayRequestData := lavaprotocol.NewRelayData(ctx, connectionType, url, []byte(req), seenBlock, reqBlock, rpccs.listenEndpoint.ApiInterface, chainMessage.GetRPCMessage().GetHeaders(), chainlib.GetAddon(chainMessage), common.GetExtensionNames(chainMessage.GetExtensions())) + protocolMessage = chainlib.NewProtocolMessage(chainMessage, directiveHeaders, relayRequestData) + return protocolMessage, nil } func (rpccs *RPCConsumerServer) SendParsedRelay( @@ -355,9 +356,7 @@ func (rpccs *RPCConsumerServer) SendParsedRelay( dappID string, consumerIp string, analytics *metrics.RelayMetrics, - chainMessage chainlib.ChainMessage, - directiveHeaders map[string]string, - relayRequestData *pairingtypes.RelayPrivateData, + protocolMessage chainlib.ProtocolMessage, ) (relayResult *common.RelayResult, errRet error) { // sends a relay message to a provider // compares the result with other providers if defined so @@ -365,7 +364,7 @@ func (rpccs *RPCConsumerServer) SendParsedRelay( // asynchronously sends data reliability if necessary relaySentTime := time.Now() - relayProcessor, err := rpccs.ProcessRelaySend(ctx, directiveHeaders, chainMessage, relayRequestData, dappID, consumerIp, analytics) + relayProcessor, err := rpccs.ProcessRelaySend(ctx, protocolMessage, dappID, consumerIp, analytics) if err != nil && !relayProcessor.HasResults() { // we can't send anymore, and we don't have any responses utils.LavaFormatError("failed getting responses from providers", err, utils.Attribute{Key: "GUID", Value: ctx}, utils.LogAttr("endpoint", rpccs.listenEndpoint.Key()), utils.LogAttr("userIp", consumerIp), utils.LogAttr("relayProcessor", relayProcessor)) @@ -383,11 +382,11 @@ func (rpccs *RPCConsumerServer) SendParsedRelay( if found { dataReliabilityContext = utils.WithUniqueIdentifier(dataReliabilityContext, guid) } - go rpccs.sendDataReliabilityRelayIfApplicable(dataReliabilityContext, dappID, consumerIp, chainMessage, dataReliabilityThreshold, relayProcessor) // runs asynchronously + go rpccs.sendDataReliabilityRelayIfApplicable(dataReliabilityContext, dappID, consumerIp, protocolMessage, dataReliabilityThreshold, relayProcessor) // runs asynchronously } returnedResult, err := relayProcessor.ProcessingResult() - rpccs.appendHeadersToRelayResult(ctx, returnedResult, relayProcessor.ProtocolErrors(), relayProcessor, directiveHeaders) + rpccs.appendHeadersToRelayResult(ctx, returnedResult, relayProcessor.ProtocolErrors(), relayProcessor, protocolMessage.GetDirectiveHeaders()) if err != nil { return returnedResult, utils.LavaFormatError("failed processing responses from providers", err, utils.Attribute{Key: "GUID", Value: ctx}, utils.LogAttr("endpoint", rpccs.listenEndpoint.Key())) } @@ -395,7 +394,7 @@ func (rpccs *RPCConsumerServer) SendParsedRelay( if analytics != nil { currentLatency := time.Since(relaySentTime) analytics.Latency = currentLatency.Milliseconds() - api := chainMessage.GetApi() + api := protocolMessage.GetApi() analytics.ComputeUnits = api.ComputeUnits analytics.ApiMethod = api.Name } @@ -407,11 +406,11 @@ func (rpccs *RPCConsumerServer) GetChainIdAndApiInterface() (string, string) { return rpccs.listenEndpoint.ChainID, rpccs.listenEndpoint.ApiInterface } -func (rpccs *RPCConsumerServer) ProcessRelaySend(ctx context.Context, directiveHeaders map[string]string, chainMessage chainlib.ChainMessage, relayRequestData *pairingtypes.RelayPrivateData, dappID string, consumerIp string, analytics *metrics.RelayMetrics) (*RelayProcessor, error) { +func (rpccs *RPCConsumerServer) ProcessRelaySend(ctx context.Context, protocolMessage chainlib.ProtocolMessage, dappID string, consumerIp string, analytics *metrics.RelayMetrics) (*RelayProcessor, error) { // make sure all of the child contexts are cancelled when we exit ctx, cancel := context.WithCancel(ctx) defer cancel() - relayProcessor := NewRelayProcessor(ctx, lavasession.NewUsedProviders(directiveHeaders), rpccs.requiredResponses, chainMessage, rpccs.consumerConsistency, dappID, consumerIp, rpccs.debugRelays, rpccs.rpcConsumerLogs, rpccs, rpccs.disableNodeErrorRetry, rpccs.relayRetriesManager) + relayProcessor := NewRelayProcessor(ctx, lavasession.NewUsedProviders(protocolMessage), rpccs.requiredResponses, protocolMessage, rpccs.consumerConsistency, dappID, consumerIp, rpccs.debugRelays, rpccs.rpcConsumerLogs, rpccs, rpccs.disableNodeErrorRetry, rpccs.relayRetriesManager) var err error // try sending a relay 3 times. if failed return the error for retryFirstRelayAttempt := 0; retryFirstRelayAttempt < SendRelayAttempts; retryFirstRelayAttempt++ { @@ -419,7 +418,7 @@ func (rpccs *RPCConsumerServer) ProcessRelaySend(ctx context.Context, directiveH if analytics != nil && retryFirstRelayAttempt > 0 { analytics = nil } - err = rpccs.sendRelayToProvider(ctx, chainMessage, relayRequestData, dappID, consumerIp, relayProcessor, analytics) + err = rpccs.sendRelayToProvider(ctx, protocolMessage, dappID, consumerIp, relayProcessor, analytics) // check if we had an error. if we did, try again. if err == nil { @@ -434,7 +433,7 @@ func (rpccs *RPCConsumerServer) ProcessRelaySend(ctx context.Context, directiveH // a channel to be notified processing was done, true means we have results and can return gotResults := make(chan bool) - processingTimeout, relayTimeout := rpccs.getProcessingTimeout(chainMessage) + processingTimeout, relayTimeout := rpccs.getProcessingTimeout(protocolMessage) if rpccs.debugRelays { utils.LavaFormatDebug("Relay initiated with the following timeout schedule", utils.LogAttr("processingTimeout", processingTimeout), utils.LogAttr("newRelayTimeout", relayTimeout)) } @@ -486,7 +485,7 @@ func (rpccs *RPCConsumerServer) ProcessRelaySend(ctx context.Context, directiveH return relayProcessor, nil } // otherwise continue sending another relay - err := rpccs.sendRelayToProvider(processingCtx, chainMessage, relayRequestData, dappID, consumerIp, relayProcessor, nil) + err := rpccs.sendRelayToProvider(processingCtx, protocolMessage, dappID, consumerIp, relayProcessor, nil) go validateReturnCondition(err) go readResultsFromProcessor() // increase number of retries launched only if we still have pairing available, if we exhausted the list we don't want to break early @@ -499,7 +498,7 @@ func (rpccs *RPCConsumerServer) ProcessRelaySend(ctx context.Context, directiveH if relayProcessor.ShouldRetry(numberOfRetriesLaunched) { // limit the number of retries called from the new batch ticker flow. // if we pass the limit we just wait for the relays we sent to return. - err := rpccs.sendRelayToProvider(processingCtx, chainMessage, relayRequestData, dappID, consumerIp, relayProcessor, nil) + err := rpccs.sendRelayToProvider(processingCtx, protocolMessage, dappID, consumerIp, relayProcessor, nil) go validateReturnCondition(err) // add ticker launch metrics go rpccs.rpcConsumerLogs.SetRelaySentByNewBatchTickerMetric(rpccs.GetChainIdAndApiInterface()) @@ -524,7 +523,7 @@ func (rpccs *RPCConsumerServer) ProcessRelaySend(ctx context.Context, directiveH utils.LogAttr("processingTimeout", processingTimeout), utils.LogAttr("dappId", dappID), utils.LogAttr("consumerIp", consumerIp), - utils.LogAttr("chainMessage.GetApi().Name", chainMessage.GetApi().Name), + utils.LogAttr("protocolMessage.GetApi().Name", protocolMessage.GetApi().Name), utils.LogAttr("GUID", ctx), utils.LogAttr("relayProcessor", relayProcessor), ) @@ -553,8 +552,7 @@ func (rpccs *RPCConsumerServer) CancelSubscriptionContext(subscriptionKey string func (rpccs *RPCConsumerServer) sendRelayToProvider( ctx context.Context, - chainMessage chainlib.ChainMessage, - relayRequestData *pairingtypes.RelayPrivateData, + protocolMessage chainlib.ProtocolMessage, dappID string, consumerIp string, relayProcessor *RelayProcessor, @@ -581,27 +579,27 @@ func (rpccs *RPCConsumerServer) sendRelayToProvider( lavaChainID := rpccs.lavaChainID // Get Session. we get session here so we can use the epoch in the callbacks - reqBlock, _ := chainMessage.RequestedBlock() + reqBlock, _ := protocolMessage.RequestedBlock() // try using cache before sending relay var cacheError error if rpccs.cache.CacheActive() { // use cache only if its defined. - if !chainMessage.GetForceCacheRefresh() { // don't use cache if user specified + if !protocolMessage.GetForceCacheRefresh() { // don't use cache if user specified if reqBlock != spectypes.NOT_APPLICABLE { // don't use cache if requested block is not applicable var cacheReply *pairingtypes.CacheRelayReply - hashKey, outputFormatter, err := chainlib.HashCacheRequest(relayRequestData, chainId) + hashKey, outputFormatter, err := protocolMessage.HashCacheRequest(chainId) if err != nil { utils.LavaFormatError("sendRelayToProvider Failed getting Hash for cache request", err) } else { cacheCtx, cancel := context.WithTimeout(ctx, common.CacheTimeout) cacheReply, cacheError = rpccs.cache.GetEntry(cacheCtx, &pairingtypes.RelayCacheGet{ RequestHash: hashKey, - RequestedBlock: relayRequestData.RequestBlock, + RequestedBlock: reqBlock, ChainId: chainId, BlockHash: nil, Finalized: false, SharedStateId: sharedStateId, - SeenBlock: relayRequestData.SeenBlock, + SeenBlock: protocolMessage.RelayPrivateData().SeenBlock, }) // caching in the portal doesn't care about hashes, and we don't have data on finalization yet cancel() reply := cacheReply.GetReply() @@ -610,9 +608,9 @@ func (rpccs *RPCConsumerServer) sendRelayToProvider( cacheSeenBlock := cacheReply.GetSeenBlock() // check if the cache seen block is greater than my local seen block, this means the user requested this // request spoke with another consumer instance and use that block for inter consumer consistency. - if rpccs.sharedState && cacheSeenBlock > relayRequestData.SeenBlock { - utils.LavaFormatDebug("shared state seen block is newer", utils.LogAttr("cache_seen_block", cacheSeenBlock), utils.LogAttr("local_seen_block", relayRequestData.SeenBlock)) - relayRequestData.SeenBlock = cacheSeenBlock + if rpccs.sharedState && cacheSeenBlock > protocolMessage.RelayPrivateData().SeenBlock { + utils.LavaFormatDebug("shared state seen block is newer", utils.LogAttr("cache_seen_block", cacheSeenBlock), utils.LogAttr("local_seen_block", protocolMessage.RelayPrivateData().SeenBlock)) + protocolMessage.RelayPrivateData().SeenBlock = cacheSeenBlock // setting the fetched seen block from the cache server to our local cache as well. rpccs.consumerConsistency.SetSeenBlock(cacheSeenBlock, dappID, consumerIp) } @@ -625,7 +623,7 @@ func (rpccs *RPCConsumerServer) sendRelayToProvider( relayResult := common.RelayResult{ Reply: reply, Request: &pairingtypes.RelayRequest{ - RelayData: relayRequestData, + RelayData: protocolMessage.RelayPrivateData(), }, Finalized: false, // set false to skip data reliability StatusCode: 200, @@ -643,33 +641,33 @@ func (rpccs *RPCConsumerServer) sendRelayToProvider( } } } else { - utils.LavaFormatDebug("skipping cache due to requested block being NOT_APPLICABLE", utils.Attribute{Key: "api name", Value: chainMessage.GetApi().Name}) + utils.LavaFormatDebug("skipping cache due to requested block being NOT_APPLICABLE", utils.Attribute{Key: "api name", Value: protocolMessage.GetApi().Name}) } } } - if reqBlock == spectypes.LATEST_BLOCK && relayRequestData.SeenBlock != 0 { + if reqBlock == spectypes.LATEST_BLOCK && protocolMessage.RelayPrivateData().SeenBlock != 0 { // make optimizer select a provider that is likely to have the latest seen block - reqBlock = relayRequestData.SeenBlock + reqBlock = protocolMessage.RelayPrivateData().SeenBlock } // consumerEmergencyTracker always use latest virtual epoch virtualEpoch := rpccs.consumerTxSender.GetLatestVirtualEpoch() - addon := chainlib.GetAddon(chainMessage) - extensions := chainMessage.GetExtensions() + addon := chainlib.GetAddon(protocolMessage) + extensions := protocolMessage.GetExtensions() usedProviders := relayProcessor.GetUsedProviders() - sessions, err := rpccs.consumerSessionManager.GetSessions(ctx, chainlib.GetComputeUnits(chainMessage), usedProviders, reqBlock, addon, extensions, chainlib.GetStateful(chainMessage), virtualEpoch) + sessions, err := rpccs.consumerSessionManager.GetSessions(ctx, chainlib.GetComputeUnits(protocolMessage), usedProviders, reqBlock, addon, extensions, chainlib.GetStateful(protocolMessage), virtualEpoch) if err != nil { if lavasession.PairingListEmptyError.Is(err) { if addon != "" { return utils.LavaFormatError("No Providers For Addon", err, utils.LogAttr("addon", addon), utils.LogAttr("extensions", extensions), utils.LogAttr("userIp", consumerIp)) } else if len(extensions) > 0 && relayProcessor.GetAllowSessionDegradation() { // if we have no providers for that extension, use a regular provider, otherwise return the extension results - sessions, err = rpccs.consumerSessionManager.GetSessions(ctx, chainlib.GetComputeUnits(chainMessage), usedProviders, reqBlock, addon, []*spectypes.Extension{}, chainlib.GetStateful(chainMessage), virtualEpoch) + sessions, err = rpccs.consumerSessionManager.GetSessions(ctx, chainlib.GetComputeUnits(protocolMessage), usedProviders, reqBlock, addon, []*spectypes.Extension{}, chainlib.GetStateful(protocolMessage), virtualEpoch) if err != nil { return err } - relayProcessor.setSkipDataReliability(true) // disabling data reliability when disabling extensions. - relayRequestData.Extensions = []string{} // reset request data extensions - extensions = []*spectypes.Extension{} // reset extensions too so we wont hit SetDisallowDegradation + relayProcessor.setSkipDataReliability(true) // disabling data reliability when disabling extensions. + protocolMessage.RelayPrivateData().Extensions = []string{} // reset request data extensions + extensions = []*spectypes.Extension{} // reset extensions too so we wont hit SetDisallowDegradation } else { return err } @@ -726,7 +724,7 @@ func (rpccs *RPCConsumerServer) sendRelayToProvider( goroutineCtxCancel() }() - localRelayRequestData := *relayRequestData + localRelayRequestData := *protocolMessage.RelayPrivateData() // Extract fields from the sessionInfo singleConsumerSession := sessionInfo.Session @@ -744,10 +742,10 @@ func (rpccs *RPCConsumerServer) sendRelayToProvider( // set relay sent metric go rpccs.rpcConsumerLogs.SetRelaySentToProviderMetric(chainId, apiInterface) - if chainlib.IsFunctionTagOfType(chainMessage, spectypes.FUNCTION_TAG_SUBSCRIBE) { + if chainlib.IsFunctionTagOfType(protocolMessage, spectypes.FUNCTION_TAG_SUBSCRIBE) { utils.LavaFormatTrace("inside sendRelayToProvider, relay is subscription", utils.LogAttr("requestData", localRelayRequestData.Data)) - params, err := json.Marshal(chainMessage.GetRPCMessage().GetParams()) + params, err := json.Marshal(protocolMessage.GetRPCMessage().GetParams()) if err != nil { utils.LavaFormatError("could not marshal params", err) return @@ -781,7 +779,7 @@ func (rpccs *RPCConsumerServer) sendRelayToProvider( // unique per dappId and ip consumerToken := common.GetUniqueToken(dappID, consumerIp) - processingTimeout, expectedRelayTimeoutForQOS := rpccs.getProcessingTimeout(chainMessage) + processingTimeout, expectedRelayTimeoutForQOS := rpccs.getProcessingTimeout(protocolMessage) deadline, ok := ctx.Deadline() if ok { // we have ctx deadline. we cant go past it. processingTimeout = time.Until(deadline) @@ -796,7 +794,7 @@ func (rpccs *RPCConsumerServer) sendRelayToProvider( } } // send relay - relayLatency, errResponse, backoff := rpccs.relayInner(goroutineCtx, singleConsumerSession, localRelayResult, processingTimeout, chainMessage, consumerToken, analytics) + relayLatency, errResponse, backoff := rpccs.relayInner(goroutineCtx, singleConsumerSession, localRelayResult, processingTimeout, protocolMessage, consumerToken, analytics) if errResponse != nil { failRelaySession := func(origErr error, backoff_ bool) { backOffDuration := 0 * time.Second @@ -840,10 +838,10 @@ func (rpccs *RPCConsumerServer) sendRelayToProvider( ) } - errResponse = rpccs.consumerSessionManager.OnSessionDone(singleConsumerSession, latestBlock, chainlib.GetComputeUnits(chainMessage), relayLatency, singleConsumerSession.CalculateExpectedLatency(expectedRelayTimeoutForQOS), expectedBH, numOfProviders, pairingAddressesLen, chainMessage.GetApi().Category.HangingApi) // session done successfully + errResponse = rpccs.consumerSessionManager.OnSessionDone(singleConsumerSession, latestBlock, chainlib.GetComputeUnits(protocolMessage), relayLatency, singleConsumerSession.CalculateExpectedLatency(expectedRelayTimeoutForQOS), expectedBH, numOfProviders, pairingAddressesLen, protocolMessage.GetApi().Category.HangingApi) // session done successfully if rpccs.cache.CacheActive() && rpcclient.ValidateStatusCodes(localRelayResult.StatusCode, true) == nil { - isNodeError, _ := chainMessage.CheckResponseError(localRelayResult.Reply.Data, localRelayResult.StatusCode) + isNodeError, _ := protocolMessage.CheckResponseError(localRelayResult.Reply.Data, localRelayResult.StatusCode) // in case the error is a node error we don't want to cache if !isNodeError { // copy reply data so if it changes it doesn't panic mid async send @@ -863,7 +861,7 @@ func (rpccs *RPCConsumerServer) sendRelayToProvider( ) return } - chainMessageRequestedBlock, _ := chainMessage.RequestedBlock() + chainMessageRequestedBlock, _ := protocolMessage.RequestedBlock() if chainMessageRequestedBlock == spectypes.NOT_APPLICABLE { return } @@ -1234,8 +1232,9 @@ func (rpccs *RPCConsumerServer) sendDataReliabilityRelayIfApplicable(ctx context relayResult := results[0] if len(results) < 2 { relayRequestData := lavaprotocol.NewRelayData(ctx, relayResult.Request.RelayData.ConnectionType, relayResult.Request.RelayData.ApiUrl, relayResult.Request.RelayData.Data, relayResult.Request.RelayData.SeenBlock, reqBlock, relayResult.Request.RelayData.ApiInterface, chainMessage.GetRPCMessage().GetHeaders(), relayResult.Request.RelayData.Addon, relayResult.Request.RelayData.Extensions) + protocolMessage := chainlib.NewProtocolMessage(chainMessage, nil, relayRequestData) relayProcessorDataReliability := NewRelayProcessor(ctx, relayProcessor.usedProviders, 1, chainMessage, rpccs.consumerConsistency, dappID, consumerIp, rpccs.debugRelays, rpccs.rpcConsumerLogs, rpccs, rpccs.disableNodeErrorRetry, rpccs.relayRetriesManager) - err := rpccs.sendRelayToProvider(ctx, chainMessage, relayRequestData, dappID, consumerIp, relayProcessorDataReliability, nil) + err := rpccs.sendRelayToProvider(ctx, protocolMessage, dappID, consumerIp, relayProcessorDataReliability, nil) if err != nil { return utils.LavaFormatWarning("failed data reliability relay to provider", err, utils.LogAttr("relayProcessorDataReliability", relayProcessorDataReliability)) } diff --git a/x/pairing/types/relay_mock.pb.go b/x/pairing/types/relay_mock.pb.go index e49f7212e9..ad76b049fa 100644 --- a/x/pairing/types/relay_mock.pb.go +++ b/x/pairing/types/relay_mock.pb.go @@ -1,10 +1,6 @@ // Code generated by MockGen. DO NOT EDIT. // Source: x/pairing/types/relay.pb.go -// -// Generated by this command: -// -// mockgen -source=x/pairing/types/relay.pb.go -destination x/pairing/types/relay_mock.pb.go -package types -// + // Package types is a generated GoMock package. package types @@ -12,7 +8,7 @@ import ( context "context" reflect "reflect" - gomock "go.uber.org/mock/gomock" + gomock "github.com/golang/mock/gomock" grpc "google.golang.org/grpc" metadata "google.golang.org/grpc/metadata" ) @@ -43,7 +39,7 @@ func (m *MockRelayerClient) EXPECT() *MockRelayerClientMockRecorder { // Probe mocks base method. func (m *MockRelayerClient) Probe(ctx context.Context, in *ProbeRequest, opts ...grpc.CallOption) (*ProbeReply, error) { m.ctrl.T.Helper() - varargs := []any{ctx, in} + varargs := []interface{}{ctx, in} for _, a := range opts { varargs = append(varargs, a) } @@ -54,16 +50,16 @@ func (m *MockRelayerClient) Probe(ctx context.Context, in *ProbeRequest, opts .. } // Probe indicates an expected call of Probe. -func (mr *MockRelayerClientMockRecorder) Probe(ctx, in any, opts ...any) *gomock.Call { +func (mr *MockRelayerClientMockRecorder) Probe(ctx, in interface{}, opts ...interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - varargs := append([]any{ctx, in}, opts...) + varargs := append([]interface{}{ctx, in}, opts...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Probe", reflect.TypeOf((*MockRelayerClient)(nil).Probe), varargs...) } // Relay mocks base method. func (m *MockRelayerClient) Relay(ctx context.Context, in *RelayRequest, opts ...grpc.CallOption) (*RelayReply, error) { m.ctrl.T.Helper() - varargs := []any{ctx, in} + varargs := []interface{}{ctx, in} for _, a := range opts { varargs = append(varargs, a) } @@ -74,16 +70,16 @@ func (m *MockRelayerClient) Relay(ctx context.Context, in *RelayRequest, opts .. } // Relay indicates an expected call of Relay. -func (mr *MockRelayerClientMockRecorder) Relay(ctx, in any, opts ...any) *gomock.Call { +func (mr *MockRelayerClientMockRecorder) Relay(ctx, in interface{}, opts ...interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - varargs := append([]any{ctx, in}, opts...) + varargs := append([]interface{}{ctx, in}, opts...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Relay", reflect.TypeOf((*MockRelayerClient)(nil).Relay), varargs...) } // RelaySubscribe mocks base method. func (m *MockRelayerClient) RelaySubscribe(ctx context.Context, in *RelayRequest, opts ...grpc.CallOption) (Relayer_RelaySubscribeClient, error) { m.ctrl.T.Helper() - varargs := []any{ctx, in} + varargs := []interface{}{ctx, in} for _, a := range opts { varargs = append(varargs, a) } @@ -94,9 +90,9 @@ func (m *MockRelayerClient) RelaySubscribe(ctx context.Context, in *RelayRequest } // RelaySubscribe indicates an expected call of RelaySubscribe. -func (mr *MockRelayerClientMockRecorder) RelaySubscribe(ctx, in any, opts ...any) *gomock.Call { +func (mr *MockRelayerClientMockRecorder) RelaySubscribe(ctx, in interface{}, opts ...interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - varargs := append([]any{ctx, in}, opts...) + varargs := append([]interface{}{ctx, in}, opts...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RelaySubscribe", reflect.TypeOf((*MockRelayerClient)(nil).RelaySubscribe), varargs...) } @@ -190,7 +186,7 @@ func (m_2 *MockRelayer_RelaySubscribeClient) RecvMsg(m any) error { } // RecvMsg indicates an expected call of RecvMsg. -func (mr *MockRelayer_RelaySubscribeClientMockRecorder) RecvMsg(m any) *gomock.Call { +func (mr *MockRelayer_RelaySubscribeClientMockRecorder) RecvMsg(m interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RecvMsg", reflect.TypeOf((*MockRelayer_RelaySubscribeClient)(nil).RecvMsg), m) } @@ -204,7 +200,7 @@ func (m_2 *MockRelayer_RelaySubscribeClient) SendMsg(m any) error { } // SendMsg indicates an expected call of SendMsg. -func (mr *MockRelayer_RelaySubscribeClientMockRecorder) SendMsg(m any) *gomock.Call { +func (mr *MockRelayer_RelaySubscribeClientMockRecorder) SendMsg(m interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendMsg", reflect.TypeOf((*MockRelayer_RelaySubscribeClient)(nil).SendMsg), m) } @@ -256,7 +252,7 @@ func (m *MockRelayerServer) Probe(arg0 context.Context, arg1 *ProbeRequest) (*Pr } // Probe indicates an expected call of Probe. -func (mr *MockRelayerServerMockRecorder) Probe(arg0, arg1 any) *gomock.Call { +func (mr *MockRelayerServerMockRecorder) Probe(arg0, arg1 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Probe", reflect.TypeOf((*MockRelayerServer)(nil).Probe), arg0, arg1) } @@ -271,7 +267,7 @@ func (m *MockRelayerServer) Relay(arg0 context.Context, arg1 *RelayRequest) (*Re } // Relay indicates an expected call of Relay. -func (mr *MockRelayerServerMockRecorder) Relay(arg0, arg1 any) *gomock.Call { +func (mr *MockRelayerServerMockRecorder) Relay(arg0, arg1 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Relay", reflect.TypeOf((*MockRelayerServer)(nil).Relay), arg0, arg1) } @@ -285,7 +281,7 @@ func (m *MockRelayerServer) RelaySubscribe(arg0 *RelayRequest, arg1 Relayer_Rela } // RelaySubscribe indicates an expected call of RelaySubscribe. -func (mr *MockRelayerServerMockRecorder) RelaySubscribe(arg0, arg1 any) *gomock.Call { +func (mr *MockRelayerServerMockRecorder) RelaySubscribe(arg0, arg1 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RelaySubscribe", reflect.TypeOf((*MockRelayerServer)(nil).RelaySubscribe), arg0, arg1) } @@ -336,7 +332,7 @@ func (m_2 *MockRelayer_RelaySubscribeServer) RecvMsg(m any) error { } // RecvMsg indicates an expected call of RecvMsg. -func (mr *MockRelayer_RelaySubscribeServerMockRecorder) RecvMsg(m any) *gomock.Call { +func (mr *MockRelayer_RelaySubscribeServerMockRecorder) RecvMsg(m interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RecvMsg", reflect.TypeOf((*MockRelayer_RelaySubscribeServer)(nil).RecvMsg), m) } @@ -350,7 +346,7 @@ func (m *MockRelayer_RelaySubscribeServer) Send(arg0 *RelayReply) error { } // Send indicates an expected call of Send. -func (mr *MockRelayer_RelaySubscribeServerMockRecorder) Send(arg0 any) *gomock.Call { +func (mr *MockRelayer_RelaySubscribeServerMockRecorder) Send(arg0 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Send", reflect.TypeOf((*MockRelayer_RelaySubscribeServer)(nil).Send), arg0) } @@ -364,7 +360,7 @@ func (m *MockRelayer_RelaySubscribeServer) SendHeader(arg0 metadata.MD) error { } // SendHeader indicates an expected call of SendHeader. -func (mr *MockRelayer_RelaySubscribeServerMockRecorder) SendHeader(arg0 any) *gomock.Call { +func (mr *MockRelayer_RelaySubscribeServerMockRecorder) SendHeader(arg0 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendHeader", reflect.TypeOf((*MockRelayer_RelaySubscribeServer)(nil).SendHeader), arg0) } @@ -378,7 +374,7 @@ func (m_2 *MockRelayer_RelaySubscribeServer) SendMsg(m any) error { } // SendMsg indicates an expected call of SendMsg. -func (mr *MockRelayer_RelaySubscribeServerMockRecorder) SendMsg(m any) *gomock.Call { +func (mr *MockRelayer_RelaySubscribeServerMockRecorder) SendMsg(m interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendMsg", reflect.TypeOf((*MockRelayer_RelaySubscribeServer)(nil).SendMsg), m) } @@ -392,7 +388,7 @@ func (m *MockRelayer_RelaySubscribeServer) SetHeader(arg0 metadata.MD) error { } // SetHeader indicates an expected call of SetHeader. -func (mr *MockRelayer_RelaySubscribeServerMockRecorder) SetHeader(arg0 any) *gomock.Call { +func (mr *MockRelayer_RelaySubscribeServerMockRecorder) SetHeader(arg0 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetHeader", reflect.TypeOf((*MockRelayer_RelaySubscribeServer)(nil).SetHeader), arg0) } @@ -404,7 +400,7 @@ func (m *MockRelayer_RelaySubscribeServer) SetTrailer(arg0 metadata.MD) { } // SetTrailer indicates an expected call of SetTrailer. -func (mr *MockRelayer_RelaySubscribeServerMockRecorder) SetTrailer(arg0 any) *gomock.Call { +func (mr *MockRelayer_RelaySubscribeServerMockRecorder) SetTrailer(arg0 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetTrailer", reflect.TypeOf((*MockRelayer_RelaySubscribeServer)(nil).SetTrailer), arg0) }