diff --git a/docs/source/reference-lowlevel.rst b/docs/source/reference-lowlevel.rst index 10c1ddfdc..46c8b4d48 100644 --- a/docs/source/reference-lowlevel.rst +++ b/docs/source/reference-lowlevel.rst @@ -377,6 +377,46 @@ These transitions are accomplished using two function decorators: poorly-timed :exc:`KeyboardInterrupt` could leave the lock in an inconsistent state and cause a deadlock. + Since KeyboardInterrupt protection is tracked per code object, any attempt to + conditionally protect the same block of code in different ways is unlikely to behave + how you expect. If you try to conditionally protect a closure, it will be + unconditionally protected instead:: + + def example(protect: bool) -> bool: + def inner() -> bool: + return trio.lowlevel.currently_ki_protected() + if protect: + inner = trio.lowlevel.enable_ki_protection(inner) + return inner() + + async def amain(): + assert example(False) == False + assert example(True) == True # once protected ... + assert example(False) == True # ... always protected + + trio.run(amain) + + If you really need conditional protection, you can achieve it by giving each + KI-protected instance of the closure its own code object:: + + def example(protect: bool) -> bool: + def inner() -> bool: + return trio.lowlevel.currently_ki_protected() + if protect: + inner.__code__ = inner.__code__.replace() + inner = trio.lowlevel.enable_ki_protection(inner) + return inner() + + async def amain(): + assert example(False) == False + assert example(True) == True + assert example(False) == False + + trio.run(amain) + + (This isn't done by default because it carries some memory overhead and reduces + the potential for specializing optimizations in recent versions of CPython.) + .. autofunction:: currently_ki_protected diff --git a/newsfragments/2670.bugfix.rst b/newsfragments/2670.bugfix.rst new file mode 100644 index 000000000..cd5ed3b94 --- /dev/null +++ b/newsfragments/2670.bugfix.rst @@ -0,0 +1,2 @@ +:func:`inspect.iscoroutinefunction` and the like now give correct answers when +called on KI-protected functions. diff --git a/newsfragments/3108.bugfix.rst b/newsfragments/3108.bugfix.rst new file mode 100644 index 000000000..16cf46b96 --- /dev/null +++ b/newsfragments/3108.bugfix.rst @@ -0,0 +1,26 @@ +Rework KeyboardInterrupt protection to track code objects, rather than frames, +as protected or not. The new implementation no longer needs to access +``frame.f_locals`` dictionaries, so it won't artificially extend the lifetime of +local variables. Since KeyboardInterrupt protection is now imposed statically +(when a protected function is defined) rather than each time the function runs, +its previously-noticeable performance overhead should now be near zero. +The lack of a call-time wrapper has some other benefits as well: + +* :func:`inspect.iscoroutinefunction` and the like now give correct answers when + called on KI-protected functions. + +* Calling a synchronous KI-protected function no longer pushes an additional stack + frame, so tracebacks are clearer. + +* A synchronous KI-protected function invoked from C code (such as a weakref + finalizer) is now guaranteed to start executing; previously there would be a brief + window in which KeyboardInterrupt could be raised before the protection was + established. + +One minor drawback of the new approach is that multiple instances of the same +closure share a single KeyboardInterrupt protection state (because they share a +single code object). That means that if you apply +`@enable_ki_protection ` to some of them +and not others, you won't get the protection semantics you asked for. See the +documentation of `@enable_ki_protection ` +for more details and a workaround. diff --git a/src/trio/_core/_generated_instrumentation.py b/src/trio/_core/_generated_instrumentation.py index 568b76dff..d03ef9db7 100644 --- a/src/trio/_core/_generated_instrumentation.py +++ b/src/trio/_core/_generated_instrumentation.py @@ -3,10 +3,9 @@ # ************************************************************* from __future__ import annotations -import sys from typing import TYPE_CHECKING -from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED +from ._ki import enable_ki_protection from ._run import GLOBAL_RUN_CONTEXT if TYPE_CHECKING: @@ -15,6 +14,7 @@ __all__ = ["add_instrument", "remove_instrument"] +@enable_ki_protection def add_instrument(instrument: Instrument) -> None: """Start instrumenting the current run loop with the given instrument. @@ -24,13 +24,13 @@ def add_instrument(instrument: Instrument) -> None: If ``instrument`` is already active, does nothing. """ - sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return GLOBAL_RUN_CONTEXT.runner.instruments.add_instrument(instrument) except AttributeError: raise RuntimeError("must be called from async context") from None +@enable_ki_protection def remove_instrument(instrument: Instrument) -> None: """Stop instrumenting the current run loop with the given instrument. @@ -44,7 +44,6 @@ def remove_instrument(instrument: Instrument) -> None: deactivated. """ - sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return GLOBAL_RUN_CONTEXT.runner.instruments.remove_instrument(instrument) except AttributeError: diff --git a/src/trio/_core/_generated_io_epoll.py b/src/trio/_core/_generated_io_epoll.py index 9f9ad5972..41cbb4065 100644 --- a/src/trio/_core/_generated_io_epoll.py +++ b/src/trio/_core/_generated_io_epoll.py @@ -6,7 +6,7 @@ import sys from typing import TYPE_CHECKING -from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED +from ._ki import enable_ki_protection from ._run import GLOBAL_RUN_CONTEXT if TYPE_CHECKING: @@ -18,6 +18,7 @@ __all__ = ["notify_closing", "wait_readable", "wait_writable"] +@enable_ki_protection async def wait_readable(fd: int | _HasFileNo) -> None: """Block until the kernel reports that the given object is readable. @@ -40,13 +41,13 @@ async def wait_readable(fd: int | _HasFileNo) -> None: if another task calls :func:`notify_closing` while this function is still working. """ - sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_readable(fd) except AttributeError: raise RuntimeError("must be called from async context") from None +@enable_ki_protection async def wait_writable(fd: int | _HasFileNo) -> None: """Block until the kernel reports that the given object is writable. @@ -59,13 +60,13 @@ async def wait_writable(fd: int | _HasFileNo) -> None: if another task calls :func:`notify_closing` while this function is still working. """ - sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_writable(fd) except AttributeError: raise RuntimeError("must be called from async context") from None +@enable_ki_protection def notify_closing(fd: int | _HasFileNo) -> None: """Notify waiters of the given object that it will be closed. @@ -91,7 +92,6 @@ def notify_closing(fd: int | _HasFileNo) -> None: step, so other tasks won't be able to tell what order they happened in anyway. """ - sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return GLOBAL_RUN_CONTEXT.runner.io_manager.notify_closing(fd) except AttributeError: diff --git a/src/trio/_core/_generated_io_kqueue.py b/src/trio/_core/_generated_io_kqueue.py index b2bdfc576..016704eac 100644 --- a/src/trio/_core/_generated_io_kqueue.py +++ b/src/trio/_core/_generated_io_kqueue.py @@ -6,7 +6,7 @@ import sys from typing import TYPE_CHECKING -from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED +from ._ki import enable_ki_protection from ._run import GLOBAL_RUN_CONTEXT if TYPE_CHECKING: @@ -31,18 +31,19 @@ ] +@enable_ki_protection def current_kqueue() -> select.kqueue: """TODO: these are implemented, but are currently more of a sketch than anything real. See `#26 `__. """ - sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return GLOBAL_RUN_CONTEXT.runner.io_manager.current_kqueue() except AttributeError: raise RuntimeError("must be called from async context") from None +@enable_ki_protection def monitor_kevent( ident: int, filter: int, @@ -51,13 +52,13 @@ def monitor_kevent( anything real. See `#26 `__. """ - sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return GLOBAL_RUN_CONTEXT.runner.io_manager.monitor_kevent(ident, filter) except AttributeError: raise RuntimeError("must be called from async context") from None +@enable_ki_protection async def wait_kevent( ident: int, filter: int, @@ -67,7 +68,6 @@ async def wait_kevent( anything real. See `#26 `__. """ - sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_kevent( ident, @@ -78,6 +78,7 @@ async def wait_kevent( raise RuntimeError("must be called from async context") from None +@enable_ki_protection async def wait_readable(fd: int | _HasFileNo) -> None: """Block until the kernel reports that the given object is readable. @@ -100,13 +101,13 @@ async def wait_readable(fd: int | _HasFileNo) -> None: if another task calls :func:`notify_closing` while this function is still working. """ - sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_readable(fd) except AttributeError: raise RuntimeError("must be called from async context") from None +@enable_ki_protection async def wait_writable(fd: int | _HasFileNo) -> None: """Block until the kernel reports that the given object is writable. @@ -119,13 +120,13 @@ async def wait_writable(fd: int | _HasFileNo) -> None: if another task calls :func:`notify_closing` while this function is still working. """ - sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_writable(fd) except AttributeError: raise RuntimeError("must be called from async context") from None +@enable_ki_protection def notify_closing(fd: int | _HasFileNo) -> None: """Notify waiters of the given object that it will be closed. @@ -151,7 +152,6 @@ def notify_closing(fd: int | _HasFileNo) -> None: step, so other tasks won't be able to tell what order they happened in anyway. """ - sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return GLOBAL_RUN_CONTEXT.runner.io_manager.notify_closing(fd) except AttributeError: diff --git a/src/trio/_core/_generated_io_windows.py b/src/trio/_core/_generated_io_windows.py index d06bb19e0..745fa4fc4 100644 --- a/src/trio/_core/_generated_io_windows.py +++ b/src/trio/_core/_generated_io_windows.py @@ -6,7 +6,7 @@ import sys from typing import TYPE_CHECKING -from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED +from ._ki import enable_ki_protection from ._run import GLOBAL_RUN_CONTEXT if TYPE_CHECKING: @@ -34,6 +34,7 @@ ] +@enable_ki_protection async def wait_readable(sock: _HasFileNo | int) -> None: """Block until the kernel reports that the given object is readable. @@ -56,13 +57,13 @@ async def wait_readable(sock: _HasFileNo | int) -> None: if another task calls :func:`notify_closing` while this function is still working. """ - sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_readable(sock) except AttributeError: raise RuntimeError("must be called from async context") from None +@enable_ki_protection async def wait_writable(sock: _HasFileNo | int) -> None: """Block until the kernel reports that the given object is writable. @@ -75,13 +76,13 @@ async def wait_writable(sock: _HasFileNo | int) -> None: if another task calls :func:`notify_closing` while this function is still working. """ - sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_writable(sock) except AttributeError: raise RuntimeError("must be called from async context") from None +@enable_ki_protection def notify_closing(handle: Handle | int | _HasFileNo) -> None: """Notify waiters of the given object that it will be closed. @@ -107,33 +108,32 @@ def notify_closing(handle: Handle | int | _HasFileNo) -> None: step, so other tasks won't be able to tell what order they happened in anyway. """ - sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return GLOBAL_RUN_CONTEXT.runner.io_manager.notify_closing(handle) except AttributeError: raise RuntimeError("must be called from async context") from None +@enable_ki_protection def register_with_iocp(handle: int | CData) -> None: """TODO: these are implemented, but are currently more of a sketch than anything real. See `#26 `__ and `#52 `__. """ - sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return GLOBAL_RUN_CONTEXT.runner.io_manager.register_with_iocp(handle) except AttributeError: raise RuntimeError("must be called from async context") from None +@enable_ki_protection async def wait_overlapped(handle_: int | CData, lpOverlapped: CData | int) -> object: """TODO: these are implemented, but are currently more of a sketch than anything real. See `#26 `__ and `#52 `__. """ - sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_overlapped( handle_, @@ -143,6 +143,7 @@ async def wait_overlapped(handle_: int | CData, lpOverlapped: CData | int) -> ob raise RuntimeError("must be called from async context") from None +@enable_ki_protection async def write_overlapped( handle: int | CData, data: Buffer, @@ -153,7 +154,6 @@ async def write_overlapped( `__ and `#52 `__. """ - sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return await GLOBAL_RUN_CONTEXT.runner.io_manager.write_overlapped( handle, @@ -164,6 +164,7 @@ async def write_overlapped( raise RuntimeError("must be called from async context") from None +@enable_ki_protection async def readinto_overlapped( handle: int | CData, buffer: Buffer, @@ -174,7 +175,6 @@ async def readinto_overlapped( `__ and `#52 `__. """ - sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return await GLOBAL_RUN_CONTEXT.runner.io_manager.readinto_overlapped( handle, @@ -185,19 +185,20 @@ async def readinto_overlapped( raise RuntimeError("must be called from async context") from None +@enable_ki_protection def current_iocp() -> int: """TODO: these are implemented, but are currently more of a sketch than anything real. See `#26 `__ and `#52 `__. """ - sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return GLOBAL_RUN_CONTEXT.runner.io_manager.current_iocp() except AttributeError: raise RuntimeError("must be called from async context") from None +@enable_ki_protection def monitor_completion_key() -> ( AbstractContextManager[tuple[int, UnboundedQueue[object]]] ): @@ -206,7 +207,6 @@ def monitor_completion_key() -> ( `__ and `#52 `__. """ - sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return GLOBAL_RUN_CONTEXT.runner.io_manager.monitor_completion_key() except AttributeError: diff --git a/src/trio/_core/_generated_run.py b/src/trio/_core/_generated_run.py index b5957a134..67d70d907 100644 --- a/src/trio/_core/_generated_run.py +++ b/src/trio/_core/_generated_run.py @@ -3,10 +3,9 @@ # ************************************************************* from __future__ import annotations -import sys -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING -from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED +from ._ki import enable_ki_protection from ._run import _NO_SEND, GLOBAL_RUN_CONTEXT, RunStatistics, Task if TYPE_CHECKING: @@ -33,6 +32,7 @@ ] +@enable_ki_protection def current_statistics() -> RunStatistics: """Returns ``RunStatistics``, which contains run-loop-level debugging information. @@ -56,13 +56,13 @@ def current_statistics() -> RunStatistics: other attributes vary between backends. """ - sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return GLOBAL_RUN_CONTEXT.runner.current_statistics() except AttributeError: raise RuntimeError("must be called from async context") from None +@enable_ki_protection def current_time() -> float: """Returns the current time according to Trio's internal clock. @@ -73,36 +73,36 @@ def current_time() -> float: RuntimeError: if not inside a call to :func:`trio.run`. """ - sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return GLOBAL_RUN_CONTEXT.runner.current_time() except AttributeError: raise RuntimeError("must be called from async context") from None +@enable_ki_protection def current_clock() -> Clock: """Returns the current :class:`~trio.abc.Clock`.""" - sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return GLOBAL_RUN_CONTEXT.runner.current_clock() except AttributeError: raise RuntimeError("must be called from async context") from None +@enable_ki_protection def current_root_task() -> Task | None: """Returns the current root :class:`Task`. This is the task that is the ultimate parent of all other tasks. """ - sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return GLOBAL_RUN_CONTEXT.runner.current_root_task() except AttributeError: raise RuntimeError("must be called from async context") from None -def reschedule(task: Task, next_send: Outcome[Any] = _NO_SEND) -> None: +@enable_ki_protection +def reschedule(task: Task, next_send: Outcome[object] = _NO_SEND) -> None: """Reschedule the given task with the given :class:`outcome.Outcome`. @@ -120,13 +120,13 @@ def reschedule(task: Task, next_send: Outcome[Any] = _NO_SEND) -> None: raise) from :func:`wait_task_rescheduled`. """ - sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return GLOBAL_RUN_CONTEXT.runner.reschedule(task, next_send) except AttributeError: raise RuntimeError("must be called from async context") from None +@enable_ki_protection def spawn_system_task( async_fn: Callable[[Unpack[PosArgT]], Awaitable[object]], *args: Unpack[PosArgT], @@ -184,7 +184,6 @@ def spawn_system_task( Task: the newly spawned task """ - sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return GLOBAL_RUN_CONTEXT.runner.spawn_system_task( async_fn, @@ -196,18 +195,19 @@ def spawn_system_task( raise RuntimeError("must be called from async context") from None +@enable_ki_protection def current_trio_token() -> TrioToken: """Retrieve the :class:`TrioToken` for the current call to :func:`trio.run`. """ - sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return GLOBAL_RUN_CONTEXT.runner.current_trio_token() except AttributeError: raise RuntimeError("must be called from async context") from None +@enable_ki_protection async def wait_all_tasks_blocked(cushion: float = 0.0) -> None: """Block until there are no runnable tasks. @@ -266,7 +266,6 @@ async def test_lock_fairness(): print("FAIL") """ - sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return await GLOBAL_RUN_CONTEXT.runner.wait_all_tasks_blocked(cushion) except AttributeError: diff --git a/src/trio/_core/_ki.py b/src/trio/_core/_ki.py index a8431f89d..672501f75 100644 --- a/src/trio/_core/_ki.py +++ b/src/trio/_core/_ki.py @@ -1,26 +1,21 @@ from __future__ import annotations -import inspect import signal import sys -from functools import wraps -from typing import TYPE_CHECKING, Final, Protocol, TypeVar +import types +import weakref +from typing import TYPE_CHECKING, Any, Generic, Protocol, TypeVar import attrs from .._util import is_main_thread - -CallableT = TypeVar("CallableT", bound="Callable[..., object]") -RetT = TypeVar("RetT") +from ._run_context import GLOBAL_RUN_CONTEXT if TYPE_CHECKING: import types from collections.abc import Callable - from typing_extensions import ParamSpec, TypeGuard - - ArgsT = ParamSpec("ArgsT") - + from typing_extensions import Self, TypeGuard # In ordinary single-threaded Python code, when you hit control-C, it raises # an exception and automatically does all the regular unwinding stuff. # @@ -83,20 +78,112 @@ # for any Python program that's written to catch and ignore # KeyboardInterrupt.) -# We use this special string as a unique key into the frame locals dictionary. -# The @ ensures it is not a valid identifier and can't clash with any possible -# real local name. See: https://github.com/python-trio/trio/issues/469 -LOCALS_KEY_KI_PROTECTION_ENABLED: Final = "@TRIO_KI_PROTECTION_ENABLED" +_T = TypeVar("_T") + + +class _IdRef(weakref.ref[_T]): + __slots__ = ("_hash",) + _hash: int + + def __new__(cls, ob: _T, callback: Callable[[Self], Any] | None = None, /) -> Self: + self: Self = weakref.ref.__new__(cls, ob, callback) + self._hash = object.__hash__(ob) + return self + + def __eq__(self, other: object) -> bool: + if self is other: + return True + + if not isinstance(other, _IdRef): + return NotImplemented + + my_obj = None + try: + my_obj = self() + return my_obj is not None and my_obj is other() + finally: + del my_obj + + # we're overriding a builtin so we do need this + def __ne__(self, other: object) -> bool: + return not self == other + + def __hash__(self) -> int: + return self._hash + + +_KT = TypeVar("_KT") +_VT = TypeVar("_VT") + + +# see also: https://github.com/python/cpython/issues/88306 +class WeakKeyIdentityDictionary(Generic[_KT, _VT]): + def __init__(self) -> None: + self._data: dict[_IdRef[_KT], _VT] = {} + + def remove( + k: _IdRef[_KT], + selfref: weakref.ref[ + WeakKeyIdentityDictionary[_KT, _VT] + ] = weakref.ref( # noqa: B008 # function-call-in-default-argument + self, + ), + ) -> None: + self = selfref() + if self is not None: + try: # noqa: SIM105 # supressible-exception + del self._data[k] + except KeyError: + pass + + self._remove = remove + + def __getitem__(self, k: _KT) -> _VT: + return self._data[_IdRef(k)] + + def __setitem__(self, k: _KT, v: _VT) -> None: + self._data[_IdRef(k, self._remove)] = v + + +_CODE_KI_PROTECTION_STATUS_WMAP: WeakKeyIdentityDictionary[ + types.CodeType, + bool, +] = WeakKeyIdentityDictionary() + + +# This is to support the async_generator package necessary for aclosing on <3.10 +# functions decorated @async_generator are given this magic property that's a +# reference to the object itself +# see python-trio/async_generator/async_generator/_impl.py +def legacy_isasyncgenfunction( + obj: object, +) -> TypeGuard[Callable[..., types.AsyncGeneratorType[object, object]]]: + return getattr(obj, "_async_gen_function", None) == id(obj) # NB: according to the signal.signal docs, 'frame' can be None on entry to # this function: def ki_protection_enabled(frame: types.FrameType | None) -> bool: + try: + task = GLOBAL_RUN_CONTEXT.task + except AttributeError: + task_ki_protected = False + task_frame = None + else: + task_ki_protected = task._ki_protected + task_frame = task.coro.cr_frame + while frame is not None: - if LOCALS_KEY_KI_PROTECTION_ENABLED in frame.f_locals: - return bool(frame.f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED]) + try: + v = _CODE_KI_PROTECTION_STATUS_WMAP[frame.f_code] + except KeyError: + pass + else: + return bool(v) if frame.f_code.co_name == "__del__": return True + if frame is task_frame: + return task_ki_protected frame = frame.f_back return True @@ -117,90 +204,33 @@ def currently_ki_protected() -> bool: return ki_protection_enabled(sys._getframe()) -# This is to support the async_generator package necessary for aclosing on <3.10 -# functions decorated @async_generator are given this magic property that's a -# reference to the object itself -# see python-trio/async_generator/async_generator/_impl.py -def legacy_isasyncgenfunction( - obj: object, -) -> TypeGuard[Callable[..., types.AsyncGeneratorType[object, object]]]: - return getattr(obj, "_async_gen_function", None) == id(obj) +class _SupportsCode(Protocol): + __code__: types.CodeType -def _ki_protection_decorator( - enabled: bool, -) -> Callable[[Callable[ArgsT, RetT]], Callable[ArgsT, RetT]]: - # The "ignore[return-value]" below is because the inspect functions cast away the - # original return type of fn, making it just CoroutineType[Any, Any, Any] etc. - # ignore[misc] is because @wraps() is passed a callable with Any in the return type. - def decorator(fn: Callable[ArgsT, RetT]) -> Callable[ArgsT, RetT]: - # In some version of Python, isgeneratorfunction returns true for - # coroutine functions, so we have to check for coroutine functions - # first. - if inspect.iscoroutinefunction(fn): - - @wraps(fn) - def wrapper(*args: ArgsT.args, **kwargs: ArgsT.kwargs) -> RetT: # type: ignore[misc] - # See the comment for regular generators below - coro = fn(*args, **kwargs) - assert coro.cr_frame is not None, "Coroutine frame should exist" - coro.cr_frame.f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = enabled - return coro # type: ignore[return-value] - - return wrapper - elif inspect.isgeneratorfunction(fn): - - @wraps(fn) - def wrapper(*args: ArgsT.args, **kwargs: ArgsT.kwargs) -> RetT: # type: ignore[misc] - # It's important that we inject this directly into the - # generator's locals, as opposed to setting it here and then - # doing 'yield from'. The reason is, if a generator is - # throw()n into, then it may magically pop to the top of the - # stack. And @contextmanager generators in particular are a - # case where we often want KI protection, and which are often - # thrown into! See: - # https://bugs.python.org/issue29590 - gen = fn(*args, **kwargs) - gen.gi_frame.f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = enabled - return gen # type: ignore[return-value] - - return wrapper - elif inspect.isasyncgenfunction(fn) or legacy_isasyncgenfunction(fn): - - @wraps(fn) # type: ignore[arg-type] - def wrapper(*args: ArgsT.args, **kwargs: ArgsT.kwargs) -> RetT: # type: ignore[misc] - # See the comment for regular generators above - agen = fn(*args, **kwargs) - agen.ag_frame.f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = enabled - return agen # type: ignore[return-value] - - return wrapper - else: - - @wraps(fn) - def wrapper(*args: ArgsT.args, **kwargs: ArgsT.kwargs) -> RetT: - sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = enabled - return fn(*args, **kwargs) +_T_supports_code = TypeVar("_T_supports_code", bound=_SupportsCode) - return wrapper - return decorator +def enable_ki_protection(f: _T_supports_code, /) -> _T_supports_code: + """Decorator to enable KI protection.""" + orig = f + if legacy_isasyncgenfunction(f): + f = f.__wrapped__ # type: ignore -# pyright workaround: https://github.com/microsoft/pyright/issues/5866 -class KIProtectionSignature(Protocol): - __name__: str + _CODE_KI_PROTECTION_STATUS_WMAP[f.__code__] = True + return orig - def __call__(self, f: CallableT, /) -> CallableT: - pass +def disable_ki_protection(f: _T_supports_code, /) -> _T_supports_code: + """Decorator to disable KI protection.""" + orig = f -# the following `type: ignore`s are because we use ParamSpec internally, but want to allow overloads -enable_ki_protection: KIProtectionSignature = _ki_protection_decorator(True) # type: ignore[assignment] -enable_ki_protection.__name__ = "enable_ki_protection" + if legacy_isasyncgenfunction(f): + f = f.__wrapped__ # type: ignore -disable_ki_protection: KIProtectionSignature = _ki_protection_decorator(False) # type: ignore[assignment] -disable_ki_protection.__name__ = "disable_ki_protection" + _CODE_KI_PROTECTION_STATUS_WMAP[f.__code__] = False + return orig @attrs.define(slots=False) diff --git a/src/trio/_core/_run.py b/src/trio/_core/_run.py index cba7a8dec..3961a6e10 100644 --- a/src/trio/_core/_run.py +++ b/src/trio/_core/_run.py @@ -7,7 +7,6 @@ import random import select import sys -import threading import warnings from collections import deque from contextlib import AbstractAsyncContextManager, contextmanager, suppress @@ -39,8 +38,9 @@ from ._entry_queue import EntryQueue, TrioToken from ._exceptions import Cancelled, RunFinishedError, TrioInternalError from ._instrumentation import Instruments -from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED, KIManager, enable_ki_protection +from ._ki import KIManager, enable_ki_protection from ._parking_lot import GLOBAL_PARKING_LOT_BREAKER +from ._run_context import GLOBAL_RUN_CONTEXT as GLOBAL_RUN_CONTEXT from ._thread_cache import start_thread_soon from ._traps import ( Abort, @@ -1382,6 +1382,7 @@ class Task(metaclass=NoPublicConstructor): name: str context: contextvars.Context _counter: int = attrs.field(init=False, factory=itertools.count().__next__) + _ki_protected: bool # Invariant: # - for unscheduled tasks, _next_send_fn and _next_send are both None @@ -1557,14 +1558,6 @@ def raise_cancel() -> NoReturn: ################################################################ -class RunContext(threading.local): - runner: Runner - task: Task - - -GLOBAL_RUN_CONTEXT: Final = RunContext() - - @attrs.frozen class RunStatistics: """An object containing run-loop-level debugging information. @@ -1780,11 +1773,11 @@ def current_root_task(self) -> Task | None: # Core task handling primitives ################ - @_public # Type-ignore due to use of Any here. - def reschedule( # type: ignore[misc] + @_public + def reschedule( self, task: Task, - next_send: Outcome[Any] = _NO_SEND, + next_send: Outcome[object] = _NO_SEND, ) -> None: """Reschedule the given task with the given :class:`outcome.Outcome`. @@ -1871,7 +1864,6 @@ async def python_wrapper(orig_coro: Awaitable[RetT]) -> RetT: coro = python_wrapper(coro) assert coro.cr_frame is not None, "Coroutine frame should exist" - coro.cr_frame.f_locals.setdefault(LOCALS_KEY_KI_PROTECTION_ENABLED, system_task) ###### # Set up the Task object @@ -1882,6 +1874,7 @@ async def python_wrapper(orig_coro: Awaitable[RetT]) -> RetT: runner=self, name=name, context=context, + ki_protected=system_task, ) self.tasks.add(task) @@ -2573,13 +2566,13 @@ def my_done_callback(run_outcome): # mode", where our core event loop gets unrolled into a series of callbacks on # the host loop. If you're doing a regular trio.run then this gets run # straight through. +@enable_ki_protection def unrolled_run( runner: Runner, async_fn: Callable[[Unpack[PosArgT]], Awaitable[object]], args: tuple[Unpack[PosArgT]], host_uses_signal_set_wakeup_fd: bool = False, ) -> Generator[float, EventResult, None]: - sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True __tracebackhide__ = True try: diff --git a/src/trio/_core/_run_context.py b/src/trio/_core/_run_context.py new file mode 100644 index 000000000..085bff9a3 --- /dev/null +++ b/src/trio/_core/_run_context.py @@ -0,0 +1,15 @@ +from __future__ import annotations + +import threading +from typing import TYPE_CHECKING, Final + +if TYPE_CHECKING: + from ._run import Runner, Task + + +class RunContext(threading.local): + runner: Runner + task: Task + + +GLOBAL_RUN_CONTEXT: Final = RunContext() diff --git a/src/trio/_core/_tests/test_ki.py b/src/trio/_core/_tests/test_ki.py index 8582cc0b2..d403cfa7a 100644 --- a/src/trio/_core/_tests/test_ki.py +++ b/src/trio/_core/_tests/test_ki.py @@ -3,14 +3,19 @@ import contextlib import inspect import signal +import sys import threading -from typing import TYPE_CHECKING +import weakref +from collections.abc import AsyncIterator, Iterator +from typing import TYPE_CHECKING, Callable, TypeVar import outcome import pytest from trio.testing import RaisesGroup +from .tutil import gc_collect_harder + try: from async_generator import async_generator, yield_ except ImportError: # pragma: no cover @@ -18,12 +23,19 @@ from ... import _core from ..._abc import Instrument +from ..._core import _ki from ..._timeouts import sleep from ..._util import signal_raise from ...testing import wait_all_tasks_blocked if TYPE_CHECKING: - from collections.abc import AsyncIterator, Callable, Iterator + from collections.abc import ( + AsyncGenerator, + AsyncIterator, + Callable, + Generator, + Iterator, + ) from ..._core import Abort, RaiseCancelT @@ -517,3 +529,176 @@ async def inner() -> None: _core.run(inner) finally: threading._active[thread.ident] = original # type: ignore[attr-defined] + + +_T = TypeVar("_T") + + +def _identity(v: _T) -> _T: + return v + + +@pytest.mark.xfail( + strict=True, + raises=AssertionError, + reason=( + "it was decided not to protect against this case, see discussion in: " + "https://github.com/python-trio/trio/pull/3110#discussion_r1802123644" + ), +) +async def test_ki_does_not_leak_across_different_calls_to_inner_functions() -> None: + assert not _core.currently_ki_protected() + + def factory(enabled: bool) -> Callable[[], bool]: + @_core.enable_ki_protection if enabled else _identity + def decorated() -> bool: + return _core.currently_ki_protected() + + return decorated + + decorated_enabled = factory(True) + decorated_disabled = factory(False) + assert decorated_enabled() + assert not decorated_disabled() + + +async def test_ki_protection_check_does_not_freeze_locals() -> None: + class A: + pass + + a = A() + wr_a = weakref.ref(a) + assert not _core.currently_ki_protected() + del a + if sys.implementation.name == "pypy": + gc_collect_harder() + assert wr_a() is None + + +def test_identity_weakref_internals() -> None: + """To cover the parts WeakKeyIdentityDictionary won't ever reach.""" + + class A: + def __eq__(self, other: object) -> bool: + return False + + a = A() + assert a != a + wr = _ki._IdRef(a) + wr_other_is_self = wr + + # dict always checks identity before equality so we need to do it here + # to cover `if self is other` + assert wr == wr_other_is_self + + # we want to cover __ne__ and `return NotImplemented` + assert wr != object() + + +def test_weak_key_identity_dict_remove_callback_keyerror() -> None: + """We need to cover the KeyError in self._remove.""" + + class A: + def __eq__(self, other: object) -> bool: + return False + + a = A() + assert a != a + d: _ki.WeakKeyIdentityDictionary[A, bool] = _ki.WeakKeyIdentityDictionary() + + d[a] = True + + data_copy = d._data.copy() + d._data.clear() + del a + + gc_collect_harder() # would call sys.unraisablehook if there's a problem + assert data_copy + + +def test_weak_key_identity_dict_remove_callback_selfref_expired() -> None: + """We need to cover the KeyError in self._remove.""" + + class A: + def __eq__(self, other: object) -> bool: + return False + + a = A() + assert a != a + d: _ki.WeakKeyIdentityDictionary[A, bool] = _ki.WeakKeyIdentityDictionary() + + d[a] = True + + data_copy = d._data.copy() + wr_d = weakref.ref(d) + del d + gc_collect_harder() # would call sys.unraisablehook if there's a problem + assert wr_d() is None + del a + gc_collect_harder() + assert data_copy + + +@_core.enable_ki_protection +async def _protected_async_gen_fn() -> AsyncGenerator[None, None]: + yield + + +@_core.enable_ki_protection +async def _protected_async_fn() -> None: + pass + + +@_core.enable_ki_protection +def _protected_gen_fn() -> Generator[None, None, None]: + yield + + +@_core.disable_ki_protection +async def _unprotected_async_gen_fn() -> AsyncGenerator[None, None]: + yield + + +@_core.disable_ki_protection +async def _unprotected_async_fn() -> None: + pass + + +@_core.disable_ki_protection +def _unprotected_gen_fn() -> Generator[None, None, None]: + yield + + +async def _consume_async_generator(agen: AsyncGenerator[None, None]) -> None: + try: + with pytest.raises(StopAsyncIteration): + while True: + await agen.asend(None) + finally: + await agen.aclose() + + +def _consume_function_for_coverage(fn: Callable[..., object]) -> None: + result = fn() + if inspect.isasyncgen(result): + result = _consume_async_generator(result) + + assert inspect.isgenerator(result) or inspect.iscoroutine(result) + with pytest.raises(StopIteration): + while True: + result.send(None) + + +def test_enable_disable_ki_protection_passes_on_inspect_flags() -> None: + assert inspect.isasyncgenfunction(_protected_async_gen_fn) + _consume_function_for_coverage(_protected_async_gen_fn) + assert inspect.iscoroutinefunction(_protected_async_fn) + _consume_function_for_coverage(_protected_async_fn) + assert inspect.isgeneratorfunction(_protected_gen_fn) + _consume_function_for_coverage(_protected_gen_fn) + assert inspect.isasyncgenfunction(_unprotected_async_gen_fn) + _consume_function_for_coverage(_unprotected_async_gen_fn) + assert inspect.iscoroutinefunction(_unprotected_async_fn) + _consume_function_for_coverage(_unprotected_async_fn) + assert inspect.isgeneratorfunction(_unprotected_gen_fn) + _consume_function_for_coverage(_unprotected_gen_fn) diff --git a/src/trio/_core/_tests/test_run.py b/src/trio/_core/_tests/test_run.py index d87bd2102..f04c95161 100644 --- a/src/trio/_core/_tests/test_run.py +++ b/src/trio/_core/_tests/test_run.py @@ -2778,3 +2778,47 @@ async def spawn_tasks_in_old_nursery(task_status: _core.TaskStatus[None]) -> Non with pytest.raises(_core.TrioInternalError) as excinfo: await nursery.start(spawn_tasks_in_old_nursery) assert RaisesGroup(ValueError, ValueError).matches(excinfo.value.__cause__) + + +if sys.version_info <= (3, 11): + + def no_other_refs() -> list[object]: + return [sys._getframe(1)] + +else: + + def no_other_refs() -> list[object]: + return [] + + +@pytest.mark.skipif( + sys.implementation.name != "cpython", + reason="Only makes sense with refcounting GC", +) +async def test_ki_protection_doesnt_leave_cyclic_garbage() -> None: + class MyException(Exception): + pass + + async def demo() -> None: + async def handle_error() -> None: + try: + raise MyException + except MyException as e: + exceptions.append(e) + + exceptions: list[MyException] = [] + try: + async with _core.open_nursery() as n: + n.start_soon(handle_error) + raise ExceptionGroup("errors", exceptions) + finally: + exceptions = [] + + exc: Exception | None = None + try: + await demo() + except ExceptionGroup as excs: + exc = excs.exceptions[0] + + assert isinstance(exc, MyException) + assert gc.get_referrers(exc) == no_other_refs() diff --git a/src/trio/_tests/_check_type_completeness.json b/src/trio/_tests/_check_type_completeness.json index badb7cba1..72d981f89 100644 --- a/src/trio/_tests/_check_type_completeness.json +++ b/src/trio/_tests/_check_type_completeness.json @@ -40,7 +40,6 @@ "No docstring found for class \"trio._core._local.RunVarToken\"", "No docstring found for class \"trio.lowlevel.RunVarToken\"", "No docstring found for class \"trio.lowlevel.Task\"", - "No docstring found for class \"trio._core._ki.KIProtectionSignature\"", "No docstring found for class \"trio.socket.SocketType\"", "No docstring found for class \"trio.socket.gaierror\"", "No docstring found for class \"trio.socket.herror\"", diff --git a/src/trio/_tools/gen_exports.py b/src/trio/_tools/gen_exports.py index c762ee138..b4db597b6 100755 --- a/src/trio/_tools/gen_exports.py +++ b/src/trio/_tools/gen_exports.py @@ -34,12 +34,11 @@ import sys -from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED +from ._ki import enable_ki_protection from ._run import GLOBAL_RUN_CONTEXT """ -TEMPLATE = """sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True -try: +TEMPLATE = """try: return{}GLOBAL_RUN_CONTEXT.{}.{} except AttributeError: raise RuntimeError("must be called from async context") from None @@ -237,7 +236,7 @@ def gen_public_wrappers_source(file: File) -> str: is_cm = False # Remove decorators - method.decorator_list = [] + method.decorator_list = [ast.Name("enable_ki_protection")] # Create pass through arguments new_args = create_passthrough_args(method)