Skip to content

Commit

Permalink
Start background tasks through Component.task_group, not context reso…
Browse files Browse the repository at this point in the history
…urce
  • Loading branch information
davidbrochart committed Jan 21, 2024
1 parent 5e3b3f0 commit ae82177
Show file tree
Hide file tree
Showing 6 changed files with 42 additions and 115 deletions.
1 change: 0 additions & 1 deletion src/asphalt/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
Component,
ContainerComponent,
)
from ._concurrent import start_background_task, start_service_task
from ._context import (
Context,
NoCurrentContext,
Expand Down
44 changes: 38 additions & 6 deletions src/asphalt/core/_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@

from abc import ABCMeta, abstractmethod
from collections import OrderedDict
from contextlib import AsyncExitStack
from typing import Any
from warnings import warn

from anyio import create_task_group
from anyio.abc import TaskGroup
from anyio.lowlevel import cancel_shielded_checkpoint

from ._concurrent import start_service_task
from ._context import Context
from ._exceptions import ApplicationExit
from ._utils import PluginContainer, merge_config, qualified_name
Expand All @@ -17,7 +18,25 @@
class Component(metaclass=ABCMeta):
"""This is the base class for all Asphalt components."""

__slots__ = ()
_task_group: TaskGroup | None = None

async def __aenter__(self) -> Component:
if self._task_group is not None:
raise RuntimeError("Component already entered")

async with AsyncExitStack() as exit_stack:
tg = create_task_group()
self._task_group = await exit_stack.enter_async_context(tg)
self._exit_stack = exit_stack.pop_all()

return self

async def __aexit__(self, exc_type, exc_value, exc_tb):
if self._task_group is None:
raise RuntimeError("Component not entered")

self._task_group = None
return await self._exit_stack.__aexit__(exc_type, exc_value, exc_tb)

@abstractmethod
async def start(self, ctx: Context) -> None:
Expand All @@ -38,6 +57,15 @@ async def start(self, ctx: Context) -> None:
:param ctx: the containing context for this component
"""

@property
def task_group(self) -> TaskGroup:
if self._task_group is None:
raise RuntimeError(
"Component has no task group, did you forget to use: async with component ?"
)
else:
return self._task_group


class ContainerComponent(Component):
"""
Expand Down Expand Up @@ -111,9 +139,12 @@ async def start(self, ctx: Context) -> None:
if alias not in self.child_components:
self.add_component(alias)

async with create_task_group() as tg:
for alias, component in self.child_components.items():
tg.start_soon(component.start, ctx)
async def start_child_components():
for component in self.child_components.values():
component._task_group = self._task_group
self._task_group.start_soon(component.start, ctx)

await start_child_components()


class CLIApplicationComponent(ContainerComponent):
Expand Down Expand Up @@ -156,7 +187,8 @@ async def run() -> None:
raise ApplicationExit

await super().start(ctx)
start_service_task(run, "Main task")
assert self._task_group is not None
self._task_group.start_soon(run)

@abstractmethod
async def run(self) -> int | None:
Expand Down
103 changes: 0 additions & 103 deletions src/asphalt/core/_concurrent.py

This file was deleted.

4 changes: 2 additions & 2 deletions src/asphalt/core/_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
sleep,
to_thread,
)
from anyio.abc import TaskGroup, TaskStatus
from anyio.abc import TaskStatus
from exceptiongroup import catch

from ._component import Component, component_types
Expand Down Expand Up @@ -112,7 +112,7 @@ async def run_application(
exit_stack.enter_context(catch(handlers)) # type: ignore[arg-type]
root_tg = await exit_stack.enter_async_context(create_task_group())
ctx = await exit_stack.enter_async_context(Context())
await ctx.add_resource(root_tg, "root_taskgroup", [TaskGroup])
component._task_group = root_tg
if platform.system() != "Windows":
await root_tg.start(handle_signals, name="Asphalt signal handler")

Expand Down
2 changes: 1 addition & 1 deletion tests/test_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def test_add_duplicate_component(self, container) -> None:
assert str(exc.value) == 'there is already a child component named "dummy"'

async def test_start(self, container) -> None:
async with Context() as ctx:
async with Context() as ctx, container:
await container.start(ctx)

assert container.child_components["dummy"].started
Expand Down
3 changes: 1 addition & 2 deletions tests/test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
Component,
Context,
run_application,
start_service_task,
)

pytestmark = pytest.mark.anyio()
Expand Down Expand Up @@ -52,7 +51,7 @@ async def start(self, ctx: Context) -> None:
if self.method == "timeout":
await anyio.sleep(1)
else:
start_service_task(self.stop_app, name="Application terminator")
self.task_group.start_soon(self.stop_app)


class CrashComponent(Component):
Expand Down

0 comments on commit ae82177

Please sign in to comment.