diff --git a/docs/versionhistory.rst b/docs/versionhistory.rst index 3d5bdfd8..1883facb 100644 --- a/docs/versionhistory.rst +++ b/docs/versionhistory.rst @@ -9,6 +9,8 @@ This library adheres to `Semantic Versioning 2.0 `_. * Asphalt now runs via AnyIO, rather than asyncio, although the asyncio backend is used by default + * The runner now outputs an elaborate tree of component startup tasks if the + application fails to start within the allotted time * Dropped the ``--unsafe`` switch for ``asphalt run`` – configuration files are now always parsed in unsafe mode * Changed configuration parsing to no longer treat dotted keys in configuration diff --git a/src/asphalt/core/_component.py b/src/asphalt/core/_component.py index cdfdf488..8091fd73 100644 --- a/src/asphalt/core/_component.py +++ b/src/asphalt/core/_component.py @@ -6,7 +6,7 @@ from anyio import create_task_group -from ._utils import PluginContainer, merge_config +from ._utils import PluginContainer, merge_config, qualified_name class Component(metaclass=ABCMeta): @@ -105,7 +105,10 @@ async def start(self) -> None: async with create_task_group() as tg: for alias, component in self.child_components.items(): - tg.start_soon(component.start) + tg.start_soon( + component.start, + name=f"Starting {qualified_name(component)} ({alias})", + ) class CLIApplicationComponent(ContainerComponent): diff --git a/src/asphalt/core/_runner.py b/src/asphalt/core/_runner.py index d3d1318c..008ad0bf 100644 --- a/src/asphalt/core/_runner.py +++ b/src/asphalt/core/_runner.py @@ -1,20 +1,30 @@ from __future__ import annotations +import gc import platform +import re import signal import sys +import textwrap +from collections.abc import Coroutine from contextlib import AsyncExitStack +from dataclasses import dataclass, field from functools import partial from logging import INFO, Logger, basicConfig, getLogger from logging.config import dictConfig +from traceback import StackSummary +from types import FrameType from typing import Any, cast from warnings import warn import anyio from anyio import ( + CancelScope, Event, - fail_after, get_cancelled_exc_class, + get_current_task, + get_running_tasks, + sleep, to_thread, ) from anyio.abc import TaskStatus @@ -24,8 +34,12 @@ from ._context import Context from ._utils import qualified_name +component_task_re = re.compile(r"^Starting (\S+) \((.+)\)$") -async def handle_signals(event: Event, *, task_status: TaskStatus[None]) -> None: + +async def handle_signals( + startup_scope: CancelScope, event: Event, *, task_status: TaskStatus[None] +) -> None: logger = getLogger(__name__) with anyio.open_signal_receiver(signal.SIGTERM, signal.SIGINT) as signals: task_status.started() @@ -35,9 +49,112 @@ async def handle_signals(event: Event, *, task_status: TaskStatus[None]) -> None "Received signal (%s) – terminating application", signal_name.split(":", 1)[0], # macOS has ": " after the name ) + startup_scope.cancel() event.set() +def get_coro_stack_summary(coro: Any) -> StackSummary: + frames: list[FrameType] = [] + while isinstance(coro, Coroutine): + while coro.__class__.__name__ == "async_generator_asend": + # Hack to get past asend() objects + coro = gc.get_referents(coro)[0].ag_await + + if frame := getattr(coro, "cr_frame", None): + frames.append(frame) + + coro = getattr(coro, "cr_await", None) + + frame_tuples = [(f, f.f_lineno) for f in frames] + return StackSummary.extract(frame_tuples) + + +async def startup_watcher( + startup_cancel_scope: CancelScope, + root_component: Component, + start_timeout: float, + logger: Logger, + *, + task_status: TaskStatus[CancelScope], +) -> None: + current_task = get_current_task() + parent_task = next( + task_info + for task_info in get_running_tasks() + if task_info.id == current_task.parent_id + ) + + with CancelScope() as cancel_scope: + task_status.started(cancel_scope) + await sleep(start_timeout) + + if cancel_scope.cancel_called: + return + + @dataclass + class ComponentStatus: + name: str + alias: str | None + parent_task_id: int | None + traceback: list[str] = field(init=False, default_factory=list) + children: list[ComponentStatus] = field(init=False, default_factory=list) + + component_statuses: dict[int, ComponentStatus] = {} + for task in get_running_tasks(): + if task.id == parent_task.id: + status = ComponentStatus(qualified_name(root_component), None, None) + elif task.name and (match := component_task_re.match(task.name)): + name: str + alias: str + name, alias = match.groups() + status = ComponentStatus(name, alias, task.parent_id) + else: + continue + + status.traceback = get_coro_stack_summary(task.coro).format() + component_statuses[task.id] = status + + root_status: ComponentStatus + for task_id, component_status in component_statuses.items(): + if component_status.parent_task_id is None: + root_status = component_status + elif parent_status := component_statuses.get(component_status.parent_task_id): + parent_status.children.append(component_status) + if parent_status.alias: + component_status.alias = ( + f"{parent_status.alias}.{component_status.alias}" + ) + + def format_status(status_: ComponentStatus, level: int) -> str: + title = f"{status_.alias or 'root'} ({status_.name})" + if status_.children: + children_output = "" + for i, child in enumerate(status_.children): + prefix = "| " if i < (len(status_.children) - 1) else " " + children_output += "+-" + textwrap.indent( + format_status(child, level + 1), + prefix, + lambda line: line[0] in " +|", + ) + + output = title + "\n" + children_output + else: + formatted_traceback = "".join(status_.traceback) + if level == 0: + formatted_traceback = textwrap.indent(formatted_traceback, "| ") + + output = title + "\n" + formatted_traceback + + return output + + logger.error( + "Timeout waiting for the root component to start\n" + "Components still waiting to finish startup:\n%s", + textwrap.indent(format_status(root_status, 0).rstrip(), " "), + ) + startup_cancel_scope.cancel() + + async def _run_application_async( component: Component, logger: Logger, @@ -54,24 +171,34 @@ async def _run_application_async( event = Event() await exit_stack.enter_async_context(Context()) - if platform.system() != "Windows": - await start_service_task( - partial(handle_signals, event), "Asphalt signal handler" - ) + with CancelScope() as startup_scope: + if platform.system() != "Windows": + await start_service_task( + partial(handle_signals, startup_scope, event), + "Asphalt signal handler", + ) + + try: + if start_timeout is not None: + startup_watcher_scope = await start_service_task( + lambda task_status: startup_watcher( + startup_scope, + component, + start_timeout, + logger, + task_status=task_status, + ), + "Asphalt startup watcher task", + ) - try: - with fail_after(start_timeout): await component.start() - except TimeoutError: - logger.error("Timeout waiting for the root component to start") - raise - except get_cancelled_exc_class(): - logger.error("Application startup interrupted") - return 1 - except BaseException: - logger.exception("Error during application startup") - raise + except get_cancelled_exc_class(): + return 1 + except BaseException: + logger.exception("Error during application startup") + return 1 + startup_watcher_scope.cancel() logger.info("Application started") if isinstance(component, CLIApplicationComponent): diff --git a/tests/test_runner.py b/tests/test_runner.py index ffb334da..11582c70 100644 --- a/tests/test_runner.py +++ b/tests/test_runner.py @@ -2,6 +2,7 @@ import logging import platform +import re import signal from typing import Any, Literal from unittest.mock import patch @@ -11,11 +12,11 @@ from _pytest.logging import LogCaptureFixture from anyio import to_thread from anyio.lowlevel import checkpoint -from common import raises_in_exception_group from asphalt.core import ( CLIApplicationComponent, Component, + ContainerComponent, add_teardown_callback, get_resource, run_application, @@ -204,11 +205,10 @@ def test_start_exception(caplog: LogCaptureFixture, anyio_backend_name: str) -> """ caplog.set_level(logging.INFO) component = CrashComponent(method="exception") - with raises_in_exception_group( - RuntimeError, match="this should crash the application" - ): + with pytest.raises(SystemExit) as exc_info: run_application(component, backend=anyio_backend_name) + assert exc_info.value.code == 1 records = [ record for record in caplog.records if record.name.startswith("asphalt.core.") ] @@ -219,33 +219,78 @@ def test_start_exception(caplog: LogCaptureFixture, anyio_backend_name: str) -> assert records[3].message == "Application stopped" -def test_start_timeout(caplog: LogCaptureFixture, anyio_backend_name: str) -> None: - """ - Test that when the root component takes too long to start up, the runner exits and - logs the appropriate error message. - """ +@pytest.mark.parametrize("levels", [1, 2, 3]) +def test_start_timeout( + caplog: LogCaptureFixture, anyio_backend_name: str, levels: int +) -> None: + class StallingComponent(ContainerComponent): + def __init__(self, level: int): + super().__init__() + self.level = level - class StallingComponent(Component): async def start(self) -> None: - # Wait forever for a non-existent resource - await get_resource(float, wait=True) + if self.level == levels: + # Wait forever for a non-existent resource + await get_resource(float, wait=True) + else: + self.add_component("child1", StallingComponent, level=self.level + 1) + self.add_component("child2", StallingComponent, level=self.level + 1) + + await super().start() caplog.set_level(logging.INFO) - component = StallingComponent() - with raises_in_exception_group(TimeoutError): + component = StallingComponent(1) + with pytest.raises(SystemExit) as exc_info: run_application(component, start_timeout=0.1, backend=anyio_backend_name) - records = [ - record for record in caplog.records if record.name == "asphalt.core._runner" - ] - assert len(records) == 4 - assert records[0].message == "Running in development mode" - assert records[1].message == "Starting application" - assert records[2].message.startswith( - "Timeout waiting for the root component to start" + assert exc_info.value.code == 1 + assert len(caplog.messages) == 4 + assert caplog.messages[0] == "Running in development mode" + assert caplog.messages[1] == "Starting application" + assert caplog.messages[2].startswith( + "Timeout waiting for the root component to start\n" + "Components still waiting to finish startup:\n" ) - # assert "-> await ctx.get_resource(float)" in records[2].message - assert records[3].message == "Application stopped" + assert caplog.messages[3] == "Application stopped" + + child_component_re = re.compile(r"([ |]+)\+-([a-z.12]+) \((.+)\)") + lines = caplog.messages[2].splitlines() + expected_test_name = f"{__name__}.test_start_timeout" + assert lines[2] == f" root ({expected_test_name}..StallingComponent)" + component_aliases: set[str] = set() + depths: list[int] = [0] * levels + expected_indent = " | " + for line in lines[3:]: + if match := child_component_re.match(line): + indent, alias, component_name = match.groups() + depth = len(alias.split(".")) + depths[depth - 1] += 1 + depths[depth:] = [0] * (len(depths) - depth) + assert len(depths) == levels + assert all(d < 3 for d in depths) + expected_indent = " " + "".join( + (" " if d > 1 else "| ") for d in depths[:depth] + ) + assert component_name == ( + f"{expected_test_name}..StallingComponent" + ) + component_aliases.add(alias) + else: + assert line.startswith(expected_indent) + + if levels == 1: + assert not component_aliases + elif levels == 2: + assert component_aliases == {"child1", "child2"} + else: + assert component_aliases == { + "child1", + "child2", + "child1.child1", + "child1.child2", + "child2.child1", + "child2.child2", + } def test_dict_config(caplog: LogCaptureFixture, anyio_backend_name: str) -> None: