From 36da2f5b579699f37ac940afc646abe9f877c61b Mon Sep 17 00:00:00 2001 From: Gabriel Levcovitz Date: Wed, 6 Nov 2024 22:59:32 -0300 Subject: [PATCH] refactor(p2p): implement initial state --- hathor/p2p/dependencies/protocols.py | 2 +- hathor/p2p/factory.py | 13 +++-- hathor/p2p/manager.py | 15 ++++-- hathor/p2p/peer_connections.py | 46 +++++++++++----- hathor/p2p/peer_endpoint.py | 4 ++ hathor/p2p/protocol.py | 78 ++++++++-------------------- hathor/p2p/resources/status.py | 4 +- hathor/p2p/states/__init__.py | 3 +- hathor/p2p/states/base.py | 22 +++++--- hathor/p2p/states/hello.py | 19 +++++-- hathor/p2p/states/initial.py | 27 ++++++++++ hathor/p2p/states/peer_id.py | 21 ++++++-- hathor/p2p/states/ready.py | 16 ++++-- hathor/simulator/fake_connection.py | 8 +-- tests/others/test_metrics.py | 22 ++++---- tests/p2p/netfilter/test_match.py | 7 +-- tests/p2p/test_capabilities.py | 16 +++--- tests/p2p/test_protocol.py | 31 +++++++---- tests/p2p/test_sync_enabled.py | 3 +- tests/sysctl/test_p2p.py | 3 +- 20 files changed, 221 insertions(+), 139 deletions(-) create mode 100644 hathor/p2p/states/initial.py diff --git a/hathor/p2p/dependencies/protocols.py b/hathor/p2p/dependencies/protocols.py index cda779fcc..ba632d564 100644 --- a/hathor/p2p/dependencies/protocols.py +++ b/hathor/p2p/dependencies/protocols.py @@ -41,7 +41,7 @@ def on_peer_connect(self, protocol: HathorProtocol) -> None: ... def on_peer_ready(self, protocol: HathorProtocol) -> None: ... def on_handshake_disconnect(self, *, addr: PeerAddress) -> None: ... def on_ready_disconnect(self, *, addr: PeerAddress, peer_id: PeerId) -> None: ... - def on_unknown_disconnect(self, *, addr: PeerAddress) -> None: ... + def on_initial_disconnect(self, *, addr: PeerAddress) -> None: ... def get_randbytes(self, n: int) -> bytes: ... def is_peer_ready(self, peer_id: PeerId) -> bool: ... def send_tx_to_peers(self, tx: BaseTransaction) -> None: ... diff --git a/hathor/p2p/factory.py b/hathor/p2p/factory.py index 9b80fc5e7..f2fd8b64d 100644 --- a/hathor/p2p/factory.py +++ b/hathor/p2p/factory.py @@ -13,6 +13,7 @@ # limitations under the License. from abc import ABC +from typing import Callable from twisted.internet import protocol from twisted.internet.interfaces import IAddress @@ -21,7 +22,7 @@ from hathor.p2p.manager import ConnectionsManager from hathor.p2p.peer import PrivatePeer from hathor.p2p.peer_endpoint import PeerAddress -from hathor.p2p.protocol import HathorLineReceiver +from hathor.p2p.protocol import HathorLineReceiver, HathorProtocol class _HathorLineReceiverFactory(ABC, protocol.Factory): @@ -34,22 +35,28 @@ def __init__( *, dependencies: P2PDependencies, use_ssl: bool, + built_protocol_callback: Callable[[PeerAddress, HathorProtocol], None] | None, ): super().__init__() self.my_peer = my_peer self.p2p_manager = p2p_manager self.dependencies = dependencies self.use_ssl = use_ssl + self._built_protocol_callback = built_protocol_callback def buildProtocol(self, addr: IAddress) -> HathorLineReceiver: - return HathorLineReceiver( - addr=PeerAddress.from_address(addr), + peer_addr = PeerAddress.from_address(addr) + hathor_protocol = HathorLineReceiver( + addr=peer_addr, my_peer=self.my_peer, p2p_manager=self.p2p_manager, dependencies=self.dependencies, use_ssl=self.use_ssl, inbound=self.inbound, ) + if self._built_protocol_callback: + self._built_protocol_callback(peer_addr, hathor_protocol) + return hathor_protocol class HathorServerFactory(_HathorLineReceiverFactory, protocol.ServerFactory): diff --git a/hathor/p2p/manager.py b/hathor/p2p/manager.py index 4cffaa445..de15b1b5e 100644 --- a/hathor/p2p/manager.py +++ b/hathor/p2p/manager.py @@ -123,12 +123,14 @@ def __init__( p2p_manager=self, dependencies=dependencies, use_ssl=self.use_ssl, + built_protocol_callback=self._on_built_protocol, ) self.client_factory = HathorClientFactory( my_peer=self.my_peer, p2p_manager=self, dependencies=dependencies, use_ssl=self.use_ssl, + built_protocol_callback=self._on_built_protocol, ) # Global maximum number of connections. @@ -360,7 +362,7 @@ def on_peer_connect(self, protocol: HathorProtocol) -> None: protocol.disconnect(force=True) return - self._connections.on_connected(protocol=protocol) + self._connections.on_connected(addr=protocol.addr, inbound=protocol.inbound) self.pubsub.publish( HathorEvents.NETWORK_PEER_CONNECTED, protocol=protocol, @@ -423,9 +425,9 @@ def on_ready_disconnect(self, *, addr: PeerAddress, peer_id: PeerId) -> None: peers_count=self._get_peers_count() ) - def on_unknown_disconnect(self, *, addr: PeerAddress) -> None: - """Called when a peer disconnects from an unknown state (None).""" - self._connections.on_unknown_disconnect(addr=addr) + def on_initial_disconnect(self, *, addr: PeerAddress) -> None: + """Called when a peer disconnects from the initial state.""" + self._connections.on_initial_disconnect(addr=addr) self.pubsub.publish( HathorEvents.NETWORK_PEER_DISCONNECTED, peers_count=self._get_peers_count() @@ -575,7 +577,7 @@ def connect_to( if endpoint.peer_id is not None and peer is not None: assert endpoint.peer_id == peer.id, 'the entrypoint peer_id does not match the actual peer_id' - already_exists = self._connections.on_connecting(addr=endpoint.addr) + already_exists = self._connections.on_start_connection(addr=endpoint.addr) if already_exists: self.log.debug('skipping because we are already connected(ing) to this endpoint', endpoint=str(endpoint)) return None @@ -643,6 +645,9 @@ def _on_listen_success(self, listening_port: IListeningPort, description: str) - if self.manager.hostname: self._add_hostname_entrypoint(self.manager.hostname, address) + def _on_built_protocol(self, addr: PeerAddress, protocol: HathorProtocol) -> None: + self._connections.on_built_protocol(addr=addr, protocol=protocol) + def update_hostname_entrypoints(self, *, old_hostname: str | None, new_hostname: str) -> None: """Add new hostname entrypoints according to the listen addresses, and remove any old entrypoint.""" assert self.manager is not None diff --git a/hathor/p2p/peer_connections.py b/hathor/p2p/peer_connections.py index d42d4acc1..ec5475048 100644 --- a/hathor/p2p/peer_connections.py +++ b/hathor/p2p/peer_connections.py @@ -33,7 +33,7 @@ class PeerConnections: It's also responsible for reacting for state changes on those connections. """ - __slots__ = ('_connecting_outbound', '_handshaking', '_ready', '_addr_by_id') + __slots__ = ('_connecting_outbound', '_built', '_handshaking', '_ready', '_addr_by_id') def __init__(self) -> None: # Peers that are in the "connecting" state, between starting a connection and Twisted calling `connectionMade`. @@ -41,6 +41,9 @@ def __init__(self) -> None: # They're uniquely identified by the address we're connecting to. self._connecting_outbound: set[PeerAddress] = set() + # Peers that had their protocol instances built, before getting connected. + self._built: dict[PeerAddress, HathorProtocol] = {} + # Peers that are handshaking, in a state after being connected and before reaching the READY state. # They're uniquely identified by the address we're connected to. self._handshaking: dict[PeerAddress, HathorProtocol] = {} @@ -71,7 +74,7 @@ def ready_peers(self) -> dict[PeerAddress, HathorProtocol]: def not_ready_peers(self) -> list[PeerAddress]: """Get not ready peers, that is, peers that are either connecting or handshaking.""" - return list(self._connecting_outbound) + list(self._handshaking) + return list(self._built) + list(self._connecting_outbound) + list(self._handshaking) def connected_peers(self) -> dict[PeerAddress, HathorProtocol]: """ @@ -104,39 +107,49 @@ def is_peer_ready(self, peer_id: PeerId) -> bool: """Return whether a peer is ready, by its PeerId.""" return peer_id in self._addr_by_id - def on_connecting(self, *, addr: PeerAddress) -> bool: + def on_start_connection(self, *, addr: PeerAddress) -> bool: """ Callback for when an outbound connection is initiated. - Returns True if this address already exists, either connecting or connected, and False otherwise.""" + Returns True if this address already exists, either connecting or connected, and False otherwise. + """ if addr in self.all_peers(): return True self._connecting_outbound.add(addr) return False + def on_built_protocol(self, *, addr: PeerAddress, protocol: HathorProtocol) -> None: + """Callback for when a HathorProtocol instance is built.""" + assert addr not in self._built + assert addr not in self.connected_peers() + self._built[addr] = protocol + def on_failed_to_connect(self, *, addr: PeerAddress) -> None: """Callback for when an outbound connection fails before getting connected.""" assert addr in self._connecting_outbound assert addr not in self.connected_peers() self._connecting_outbound.remove(addr) - def on_connected(self, *, protocol: HathorProtocol) -> None: - """Callback for when an outbound connection gets connected.""" - assert protocol.addr not in self.connected_peers() + def on_connected(self, *, addr: PeerAddress, inbound: bool) -> None: + """Callback for when a connection is made from both inbound and outbound peers.""" + assert addr in self._built + assert addr not in self.connected_peers() - if protocol.inbound: - assert protocol.addr not in self._connecting_outbound + if inbound: + assert addr not in self._connecting_outbound else: - assert protocol.addr in self._connecting_outbound - self._connecting_outbound.remove(protocol.addr) + assert addr in self._connecting_outbound + self._connecting_outbound.remove(addr) - self._handshaking[protocol.addr] = protocol + protocol = self._built.pop(addr) + self._handshaking[addr] = protocol def on_handshake_disconnect(self, *, addr: PeerAddress) -> None: """ Callback for when a connection is closed during a handshaking state, that is, after getting connected and before getting READY. """ + assert addr not in self._built assert addr not in self._connecting_outbound assert addr in self._handshaking assert addr not in self._ready @@ -148,6 +161,7 @@ def on_ready(self, *, addr: PeerAddress, peer_id: PeerId) -> HathorProtocol | No If the PeerId of this connection is duplicate, return the protocol that we should disconnect. Return None otherwise. """ + assert addr not in self._built assert addr not in self._connecting_outbound assert addr in self._handshaking assert addr not in self._ready @@ -173,6 +187,7 @@ def on_ready(self, *, addr: PeerAddress, peer_id: PeerId) -> HathorProtocol | No def on_ready_disconnect(self, *, addr: PeerAddress, peer_id: PeerId) -> None: """Callback for when a connection is closed during the READY state.""" + assert addr not in self._built assert addr not in self._connecting_outbound assert addr not in self._handshaking assert addr in self._ready @@ -181,10 +196,13 @@ def on_ready_disconnect(self, *, addr: PeerAddress, peer_id: PeerId) -> None: if self._addr_by_id[peer_id] == addr: self._addr_by_id.pop(peer_id) - def on_unknown_disconnect(self, *, addr: PeerAddress) -> None: - """Callback for when a connection is closed during an unknown state.""" + def on_initial_disconnect(self, *, addr: PeerAddress) -> None: + """Callback for when a connection is closed during the initial state.""" + assert addr in self._built assert addr not in self._handshaking assert addr not in self._ready + + self._built.pop(addr) if addr in self._connecting_outbound: self._connecting_outbound.remove(addr) diff --git a/hathor/p2p/peer_endpoint.py b/hathor/p2p/peer_endpoint.py index 47ff422ed..22ebb2d7a 100644 --- a/hathor/p2p/peer_endpoint.py +++ b/hathor/p2p/peer_endpoint.py @@ -118,6 +118,10 @@ def __eq__(self, other: Any) -> bool: def __ne__(self, other: Any) -> bool: return not self == other + def __hash__(self): + host = 'localhost' if self.is_localhost() else self.host + return hash((self.protocol, host, self.port)) + @classmethod def parse(cls, description: str) -> Self: protocol, host, port, query = _parse_address_parts(description) diff --git a/hathor/p2p/protocol.py b/hathor/p2p/protocol.py index f11cff991..4b93ec137 100644 --- a/hathor/p2p/protocol.py +++ b/hathor/p2p/protocol.py @@ -31,7 +31,7 @@ from hathor.p2p.peer_endpoint import PeerAddress from hathor.p2p.peer_id import PeerId from hathor.p2p.rate_limiter import RateLimiter -from hathor.p2p.states import BaseState, HelloState, PeerIdState, ReadyState +from hathor.p2p.states import BaseState, HelloState, InitialState, PeerIdState, ReadyState from hathor.p2p.sync_version import SyncVersion from hathor.p2p.utils import format_address from hathor.profiler import get_cpu_profiler @@ -64,11 +64,6 @@ class HathorProtocol: The available commands are listed in the ProtocolMessages class. """ - class PeerState(Enum): - HELLO = HelloState - PEER_ID = PeerIdState - READY = ReadyState - class RateLimitKeys(str, Enum): GLOBAL = 'global' @@ -80,9 +75,7 @@ class WarningFlags(str, Enum): last_message: float _peer: Optional[PublicPeer] transport: Optional[ITransport] - state: Optional[BaseState] connection_time: float - _state_instances: dict[PeerState, BaseState] warning_flags: set[str] aborting: bool diff_timestamp: Optional[int] @@ -120,8 +113,6 @@ def __init__( self.idle_timeout = self._settings.PEER_IDLE_TIMEOUT self._idle_timeout_call_later: Optional[IDelayedCall] = None - self._state_instances = {} - self.app_version = 'Unknown' self.diff_timestamp = None @@ -139,7 +130,7 @@ def __init__( self.connection_time = 0.0 # The current state of the connection. - self.state: Optional[BaseState] = None + self.state: BaseState = InitialState(dependencies=dependencies) # Default rate limit self.ratelimit: RateLimiter = RateLimiter(self.reactor) @@ -165,24 +156,14 @@ def __init__( self.capabilities = set() - def change_state(self, state_enum: PeerState) -> None: - """Called to change the state of the connection.""" - if state_enum not in self._state_instances: - state_cls = state_enum.value - instance = state_cls(self, dependencies=self.dependencies) - instance.state_name = state_enum.name - self._state_instances[state_enum] = instance - new_state = self._state_instances[state_enum] - if new_state != self.state: - if self.state: - self.state.on_exit() - self.state = new_state - if self.state: - self.state.on_enter() - - def is_state(self, state_enum: PeerState) -> bool: - """Checks whether current state is `state_enum`.""" - return isinstance(self.state, state_enum.value) + def advance_state(self) -> None: + """Called to advance the state of the connection.""" + new_state_type = self.state.next_state_type() + assert new_state_type is not None + new_state: BaseState = new_state_type(dependencies=self.dependencies, protocol=self) + self.state.on_exit() + self.state = new_state + self.state.on_enter() def get_short_remote(self) -> str: """Get remote for logging.""" @@ -242,8 +223,8 @@ def on_connect(self) -> None: self.reset_idle_timeout() - # The initial state is HELLO. - self.change_state(self.PeerState.HELLO) + # The first state after INITIAL is HELLO. + self.advance_state() self.p2p_manager.on_peer_connect(self) @@ -262,7 +243,7 @@ def on_peer_ready(self) -> None: def on_disconnect(self, reason: Failure) -> None: """ Executed when the connection is lost. """ - if self.is_state(self.PeerState.READY): + if isinstance(self.state, ReadyState): self.log.info('disconnected', reason=reason.getErrorMessage()) else: self.log.debug('disconnected', reason=reason.getErrorMessage()) @@ -271,33 +252,21 @@ def on_disconnect(self, reason: Failure) -> None: self._idle_timeout_call_later = None self.aborting = True self.update_log_context() + self.state.on_exit() - if not self.state: - # TODO: This should never happen, it can only happen if an exception was raised in the middle of our - # connection callback (connectionMade/on_connect). In that case, we may have not initialized our state - # yet. We should improve this by making an initial non-None state. - self.log.error( - 'disconnecting protocol with no state. check for previous exceptions', - addr=str(self.addr), - peer_id=str(self.get_peer_id()), - ) - self.p2p_manager.on_unknown_disconnect(addr=self.addr) + if isinstance(self.state, InitialState): + self.p2p_manager.on_initial_disconnect(addr=self.addr) return - self.state.on_exit() - state_name = self.state.state_name - if self.is_state(self.PeerState.HELLO) or self.is_state(self.PeerState.PEER_ID): - self.state = None + if isinstance(self.state, HelloState) or isinstance(self.state, PeerIdState): self.p2p_manager.on_handshake_disconnect(addr=self.addr) return - if self.is_state(self.PeerState.READY): - self.state = None + if isinstance(self.state, ReadyState): self.p2p_manager.on_ready_disconnect(addr=self.addr, peer_id=self.peer.id) return - self.state = None - raise AssertionError(f'disconnected in unexpected state: {state_name or "unknown"}') + raise AssertionError(f'disconnected in unexpected state: {self.state.name}') def send_message(self, cmd: ProtocolMessages, payload: Optional[str] = None) -> None: """ A generic message which must be implemented to send a message @@ -338,13 +307,13 @@ def recv_message(self, cmd: ProtocolMessages, payload: str) -> None: .addErrback(self._on_cmd_handler_error, cmd) def _on_cmd_handler_error(self, failure: Failure, cmd: ProtocolMessages) -> None: - self.log.warn('recv_message processing error', reason=failure.getErrorMessage(), exc_info=True) + self.log.warn('recv_message processing error', reason=failure.getErrorMessage()) self.send_error_and_close_connection(f'Error processing "{cmd.value}" command') def send_error(self, msg: str) -> None: """ Send an error message to the peer. """ - if self.is_state(self.PeerState.READY): + if isinstance(self.state, ReadyState): self.log.warn('send error', msg=msg) else: self.log.debug('send error', msg=msg) @@ -386,21 +355,18 @@ def handle_error(self, payload: str) -> None: def is_sync_enabled(self) -> bool: """Return true if sync is enabled for this connection.""" - if not self.is_state(self.PeerState.READY): + if not isinstance(self.state, ReadyState): return False - assert isinstance(self.state, ReadyState) return self.state.sync_agent.is_sync_enabled() def enable_sync(self) -> None: """Enable sync for this connection.""" - assert self.is_state(self.PeerState.READY) assert isinstance(self.state, ReadyState) self.log.info('enable sync') self.state.sync_agent.enable_sync() def disable_sync(self) -> None: """Disable sync for this connection.""" - assert self.is_state(self.PeerState.READY) assert isinstance(self.state, ReadyState) self.log.info('disable sync') self.state.sync_agent.disable_sync() diff --git a/hathor/p2p/resources/status.py b/hathor/p2p/resources/status.py index 896220b9a..bb9ab4995 100644 --- a/hathor/p2p/resources/status.py +++ b/hathor/p2p/resources/status.py @@ -52,7 +52,7 @@ def render_GET(self, request: Request) -> bytes: for conn in self.manager.connections.iter_handshaking_peers(): handshaking_peers.append({ 'address': str(conn.addr), - 'state': conn.state.state_name, + 'state': conn.state.name, 'uptime': now - conn.connection_time, 'app_version': conn.app_version, }) @@ -67,7 +67,7 @@ def render_GET(self, request: Request) -> bytes: 'current_time': now, 'uptime': now - conn.connection_time, 'address': str(conn.addr), - 'state': conn.state.state_name, + 'state': conn.state.name, # 'received_bytes': conn.received_bytes, 'rtt': list(conn.state.rtt_window), 'last_message': now - conn.last_message, diff --git a/hathor/p2p/states/__init__.py b/hathor/p2p/states/__init__.py index 0e9be0913..97d4ba92b 100644 --- a/hathor/p2p/states/__init__.py +++ b/hathor/p2p/states/__init__.py @@ -14,7 +14,8 @@ from .base import BaseState from .hello import HelloState +from .initial import InitialState from .peer_id import PeerIdState from .ready import ReadyState -__all__ = ['BaseState', 'HelloState', 'PeerIdState', 'ReadyState'] +__all__ = ['BaseState', 'InitialState', 'HelloState', 'PeerIdState', 'ReadyState'] diff --git a/hathor/p2p/states/base.py b/hathor/p2p/states/base.py index 75a69140e..29e3ed418 100644 --- a/hathor/p2p/states/base.py +++ b/hathor/p2p/states/base.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + from collections.abc import Coroutine from typing import TYPE_CHECKING, Any, Callable, Optional @@ -29,24 +31,26 @@ class BaseState: - protocol: 'HathorProtocol' + name: str cmd_map: dict[ ProtocolMessages, Callable[[str], None] | Callable[[str], Deferred[None]] | Callable[[str], Coroutine[Deferred[None], Any, None]] ] - def __init__(self, protocol: 'HathorProtocol', *, dependencies: P2PDependencies): - self.log = logger.new(**protocol.get_logger_context()) + def __init__(self, dependencies: P2PDependencies, protocol: HathorProtocol | None = None): + self.log = logger.new(**protocol.get_logger_context()) if protocol else logger.new() self.dependencies = dependencies self._settings: HathorSettings = dependencies.settings - self.protocol = protocol + self._protocol = protocol self.cmd_map = { ProtocolMessages.ERROR: self.handle_error, ProtocolMessages.THROTTLE: self.handle_throttle, } - # This variable is set by HathorProtocol after instantiating the state - self.state_name = None + @property + def protocol(self) -> HathorProtocol: + assert self._protocol is not None + return self._protocol def handle_error(self, payload: str) -> None: self.protocol.handle_error(payload) @@ -66,7 +70,7 @@ def send_throttle(self, key: str) -> None: self.protocol.send_message(ProtocolMessages.THROTTLE, payload) def on_enter(self) -> None: - raise NotImplementedError + pass def on_exit(self) -> None: pass @@ -74,3 +78,7 @@ def on_exit(self) -> None: def prepare_to_disconnect(self) -> None: """Called when we will disconnect with the peer.""" pass + + @staticmethod + def next_state_type() -> type[BaseState] | None: + return None diff --git a/hathor/p2p/states/hello.py b/hathor/p2p/states/hello.py index 20e45736f..5530f6458 100644 --- a/hathor/p2p/states/hello.py +++ b/hathor/p2p/states/hello.py @@ -12,9 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + from typing import TYPE_CHECKING, Any from structlog import get_logger +from typing_extensions import override import hathor from hathor.conf.get_settings import get_global_settings @@ -27,14 +30,16 @@ from hathor.util import json_dumps, json_loads if TYPE_CHECKING: - from hathor.p2p.protocol import HathorProtocol # noqa: F401 + from hathor.p2p.protocol import HathorProtocol logger = get_logger() class HelloState(BaseState): - def __init__(self, protocol: 'HathorProtocol', *, dependencies: P2PDependencies) -> None: - super().__init__(protocol, dependencies=dependencies) + name: str = 'HELLO' + + def __init__(self, dependencies: P2PDependencies, protocol: HathorProtocol): + super().__init__(dependencies=dependencies, protocol=protocol) self.log = logger.new(**protocol.get_logger_context()) self.cmd_map.update({ ProtocolMessages.HELLO: self.handle_hello, @@ -167,7 +172,13 @@ def handle_hello(self, payload: str) -> None: self.protocol.disconnect('rejected by netfilter: filter post_hello', force=True) return - protocol.change_state(protocol.PeerState.PEER_ID) + protocol.advance_state() + + @staticmethod + @override + def next_state_type() -> type[BaseState] | None: + from hathor.p2p.states import PeerIdState + return PeerIdState def _parse_sync_versions(hello_data: dict[str, Any]) -> set[SyncVersion]: diff --git a/hathor/p2p/states/initial.py b/hathor/p2p/states/initial.py new file mode 100644 index 000000000..d93a39a61 --- /dev/null +++ b/hathor/p2p/states/initial.py @@ -0,0 +1,27 @@ +# Copyright 2024 Hathor Labs +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing_extensions import override + +from hathor.p2p.states import BaseState + + +class InitialState(BaseState): + name: str = 'INITIAL' + + @staticmethod + @override + def next_state_type() -> type[BaseState] | None: + from hathor.p2p.states import HelloState + return HelloState diff --git a/hathor/p2p/states/peer_id.py b/hathor/p2p/states/peer_id.py index f305db386..56ed6bbbb 100644 --- a/hathor/p2p/states/peer_id.py +++ b/hathor/p2p/states/peer_id.py @@ -12,9 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + from typing import TYPE_CHECKING, Any from structlog import get_logger +from typing_extensions import override from hathor.p2p import P2PDependencies from hathor.p2p.messages import ProtocolMessages @@ -24,14 +27,16 @@ from hathor.util import json_dumps, json_loads if TYPE_CHECKING: - from hathor.p2p.protocol import HathorProtocol # noqa: F401 + from hathor.p2p.protocol import HathorProtocol logger = get_logger() class PeerIdState(BaseState): - def __init__(self, protocol: 'HathorProtocol', *, dependencies: P2PDependencies) -> None: - super().__init__(protocol, dependencies=dependencies) + name: str = 'PEER_ID' + + def __init__(self, dependencies: P2PDependencies, protocol: HathorProtocol): + super().__init__(dependencies=dependencies, protocol=protocol) self.log = logger.new(remote=protocol.get_short_remote()) self.cmd_map.update({ ProtocolMessages.PEER_ID: self.handle_peer_id, @@ -52,7 +57,7 @@ def send_ready(self) -> None: self.send_message(ProtocolMessages.READY) if self.other_peer_ready: # In case both peers are already ready, we change the state to READY - self.protocol.change_state(self.protocol.PeerState.READY) + self.protocol.advance_state() def handle_ready(self, payload: str) -> None: """ Handles a received READY message @@ -61,7 +66,7 @@ def handle_ready(self, payload: str) -> None: if self.my_peer_ready: # In this case this peer already completed the peer-id validation # So it was just waiting for the ready message from the other peer to change the state to READY - self.protocol.change_state(self.protocol.PeerState.READY) + self.protocol.advance_state() def _get_peer_id_data(self) -> dict[str, Any]: my_peer = self.protocol.my_peer @@ -164,3 +169,9 @@ def _should_block_peer(self, peer_id: PeerId) -> bool: # default is not blocking, this will be sync-v2 peers not on whitelist when not on whitelist-only mode return False + + @staticmethod + @override + def next_state_type() -> type[BaseState] | None: + from hathor.p2p.states import ReadyState + return ReadyState diff --git a/hathor/p2p/states/ready.py b/hathor/p2p/states/ready.py index f258cfada..f7b561e12 100644 --- a/hathor/p2p/states/ready.py +++ b/hathor/p2p/states/ready.py @@ -12,11 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + from collections import deque from typing import TYPE_CHECKING, Iterable, Optional from structlog import get_logger from twisted.internet.task import LoopingCall +from typing_extensions import override from hathor.indexes.height_index import HeightInfo from hathor.p2p import P2PDependencies @@ -29,14 +32,16 @@ from hathor.util import json_dumps, json_loads if TYPE_CHECKING: - from hathor.p2p.protocol import HathorProtocol # noqa: F401 + from hathor.p2p.protocol import HathorProtocol logger = get_logger() class ReadyState(BaseState): - def __init__(self, protocol: 'HathorProtocol', *, dependencies: P2PDependencies) -> None: - super().__init__(protocol, dependencies=dependencies) + name: str = 'READY' + + def __init__(self, dependencies: P2PDependencies, protocol: HathorProtocol): + super().__init__(dependencies=dependencies, protocol=protocol) self.log = logger.new(**self.protocol.get_logger_context()) @@ -271,3 +276,8 @@ def handle_best_blockchain(self, payload: str) -> None: ) return self.peer_best_blockchain = best_blockchain + + @staticmethod + @override + def next_state_type() -> type[BaseState] | None: + return None diff --git a/hathor/simulator/fake_connection.py b/hathor/simulator/fake_connection.py index aefd804b4..bf1826ec4 100644 --- a/hathor/simulator/fake_connection.py +++ b/hathor/simulator/fake_connection.py @@ -285,13 +285,13 @@ def reconnect(self) -> None: self._buf1.clear() self._buf2.clear() - self._proto1 = self.manager1.connections.server_factory.buildProtocol(self.addr2) - self._proto2 = self.manager2.connections.client_factory.buildProtocol(self.addr1) - # When _fake_bootstrap_id is set we don't pass the peer because that's how bootstrap calls connect_to() - peer = self._proto1.my_peer.to_unverified_peer() if self._fake_bootstrap_id is False else None + peer = self.manager1.my_peer.to_unverified_peer() if self._fake_bootstrap_id is False else None deferred = self.manager2.connections.connect_to(self.entrypoint, peer) assert deferred is not None + + self._proto1 = self.manager1.connections.server_factory.buildProtocol(self.addr2) + self._proto2 = self.manager2.connections.client_factory.buildProtocol(self.addr1) deferred.callback(self._proto2) self.tr1 = HathorStringTransport(self._proto2.my_peer, peer_address=self.addr2) diff --git a/tests/others/test_metrics.py b/tests/others/test_metrics.py index 8a8dc546e..dac804846 100644 --- a/tests/others/test_metrics.py +++ b/tests/others/test_metrics.py @@ -67,17 +67,17 @@ def test_connections_manager_integration(self): "2": PrivatePeer.auto_generated(), "3": PrivatePeer.auto_generated(), }) - peer1 = Mock() - peer1.addr = PeerAddress.parse('tcp://localhost:40403') - peer2 = Mock() - peer2.addr = PeerAddress.parse('tcp://localhost:40404') - peer3 = Mock() - peer3.addr = PeerAddress.parse('tcp://localhost:40405') - p2p_manager._connections.on_connected(protocol=peer1) - p2p_manager._connections.on_connected(protocol=peer2) - p2p_manager._connections.on_connected(protocol=peer3) - p2p_manager._connections.on_ready(addr=peer1.addr, peer_id=Mock()) - p2p_manager._connections.on_ready(addr=peer2.addr, peer_id=Mock()) + addr1 = PeerAddress.parse('tcp://localhost:40403') + addr2 = PeerAddress.parse('tcp://localhost:40404') + addr3 = PeerAddress.parse('tcp://localhost:40405') + p2p_manager._connections.on_built_protocol(addr=addr1, protocol=Mock()) + p2p_manager._connections.on_built_protocol(addr=addr2, protocol=Mock()) + p2p_manager._connections.on_built_protocol(addr=addr3, protocol=Mock()) + p2p_manager._connections.on_connected(addr=addr1, inbound=True) + p2p_manager._connections.on_connected(addr=addr2, inbound=True) + p2p_manager._connections.on_connected(addr=addr3, inbound=True) + p2p_manager._connections.on_ready(addr=addr1, peer_id=Mock()) + p2p_manager._connections.on_ready(addr=addr2, peer_id=Mock()) # Execution endpoint = PeerEndpoint.parse('tcp://127.0.0.1:8005') diff --git a/tests/p2p/netfilter/test_match.py b/tests/p2p/netfilter/test_match.py index ad3929d85..1c865fc92 100644 --- a/tests/p2p/netfilter/test_match.py +++ b/tests/p2p/netfilter/test_match.py @@ -12,6 +12,7 @@ NetfilterMatchPeerId, ) from hathor.p2p.peer import PrivatePeer +from hathor.p2p.states import HelloState, PeerIdState, ReadyState from hathor.simulator import FakeConnection from tests import unittest @@ -208,20 +209,20 @@ def test_match_peer_id(self) -> None: manager2 = self.create_peer(network, peer=peer2) conn = FakeConnection(manager1, manager2) - self.assertTrue(conn.proto2.is_state(conn.proto2.PeerState.HELLO)) + self.assertTrue(isinstance(conn.proto2.state, HelloState)) matcher = NetfilterMatchPeerId(str(peer1.id)) context = NetfilterContext(protocol=conn.proto2) self.assertFalse(matcher.match(context)) conn.run_one_step() - self.assertTrue(conn.proto2.is_state(conn.proto2.PeerState.PEER_ID)) + self.assertTrue(isinstance(conn.proto2.state, PeerIdState)) self.assertFalse(matcher.match(context)) # Success because the connection is ready and proto2 is connected to proto1. conn.run_one_step() conn.run_one_step() - self.assertTrue(conn.proto2.is_state(conn.proto2.PeerState.READY)) + self.assertTrue(isinstance(conn.proto2.state, ReadyState)) self.assertTrue(matcher.match(context)) # Fail because proto1 is connected to proto2, and the peer id cannot match. diff --git a/tests/p2p/test_capabilities.py b/tests/p2p/test_capabilities.py index 022fb1fc6..cb2c32414 100644 --- a/tests/p2p/test_capabilities.py +++ b/tests/p2p/test_capabilities.py @@ -21,8 +21,8 @@ def test_capabilities(self) -> None: # Even if we don't have the capability we must connect because the whitelist url conf is None assert isinstance(conn._proto1.state, ReadyState) assert isinstance(conn._proto2.state, ReadyState) - self.assertEqual(conn._proto1.state.state_name, 'READY') - self.assertEqual(conn._proto2.state.state_name, 'READY') + self.assertEqual(conn._proto1.state.name, 'READY') + self.assertEqual(conn._proto2.state.name, 'READY') self.assertIsInstance(conn._proto1.state.sync_agent, NodeSyncTimestamp) self.assertIsInstance(conn._proto2.state.sync_agent, NodeSyncTimestamp) @@ -38,8 +38,8 @@ def test_capabilities(self) -> None: assert isinstance(conn2._proto1.state, ReadyState) assert isinstance(conn2._proto2.state, ReadyState) - self.assertEqual(conn2._proto1.state.state_name, 'READY') - self.assertEqual(conn2._proto2.state.state_name, 'READY') + self.assertEqual(conn2._proto1.state.name, 'READY') + self.assertEqual(conn2._proto2.state.name, 'READY') self.assertIsInstance(conn2._proto1.state.sync_agent, NodeSyncTimestamp) self.assertIsInstance(conn2._proto2.state.sync_agent, NodeSyncTimestamp) @@ -61,8 +61,8 @@ def test_capabilities(self) -> None: # Even if we don't have the capability we must connect because the whitelist url conf is None assert isinstance(conn._proto1.state, ReadyState) assert isinstance(conn._proto2.state, ReadyState) - self.assertEqual(conn._proto1.state.state_name, 'READY') - self.assertEqual(conn._proto2.state.state_name, 'READY') + self.assertEqual(conn._proto1.state.name, 'READY') + self.assertEqual(conn._proto2.state.name, 'READY') self.assertIsInstance(conn._proto1.state.sync_agent, NodeBlockSync) self.assertIsInstance(conn._proto2.state.sync_agent, NodeBlockSync) @@ -80,8 +80,8 @@ def test_capabilities(self) -> None: assert isinstance(conn2._proto1.state, ReadyState) assert isinstance(conn2._proto2.state, ReadyState) - self.assertEqual(conn2._proto1.state.state_name, 'READY') - self.assertEqual(conn2._proto2.state.state_name, 'READY') + self.assertEqual(conn2._proto1.state.name, 'READY') + self.assertEqual(conn2._proto2.state.name, 'READY') self.assertIsInstance(conn2._proto1.state.sync_agent, NodeBlockSync) self.assertIsInstance(conn2._proto2.state.sync_agent, NodeBlockSync) diff --git a/tests/p2p/test_protocol.py b/tests/p2p/test_protocol.py index 2ae01fc35..d55c48b7d 100644 --- a/tests/p2p/test_protocol.py +++ b/tests/p2p/test_protocol.py @@ -2,8 +2,8 @@ from typing import Optional from unittest.mock import Mock, patch -import pytest from twisted.internet import defer +from twisted.internet.address import IPv4Address from twisted.internet.protocol import Protocol from twisted.python.failure import Failure @@ -13,7 +13,9 @@ from hathor.p2p.peer import PrivatePeer from hathor.p2p.peer_endpoint import PeerAddress, PeerEndpoint from hathor.p2p.protocol import HathorLineReceiver, HathorProtocol +from hathor.p2p.states import InitialState from hathor.simulator import FakeConnection +from hathor.simulator.fake_connection import HathorStringTransport from hathor.util import json_dumps, json_loadb from tests import unittest @@ -116,14 +118,6 @@ def test_rate_limit(self) -> None: self.conn.proto1.state.handle_throttle(b'') - # Test empty disconnect - self.conn.proto1.state = None - with pytest.raises(AssertionError): - # TODO: This raises because we are trying to disconnect a protocol with no state, but it's not possible - # for a protocol to have no state after it's handshaking. We have to update this when we introduce the - # new non-None initial state for protocols. - self.conn.proto1.on_disconnect(Failure(Exception())) - def test_invalid_size(self) -> None: self.conn.tr1.clear() cmd = b'HELLO ' @@ -389,7 +383,24 @@ def test_send_invalid_unicode(self) -> None: self.conn.proto1.dataReceived(b'\xff\r\n') self.assertTrue(self.conn.tr1.disconnecting) - def test_on_disconnect(self) -> None: + def test_on_disconnect_initial(self) -> None: + manager3 = self.create_peer(self.network) + addr = IPv4Address('TCP', '127.0.0.1', 40403) + peer_addr = PeerAddress.from_address(addr) + peer_endpoint = peer_addr.with_id(self.manager1.my_peer.id) + deferred = manager3.connections.connect_to(peer_endpoint, self.manager1.my_peer) + assert deferred is not None + + proto = manager3.connections.client_factory.buildProtocol(addr) + proto.transport = HathorStringTransport(self.manager1.my_peer, peer_address=addr) + deferred.callback(proto) + assert peer_addr in manager3.connections.iter_not_ready_endpoints() + assert isinstance(proto.state, InitialState) + + proto.connectionLost(Failure(Exception())) + assert peer_addr not in manager3.connections._connections.all_peers() + + def test_on_disconnect_before_hello(self) -> None: self.assertIn(self.conn.proto1, self.manager1.connections.iter_handshaking_peers()) self.conn.disconnect(Failure(Exception('testing'))) self.assertNotIn(self.conn.proto1, self.manager1.connections.iter_handshaking_peers()) diff --git a/tests/p2p/test_sync_enabled.py b/tests/p2p/test_sync_enabled.py index f681f90a0..50ae56dae 100644 --- a/tests/p2p/test_sync_enabled.py +++ b/tests/p2p/test_sync_enabled.py @@ -1,3 +1,4 @@ +from hathor.p2p.states import ReadyState from hathor.simulator import FakeConnection from hathor.simulator.trigger import StopAfterNMinedBlocks from tests import unittest @@ -52,7 +53,7 @@ def test_sync_rotate(self) -> None: self.simulator.run(600) - ready = set(conn for conn in connections if conn.proto1.is_state(conn.proto1.PeerState.READY)) + ready = set(conn for conn in connections if isinstance(conn.proto1.state, ReadyState)) self.assertEqual(len(ready), len(other_managers)) enabled = set(conn for conn in connections if conn.proto1.is_sync_enabled()) diff --git a/tests/sysctl/test_p2p.py b/tests/sysctl/test_p2p.py index a3703676a..bcadc3bbd 100644 --- a/tests/sysctl/test_p2p.py +++ b/tests/sysctl/test_p2p.py @@ -174,7 +174,8 @@ def test_kill_one_connection(self): peer_id = '0e2bd0d8cd1fb6d040801c32ec27e8986ce85eb8810b6c878dcad15bce3b5b1e' conn = MagicMock() conn.addr = PeerAddress.parse('tcp://localhost:40403') - p2p_manager._connections.on_connected(protocol=conn) + p2p_manager._connections.on_built_protocol(addr=conn.addr, protocol=conn) + p2p_manager._connections.on_connected(addr=conn.addr, inbound=True) p2p_manager._connections.on_ready(addr=conn.addr, peer_id=PeerId(peer_id)) self.assertEqual(conn.disconnect.call_count, 0) sysctl.unsafe_set('kill_connection', peer_id)