Skip to content

Commit

Permalink
Bugfix ensure ExceptionGroup lifespan failures crash the server
Browse files Browse the repository at this point in the history
A lifespan.startup.failure should crash the server, however if these
became wrapped in an Exception group in the ASGI app the server
wouldn't crash, now fixed.
  • Loading branch information
pgjones committed May 28, 2024
1 parent edd0aac commit bfb0877
Show file tree
Hide file tree
Showing 5 changed files with 63 additions and 32 deletions.
12 changes: 11 additions & 1 deletion src/hypercorn/asyncio/lifespan.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
from __future__ import annotations

import asyncio
import sys
from functools import partial
from typing import Any, Callable

from ..config import Config
from ..typing import AppWrapper, ASGIReceiveEvent, ASGISendEvent, LifespanScope, LifespanState
from ..utils import LifespanFailureError, LifespanTimeoutError

if sys.version_info < (3, 11):
from exceptiongroup import BaseExceptionGroup


class UnexpectedMessageError(Exception):
pass
Expand Down Expand Up @@ -58,7 +62,13 @@ def _call_soon(func: Callable, *args: Any) -> Any:
except LifespanFailureError:
# Lifespan failures should crash the server
raise
except Exception:
except (BaseExceptionGroup, Exception) as error:
if isinstance(error, BaseExceptionGroup):
failure_error = error.subgroup(LifespanFailureError)
if failure_error is not None:
# Lifespan failures should crash the server
raise failure_error

self.supported = False
if not self.startup.is_set():
await self.config.log.warning(
Expand Down
13 changes: 12 additions & 1 deletion src/hypercorn/trio/lifespan.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
from __future__ import annotations

import sys

import trio

from ..config import Config
from ..typing import AppWrapper, ASGIReceiveEvent, ASGISendEvent, LifespanScope, LifespanState
from ..utils import LifespanFailureError, LifespanTimeoutError

if sys.version_info < (3, 11):
from exceptiongroup import BaseExceptionGroup


class UnexpectedMessageError(Exception):
pass
Expand Down Expand Up @@ -43,7 +48,13 @@ async def handle_lifespan(
except LifespanFailureError:
# Lifespan failures should crash the server
raise
except Exception:
except (BaseExceptionGroup, Exception) as error:
if isinstance(error, BaseExceptionGroup):
failure_error = error.subgroup(LifespanFailureError)
if failure_error is not None:
# Lifespan failures should crash the server
raise failure_error

self.supported = False
if not self.startup.is_set():
await self.config.log.warning(
Expand Down
25 changes: 20 additions & 5 deletions tests/asyncio/test_lifespan.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,14 @@
from hypercorn.app_wrappers import ASGIWrapper
from hypercorn.asyncio.lifespan import Lifespan
from hypercorn.config import Config
from hypercorn.typing import Scope
from hypercorn.typing import ASGIReceiveCallable, ASGISendCallable, Scope
from hypercorn.utils import LifespanFailureError, LifespanTimeoutError
from ..helpers import lifespan_failure, SlowLifespanFramework
from ..helpers import SlowLifespanFramework

try:
from asyncio import TaskGroup
except ImportError:
from taskgroup import TaskGroup # type: ignore


async def no_lifespan_app(scope: Scope, receive: Callable, send: Callable) -> None:
Expand Down Expand Up @@ -47,17 +52,27 @@ async def test_startup_timeout_error() -> None:
await task


async def _lifespan_failure(
scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable
) -> None:
async with TaskGroup():
while True:
message = await receive()
if message["type"] == "lifespan.startup":
await send({"type": "lifespan.startup.failed", "message": "Failure"})
break


@pytest.mark.asyncio
async def test_startup_failure() -> None:
event_loop: asyncio.AbstractEventLoop = asyncio.get_running_loop()

lifespan = Lifespan(ASGIWrapper(lifespan_failure), Config(), event_loop, {})
lifespan = Lifespan(ASGIWrapper(_lifespan_failure), Config(), event_loop, {})
lifespan_task = event_loop.create_task(lifespan.handle_lifespan())
await lifespan.wait_for_startup()
assert lifespan_task.done()
exception = lifespan_task.exception()
assert isinstance(exception, LifespanFailureError)
assert str(exception) == "Lifespan failure in startup. 'Failure'"
assert exception.subgroup(LifespanFailureError) is not None # type: ignore


async def return_app(scope: Scope, receive: Callable, send: Callable) -> None:
Expand Down
10 changes: 0 additions & 10 deletions tests/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,16 +71,6 @@ async def echo_framework(
await send({"type": "websocket.send", "text": event["text"], "bytes": event["bytes"]})


async def lifespan_failure(
scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable
) -> None:
while True:
message = await receive()
if message["type"] == "lifespan.startup":
await send({"type": "lifespan.startup.failed", "message": "Failure"})
break


async def sanity_framework(
scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable
) -> None:
Expand Down
35 changes: 20 additions & 15 deletions tests/trio/test_lifespan.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@
from hypercorn.app_wrappers import ASGIWrapper
from hypercorn.config import Config
from hypercorn.trio.lifespan import Lifespan
from hypercorn.typing import ASGIReceiveCallable, ASGISendCallable, Scope
from hypercorn.utils import LifespanFailureError, LifespanTimeoutError
from ..helpers import lifespan_failure, SlowLifespanFramework
from ..helpers import SlowLifespanFramework


@pytest.mark.trio
Expand All @@ -26,19 +27,23 @@ async def test_startup_timeout_error(nursery: trio._core._run.Nursery) -> None:
assert str(exc_info.value).startswith("Timeout whilst awaiting startup")


async def _lifespan_failure(
scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable
) -> None:
async with trio.open_nursery():
while True:
message = await receive()
if message["type"] == "lifespan.startup":
await send({"type": "lifespan.startup.failed", "message": "Failure"})
break


@pytest.mark.trio
async def test_startup_failure() -> None:
lifespan = Lifespan(ASGIWrapper(lifespan_failure), Config(), {})
with pytest.raises(LifespanFailureError) as exc_info:
try:
async with trio.open_nursery() as lifespan_nursery:
await lifespan_nursery.start(lifespan.handle_lifespan)
await lifespan.wait_for_startup()
except ExceptionGroup as exception:
target_exception = exception
if len(exception.exceptions) == 1:
target_exception = exception.exceptions[0]

raise target_exception.with_traceback(target_exception.__traceback__)

assert str(exc_info.value) == "Lifespan failure in startup. 'Failure'"
lifespan = Lifespan(ASGIWrapper(_lifespan_failure), Config(), {})
try:
async with trio.open_nursery() as lifespan_nursery:
await lifespan_nursery.start(lifespan.handle_lifespan)
await lifespan.wait_for_startup()
except ExceptionGroup as error:
assert error.subgroup(LifespanFailureError) is not None

0 comments on commit bfb0877

Please sign in to comment.