Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support AnyIO #101

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ dependencies = [
"asyncio_extras ~= 1.3",
"async_timeout >= 2.0",
"click >= 6.6",
"anyio >=4.1.0,<5",
]
dynamic = ["version"]

Expand All @@ -53,6 +54,7 @@ test = [
"pytest >= 3.9",
"pytest-asyncio",
"uvloop; python_version < '3.12' and python_implementation == 'CPython' and platform_system != 'Windows'",
"trio >=0.24.0",
]
doc = [
"Sphinx >= 7.0",
Expand Down
4 changes: 3 additions & 1 deletion src/asphalt/core/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
import os
import re
from collections.abc import Mapping
from functools import partial
from pathlib import Path
from typing import Any

import anyio
import click
from ruamel.yaml import YAML, ScalarNode
from ruamel.yaml.loader import Loader
Expand Down Expand Up @@ -140,4 +142,4 @@ def run(
config = merge_config(config, service_config)

# Start the application
run_application(**config)
anyio.run(partial(run_application, **config))
73 changes: 51 additions & 22 deletions src/asphalt/core/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,42 @@

__all__ = ("Component", "ContainerComponent", "CLIApplicationComponent")

import asyncio
import sys
from abc import ABCMeta, abstractmethod
from asyncio import Future
from collections import OrderedDict
from contextlib import AsyncExitStack
from traceback import print_exception
from typing import Any
from warnings import warn

from anyio import create_memory_object_stream, create_task_group
from anyio.abc import TaskGroup

from .context import Context
from .utils import PluginContainer, merge_config, qualified_name


class Component(metaclass=ABCMeta):
"""This is the base class for all Asphalt components."""

__slots__ = ()
_task_group = None

async def __aenter__(self) -> Component:
if self._task_group is not None:
raise RuntimeError("Component already entered")

async with AsyncExitStack() as exit_stack:
tg = create_task_group()
self._task_group = await exit_stack.enter_async_context(tg)
self._exit_stack = exit_stack.pop_all()

return self

async def __aexit__(self, exc_type, exc_value, exc_tb):
if self._task_group is None:
raise RuntimeError("Component not entered")

self._task_group = None
return await self._exit_stack.__aexit__(exc_type, exc_value, exc_tb)

@abstractmethod
async def start(self, ctx: Context) -> None:
Expand All @@ -39,6 +58,10 @@ async def start(self, ctx: Context) -> None:
:param ctx: the containing context for this component
"""

@property
def task_group(self) -> TaskGroup:
return self._task_group


class ContainerComponent(Component):
"""
Expand Down Expand Up @@ -107,9 +130,13 @@ async def start(self, ctx: Context) -> None:
if alias not in self.child_components:
self.add_component(alias)

tasks = [component.start(ctx) for component in self.child_components.values()]
if tasks:
await asyncio.gather(*tasks)
async def start_child_components():
async with create_task_group() as tg:
for component in self.child_components.values():
component._task_group = self._task_group
tg.start_soon(component.start, ctx)

await start_child_components()


class CLIApplicationComponent(ContainerComponent):
Expand All @@ -128,35 +155,37 @@ class CLIApplicationComponent(ContainerComponent):
"""

async def start(self, ctx: Context) -> None:
def run_complete(f: Future[int | None]) -> None:
# If run() raised an exception, print it with a traceback and exit with code 1
exc = f.exception()
if exc is not None:
await super().start(ctx)

async def run(exit_code):
try:
retval = await self.run(ctx)
except Exception as exc:
print_exception(type(exc), exc, exc.__traceback__)
sys.exit(1)
exit_code.send_nowait(1)
return

retval = f.result()
if isinstance(retval, int):
if 0 <= retval <= 127:
sys.exit(retval)
exit_code.send_nowait(retval)
else:
warn("exit code out of range: %d" % retval)
sys.exit(1)
exit_code.send_nowait(1)
elif retval is not None:
warn(
"run() must return an integer or None, not %s"
% qualified_name(retval.__class__)
)
sys.exit(1)
exit_code.send_nowait(1)
else:
sys.exit(0)
exit_code.send_nowait(0)

def start_run_task() -> None:
task = ctx.loop.create_task(self.run(ctx))
task.add_done_callback(run_complete)
send_stream, receive_stream = create_memory_object_stream[int](max_buffer_size=1)
self._exit_code = receive_stream
self.task_group.start_soon(run, send_stream)

await super().start(ctx)
ctx.loop.call_later(0.1, start_run_task)
async def exit_code(self) -> int:
return await self._exit_code.receive()

@abstractmethod
async def run(self, ctx: Context) -> int | None:
Expand Down
96 changes: 33 additions & 63 deletions src/asphalt/core/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,16 @@
__all__ = ("run_application",)

import asyncio
import signal
import sys
from asyncio.events import AbstractEventLoop
from concurrent.futures import ThreadPoolExecutor
from logging import INFO, Logger, basicConfig, getLogger, shutdown
from logging.config import dictConfig
from traceback import print_exception
from typing import Any, cast

from .component import Component, component_types
from anyio import create_task_group, fail_after

from .component import CLIApplicationComponent, Component, component_types
from .context import Context, _current_context
from .utils import PluginContainer, qualified_name

Expand All @@ -24,11 +25,10 @@ def sigterm_handler(logger: Logger, event_loop: AbstractEventLoop) -> None:
event_loop.stop()


def run_application(
async def run_application(
component: Component | dict[str, Any],
*,
event_loop_policy: str | None = None,
max_threads: int | None = None,
logging: dict[str, Any] | int | None = INFO,
start_timeout: int | float | None = 10,
) -> None:
Expand All @@ -48,18 +48,11 @@ def run_application(
By default, the logging system is initialized using :func:`~logging.basicConfig` using the
``INFO`` logging level.

The default executor in the event loop is replaced with a new
:class:`~concurrent.futures.ThreadPoolExecutor` where the maximum number of threads is set to
the value of ``max_threads`` or, if omitted, the default value of
:class:`~concurrent.futures.ThreadPoolExecutor`.

:param component: the root component (either a component instance or a configuration dictionary
where the special ``type`` key is either a component class or a ``module:varname``
reference to one)
:param event_loop_policy: entry point name (from the ``asphalt.core.event_loop_policies``
namespace) of an alternate event loop policy (or a module:varname reference to one)
:param max_threads: the maximum number of worker threads in the default thread pool executor
(the default value depends on the event loop implementation)
:param logging: a logging configuration dictionary, :ref:`logging level <python:levels>` or
``None``
:param start_timeout: seconds to wait for the root component (and its subcomponents) to start
Expand All @@ -83,30 +76,24 @@ def run_application(
asyncio.set_event_loop_policy(policy)
logger.info("Switched event loop policy to %s", qualified_name(policy))

# Assign a new default executor with the given max worker thread limit if one was provided
event_loop = asyncio.new_event_loop()
asyncio.set_event_loop(event_loop)
# Instantiate the root component if a dict was given
if isinstance(component, dict):
component = cast(Component, component_types.create_object(**component))

logger.info("Starting application")
context = Context()
exception: BaseException | None = None
exit_code = 0

# Start the root component
token = _current_context.set(context)
try:
if max_threads is not None:
event_loop.set_default_executor(ThreadPoolExecutor(max_threads))
logger.info("Installed a new thread pool executor with max_workers=%d", max_threads)

# Instantiate the root component if a dict was given
if isinstance(component, dict):
component = cast(Component, component_types.create_object(**component))

logger.info("Starting application")
context = Context()
exception: BaseException | None = None
exit_code: str | int = 0

# Start the root component
token = _current_context.set(context)
try:
async with create_task_group() as tg:
component._task_group = tg
try:
coro = asyncio.wait_for(component.start(context), start_timeout)
event_loop.run_until_complete(coro)
except asyncio.TimeoutError as e:
with fail_after(start_timeout):
await component.start(context)
except TimeoutError as e:
exception = e
logger.error("Timeout waiting for the root component to start")
exit_code = 1
Expand All @@ -116,40 +103,23 @@ def run_application(
exit_code = 1
else:
logger.info("Application started")

# Add a signal handler to gracefully deal with SIGTERM
try:
event_loop.add_signal_handler(
signal.SIGTERM, sigterm_handler, logger, event_loop
)
except NotImplementedError:
pass # Windows does not support signals very well

# Finally, run the event loop until the process is terminated or Ctrl+C
# is pressed
try:
event_loop.run_forever()
except KeyboardInterrupt:
pass
except SystemExit as e:
exit_code = e.code or 0

# Close the root context
logger.info("Stopping application")
event_loop.run_until_complete(context.close(exception))
finally:
_current_context.reset(token)

# Shut down leftover async generators
event_loop.run_until_complete(event_loop.shutdown_asyncgens())
if isinstance(component, CLIApplicationComponent):
exit_code = await component._exit_code.receive()
except Exception as e:
exception = e
exit_code = 1
finally:
# Finally, close the event loop itself
event_loop.close()
asyncio.set_event_loop(None)
# Close the root context
logger.info("Stopping application")
await context.close(exception)
_current_context.reset(token)
logger.info("Application stopped")

# Shut down the logging system
shutdown()

if exception is not None:
print_exception(type(exception), exception, exception.__traceback__)

if exit_code:
sys.exit(exit_code)
Loading