Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed race condition in ThreadsPortal.run_coroutine_soon() #148

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 18 additions & 13 deletions src/easynetwork/lowlevel/asyncio/threads.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import asyncio
import concurrent.futures
import contextlib
import contextvars
import inspect
from collections.abc import Awaitable, Callable
Expand Down Expand Up @@ -85,21 +86,21 @@ def schedule_task() -> concurrent.futures.Future[_T]:
async def coroutine() -> None:
def on_fut_done(future: concurrent.futures.Future[_T]) -> None:
if future.cancelled():
try:
if self.__is_in_this_loop_thread(loop):
if not cancelling:
task.cancel()
return
with contextlib.suppress(RuntimeError):
self.run_sync(task.cancel)
except RuntimeError:
# on_fut_done() called from coroutine()
# or the portal is already shut down
pass

task = TaskUtils.current_asyncio_task()
loop = task.get_loop()
cancelling: bool = False
try:
if future.cancelled():
task.cancel()
else:
future.add_done_callback(on_fut_done)
future.add_done_callback(on_fut_done)
result = await coro_func(*args, **kwargs)
except asyncio.CancelledError:
cancelling = True
future.cancel()
future.set_running_or_notify_cancel()
raise
Expand Down Expand Up @@ -158,13 +159,17 @@ def __check_loop(self) -> asyncio.AbstractEventLoop:
loop = self.__loop
if loop is None:
raise RuntimeError("ThreadsPortal not running.")
if self.__is_in_this_loop_thread(loop):
raise RuntimeError("This function must be called in a different OS thread")
return loop

@staticmethod
def __is_in_this_loop_thread(loop: asyncio.AbstractEventLoop) -> bool:
try:
running_loop = asyncio.get_running_loop()
except RuntimeError:
return loop
if running_loop is loop:
raise RuntimeError("This function must be called in a different OS thread")
return loop
return False
return running_loop is loop

@staticmethod
def __register_waiter(waiters: set[asyncio.Future[None]], loop: asyncio.AbstractEventLoop) -> asyncio.Future[None]:
Expand Down