Skip to content

Commit

Permalink
Move to src and add type annotations
Browse files Browse the repository at this point in the history
  • Loading branch information
CoolCat467 committed Jan 14, 2025
1 parent 40211e0 commit d5327ca
Showing 1 changed file with 95 additions and 29 deletions.
124 changes: 95 additions & 29 deletions trio/_shared_task.py → src/trio/_shared_task.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,26 @@
from __future__ import annotations

import functools
from typing import TYPE_CHECKING, Generic, TypeVar, cast

import attr
import outcome

from trio._core._ki import disable_ki_protection, enable_ki_protection
from trio._core._run import CancelScope, spawn_system_task
from trio._sync import Event

if TYPE_CHECKING:
from collections.abc import Awaitable, Callable

from typing_extensions import ParamSpec, TypeVarTuple, Unpack

PS = ParamSpec("PS")
PosArgT = TypeVarTuple("PosArgT")

RetT = TypeVar("RetT")


__all__ = ["SharedTaskRegistry"]


Expand All @@ -13,70 +36,112 @@
# ? This is less important b/c we can document that if you want magic key
# generation then you should be careful to make your matching calls obviously
# matching.
def _unpack_call(fn, args, kwargs):
def _unpack_call(
fn: Callable[PS, RetT],
args: PS.args,
kwargs: PS.kwargs | dict[str, object],
) -> tuple[Callable[PS, RetT], PS.args, PS.kwargs | dict[str, object]]:
if isinstance(fn, functools.partial):
inner_fn, inner_args, inner_kwargs = _call_to_key(fn.func, fn.args, fn.kwargs)
inner_fn, inner_args, inner_kwargs = _unpack_call(fn.func, fn.args, {})
fn = inner_fn
args = (*inner_args, *args)
kwargs = {**inner_kwargs, **kwargs}
return fn, args, kwargs


def call_to_hashable_key(fn, args):
def call_to_hashable_key(
fn: Callable[[Unpack[PosArgT]], RetT],
args: tuple[Unpack[PosArgT]],
) -> tuple[
Callable[[Unpack[PosArgT]], RetT],
tuple[Unpack[PosArgT]],
tuple[tuple[str, object], ...],
]:
fn, args, kwargs = _unpack_call(fn, args, {})
return (fn, args, tuple(sorted(kwargs.items())))


class BaseSharedTask:
__slots__ = ()


@attr.s
class SharedTask:
registry = attr.ib()
key = attr.ib()
cancel_scope = attr.ib(default=None)
class SharedTask(BaseSharedTask, Generic["Unpack[PosArgT]", RetT]):
registry: SharedTaskRegistry = attr.ib()
key: tuple[
Callable[[Unpack[PosArgT]], Awaitable[RetT]],
tuple[Unpack[PosArgT]],
tuple[tuple[str, object], ...],
] = attr.ib()
cancel_scope: CancelScope | None = attr.ib(default=None)
# Needed to work around a race condition, where we realize we want to
# cancel the child before it's even created the cancel scope
cancelled_early = attr.ib(default=False)
cancelled_early: bool = attr.ib(default=False)
# Reference count
waiter_count = attr.ib(default=0)
waiter_count: int = attr.ib(default=0)
# Reporting back
finished = attr.ib(default=attr.Factory(trio.Event))
result = attr.ib(default=None)
finished: Event = attr.ib(default=attr.Factory(Event))
result: outcome.Value[RetT] | outcome.Error = attr.ib(default=None)

# This runs in system task context, so it has KI protection enabled and
# any exceptions will crash the whole program.
async def run(self, async_fn, args):
async def run(
self,
async_fn: Callable[[Unpack[PosArgT]], Awaitable[RetT]],
args: tuple[Unpack[PosArgT]],
) -> None:
@disable_ki_protection
async def ki_unprotected_runner() -> RetT:
return await async_fn(*args)

async def cancellable_runner():
with trio.open_cancel_scope() as cancel_scope:
async def cancellable_runner() -> RetT:
with CancelScope() as cancel_scope:
self.cancel_scope = cancel_scope
if self.cancelled_early:
self.cancel_scope.cancel()
return await ki_unprotected_runner()
raise RuntimeError("Should be unreachable.")

@trio.hazmat.disable_ki_protection
async def ki_unprotected_runner():
return await async_fn(*args)

self.result = await Result.acapture(cancellable_runner)
self.result = await outcome.acapture(cancellable_runner)
self.finished.set()
if self.registry._tasks.get(self.key) is self:
del self.registry._tasks[self.key]


@attr.s(slots=True, frozen=True, hash=False, cmp=False, repr=False)
class SharedTaskRegistry:
_tasks = attr.ib(default=attr.Factory(dict))

@trio.hazmat.enable_ki_protection
async def run(self, async_fn, *args, key=None):
class SharedTaskRegistry: # type: ignore[misc]
_tasks: dict[ # type: ignore[misc]
tuple[
Callable[..., Awaitable[object]],
tuple[object, ...],
tuple[tuple[str, object], ...],
],
BaseSharedTask,
] = attr.ib(default=attr.Factory(dict))

@enable_ki_protection
async def run(
self,
async_fn: Callable[[Unpack[PosArgT]], Awaitable[RetT]],
*args: Unpack[PosArgT],
key: (
tuple[
Callable[[Unpack[PosArgT]], Awaitable[RetT]],
tuple[Unpack[PosArgT]],
tuple[tuple[str, object], ...],
]
| None
) = None,
) -> RetT:
if key is None:
key = call_to_hashable_key(async_fn, args)

if key not in self._tasks:
shared_task = SharedTask(self, key)
shared_task = SharedTask["Unpack[PosArgT]", RetT](self, key)
self._tasks[key] = shared_task
trio.hazmat.spawn_system_task(shared_task.run, async_fn, args)
spawn_system_task(shared_task.run, async_fn, args)

shared_task = self._tasks[key]
shared_task = cast("SharedTask[Unpack[PosArgT], RetT]", self._tasks[key])
shared_task.waiter_count += 1

try:
Expand All @@ -94,8 +159,8 @@ async def run(self, async_fn, *args, key=None):
else:
shared_task.cancel_scope.cancel()

with trio.open_cancel_scope(shield=True) as cancel_scope:
await shared_task.finished()
with CancelScope(shield=True):
await shared_task.finished.wait()
# Some possibilities:
# - they raised Cancelled. The cancellation we injected is
# absorbed internally, though, so this can only happen
Expand All @@ -107,6 +172,7 @@ async def run(self, async_fn, *args, key=None):
# - they raise some other error: we should propagate
# - they return nothing (most common, b/c cancelled was
# raised and then
assert shared_task.cancel_scope is not None
if not shared_task.cancel_scope.cancelled_caught:
return shared_task.result.unwrap()
else:
Expand Down

0 comments on commit d5327ca

Please sign in to comment.