From b2d16d3d28b8dee6a6b1a3f4a159e93e15f82034 Mon Sep 17 00:00:00 2001 From: David Brochart Date: Fri, 19 Apr 2024 09:25:01 +0200 Subject: [PATCH 1/2] Added TaskFactory cancel_all_tasks() and wait_all_tasks_finished() --- src/asphalt/core/_concurrent.py | 28 ++++++++++++++++++++--- tests/test_concurrent.py | 40 ++++++++++++++++++++++++++++++--- 2 files changed, 62 insertions(+), 6 deletions(-) diff --git a/src/asphalt/core/_concurrent.py b/src/asphalt/core/_concurrent.py index 96495692..d4172704 100644 --- a/src/asphalt/core/_concurrent.py +++ b/src/asphalt/core/_concurrent.py @@ -99,6 +99,7 @@ class TaskFactory: exception_handler: ExceptionHandler | None = None _finished_event: Event = field(init=False, default_factory=Event) _task_group: TaskGroup = field(init=False) + _sub_task_group: TaskGroup = field(init=False) async def start_task( self, @@ -128,7 +129,7 @@ async def start_task( """ task_handle = TaskHandle(name=name or callable_name(func)) - task_handle.start_value = await self._task_group.start( + task_handle.start_value = await self._sub_task_group.start( _run_background_task, func, task_handle, @@ -155,7 +156,7 @@ def start_task_soon( """ task_handle = TaskHandle(name=name or callable_name(func)) - self._task_group.start_soon( + self._sub_task_group.start_soon( _run_background_task, func, task_handle, @@ -164,14 +165,35 @@ def start_task_soon( ) return task_handle + def cancel_all_tasks(self) -> None: + """Schedule all currently running tasks to be cancelled.""" + self._sub_task_group.cancel_scope.cancel() + self._task_group.start_soon(self._start) + + async def wait_all_tasks_finished(self) -> None: + """Wait until all currently running tasks are finished.""" + self._finished_event.set() + self._finished_event = Event() + await self._task_group.start(self._start) + async def _run( self, ctx: Context, resource_name: str, *, task_status: TaskStatus[None] ) -> None: async with create_task_group() as self._task_group: ctx.add_resource(self, resource_name) + task_status.started() + await self._start() + + async def _start( + self, *, task_status: TaskStatus[None] = TASK_STATUS_IGNORED + ) -> None: + async with create_task_group() as self._sub_task_group: task_status.started() await self._finished_event.wait() + def _stop(self) -> None: + self._finished_event.set() + async def start_service_task( func: Callable[..., Coroutine[Any, Any, T_Retval]], @@ -285,6 +307,6 @@ async def start_background_task_factory( await start_service_task( partial(factory._run, current_context(), resource_name), f"Background task factory ({resource_name})", - teardown_action=factory._finished_event.set, + teardown_action=factory._stop, ) return factory diff --git a/tests/test_concurrent.py b/tests/test_concurrent.py index 30e5ceac..d24902db 100644 --- a/tests/test_concurrent.py +++ b/tests/test_concurrent.py @@ -108,9 +108,13 @@ async def taskfunc() -> NoReturn: assert len(excinfo.value.exceptions) == 1 assert isinstance(excinfo.value.exceptions[0], ExceptionGroup) - excgrp = excinfo.value.exceptions[0] - assert len(excgrp.exceptions) == 1 - assert str(excgrp.exceptions[0]) == "foo" + excgrp0 = excinfo.value.exceptions[0] + assert len(excgrp0.exceptions) == 1 + assert isinstance(excgrp0, ExceptionGroup) + excgrp1 = excgrp0.exceptions[0] + assert isinstance(excgrp1, ExceptionGroup) + assert len(excgrp1.exceptions) == 1 + assert str(excgrp1.exceptions[0]) == "foo" async def test_start_exception_handled(self) -> None: handled_exception: Exception | None = None @@ -149,6 +153,36 @@ async def taskfunc() -> str: assert handle.name == expected_name + async def test_cancel_all_tasks(self) -> None: + async def taskfunc(task_status: TaskStatus[None]) -> None: + task_status.started() + await sleep(1) + raise RuntimeError("this exception should not be raised") + + async with Context(): + factory = await start_background_task_factory() + await factory.start_task(taskfunc) + await factory.start_task(taskfunc) + factory.cancel_all_tasks() + + async def test_wait_all_tasks_finished(self) -> None: + return_values = [] + + async def taskfunc(task_status: TaskStatus[None]) -> None: + task_status.started() + return_values.append("returnvalue") + + async with Context(): + factory = await start_background_task_factory() + await factory.start_task(taskfunc) + await factory.start_task(taskfunc) + await factory.wait_all_tasks_finished() + assert return_values == 2 * ["returnvalue"] + return_values.clear() + await factory.start_task(taskfunc) + await factory.wait_all_tasks_finished() + assert return_values == ["returnvalue"] + class TestServiceTask: async def test_bad_teardown_action(self, caplog: LogCaptureFixture) -> None: From b02717f302ca18813aad726150440be1244e60f8 Mon Sep 17 00:00:00 2001 From: David Brochart Date: Fri, 19 Apr 2024 11:26:32 +0200 Subject: [PATCH 2/2] Raise exception from sub-task group --- src/asphalt/core/_concurrent.py | 14 ++++++++++---- tests/test_concurrent.py | 10 +++------- 2 files changed, 13 insertions(+), 11 deletions(-) diff --git a/src/asphalt/core/_concurrent.py b/src/asphalt/core/_concurrent.py index d4172704..d08f17b9 100644 --- a/src/asphalt/core/_concurrent.py +++ b/src/asphalt/core/_concurrent.py @@ -24,6 +24,9 @@ else: from typing_extensions import TypeAlias +if sys.version_info < (3, 11): + from exceptiongroup import ExceptionGroup + T_Retval = TypeVar("T_Retval") TeardownAction: TypeAlias = Union[Callable[[], Any], Literal["cancel"], None] ExceptionHandler: TypeAlias = Callable[[Exception], bool] @@ -179,10 +182,13 @@ async def wait_all_tasks_finished(self) -> None: async def _run( self, ctx: Context, resource_name: str, *, task_status: TaskStatus[None] ) -> None: - async with create_task_group() as self._task_group: - ctx.add_resource(self, resource_name) - task_status.started() - await self._start() + try: + async with create_task_group() as self._task_group: + ctx.add_resource(self, resource_name) + task_status.started() + await self._start() + except ExceptionGroup as excgrp: + raise excgrp.exceptions[0] async def _start( self, *, task_status: TaskStatus[None] = TASK_STATUS_IGNORED diff --git a/tests/test_concurrent.py b/tests/test_concurrent.py index d24902db..b9c5a37d 100644 --- a/tests/test_concurrent.py +++ b/tests/test_concurrent.py @@ -108,13 +108,9 @@ async def taskfunc() -> NoReturn: assert len(excinfo.value.exceptions) == 1 assert isinstance(excinfo.value.exceptions[0], ExceptionGroup) - excgrp0 = excinfo.value.exceptions[0] - assert len(excgrp0.exceptions) == 1 - assert isinstance(excgrp0, ExceptionGroup) - excgrp1 = excgrp0.exceptions[0] - assert isinstance(excgrp1, ExceptionGroup) - assert len(excgrp1.exceptions) == 1 - assert str(excgrp1.exceptions[0]) == "foo" + excgrp = excinfo.value.exceptions[0] + assert len(excgrp.exceptions) == 1 + assert str(excgrp.exceptions[0]) == "foo" async def test_start_exception_handled(self) -> None: handled_exception: Exception | None = None