diff --git a/src/hypercorn/asyncio/lifespan.py b/src/hypercorn/asyncio/lifespan.py index eaef9068..bd22c8ff 100644 --- a/src/hypercorn/asyncio/lifespan.py +++ b/src/hypercorn/asyncio/lifespan.py @@ -1,6 +1,7 @@ from __future__ import annotations import asyncio +import sys from functools import partial from typing import Any, Callable @@ -8,6 +9,9 @@ from ..typing import AppWrapper, ASGIReceiveEvent, ASGISendEvent, LifespanScope, LifespanState from ..utils import LifespanFailureError, LifespanTimeoutError +if sys.version_info < (3, 11): + from exceptiongroup import BaseExceptionGroup + class UnexpectedMessageError(Exception): pass @@ -58,7 +62,13 @@ def _call_soon(func: Callable, *args: Any) -> Any: except LifespanFailureError: # Lifespan failures should crash the server raise - except Exception: + except (BaseExceptionGroup, Exception) as error: + if isinstance(error, BaseExceptionGroup): + failure_error = error.subgroup(LifespanFailureError) + if failure_error is not None: + # Lifespan failures should crash the server + raise failure_error + self.supported = False if not self.startup.is_set(): await self.config.log.warning( diff --git a/src/hypercorn/trio/lifespan.py b/src/hypercorn/trio/lifespan.py index 21f4dd26..cd809845 100644 --- a/src/hypercorn/trio/lifespan.py +++ b/src/hypercorn/trio/lifespan.py @@ -1,11 +1,16 @@ from __future__ import annotations +import sys + import trio from ..config import Config from ..typing import AppWrapper, ASGIReceiveEvent, ASGISendEvent, LifespanScope, LifespanState from ..utils import LifespanFailureError, LifespanTimeoutError +if sys.version_info < (3, 11): + from exceptiongroup import BaseExceptionGroup + class UnexpectedMessageError(Exception): pass @@ -43,7 +48,13 @@ async def handle_lifespan( except LifespanFailureError: # Lifespan failures should crash the server raise - except Exception: + except (BaseExceptionGroup, Exception) as error: + if isinstance(error, BaseExceptionGroup): + failure_error = error.subgroup(LifespanFailureError) + if failure_error is not None: + # Lifespan failures should crash the server + raise failure_error + self.supported = False if not self.startup.is_set(): await self.config.log.warning( diff --git a/tests/asyncio/test_lifespan.py b/tests/asyncio/test_lifespan.py index bf2cfc6d..e79d173b 100644 --- a/tests/asyncio/test_lifespan.py +++ b/tests/asyncio/test_lifespan.py @@ -9,9 +9,14 @@ from hypercorn.app_wrappers import ASGIWrapper from hypercorn.asyncio.lifespan import Lifespan from hypercorn.config import Config -from hypercorn.typing import Scope +from hypercorn.typing import ASGIReceiveCallable, ASGISendCallable, Scope from hypercorn.utils import LifespanFailureError, LifespanTimeoutError -from ..helpers import lifespan_failure, SlowLifespanFramework +from ..helpers import SlowLifespanFramework + +try: + from asyncio import TaskGroup +except ImportError: + from taskgroup import TaskGroup # type: ignore async def no_lifespan_app(scope: Scope, receive: Callable, send: Callable) -> None: @@ -47,17 +52,27 @@ async def test_startup_timeout_error() -> None: await task +async def _lifespan_failure( + scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable +) -> None: + async with TaskGroup(): + while True: + message = await receive() + if message["type"] == "lifespan.startup": + await send({"type": "lifespan.startup.failed", "message": "Failure"}) + break + + @pytest.mark.asyncio async def test_startup_failure() -> None: event_loop: asyncio.AbstractEventLoop = asyncio.get_running_loop() - lifespan = Lifespan(ASGIWrapper(lifespan_failure), Config(), event_loop, {}) + lifespan = Lifespan(ASGIWrapper(_lifespan_failure), Config(), event_loop, {}) lifespan_task = event_loop.create_task(lifespan.handle_lifespan()) await lifespan.wait_for_startup() assert lifespan_task.done() exception = lifespan_task.exception() - assert isinstance(exception, LifespanFailureError) - assert str(exception) == "Lifespan failure in startup. 'Failure'" + assert exception.subgroup(LifespanFailureError) is not None # type: ignore async def return_app(scope: Scope, receive: Callable, send: Callable) -> None: diff --git a/tests/helpers.py b/tests/helpers.py index b72b1794..0e2d4d8d 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -71,16 +71,6 @@ async def echo_framework( await send({"type": "websocket.send", "text": event["text"], "bytes": event["bytes"]}) -async def lifespan_failure( - scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable -) -> None: - while True: - message = await receive() - if message["type"] == "lifespan.startup": - await send({"type": "lifespan.startup.failed", "message": "Failure"}) - break - - async def sanity_framework( scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable ) -> None: diff --git a/tests/trio/test_lifespan.py b/tests/trio/test_lifespan.py index bdccf45a..1dbc0086 100644 --- a/tests/trio/test_lifespan.py +++ b/tests/trio/test_lifespan.py @@ -11,8 +11,9 @@ from hypercorn.app_wrappers import ASGIWrapper from hypercorn.config import Config from hypercorn.trio.lifespan import Lifespan +from hypercorn.typing import ASGIReceiveCallable, ASGISendCallable, Scope from hypercorn.utils import LifespanFailureError, LifespanTimeoutError -from ..helpers import lifespan_failure, SlowLifespanFramework +from ..helpers import SlowLifespanFramework @pytest.mark.trio @@ -26,19 +27,23 @@ async def test_startup_timeout_error(nursery: trio._core._run.Nursery) -> None: assert str(exc_info.value).startswith("Timeout whilst awaiting startup") +async def _lifespan_failure( + scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable +) -> None: + async with trio.open_nursery(): + while True: + message = await receive() + if message["type"] == "lifespan.startup": + await send({"type": "lifespan.startup.failed", "message": "Failure"}) + break + + @pytest.mark.trio async def test_startup_failure() -> None: - lifespan = Lifespan(ASGIWrapper(lifespan_failure), Config(), {}) - with pytest.raises(LifespanFailureError) as exc_info: - try: - async with trio.open_nursery() as lifespan_nursery: - await lifespan_nursery.start(lifespan.handle_lifespan) - await lifespan.wait_for_startup() - except ExceptionGroup as exception: - target_exception = exception - if len(exception.exceptions) == 1: - target_exception = exception.exceptions[0] - - raise target_exception.with_traceback(target_exception.__traceback__) - - assert str(exc_info.value) == "Lifespan failure in startup. 'Failure'" + lifespan = Lifespan(ASGIWrapper(_lifespan_failure), Config(), {}) + try: + async with trio.open_nursery() as lifespan_nursery: + await lifespan_nursery.start(lifespan.handle_lifespan) + await lifespan.wait_for_startup() + except ExceptionGroup as error: + assert error.subgroup(LifespanFailureError) is not None