diff --git a/bittensor/core/subtensor.py b/bittensor/core/subtensor.py index 51385d8d00..01fd262e13 100644 --- a/bittensor/core/subtensor.py +++ b/bittensor/core/subtensor.py @@ -5,7 +5,6 @@ import argparse import copy -import socket import ssl from typing import Union, Optional, TypedDict, Any @@ -18,6 +17,7 @@ from scalecodec.type_registry import load_type_registry_preset from scalecodec.types import ScaleType from substrateinterface.base import QueryMapResult, SubstrateInterface +from websockets.sync import client as ws_client from bittensor.core import settings from bittensor.core.axon import Axon @@ -140,6 +140,7 @@ def __init__( _mock: bool = False, log_verbose: bool = False, connection_timeout: int = 600, + websocket: Optional[ws_client.ClientConnection] = None, ) -> None: """ Initializes a Subtensor interface for interacting with the Bittensor blockchain. @@ -155,6 +156,7 @@ def __init__( _mock (bool): If set to ``True``, uses a mocked connection for testing purposes. Default is ``False``. log_verbose (bool): Whether to enable verbose logging. If set to ``True``, detailed log information about the connection and network operations will be provided. Default is ``True``. connection_timeout (int): The maximum time in seconds to keep the connection alive. Default is ``600``. + websocket (websockets.sync.client.ClientConnection): websockets sync (threading) client object connected to the network. This initialization sets up the connection to the specified Bittensor network, allowing for various blockchain operations such as neuron registration, stake management, and setting weights. """ @@ -191,6 +193,7 @@ def __init__( self.log_verbose = log_verbose self._connection_timeout = connection_timeout self.substrate: "SubstrateInterface" = None + self.websocket = websocket self._get_substrate() def __str__(self) -> str: @@ -213,22 +216,23 @@ def _get_substrate(self): """Establishes a connection to the Substrate node using configured parameters.""" try: # Set up params. + if not self.websocket: + self.websocket = ws_client.connect( + self.chain_endpoint, + open_timeout=self._connection_timeout, + max_size=2**32, + ) self.substrate = SubstrateInterface( ss58_format=settings.SS58_FORMAT, use_remote_preset=True, - url=self.chain_endpoint, type_registry=settings.TYPE_REGISTRY, + websocket=self.websocket, ) if self.log_verbose: logging.debug( f"Connected to {self.network} network and {self.chain_endpoint}." ) - try: - self.substrate.websocket.settimeout(self._connection_timeout) - except (AttributeError, TypeError, socket.error, OSError) as e: - logging.warning(f"Error setting timeout: {e}") - except (ConnectionRefusedError, ssl.SSLError) as error: logging.error( f"Could not connect to {self.network} network with {self.chain_endpoint} chain endpoint.", diff --git a/bittensor/utils/async_substrate_interface.py b/bittensor/utils/async_substrate_interface.py index c3af691952..982e8dfc96 100644 --- a/bittensor/utils/async_substrate_interface.py +++ b/bittensor/utils/async_substrate_interface.py @@ -4,13 +4,11 @@ from collections import defaultdict from dataclasses import dataclass from hashlib import blake2b -from typing import Optional, Any, Union, Callable, Awaitable, cast +from typing import Optional, Any, Union, Callable, Awaitable, cast, TYPE_CHECKING -import websockets from async_property import async_property from bittensor_wallet import Keypair from bt_decode import PortableRegistry, decode as decode_by_type_string, MetadataV15 -from packaging import version from scalecodec import GenericExtrinsic from scalecodec.base import ScaleBytes, ScaleType, RuntimeConfigurationObject from scalecodec.type_registry import load_type_registry_preset @@ -21,6 +19,11 @@ BlockNotFound, ) from substrateinterface.storage import StorageKey +from websockets.asyncio.client import connect +from websockets.exceptions import ConnectionClosed + +if TYPE_CHECKING: + from websockets.asyncio.client import ClientConnection ResultHandler = Callable[[dict, Any], Awaitable[tuple[dict, bool]]] @@ -622,7 +625,7 @@ def __init__( # TODO allow setting max concurrent connections and rpc subscriptions per connection # TODO reconnection logic self.ws_url = ws_url - self.ws: Optional[websockets.WebSocketClientProtocol] = None + self.ws: Optional["ClientConnection"] = None self.id = 0 self.max_subscriptions = max_subscriptions self.max_connections = max_connections @@ -650,7 +653,7 @@ async def __aenter__(self): async def _connect(self): self.ws = await asyncio.wait_for( - websockets.connect(self.ws_url, **self._options), timeout=10 + connect(self.ws_url, **self._options), timeout=10 ) async def __aexit__(self, exc_type, exc_val, exc_tb): @@ -693,9 +696,7 @@ async def shutdown(self): async def _recv(self) -> None: try: - response = json.loads( - await cast(websockets.WebSocketClientProtocol, self.ws).recv() - ) + response = json.loads(await cast(ClientConnection, self.ws).recv()) async with self._lock: self._open_subscriptions -= 1 if "id" in response: @@ -704,7 +705,7 @@ async def _recv(self) -> None: self._received[response["params"]["subscription"]] = response else: raise KeyError(response) - except websockets.ConnectionClosed: + except ConnectionClosed: raise except KeyError as e: raise e @@ -715,7 +716,7 @@ async def _start_receiving(self): await self._recv() except asyncio.CancelledError: pass - except websockets.ConnectionClosed: + except ConnectionClosed: # TODO try reconnect, but only if it's needed raise @@ -732,7 +733,7 @@ async def send(self, payload: dict) -> int: try: await self.ws.send(json.dumps({**payload, **{"id": original_id}})) return original_id - except websockets.ConnectionClosed: + except ConnectionClosed: raise async def retrieve(self, item_id: int) -> Optional[dict]: @@ -773,8 +774,6 @@ def __init__( "max_size": 2**32, "write_limit": 2**16, } - if version.parse(websockets.__version__) < version.parse("14.0"): - options.update({"read_limit": 2**16}) self.ws = Websocket(chain_endpoint, options=options) self._lock = asyncio.Lock() self.last_block_hash: Optional[str] = None diff --git a/bittensor/utils/networking.py b/bittensor/utils/networking.py index 7524b353f5..e0ce8c2bce 100644 --- a/bittensor/utils/networking.py +++ b/bittensor/utils/networking.py @@ -165,7 +165,7 @@ def ensure_connected(func): def is_connected(substrate) -> bool: """Check if the substrate connection is active.""" - sock = substrate.websocket.sock + sock = substrate.websocket.socket return ( sock is not None and sock.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR) == 0 diff --git a/requirements/prod.txt b/requirements/prod.txt index d084b5e37a..b50ae9c2af 100644 --- a/requirements/prod.txt +++ b/requirements/prod.txt @@ -23,5 +23,5 @@ python-Levenshtein scalecodec==1.2.11 substrate-interface~=1.7.9 uvicorn -websockets>12.0 +websockets>=14.1 bittensor-wallet>=2.1.0 diff --git a/tests/integration_tests/test_subtensor_integration.py b/tests/integration_tests/test_subtensor_integration.py index bacb340f2c..ebe07fdf56 100644 --- a/tests/integration_tests/test_subtensor_integration.py +++ b/tests/integration_tests/test_subtensor_integration.py @@ -75,7 +75,7 @@ def test_network_overrides(self): config1.subtensor.chain_endpoint = None # Mock network calls - with patch("substrateinterface.SubstrateInterface.connect_websocket"): + with patch("websockets.sync.client.connect"): with patch("substrateinterface.SubstrateInterface.reload_type_registry"): print(bittensor.Subtensor, type(bittensor.Subtensor)) # Choose network arg over config diff --git a/tests/unit_tests/test_subtensor.py b/tests/unit_tests/test_subtensor.py index fa7e190dc5..dd31173083 100644 --- a/tests/unit_tests/test_subtensor.py +++ b/tests/unit_tests/test_subtensor.py @@ -1904,7 +1904,7 @@ def test_connect_with_substrate(mocker): """Ensure re-connection is non called when using an alive substrate.""" # Prep fake_substrate = mocker.MagicMock() - fake_substrate.websocket.sock.getsockopt.return_value = 0 + fake_substrate.websocket.socket.getsockopt.return_value = 0 mocker.patch.object( subtensor_module, "SubstrateInterface", return_value=fake_substrate ) @@ -2145,6 +2145,7 @@ def test_networks_during_connection(mocker): """Test networks during_connection.""" # Preps subtensor_module.SubstrateInterface = mocker.Mock() + mocker.patch("websockets.sync.client.connect") # Call for network in list(settings.NETWORK_MAP.keys()) + ["undefined"]: sub = Subtensor(network)