diff --git a/src/easynetwork/lowlevel/_asyncgen.py b/src/easynetwork/lowlevel/_asyncgen.py index 89919921..029f0813 100644 --- a/src/easynetwork/lowlevel/_asyncgen.py +++ b/src/easynetwork/lowlevel/_asyncgen.py @@ -19,9 +19,12 @@ __all__ = [] # type: list[str] import dataclasses +import sys from abc import ABCMeta, abstractmethod from collections.abc import AsyncGenerator -from typing import Any, Generic, TypeVar +from typing import Any, Generic, Protocol, TypeVar + +from . import _utils _T_Send = TypeVar("_T_Send") _T_Yield = TypeVar("_T_Yield") @@ -57,3 +60,36 @@ async def asend(self, generator: AsyncGenerator[_T_Yield, Any]) -> _T_Yield: return await generator.athrow(self.exception) finally: del generator, self # Needed to avoid circular reference with raised exception + + +class _GetAsyncGenHooks(Protocol): + @staticmethod + @abstractmethod + def __call__() -> sys._asyncgen_hooks: ... + + +class _SetAsyncGenHooks(Protocol): + @staticmethod + @abstractmethod + def __call__(firstiter: sys._AsyncgenHook = ..., finalizer: sys._AsyncgenHook = ...) -> None: ... + + +async def anext_without_asyncgen_hook( + agen: AsyncGenerator[_T_Yield, Any], + /, + *, + _get_asyncgen_hooks: _GetAsyncGenHooks = sys.get_asyncgen_hooks, + _set_asyncgen_hooks: _SetAsyncGenHooks = sys.set_asyncgen_hooks, +) -> _T_Yield: + previous_firstiter_hook = _get_asyncgen_hooks().firstiter + _set_asyncgen_hooks(firstiter=None) + try: + anext_coroutine = anext(agen) + finally: + _set_asyncgen_hooks(firstiter=previous_firstiter_hook) + previous_firstiter_hook = None + try: + return await anext_coroutine + except BaseException as exc: + _utils.remove_traceback_frames_in_place(exc, 1) + raise diff --git a/src/easynetwork/lowlevel/api_async/servers/datagram.py b/src/easynetwork/lowlevel/api_async/servers/datagram.py index 6c0e8c0b..3a07f130 100644 --- a/src/easynetwork/lowlevel/api_async/servers/datagram.py +++ b/src/easynetwork/lowlevel/api_async/servers/datagram.py @@ -32,7 +32,7 @@ from ....exceptions import DatagramProtocolParseError from ....protocol import DatagramProtocol from ... import _utils -from ..._asyncgen import AsyncGenAction, SendAction, ThrowAction +from ..._asyncgen import AsyncGenAction, SendAction, ThrowAction, anext_without_asyncgen_hook from ..backend.abc import AsyncBackend, ICondition, ILock, TaskGroup from ..transports import abc as _transports @@ -233,7 +233,7 @@ async def __client_coroutine_inner_loop( datagram: bytes = client_data.pop_datagram_no_wait() try: # Ignore sent timeout here, we already have the datagram. - await anext(request_handler_generator) + await anext_without_asyncgen_hook(request_handler_generator) except StopAsyncIteration: return else: diff --git a/src/easynetwork/lowlevel/api_async/servers/stream.py b/src/easynetwork/lowlevel/api_async/servers/stream.py index de50a937..342e3edc 100644 --- a/src/easynetwork/lowlevel/api_async/servers/stream.py +++ b/src/easynetwork/lowlevel/api_async/servers/stream.py @@ -27,7 +27,7 @@ from ...._typevars import _T_Request, _T_Response from ....protocol import AnyStreamProtocolType from ... import _stream, _utils -from ..._asyncgen import AsyncGenAction, SendAction, ThrowAction +from ..._asyncgen import AsyncGenAction, SendAction, ThrowAction, anext_without_asyncgen_hook from ..backend.abc import AsyncBackend, TaskGroup from ..transports import abc as _transports, utils as _transports_utils @@ -229,7 +229,7 @@ async def __client_coroutine( timeout: float | None try: - timeout = await anext(request_handler_generator) + timeout = await anext_without_asyncgen_hook(request_handler_generator) except StopAsyncIteration: return else: diff --git a/src/easynetwork/servers/misc.py b/src/easynetwork/servers/misc.py index 55d5e576..f2c024d1 100644 --- a/src/easynetwork/servers/misc.py +++ b/src/easynetwork/servers/misc.py @@ -63,7 +63,7 @@ def build_lowlevel_stream_server_handler( if logger is None: logger = logging.getLogger(__name__) - from ..lowlevel._asyncgen import SendAction, ThrowAction + from ..lowlevel._asyncgen import SendAction, ThrowAction, anext_without_asyncgen_hook async def handler( lowlevel_client: _lowlevel_stream_server.ConnectedStreamClient[_T_Response], / @@ -82,7 +82,7 @@ async def handler( _on_connection_hook = request_handler.on_connection(client) if isinstance(_on_connection_hook, AsyncGenerator): try: - timeout = await anext(_on_connection_hook) + timeout = await anext_without_asyncgen_hook(_on_connection_hook) except StopAsyncIteration: pass else: @@ -128,7 +128,7 @@ async def disconnect_client() -> None: while not client_is_closing(): request_handler_generator = new_request_handler(client) try: - timeout = await anext(request_handler_generator) + timeout = await anext_without_asyncgen_hook(request_handler_generator) except StopAsyncIteration: return else: @@ -181,7 +181,7 @@ def build_lowlevel_datagram_server_handler( an :term:`asynchronous generator` function. """ - from ..lowlevel._asyncgen import SendAction, ThrowAction + from ..lowlevel._asyncgen import SendAction, ThrowAction, anext_without_asyncgen_hook async def handler( lowlevel_client: _lowlevel_datagram_server.DatagramClientContext[_T_Response, _T_Address], / @@ -196,7 +196,7 @@ async def handler( request_handler_generator = request_handler.handle(client) timeout: float | None try: - timeout = await anext(request_handler_generator) + timeout = await anext_without_asyncgen_hook(request_handler_generator) except StopAsyncIteration: return else: diff --git a/tests/unit_test/test_tools/test_asyncgen.py b/tests/unit_test/test_tools/test_asyncgen.py index 66165050..52a85924 100644 --- a/tests/unit_test/test_tools/test_asyncgen.py +++ b/tests/unit_test/test_tools/test_asyncgen.py @@ -1,12 +1,14 @@ from __future__ import annotations +import contextlib import sys -from collections.abc import AsyncGenerator +from collections.abc import AsyncGenerator, AsyncIterator from typing import TYPE_CHECKING, Any -from easynetwork.lowlevel._asyncgen import SendAction, ThrowAction +from easynetwork.lowlevel._asyncgen import SendAction, ThrowAction, anext_without_asyncgen_hook import pytest +import pytest_asyncio if TYPE_CHECKING: from unittest.mock import MagicMock @@ -89,3 +91,63 @@ async def test____ThrowAction____does_not_create_reference_cycles(self, mock_gen unwrap_frame = exc.__traceback__.tb_next.tb_frame assert unwrap_frame.f_code.co_name == "asend" assert unwrap_frame.f_locals == {} + + +@pytest_asyncio.fixture +async def current_asyncgen_hooks() -> AsyncIterator[sys._asyncgen_hooks]: + current_hooks = sys.get_asyncgen_hooks() + yield current_hooks + sys.set_asyncgen_hooks(*current_hooks) + + +@pytest.mark.asyncio +async def test____anext_without_asyncgen_hook____skips_firstiter_hook( + current_asyncgen_hooks: sys._asyncgen_hooks, + mocker: MockerFixture, +) -> None: + # Arrange + firstiter_stub = mocker.stub("firstiter_hook") + firstiter_stub.side_effect = current_asyncgen_hooks.firstiter + firstiter_stub.return_value = None + sys.set_asyncgen_hooks(firstiter=firstiter_stub) + + async def async_generator_function() -> AsyncGenerator[int, None]: + yield 42 + + async_generator = async_generator_function() + + # Act + async with contextlib.aclosing(async_generator): + value = await anext_without_asyncgen_hook(async_generator) + + # Assert + assert value == 42 + firstiter_stub.assert_not_called() + assert sys.get_asyncgen_hooks() == (firstiter_stub, current_asyncgen_hooks.finalizer) + + +@pytest.mark.asyncio +async def test____anext_without_asyncgen_hook____remove_frame_on_error() -> None: + # Arrange + exc = ValueError("abc") + + async def async_generator_function() -> AsyncGenerator[int, None]: + if False: + yield 42 # type: ignore[unreachable] + raise exc + + async_generator = async_generator_function() + + # Act + with pytest.raises(ValueError): + await anext_without_asyncgen_hook(async_generator) + + # Assert + # Top frame in the traceback is the current test function; we don't care + # about its references + assert exc.__traceback__ is not None + assert exc.__traceback__.tb_frame is sys._getframe() + # The next frame down is the 'async_generator_function' frame + assert exc.__traceback__.tb_next is not None + generator_frame = exc.__traceback__.tb_next.tb_frame + assert generator_frame.f_code.co_name == "async_generator_function"