diff --git a/newsfragments/3112.bugfix.rst b/newsfragments/3112.bugfix.rst new file mode 100644 index 000000000..c34d03552 --- /dev/null +++ b/newsfragments/3112.bugfix.rst @@ -0,0 +1,5 @@ +Rework foreign async generator finalization to track async generator +ids rather than mutating ``ag_frame.f_locals``. This fixes an issue +with the previous implementation: locals' lifetimes will no longer be +extended by materialization in the ``ag_frame.f_locals`` dictionary that +the previous finalization dispatcher logic needed to access to do its work. diff --git a/src/trio/_core/_asyncgens.py b/src/trio/_core/_asyncgens.py index 21102ea7d..b3b689575 100644 --- a/src/trio/_core/_asyncgens.py +++ b/src/trio/_core/_asyncgens.py @@ -4,7 +4,7 @@ import sys import warnings import weakref -from typing import TYPE_CHECKING, NoReturn +from typing import TYPE_CHECKING, NoReturn, TypeVar import attrs @@ -16,14 +16,31 @@ ASYNCGEN_LOGGER = logging.getLogger("trio.async_generator_errors") if TYPE_CHECKING: + from collections.abc import Callable from types import AsyncGeneratorType + from typing_extensions import ParamSpec + + _P = ParamSpec("_P") + _WEAK_ASYNC_GEN_SET = weakref.WeakSet[AsyncGeneratorType[object, NoReturn]] _ASYNC_GEN_SET = set[AsyncGeneratorType[object, NoReturn]] else: _WEAK_ASYNC_GEN_SET = weakref.WeakSet _ASYNC_GEN_SET = set +_R = TypeVar("_R") + + +@_core.disable_ki_protection +def _call_without_ki_protection( + f: Callable[_P, _R], + /, + *args: _P.args, + **kwargs: _P.kwargs, +) -> _R: + return f(*args, **kwargs) + @attrs.define(eq=False) class AsyncGenerators: @@ -35,6 +52,11 @@ class AsyncGenerators: # regular set so we don't have to deal with GC firing at # unexpected times. alive: _WEAK_ASYNC_GEN_SET | _ASYNC_GEN_SET = attrs.Factory(_WEAK_ASYNC_GEN_SET) + # The ids of foreign async generators are added to this set when first + # iterated. Usually it is not safe to refer to ids like this, but because + # we're using a finalizer we can ensure ids in this set do not outlive + # their async generator. + foreign: set[int] = attrs.Factory(set) # This collects async generators that get garbage collected during # the one-tick window between the system nursery closing and the @@ -51,10 +73,10 @@ def firstiter(agen: AsyncGeneratorType[object, NoReturn]) -> None: # An async generator first iterated outside of a Trio # task doesn't belong to Trio. Probably we're in guest # mode and the async generator belongs to our host. - # The locals dictionary is the only good place to + # A strong set of ids is one of the only good places to # remember this fact, at least until - # https://bugs.python.org/issue40916 is implemented. - agen.ag_frame.f_locals["@trio_foreign_asyncgen"] = True + # https://github.com/python/cpython/issues/85093 is implemented. + self.foreign.add(id(agen)) if self.prev_hooks.firstiter is not None: self.prev_hooks.firstiter(agen) @@ -76,13 +98,16 @@ def finalize_in_trio_context( # have hit it. self.trailing_needs_finalize.add(agen) + @_core.enable_ki_protection def finalizer(agen: AsyncGeneratorType[object, NoReturn]) -> None: - agen_name = name_asyncgen(agen) try: - is_ours = not agen.ag_frame.f_locals.get("@trio_foreign_asyncgen") - except AttributeError: # pragma: no cover + self.foreign.remove(id(agen)) + except KeyError: is_ours = True + else: + is_ours = False + agen_name = name_asyncgen(agen) if is_ours: runner.entry_queue.run_sync_soon( finalize_in_trio_context, @@ -105,8 +130,9 @@ def finalizer(agen: AsyncGeneratorType[object, NoReturn]) -> None: ) else: # Not ours -> forward to the host loop's async generator finalizer - if self.prev_hooks.finalizer is not None: - self.prev_hooks.finalizer(agen) + finalizer = self.prev_hooks.finalizer + if finalizer is not None: + _call_without_ki_protection(finalizer, agen) else: # Host has no finalizer. Reimplement the default # Python behavior with no hooks installed: throw in @@ -116,7 +142,7 @@ def finalizer(agen: AsyncGeneratorType[object, NoReturn]) -> None: try: # If the next thing is a yield, this will raise RuntimeError # which we allow to propagate - closer.send(None) + _call_without_ki_protection(closer.send, None) except StopIteration: pass else: diff --git a/src/trio/_core/_tests/test_guest_mode.py b/src/trio/_core/_tests/test_guest_mode.py index 5beee4b22..93cc32c9e 100644 --- a/src/trio/_core/_tests/test_guest_mode.py +++ b/src/trio/_core/_tests/test_guest_mode.py @@ -2,7 +2,6 @@ import asyncio import contextlib -import contextvars import queue import signal import socket @@ -11,6 +10,7 @@ import time import traceback import warnings +import weakref from collections.abc import AsyncGenerator, Awaitable, Callable from functools import partial from math import inf @@ -22,6 +22,7 @@ ) import pytest +import sniffio from outcome import Outcome import trio @@ -221,7 +222,8 @@ async def trio_main(in_host: InHost) -> str: def test_guest_mode_sniffio_integration() -> None: - from sniffio import current_async_library, thread_local as sniffio_library + current_async_library = sniffio.current_async_library + sniffio_library = sniffio.thread_local async def trio_main(in_host: InHost) -> str: async def synchronize() -> None: @@ -439,9 +441,9 @@ def aiotrio_run( loop = asyncio.new_event_loop() async def aio_main() -> T: - trio_done_fut = loop.create_future() + trio_done_fut: asyncio.Future[Outcome[T]] = loop.create_future() - def trio_done_callback(main_outcome: Outcome[object]) -> None: + def trio_done_callback(main_outcome: Outcome[T]) -> None: print(f"trio_fn finished: {main_outcome!r}") trio_done_fut.set_result(main_outcome) @@ -455,9 +457,11 @@ def trio_done_callback(main_outcome: Outcome[object]) -> None: **start_guest_run_kwargs, ) - return (await trio_done_fut).unwrap() # type: ignore[no-any-return] + return (await trio_done_fut).unwrap() try: + # can't use asyncio.run because that fails on Windows (3.8, x64, with + # Komodia LSP) and segfaults on Windows (3.9, x64, with Komodia LSP) return loop.run_until_complete(aio_main()) finally: loop.close() @@ -628,8 +632,6 @@ async def trio_main(in_host: InHost) -> None: @restore_unraisablehook() def test_guest_mode_asyncgens() -> None: - import sniffio - record = set() async def agen(label: str) -> AsyncGenerator[int, None]: @@ -656,9 +658,49 @@ async def trio_main() -> None: gc_collect_harder() - # Ensure we don't pollute the thread-level context if run under - # an asyncio without contextvars support (3.6) - context = contextvars.copy_context() - context.run(aiotrio_run, trio_main, host_uses_signal_set_wakeup_fd=True) + aiotrio_run(trio_main, host_uses_signal_set_wakeup_fd=True) assert record == {("asyncio", "asyncio"), ("trio", "trio")} + + +@restore_unraisablehook() +def test_guest_mode_asyncgens_garbage_collection() -> None: + record: set[tuple[str, str, bool]] = set() + + async def agen(label: str) -> AsyncGenerator[int, None]: + class A: + pass + + a = A() + a_wr = weakref.ref(a) + assert sniffio.current_async_library() == label + try: + yield 1 + finally: + library = sniffio.current_async_library() + with contextlib.suppress(trio.Cancelled): + await sys.modules[library].sleep(0) + + del a + if sys.implementation.name == "pypy": + gc_collect_harder() + + record.add((label, library, a_wr() is None)) + + async def iterate_in_aio() -> None: + await agen("asyncio").asend(None) + + async def trio_main() -> None: + task = asyncio.ensure_future(iterate_in_aio()) + done_evt = trio.Event() + task.add_done_callback(lambda _: done_evt.set()) + with trio.fail_after(1): + await done_evt.wait() + + await agen("trio").asend(None) + + gc_collect_harder() + + aiotrio_run(trio_main, host_uses_signal_set_wakeup_fd=True) + + assert record == {("asyncio", "asyncio", True), ("trio", "trio", True)}