Skip to content

Commit

Permalink
Servers: Temporary disable async generator hooks for request handlers (
Browse files Browse the repository at this point in the history
  • Loading branch information
francis-clairicia authored Jul 27, 2024
1 parent a15d0d8 commit f5667d5
Show file tree
Hide file tree
Showing 5 changed files with 110 additions and 12 deletions.
38 changes: 37 additions & 1 deletion src/easynetwork/lowlevel/_asyncgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,12 @@
__all__ = [] # type: list[str]

import dataclasses
import sys
from abc import ABCMeta, abstractmethod
from collections.abc import AsyncGenerator
from typing import Any, Generic, TypeVar
from typing import Any, Generic, Protocol, TypeVar

from . import _utils

_T_Send = TypeVar("_T_Send")
_T_Yield = TypeVar("_T_Yield")
Expand Down Expand Up @@ -57,3 +60,36 @@ async def asend(self, generator: AsyncGenerator[_T_Yield, Any]) -> _T_Yield:
return await generator.athrow(self.exception)
finally:
del generator, self # Needed to avoid circular reference with raised exception


class _GetAsyncGenHooks(Protocol):
@staticmethod
@abstractmethod
def __call__() -> sys._asyncgen_hooks: ...


class _SetAsyncGenHooks(Protocol):
@staticmethod
@abstractmethod
def __call__(firstiter: sys._AsyncgenHook = ..., finalizer: sys._AsyncgenHook = ...) -> None: ...


async def anext_without_asyncgen_hook(
agen: AsyncGenerator[_T_Yield, Any],
/,
*,
_get_asyncgen_hooks: _GetAsyncGenHooks = sys.get_asyncgen_hooks,
_set_asyncgen_hooks: _SetAsyncGenHooks = sys.set_asyncgen_hooks,
) -> _T_Yield:
previous_firstiter_hook = _get_asyncgen_hooks().firstiter
_set_asyncgen_hooks(firstiter=None)
try:
anext_coroutine = anext(agen)
finally:
_set_asyncgen_hooks(firstiter=previous_firstiter_hook)
previous_firstiter_hook = None
try:
return await anext_coroutine
except BaseException as exc:
_utils.remove_traceback_frames_in_place(exc, 1)
raise
4 changes: 2 additions & 2 deletions src/easynetwork/lowlevel/api_async/servers/datagram.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from ....exceptions import DatagramProtocolParseError
from ....protocol import DatagramProtocol
from ... import _utils
from ..._asyncgen import AsyncGenAction, SendAction, ThrowAction
from ..._asyncgen import AsyncGenAction, SendAction, ThrowAction, anext_without_asyncgen_hook
from ..backend.abc import AsyncBackend, ICondition, ILock, TaskGroup
from ..transports import abc as _transports

Expand Down Expand Up @@ -233,7 +233,7 @@ async def __client_coroutine_inner_loop(
datagram: bytes = client_data.pop_datagram_no_wait()
try:
# Ignore sent timeout here, we already have the datagram.
await anext(request_handler_generator)
await anext_without_asyncgen_hook(request_handler_generator)
except StopAsyncIteration:
return
else:
Expand Down
4 changes: 2 additions & 2 deletions src/easynetwork/lowlevel/api_async/servers/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from ...._typevars import _T_Request, _T_Response
from ....protocol import AnyStreamProtocolType
from ... import _stream, _utils
from ..._asyncgen import AsyncGenAction, SendAction, ThrowAction
from ..._asyncgen import AsyncGenAction, SendAction, ThrowAction, anext_without_asyncgen_hook
from ..backend.abc import AsyncBackend, TaskGroup
from ..transports import abc as _transports, utils as _transports_utils

Expand Down Expand Up @@ -229,7 +229,7 @@ async def __client_coroutine(

timeout: float | None
try:
timeout = await anext(request_handler_generator)
timeout = await anext_without_asyncgen_hook(request_handler_generator)
except StopAsyncIteration:
return
else:
Expand Down
10 changes: 5 additions & 5 deletions src/easynetwork/servers/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def build_lowlevel_stream_server_handler(
if logger is None:
logger = logging.getLogger(__name__)

from ..lowlevel._asyncgen import SendAction, ThrowAction
from ..lowlevel._asyncgen import SendAction, ThrowAction, anext_without_asyncgen_hook

async def handler(
lowlevel_client: _lowlevel_stream_server.ConnectedStreamClient[_T_Response], /
Expand All @@ -82,7 +82,7 @@ async def handler(
_on_connection_hook = request_handler.on_connection(client)
if isinstance(_on_connection_hook, AsyncGenerator):
try:
timeout = await anext(_on_connection_hook)
timeout = await anext_without_asyncgen_hook(_on_connection_hook)
except StopAsyncIteration:
pass
else:
Expand Down Expand Up @@ -128,7 +128,7 @@ async def disconnect_client() -> None:
while not client_is_closing():
request_handler_generator = new_request_handler(client)
try:
timeout = await anext(request_handler_generator)
timeout = await anext_without_asyncgen_hook(request_handler_generator)
except StopAsyncIteration:
return
else:
Expand Down Expand Up @@ -181,7 +181,7 @@ def build_lowlevel_datagram_server_handler(
an :term:`asynchronous generator` function.
"""

from ..lowlevel._asyncgen import SendAction, ThrowAction
from ..lowlevel._asyncgen import SendAction, ThrowAction, anext_without_asyncgen_hook

async def handler(
lowlevel_client: _lowlevel_datagram_server.DatagramClientContext[_T_Response, _T_Address], /
Expand All @@ -196,7 +196,7 @@ async def handler(
request_handler_generator = request_handler.handle(client)
timeout: float | None
try:
timeout = await anext(request_handler_generator)
timeout = await anext_without_asyncgen_hook(request_handler_generator)
except StopAsyncIteration:
return
else:
Expand Down
66 changes: 64 additions & 2 deletions tests/unit_test/test_tools/test_asyncgen.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from __future__ import annotations

import contextlib
import sys
from collections.abc import AsyncGenerator
from collections.abc import AsyncGenerator, AsyncIterator
from typing import TYPE_CHECKING, Any

from easynetwork.lowlevel._asyncgen import SendAction, ThrowAction
from easynetwork.lowlevel._asyncgen import SendAction, ThrowAction, anext_without_asyncgen_hook

import pytest
import pytest_asyncio

if TYPE_CHECKING:
from unittest.mock import MagicMock
Expand Down Expand Up @@ -89,3 +91,63 @@ async def test____ThrowAction____does_not_create_reference_cycles(self, mock_gen
unwrap_frame = exc.__traceback__.tb_next.tb_frame
assert unwrap_frame.f_code.co_name == "asend"
assert unwrap_frame.f_locals == {}


@pytest_asyncio.fixture
async def current_asyncgen_hooks() -> AsyncIterator[sys._asyncgen_hooks]:
current_hooks = sys.get_asyncgen_hooks()
yield current_hooks
sys.set_asyncgen_hooks(*current_hooks)


@pytest.mark.asyncio
async def test____anext_without_asyncgen_hook____skips_firstiter_hook(
current_asyncgen_hooks: sys._asyncgen_hooks,
mocker: MockerFixture,
) -> None:
# Arrange
firstiter_stub = mocker.stub("firstiter_hook")
firstiter_stub.side_effect = current_asyncgen_hooks.firstiter
firstiter_stub.return_value = None
sys.set_asyncgen_hooks(firstiter=firstiter_stub)

async def async_generator_function() -> AsyncGenerator[int, None]:
yield 42

async_generator = async_generator_function()

# Act
async with contextlib.aclosing(async_generator):
value = await anext_without_asyncgen_hook(async_generator)

# Assert
assert value == 42
firstiter_stub.assert_not_called()
assert sys.get_asyncgen_hooks() == (firstiter_stub, current_asyncgen_hooks.finalizer)


@pytest.mark.asyncio
async def test____anext_without_asyncgen_hook____remove_frame_on_error() -> None:
# Arrange
exc = ValueError("abc")

async def async_generator_function() -> AsyncGenerator[int, None]:
if False:
yield 42 # type: ignore[unreachable]
raise exc

async_generator = async_generator_function()

# Act
with pytest.raises(ValueError):
await anext_without_asyncgen_hook(async_generator)

# Assert
# Top frame in the traceback is the current test function; we don't care
# about its references
assert exc.__traceback__ is not None
assert exc.__traceback__.tb_frame is sys._getframe()
# The next frame down is the 'async_generator_function' frame
assert exc.__traceback__.tb_next is not None
generator_frame = exc.__traceback__.tb_next.tb_frame
assert generator_frame.f_code.co_name == "async_generator_function"

0 comments on commit f5667d5

Please sign in to comment.