Skip to content

Commit

Permalink
fix resource handling
Browse files Browse the repository at this point in the history
  • Loading branch information
provinzkraut committed Oct 20, 2024
1 parent b2adb0d commit fca8691
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 5 deletions.
1 change: 1 addition & 0 deletions litestar/testing/client/async_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ async def __aenter__(self) -> Self:
async with AsyncExitStack() as stack:
self.blocking_portal = portal = stack.enter_context(self.portal())
self.lifespan_handler = LifeSpanHandler(client=self)
stack.enter_context(self.lifespan_handler)

@stack.callback
def reset_portal() -> None:
Expand Down
1 change: 1 addition & 0 deletions litestar/testing/client/sync_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ def __enter__(self) -> Self:
with ExitStack() as stack:
self.blocking_portal = portal = stack.enter_context(self.portal())
self.lifespan_handler = LifeSpanHandler(client=self)
stack.enter_context(self.lifespan_handler)

@stack.callback
def reset_portal() -> None:
Expand Down
51 changes: 50 additions & 1 deletion litestar/testing/life_span_handler.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
from __future__ import annotations

import contextlib
import threading
import warnings
from math import inf
from typing import TYPE_CHECKING, Generic, Optional, TypeVar, cast

import anyio
from anyio import create_memory_object_stream
from anyio.streams.stapled import StapledObjectStream

from litestar.testing.client.base import BaseTestClient
from litestar.utils import warn_deprecation

if TYPE_CHECKING:
from litestar.types import (
Expand All @@ -20,24 +25,64 @@


class LifeSpanHandler(Generic[T]):
__slots__ = "stream_send", "stream_receive", "client", "task"
__slots__ = (
"stream_send",
"stream_receive",
"client",
"task",
"_startup_done",
)

def __init__(self, client: T) -> None:
self.client = client
self.stream_send = StapledObjectStream[Optional["LifeSpanSendMessage"]](*create_memory_object_stream(inf)) # type: ignore[arg-type]
self.stream_receive = StapledObjectStream["LifeSpanReceiveMessage"](*create_memory_object_stream(inf)) # type: ignore[arg-type]
self._startup_done = False

def _ensure_setup(self, is_safe: bool = False):
if self._startup_done:
return

if not is_safe:
warnings.warn(
"LifeSpanHandler used with implicit startup; Use LifeSpanHandler as a context manager instead. "
"Implicit startup will be deprecated in version 3.0.",
category=DeprecationWarning,
stacklevel=2,
)

self._startup_done = True
with self.client.portal() as portal:
self.task = portal.start_task_soon(self.lifespan)
portal.call(self.wait_startup)

def _teardown(self):
with self.client.portal() as portal:
portal.call(self.stream_send.aclose)
portal.call(self.stream_receive.aclose)

def __enter__(self) -> LifeSpanHandler:
try:
self._ensure_setup()
except Exception as exc:
self._teardown()
raise exc
return self

def __exit__(self, exc_type, exc_val, exc_tb):
self._teardown()

async def receive(self) -> LifeSpanSendMessage:
self._ensure_setup()

message = await self.stream_send.receive()
if message is None:
self.task.result()
return cast("LifeSpanSendMessage", message)

async def wait_startup(self) -> None:
self._ensure_setup()

event: LifeSpanStartupEvent = {"type": "lifespan.startup"}
await self.stream_receive.send(event)

Expand All @@ -54,6 +99,8 @@ async def wait_startup(self) -> None:
await self.receive()

async def wait_shutdown(self) -> None:
self._ensure_setup()

async with self.stream_send:
lifespan_shutdown_event: LifeSpanShutdownEvent = {"type": "lifespan.shutdown"}
await self.stream_receive.send(lifespan_shutdown_event)
Expand All @@ -71,6 +118,8 @@ async def wait_shutdown(self) -> None:
await self.receive()

async def lifespan(self) -> None:
self._ensure_setup()

scope = {"type": "lifespan"}
try:
await self.client.app(scope, self.stream_receive.receive, self.stream_send.send)
Expand Down
11 changes: 7 additions & 4 deletions tests/unit/test_testing/test_lifespan_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,24 @@
from litestar.testing.life_span_handler import LifeSpanHandler
from litestar.types import Receive, Scope, Send

pytestmark = pytest.mark.filterwarnings("default")


async def test_wait_startup_invalid_event() -> None:
async def app(scope: Scope, receive: Receive, send: Send) -> None:
await send({"type": "lifespan.startup.something_unexpected"}) # type: ignore[typeddict-item]

with pytest.raises(RuntimeError, match="Received unexpected ASGI message type"):
LifeSpanHandler(TestClient(app))
with LifeSpanHandler(TestClient(app)):
pass


async def test_wait_shutdown_invalid_event() -> None:
async def app(scope: Scope, receive: Receive, send: Send) -> None:
await send({"type": "lifespan.startup.complete"}) # type: ignore[typeddict-item]
await send({"type": "lifespan.shutdown.something_unexpected"}) # type: ignore[typeddict-item]

handler = LifeSpanHandler(TestClient(app))
with LifeSpanHandler(TestClient(app)) as handler:

with pytest.raises(RuntimeError, match="Received unexpected ASGI message type"):
await handler.wait_shutdown()
with pytest.raises(RuntimeError, match="Received unexpected ASGI message type"):
await handler.wait_shutdown()

0 comments on commit fca8691

Please sign in to comment.