From 70dbc1c141f8e1452b465936a3770cca2f78dfe8 Mon Sep 17 00:00:00 2001 From: Russ Allbery Date: Thu, 23 Jan 2025 17:34:23 -0800 Subject: [PATCH 1/2] Suppress some client warnings for async generators Python 3.13 adds a new warning when the `aclose` method of an async generator is not called explicitly. Async generators should ideally be explicitly shut down with `aclose` when they are no longer in use. Fix typing in some cases where async generators were returned but typed as async iterators (which don't require `aclose`), and close more generators explicitly. Include some async iterators from third party libraries that are actually generators, which requires working around the typing and some future-proofing for API changes. Ignore the remaining warning that's internal to httpx-sse and can't be fixed at the client level. --- client/pyproject.toml | 12 +-- .../src/rubin/nublado/client/nubladoclient.py | 101 +++++++++++------- client/tests/client/client_test.py | 14 +-- client/tests/mock/mock_test.py | 40 +++---- docs/client/index.rst | 14 +-- 5 files changed, 108 insertions(+), 73 deletions(-) diff --git a/client/pyproject.toml b/client/pyproject.toml index 83d535204..8687f5197 100644 --- a/client/pyproject.toml +++ b/client/pyproject.toml @@ -52,10 +52,7 @@ Source = "https://github.com/lsst-sqre/nublado" "Issue tracker" = "https://github.com/lsst-sqre/nublado/issues" [build-system] -requires = [ - "setuptools>=61", - "setuptools_scm[toml]>=6.2" -] +requires = ["setuptools>=61", "setuptools_scm[toml]>=6.2"] build-backend = "setuptools.build_meta" [tool.coverage.run] @@ -77,7 +74,7 @@ exclude_lines = [ "raise NotImplementedError", "if 0:", "if __name__ == .__main__.:", - "if TYPE_CHECKING:" + "if TYPE_CHECKING:", ] [tool.pytest.ini_options] @@ -85,7 +82,10 @@ asyncio_mode = "strict" asyncio_default_fixture_loop_scope = "function" filterwarnings = [ # Bug in aiojobs - "ignore:with timeout\\(\\) is deprecated:DeprecationWarning" + "ignore:with timeout\\(\\) is deprecated:DeprecationWarning", + # Arguably a bug in Python 3.13 with async iterators implemented using + # generators + "ignore:.*method 'aclose' of 'Response.aiter_lines':RuntimeWarning", ] # The python_files setting is not for test detection (pytest will pick up any # test files named *_test.py without this setting) but to enable special diff --git a/client/src/rubin/nublado/client/nubladoclient.py b/client/src/rubin/nublado/client/nubladoclient.py index 159e8392b..4508e3c2f 100644 --- a/client/src/rubin/nublado/client/nubladoclient.py +++ b/client/src/rubin/nublado/client/nubladoclient.py @@ -8,7 +8,8 @@ import asyncio import json -from collections.abc import AsyncIterator, Callable, Coroutine +from collections.abc import AsyncGenerator, Callable, Coroutine +from contextlib import AbstractAsyncContextManager, aclosing, suppress from datetime import UTC, datetime, timedelta from functools import wraps from pathlib import Path @@ -68,7 +69,7 @@ def __init__(self, event_source: EventSource, logger: BoundLogger) -> None: self._logger = logger self._start = datetime.now(tz=UTC) - async def __aiter__(self) -> AsyncIterator[SpawnProgressMessage]: + async def __aiter__(self) -> AsyncGenerator[SpawnProgressMessage, None]: """Iterate over spawn progress events. Yields @@ -82,27 +83,36 @@ async def __aiter__(self) -> AsyncIterator[SpawnProgressMessage]: Raised if a protocol error occurred while connecting to the EventStream API or reading or parsing a message from it. """ - async for sse in self._source.aiter_sse(): - try: - event_dict = sse.json() - event = SpawnProgressMessage( - progress=event_dict["progress"], - message=event_dict["message"], - ready=event_dict.get("ready", False), - ) - except Exception as e: - err = f"{type(e).__name__}: {e!s}" - msg = f"Error parsing progress event, ignoring: {err}" - self._logger.warning(msg, type=sse.event, data=sse.data) - continue + sse_events = self._source.aiter_sse() + try: + async for sse in sse_events: + try: + event_dict = sse.json() + event = SpawnProgressMessage( + progress=event_dict["progress"], + message=event_dict["message"], + ready=event_dict.get("ready", False), + ) + except Exception as e: + err = f"{type(e).__name__}: {e!s}" + msg = f"Error parsing progress event, ignoring: {err}" + self._logger.warning(msg, type=sse.event, data=sse.data) + continue - # Log the event and yield it. - now = datetime.now(tz=UTC) - elapsed = int((now - self._start).total_seconds()) - status = "complete" if event.ready else "in progress" - msg = f"Spawn {status} ({elapsed}s elapsed): {event.message}" - self._logger.info(msg, elapsed=elapsed, status=status) - yield event + # Log the event and yield it. + now = datetime.now(tz=UTC) + elapsed = int((now - self._start).total_seconds()) + status = "complete" if event.ready else "in progress" + msg = f"Spawn {status} ({elapsed}s elapsed): {event.message}" + self._logger.info(msg, elapsed=elapsed, status=status) + yield event + finally: + # aiter_sse actually returns an asynchronous generator, which + # therefore is supposed to be explicitly closed with alcose() + # after break, unlike a true iterator. Handle this case to suppress + # warnings in Python 3.13. + with suppress(AttributeError): + await sse_events.aclose() # type: ignore[attr-defined] class JupyterLabSession: @@ -173,6 +183,9 @@ def __init__( self._logger = logger self._session_id: str | None = None + self._socket_manager: ( + AbstractAsyncContextManager[ClientConnection] | None + ) = None self._socket: ClientConnection | None = None async def __aenter__(self) -> Self: @@ -238,12 +251,13 @@ async def __aenter__(self) -> Self: self._logger.debug("Opening WebSocket connection") start = datetime.now(tz=UTC) try: - self._socket = await websockets.connect( + self._socket_manager = websockets.connect( self._url_for_websocket(url), extra_headers=headers, open_timeout=WEBSOCKET_OPEN_TIMEOUT, max_size=self._max_websocket_size, - ).__aenter__() + ) + self._socket = await self._socket_manager.__aenter__() except WebSocketException as e: user = self._username raise JupyterWebSocketError.from_exception( @@ -266,14 +280,15 @@ async def __aexit__( session_id = self._session_id # Close the WebSocket. - if self._socket: + if self._socket_manager: start = datetime.now(tz=UTC) try: - await self._socket.close() + await self._socket_manager.__aexit__(exc_type, exc_val, exc_tb) except WebSocketException as e: raise JupyterWebSocketError.from_exception( e, username, started_at=start ) from e + self._socket_manager = None self._socket = None # Delete the lab session. @@ -357,7 +372,8 @@ async def run_python( result = "" try: await self._socket.send(json.dumps(request)) - async for message in self._socket: + messages = aiter(self._socket) + async for message in messages: try: output = self._parse_message(message, message_id) except CodeExecutionError as e: @@ -383,6 +399,13 @@ async def run_python( new_exc = JupyterWebSocketError.from_exception(e, user) new_exc.started_at = start raise new_exc from e + finally: + # websocket.__aiter__ actually returns an asynchronous generator, + # which therefore is supposed to be explicitly closed with alcose() + # after break, unlike a true iterator. Handle this case to suppress + # warnings in Python 3.13. + with suppress(AttributeError): + await messages.aclose() # type: ignore[attr-defined] # Return the accumulated output. return result @@ -734,9 +757,9 @@ async def wrapper( return wrapper -def _convert_iterator_exception[**P, T]( - f: Callable[Concatenate[NubladoClient, P], AsyncIterator[T]], -) -> Callable[Concatenate[NubladoClient, P], AsyncIterator[T]]: +def _convert_generator_exception[**P, T]( + f: Callable[Concatenate[NubladoClient, P], AsyncGenerator[T, None]], +) -> Callable[Concatenate[NubladoClient, P], AsyncGenerator[T, None]]: """Convert web errors to a `~rubin.nublado.client.JupyterWebError`. This can only be used as a decorator on `JupyterClientSession` or another @@ -747,11 +770,13 @@ def _convert_iterator_exception[**P, T]( @wraps(f) async def wrapper( client: NubladoClient, *args: P.args, **kwargs: P.kwargs - ) -> AsyncIterator[T]: + ) -> AsyncGenerator[T, None]: start = datetime.now(tz=UTC) + generator = f(client, *args, **kwargs) try: - async for result in f(client, *args, **kwargs): - yield result + async with aclosing(generator): + async for result in generator: + yield result except HTTPError as e: username = client.user.username raise JupyterWebError.raise_from_exception_with_timestamps( @@ -1086,10 +1111,10 @@ async def stop_lab(self) -> None: r = await self._client.delete(url, headers=headers) r.raise_for_status() - @_convert_iterator_exception + @_convert_generator_exception async def watch_spawn_progress( self, - ) -> AsyncIterator[SpawnProgressMessage]: + ) -> AsyncGenerator[SpawnProgressMessage, None]: """Monitor lab spawn progress. This is an EventStream API, which provides a stream of events until @@ -1108,8 +1133,10 @@ async def watch_spawn_progress( headers["X-XSRFToken"] = self._hub_xsrf while True: async with aconnect_sse(client, "GET", url, headers=headers) as s: - async for message in JupyterSpawnProgress(s, self._logger): - yield message + progress = aiter(JupyterSpawnProgress(s, self._logger)) + async with aclosing(progress): + async for message in progress: + yield message # Sometimes we get only the initial request message and then the # progress API immediately closes the connection. If that happens, diff --git a/client/tests/client/client_test.py b/client/tests/client/client_test.py index 31eb26064..6e99b7ec3 100644 --- a/client/tests/client/client_test.py +++ b/client/tests/client/client_test.py @@ -1,6 +1,7 @@ """Tests for the NubladoClient object.""" import asyncio +from contextlib import aclosing from pathlib import Path import pytest @@ -37,12 +38,13 @@ async def test_hub_flow( # Watch the progress meter progress = configured_client.watch_spawn_progress() progress_pct = -1 - async with asyncio.timeout(30): - async for message in progress: - if message.ready: - break - assert message.progress > progress_pct - progress_pct = message.progress + async with aclosing(progress): + async with asyncio.timeout(30): + async for message in progress: + if message.ready: + break + assert message.progress > progress_pct + progress_pct = message.progress # Is the lab running? Should be. assert not (await configured_client.is_lab_stopped()) try: diff --git a/client/tests/mock/mock_test.py b/client/tests/mock/mock_test.py index 215e83505..c8fd8ef3d 100644 --- a/client/tests/mock/mock_test.py +++ b/client/tests/mock/mock_test.py @@ -2,6 +2,7 @@ import asyncio import json +from contextlib import aclosing from pathlib import Path import pytest @@ -42,12 +43,13 @@ async def test_register_python( # Watch the progress meter progress = configured_client.watch_spawn_progress() progress_pct = -1 - async with asyncio.timeout(30): - async for message in progress: - if message.ready: - break - assert message.progress > progress_pct - progress_pct = message.progress + async with aclosing(progress): + async with asyncio.timeout(30): + async for message in progress: + if message.ready: + break + assert message.progress > progress_pct + progress_pct = message.progress await configured_client.auth_to_lab() # Now test our mock @@ -93,12 +95,13 @@ async def test_register_python_with_notebook( # Watch the progress meter progress = configured_client.watch_spawn_progress() progress_pct = -1 - async with asyncio.timeout(30): - async for message in progress: - if message.ready: - break - assert message.progress > progress_pct - progress_pct = message.progress + async with aclosing(progress): + async with asyncio.timeout(30): + async for message in progress: + if message.ready: + break + assert message.progress > progress_pct + progress_pct = message.progress await configured_client.auth_to_lab() # Now test our mock @@ -139,12 +142,13 @@ async def test_register_extension( # Watch the progress meter progress = configured_client.watch_spawn_progress() progress_pct = -1 - async with asyncio.timeout(30): - async for message in progress: - if message.ready: - break - assert message.progress > progress_pct - progress_pct = message.progress + async with aclosing(progress): + async with asyncio.timeout(30): + async for message in progress: + if message.ready: + break + assert message.progress > progress_pct + progress_pct = message.progress await configured_client.auth_to_lab() # Now test our mock diff --git a/docs/client/index.rst b/docs/client/index.rst index fa5e12216..00a3ce487 100644 --- a/docs/client/index.rst +++ b/docs/client/index.rst @@ -66,6 +66,7 @@ the Lab: """Ensure there's a running lab for the user.""" import asyncio + from contextlib import aclosing import structlog @@ -95,10 +96,11 @@ the Lab: ) await client.spawn_lab(image) progress = client.watch_spawn_progress() - async with asyncio.timeout(LAB_SPAWN_TIMEOUT): - async for message in progress: - if message.ready: - break + async with aclosing(progress): + async with asyncio.timeout(LAB_SPAWN_TIMEOUT): + async for message in progress: + if message.ready: + break asyncio.run(ensure_lab()) @@ -267,7 +269,7 @@ It depends on two other fixtures: ``environment_url`` is a string, representing .. code-block:: python - from collections.abc import Iterator + from collections.abc import AsyncGenerator, Iterator from contextlib import asynccontextmanager from pathlib import Path from unittest.mock import patch @@ -318,7 +320,7 @@ It depends on two other fixtures: ``environment_url`` is a string, representing extra_headers: dict[str, str], max_size: int | None, open_timeout: int, - ) -> AsyncIterator[MockJupyterWebSocket]: + ) -> AsyncGenerator[MockJupyterWebSocket, None]: yield mock_jupyter_websocket(url, extra_headers, jupyter_mock) with patch(websockets, "connect") as mock: From 775a6e55cbf5677e07931f0ca8ccc44c4913d7df Mon Sep 17 00:00:00 2001 From: Russ Allbery Date: Fri, 24 Jan 2025 10:28:37 -0800 Subject: [PATCH 2/2] Refactor generator closing into a helper class Move the code in the Nublado client to close upstream library async iterators that are actually generators into an `_aclosing_iter` helper class with the same interface as `contextlib.aclosing`. Use that class documentation to more thoroughly document why we're taking this approach. --- .../src/rubin/nublado/client/nubladoclient.py | 107 +++++++++++------- 1 file changed, 67 insertions(+), 40 deletions(-) diff --git a/client/src/rubin/nublado/client/nubladoclient.py b/client/src/rubin/nublado/client/nubladoclient.py index 4508e3c2f..a17810666 100644 --- a/client/src/rubin/nublado/client/nubladoclient.py +++ b/client/src/rubin/nublado/client/nubladoclient.py @@ -8,8 +8,8 @@ import asyncio import json -from collections.abc import AsyncGenerator, Callable, Coroutine -from contextlib import AbstractAsyncContextManager, aclosing, suppress +from collections.abc import AsyncGenerator, AsyncIterator, Callable, Coroutine +from contextlib import AbstractAsyncContextManager, aclosing from datetime import UTC, datetime, timedelta from functools import wraps from pathlib import Path @@ -50,6 +50,48 @@ __all__ = ["JupyterLabSession", "NubladoClient"] +class _aclosing_iter[T: AsyncIterator](AbstractAsyncContextManager): # noqa: N801 + """Automatically close async iterators that are generators. + + Python supports two ways of writing an async iterator: a true async + iterator, and an async generator. Generators support additional async + context, such as yielding from inside an async context manager, and + therefore require cleanup by calling their `aclose` method once the + generator is no longer needed. This step is done automatically by the + async loop implementation when the generator is garbage-collected, but + this may happen at an arbitrary point and produces pytest warnings + saying that the `aclose` method on the generator was never called. + + This class provides a variant of `contextlib.aclosing` that can be + used to close generators masquerading as iterators. Many Python libraries + implement `__aiter__` by returning a generator rather than an iterator, + which is equivalent except for this cleanup behavior. Async iterators do + not require this explicit cleanup step because they don't support async + context managers inside the iteration. Since the library is free to change + from a generator to an iterator at any time, and async iterators don't + require this cleanup and don't have `aclose` methods, the `aclose` method + should be called only if it exists. + """ + + def __init__(self, thing: T) -> None: + self.thing = thing + + async def __aenter__(self) -> T: + return self.thing + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> Literal[False]: + # Only call aclose if the method is defined, which we take to mean that + # this iterator is actually a generator. + if getattr(self.thing, "aclose", None): + await self.thing.aclose() # type: ignore[attr-defined] + return False + + class JupyterSpawnProgress: """Async iterator returning spawn progress messages. @@ -83,8 +125,7 @@ async def __aiter__(self) -> AsyncGenerator[SpawnProgressMessage, None]: Raised if a protocol error occurred while connecting to the EventStream API or reading or parsing a message from it. """ - sse_events = self._source.aiter_sse() - try: + async with _aclosing_iter(self._source.aiter_sse()) as sse_events: async for sse in sse_events: try: event_dict = sse.json() @@ -106,13 +147,6 @@ async def __aiter__(self) -> AsyncGenerator[SpawnProgressMessage, None]: msg = f"Spawn {status} ({elapsed}s elapsed): {event.message}" self._logger.info(msg, elapsed=elapsed, status=status) yield event - finally: - # aiter_sse actually returns an asynchronous generator, which - # therefore is supposed to be explicitly closed with alcose() - # after break, unlike a true iterator. Handle this case to suppress - # warnings in Python 3.13. - with suppress(AttributeError): - await sse_events.aclose() # type: ignore[attr-defined] class JupyterLabSession: @@ -372,40 +406,33 @@ async def run_python( result = "" try: await self._socket.send(json.dumps(request)) - messages = aiter(self._socket) - async for message in messages: - try: - output = self._parse_message(message, message_id) - except CodeExecutionError as e: - e.code = code - e.started_at = start - _annotate_exception_from_context(e, context) - raise - except Exception as e: - error = f"{type(e).__name__}: {e!s}" - msg = "Ignoring unparsable web socket message" - self._logger.warning(msg, error=error, message=message) - - # Accumulate the results if they are of interest, and exit and - # return the results if this message indicated the end of - # execution. - if not output: - continue - result += output.content - if output.done: - break + async with _aclosing_iter(aiter(self._socket)) as messages: + async for message in messages: + try: + output = self._parse_message(message, message_id) + except CodeExecutionError as e: + e.code = code + e.started_at = start + _annotate_exception_from_context(e, context) + raise + except Exception as e: + error = f"{type(e).__name__}: {e!s}" + msg = "Ignoring unparsable web socket message" + self._logger.warning(msg, error=error, message=message) + + # Accumulate the results if they are of interest, and exit + # and return the results if this message indicated the end + # of execution. + if not output: + continue + result += output.content + if output.done: + break except WebSocketException as e: user = self._username new_exc = JupyterWebSocketError.from_exception(e, user) new_exc.started_at = start raise new_exc from e - finally: - # websocket.__aiter__ actually returns an asynchronous generator, - # which therefore is supposed to be explicitly closed with alcose() - # after break, unlike a true iterator. Handle this case to suppress - # warnings in Python 3.13. - with suppress(AttributeError): - await messages.aclose() # type: ignore[attr-defined] # Return the accumulated output. return result