diff --git a/litestar/testing/client/async_client.py b/litestar/testing/client/async_client.py index 4e28bef4ac..4bf9eec087 100644 --- a/litestar/testing/client/async_client.py +++ b/litestar/testing/client/async_client.py @@ -86,6 +86,7 @@ async def __aenter__(self) -> Self: async with AsyncExitStack() as stack: self.blocking_portal = portal = stack.enter_context(self.portal()) self.lifespan_handler = LifeSpanHandler(client=self) + stack.enter_context(self.lifespan_handler) @stack.callback def reset_portal() -> None: diff --git a/litestar/testing/client/sync_client.py b/litestar/testing/client/sync_client.py index 9cbfcb3d94..9c58d139d2 100644 --- a/litestar/testing/client/sync_client.py +++ b/litestar/testing/client/sync_client.py @@ -87,6 +87,7 @@ def __enter__(self) -> Self: with ExitStack() as stack: self.blocking_portal = portal = stack.enter_context(self.portal()) self.lifespan_handler = LifeSpanHandler(client=self) + stack.enter_context(self.lifespan_handler) @stack.callback def reset_portal() -> None: diff --git a/litestar/testing/life_span_handler.py b/litestar/testing/life_span_handler.py index 8ee7d22c3c..8c2ee5f2dd 100644 --- a/litestar/testing/life_span_handler.py +++ b/litestar/testing/life_span_handler.py @@ -1,5 +1,6 @@ from __future__ import annotations +import warnings from math import inf from typing import TYPE_CHECKING, Generic, Optional, TypeVar, cast @@ -9,6 +10,8 @@ from litestar.testing.client.base import BaseTestClient if TYPE_CHECKING: + from types import TracebackType + from litestar.types import ( LifeSpanReceiveMessage, # noqa: F401 LifeSpanSendMessage, @@ -20,24 +23,69 @@ class LifeSpanHandler(Generic[T]): - __slots__ = "stream_send", "stream_receive", "client", "task" + __slots__ = ( + "stream_send", + "stream_receive", + "client", + "task", + "_startup_done", + ) def __init__(self, client: T) -> None: self.client = client self.stream_send = StapledObjectStream[Optional["LifeSpanSendMessage"]](*create_memory_object_stream(inf)) # type: ignore[arg-type] self.stream_receive = StapledObjectStream["LifeSpanReceiveMessage"](*create_memory_object_stream(inf)) # type: ignore[arg-type] + self._startup_done = False + def _ensure_setup(self, is_safe: bool = False) -> None: + if self._startup_done: + return + + if not is_safe: + warnings.warn( + "LifeSpanHandler used with implicit startup; Use LifeSpanHandler as a context manager instead. " + "Implicit startup will be deprecated in version 3.0.", + category=DeprecationWarning, + stacklevel=2, + ) + + self._startup_done = True with self.client.portal() as portal: self.task = portal.start_task_soon(self.lifespan) portal.call(self.wait_startup) + def close(self) -> None: + with self.client.portal() as portal: + portal.call(self.stream_send.aclose) + portal.call(self.stream_receive.aclose) + + def __enter__(self) -> LifeSpanHandler: + try: + self._ensure_setup(is_safe=True) + except Exception as exc: + self.close() + raise exc + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> None: + self.close() + async def receive(self) -> LifeSpanSendMessage: + self._ensure_setup() + message = await self.stream_send.receive() if message is None: self.task.result() return cast("LifeSpanSendMessage", message) async def wait_startup(self) -> None: + self._ensure_setup() + event: LifeSpanStartupEvent = {"type": "lifespan.startup"} await self.stream_receive.send(event) @@ -54,6 +102,8 @@ async def wait_startup(self) -> None: await self.receive() async def wait_shutdown(self) -> None: + self._ensure_setup() + async with self.stream_send: lifespan_shutdown_event: LifeSpanShutdownEvent = {"type": "lifespan.shutdown"} await self.stream_receive.send(lifespan_shutdown_event) @@ -71,6 +121,8 @@ async def wait_shutdown(self) -> None: await self.receive() async def lifespan(self) -> None: + self._ensure_setup() + scope = {"type": "lifespan"} try: await self.client.app(scope, self.stream_receive.receive, self.stream_send.send) diff --git a/tests/unit/test_testing/test_lifespan_handler.py b/tests/unit/test_testing/test_lifespan_handler.py index 132b642710..f04f91a959 100644 --- a/tests/unit/test_testing/test_lifespan_handler.py +++ b/tests/unit/test_testing/test_lifespan_handler.py @@ -4,13 +4,16 @@ from litestar.testing.life_span_handler import LifeSpanHandler from litestar.types import Receive, Scope, Send +pytestmark = pytest.mark.filterwarnings("default") + async def test_wait_startup_invalid_event() -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: await send({"type": "lifespan.startup.something_unexpected"}) # type: ignore[typeddict-item] with pytest.raises(RuntimeError, match="Received unexpected ASGI message type"): - LifeSpanHandler(TestClient(app)) + with LifeSpanHandler(TestClient(app)): + pass async def test_wait_shutdown_invalid_event() -> None: @@ -18,7 +21,17 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None: await send({"type": "lifespan.startup.complete"}) # type: ignore[typeddict-item] await send({"type": "lifespan.shutdown.something_unexpected"}) # type: ignore[typeddict-item] - handler = LifeSpanHandler(TestClient(app)) + with LifeSpanHandler(TestClient(app)) as handler: + with pytest.raises(RuntimeError, match="Received unexpected ASGI message type"): + await handler.wait_shutdown() - with pytest.raises(RuntimeError, match="Received unexpected ASGI message type"): + +async def test_implicit_startup() -> None: + async def app(scope: Scope, receive: Receive, send: Send) -> None: + await send({"type": "lifespan.startup.complete"}) # type: ignore[typeddict-item] + await send({"type": "lifespan.shutdown.complete"}) # type: ignore[typeddict-item] + + with pytest.warns(DeprecationWarning): + handler = LifeSpanHandler(TestClient(app)) await handler.wait_shutdown() + handler.close()