diff --git a/src/asphalt/core/_concurrent.py b/src/asphalt/core/_concurrent.py index 96495692..d08f17b9 100644 --- a/src/asphalt/core/_concurrent.py +++ b/src/asphalt/core/_concurrent.py @@ -24,6 +24,9 @@ else: from typing_extensions import TypeAlias +if sys.version_info < (3, 11): + from exceptiongroup import ExceptionGroup + T_Retval = TypeVar("T_Retval") TeardownAction: TypeAlias = Union[Callable[[], Any], Literal["cancel"], None] ExceptionHandler: TypeAlias = Callable[[Exception], bool] @@ -99,6 +102,7 @@ class TaskFactory: exception_handler: ExceptionHandler | None = None _finished_event: Event = field(init=False, default_factory=Event) _task_group: TaskGroup = field(init=False) + _sub_task_group: TaskGroup = field(init=False) async def start_task( self, @@ -128,7 +132,7 @@ async def start_task( """ task_handle = TaskHandle(name=name or callable_name(func)) - task_handle.start_value = await self._task_group.start( + task_handle.start_value = await self._sub_task_group.start( _run_background_task, func, task_handle, @@ -155,7 +159,7 @@ def start_task_soon( """ task_handle = TaskHandle(name=name or callable_name(func)) - self._task_group.start_soon( + self._sub_task_group.start_soon( _run_background_task, func, task_handle, @@ -164,14 +168,38 @@ def start_task_soon( ) return task_handle + def cancel_all_tasks(self) -> None: + """Schedule all currently running tasks to be cancelled.""" + self._sub_task_group.cancel_scope.cancel() + self._task_group.start_soon(self._start) + + async def wait_all_tasks_finished(self) -> None: + """Wait until all currently running tasks are finished.""" + self._finished_event.set() + self._finished_event = Event() + await self._task_group.start(self._start) + async def _run( self, ctx: Context, resource_name: str, *, task_status: TaskStatus[None] ) -> None: - async with create_task_group() as self._task_group: - ctx.add_resource(self, resource_name) + try: + async with create_task_group() as self._task_group: + ctx.add_resource(self, resource_name) + task_status.started() + await self._start() + except ExceptionGroup as excgrp: + raise excgrp.exceptions[0] + + async def _start( + self, *, task_status: TaskStatus[None] = TASK_STATUS_IGNORED + ) -> None: + async with create_task_group() as self._sub_task_group: task_status.started() await self._finished_event.wait() + def _stop(self) -> None: + self._finished_event.set() + async def start_service_task( func: Callable[..., Coroutine[Any, Any, T_Retval]], @@ -285,6 +313,6 @@ async def start_background_task_factory( await start_service_task( partial(factory._run, current_context(), resource_name), f"Background task factory ({resource_name})", - teardown_action=factory._finished_event.set, + teardown_action=factory._stop, ) return factory diff --git a/tests/test_concurrent.py b/tests/test_concurrent.py index 30e5ceac..b9c5a37d 100644 --- a/tests/test_concurrent.py +++ b/tests/test_concurrent.py @@ -149,6 +149,36 @@ async def taskfunc() -> str: assert handle.name == expected_name + async def test_cancel_all_tasks(self) -> None: + async def taskfunc(task_status: TaskStatus[None]) -> None: + task_status.started() + await sleep(1) + raise RuntimeError("this exception should not be raised") + + async with Context(): + factory = await start_background_task_factory() + await factory.start_task(taskfunc) + await factory.start_task(taskfunc) + factory.cancel_all_tasks() + + async def test_wait_all_tasks_finished(self) -> None: + return_values = [] + + async def taskfunc(task_status: TaskStatus[None]) -> None: + task_status.started() + return_values.append("returnvalue") + + async with Context(): + factory = await start_background_task_factory() + await factory.start_task(taskfunc) + await factory.start_task(taskfunc) + await factory.wait_all_tasks_finished() + assert return_values == 2 * ["returnvalue"] + return_values.clear() + await factory.start_task(taskfunc) + await factory.wait_all_tasks_finished() + assert return_values == ["returnvalue"] + class TestServiceTask: async def test_bad_teardown_action(self, caplog: LogCaptureFixture) -> None: