diff --git a/hathor/p2p/dependencies/protocols.py b/hathor/p2p/dependencies/protocols.py index 662d61703..cda779fcc 100644 --- a/hathor/p2p/dependencies/protocols.py +++ b/hathor/p2p/dependencies/protocols.py @@ -12,12 +12,40 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Protocol +from __future__ import annotations + +from typing import TYPE_CHECKING, Iterable, Protocol from hathor.indexes.height_index import HeightInfo -from hathor.transaction import Block, Vertex +from hathor.p2p.peer_endpoint import PeerAddress +from hathor.transaction import BaseTransaction, Block, Vertex from hathor.types import VertexId +if TYPE_CHECKING: + from hathor.p2p.peer import PublicPeer, UnverifiedPeer + from hathor.p2p.peer_id import PeerId + from hathor.p2p.protocol import HathorProtocol + from hathor.p2p.sync_factory import SyncAgentFactory + from hathor.p2p.sync_version import SyncVersion + + +class P2PManagerProtocol(Protocol): + """Abstract the P2PManager as a Python protocol to be used in P2P classes.""" + + def is_peer_whitelisted(self, peer_id: PeerId) -> bool: ... + def get_enabled_sync_versions(self) -> set[SyncVersion]: ... + def get_sync_factory(self, sync_version: SyncVersion) -> SyncAgentFactory: ... + def get_verified_peers(self) -> Iterable[PublicPeer]: ... + def on_receive_peer(self, peer: UnverifiedPeer) -> None: ... + def on_peer_connect(self, protocol: HathorProtocol) -> None: ... + def on_peer_ready(self, protocol: HathorProtocol) -> None: ... + def on_handshake_disconnect(self, *, addr: PeerAddress) -> None: ... + def on_ready_disconnect(self, *, addr: PeerAddress, peer_id: PeerId) -> None: ... + def on_unknown_disconnect(self, *, addr: PeerAddress) -> None: ... + def get_randbytes(self, n: int) -> bytes: ... + def is_peer_ready(self, peer_id: PeerId) -> bool: ... + def send_tx_to_peers(self, tx: BaseTransaction) -> None: ... + class P2PVertexHandlerProtocol(Protocol): """Abstract the VertexHandler as a Python protocol to be used in P2P classes.""" diff --git a/hathor/p2p/manager.py b/hathor/p2p/manager.py index 352d6f4f2..4cffaa445 100644 --- a/hathor/p2p/manager.py +++ b/hathor/p2p/manager.py @@ -36,7 +36,6 @@ from hathor.p2p.peer_storage import UnverifiedPeerStorage, VerifiedPeerStorage from hathor.p2p.protocol import HathorProtocol from hathor.p2p.rate_limiter import RateLimiter -from hathor.p2p.states.ready import ReadyState from hathor.p2p.sync_factory import SyncAgentFactory from hathor.p2p.sync_version import SyncVersion from hathor.p2p.utils import parse_whitelist @@ -458,7 +457,7 @@ def is_peer_ready(self, peer_id: PeerId) -> bool: """ return self._connections.is_peer_ready(peer_id) - def on_receive_peer(self, peer: UnverifiedPeer, origin: Optional[ReadyState] = None) -> None: + def on_receive_peer(self, peer: UnverifiedPeer) -> None: """ Update a peer information in our storage, and instantly attempt to connect to it if it is not connected yet. """ @@ -623,7 +622,7 @@ def listen(self, description: str) -> None: if self.use_ssl: factory = TLSMemoryBIOFactory(self.my_peer.certificate_options, False, factory) - factory = NetfilterFactory(self, factory) + factory = NetfilterFactory(factory) self.log.info('trying to listen on', endpoint=description) deferred: Deferred[IListeningPort] = endpoint.listen(factory) diff --git a/hathor/p2p/netfilter/context.py b/hathor/p2p/netfilter/context.py index 5446f4f37..3ad53fbe2 100644 --- a/hathor/p2p/netfilter/context.py +++ b/hathor/p2p/netfilter/context.py @@ -17,15 +17,12 @@ if TYPE_CHECKING: from twisted.internet.interfaces import IAddress - from hathor.p2p.manager import ConnectionsManager from hathor.p2p.protocol import HathorProtocol class NetfilterContext: """Context sent to the targets when a match occurs.""" - def __init__(self, *, connections: Optional['ConnectionsManager'] = None, addr: Optional['IAddress'] = None, - protocol: Optional['HathorProtocol'] = None): + def __init__(self, *, addr: Optional['IAddress'] = None, protocol: Optional['HathorProtocol'] = None): """Initialize the context.""" self.addr = addr self.protocol = protocol - self.connections = connections diff --git a/hathor/p2p/netfilter/factory.py b/hathor/p2p/netfilter/factory.py index a42c3c1a3..9cf5b11a5 100644 --- a/hathor/p2p/netfilter/factory.py +++ b/hathor/p2p/netfilter/factory.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, Optional +from typing import Optional from twisted.internet.interfaces import IAddress, IProtocolFactory from twisted.internet.protocol import Protocol @@ -21,19 +21,14 @@ from hathor.p2p.netfilter import get_table from hathor.p2p.netfilter.context import NetfilterContext -if TYPE_CHECKING: - from hathor.p2p.manager import ConnectionsManager - class NetfilterFactory(WrappingFactory): """Wrapper factory to easily check new connections.""" - def __init__(self, connections: 'ConnectionsManager', wrappedFactory: 'IProtocolFactory'): + def __init__(self, wrappedFactory: 'IProtocolFactory'): super().__init__(wrappedFactory) - self.connections = connections def buildProtocol(self, addr: IAddress) -> Optional[Protocol]: context = NetfilterContext( - connections=self.connections, addr=addr, ) verdict = get_table('filter').get_chain('pre_conn').process(context) diff --git a/hathor/p2p/protocol.py b/hathor/p2p/protocol.py index 5159eb075..f11cff991 100644 --- a/hathor/p2p/protocol.py +++ b/hathor/p2p/protocol.py @@ -25,6 +25,7 @@ from twisted.python.failure import Failure from hathor.p2p import P2PDependencies +from hathor.p2p.dependencies.protocols import P2PManagerProtocol from hathor.p2p.messages import ProtocolMessages from hathor.p2p.peer import PrivatePeer, PublicPeer from hathor.p2p.peer_endpoint import PeerAddress @@ -75,8 +76,6 @@ class WarningFlags(str, Enum): NO_ENTRYPOINTS = 'no_entrypoints' my_peer: PrivatePeer - connections: 'ConnectionsManager' - node: 'HathorManager' app_version: str last_message: float _peer: Optional[PublicPeer] @@ -99,7 +98,7 @@ def peer(self) -> PublicPeer: def __init__( self, my_peer: PrivatePeer, - p2p_manager: 'ConnectionsManager', + p2p_manager: P2PManagerProtocol, *, dependencies: P2PDependencies, use_ssl: bool, @@ -109,14 +108,10 @@ def __init__( self.dependencies = dependencies self._settings = dependencies.settings self.my_peer = my_peer - self.connections = p2p_manager + self.p2p_manager: P2PManagerProtocol = p2p_manager self.addr = addr - assert p2p_manager.manager is not None - self.node = p2p_manager.manager - - assert self.connections.reactor is not None - self.reactor = self.connections.reactor + self.reactor = self.dependencies.reactor # Indicate whether it is an inbound connection (true) or an outbound connection (false). self.inbound = inbound @@ -250,8 +245,7 @@ def on_connect(self) -> None: # The initial state is HELLO. self.change_state(self.PeerState.HELLO) - if self.connections: - self.connections.on_peer_connect(self) + self.p2p_manager.on_peer_connect(self) def on_outbound_connect(self, peer_id: PeerId | None) -> None: """Called when we successfully establish an outbound connection to a peer.""" @@ -260,10 +254,9 @@ def on_outbound_connect(self, peer_id: PeerId | None) -> None: self.expected_peer_id = peer_id def on_peer_ready(self) -> None: - assert self.connections is not None assert self.peer is not None self.update_log_context() - self.connections.on_peer_ready(self) + self.p2p_manager.on_peer_ready(self) self.log.info('peer connected', peer_id=self.peer.id) def on_disconnect(self, reason: Failure) -> None: @@ -288,19 +281,19 @@ def on_disconnect(self, reason: Failure) -> None: addr=str(self.addr), peer_id=str(self.get_peer_id()), ) - self.connections.on_unknown_disconnect(addr=self.addr) + self.p2p_manager.on_unknown_disconnect(addr=self.addr) return self.state.on_exit() state_name = self.state.state_name if self.is_state(self.PeerState.HELLO) or self.is_state(self.PeerState.PEER_ID): self.state = None - self.connections.on_handshake_disconnect(addr=self.addr) + self.p2p_manager.on_handshake_disconnect(addr=self.addr) return if self.is_state(self.PeerState.READY): self.state = None - self.connections.on_ready_disconnect(addr=self.addr, peer_id=self.peer.id) + self.p2p_manager.on_ready_disconnect(addr=self.addr, peer_id=self.peer.id) return self.state = None diff --git a/hathor/p2p/states/hello.py b/hathor/p2p/states/hello.py index 9c034b7cb..20e45736f 100644 --- a/hathor/p2p/states/hello.py +++ b/hathor/p2p/states/hello.py @@ -65,9 +65,7 @@ def _get_hello_data(self) -> dict[str, Any]: def _get_sync_versions(self) -> set[SyncVersion]: """Shortcut to ConnectionManager.get_enabled_sync_versions""" - connections_manager = self.protocol.connections - assert connections_manager is not None - return connections_manager.get_enabled_sync_versions() + return self.protocol.p2p_manager.get_enabled_sync_versions() def on_enter(self) -> None: # After a connection is made, we just send a HELLO message. @@ -162,7 +160,6 @@ def handle_hello(self, payload: str) -> None: context = NetfilterContext( protocol=protocol, - connections=protocol.connections, addr=protocol.transport.getPeer(), ) verdict = get_table('filter').get_chain('post_hello').process(context) diff --git a/hathor/p2p/states/peer_id.py b/hathor/p2p/states/peer_id.py index 2ca93ea59..f305db386 100644 --- a/hathor/p2p/states/peer_id.py +++ b/hathor/p2p/states/peer_id.py @@ -110,10 +110,9 @@ async def handle_peer_id(self, payload: str) -> None: protocol.send_error_and_close_connection('Are you my clone?!') return - if protocol.connections is not None: - if protocol.connections.is_peer_ready(peer.id): - protocol.send_error_and_close_connection('We are already connected.') - return + if self.protocol.p2p_manager.is_peer_ready(peer.id): + protocol.send_error_and_close_connection('We are already connected.') + return entrypoint_valid = await peer.info.validate_entrypoint(protocol) if not entrypoint_valid: @@ -131,7 +130,6 @@ async def handle_peer_id(self, payload: str) -> None: context = NetfilterContext( protocol=protocol, - connections=protocol.connections, addr=protocol.transport.getPeer(), ) verdict = get_table('filter').get_chain('post_peerid').process(context) @@ -146,7 +144,7 @@ def _should_block_peer(self, peer_id: PeerId) -> bool: Currently this is only because the peer is not in a whitelist and whitelist blocking is active. """ - peer_is_whitelisted = self.protocol.connections.is_peer_whitelisted(peer_id) + peer_is_whitelisted = self.protocol.p2p_manager.is_peer_whitelisted(peer_id) # never block whitelisted peers if peer_is_whitelisted: return False diff --git a/hathor/p2p/states/ready.py b/hathor/p2p/states/ready.py index fb3db3ab9..f258cfada 100644 --- a/hathor/p2p/states/ready.py +++ b/hathor/p2p/states/ready.py @@ -96,22 +96,18 @@ def __init__(self, protocol: 'HathorProtocol', *, dependencies: P2PDependencies) ProtocolMessages.BEST_BLOCKCHAIN: self.handle_best_blockchain, }) - # Initialize sync manager and add its commands to the list of available commands. - connections = self.protocol.connections - assert connections is not None - - # Get the sync factory and create a sync manager from it + # Get the sync factory and create a sync agent from it sync_version = self.protocol.sync_version assert sync_version is not None self.log.debug(f'loading {sync_version}') - sync_factory = connections.get_sync_factory(sync_version) + sync_factory = self.protocol.p2p_manager.get_sync_factory(sync_version) + # Initialize sync agent and add its commands to the list of available commands. self.sync_agent: SyncAgent = sync_factory.create_sync_agent(self.protocol) self.cmd_map.update(self.sync_agent.get_cmd_dict()) def on_enter(self) -> None: - if self.protocol.connections: - self.protocol.on_peer_ready() + self.protocol.on_peer_ready() self.lc_ping.start(1, now=False) @@ -155,7 +151,7 @@ def handle_get_peers(self, payload: str) -> None: """ Executed when a GET-PEERS command is received. It just responds with a list of all known peers. """ - for peer in self.protocol.connections.get_verified_peers(): + for peer in self.protocol.p2p_manager.get_verified_peers(): self.send_peers([peer]) def send_peers(self, peer_list: Iterable[PublicPeer]) -> None: @@ -175,8 +171,7 @@ def handle_peers(self, payload: str) -> None: received_peers = json_loads(payload) for data in received_peers: peer = UnverifiedPeer.create_from_json(data) - if self.protocol.connections: - self.protocol.connections.on_receive_peer(peer, origin=self) + self.protocol.p2p_manager.on_receive_peer(peer) self.log.debug('received peers', payload=payload) def send_ping_if_necessary(self) -> None: @@ -195,7 +190,7 @@ def send_ping(self) -> None: """ # Add a salt number to prevent peers from faking rtt. self.ping_start_time = self.reactor.seconds() - self.ping_salt = self.protocol.connections.get_randbytes(self.ping_salt_size).hex() + self.ping_salt = self.protocol.p2p_manager.get_randbytes(self.ping_salt_size).hex() self.send_message(ProtocolMessages.PING, self.ping_salt) def send_pong(self, salt: str) -> None: diff --git a/hathor/p2p/sync_v1/agent.py b/hathor/p2p/sync_v1/agent.py index 7291db0e9..29a8575ff 100644 --- a/hathor/p2p/sync_v1/agent.py +++ b/hathor/p2p/sync_v1/agent.py @@ -23,6 +23,7 @@ from twisted.internet.interfaces import IDelayedCall from hathor.p2p import P2PDependencies +from hathor.p2p.manager import ConnectionsManager from hathor.p2p.messages import GetNextPayload, GetTipsPayload, NextPayload, ProtocolMessages, TipsPayload from hathor.p2p.sync_agent import SyncAgent from hathor.p2p.sync_v1.downloader import Downloader @@ -86,9 +87,9 @@ def __init__( self.tx_storage = self.dependencies.tx_storage # Rate limit for this connection. - assert protocol.connections is not None - self.global_rate_limiter: 'RateLimiter' = protocol.connections.rate_limiter - self.GlobalRateLimiter = protocol.connections.GlobalRateLimiter + assert isinstance(self.protocol.p2p_manager, ConnectionsManager) + self.global_rate_limiter: 'RateLimiter' = self.protocol.p2p_manager.rate_limiter + self.GlobalRateLimiter = self.protocol.p2p_manager.GlobalRateLimiter self.call_later_id: Optional[IDelayedCall] = None self.call_later_interval: int = 1 # seconds @@ -639,7 +640,7 @@ def handle_data(self, payload: str) -> None: # in the network, thus, we propagate it as well. success = self.dependencies.vertex_handler.on_new_vertex(tx) if success: - self.protocol.connections.send_tx_to_peers(tx) + self.protocol.p2p_manager.send_tx_to_peers(tx) self.update_received_stats(tx, success) def update_received_stats(self, tx: 'BaseTransaction', result: bool) -> None: @@ -690,7 +691,7 @@ def on_tx_success(self, tx: 'BaseTransaction') -> 'BaseTransaction': # Add tx to the DAG. success = self.dependencies.vertex_handler.on_new_vertex(tx) if success: - self.protocol.connections.send_tx_to_peers(tx) + self.protocol.p2p_manager.send_tx_to_peers(tx) # Updating stats data self.update_received_stats(tx, success) return tx diff --git a/hathor/p2p/sync_v2/agent.py b/hathor/p2p/sync_v2/agent.py index 2a3b75f64..766d676b9 100644 --- a/hathor/p2p/sync_v2/agent.py +++ b/hathor/p2p/sync_v2/agent.py @@ -270,7 +270,6 @@ def handle_not_found(self, payload: str) -> None: def handle_error(self, payload: str) -> None: """ Override protocols original handle_error so we can recover a sync in progress. """ - assert self.protocol.connections is not None # forward message to overloaded handle_error: self.protocol.handle_error(payload) @@ -320,7 +319,6 @@ def _run_sync(self) -> Generator[Any, Any, None]: """ assert not self.receiving_stream assert not self.mempool_manager.is_running() - assert self.protocol.connections is not None is_block_synced = yield self.run_sync_blocks() if is_block_synced: @@ -742,7 +740,6 @@ def handle_blocks_end(self, payload: str) -> None: response_code = StreamEnd(int(payload)) self.receiving_stream = False - assert self.protocol.connections is not None if self.state is not PeerState.SYNCING_BLOCKS: self.log.error('unexpected BLOCKS-END', state=self.state, response_code=response_code.name) @@ -761,8 +758,6 @@ def handle_blocks(self, payload: str) -> None: self.protocol.send_error_and_close_connection('Not expecting to receive BLOCK message') return - assert self.protocol.connections is not None - blk_bytes = base64.b64decode(payload) blk = self.vertex_parser.deserialize(blk_bytes) if not isinstance(blk, Block): @@ -998,7 +993,6 @@ def handle_transactions_end(self, payload: str) -> None: response_code = StreamEnd(int(payload)) self.receiving_stream = False - assert self.protocol.connections is not None if self.state is not PeerState.SYNCING_TRANSACTIONS: self.log.error('unexpected TRANSACTIONS-END', state=self.state, response_code=response_code.name) @@ -1012,9 +1006,6 @@ def handle_transactions_end(self, payload: str) -> None: def handle_transaction(self, payload: str) -> None: """ Handle a TRANSACTION message. """ - assert self.protocol.connections is not None - - # tx_bytes = bytes.fromhex(payload) tx_bytes = base64.b64decode(payload) tx = self.vertex_parser.deserialize(tx_bytes) if not isinstance(tx, Transaction): @@ -1157,7 +1148,7 @@ def handle_data(self, payload: str) -> None: try: success = self.dependencies.vertex_handler.on_new_vertex(tx, fails_silently=False) if success: - self.protocol.connections.send_tx_to_peers(tx) + self.protocol.p2p_manager.send_tx_to_peers(tx) except InvalidNewTransaction: self.protocol.send_error_and_close_connection('invalid vertex received') else: diff --git a/hathor/p2p/sync_v2/mempool.py b/hathor/p2p/sync_v2/mempool.py index c7ff2c363..e0de637bd 100644 --- a/hathor/p2p/sync_v2/mempool.py +++ b/hathor/p2p/sync_v2/mempool.py @@ -142,7 +142,7 @@ def _add_tx(self, tx: BaseTransaction) -> None: try: success = self.dependencies.vertex_handler.on_new_vertex(tx, fails_silently=False) if success: - self.sync_agent.protocol.connections.send_tx_to_peers(tx) + self.sync_agent.protocol.p2p_manager.send_tx_to_peers(tx) except InvalidNewTransaction: self.sync_agent.protocol.send_error_and_close_connection('invalid vertex received') raise diff --git a/tests/p2p/netfilter/test_factory.py b/tests/p2p/netfilter/test_factory.py index 53ca409c8..990aec74c 100644 --- a/tests/p2p/netfilter/test_factory.py +++ b/tests/p2p/netfilter/test_factory.py @@ -1,5 +1,3 @@ -from unittest.mock import Mock - from twisted.internet.address import IPv4Address from hathor.p2p.netfilter import get_table @@ -22,7 +20,7 @@ def test_factory(self) -> None: builder = TestBuilder() artifacts = builder.build() wrapped_factory = artifacts.p2p_manager.server_factory - factory = NetfilterFactory(connections=Mock(), wrappedFactory=wrapped_factory) + factory = NetfilterFactory(wrappedFactory=wrapped_factory) ret = factory.buildProtocol(IPv4Address('TCP', '192.168.0.1', 1234)) self.assertIsNone(ret) diff --git a/tests/p2p/test_sync.py b/tests/p2p/test_sync.py index 533d14192..bd23a9e87 100644 --- a/tests/p2p/test_sync.py +++ b/tests/p2p/test_sync.py @@ -267,7 +267,7 @@ def test_downloader(self) -> None: self.assertTrue(isinstance(conn.proto1.state, PeerIdState)) self.assertTrue(isinstance(conn.proto2.state, PeerIdState)) - downloader = conn.proto2.connections.get_sync_factory(SyncVersion.V1_1).get_downloader() + downloader = conn.proto2.p2p_manager.get_sync_factory(SyncVersion.V1_1).get_downloader() p2p_dependencies1 = P2PDependencies( reactor=self.manager1.reactor,