Skip to content

Commit

Permalink
Added TaskFactory cancel_all_tasks() and wait_all_tasks_finished()
Browse files Browse the repository at this point in the history
  • Loading branch information
davidbrochart committed Apr 19, 2024
1 parent 21652de commit a0e217b
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 4 deletions.
26 changes: 23 additions & 3 deletions src/asphalt/core/_concurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ class TaskFactory:
exception_handler: ExceptionHandler | None = None
_finished_event: Event = field(init=False, default_factory=Event)
_task_group: TaskGroup = field(init=False)
_sub_task_group: TaskGroup = field(init=False)

async def start_task(
self,
Expand Down Expand Up @@ -128,7 +129,7 @@ async def start_task(
"""
task_handle = TaskHandle(name=name or callable_name(func))
task_handle.start_value = await self._task_group.start(
task_handle.start_value = await self._sub_task_group.start(
_run_background_task,
func,
task_handle,
Expand All @@ -155,7 +156,7 @@ def start_task_soon(
"""
task_handle = TaskHandle(name=name or callable_name(func))
self._task_group.start_soon(
self._sub_task_group.start_soon(
_run_background_task,
func,
task_handle,
Expand All @@ -164,14 +165,33 @@ def start_task_soon(
)
return task_handle

def cancel_all_tasks(self) -> None:
"""Schedule all currently running tasks to be cancelled."""
self._sub_task_group.cancel_scope.cancel()
self._task_group.start_soon(self._start)

async def wait_all_tasks_finished(self) -> None:
"""Wait until all currently running tasks are finished."""
self._finished_event.set()
self._finished_event = Event()
await self._task_group.start(self._start)

async def _run(
self, ctx: Context, resource_name: str, *, task_status: TaskStatus[None]
) -> None:
async with create_task_group() as self._task_group:
ctx.add_resource(self, resource_name)
task_status.started()
await self._start()

async def _start(self, *, task_status: TaskStatus[None] = TASK_STATUS_IGNORED) -> None:
async with create_task_group() as self._sub_task_group:
task_status.started()
await self._finished_event.wait()

def _stop(self) -> None:
self._finished_event.set()


async def start_service_task(
func: Callable[..., Coroutine[Any, Any, T_Retval]],
Expand Down Expand Up @@ -285,6 +305,6 @@ async def start_background_task_factory(
await start_service_task(
partial(factory._run, current_context(), resource_name),
f"Background task factory ({resource_name})",
teardown_action=factory._finished_event.set,
teardown_action=factory._stop,
)
return factory
32 changes: 31 additions & 1 deletion tests/test_concurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ async def taskfunc() -> NoReturn:
assert isinstance(excinfo.value.exceptions[0], ExceptionGroup)
excgrp = excinfo.value.exceptions[0]
assert len(excgrp.exceptions) == 1
assert str(excgrp.exceptions[0]) == "foo"
assert str(excgrp.exceptions[0].exceptions[0]) == "foo"

async def test_start_exception_handled(self) -> None:
handled_exception: Exception | None = None
Expand Down Expand Up @@ -149,6 +149,36 @@ async def taskfunc() -> str:

assert handle.name == expected_name

async def test_cancel_all_tasks(self) -> None:
async def taskfunc(task_status: TaskStatus[None]) -> None:
task_status.started()
await sleep(1)
raise RuntimeError("this exception should not be raised")

async with Context():
factory = await start_background_task_factory()
await factory.start_task(taskfunc)
await factory.start_task(taskfunc)
factory.cancel_all_tasks()

async def test_wait_all_tasks_finished(self) -> None:
return_values = []

async def taskfunc(task_status: TaskStatus[None]) -> None:
task_status.started()
return_values.append("returnvalue")

async with Context():
factory = await start_background_task_factory()
await factory.start_task(taskfunc)
await factory.start_task(taskfunc)
await factory.wait_all_tasks_finished()
assert return_values == 2 * ["returnvalue"]
return_values.clear()
await factory.start_task(taskfunc)
await factory.wait_all_tasks_finished()
assert return_values == ["returnvalue"]


class TestServiceTask:
async def test_bad_teardown_action(self, caplog: LogCaptureFixture) -> None:
Expand Down

0 comments on commit a0e217b

Please sign in to comment.