Skip to content

Commit

Permalink
Merge branch 'staging' into feat/roman/add_unstake_extrinsic
Browse files Browse the repository at this point in the history
  • Loading branch information
roman-opentensor authored Nov 21, 2024
2 parents 5c72efe + 1164fb5 commit 7f373a3
Show file tree
Hide file tree
Showing 6 changed files with 28 additions and 24 deletions.
18 changes: 11 additions & 7 deletions bittensor/core/subtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

import argparse
import copy
import socket
import ssl
from typing import Union, Optional, TypedDict, Any

Expand All @@ -18,6 +17,7 @@
from scalecodec.type_registry import load_type_registry_preset
from scalecodec.types import ScaleType
from substrateinterface.base import QueryMapResult, SubstrateInterface
from websockets.sync import client as ws_client

from bittensor.core import settings
from bittensor.core.axon import Axon
Expand Down Expand Up @@ -140,6 +140,7 @@ def __init__(
_mock: bool = False,
log_verbose: bool = False,
connection_timeout: int = 600,
websocket: Optional[ws_client.ClientConnection] = None,
) -> None:
"""
Initializes a Subtensor interface for interacting with the Bittensor blockchain.
Expand All @@ -155,6 +156,7 @@ def __init__(
_mock (bool): If set to ``True``, uses a mocked connection for testing purposes. Default is ``False``.
log_verbose (bool): Whether to enable verbose logging. If set to ``True``, detailed log information about the connection and network operations will be provided. Default is ``True``.
connection_timeout (int): The maximum time in seconds to keep the connection alive. Default is ``600``.
websocket (websockets.sync.client.ClientConnection): websockets sync (threading) client object connected to the network.
This initialization sets up the connection to the specified Bittensor network, allowing for various blockchain operations such as neuron registration, stake management, and setting weights.
"""
Expand Down Expand Up @@ -191,6 +193,7 @@ def __init__(
self.log_verbose = log_verbose
self._connection_timeout = connection_timeout
self.substrate: "SubstrateInterface" = None
self.websocket = websocket
self._get_substrate()

def __str__(self) -> str:
Expand All @@ -213,22 +216,23 @@ def _get_substrate(self):
"""Establishes a connection to the Substrate node using configured parameters."""
try:
# Set up params.
if not self.websocket:
self.websocket = ws_client.connect(
self.chain_endpoint,
open_timeout=self._connection_timeout,
max_size=2**32,
)
self.substrate = SubstrateInterface(
ss58_format=settings.SS58_FORMAT,
use_remote_preset=True,
url=self.chain_endpoint,
type_registry=settings.TYPE_REGISTRY,
websocket=self.websocket,
)
if self.log_verbose:
logging.debug(
f"Connected to {self.network} network and {self.chain_endpoint}."
)

try:
self.substrate.websocket.settimeout(self._connection_timeout)
except (AttributeError, TypeError, socket.error, OSError) as e:
logging.warning(f"Error setting timeout: {e}")

except (ConnectionRefusedError, ssl.SSLError) as error:
logging.error(
f"<red>Could not connect to</red> <blue>{self.network}</blue> <red>network with</red> <blue>{self.chain_endpoint}</blue> <red>chain endpoint.</red>",
Expand Down
25 changes: 12 additions & 13 deletions bittensor/utils/async_substrate_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,11 @@
from collections import defaultdict
from dataclasses import dataclass
from hashlib import blake2b
from typing import Optional, Any, Union, Callable, Awaitable, cast
from typing import Optional, Any, Union, Callable, Awaitable, cast, TYPE_CHECKING

import websockets
from async_property import async_property
from bittensor_wallet import Keypair
from bt_decode import PortableRegistry, decode as decode_by_type_string, MetadataV15
from packaging import version
from scalecodec import GenericExtrinsic
from scalecodec.base import ScaleBytes, ScaleType, RuntimeConfigurationObject
from scalecodec.type_registry import load_type_registry_preset
Expand All @@ -21,6 +19,11 @@
BlockNotFound,
)
from substrateinterface.storage import StorageKey
from websockets.asyncio.client import connect
from websockets.exceptions import ConnectionClosed

if TYPE_CHECKING:
from websockets.asyncio.client import ClientConnection

ResultHandler = Callable[[dict, Any], Awaitable[tuple[dict, bool]]]

Expand Down Expand Up @@ -622,7 +625,7 @@ def __init__(
# TODO allow setting max concurrent connections and rpc subscriptions per connection
# TODO reconnection logic
self.ws_url = ws_url
self.ws: Optional[websockets.WebSocketClientProtocol] = None
self.ws: Optional["ClientConnection"] = None
self.id = 0
self.max_subscriptions = max_subscriptions
self.max_connections = max_connections
Expand Down Expand Up @@ -650,7 +653,7 @@ async def __aenter__(self):

async def _connect(self):
self.ws = await asyncio.wait_for(
websockets.connect(self.ws_url, **self._options), timeout=10
connect(self.ws_url, **self._options), timeout=10
)

async def __aexit__(self, exc_type, exc_val, exc_tb):
Expand Down Expand Up @@ -693,9 +696,7 @@ async def shutdown(self):

async def _recv(self) -> None:
try:
response = json.loads(
await cast(websockets.WebSocketClientProtocol, self.ws).recv()
)
response = json.loads(await cast(ClientConnection, self.ws).recv())
async with self._lock:
self._open_subscriptions -= 1
if "id" in response:
Expand All @@ -704,7 +705,7 @@ async def _recv(self) -> None:
self._received[response["params"]["subscription"]] = response
else:
raise KeyError(response)
except websockets.ConnectionClosed:
except ConnectionClosed:
raise
except KeyError as e:
raise e
Expand All @@ -715,7 +716,7 @@ async def _start_receiving(self):
await self._recv()
except asyncio.CancelledError:
pass
except websockets.ConnectionClosed:
except ConnectionClosed:
# TODO try reconnect, but only if it's needed
raise

Expand All @@ -732,7 +733,7 @@ async def send(self, payload: dict) -> int:
try:
await self.ws.send(json.dumps({**payload, **{"id": original_id}}))
return original_id
except websockets.ConnectionClosed:
except ConnectionClosed:
raise

async def retrieve(self, item_id: int) -> Optional[dict]:
Expand Down Expand Up @@ -773,8 +774,6 @@ def __init__(
"max_size": 2**32,
"write_limit": 2**16,
}
if version.parse(websockets.__version__) < version.parse("14.0"):
options.update({"read_limit": 2**16})
self.ws = Websocket(chain_endpoint, options=options)
self._lock = asyncio.Lock()
self.last_block_hash: Optional[str] = None
Expand Down
2 changes: 1 addition & 1 deletion bittensor/utils/networking.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def ensure_connected(func):

def is_connected(substrate) -> bool:
"""Check if the substrate connection is active."""
sock = substrate.websocket.sock
sock = substrate.websocket.socket
return (
sock is not None
and sock.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR) == 0
Expand Down
2 changes: 1 addition & 1 deletion requirements/prod.txt
Original file line number Diff line number Diff line change
Expand Up @@ -23,5 +23,5 @@ python-Levenshtein
scalecodec==1.2.11
substrate-interface~=1.7.9
uvicorn
websockets>12.0
websockets>=14.1
bittensor-wallet>=2.1.0
2 changes: 1 addition & 1 deletion tests/integration_tests/test_subtensor_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def test_network_overrides(self):
config1.subtensor.chain_endpoint = None

# Mock network calls
with patch("substrateinterface.SubstrateInterface.connect_websocket"):
with patch("websockets.sync.client.connect"):
with patch("substrateinterface.SubstrateInterface.reload_type_registry"):
print(bittensor.Subtensor, type(bittensor.Subtensor))
# Choose network arg over config
Expand Down
3 changes: 2 additions & 1 deletion tests/unit_tests/test_subtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1904,7 +1904,7 @@ def test_connect_with_substrate(mocker):
"""Ensure re-connection is non called when using an alive substrate."""
# Prep
fake_substrate = mocker.MagicMock()
fake_substrate.websocket.sock.getsockopt.return_value = 0
fake_substrate.websocket.socket.getsockopt.return_value = 0
mocker.patch.object(
subtensor_module, "SubstrateInterface", return_value=fake_substrate
)
Expand Down Expand Up @@ -2145,6 +2145,7 @@ def test_networks_during_connection(mocker):
"""Test networks during_connection."""
# Preps
subtensor_module.SubstrateInterface = mocker.Mock()
mocker.patch("websockets.sync.client.connect")
# Call
for network in list(settings.NETWORK_MAP.keys()) + ["undefined"]:
sub = Subtensor(network)
Expand Down

0 comments on commit 7f373a3

Please sign in to comment.