Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bittensor/axon.py: thread and exception handling #2227

Open
wants to merge 9 commits into
base: staging
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 94 additions & 7 deletions bittensor/axon.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import inspect
import json
import os
import socket
import threading
import time
import traceback
Expand Down Expand Up @@ -62,6 +63,12 @@
from bittensor.utils import networking


"""
The quantum of time to sleep in waiting loops, in seconds.
"""
TIME_SLEEP_INTERVAL: float = 1e-3


class FastAPIThreadedServer(uvicorn.Server):
"""
The ``FastAPIThreadedServer`` class is a specialized server implementation for the Axon server in the Bittensor network.
Expand Down Expand Up @@ -100,26 +107,80 @@ class FastAPIThreadedServer(uvicorn.Server):
should_exit: bool = False
is_running: bool = False

"""
Provide a channel to signal exceptions from the thread to our caller.
"""
_exception: Optional[Exception] = None
_lock: threading.Lock = threading.Lock()
_thread: Optional[threading.Thread] = None
_started: bool = False

def set_exception(self, exception: Exception) -> None:
"""
Set self._exception in a thread safe manner, so the worker thread can communicate exceptions to the main thread.
"""
with self._lock:
self._exception = exception

def get_exception(self) -> Optional[Exception]:
with self._lock:
return self._exception

def set_thread(self, thread: threading.Thread):
"""
Set self._thread in a thread safe manner, so the main thread can get the worker thread object.
"""
with self._lock:
self._thread = thread

def get_thread(self) -> Optional[threading.Thread]:
with self._lock:
return self._thread

def set_started(self, started: bool) -> None:
"""
Set self._started in a thread safe manner, so the main thread can get the worker thread status.
"""
with self._lock:
self._started = started

def get_started(self) -> bool:
with self._lock:
return self._started

def install_signal_handlers(self):
"""
Overrides the default signal handlers provided by ``uvicorn.Server``. This method is essential to ensure that the signal handling in the threaded server does not interfere with the main application's flow, especially in a complex asynchronous environment like the Axon server.
"""
pass

async def startup(self, sockets: Optional[List[socket.socket]] = None) -> None:
"""
Adds a thread-safe call to set a 'started' flag on the object.
"""
await super().startup(sockets)
self.set_started(True)

@contextlib.contextmanager
def run_in_thread(self):
"""
Manages the execution of the server in a separate thread, allowing the FastAPI application to run asynchronously without blocking the main thread of the Axon server. This method is a key component in enabling concurrent request handling in the Axon server.

Yields:
None: This method yields control back to the caller while the server is running in the background thread.
thread: a running thread

Raises:
Exception: in case the server did not start (as signalled by self.get_started())
"""
thread = threading.Thread(target=self.run, daemon=True)
thread.start()
try:
while not self.started:
time.sleep(1e-3)
yield
time_start = time.time()
while not self.get_started() and time.time() - time_start < 1:
time.sleep(TIME_SLEEP_INTERVAL)
if not self.get_started():
raise Exception("failed to start server")
yield thread
finally:
self.should_exit = True
thread.join()
Expand All @@ -128,9 +189,15 @@ def _wrapper_run(self):
"""
A wrapper method for the :func:`run_in_thread` context manager. This method is used internally by the ``start`` method to initiate the server's execution in a separate thread.
"""
with self.run_in_thread():
while not self.should_exit:
time.sleep(1e-3)
try:
with self.run_in_thread() as thread:
self.set_thread(thread)
while not self.should_exit:
if not thread.is_alive():
raise Exception("worker thread died")
time.sleep(TIME_SLEEP_INTERVAL)
except Exception as e:
self.set_exception(e)

def start(self):
"""
Expand Down Expand Up @@ -405,6 +472,26 @@ def info(self) -> "bittensor.AxonInfo":
placeholder2=0,
)

@property
def exception(self) -> Optional[Exception]:
"""
Axon objects expose exceptions that occurred internally through the .exception property.
"""
# for future use: setting self._exception to signal an exception
exception = getattr(self, "_exception", None)
if exception:
return exception
return self.fast_server.get_exception()

def is_running(self) -> bool:
"""
Axon objects can be queried using .is_running() to test whether worker threads are running.
"""
thread = self.fast_server.get_thread()
if thread is None:
return False
return thread.is_alive()

def attach(
self,
forward_fn: Callable,
Expand Down
File renamed without changes.
59 changes: 39 additions & 20 deletions bittensor/commands/stake.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,19 +44,25 @@


def get_netuid(
cli: "bittensor.cli", subtensor: "bittensor.subtensor"
cli: "bittensor.cli", subtensor: "bittensor.subtensor", prompt: bool = True
) -> Tuple[bool, int]:
"""Retrieve and validate the netuid from the user or configuration."""
console = Console()
if not cli.config.is_set("netuid"):
try:
cli.config.netuid = int(Prompt.ask("Enter netuid"))
except ValueError:
console.print(
"[red]Invalid input. Please enter a valid integer for netuid.[/red]"
)
return False, -1
if not cli.config.is_set("netuid") and prompt:
cli.config.netuid = Prompt.ask("Enter netuid")
try:
cli.config.netuid = int(cli.config.netuid)
except ValueError:
console.print(
"[red]Invalid input. Please enter a valid integer for netuid.[/red]"
)
return False, -1
netuid = cli.config.netuid
if netuid < 0 or netuid > 65535:
console.print(
"[red]Invalid input. Please enter a valid integer for netuid in subnet range.[/red]"
)
return False, -1
if not subtensor.subnet_exists(netuid=netuid):
console.print(
"[red]Network with netuid {} does not exist. Please try again.[/red]".format(
Expand Down Expand Up @@ -1136,10 +1142,27 @@ def _run(cli: "bittensor.cli", subtensor: "bittensor.subtensor"):
wallet = bittensor.wallet(config=cli.config)

# check all
if not cli.config.is_set("all"):
exists, netuid = get_netuid(cli, subtensor)
if not exists:
return
if cli.config.is_set("all"):
cli.config.netuid = None
cli.config.all = True
elif cli.config.is_set("netuid"):
if cli.config.netuid == "all":
cli.config.all = True
else:
cli.config.netuid = int(cli.config.netuid)
exists, netuid = get_netuid(cli, subtensor)
if not exists:
return
else:
netuid_input = Prompt.ask("Enter netuid or 'all'", default="all")
if netuid_input == "all":
cli.config.netuid = None
cli.config.all = True
else:
cli.config.netuid = int(netuid_input)
exists, netuid = get_netuid(cli, subtensor, False)
if not exists:
return

# get parent hotkey
hotkey = get_hotkey(wallet, cli.config)
Expand All @@ -1148,11 +1171,7 @@ def _run(cli: "bittensor.cli", subtensor: "bittensor.subtensor"):
return

try:
netuids = (
subtensor.get_all_subnet_netuids()
if cli.config.is_set("all")
else [netuid]
)
netuids = subtensor.get_all_subnet_netuids() if cli.config.all else [netuid]
hotkey_stake = GetChildrenCommand.get_parent_stake_info(
console, subtensor, hotkey
)
Expand Down Expand Up @@ -1236,7 +1255,7 @@ def add_args(parser: argparse.ArgumentParser):
parser = parser.add_parser(
"get_children", help="""Get child hotkeys on subnet."""
)
parser.add_argument("--netuid", dest="netuid", type=int, required=False)
parser.add_argument("--netuid", dest="netuid", type=str, required=False)
parser.add_argument("--hotkey", dest="hotkey", type=str, required=False)
parser.add_argument(
"--all",
Expand Down Expand Up @@ -1294,7 +1313,7 @@ def render_table(

# Add columns to the table with specific styles
table.add_column("Index", style="bold yellow", no_wrap=True, justify="center")
table.add_column("ChildHotkey", style="bold green")
table.add_column("Child Hotkey", style="bold green")
table.add_column("Proportion", style="bold cyan", no_wrap=True, justify="right")
table.add_column(
"Childkey Take", style="bold blue", no_wrap=True, justify="right"
Expand Down
2 changes: 1 addition & 1 deletion bittensor/subtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@
unstake_extrinsic,
unstake_multiple_extrinsic,
)
from .types import AxonServeCallParams, PrometheusServeCallParams
from .bt_types import AxonServeCallParams, PrometheusServeCallParams
from .utils import (
U16_NORMALIZED_FLOAT,
ss58_to_vec_u8,
Expand Down
Loading