Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Servers: Temporary disable async generator hooks for request handlers #332

Merged
merged 1 commit into from
Jul 27, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Servers: Temporary disable async generator hooks for request handlers
francis-clairicia committed Jul 27, 2024
commit 0df1dbfa13650f98c6f6f5034a710d9d0551479f
38 changes: 37 additions & 1 deletion src/easynetwork/lowlevel/_asyncgen.py
Original file line number Diff line number Diff line change
@@ -19,9 +19,12 @@
__all__ = [] # type: list[str]

import dataclasses
import sys
from abc import ABCMeta, abstractmethod
from collections.abc import AsyncGenerator
from typing import Any, Generic, TypeVar
from typing import Any, Generic, Protocol, TypeVar

from . import _utils

_T_Send = TypeVar("_T_Send")
_T_Yield = TypeVar("_T_Yield")
@@ -57,3 +60,36 @@ async def asend(self, generator: AsyncGenerator[_T_Yield, Any]) -> _T_Yield:
return await generator.athrow(self.exception)
finally:
del generator, self # Needed to avoid circular reference with raised exception


class _GetAsyncGenHooks(Protocol):
@staticmethod
@abstractmethod
def __call__() -> sys._asyncgen_hooks: ...


class _SetAsyncGenHooks(Protocol):
@staticmethod
@abstractmethod
def __call__(firstiter: sys._AsyncgenHook = ..., finalizer: sys._AsyncgenHook = ...) -> None: ...


async def anext_without_asyncgen_hook(
agen: AsyncGenerator[_T_Yield, Any],
/,
*,
_get_asyncgen_hooks: _GetAsyncGenHooks = sys.get_asyncgen_hooks,
_set_asyncgen_hooks: _SetAsyncGenHooks = sys.set_asyncgen_hooks,
) -> _T_Yield:
previous_firstiter_hook = _get_asyncgen_hooks().firstiter
_set_asyncgen_hooks(firstiter=None)
try:
anext_coroutine = anext(agen)
finally:
_set_asyncgen_hooks(firstiter=previous_firstiter_hook)
previous_firstiter_hook = None
try:
return await anext_coroutine
except BaseException as exc:
_utils.remove_traceback_frames_in_place(exc, 1)
raise
4 changes: 2 additions & 2 deletions src/easynetwork/lowlevel/api_async/servers/datagram.py
Original file line number Diff line number Diff line change
@@ -32,7 +32,7 @@
from ....exceptions import DatagramProtocolParseError
from ....protocol import DatagramProtocol
from ... import _utils
from ..._asyncgen import AsyncGenAction, SendAction, ThrowAction
from ..._asyncgen import AsyncGenAction, SendAction, ThrowAction, anext_without_asyncgen_hook
from ..backend.abc import AsyncBackend, ICondition, ILock, TaskGroup
from ..transports import abc as _transports

@@ -233,7 +233,7 @@ async def __client_coroutine_inner_loop(
datagram: bytes = client_data.pop_datagram_no_wait()
try:
# Ignore sent timeout here, we already have the datagram.
await anext(request_handler_generator)
await anext_without_asyncgen_hook(request_handler_generator)
except StopAsyncIteration:
return
else:
4 changes: 2 additions & 2 deletions src/easynetwork/lowlevel/api_async/servers/stream.py
Original file line number Diff line number Diff line change
@@ -27,7 +27,7 @@
from ...._typevars import _T_Request, _T_Response
from ....protocol import AnyStreamProtocolType
from ... import _stream, _utils
from ..._asyncgen import AsyncGenAction, SendAction, ThrowAction
from ..._asyncgen import AsyncGenAction, SendAction, ThrowAction, anext_without_asyncgen_hook
from ..backend.abc import AsyncBackend, TaskGroup
from ..transports import abc as _transports, utils as _transports_utils

@@ -229,7 +229,7 @@ async def __client_coroutine(

timeout: float | None
try:
timeout = await anext(request_handler_generator)
timeout = await anext_without_asyncgen_hook(request_handler_generator)
except StopAsyncIteration:
return
else:
10 changes: 5 additions & 5 deletions src/easynetwork/servers/misc.py
Original file line number Diff line number Diff line change
@@ -63,7 +63,7 @@ def build_lowlevel_stream_server_handler(
if logger is None:
logger = logging.getLogger(__name__)

from ..lowlevel._asyncgen import SendAction, ThrowAction
from ..lowlevel._asyncgen import SendAction, ThrowAction, anext_without_asyncgen_hook

async def handler(
lowlevel_client: _lowlevel_stream_server.ConnectedStreamClient[_T_Response], /
@@ -82,7 +82,7 @@ async def handler(
_on_connection_hook = request_handler.on_connection(client)
if isinstance(_on_connection_hook, AsyncGenerator):
try:
timeout = await anext(_on_connection_hook)
timeout = await anext_without_asyncgen_hook(_on_connection_hook)
except StopAsyncIteration:
pass
else:
@@ -128,7 +128,7 @@ async def disconnect_client() -> None:
while not client_is_closing():
request_handler_generator = new_request_handler(client)
try:
timeout = await anext(request_handler_generator)
timeout = await anext_without_asyncgen_hook(request_handler_generator)
except StopAsyncIteration:
return
else:
@@ -181,7 +181,7 @@ def build_lowlevel_datagram_server_handler(
an :term:`asynchronous generator` function.
"""

from ..lowlevel._asyncgen import SendAction, ThrowAction
from ..lowlevel._asyncgen import SendAction, ThrowAction, anext_without_asyncgen_hook

async def handler(
lowlevel_client: _lowlevel_datagram_server.DatagramClientContext[_T_Response, _T_Address], /
@@ -196,7 +196,7 @@ async def handler(
request_handler_generator = request_handler.handle(client)
timeout: float | None
try:
timeout = await anext(request_handler_generator)
timeout = await anext_without_asyncgen_hook(request_handler_generator)
except StopAsyncIteration:
return
else:
66 changes: 64 additions & 2 deletions tests/unit_test/test_tools/test_asyncgen.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from __future__ import annotations

import contextlib
import sys
from collections.abc import AsyncGenerator
from collections.abc import AsyncGenerator, AsyncIterator
from typing import TYPE_CHECKING, Any

from easynetwork.lowlevel._asyncgen import SendAction, ThrowAction
from easynetwork.lowlevel._asyncgen import SendAction, ThrowAction, anext_without_asyncgen_hook

import pytest
import pytest_asyncio

if TYPE_CHECKING:
from unittest.mock import MagicMock
@@ -89,3 +91,63 @@ async def test____ThrowAction____does_not_create_reference_cycles(self, mock_gen
unwrap_frame = exc.__traceback__.tb_next.tb_frame
assert unwrap_frame.f_code.co_name == "asend"
assert unwrap_frame.f_locals == {}


@pytest_asyncio.fixture
async def current_asyncgen_hooks() -> AsyncIterator[sys._asyncgen_hooks]:
current_hooks = sys.get_asyncgen_hooks()
yield current_hooks
sys.set_asyncgen_hooks(*current_hooks)


@pytest.mark.asyncio
async def test____anext_without_asyncgen_hook____skips_firstiter_hook(
current_asyncgen_hooks: sys._asyncgen_hooks,
mocker: MockerFixture,
) -> None:
# Arrange
firstiter_stub = mocker.stub("firstiter_hook")
firstiter_stub.side_effect = current_asyncgen_hooks.firstiter
firstiter_stub.return_value = None
sys.set_asyncgen_hooks(firstiter=firstiter_stub)

async def async_generator_function() -> AsyncGenerator[int, None]:
yield 42

async_generator = async_generator_function()

# Act
async with contextlib.aclosing(async_generator):
value = await anext_without_asyncgen_hook(async_generator)

# Assert
assert value == 42
firstiter_stub.assert_not_called()
assert sys.get_asyncgen_hooks() == (firstiter_stub, current_asyncgen_hooks.finalizer)


@pytest.mark.asyncio
async def test____anext_without_asyncgen_hook____remove_frame_on_error() -> None:
# Arrange
exc = ValueError("abc")

async def async_generator_function() -> AsyncGenerator[int, None]:
if False:
yield 42 # type: ignore[unreachable]
raise exc

async_generator = async_generator_function()

# Act
with pytest.raises(ValueError):
await anext_without_asyncgen_hook(async_generator)

# Assert
# Top frame in the traceback is the current test function; we don't care
# about its references
assert exc.__traceback__ is not None
assert exc.__traceback__.tb_frame is sys._getframe()
# The next frame down is the 'async_generator_function' frame
assert exc.__traceback__.tb_next is not None
generator_frame = exc.__traceback__.tb_next.tb_frame
assert generator_frame.f_code.co_name == "async_generator_function"