Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
davidbrochart committed Jan 18, 2024
1 parent 287ef59 commit 427d0b4
Show file tree
Hide file tree
Showing 8 changed files with 313 additions and 203 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ test = [
"pytest >= 3.9",
"pytest-asyncio",
"uvloop; python_version < '3.12' and python_implementation == 'CPython' and platform_system != 'Windows'",
"trio >=0.24.0",
]
doc = [
"Sphinx >= 7.0",
Expand Down
4 changes: 3 additions & 1 deletion src/asphalt/core/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
import os
import re
from collections.abc import Mapping
from functools import partial
from pathlib import Path
from typing import Any

import anyio
import click
from ruamel.yaml import YAML, ScalarNode
from ruamel.yaml.loader import Loader
Expand Down Expand Up @@ -140,4 +142,4 @@ def run(
config = merge_config(config, service_config)

# Start the application
run_application(**config)
anyio.run(partial(run_application, **config))
58 changes: 38 additions & 20 deletions src/asphalt/core/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,14 @@

__all__ = ("Component", "ContainerComponent", "CLIApplicationComponent")

import sys
from abc import ABCMeta, abstractmethod
from asyncio import Future
from collections import OrderedDict
from contextlib import AsyncExitStack
from traceback import print_exception
from typing import Any
from warnings import warn

from anyio import create_task_group
from anyio import create_memory_object_stream, create_task_group
from anyio.abc import TaskGroup

from .context import Context
Expand All @@ -20,8 +19,25 @@
class Component(metaclass=ABCMeta):
"""This is the base class for all Asphalt components."""

__slots__ = ()
_task_group: TaskGroup
_task_group = None

async def __aenter__(self) -> Component:
if self._task_group is not None:
raise RuntimeError("Component already entered")

async with AsyncExitStack() as exit_stack:
tg = create_task_group()
self._task_group = await exit_stack.enter_async_context(tg)
self._exit_stack = exit_stack.pop_all()

return self

async def __aexit__(self, exc_type, exc_value, exc_tb):
if self._task_group is None:
raise RuntimeError("Component not entered")

self._task_group = None
return await self._exit_stack.__aexit__(exc_type, exc_value, exc_tb)

@abstractmethod
async def start(self, ctx: Context) -> None:
Expand Down Expand Up @@ -139,35 +155,37 @@ class CLIApplicationComponent(ContainerComponent):
"""

async def start(self, ctx: Context) -> None:
def run_complete(f: Future[int | None]) -> None:
# If run() raised an exception, print it with a traceback and exit with code 1
exc = f.exception()
if exc is not None:
await super().start(ctx)

async def run(exit_code):
try:
retval = await self.run(ctx)
except Exception as exc:
print_exception(type(exc), exc, exc.__traceback__)
sys.exit(1)
exit_code.send_nowait(1)
return

retval = f.result()
if isinstance(retval, int):
if 0 <= retval <= 127:
sys.exit(retval)
exit_code.send_nowait(retval)
else:
warn("exit code out of range: %d" % retval)
sys.exit(1)
exit_code.send_nowait(1)
elif retval is not None:
warn(
"run() must return an integer or None, not %s"
% qualified_name(retval.__class__)
)
sys.exit(1)
exit_code.send_nowait(1)
else:
sys.exit(0)
exit_code.send_nowait(0)

def start_run_task() -> None:
task = ctx.loop.create_task(self.run(ctx))
task.add_done_callback(run_complete)
send_stream, receive_stream = create_memory_object_stream[int](max_buffer_size=1)
self._exit_code = receive_stream
self.task_group.start_soon(run, send_stream)

await super().start(ctx)
ctx.loop.call_later(0.1, start_run_task)
async def exit_code(self) -> int:
return await self._exit_code.receive()

@abstractmethod
async def run(self, ctx: Context) -> int | None:
Expand Down
32 changes: 26 additions & 6 deletions src/asphalt/core/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,16 @@
__all__ = ("run_application",)

import asyncio
import sys
from asyncio.events import AbstractEventLoop
from logging import INFO, Logger, basicConfig, getLogger, shutdown
from logging.config import dictConfig
from traceback import print_exception
from typing import Any, cast

from anyio import create_task_group, fail_after

from .component import Component, component_types
from .component import CLIApplicationComponent, Component, component_types
from .context import Context, _current_context
from .utils import PluginContainer, qualified_name

Expand Down Expand Up @@ -81,17 +83,31 @@ async def run_application(
logger.info("Starting application")
context = Context()
exception: BaseException | None = None
exit_code = 0

# Start the root component
token = _current_context.set(context)
try:
async with create_task_group() as tg:
component._task_group = tg
with fail_after(start_timeout) as scope:
await component.start(context)
logger.info("Application started")
try:
with fail_after(start_timeout):
await component.start(context)
except TimeoutError as e:
exception = e
logger.error("Timeout waiting for the root component to start")
exit_code = 1
except Exception as e:
exception = e
logger.exception("Error during application startup")
exit_code = 1
else:
logger.info("Application started")
if isinstance(component, CLIApplicationComponent):
exit_code = await component._exit_code.receive()
except Exception as e:
exception = e
exit_code = 1
finally:
# Close the root context
logger.info("Stopping application")
Expand All @@ -102,5 +118,9 @@ async def run_application(
# Shut down the logging system
shutdown()

if exception:
raise exception
if exception is not None:
print_exception(type(exception), exception, exception.__traceback__)

print(exit_code)
if exit_code:
sys.exit(exit_code)
58 changes: 32 additions & 26 deletions tests/test_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,70 +112,76 @@ async def test_start(self, container) -> None:


class TestCLIApplicationComponent:
def test_run_return_none(self, event_loop: AbstractEventLoop) -> None:
@pytest.mark.anyio
async def test_run_return_none(self) -> None:
class DummyCLIComponent(CLIApplicationComponent):
async def run(self, ctx: Context) -> None:
pass

component = DummyCLIComponent()
event_loop.run_until_complete(component.start(Context()))
exc = pytest.raises(SystemExit, event_loop.run_forever)
assert exc.value.code == 0
async with component:
await component.start(Context())
assert await component.exit_code() == 0

def test_run_return_5(self, event_loop: AbstractEventLoop) -> None:
@pytest.mark.anyio
async def test_run_return_5(self) -> None:
class DummyCLIComponent(CLIApplicationComponent):
async def run(self, ctx: Context) -> int:
return 5

component = DummyCLIComponent()
event_loop.run_until_complete(component.start(Context()))
exc = pytest.raises(SystemExit, event_loop.run_forever)
assert exc.value.code == 5
async with component:
await component.start(Context())
assert await component.exit_code() == 5

def test_run_return_invalid_value(self, event_loop: AbstractEventLoop) -> None:
@pytest.mark.anyio
async def test_run_return_invalid_value(self) -> None:
class DummyCLIComponent(CLIApplicationComponent):
async def run(self, ctx: Context) -> int:
return 128

component = DummyCLIComponent()
event_loop.run_until_complete(component.start(Context()))
with pytest.warns(UserWarning) as record:
exc = pytest.raises(SystemExit, event_loop.run_forever)
async with component:
with pytest.warns(UserWarning) as record:
await component.start(Context())
assert await component.exit_code() == 1

assert exc.value.code == 1
assert len(record) == 1
assert str(record[0].message) == "exit code out of range: 128"
assert len(record) >= 1
assert str(record[-1].message) == "exit code out of range: 128"

def test_run_return_invalid_type(self, event_loop: AbstractEventLoop) -> None:
@pytest.mark.anyio
async def test_run_return_invalid_type(self) -> None:
class DummyCLIComponent(CLIApplicationComponent):
async def run(self, ctx: Context) -> int:
return "foo" # type: ignore[return-value]

component = DummyCLIComponent()
event_loop.run_until_complete(component.start(Context()))
with pytest.warns(UserWarning) as record:
exc = pytest.raises(SystemExit, event_loop.run_forever)
async with component:
with pytest.warns(UserWarning) as record:
await component.start(Context())
assert await component.exit_code() == 1

assert exc.value.code == 1
assert len(record) == 1
assert str(record[0].message) == "run() must return an integer or None, not str"

def test_run_exception(self, event_loop: AbstractEventLoop) -> None:
@pytest.mark.anyio
async def test_run_exception(self, event_loop: AbstractEventLoop) -> None:
class DummyCLIComponent(CLIApplicationComponent):
async def run(self, ctx: Context) -> NoReturn:
raise Exception("blah")

component = DummyCLIComponent()
event_loop.run_until_complete(component.start(Context()))
exc = pytest.raises(SystemExit, event_loop.run_forever)
assert exc.value.code == 1
async with component:
await component.start(Context())
assert await component.exit_code() == 1

def test_add_teardown_callback(self) -> None:
@pytest.mark.anyio
async def test_add_teardown_callback(self) -> None:
async def callback() -> None:
current_context()

class DummyCLIComponent(CLIApplicationComponent):
async def run(self, ctx: Context) -> None:
ctx.add_teardown_callback(callback)

run_application(DummyCLIComponent())
await run_application(DummyCLIComponent())
12 changes: 8 additions & 4 deletions tests/test_concurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ async def special_executor(context: Context) -> ThreadPoolExecutor:


@pytest.mark.parametrize("use_resource_name", [False, True], ids=["instance", "resource_name"])
@pytest.mark.asyncio
@pytest.mark.anyio
@pytest.mark.parametrize("anyio_backend", ["asyncio"])
async def test_executor_special(
context: Context, use_resource_name: bool, special_executor: ThreadPoolExecutor
) -> None:
Expand All @@ -38,7 +39,8 @@ def check_thread(ctx: Context) -> None:
await check_thread(context)


@pytest.mark.asyncio
@pytest.mark.anyio
@pytest.mark.parametrize("anyio_backend", ["asyncio"])
async def test_executor_default(event_loop: AbstractEventLoop, context: Context) -> None:
@executor
def check_thread(ctx: Context) -> None:
Expand All @@ -49,7 +51,8 @@ def check_thread(ctx: Context) -> None:
await check_thread(context)


@pytest.mark.asyncio
@pytest.mark.anyio
@pytest.mark.parametrize("anyio_backend", ["asyncio"])
async def test_executor_worker_thread(
event_loop: AbstractEventLoop,
context: Context,
Expand All @@ -73,7 +76,8 @@ def runs_in_default_worker(ctx: Context) -> str:
assert retval == "foo"


@pytest.mark.asyncio
@pytest.mark.anyio
@pytest.mark.parametrize("anyio_backend", ["asyncio"])
async def test_executor_missing_context(event_loop: AbstractEventLoop, context: Context) -> None:
@executor("special")
def runs_in_default_worker() -> None:
Expand Down
Loading

0 comments on commit 427d0b4

Please sign in to comment.