Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed CancelScope's delayed cancellation system #132

Merged
merged 2 commits into from
Sep 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions src/easynetwork_asyncio/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,6 @@ class CancelScope(AbstractCancelScope):
"__state",
"__cancel_called",
"__cancelled_caught",
"__task_cancelling",
"__deadline",
"__timeout_handle",
"__delayed_cancellation_on_enter",
Expand All @@ -160,7 +159,6 @@ def __init__(self, *, deadline: float = math.inf) -> None:
self.__state: _ScopeState = _ScopeState.CREATED
self.__cancel_called: bool = False
self.__cancelled_caught: bool = False
self.__task_cancelling: int = 0
self.__deadline: float = math.inf
self.__timeout_handle: asyncio.TimerHandle | None = None
self.reschedule(deadline)
Expand All @@ -180,7 +178,6 @@ def __enter__(self) -> Self:
raise RuntimeError("CancelScope entered twice")

self.__host_task = current_task = TaskUtils.current_asyncio_task()
self.__task_cancelling = current_task.cancelling()

current_task_scope = self.__current_task_scope_dict
if current_task not in current_task_scope:
Expand Down Expand Up @@ -218,13 +215,15 @@ def __exit__(self, exc_type: type[BaseException] | None, exc_val: BaseException
if self.__cancel_called:
task_cancelling = host_task.uncancel()
if isinstance(exc_val, asyncio.CancelledError):
self.__cancelled_caught = task_cancelling <= self.__task_cancelling or self.__cancellation_id() in exc_val.args
self.__cancelled_caught = self.__cancellation_id() in exc_val.args

delayed_task_cancel = self.__delayed_task_cancel_dict.get(host_task, None)
if delayed_task_cancel is not None and delayed_task_cancel.message == self.__cancellation_id():
del self.__delayed_task_cancel_dict[host_task]
delayed_task_cancel.handle.cancel()
delayed_task_cancel = None

if delayed_task_cancel is None:
for cancel_scope in self._inner_to_outer_task_scopes(host_task):
if cancel_scope.__cancel_called:
self._reschedule_delayed_task_cancel(host_task, cancel_scope.__cancellation_id())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1240,6 +1240,91 @@ async def coroutine() -> None:

await event_loop.create_task(coroutine())

@pytest.mark.xfail(raises=asyncio.CancelledError, reason="Task.cancel() cannot be erased", strict=True)
async def test____cancel_shielded_coroutine____scope_cancellation_edge_case_3(
self,
event_loop: asyncio.AbstractEventLoop,
backend: AsyncIOBackend,
) -> None:
async def coroutine() -> None:
current_task = asyncio.current_task()
assert current_task is not None

with backend.move_on_after(0) as inner_scope:
pass

await backend.coro_yield()

assert inner_scope.cancel_called()

assert not inner_scope.cancelled_caught()

await event_loop.create_task(coroutine())

async def test____cancel_shielded_coroutine____scope_cancellation_edge_case_4(
self,
event_loop: asyncio.AbstractEventLoop,
backend: AsyncIOBackend,
) -> None:
async def coroutine() -> None:
current_task = asyncio.current_task()
assert current_task is not None

outer_scope = backend.open_cancel_scope()
inner_scope = backend.open_cancel_scope()
with outer_scope:
outer_scope.cancel()

await backend.ignore_cancellation(backend.sleep(0.1))

with inner_scope:
inner_scope.cancel()

await backend.coro_yield()

await backend.coro_yield()

assert outer_scope.cancel_called()
assert inner_scope.cancel_called()

assert not inner_scope.cancelled_caught()
assert outer_scope.cancelled_caught()

await event_loop.create_task(coroutine())

async def test____cancel_shielded_coroutine____scope_cancellation_edge_case_5(
self,
event_loop: asyncio.AbstractEventLoop,
backend: AsyncIOBackend,
) -> None:
async def coroutine() -> None:
current_task = asyncio.current_task()
assert current_task is not None

outer_scope = backend.open_cancel_scope()
inner_scope = backend.open_cancel_scope()
with outer_scope:
with inner_scope:
inner_scope.cancel()

await backend.cancel_shielded_coro_yield()

outer_scope.cancel()

await backend.cancel_shielded_coro_yield()

await backend.coro_yield()

await backend.coro_yield()

assert outer_scope.cancel_called()
assert inner_scope.cancel_called()

assert not inner_scope.cancelled_caught()
assert outer_scope.cancelled_caught()

await event_loop.create_task(coroutine())

async def test____ignore_cancellation____do_not_reschedule_if_inner_task_cancelled_itself(
self,
event_loop: asyncio.AbstractEventLoop,
Expand Down