diff --git a/docs/versionhistory.rst b/docs/versionhistory.rst
index 3d5bdfd8..1883facb 100644
--- a/docs/versionhistory.rst
+++ b/docs/versionhistory.rst
@@ -9,6 +9,8 @@ This library adheres to `Semantic Versioning 2.0 `_.
* Asphalt now runs via AnyIO, rather than asyncio, although the asyncio backend is
used by default
+ * The runner now outputs an elaborate tree of component startup tasks if the
+ application fails to start within the allotted time
* Dropped the ``--unsafe`` switch for ``asphalt run`` – configuration files are now
always parsed in unsafe mode
* Changed configuration parsing to no longer treat dotted keys in configuration
diff --git a/src/asphalt/core/_component.py b/src/asphalt/core/_component.py
index cdfdf488..8091fd73 100644
--- a/src/asphalt/core/_component.py
+++ b/src/asphalt/core/_component.py
@@ -6,7 +6,7 @@
from anyio import create_task_group
-from ._utils import PluginContainer, merge_config
+from ._utils import PluginContainer, merge_config, qualified_name
class Component(metaclass=ABCMeta):
@@ -105,7 +105,10 @@ async def start(self) -> None:
async with create_task_group() as tg:
for alias, component in self.child_components.items():
- tg.start_soon(component.start)
+ tg.start_soon(
+ component.start,
+ name=f"Starting {qualified_name(component)} ({alias})",
+ )
class CLIApplicationComponent(ContainerComponent):
diff --git a/src/asphalt/core/_runner.py b/src/asphalt/core/_runner.py
index d3d1318c..008ad0bf 100644
--- a/src/asphalt/core/_runner.py
+++ b/src/asphalt/core/_runner.py
@@ -1,20 +1,30 @@
from __future__ import annotations
+import gc
import platform
+import re
import signal
import sys
+import textwrap
+from collections.abc import Coroutine
from contextlib import AsyncExitStack
+from dataclasses import dataclass, field
from functools import partial
from logging import INFO, Logger, basicConfig, getLogger
from logging.config import dictConfig
+from traceback import StackSummary
+from types import FrameType
from typing import Any, cast
from warnings import warn
import anyio
from anyio import (
+ CancelScope,
Event,
- fail_after,
get_cancelled_exc_class,
+ get_current_task,
+ get_running_tasks,
+ sleep,
to_thread,
)
from anyio.abc import TaskStatus
@@ -24,8 +34,12 @@
from ._context import Context
from ._utils import qualified_name
+component_task_re = re.compile(r"^Starting (\S+) \((.+)\)$")
-async def handle_signals(event: Event, *, task_status: TaskStatus[None]) -> None:
+
+async def handle_signals(
+ startup_scope: CancelScope, event: Event, *, task_status: TaskStatus[None]
+) -> None:
logger = getLogger(__name__)
with anyio.open_signal_receiver(signal.SIGTERM, signal.SIGINT) as signals:
task_status.started()
@@ -35,9 +49,112 @@ async def handle_signals(event: Event, *, task_status: TaskStatus[None]) -> None
"Received signal (%s) – terminating application",
signal_name.split(":", 1)[0], # macOS has ": " after the name
)
+ startup_scope.cancel()
event.set()
+def get_coro_stack_summary(coro: Any) -> StackSummary:
+ frames: list[FrameType] = []
+ while isinstance(coro, Coroutine):
+ while coro.__class__.__name__ == "async_generator_asend":
+ # Hack to get past asend() objects
+ coro = gc.get_referents(coro)[0].ag_await
+
+ if frame := getattr(coro, "cr_frame", None):
+ frames.append(frame)
+
+ coro = getattr(coro, "cr_await", None)
+
+ frame_tuples = [(f, f.f_lineno) for f in frames]
+ return StackSummary.extract(frame_tuples)
+
+
+async def startup_watcher(
+ startup_cancel_scope: CancelScope,
+ root_component: Component,
+ start_timeout: float,
+ logger: Logger,
+ *,
+ task_status: TaskStatus[CancelScope],
+) -> None:
+ current_task = get_current_task()
+ parent_task = next(
+ task_info
+ for task_info in get_running_tasks()
+ if task_info.id == current_task.parent_id
+ )
+
+ with CancelScope() as cancel_scope:
+ task_status.started(cancel_scope)
+ await sleep(start_timeout)
+
+ if cancel_scope.cancel_called:
+ return
+
+ @dataclass
+ class ComponentStatus:
+ name: str
+ alias: str | None
+ parent_task_id: int | None
+ traceback: list[str] = field(init=False, default_factory=list)
+ children: list[ComponentStatus] = field(init=False, default_factory=list)
+
+ component_statuses: dict[int, ComponentStatus] = {}
+ for task in get_running_tasks():
+ if task.id == parent_task.id:
+ status = ComponentStatus(qualified_name(root_component), None, None)
+ elif task.name and (match := component_task_re.match(task.name)):
+ name: str
+ alias: str
+ name, alias = match.groups()
+ status = ComponentStatus(name, alias, task.parent_id)
+ else:
+ continue
+
+ status.traceback = get_coro_stack_summary(task.coro).format()
+ component_statuses[task.id] = status
+
+ root_status: ComponentStatus
+ for task_id, component_status in component_statuses.items():
+ if component_status.parent_task_id is None:
+ root_status = component_status
+ elif parent_status := component_statuses.get(component_status.parent_task_id):
+ parent_status.children.append(component_status)
+ if parent_status.alias:
+ component_status.alias = (
+ f"{parent_status.alias}.{component_status.alias}"
+ )
+
+ def format_status(status_: ComponentStatus, level: int) -> str:
+ title = f"{status_.alias or 'root'} ({status_.name})"
+ if status_.children:
+ children_output = ""
+ for i, child in enumerate(status_.children):
+ prefix = "| " if i < (len(status_.children) - 1) else " "
+ children_output += "+-" + textwrap.indent(
+ format_status(child, level + 1),
+ prefix,
+ lambda line: line[0] in " +|",
+ )
+
+ output = title + "\n" + children_output
+ else:
+ formatted_traceback = "".join(status_.traceback)
+ if level == 0:
+ formatted_traceback = textwrap.indent(formatted_traceback, "| ")
+
+ output = title + "\n" + formatted_traceback
+
+ return output
+
+ logger.error(
+ "Timeout waiting for the root component to start\n"
+ "Components still waiting to finish startup:\n%s",
+ textwrap.indent(format_status(root_status, 0).rstrip(), " "),
+ )
+ startup_cancel_scope.cancel()
+
+
async def _run_application_async(
component: Component,
logger: Logger,
@@ -54,24 +171,34 @@ async def _run_application_async(
event = Event()
await exit_stack.enter_async_context(Context())
- if platform.system() != "Windows":
- await start_service_task(
- partial(handle_signals, event), "Asphalt signal handler"
- )
+ with CancelScope() as startup_scope:
+ if platform.system() != "Windows":
+ await start_service_task(
+ partial(handle_signals, startup_scope, event),
+ "Asphalt signal handler",
+ )
+
+ try:
+ if start_timeout is not None:
+ startup_watcher_scope = await start_service_task(
+ lambda task_status: startup_watcher(
+ startup_scope,
+ component,
+ start_timeout,
+ logger,
+ task_status=task_status,
+ ),
+ "Asphalt startup watcher task",
+ )
- try:
- with fail_after(start_timeout):
await component.start()
- except TimeoutError:
- logger.error("Timeout waiting for the root component to start")
- raise
- except get_cancelled_exc_class():
- logger.error("Application startup interrupted")
- return 1
- except BaseException:
- logger.exception("Error during application startup")
- raise
+ except get_cancelled_exc_class():
+ return 1
+ except BaseException:
+ logger.exception("Error during application startup")
+ return 1
+ startup_watcher_scope.cancel()
logger.info("Application started")
if isinstance(component, CLIApplicationComponent):
diff --git a/tests/test_runner.py b/tests/test_runner.py
index ffb334da..11582c70 100644
--- a/tests/test_runner.py
+++ b/tests/test_runner.py
@@ -2,6 +2,7 @@
import logging
import platform
+import re
import signal
from typing import Any, Literal
from unittest.mock import patch
@@ -11,11 +12,11 @@
from _pytest.logging import LogCaptureFixture
from anyio import to_thread
from anyio.lowlevel import checkpoint
-from common import raises_in_exception_group
from asphalt.core import (
CLIApplicationComponent,
Component,
+ ContainerComponent,
add_teardown_callback,
get_resource,
run_application,
@@ -204,11 +205,10 @@ def test_start_exception(caplog: LogCaptureFixture, anyio_backend_name: str) ->
"""
caplog.set_level(logging.INFO)
component = CrashComponent(method="exception")
- with raises_in_exception_group(
- RuntimeError, match="this should crash the application"
- ):
+ with pytest.raises(SystemExit) as exc_info:
run_application(component, backend=anyio_backend_name)
+ assert exc_info.value.code == 1
records = [
record for record in caplog.records if record.name.startswith("asphalt.core.")
]
@@ -219,33 +219,78 @@ def test_start_exception(caplog: LogCaptureFixture, anyio_backend_name: str) ->
assert records[3].message == "Application stopped"
-def test_start_timeout(caplog: LogCaptureFixture, anyio_backend_name: str) -> None:
- """
- Test that when the root component takes too long to start up, the runner exits and
- logs the appropriate error message.
- """
+@pytest.mark.parametrize("levels", [1, 2, 3])
+def test_start_timeout(
+ caplog: LogCaptureFixture, anyio_backend_name: str, levels: int
+) -> None:
+ class StallingComponent(ContainerComponent):
+ def __init__(self, level: int):
+ super().__init__()
+ self.level = level
- class StallingComponent(Component):
async def start(self) -> None:
- # Wait forever for a non-existent resource
- await get_resource(float, wait=True)
+ if self.level == levels:
+ # Wait forever for a non-existent resource
+ await get_resource(float, wait=True)
+ else:
+ self.add_component("child1", StallingComponent, level=self.level + 1)
+ self.add_component("child2", StallingComponent, level=self.level + 1)
+
+ await super().start()
caplog.set_level(logging.INFO)
- component = StallingComponent()
- with raises_in_exception_group(TimeoutError):
+ component = StallingComponent(1)
+ with pytest.raises(SystemExit) as exc_info:
run_application(component, start_timeout=0.1, backend=anyio_backend_name)
- records = [
- record for record in caplog.records if record.name == "asphalt.core._runner"
- ]
- assert len(records) == 4
- assert records[0].message == "Running in development mode"
- assert records[1].message == "Starting application"
- assert records[2].message.startswith(
- "Timeout waiting for the root component to start"
+ assert exc_info.value.code == 1
+ assert len(caplog.messages) == 4
+ assert caplog.messages[0] == "Running in development mode"
+ assert caplog.messages[1] == "Starting application"
+ assert caplog.messages[2].startswith(
+ "Timeout waiting for the root component to start\n"
+ "Components still waiting to finish startup:\n"
)
- # assert "-> await ctx.get_resource(float)" in records[2].message
- assert records[3].message == "Application stopped"
+ assert caplog.messages[3] == "Application stopped"
+
+ child_component_re = re.compile(r"([ |]+)\+-([a-z.12]+) \((.+)\)")
+ lines = caplog.messages[2].splitlines()
+ expected_test_name = f"{__name__}.test_start_timeout"
+ assert lines[2] == f" root ({expected_test_name}..StallingComponent)"
+ component_aliases: set[str] = set()
+ depths: list[int] = [0] * levels
+ expected_indent = " | "
+ for line in lines[3:]:
+ if match := child_component_re.match(line):
+ indent, alias, component_name = match.groups()
+ depth = len(alias.split("."))
+ depths[depth - 1] += 1
+ depths[depth:] = [0] * (len(depths) - depth)
+ assert len(depths) == levels
+ assert all(d < 3 for d in depths)
+ expected_indent = " " + "".join(
+ (" " if d > 1 else "| ") for d in depths[:depth]
+ )
+ assert component_name == (
+ f"{expected_test_name}..StallingComponent"
+ )
+ component_aliases.add(alias)
+ else:
+ assert line.startswith(expected_indent)
+
+ if levels == 1:
+ assert not component_aliases
+ elif levels == 2:
+ assert component_aliases == {"child1", "child2"}
+ else:
+ assert component_aliases == {
+ "child1",
+ "child2",
+ "child1.child1",
+ "child1.child2",
+ "child2.child1",
+ "child2.child2",
+ }
def test_dict_config(caplog: LogCaptureFixture, anyio_backend_name: str) -> None: