Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DM-48576: Suppress some client warnings for async generators #460

Merged
merged 2 commits into from
Jan 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions client/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,7 @@ Source = "https://github.com/lsst-sqre/nublado"
"Issue tracker" = "https://github.com/lsst-sqre/nublado/issues"

[build-system]
requires = [
"setuptools>=61",
"setuptools_scm[toml]>=6.2"
]
requires = ["setuptools>=61", "setuptools_scm[toml]>=6.2"]
build-backend = "setuptools.build_meta"

[tool.coverage.run]
Expand All @@ -77,15 +74,18 @@ exclude_lines = [
"raise NotImplementedError",
"if 0:",
"if __name__ == .__main__.:",
"if TYPE_CHECKING:"
"if TYPE_CHECKING:",
]

[tool.pytest.ini_options]
asyncio_mode = "strict"
asyncio_default_fixture_loop_scope = "function"
filterwarnings = [
# Bug in aiojobs
"ignore:with timeout\\(\\) is deprecated:DeprecationWarning"
"ignore:with timeout\\(\\) is deprecated:DeprecationWarning",
# Arguably a bug in Python 3.13 with async iterators implemented using
# generators
"ignore:.*method 'aclose' of 'Response.aiter_lines':RuntimeWarning",
]
# The python_files setting is not for test detection (pytest will pick up any
# test files named *_test.py without this setting) but to enable special
Expand Down
168 changes: 111 additions & 57 deletions client/src/rubin/nublado/client/nubladoclient.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@

import asyncio
import json
from collections.abc import AsyncIterator, Callable, Coroutine
from collections.abc import AsyncGenerator, AsyncIterator, Callable, Coroutine
from contextlib import AbstractAsyncContextManager, aclosing
from datetime import UTC, datetime, timedelta
from functools import wraps
from pathlib import Path
Expand Down Expand Up @@ -49,6 +50,48 @@
__all__ = ["JupyterLabSession", "NubladoClient"]


class _aclosing_iter[T: AsyncIterator](AbstractAsyncContextManager): # noqa: N801
"""Automatically close async iterators that are generators.

Python supports two ways of writing an async iterator: a true async
iterator, and an async generator. Generators support additional async
context, such as yielding from inside an async context manager, and
therefore require cleanup by calling their `aclose` method once the
generator is no longer needed. This step is done automatically by the
async loop implementation when the generator is garbage-collected, but
this may happen at an arbitrary point and produces pytest warnings
saying that the `aclose` method on the generator was never called.

This class provides a variant of `contextlib.aclosing` that can be
used to close generators masquerading as iterators. Many Python libraries
implement `__aiter__` by returning a generator rather than an iterator,
which is equivalent except for this cleanup behavior. Async iterators do
not require this explicit cleanup step because they don't support async
context managers inside the iteration. Since the library is free to change
from a generator to an iterator at any time, and async iterators don't
require this cleanup and don't have `aclose` methods, the `aclose` method
should be called only if it exists.
"""

def __init__(self, thing: T) -> None:
self.thing = thing

async def __aenter__(self) -> T:
return self.thing

async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
traceback: TracebackType | None,
) -> Literal[False]:
# Only call aclose if the method is defined, which we take to mean that
# this iterator is actually a generator.
if getattr(self.thing, "aclose", None):
await self.thing.aclose() # type: ignore[attr-defined]
return False


class JupyterSpawnProgress:
"""Async iterator returning spawn progress messages.

Expand All @@ -68,7 +111,7 @@ def __init__(self, event_source: EventSource, logger: BoundLogger) -> None:
self._logger = logger
self._start = datetime.now(tz=UTC)

async def __aiter__(self) -> AsyncIterator[SpawnProgressMessage]:
async def __aiter__(self) -> AsyncGenerator[SpawnProgressMessage, None]:
"""Iterate over spawn progress events.

Yields
Expand All @@ -82,27 +125,28 @@ async def __aiter__(self) -> AsyncIterator[SpawnProgressMessage]:
Raised if a protocol error occurred while connecting to the
EventStream API or reading or parsing a message from it.
"""
async for sse in self._source.aiter_sse():
try:
event_dict = sse.json()
event = SpawnProgressMessage(
progress=event_dict["progress"],
message=event_dict["message"],
ready=event_dict.get("ready", False),
)
except Exception as e:
err = f"{type(e).__name__}: {e!s}"
msg = f"Error parsing progress event, ignoring: {err}"
self._logger.warning(msg, type=sse.event, data=sse.data)
continue
async with _aclosing_iter(self._source.aiter_sse()) as sse_events:
async for sse in sse_events:
try:
event_dict = sse.json()
event = SpawnProgressMessage(
progress=event_dict["progress"],
message=event_dict["message"],
ready=event_dict.get("ready", False),
)
except Exception as e:
err = f"{type(e).__name__}: {e!s}"
msg = f"Error parsing progress event, ignoring: {err}"
self._logger.warning(msg, type=sse.event, data=sse.data)
continue

# Log the event and yield it.
now = datetime.now(tz=UTC)
elapsed = int((now - self._start).total_seconds())
status = "complete" if event.ready else "in progress"
msg = f"Spawn {status} ({elapsed}s elapsed): {event.message}"
self._logger.info(msg, elapsed=elapsed, status=status)
yield event
# Log the event and yield it.
now = datetime.now(tz=UTC)
elapsed = int((now - self._start).total_seconds())
status = "complete" if event.ready else "in progress"
msg = f"Spawn {status} ({elapsed}s elapsed): {event.message}"
self._logger.info(msg, elapsed=elapsed, status=status)
yield event


class JupyterLabSession:
Expand Down Expand Up @@ -173,6 +217,9 @@ def __init__(
self._logger = logger

self._session_id: str | None = None
self._socket_manager: (
AbstractAsyncContextManager[ClientConnection] | None
) = None
self._socket: ClientConnection | None = None

async def __aenter__(self) -> Self:
Expand Down Expand Up @@ -238,12 +285,13 @@ async def __aenter__(self) -> Self:
self._logger.debug("Opening WebSocket connection")
start = datetime.now(tz=UTC)
try:
self._socket = await websockets.connect(
self._socket_manager = websockets.connect(
self._url_for_websocket(url),
extra_headers=headers,
open_timeout=WEBSOCKET_OPEN_TIMEOUT,
max_size=self._max_websocket_size,
).__aenter__()
)
self._socket = await self._socket_manager.__aenter__()
except WebSocketException as e:
user = self._username
raise JupyterWebSocketError.from_exception(
Expand All @@ -266,14 +314,15 @@ async def __aexit__(
session_id = self._session_id

# Close the WebSocket.
if self._socket:
if self._socket_manager:
start = datetime.now(tz=UTC)
try:
await self._socket.close()
await self._socket_manager.__aexit__(exc_type, exc_val, exc_tb)
except WebSocketException as e:
raise JupyterWebSocketError.from_exception(
e, username, started_at=start
) from e
self._socket_manager = None
self._socket = None

# Delete the lab session.
Expand Down Expand Up @@ -357,27 +406,28 @@ async def run_python(
result = ""
try:
await self._socket.send(json.dumps(request))
async for message in self._socket:
try:
output = self._parse_message(message, message_id)
except CodeExecutionError as e:
e.code = code
e.started_at = start
_annotate_exception_from_context(e, context)
raise
except Exception as e:
error = f"{type(e).__name__}: {e!s}"
msg = "Ignoring unparsable web socket message"
self._logger.warning(msg, error=error, message=message)

# Accumulate the results if they are of interest, and exit and
# return the results if this message indicated the end of
# execution.
if not output:
continue
result += output.content
if output.done:
break
async with _aclosing_iter(aiter(self._socket)) as messages:
async for message in messages:
try:
output = self._parse_message(message, message_id)
except CodeExecutionError as e:
e.code = code
e.started_at = start
_annotate_exception_from_context(e, context)
raise
except Exception as e:
error = f"{type(e).__name__}: {e!s}"
msg = "Ignoring unparsable web socket message"
self._logger.warning(msg, error=error, message=message)

# Accumulate the results if they are of interest, and exit
# and return the results if this message indicated the end
# of execution.
if not output:
continue
result += output.content
if output.done:
break
except WebSocketException as e:
user = self._username
new_exc = JupyterWebSocketError.from_exception(e, user)
Expand Down Expand Up @@ -734,9 +784,9 @@ async def wrapper(
return wrapper


def _convert_iterator_exception[**P, T](
f: Callable[Concatenate[NubladoClient, P], AsyncIterator[T]],
) -> Callable[Concatenate[NubladoClient, P], AsyncIterator[T]]:
def _convert_generator_exception[**P, T](
f: Callable[Concatenate[NubladoClient, P], AsyncGenerator[T, None]],
) -> Callable[Concatenate[NubladoClient, P], AsyncGenerator[T, None]]:
"""Convert web errors to a `~rubin.nublado.client.JupyterWebError`.

This can only be used as a decorator on `JupyterClientSession` or another
Expand All @@ -747,11 +797,13 @@ def _convert_iterator_exception[**P, T](
@wraps(f)
async def wrapper(
client: NubladoClient, *args: P.args, **kwargs: P.kwargs
) -> AsyncIterator[T]:
) -> AsyncGenerator[T, None]:
start = datetime.now(tz=UTC)
generator = f(client, *args, **kwargs)
try:
async for result in f(client, *args, **kwargs):
yield result
async with aclosing(generator):
async for result in generator:
yield result
except HTTPError as e:
username = client.user.username
raise JupyterWebError.raise_from_exception_with_timestamps(
Expand Down Expand Up @@ -1086,10 +1138,10 @@ async def stop_lab(self) -> None:
r = await self._client.delete(url, headers=headers)
r.raise_for_status()

@_convert_iterator_exception
@_convert_generator_exception
async def watch_spawn_progress(
self,
) -> AsyncIterator[SpawnProgressMessage]:
) -> AsyncGenerator[SpawnProgressMessage, None]:
"""Monitor lab spawn progress.

This is an EventStream API, which provides a stream of events until
Expand All @@ -1108,8 +1160,10 @@ async def watch_spawn_progress(
headers["X-XSRFToken"] = self._hub_xsrf
while True:
async with aconnect_sse(client, "GET", url, headers=headers) as s:
async for message in JupyterSpawnProgress(s, self._logger):
yield message
progress = aiter(JupyterSpawnProgress(s, self._logger))
async with aclosing(progress):
async for message in progress:
yield message

# Sometimes we get only the initial request message and then the
# progress API immediately closes the connection. If that happens,
Expand Down
14 changes: 8 additions & 6 deletions client/tests/client/client_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Tests for the NubladoClient object."""

import asyncio
from contextlib import aclosing
from pathlib import Path

import pytest
Expand Down Expand Up @@ -37,12 +38,13 @@ async def test_hub_flow(
# Watch the progress meter
progress = configured_client.watch_spawn_progress()
progress_pct = -1
async with asyncio.timeout(30):
async for message in progress:
if message.ready:
break
assert message.progress > progress_pct
progress_pct = message.progress
async with aclosing(progress):
async with asyncio.timeout(30):
Comment on lines +41 to +42
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just curious if the aclosing has to be outside the timeout, or if it doesn't matter.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It shouldn't matter. It will close the generator in the finally block when exited by the exception raised by asyncio.timeout either way.

async for message in progress:
if message.ready:
break
assert message.progress > progress_pct
progress_pct = message.progress
# Is the lab running? Should be.
assert not (await configured_client.is_lab_stopped())
try:
Expand Down
40 changes: 22 additions & 18 deletions client/tests/mock/mock_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import asyncio
import json
from contextlib import aclosing
from pathlib import Path

import pytest
Expand Down Expand Up @@ -42,12 +43,13 @@ async def test_register_python(
# Watch the progress meter
progress = configured_client.watch_spawn_progress()
progress_pct = -1
async with asyncio.timeout(30):
async for message in progress:
if message.ready:
break
assert message.progress > progress_pct
progress_pct = message.progress
async with aclosing(progress):
async with asyncio.timeout(30):
async for message in progress:
if message.ready:
break
assert message.progress > progress_pct
progress_pct = message.progress
await configured_client.auth_to_lab()

# Now test our mock
Expand Down Expand Up @@ -93,12 +95,13 @@ async def test_register_python_with_notebook(
# Watch the progress meter
progress = configured_client.watch_spawn_progress()
progress_pct = -1
async with asyncio.timeout(30):
async for message in progress:
if message.ready:
break
assert message.progress > progress_pct
progress_pct = message.progress
async with aclosing(progress):
async with asyncio.timeout(30):
async for message in progress:
if message.ready:
break
assert message.progress > progress_pct
progress_pct = message.progress
await configured_client.auth_to_lab()

# Now test our mock
Expand Down Expand Up @@ -139,12 +142,13 @@ async def test_register_extension(
# Watch the progress meter
progress = configured_client.watch_spawn_progress()
progress_pct = -1
async with asyncio.timeout(30):
async for message in progress:
if message.ready:
break
assert message.progress > progress_pct
progress_pct = message.progress
async with aclosing(progress):
async with asyncio.timeout(30):
async for message in progress:
if message.ready:
break
assert message.progress > progress_pct
progress_pct = message.progress
await configured_client.auth_to_lab()

# Now test our mock
Expand Down
Loading
Loading