diff --git a/docs/tutorials/snippets/echo1.py b/docs/tutorials/snippets/echo1.py index be2aea31..d0ef197d 100644 --- a/docs/tutorials/snippets/echo1.py +++ b/docs/tutorials/snippets/echo1.py @@ -1,11 +1,11 @@ # isort: off from __future__ import annotations -from asphalt.core import Component, run_application, ComponentContext +from asphalt.core import Component, run_application class ServerComponent(Component): - async def start(self, ctx: ComponentContext) -> None: + async def start(self) -> None: print("Hello, world!") diff --git a/docs/userguide/contexts.rst b/docs/userguide/contexts.rst index 3084e9f7..abdf3d31 100644 --- a/docs/userguide/contexts.rst +++ b/docs/userguide/contexts.rst @@ -171,7 +171,7 @@ For example:: class FooComponent(Component): async def start(): service = SomeService() - await service.start(ctx) + await service.start() add_teardown_callback(service.stop) add_resource(service) diff --git a/docs/userguide/snippets/components1.py b/docs/userguide/snippets/components1.py index 7259233b..129694df 100644 --- a/docs/userguide/snippets/components1.py +++ b/docs/userguide/snippets/components1.py @@ -1,6 +1,5 @@ from asphalt.core import ( Component, - ComponentContext, add_resource, get_resource, get_resource_nowait, @@ -13,11 +12,11 @@ def __init__(self) -> None: self.add_component("child1", ChildComponent, name="child1") self.add_component("child2", ChildComponent, name="child2") - async def prepare(self, ctx: ComponentContext) -> None: + async def prepare(self) -> None: print("ParentComponent.prepare()") add_resource("Hello") # adds a `str` type resource by the name `default` - async def start(self, ctx: ComponentContext) -> None: + async def start(self) -> None: print("ParentComponent.start()") print(get_resource_nowait(str, "child1_resource")) print(get_resource_nowait(str, "child2_resource")) @@ -30,11 +29,11 @@ class ChildComponent(Component): def __init__(self, name: str) -> None: self.name = name - async def prepare(self, ctx: ComponentContext) -> None: + async def prepare(self) -> None: self.parent_resource = get_resource_nowait(str) print(f"ChildComponent.prepare() [{self.name}]") - async def start(self, ctx: ComponentContext) -> None: + async def start(self) -> None: print(f"ChildComponent.start() [{self.name}]") # Add a `str` type resource, with a name like `childX_resource` diff --git a/examples/tutorial1/echo/server.py b/examples/tutorial1/echo/server.py index 5713bba9..1a96e2df 100644 --- a/examples/tutorial1/echo/server.py +++ b/examples/tutorial1/echo/server.py @@ -5,11 +5,7 @@ import anyio from anyio.abc import SocketStream, TaskStatus -from asphalt.core import ( - Component, - run_application, - ComponentContext, -) +from asphalt.core import Component, run_application, start_service_task async def handle(stream: SocketStream) -> None: @@ -27,8 +23,8 @@ async def serve_requests(*, task_status: TaskStatus[None]) -> None: class ServerComponent(Component): - async def start(self, ctx: ComponentContext) -> None: - await ctx.start_service_task(serve_requests, "Echo server") + async def start(self) -> None: + await start_service_task(serve_requests, "Echo server") if __name__ == "__main__": diff --git a/examples/tutorial2/webnotifier/detector.py b/examples/tutorial2/webnotifier/detector.py index 1733d615..80b4e009 100644 --- a/examples/tutorial2/webnotifier/detector.py +++ b/examples/tutorial2/webnotifier/detector.py @@ -14,7 +14,7 @@ Event, Signal, add_resource, - ComponentContext, + start_service_task, ) logger = logging.getLogger(__name__) @@ -59,10 +59,10 @@ def __init__(self, url: str, delay: int = 10): self.url = url self.delay = delay - async def start(self, ctx: ComponentContext) -> None: + async def start(self) -> None: detector = Detector(self.url, self.delay) add_resource(detector) - await ctx.start_service_task(detector.run, "Web page change detector") + await start_service_task(detector.run, "Web page change detector") logging.info( 'Started web page change detector for url "%s" with a delay of %d seconds', self.url, diff --git a/pyproject.toml b/pyproject.toml index 06706728..f8939ab9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -72,13 +72,14 @@ extend-select = [ "I", # isort "ISC", # flake8-implicit-str-concat "PGH", # pygrep-hooks - "RUF100", # unused noqa (yesqa) + "RUF", # Ruff-specific rules "UP", # pyupgrade "W", # pycodestyle warnings ] ignore = [ "ASYNC109", - "ASYNC115" + "ASYNC115", + "RUF001", ] [tool.ruff.lint.isort] diff --git a/src/asphalt/core/__init__.py b/src/asphalt/core/__init__.py index 8347446b..b8783046 100644 --- a/src/asphalt/core/__init__.py +++ b/src/asphalt/core/__init__.py @@ -2,7 +2,6 @@ from ._component import CLIApplicationComponent as CLIApplicationComponent from ._component import Component as Component -from ._component import ComponentContext as ComponentContext from ._component import start_component as start_component from ._concurrent import TaskFactory as TaskFactory from ._concurrent import TaskHandle as TaskHandle @@ -18,6 +17,8 @@ from ._context import get_resources as get_resources from ._context import inject as inject from ._context import resource as resource +from ._context import start_background_task_factory as start_background_task_factory +from ._context import start_service_task as start_service_task from ._event import Event as Event from ._event import Signal as Signal from ._event import SignalQueueFull as SignalQueueFull diff --git a/src/asphalt/core/_component.py b/src/asphalt/core/_component.py index fe91ee9a..bf4e1b90 100644 --- a/src/asphalt/core/_component.py +++ b/src/asphalt/core/_component.py @@ -1,7 +1,6 @@ from __future__ import annotations import logging -import sys from abc import ABCMeta, abstractmethod from collections.abc import ( Awaitable, @@ -12,7 +11,7 @@ ) from contextlib import AsyncExitStack from enum import Enum, auto -from inspect import isawaitable, isclass +from inspect import isclass from traceback import StackSummary from types import FrameType from typing import ( @@ -21,19 +20,16 @@ ClassVar, Literal, TypeVar, - Union, get_type_hints, overload, ) -from anyio import ( - create_task_group, - sleep, -) +from anyio import create_task_group, sleep from anyio.abc import TaskGroup -from ._concurrent import ExceptionHandler, TaskFactory, TaskHandle, run_background_task +from ._concurrent import ExceptionHandler, TaskFactory, TeardownAction from ._context import ( + Context, FactoryCallback, T_Resource, TeardownCallback, @@ -42,23 +38,16 @@ from ._exceptions import ComponentStartError, NoCurrentContext, ResourceNotFound from ._utils import ( PluginContainer, - callable_name, coalesce_exceptions, format_component_name, merge_config, qualified_name, ) -if sys.version_info >= (3, 10): - from typing import TypeAlias -else: - from typing_extensions import TypeAlias - logger = logging.getLogger("asphalt.core") TComponent = TypeVar("TComponent", bound="Component") T_Retval = TypeVar("T_Retval") -TeardownAction: TypeAlias = Union[Callable[[], Any], Literal["cancel"], None] class Component(metaclass=ABCMeta): @@ -118,7 +107,7 @@ def add_component( self._child_components[alias] = {"type": type or alias, **config} - async def prepare(self, ctx: ComponentContext) -> None: + async def prepare(self) -> None: """ Perform any necessary initialization before starting the component. @@ -127,7 +116,7 @@ async def prepare(self, ctx: ComponentContext) -> None: by the child components. """ - async def start(self, ctx: ComponentContext) -> None: + async def start(self) -> None: """ Perform any necessary tasks to start the services provided by this component. @@ -185,24 +174,30 @@ class ComponentState(Enum): closed = auto() -class ComponentContext: +class ComponentContext(Context): def __init__( self, component: Component, path: str, default_resource_name: str, - child_contexts: dict[str, ComponentContext], + child_component_contexts: dict[str, ComponentContext], ): + super().__init__() self._path = path self._component = component self._default_resource_name = default_resource_name - self._context = current_context() - self._child_contexts = child_contexts - self._state: ComponentState = ComponentState.initialized + self._child_component_contexts = child_component_contexts + self._component_state: ComponentState = ComponentState.initialized self._coro: Coroutine[Any, Any, None] | None = None + context = current_context() + if isinstance(context, ComponentContext): + context = context._context - def __format_resource_description( - self, types: Any, name: str, description: str | None = None + self._context: Context = context + + @staticmethod + def _format_resource_description( + types: Any, name: str, description: str | None = None ) -> str: if isclass(types): formatted = f"type={qualified_name(types)}" @@ -216,17 +211,6 @@ def __format_resource_description( return formatted - def _ensure_state(self, *allowed_states: ComponentState) -> None: - if self._state not in allowed_states: - raise RuntimeError( - f"cannot perform this operation while the component is in the " - f"{self._state} state" - ) - - @property - def state(self) -> ComponentState: - return self._state - def add_resource( self, value: T_Resource, @@ -236,8 +220,7 @@ def add_resource( description: str | None = None, teardown_callback: Callable[[], Any] | None = None, ) -> None: - self._ensure_state(ComponentState.preparing, ComponentState.starting) - if name == "default" and self._state is ComponentState.starting: + if name == "default" and self._component_state is ComponentState.starting: name = self._default_resource_name self._context.add_resource( @@ -250,7 +233,7 @@ def add_resource( logger.debug( "%s added a resource (%s)", format_component_name(self._path, capitalize=True), - self.__format_resource_description(types or type(value), name, description), + self._format_resource_description(types or type(value), name, description), ) def add_resource_factory( @@ -261,8 +244,7 @@ def add_resource_factory( types: Sequence[type] | None = None, description: str | None = None, ) -> None: - self._ensure_state(ComponentState.preparing, ComponentState.starting) - if name == "default" and self._state is ComponentState.starting: + if name == "default" and self._component_state is ComponentState.starting: name = self._default_resource_name self._context.add_resource_factory( @@ -271,7 +253,7 @@ def add_resource_factory( logger.debug( "%s added a resource factory (%s)", format_component_name(self._path, capitalize=True), - self.__format_resource_description( + self._format_resource_description( types or get_type_hints(factory_callback)["return"], name, description ), ) @@ -306,7 +288,6 @@ async def get_resource( *, optional: Literal[False, True] = False, ) -> T_Resource | None: - self._ensure_state(ComponentState.preparing, ComponentState.starting) if optional: return await self._context.get_resource(type, name, optional=True) @@ -316,7 +297,7 @@ async def get_resource( logger.debug( "%s is waiting for another component to provide a resource (%s)", format_component_name(self._path, capitalize=True), - self.__format_resource_description(type, name), + self._format_resource_description(type, name), ) # Wait until a matching resource or resource factory is available @@ -328,7 +309,7 @@ async def get_resource( logger.debug( "%s got the resource it was waiting for (%s)", format_component_name(self._path, capitalize=True), - self.__format_resource_description(type, name), + self._format_resource_description(type, name), ) return res @@ -354,7 +335,6 @@ def get_resource_nowait( *, optional: Literal[False, True] = False, ) -> T_Resource | None: - self._ensure_state(ComponentState.preparing, ComponentState.starting) if optional: return self._context.get_resource_nowait(type, name, optional=True) @@ -366,42 +346,14 @@ def get_resources(self, type: type[T_Resource]) -> Mapping[str, T_Resource]: def add_teardown_callback( self, callback: TeardownCallback, pass_exception: bool = False ) -> None: - self._ensure_state(ComponentState.preparing, ComponentState.starting) self._context.add_teardown_callback(callback, pass_exception) async def start_background_task_factory( self, *, exception_handler: ExceptionHandler | None = None ) -> TaskFactory: - """ - Start a service task that hosts ad-hoc background tasks. - - Each of the tasks started by this factory is run in its own, separate Asphalt - context, inherited from this context. - - When the service task is torn down, it will wait for all the background tasks to - finish before returning. - - It is imperative to ensure that the task factory is set up after any of the - resources potentially needed by the ad-hoc tasks are set up first. Failing to do - so risks those resources being removed from the context before all the tasks - have finished. - - :param exception_handler: a callback called to handle an exception raised from the - task. Takes the exception (:exc:`Exception`) as the argument, and should return - ``True`` if it successfully handled the exception. - :return: the task factory - - .. seealso:: :func:`start_service_task` - - """ - self._ensure_state(ComponentState.preparing, ComponentState.starting) - factory = TaskFactory(exception_handler) - await self.start_service_task( - factory._run, - f"Background task factory ({id(factory):x})", - teardown_action=factory._finished_event.set, + return await self._context.start_background_task_factory( + exception_handler=exception_handler ) - return factory async def start_service_task( self, @@ -410,75 +362,9 @@ async def start_service_task( *, teardown_action: TeardownAction = "cancel", ) -> Any: - """ - Start a background task that gets shut down when the context shuts down. - - This method is meant to be used by components to run their tasks like network - services that should be shut down with the application, because each call to this - functions registers a context teardown callback that waits for the service task to - finish before allowing the context teardown to continue.. - - If you supply a teardown callback, and it raises an exception, then the task - will be cancelled instead. - - :param func: the coroutine function to run - :param name: descriptive name (e.g. "HTTP server") for the task, to which the - prefix "Service task: " will be added when the task is actually created - in the backing asynchronous event loop implementation (e.g. asyncio) - :param teardown_action: the action to take when the context is being shut down: - - * ``'cancel'``: cancel the task - * ``None``: no action (the task must finish by itself) - * (function, or any callable, can be asynchronous): run this callable to signal - the task to finish - :return: any value passed to ``task_status.started()`` by the target callable if - it supports that, otherwise ``None`` - """ - - async def finalize_service_task() -> None: - if teardown_action == "cancel": - logger.debug("Cancelling service task %r", name) - task_handle.cancel() - elif teardown_action is not None: - teardown_action_name = callable_name(teardown_action) - logger.debug( - "Calling teardown callback (%s) for service task %r", - teardown_action_name, - name, - ) - try: - retval = teardown_action() - if isawaitable(retval): - await retval - except BaseException as exc: - task_handle.cancel() - if isinstance(exc, Exception): - logger.exception( - "Error calling teardown callback (%s) for service task %r", - teardown_action_name, - name, - ) - - logger.debug("Waiting for service task %r to finish", name) - await task_handle.wait_finished() - logger.debug("Service task %r finished", name) - - self._ensure_state(ComponentState.preparing, ComponentState.starting) - if ( - teardown_action != "cancel" - and teardown_action is not None - and not callable(teardown_action) - ): - raise ValueError( - "teardown_action must be a callable, None, or the string 'cancel'" - ) - - task_handle = TaskHandle(f"Service task: {name}") - task_handle.start_value = await self._context._task_group.start( - run_background_task, func, task_handle, name=task_handle.name + return await self._context.start_service_task( + func, name, teardown_action=teardown_action ) - self._context.add_teardown_callback(finalize_service_task) - return task_handle.start_value @overload @@ -619,55 +505,58 @@ async def _start_component(context: ComponentContext, path: str) -> None: # Prevent add_component() from being called beyond this point component = context._component component._component_started = True - component_class = type(component) - # Call prepare() on the component itself, if it's implemented on the component - # class - if component_class.prepare is not Component.prepare: - logger.debug("Calling prepare() of %s", format_component_name(path)) - context._state = ComponentState.preparing - coro = context._coro = component.prepare(context) - try: - await coro - except Exception as exc: - raise ComponentStartError("preparing", path, component_class) from exc - - logger.debug("Returned from prepare() of %s", format_component_name(path)) - context._coro = None - - # Start the child components, if there are any - if context._child_contexts: - logger.debug("Starting the child components of %s", format_component_name(path)) - context._state = ComponentState.starting_children - async with coalesce_exceptions(), create_task_group() as tg: - for alias, child_context in context._child_contexts.items(): - child_path = f"{path}.{alias}" if path else alias - tg.start_soon( - _start_component, - child_context, - child_path, - name=( - f"Starting component {child_path} " - f"({qualified_name(child_context._component)})" - ), - ) - - # Call start() on the component itself, if it's implemented on the component - # class - if component_class.start is not Component.start: - context._state = ComponentState.starting - logger.debug("Calling start() of %s", format_component_name(path)) - coro = context._coro = component.start(context) - context._state = ComponentState.starting - try: - await coro - except Exception as exc: - raise ComponentStartError("starting", path, component_class) from exc - logger.debug("Returned from start() of %s", format_component_name(path)) - context._coro = None - - context._state = ComponentState.started + async with context: + # Call prepare() on the component itself, if it's implemented on the component + # class + if component_class.prepare is not Component.prepare: + logger.debug("Calling prepare() of %s", format_component_name(path)) + context._component_state = ComponentState.preparing + coro = context._coro = component.prepare() + try: + await coro + except Exception as exc: + raise ComponentStartError("preparing", path, component_class) from exc + + logger.debug("Returned from prepare() of %s", format_component_name(path)) + context._coro = None + + # Start the child components, if there are any + if context._child_component_contexts: + logger.debug( + "Starting the child components of %s", format_component_name(path) + ) + context._component_state = ComponentState.starting_children + async with coalesce_exceptions(), create_task_group() as tg: + for alias, child_context in context._child_component_contexts.items(): + child_path = f"{path}.{alias}" if path else alias + tg.start_soon( + _start_component, + child_context, + child_path, + name=( + f"Starting component {child_path} " + f"({qualified_name(child_context._component)})" + ), + ) + + # Call start() on the component itself, if it's implemented on the component + # class + if component_class.start is not Component.start: + context._component_state = ComponentState.starting + logger.debug("Calling start() of %s", format_component_name(path)) + coro = context._coro = component.start() + context._component_state = ComponentState.starting + try: + await coro + except Exception as exc: + raise ComponentStartError("starting", path, component_class) from exc + + logger.debug("Returned from start() of %s", format_component_name(path)) + context._coro = None + + context._component_state = ComponentState.started async def _watch_component_tree_startup( @@ -677,10 +566,10 @@ async def _watch_component_tree_startup( def create_status_summaries(subcontext: ComponentContext) -> list[str]: parts = (subcontext._path or "(root)").split(".") indent = " " * (len(parts) if subcontext._path else 0) - state = subcontext._state.name.replace("_", " ") + state = subcontext._component_state.name.replace("_", " ") summaries = [f"{indent}{parts[-1]}: {state}"] - for child_context in subcontext._child_contexts.values(): - if child_context._state is not ComponentState.started: + for child_context in subcontext._child_component_contexts.values(): + if child_context._component_state is not ComponentState.started: summaries.extend(create_status_summaries(child_context)) return summaries @@ -693,7 +582,7 @@ def create_stack_summaries(subcontext: ComponentContext) -> list[str]: title = f"{subcontext._path} ({qualified_name(subcontext._component)})" summaries.append(f"{title}:\n{formatted_summary.rstrip()}") - for child_context in subcontext._child_contexts.values(): + for child_context in subcontext._child_component_contexts.values(): summaries.extend(create_stack_summaries(child_context)) return summaries diff --git a/src/asphalt/core/_concurrent.py b/src/asphalt/core/_concurrent.py index 64852bd4..19205f01 100644 --- a/src/asphalt/core/_concurrent.py +++ b/src/asphalt/core/_concurrent.py @@ -5,7 +5,7 @@ from collections.abc import Coroutine from dataclasses import dataclass, field from inspect import Parameter, signature -from typing import Any, Callable, TypeVar +from typing import Any, Callable, Literal, TypeVar, Union from anyio import ( TASK_STATUS_IGNORED, @@ -15,7 +15,6 @@ ) from anyio.abc import TaskGroup, TaskStatus -from ._context import Context from ._utils import callable_name if sys.version_info >= (3, 10): @@ -25,6 +24,7 @@ T_Retval = TypeVar("T_Retval") ExceptionHandler: TypeAlias = Callable[[Exception], bool] +TeardownAction: TypeAlias = Union[Callable[[], Any], Literal["cancel"], None] logger = logging.getLogger("asphalt.core") @@ -64,6 +64,8 @@ async def run_background_task( *, task_status: TaskStatus[Any] = TASK_STATUS_IGNORED, ) -> None: + from ._context import Context + __tracebackhide__ = True # trick supported by certain debugger frameworks # Check if the function has a parameter named "task_status" @@ -81,11 +83,10 @@ async def run_background_task( else: task_status.started() await func() - except BaseException as exc: - if isinstance(exc, Exception): - logger.exception("Background task (%s) crashed", task_handle.name) - if exception_handler is not None and exception_handler(exc): - return + except Exception as exc: + logger.exception("Background task (%s) crashed", task_handle.name) + if exception_handler is not None and exception_handler(exc): + return raise else: diff --git a/src/asphalt/core/_context.py b/src/asphalt/core/_context.py index 55f58c85..64467858 100644 --- a/src/asphalt/core/_context.py +++ b/src/asphalt/core/_context.py @@ -43,10 +43,16 @@ from anyio import ( create_task_group, - get_current_task, ) from anyio.abc import TaskGroup +from ._concurrent import ( + ExceptionHandler, + TaskFactory, + TaskHandle, + TeardownAction, + run_background_task, +) from ._event import Event, Signal from ._exceptions import ( AsyncResourceError, @@ -172,14 +178,20 @@ def __init__(self) -> None: self._teardown_callbacks: list[tuple[TeardownCallback, bool]] = [] self._child_contexts = set[Context]() self._parent = _current_context.get(None) + if self._parent is not None: + from ._component import ComponentContext + + # Don't set a ComponentContext as parent as they exit sooner + while isinstance(self._parent, ComponentContext): + self._parent = self._parent._context + self._resources = { key: res for key, res in self._parent._resources.items() if not res.is_generated } self._resource_factories = self._parent._resource_factories.copy() - self._task_group = self._parent._task_group else: self._resources = {} self._resource_factories = {} @@ -271,13 +283,11 @@ async def __aenter__(self) -> Self: self._parent._child_contexts.add(self) exit_stack.callback(self._parent._child_contexts.remove, self) - self._host_task = get_current_task() - exit_stack.callback(delattr, self, "_host_task") _reset_token = _current_context.set(self) exit_stack.callback(_current_context.reset, _reset_token) # If this is the root context, create and enter a task group - if not hasattr(self, "_task_group"): + if self._parent is None: await exit_stack.enter_async_context(coalesce_exceptions()) self._task_group = await exit_stack.enter_async_context( create_task_group() @@ -645,6 +655,119 @@ def get_resources(self, type: type[T_Resource]) -> Mapping[str, T_Resource]: if type in container.types } + async def start_background_task_factory( + self, *, exception_handler: ExceptionHandler | None = None + ) -> TaskFactory: + """ + Start a service task that hosts ad-hoc background tasks. + + Each of the tasks started by this factory is run in its own, separate Asphalt + context, inherited from this context. + + When the service task is torn down, it will wait for all the background tasks to + finish before returning. + + It is imperative to ensure that the task factory is set up after any of the + resources potentially needed by the ad-hoc tasks are set up first. Failing to do + so risks those resources being removed from the context before all the tasks + have finished. + + :param exception_handler: a callback called to handle an exception raised from the + task. Takes the exception (:exc:`Exception`) as the argument, and should return + ``True`` if it successfully handled the exception. + :return: the task factory + + .. seealso:: :func:`start_service_task` + + """ + factory = TaskFactory(exception_handler) + await self.start_service_task( + factory._run, + f"Background task factory ({id(factory):x})", + teardown_action=factory._finished_event.set, + ) + return factory + + async def start_service_task( + self, + func: Callable[..., Coroutine[Any, Any, T_Retval]], + name: str, + *, + teardown_action: TeardownAction = "cancel", + ) -> Any: + """ + Start a background task that gets shut down when the context shuts down. + + This function is meant to be used by components to run their tasks like network + services that should be shut down with the application, because each call to this + functions registers a context teardown callback that waits for the service task to + finish before allowing the context teardown to continue.. + + If you supply a teardown callback, and it raises an exception, then the task + will be cancelled instead. + + :param func: the coroutine function to run + :param name: descriptive name (e.g. "HTTP server") for the task, to which the + prefix "Service task: " will be added when the task is actually created + in the backing asynchronous event loop implementation (e.g. asyncio) + :param teardown_action: the action to take when the context is being shut down: + + * ``'cancel'``: cancel the task + * ``None``: no action (the task must finish by itself) + * (function, or any callable, can be asynchronous): run this callable to signal + the task to finish + :return: any value passed to ``task_status.started()`` by the target callable if + it supports that, otherwise ``None`` + """ + + async def finalize_service_task() -> None: + if teardown_action == "cancel": + logger.debug("Cancelling service task %r", name) + task_handle.cancel() + elif teardown_action is not None: + teardown_action_name = callable_name(teardown_action) + logger.debug( + "Calling teardown callback (%s) for service task %r", + teardown_action_name, + name, + ) + try: + retval = teardown_action() + if isawaitable(retval): + await retval + except BaseException as exc: + task_handle.cancel() + if isinstance(exc, Exception): + logger.exception( + "Error calling teardown callback (%s) for service task %r", + teardown_action_name, + name, + ) + + logger.debug("Waiting for service task %r to finish", name) + await task_handle.wait_finished() + logger.debug("Service task %r finished", name) + + if ( + teardown_action != "cancel" + and teardown_action is not None + and not callable(teardown_action) + ): + raise ValueError( + "teardown_action must be a callable, None, or the string 'cancel'" + ) + + root_context = current_context() + while root_context.parent: + root_context = root_context.parent + + task_handle = TaskHandle(f"Service task: {name}") + task_handle.start_value = await root_context._task_group.start( + run_background_task, func, task_handle, name=task_handle.name + ) + root_context.add_teardown_callback(finalize_service_task) + return task_handle.start_value + def context_teardown( func: Callable[P, AsyncGenerator[None, BaseException | None]], @@ -849,6 +972,72 @@ def get_resource_nowait( return current_context().get_resource_nowait(type, name, optional=optional) +async def start_background_task_factory( + *, exception_handler: ExceptionHandler | None = None +) -> TaskFactory: + """ + Start a service task that hosts ad-hoc background tasks. + + Each of the tasks started by this factory is run in its own, separate Asphalt + context, inherited from this context. + + When the service task is torn down, it will wait for all the background tasks to + finish before returning. + + It is imperative to ensure that the task factory is set up after any of the + resources potentially needed by the ad-hoc tasks are set up first. Failing to do + so risks those resources being removed from the context before all the tasks + have finished. + + :param exception_handler: a callback called to handle an exception raised from the + task. Takes the exception (:exc:`Exception`) as the argument, and should return + ``True`` if it successfully handled the exception. + :return: the task factory + + .. seealso:: :func:`start_service_task` + + """ + return await current_context().start_background_task_factory( + exception_handler=exception_handler + ) + + +async def start_service_task( + func: Callable[..., Coroutine[Any, Any, T_Retval]], + name: str, + *, + teardown_action: TeardownAction = "cancel", +) -> Any: + """ + Start a background task that gets shut down when the context shuts down. + + This function is meant to be used by components to run their tasks like network + services that should be shut down with the application, because each call to this + functions registers a context teardown callback that waits for the service task to + finish before allowing the context teardown to continue.. + + If you supply a teardown callback, and it raises an exception, then the task + will be cancelled instead. + + :param func: the coroutine function to run + :param name: descriptive name (e.g. "HTTP server") for the task, to which the + prefix "Service task: " will be added when the task is actually created + in the backing asynchronous event loop implementation (e.g. asyncio) + :param teardown_action: the action to take when the context is being shut down: + + * ``'cancel'``: cancel the task + * ``None``: no action (the task must finish by itself) + * (function, or any callable, can be asynchronous): run this callable to signal + the task to finish + :return: any value passed to ``task_status.started()`` by the target callable if + it supports that, otherwise ``None`` + + """ + return await current_context().start_service_task( + func, name, teardown_action=teardown_action + ) + + @dataclass class _Dependency: name: str = "default" diff --git a/src/asphalt/core/_event.py b/src/asphalt/core/_event.py index 5d835e78..7f679aca 100644 --- a/src/asphalt/core/_event.py +++ b/src/asphalt/core/_event.py @@ -37,7 +37,7 @@ class Event: :ivar float time: event creation time as seconds from the UNIX epoch """ - __slots__ = "source", "topic", "time" + __slots__ = "source", "time", "topic" source: Any topic: str diff --git a/src/asphalt/core/_runner.py b/src/asphalt/core/_runner.py index 1686ca46..e89c5c2e 100644 --- a/src/asphalt/core/_runner.py +++ b/src/asphalt/core/_runner.py @@ -3,7 +3,7 @@ import platform import signal import sys -from contextlib import AsyncExitStack +from functools import partial from logging import INFO, basicConfig, getLogger from logging.config import dictConfig from typing import Any @@ -13,19 +13,19 @@ from anyio import ( CancelScope, Event, - create_task_group, get_cancelled_exc_class, to_thread, ) from anyio.abc import TaskStatus +from . import start_service_task from ._component import ( CLIApplicationComponent, Component, start_component, ) from ._context import Context -from ._utils import coalesce_exceptions, qualified_name +from ._utils import qualified_name logger = getLogger("asphalt.core") @@ -58,30 +58,26 @@ async def _run_application_async( logger.info("Starting application") try: - async with AsyncExitStack() as exit_stack: - event = Event() - - await exit_stack.enter_async_context(Context()) - if platform.system() != "Windows": - await exit_stack.enter_async_context(coalesce_exceptions()) - startup_tg = await exit_stack.enter_async_context(create_task_group()) - startup_scope = exit_stack.enter_context(CancelScope()) - exit_stack.callback(startup_tg.cancel_scope.cancel) - await startup_tg.start( - handle_signals, startup_scope, event, name="Asphalt signal handler" - ) - - try: - component = await start_component( - component_class, config, timeout=start_timeout - ) - except (get_cancelled_exc_class(), TimeoutError): - # This happens when a signal handler cancels the startup or - # start_component() times out - return 1 - except BaseException: - logger.exception("Error during application startup") - return 1 + event = Event() + async with Context(): + with CancelScope() as startup_scope: + if platform.system() != "Windows": + await start_service_task( + partial(handle_signals, startup_scope, event), + "Asphalt signal handler", + ) + + try: + component = await start_component( + component_class, config, timeout=start_timeout + ) + except (get_cancelled_exc_class(), TimeoutError): + # This happens when a signal handler cancels the startup or + # start_component() times out + return 1 + except BaseException: + logger.exception("Error during application startup") + return 1 logger.info("Application started") diff --git a/src/asphalt/core/_utils.py b/src/asphalt/core/_utils.py index 8cb85448..d5e715e5 100644 --- a/src/asphalt/core/_utils.py +++ b/src/asphalt/core/_utils.py @@ -154,7 +154,7 @@ class PluginContainer: entry points don't point to classes) """ - __slots__ = "namespace", "base_class", "_entrypoints", "_resolved" + __slots__ = "_entrypoints", "_resolved", "base_class", "namespace" def __init__(self, namespace: str, base_class: type | None = None) -> None: self.namespace: str = namespace diff --git a/tests/test_component.py b/tests/test_component.py index aa71196a..2d934c57 100644 --- a/tests/test_component.py +++ b/tests/test_component.py @@ -14,10 +14,13 @@ from asphalt.core import ( CLIApplicationComponent, Component, - ComponentContext, ComponentStartError, Context, + add_resource, + add_resource_factory, + get_resource, get_resource_nowait, + get_resources, run_application, start_component, ) @@ -42,7 +45,7 @@ def __init__( self.alias = alias self.container = container - async def start(self, ctx: ComponentContext) -> None: + async def start(self) -> None: await anyio.sleep(0.1) if self.alias and self.container is not None: self.container[self.alias] = self @@ -140,7 +143,7 @@ def test_add_duplicate_component(self) -> None: async def test_add_component_during_start(self) -> None: class BadContainerComponent(Component): - async def start(self, ctx: ComponentContext) -> None: + async def start(self) -> None: self.add_component("foo", DummyComponent) async with Context(): @@ -263,7 +266,7 @@ def __init__(self) -> None: async def test_start_component_error_during_prepare() -> None: class BadComponent(Component): - async def prepare(self, ctx: ComponentContext) -> None: + async def prepare(self) -> None: raise RuntimeError("component fail") async with Context(): @@ -288,7 +291,7 @@ async def test_start_component_no_context() -> None: async def test_start_component_timeout() -> None: class StallingComponent(Component): - async def start(self, ctx: ComponentContext) -> None: + async def start(self) -> None: await sleep(3) pytest.fail("Shouldn't reach this point") @@ -302,16 +305,16 @@ class ParentComponent(Component): def __init__(self) -> None: self.add_component("child", ChildComponent) - async def prepare(self, ctx: ComponentContext) -> None: - ctx.add_resource("foo") + async def prepare(self) -> None: + add_resource("foo") - async def start(self, ctx: ComponentContext) -> None: + async def start(self) -> None: get_resource_nowait(str, "bar") class ChildComponent(Component): - async def start(self, ctx: ComponentContext) -> None: - foo = ctx.get_resource_nowait(str) - ctx.add_resource(foo + "bar", "bar") + async def start(self) -> None: + foo = get_resource_nowait(str) + add_resource(foo + "bar", "bar") caplog.set_level(logging.DEBUG, "asphalt.core") async with Context(): @@ -341,17 +344,17 @@ async def start(self, ctx: ComponentContext) -> None: async def test_resource_descriptions(caplog: LogCaptureFixture) -> None: class CustomComponent(Component): - async def start(self, ctx: ComponentContext) -> None: - ctx.add_resource("foo", "bar", description="sample string") - ctx.add_resource_factory( + async def start(self) -> None: + add_resource("foo", "bar", description="sample string") + add_resource_factory( lambda: 3, "bar", types=[int, float], description="sample integer factory", ) - assert await ctx.get_resource(float, optional=True) is None - assert ctx.get_resource_nowait(float, optional=True) is None - assert ctx.get_resources(str) == {"bar": "foo"} + assert await get_resource(float, optional=True) is None + assert get_resource_nowait(float, optional=True) is None + assert get_resources(str) == {"bar": "foo"} caplog.set_level(logging.DEBUG, "asphalt.core") async with Context(): @@ -381,17 +384,17 @@ def __init__(self) -> None: self.add_component("child2", Child2Component) class Child1Component(Component): - async def start(self, ctx: ComponentContext) -> None: + async def start(self) -> None: child1_ready_event.set() await child2_ready_event.wait() - ctx.add_resource("from_child1", "special") + add_resource("from_child1", "special") class Child2Component(Component): - async def start(self, ctx: ComponentContext) -> None: + async def start(self) -> None: await child1_ready_event.wait() child2_ready_event.set() with fail_after(3): - assert await ctx.get_resource(str, "special") == "from_child1" + assert await get_resource(str, "special") == "from_child1" caplog.set_level(logging.DEBUG, "asphalt.core") child1_ready_event = Event() @@ -440,10 +443,10 @@ def __init__(self) -> None: class ChildComponent(Component): name: str - async def start(self, ctx: ComponentContext) -> None: - ctx.add_resource("default_resource") - ctx.add_resource(f"special_resource_{self.name}", self.name) - ctx.add_resource_factory(lambda: 7, types=[int]) + async def start(self) -> None: + add_resource("default_resource") + add_resource(f"special_resource_{self.name}", self.name) + add_resource_factory(lambda: 7, types=[int]) caplog.set_level(logging.DEBUG, "asphalt.core") async with Context(): diff --git a/tests/test_concurrent.py b/tests/test_concurrent.py index 97223ef5..663c31df 100644 --- a/tests/test_concurrent.py +++ b/tests/test_concurrent.py @@ -11,12 +11,13 @@ from asphalt.core import ( Component, - ComponentContext, ComponentStartError, Context, TaskFactory, TaskHandle, + start_background_task_factory, start_component, + start_service_task, ) if sys.version_info < (3, 11): @@ -34,9 +35,9 @@ async def taskfunc() -> str: return "returnvalue" class TaskComponent(Component): - async def start(self, ctx: ComponentContext) -> None: + async def start(self) -> None: nonlocal handle - factory = await ctx.start_background_task_factory() + factory = await start_background_task_factory() handle = await factory.start_task(taskfunc, "taskfunc") async with Context(): @@ -52,9 +53,9 @@ async def taskfunc() -> None: assert get_current_task().name == expected_name class TaskComponent(Component): - async def start(self, ctx: ComponentContext) -> None: + async def start(self) -> None: nonlocal handle - factory = await ctx.start_background_task_factory() + factory = await start_background_task_factory() handle = await factory.start_task(taskfunc) expected_name = ( @@ -75,9 +76,9 @@ async def taskfunc(task_status: TaskStatus[str]) -> str: return "returnvalue" class TaskComponent(Component): - async def start(self, ctx: ComponentContext) -> None: + async def start(self) -> None: nonlocal handle - factory = await ctx.start_background_task_factory() + factory = await start_background_task_factory() handle = await factory.start_task(taskfunc, "taskfunc") async with Context(): @@ -99,9 +100,9 @@ async def taskfunc() -> None: finished = True class TaskComponent(Component): - async def start(self, ctx: ComponentContext) -> None: + async def start(self) -> None: nonlocal handle - factory = await ctx.start_background_task_factory() + factory = await start_background_task_factory() handle = await factory.start_task(taskfunc, "taskfunc") async with Context(): @@ -118,8 +119,8 @@ async def taskfunc() -> NoReturn: raise Exception("foo") class TaskComponent(Component): - async def start(self, ctx: ComponentContext) -> None: - factory = await ctx.start_background_task_factory() + async def start(self) -> None: + factory = await start_background_task_factory() await factory.start_task(taskfunc, "taskfunc") with pytest.raises(ExceptionGroup) as excinfo: @@ -144,8 +145,8 @@ async def taskfunc() -> NoReturn: raise Exception("foo") class TaskComponent(Component): - async def start(self, ctx: ComponentContext) -> None: - factory = await ctx.start_background_task_factory( + async def start(self) -> None: + factory = await start_background_task_factory( exception_handler=handle_exception ) await factory.start_task(taskfunc, "taskfunc") @@ -168,9 +169,9 @@ async def taskfunc() -> str: return "returnvalue" class TaskComponent(Component): - async def start(self, ctx: ComponentContext) -> None: + async def start(self) -> None: nonlocal handle - factory = await ctx.start_background_task_factory() + factory = await start_background_task_factory() handle = factory.start_task_soon(taskfunc, name) async with Context(): @@ -190,9 +191,9 @@ async def taskfunc() -> None: await event.wait() class TaskComponent(Component): - async def start(self, ctx: ComponentContext) -> None: + async def start(self) -> None: nonlocal factory, handle1, handle2 - factory = await ctx.start_background_task_factory() + factory = await start_background_task_factory() handle1 = await factory.start_task(taskfunc) handle2 = factory.start_task_soon(taskfunc) @@ -212,8 +213,8 @@ async def start(self, ctx: ComponentContext) -> None: class TestServiceTask: async def test_bad_teardown_action(self, caplog: LogCaptureFixture) -> None: class TaskComponent(Component): - async def start(self, ctx: ComponentContext) -> None: - await ctx.start_service_task( + async def start(self) -> None: + await start_service_task( lambda: sleep(1), "Dummy", teardown_action="fail", # type: ignore[arg-type] @@ -233,8 +234,8 @@ async def service_func() -> None: await event.wait() class TaskComponent(Component): - async def start(self, ctx: ComponentContext) -> None: - await ctx.start_service_task( + async def start(self) -> None: + await start_service_task( service_func, "Dummy", teardown_action=teardown_callback ) @@ -251,8 +252,8 @@ async def service_func() -> None: await event.wait() class TaskComponent(Component): - async def start(self, ctx: ComponentContext) -> None: - await ctx.start_service_task( + async def start(self) -> None: + await start_service_task( service_func, "Dummy", teardown_action=teardown_callback ) @@ -278,10 +279,8 @@ async def taskfunc() -> None: finished = True class TaskComponent(Component): - async def start(self, ctx: ComponentContext) -> None: - await ctx.start_service_task( - taskfunc, "taskfunc", teardown_action="cancel" - ) + async def start(self) -> None: + await start_service_task(taskfunc, "taskfunc", teardown_action="cancel") async with Context(): await start_component(TaskComponent) @@ -302,8 +301,8 @@ async def taskfunc(task_status: TaskStatus[str]) -> None: finished = True class TaskComponent(Component): - async def start(self, ctx: ComponentContext) -> None: - startval = await ctx.start_service_task( + async def start(self) -> None: + startval = await start_service_task( taskfunc, "taskfunc", teardown_action="cancel" ) assert startval == "startval" diff --git a/tests/test_context.py b/tests/test_context.py index 7a83d5e1..dfcf62f3 100644 --- a/tests/test_context.py +++ b/tests/test_context.py @@ -22,10 +22,14 @@ ResourceConflict, ResourceEvent, ResourceNotFound, + add_resource, + add_resource_factory, + add_teardown_callback, context_teardown, current_context, get_resource, get_resource_nowait, + get_resources, inject, resource, ) @@ -327,9 +331,9 @@ def factory() -> Optional[str]: # noqa: UP007 async def test_get_resources(self, context: Context) -> None: context.add_resource(9, "foo") - async with Context() as subctx: - subctx.add_resource(1, "bar") - assert subctx.get_resources(int) == {"bar": 1, "foo": 9} + async with Context(): + add_resource(1, "bar") + assert get_resources(int) == {"bar": 1, "foo": 9} async def test_get_resource_nowait(self, context: Context) -> None: context.add_resource(1) @@ -425,9 +429,9 @@ async def teardown_callback() -> None: nonlocal resource resource = get_resource_nowait(str) - async with Context() as ctx: - ctx.add_resource("blah") - ctx.add_teardown_callback(teardown_callback) + async with Context(): + add_resource("blah") + add_teardown_callback(teardown_callback) assert resource == "blah" @@ -438,9 +442,9 @@ async def teardown_callback() -> None: nonlocal resource resource = get_resource_nowait(str) - async with Context() as ctx: - ctx.add_resource_factory(lambda: "blah", types=[str]) - ctx.add_teardown_callback(teardown_callback) + async with Context(): + add_resource_factory(lambda: "blah", types=[str]) + add_teardown_callback(teardown_callback) assert resource == "blah" @@ -486,15 +490,15 @@ async def test_current_context() -> None: async def test_get_resource() -> None: - async with Context() as ctx: - ctx.add_resource("foo") + async with Context(): + add_resource("foo") assert await get_resource(str) == "foo" assert await get_resource(int, optional=True) is None async def test_get_resource_nowait() -> None: - async with Context() as ctx: - ctx.add_resource("foo") + async with Context(): + add_resource("foo") assert get_resource_nowait(str) == "foo" pytest.raises(ResourceNotFound, get_resource_nowait, int) @@ -507,9 +511,9 @@ async def injected( ) -> tuple[int, str, str]: return foo, bar, baz - async with Context() as ctx: - ctx.add_resource("bar_test") - ctx.add_resource("baz_test", "alt") + async with Context(): + add_resource("bar_test") + add_resource("baz_test", "alt") foo, bar, baz = await injected(2) assert foo == 2 @@ -523,9 +527,9 @@ def injected( ) -> tuple[int, str, str]: return foo, bar, baz - async with Context() as ctx: - ctx.add_resource("bar_test") - ctx.add_resource("baz_test", "alt") + async with Context(): + add_resource("bar_test") + add_resource("baz_test", "alt") foo, bar, baz = injected(2) assert foo == 2 @@ -597,10 +601,10 @@ async def injected( ) -> annotation: # type: ignore[valid-type] return res - async with Context() as ctx: + async with Context(): retval: Any = injected() if sync else (await injected()) assert retval is None - ctx.add_resource("hello") + add_resource("hello") retval = injected() if sync else (await injected()) assert retval == "hello" diff --git a/tests/test_runner.py b/tests/test_runner.py index ca344757..a9548c52 100644 --- a/tests/test_runner.py +++ b/tests/test_runner.py @@ -16,9 +16,10 @@ from asphalt.core import ( CLIApplicationComponent, Component, - ComponentContext, add_teardown_callback, + get_resource, run_application, + start_service_task, ) pytestmark = pytest.mark.anyio() @@ -41,15 +42,15 @@ async def stop_app(self) -> None: elif self.method == "exception": raise RuntimeError("this should crash the application") - async def start(self, ctx: ComponentContext) -> None: - await ctx.start_service_task(self.stop_app, "Application terminator") + async def start(self) -> None: + await start_service_task(self.stop_app, "Application terminator") class CrashComponent(Component): def __init__(self, method: str = "exit"): self.method = method - async def start(self, ctx: ComponentContext) -> None: + async def start(self) -> None: if self.method == "keyboard": signal.raise_signal(signal.SIGINT) await sleep(3) @@ -70,7 +71,7 @@ def teardown_callback(self, exception: BaseException | None) -> None: logging.getLogger(__name__).info("Teardown callback called") self.exception = exception - async def start(self, ctx: ComponentContext) -> None: + async def start(self) -> None: add_teardown_callback(self.teardown_callback, pass_exception=True) async def run(self) -> int | None: @@ -247,10 +248,10 @@ def __init__(self, level: int = 1): self.add_component("child2", StallingComponent, level=level + 1) self.add_component("child3", Component) - async def start(self, ctx: ComponentContext) -> None: + async def start(self) -> None: if self.is_leaf: # Wait forever for a non-existent resource - await ctx.get_resource(float) + await get_resource(float) caplog.set_level(logging.INFO) with pytest.raises(SystemExit) as exc_info: