From 2f685e59f0641e0a25bf8fc780e2adfff2ba26eb Mon Sep 17 00:00:00 2001 From: Francis CLAIRICIA-ROSE-CLAIRE-JOSEPHINE Date: Sat, 23 Nov 2024 14:58:27 +0100 Subject: [PATCH] AsyncIOBackend: Improved recv_into() implementation --- .../backend/_asyncio/stream/socket.py | 120 +++++++++++------- .../test_asyncio_backend/test_stream.py | 72 +++++++++-- 2 files changed, 134 insertions(+), 58 deletions(-) diff --git a/src/easynetwork/lowlevel/api_async/backend/_asyncio/stream/socket.py b/src/easynetwork/lowlevel/api_async/backend/_asyncio/stream/socket.py index de74fb3b..e32f1ba4 100644 --- a/src/easynetwork/lowlevel/api_async/backend/_asyncio/stream/socket.py +++ b/src/easynetwork/lowlevel/api_async/backend/_asyncio/stream/socket.py @@ -26,7 +26,7 @@ import warnings from collections.abc import Callable, Iterable, Mapping from types import MappingProxyType, TracebackType -from typing import TYPE_CHECKING, Any, final +from typing import TYPE_CHECKING, Any, final, overload from ......exceptions import UnsupportedOperation from ..... import _utils, socket as socket_tools @@ -143,6 +143,7 @@ class StreamReaderBufferedProtocol(asyncio.BufferedProtocol): "__buffer", "__buffer_view", "__buffer_nbytes_written", + "__external_buffer_view", "__transport", "__closed", "__write_flow", @@ -170,10 +171,11 @@ def __init__( self.__loop: asyncio.AbstractEventLoop = loop self.__buffer: bytearray | None = bytearray(self.max_size) self.__buffer_view: memoryview = memoryview(self.__buffer) + self.__external_buffer_view: WriteableBuffer | None = None self.__buffer_nbytes_written: int = 0 self.__transport: asyncio.Transport | None = None self.__closed: asyncio.Future[None] = loop.create_future() - self.__read_waiter: asyncio.Future[None] | None = None + self.__read_waiter: asyncio.Future[int | None] | None = None self.__write_flow: WriteFlowControl self.__read_paused: bool = False self.__connection_lost: bool = False @@ -219,6 +221,8 @@ def connection_lost(self, exc: Exception | None) -> None: self.__transport = None def get_buffer(self, sizehint: int) -> WriteableBuffer: + if (external_buffer_view := self.__external_buffer_view) is not None: + return external_buffer_view # Ignore sizehint, the buffer is already at its maximum size. # Returns unused buffer part if self.__buffer is None: @@ -229,12 +233,21 @@ def buffer_updated(self, nbytes: int) -> None: assert not self.__connection_lost, "buffer_updated() after connection_lost()" # nosec assert_used assert not self.__eof_reached, "buffer_updated() after eof_received()" # nosec assert_used + if self.__external_buffer_view is not None: + # Early remove to prevent using this buffer between this point and the wakeup of the task. + self.__external_buffer_view = None + self._read_waiter_fut(lambda waiter: waiter.set_result(nbytes)) + # Call to _maybe_pause_transport() is unnecessary: Did not write in internal buffer. + return + self.__buffer_nbytes_written += nbytes assert 0 <= self.__buffer_nbytes_written <= self.__buffer_view.nbytes # nosec assert_used self._wakeup_read_waiter(None) self._maybe_pause_transport() def eof_received(self) -> bool: + # Early remove to prevent using this buffer between this point and the wakeup of the task. + self.__external_buffer_view = None self.__eof_reached = True self._wakeup_read_waiter(None) if self.__over_ssl: @@ -245,14 +258,13 @@ def eof_received(self) -> bool: return True async def receive_data(self, bufsize: int, /) -> bytes: - if self.__connection_lost_exception is not None: - raise self.__connection_lost_exception.with_traceback(self.__connection_lost_exception_tb) + self._check_for_connection_lost() if bufsize == 0: return b"" if bufsize < 0: raise ValueError("'bufsize' must be a positive or null integer") - blocked: bool = await self._wait_for_data("receive_data") + await self._wait_for_data("receive_data", None) nbytes_written = self.__buffer_nbytes_written if nbytes_written: @@ -266,66 +278,80 @@ async def receive_data(self, bufsize: int, /) -> bytes: else: data = b"" self._maybe_resume_transport() - if not blocked: - await TaskUtils.cancel_shielded_coro_yield() return data async def receive_data_into(self, buffer: WriteableBuffer, /) -> int: - if self.__connection_lost_exception is not None: - raise self.__connection_lost_exception.with_traceback(self.__connection_lost_exception_tb) - with memoryview(buffer) as buffer, buffer.cast("B") as buffer: + self._check_for_connection_lost() + + with memoryview(buffer) as buffer: if not buffer: return 0 - blocked: bool = await self._wait_for_data("receive_data_into") - - nbytes_written = self.__buffer_nbytes_written - if nbytes_written: - protocol_buffer_written = self.__buffer_view[:nbytes_written] - bufsize_offset = nbytes_written - buffer.nbytes - if bufsize_offset > 0: - nbytes_written = buffer.nbytes - buffer[:] = protocol_buffer_written[:nbytes_written] - protocol_buffer_written[:bufsize_offset] = protocol_buffer_written[-bufsize_offset:] - self.__buffer_nbytes_written = bufsize_offset - else: - buffer[:nbytes_written] = protocol_buffer_written - self.__buffer_nbytes_written = 0 + with buffer.cast("B") if buffer.itemsize != 1 else buffer as buffer: + nbytes_written = await self._wait_for_data("receive_data_into", buffer) + if nbytes_written is not None: + # Call to _maybe_resume_transport() is unnecessary: Did not write in internal buffer. + return nbytes_written + + nbytes_written = self.__buffer_nbytes_written + if nbytes_written: + protocol_buffer_written = self.__buffer_view[:nbytes_written] + bufsize_offset = nbytes_written - buffer.nbytes + if bufsize_offset > 0: + nbytes_written = buffer.nbytes + buffer[:] = protocol_buffer_written[:nbytes_written] + protocol_buffer_written[:bufsize_offset] = protocol_buffer_written[-bufsize_offset:] + self.__buffer_nbytes_written = bufsize_offset + else: + buffer[:nbytes_written] = protocol_buffer_written + self.__buffer_nbytes_written = 0 self._maybe_resume_transport() - if not blocked: - await TaskUtils.cancel_shielded_coro_yield() return nbytes_written - async def _wait_for_data(self, requester: str) -> bool: - if self.__read_waiter is not None: - raise RuntimeError(f"{requester}() called while another coroutine is already waiting for incoming data") - - if self.__buffer_nbytes_written or self.__eof_reached: - return False + @overload + async def _wait_for_data(self, requester: str, external_buffer: None) -> None: ... - assert not self.__read_paused, "transport reading is paused" # nosec assert_used + @overload + async def _wait_for_data(self, requester: str, external_buffer: WriteableBuffer) -> int | None: ... - if self.__transport is None: - # happening if transport.pause_reading() raises NotImplementedError - raise _utils.error_from_errno(_errno.ECONNABORTED) + async def _wait_for_data(self, requester: str, external_buffer: WriteableBuffer | None) -> int | None: + if self.__read_waiter is not None: + raise RuntimeError(f"{requester}() called while another coroutine is already waiting for incoming data") self.__read_waiter = self.__loop.create_future() try: - await self.__read_waiter + nbytes_written_in_external_buffer: int | None + if self.__buffer_nbytes_written or self.__eof_reached: + self.__read_waiter.set_result(None) + await TaskUtils.coro_yield() + nbytes_written_in_external_buffer = None + else: + assert not self.__read_paused, "transport reading is paused" # nosec assert_used + + if self.__transport is None: + # happening if transport.pause_reading() raises NotImplementedError + raise _utils.error_from_errno(_errno.ECONNABORTED) + + self.__external_buffer_view = external_buffer + try: + nbytes_written_in_external_buffer = await self.__read_waiter + finally: + self.__external_buffer_view = None finally: self.__read_waiter = None - if self.__connection_lost_exception is not None: - raise self.__connection_lost_exception.with_traceback(self.__connection_lost_exception_tb) - return True - def _wakeup_read_waiter(self, exc: Exception | None) -> None: + if nbytes_written_in_external_buffer is None: + self._check_for_connection_lost() + return nbytes_written_in_external_buffer + + def _read_waiter_fut(self, set_result_cb: Callable[[asyncio.Future[int | None]], None]) -> None: if (waiter := self.__read_waiter) is not None: if not waiter.done(): - if exc is None: - waiter.set_result(None) - else: - waiter.set_exception(exc) + set_result_cb(waiter) + + def _wakeup_read_waiter(self, exc: Exception | None) -> None: + self._read_waiter_fut(lambda waiter: waiter.set_result(None) if exc is None else waiter.set_exception(exc)) def _get_read_buffer_size(self) -> int: return self.__buffer_nbytes_written @@ -365,6 +391,10 @@ def _maybe_resume_transport(self) -> None: transport.resume_reading() self.__read_paused = False + def _check_for_connection_lost(self) -> None: + if self.__connection_lost_exception is not None: + raise self.__connection_lost_exception.with_traceback(self.__connection_lost_exception_tb) + def pause_writing(self) -> None: self.__write_flow.pause_writing() diff --git a/tests/unit_test/test_async/test_asyncio_backend/test_stream.py b/tests/unit_test/test_async/test_asyncio_backend/test_stream.py index 5befbe4d..efa26c60 100644 --- a/tests/unit_test/test_async/test_asyncio_backend/test_stream.py +++ b/tests/unit_test/test_async/test_asyncio_backend/test_stream.py @@ -1118,7 +1118,8 @@ async def test____receive_data____default( self.write_in_protocol_buffer(protocol, b"abcdef") # Act - data = await data_receiver(protocol, 1024) + async with asyncio.timeout(5): + data = await data_receiver(protocol, 1024) # Assert assert data == b"abcdef" @@ -1135,8 +1136,9 @@ async def test____receive_data____partial_read( self.write_in_protocol_buffer(protocol, b"abcdef") # Act - first = await data_receiver(protocol, 3) - second = await data_receiver(protocol, 3) + async with asyncio.timeout(5): + first = await data_receiver(protocol, 3) + second = await data_receiver(protocol, 3) # Assert assert first == b"abc" @@ -1144,22 +1146,64 @@ async def test____receive_data____partial_read( mock_asyncio_transport.resume_reading.assert_not_called() @pytest.mark.asyncio - async def test____receive_data____buffer_updated_several_times( + @pytest.mark.parametrize("blocking", [False, True], ids=lambda p: f"blocking=={p}") + @pytest.mark.parametrize("data_receiver", ["data"], indirect=True) + async def test____receive_data____owned_data____buffer_updated_several_times( self, + blocking: bool, protocol: StreamReaderBufferedProtocol, mock_asyncio_transport: MagicMock, data_receiver: _ProtocolDataReceiver, ) -> None: # Arrange event_loop = asyncio.get_running_loop() - event_loop.call_soon(self.write_in_protocol_buffer, protocol, b"abc") - event_loop.call_soon(self.write_in_protocol_buffer, protocol, b"def") + if blocking: + event_loop.call_soon(self.write_in_protocol_buffer, protocol, b"abc") + event_loop.call_soon(self.write_in_protocol_buffer, protocol, b"def") + else: + self.write_in_protocol_buffer(protocol, b"abc") + self.write_in_protocol_buffer(protocol, b"def") # Act - data = await data_receiver(protocol, 1024) + async with asyncio.timeout(5): + data = await data_receiver(protocol, 1024) # Assert assert data == b"abcdef" + assert protocol._get_read_buffer_size() == 0 + mock_asyncio_transport.resume_reading.assert_not_called() + + @pytest.mark.asyncio + @pytest.mark.parametrize("blocking", [False, True], ids=lambda p: f"blocking=={p}") + @pytest.mark.parametrize("data_receiver", ["buffer"], indirect=True) + async def test____receive_data____into_buffer____buffer_updated_several_times( + self, + blocking: bool, + protocol: StreamReaderBufferedProtocol, + mock_asyncio_transport: MagicMock, + data_receiver: _ProtocolDataReceiver, + ) -> None: + # Arrange + event_loop = asyncio.get_running_loop() + if blocking: + event_loop.call_soon(self.write_in_protocol_buffer, protocol, b"abc") + event_loop.call_soon(self.write_in_protocol_buffer, protocol, b"def") + else: + self.write_in_protocol_buffer(protocol, b"abc") + self.write_in_protocol_buffer(protocol, b"def") + + # Act + async with asyncio.timeout(5): + data = await data_receiver(protocol, 1024) + + # Assert + if blocking: + assert data == b"abc" + assert protocol._get_read_buffer_size() == 3 # should be b"def" + assert (await data_receiver(protocol, 1024)) == b"def" + else: + assert data == b"abcdef" + assert protocol._get_read_buffer_size() == 0 mock_asyncio_transport.resume_reading.assert_not_called() @pytest.mark.asyncio @@ -1172,7 +1216,8 @@ async def test____receive_data____null_bufsize( # Arrange # Act - data = await data_receiver(protocol, 0) + async with asyncio.timeout(5): + data = await data_receiver(protocol, 0) # Assert assert data == b"" @@ -1208,7 +1253,8 @@ def protocol_eof_handler() -> None: protocol_eof_handler() # Act - data = await data_receiver(protocol, 1024) + async with asyncio.timeout(5): + data = await data_receiver(protocol, 1024) # Assert assert data == b"" @@ -1243,14 +1289,14 @@ async def test____receive_data____connection_reset( data_receiver: _ProtocolDataReceiver, ) -> None: # Arrange - event_loop = asyncio.get_running_loop() - event_loop.call_soon(self.write_in_protocol_buffer, protocol, b"abc") - event_loop.call_soon(protocol.connection_lost, None) + self.write_in_protocol_buffer(protocol, b"abc") + protocol.connection_lost(None) # Act & Assert for _ in range(3): with pytest.raises(ConnectionResetError): - _ = await data_receiver(protocol, 1024) + async with asyncio.timeout(5): + _ = await data_receiver(protocol, 1024) @pytest.mark.asyncio async def test____receive_data____invalid_bufsize(