diff --git a/docs/api.rst b/docs/api.rst index ee8cca4e..6fecd74c 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -10,12 +10,6 @@ Components .. autoclass:: ContainerComponent .. autoclass:: CLIApplicationComponent -Concurrency ------------ - -.. autofunction:: start_background_task -.. autofunction:: start_service_task - Contexts and resources ---------------------- diff --git a/docs/tutorials/echo.rst b/docs/tutorials/echo.rst index c486a08a..72aadbc4 100644 --- a/docs/tutorials/echo.rst +++ b/docs/tutorials/echo.rst @@ -112,7 +112,6 @@ For this purpose, we will use AnyIO's :func:`~anyio.create_tcp_listener` functio Context, context_teardown, run_application, - start_service_task, ) @@ -128,7 +127,7 @@ For this purpose, we will use AnyIO's :func:`~anyio.create_tcp_listener` functio async with await anyio.create_tcp_listener( local_host="localhost", local_port=64100 ) as listener: - start_service_task(lambda: listener.serve(handle), "Echo server") + self.task_group.start_soon(lambda: listener.serve(handle)) yield if __name__ == '__main__': diff --git a/docs/tutorials/webnotifier.rst b/docs/tutorials/webnotifier.rst index 91dc7343..291c023c 100644 --- a/docs/tutorials/webnotifier.rst +++ b/docs/tutorials/webnotifier.rst @@ -275,7 +275,7 @@ Asphalt application:: async def start(self, ctx: Context) -> None: detector = Detector(self.url, self.delay) await ctx.add_resource(detector) - start_service_task(detector.run, "Web page change detector") + self.task_group.start_soon(detector.run) logging.info( 'Started web page change detector for url "%s" with a delay of %d seconds', self.url, diff --git a/examples/tutorial1/echo/server.py b/examples/tutorial1/echo/server.py index 6f9a8adc..382dd41f 100644 --- a/examples/tutorial1/echo/server.py +++ b/examples/tutorial1/echo/server.py @@ -11,7 +11,6 @@ Context, context_teardown, run_application, - start_service_task, ) @@ -27,7 +26,7 @@ async def start(self, ctx: Context) -> AsyncGenerator[None, Exception | None]: async with await anyio.create_tcp_listener( local_host="localhost", local_port=64100 ) as listener: - start_service_task(lambda: listener.serve(handle), "Echo server") + self.task_group.start_soon(lambda: listener.serve(handle)) yield diff --git a/examples/tutorial2/webnotifier/detector.py b/examples/tutorial2/webnotifier/detector.py index 59ba3889..d8778add 100644 --- a/examples/tutorial2/webnotifier/detector.py +++ b/examples/tutorial2/webnotifier/detector.py @@ -16,7 +16,6 @@ Event, Signal, context_teardown, - start_service_task, ) logger = logging.getLogger(__name__) @@ -67,7 +66,7 @@ def __init__(self, url: str, delay: int = 10): async def start(self, ctx: Context) -> AsyncGenerator[None, Exception | None]: detector = Detector(self.url, self.delay) await ctx.add_resource(detector) - start_service_task(detector.run, "Web page change detector") + self.task_group.start_soon(detector.run) logging.info( 'Started web page change detector for url "%s" with a delay of %d seconds', self.url, diff --git a/src/asphalt/core/__init__.py b/src/asphalt/core/__init__.py index b1fdda94..f92d31af 100644 --- a/src/asphalt/core/__init__.py +++ b/src/asphalt/core/__init__.py @@ -4,7 +4,6 @@ "Component", "ContainerComponent", "start_background_task", - "start_service_task", "Context", "ResourceConflict", "ResourceEvent", @@ -35,7 +34,6 @@ Component, ContainerComponent, ) -from ._concurrent import start_background_task, start_service_task from ._context import ( Context, NoCurrentContext, diff --git a/src/asphalt/core/_component.py b/src/asphalt/core/_component.py index 642eb010..1cdf35af 100644 --- a/src/asphalt/core/_component.py +++ b/src/asphalt/core/_component.py @@ -2,13 +2,14 @@ from abc import ABCMeta, abstractmethod from collections import OrderedDict +from contextlib import AsyncExitStack from typing import Any from warnings import warn from anyio import create_task_group +from anyio.abc import TaskGroup from anyio.lowlevel import cancel_shielded_checkpoint -from ._concurrent import start_service_task from ._context import Context from ._exceptions import ApplicationExit from ._utils import PluginContainer, merge_config, qualified_name @@ -17,7 +18,25 @@ class Component(metaclass=ABCMeta): """This is the base class for all Asphalt components.""" - __slots__ = () + _task_group: TaskGroup | None = None + + async def __aenter__(self) -> Component: + if self._task_group is not None: + raise RuntimeError("Component already entered") + + async with AsyncExitStack() as exit_stack: + tg = create_task_group() + self._task_group = await exit_stack.enter_async_context(tg) + self._exit_stack = exit_stack.pop_all() + + return self + + async def __aexit__(self, exc_type, exc_value, exc_tb): + if self._task_group is None: + raise RuntimeError("Component not entered") + + self._task_group = None + return await self._exit_stack.__aexit__(exc_type, exc_value, exc_tb) @abstractmethod async def start(self, ctx: Context) -> None: @@ -38,6 +57,16 @@ async def start(self, ctx: Context) -> None: :param ctx: the containing context for this component """ + @property + def task_group(self) -> TaskGroup: + if self._task_group is None: + raise RuntimeError( + "Component has no task group, did you forget to use: " + "async with component ?" + ) + else: + return self._task_group + class ContainerComponent(Component): """ @@ -111,9 +140,9 @@ async def start(self, ctx: Context) -> None: if alias not in self.child_components: self.add_component(alias) - async with create_task_group() as tg: - for alias, component in self.child_components.items(): - tg.start_soon(component.start, ctx) + for component in self.child_components.values(): + component._task_group = self._task_group + self.task_group.start_soon(component.start, ctx) class CLIApplicationComponent(ContainerComponent): @@ -156,7 +185,8 @@ async def run() -> None: raise ApplicationExit await super().start(ctx) - start_service_task(run, "Main task") + assert self._task_group is not None + self._task_group.start_soon(run) @abstractmethod async def run(self) -> int | None: diff --git a/src/asphalt/core/_concurrent.py b/src/asphalt/core/_concurrent.py deleted file mode 100644 index 29fc8044..00000000 --- a/src/asphalt/core/_concurrent.py +++ /dev/null @@ -1,103 +0,0 @@ -from __future__ import annotations - -import logging -from collections.abc import Callable, Coroutine -from typing import Any - -from anyio.abc import TaskGroup - -from ._context import Context, current_context, get_resource_nowait -from ._exceptions import ApplicationExit - -logger = logging.getLogger(__name__) - - -def start_background_task( - func: Callable[..., Coroutine[Any, Any, Any]], name: str -) -> None: - """ - Start a task that runs independently on the background. - - The task runs in its own context, inherited from the root context. - If the task raises an exception, the exception is logged with a descriptive message - containing the task's name. - - To pass arguments to the target callable, pass them via lambda (e.g. - ``lambda: yourfunc(arg1, arg2, kw=val)``) - - :param func: the coroutine function to run - :param name: descriptive name for the task - - """ - - async def run_background_task() -> None: - logger.debug("Background task (%s) starting", name) - try: - async with Context(): - await func() - except Exception: - logger.exception("Background task (%s) crashed", name) - else: - logger.debug("Background task (%s) finished", name) - - ctx = current_context() - while ctx.parent: - ctx = ctx.parent - - root_taskgroup = get_resource_nowait( - TaskGroup, # type: ignore[type-abstract] - "root_taskgroup", - ) - root_taskgroup.start_soon(run_background_task, name=name) - - -def start_service_task( - func: Callable[..., Coroutine[Any, Any, Any]], name: str -) -> None: - """ - Start a task that runs independently on the background. - - The task runs in its own context, inherited from the root context. - If the task raises an exception, it is propagated to the application runner, - triggering the termination of the application. - - To pass arguments to the target callable, pass them via lambda (e.g. - ``lambda: yourfunc(arg1, arg2, kw=val)``) - - :param func: the coroutine function to run - :param name: descriptive name for the task - - """ - - async def run_service_task() -> None: - logger.debug("Service task (%s) starting", name) - try: - async with Context(): - await func() - except ApplicationExit: - logger.info( - "Service task (%s) requested the application to be shut down", name - ) - raise - except SystemExit as exc: - # asyncio stops the loop prematurely if a base exception like SystemExit - # is raised, so we work around that with Asphalt's SoftSystemExit which - # inherits from Exception instead - raise ApplicationExit(exc.code).with_traceback(exc.__traceback__) from None - except Exception: - logger.exception( - "Service task (%s) crashed – terminating application", name - ) - raise - else: - logger.info("Service task (%s) finished", name) - - ctx = current_context() - while ctx.parent: - ctx = ctx.parent - - root_taskgroup = get_resource_nowait( - TaskGroup, # type: ignore[type-abstract] - "root_taskgroup", - ) - root_taskgroup.start_soon(run_service_task, name=name) diff --git a/src/asphalt/core/_runner.py b/src/asphalt/core/_runner.py index 3da6a5ff..0268e1cd 100644 --- a/src/asphalt/core/_runner.py +++ b/src/asphalt/core/_runner.py @@ -16,7 +16,7 @@ sleep, to_thread, ) -from anyio.abc import TaskGroup, TaskStatus +from anyio.abc import TaskStatus from exceptiongroup import catch from ._component import Component, component_types @@ -112,7 +112,7 @@ async def run_application( exit_stack.enter_context(catch(handlers)) # type: ignore[arg-type] root_tg = await exit_stack.enter_async_context(create_task_group()) ctx = await exit_stack.enter_async_context(Context()) - await ctx.add_resource(root_tg, "root_taskgroup", [TaskGroup]) + component._task_group = root_tg if platform.system() != "Windows": await root_tg.start(handle_signals, name="Asphalt signal handler") diff --git a/tests/test_component.py b/tests/test_component.py index 672e6a5b..857b8d65 100644 --- a/tests/test_component.py +++ b/tests/test_component.py @@ -111,7 +111,7 @@ def test_add_duplicate_component(self, container) -> None: assert str(exc.value) == 'there is already a child component named "dummy"' async def test_start(self, container) -> None: - async with Context() as ctx: + async with Context() as ctx, container: await container.start(ctx) assert container.child_components["dummy"].started diff --git a/tests/test_runner.py b/tests/test_runner.py index b2e6fc83..eda11726 100644 --- a/tests/test_runner.py +++ b/tests/test_runner.py @@ -17,7 +17,6 @@ Component, Context, run_application, - start_service_task, ) pytestmark = pytest.mark.anyio() @@ -52,7 +51,7 @@ async def start(self, ctx: Context) -> None: if self.method == "timeout": await anyio.sleep(1) else: - start_service_task(self.stop_app, name="Application terminator") + self.task_group.start_soon(self.stop_app) class CrashComponent(Component):