Skip to content

Commit

Permalink
AsyncSocket: shutdown() was not checking current running loop (#172)
Browse files Browse the repository at this point in the history
  • Loading branch information
francis-clairicia authored Nov 27, 2023
1 parent 169ebe1 commit 3a6dd4a
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 5 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ enable_error_code = ["truthy-bool", "ignore-without-code", "unused-awaitable"]

[tool.pytest.ini_options]
asyncio_mode = "strict" # Avoid some unwanted behaviour
addopts = "--strict-markers"
addopts = "--strict-markers -p 'no:anyio'" # hatch CLI dependencies installs anyio
minversion = "7.1.2"
testpaths = ["tests"]
norecursedirs = ["scripts"]
Expand Down
11 changes: 7 additions & 4 deletions src/easynetwork/lowlevel/asyncio/socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,14 +160,17 @@ async def recvfrom(self, bufsize: int, /) -> tuple[bytes, _socket._RetAddress]:
return await self.__loop.sock_recvfrom(socket, bufsize)

async def shutdown(self, how: int, /) -> None:
socket: _socket.socket = self.__check_not_closed()
# Checks if we are within the bound loop
TaskUtils.check_current_asyncio_task(self.__loop)

if how in {_socket.SHUT_RDWR, _socket.SHUT_WR}:
while (waiter := self.__waiters.get("send")) is not None:
try:
await asyncio.shield(waiter)
finally:
waiter = None # Breack cyclic reference with raised exception
waiter = None # Break cyclic reference with raised exception

socket: _socket.socket = self.__check_not_closed()
socket.shutdown(how)
await asyncio.sleep(0)

Expand All @@ -177,9 +180,9 @@ def __conflict_detection(self, task_id: _SocketTaskId, *, abort_errno: int = _er
raise _utils.error_from_errno(_errno.EBUSY)

# Checks if we are within the bound loop
TaskUtils.current_asyncio_task(self.__loop) # type: ignore[unused-awaitable]
TaskUtils.check_current_asyncio_task(self.__loop)

with CancelScope() as scope, contextlib.ExitStack() as stack:
with contextlib.ExitStack() as stack, CancelScope() as scope:
self.__scopes.add(scope)
stack.callback(self.__scopes.discard, scope)

Expand Down
4 changes: 4 additions & 0 deletions src/easynetwork/lowlevel/asyncio/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,10 @@ def __cancel_task_unless_done(task: asyncio.Task[Any], cancel_msg: str | None) -

@final
class TaskUtils:
@staticmethod
def check_current_asyncio_task(loop: asyncio.AbstractEventLoop | None = None) -> None:
_ = TaskUtils.current_asyncio_task(loop=loop)

@staticmethod
def current_asyncio_task(loop: asyncio.AbstractEventLoop | None = None) -> asyncio.Task[Any]:
t: asyncio.Task[Any] | None = asyncio.current_task(loop=loop)
Expand Down

0 comments on commit 3a6dd4a

Please sign in to comment.