diff --git a/src/easynetwork/lowlevel/api_async/backend/_asyncio/stream/socket.py b/src/easynetwork/lowlevel/api_async/backend/_asyncio/stream/socket.py index 57096c74..1c2cf38a 100644 --- a/src/easynetwork/lowlevel/api_async/backend/_asyncio/stream/socket.py +++ b/src/easynetwork/lowlevel/api_async/backend/_asyncio/stream/socket.py @@ -32,6 +32,7 @@ from ....transports.abc import AsyncStreamTransport from ...abc import AsyncBackend from .._flow_control import WriteFlowControl, add_flowcontrol_defaults +from ..tasks import TaskUtils if TYPE_CHECKING: from _typeshed import WriteableBuffer @@ -248,7 +249,7 @@ async def receive_data(self, bufsize: int, /) -> bytes: if bufsize < 0: raise ValueError("'bufsize' must be a positive or null integer") - await self._wait_for_data("receive_data") + blocked: bool = await self._wait_for_data("receive_data") nbytes_written = self.__buffer_nbytes_written if nbytes_written: @@ -262,6 +263,8 @@ async def receive_data(self, bufsize: int, /) -> bytes: else: data = b"" self._maybe_resume_transport() + if not blocked: + await TaskUtils.cancel_shielded_coro_yield() return data async def receive_data_into(self, buffer: WriteableBuffer, /) -> int: @@ -271,7 +274,7 @@ async def receive_data_into(self, buffer: WriteableBuffer, /) -> int: if not buffer: return 0 - await self._wait_for_data("receive_data_into") + blocked: bool = await self._wait_for_data("receive_data_into") nbytes_written = self.__buffer_nbytes_written if nbytes_written: @@ -287,14 +290,16 @@ async def receive_data_into(self, buffer: WriteableBuffer, /) -> int: self.__buffer_nbytes_written = 0 self._maybe_resume_transport() + if not blocked: + await TaskUtils.cancel_shielded_coro_yield() return nbytes_written - async def _wait_for_data(self, requester: str) -> None: + async def _wait_for_data(self, requester: str) -> bool: if self.__read_waiter is not None: raise RuntimeError(f"{requester}() called while another coroutine is already waiting for incoming data") if self.__buffer_nbytes_written or self.__eof_reached: - return + return False assert not self.__read_paused, "transport reading is paused" # nosec assert_used @@ -309,6 +314,7 @@ async def _wait_for_data(self, requester: str) -> None: self.__read_waiter = None if self.__connection_lost_exception is not None: raise self.__connection_lost_exception.with_traceback(self.__connection_lost_exception_tb) + return True def _wakeup_read_waiter(self, exc: Exception | None) -> None: if (waiter := self.__read_waiter) is not None: diff --git a/src/easynetwork/lowlevel/api_async/servers/stream.py b/src/easynetwork/lowlevel/api_async/servers/stream.py index 342e3edc..b94feabb 100644 --- a/src/easynetwork/lowlevel/api_async/servers/stream.py +++ b/src/easynetwork/lowlevel/api_async/servers/stream.py @@ -275,6 +275,8 @@ async def next(self, timeout: float | None) -> AsyncGenAction[_T_Request]: except StopIteration: pass else: + if data is None: + await self.__backend.cancel_shielded_coro_yield() return SendAction(request) finally: data = None @@ -307,6 +309,8 @@ async def next(self, timeout: float | None) -> AsyncGenAction[_T_Request]: except StopIteration: pass else: + if nbytes is None: + await self.__backend.cancel_shielded_coro_yield() return SendAction(request) nbytes = await self.transport.recv_into(consumer.get_write_buffer()) if not nbytes: