Skip to content

Commit

Permalink
AsyncIOBackend: Improved recv_into() implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
francis-clairicia committed Nov 23, 2024
1 parent e69a892 commit 2f685e5
Show file tree
Hide file tree
Showing 2 changed files with 134 additions and 58 deletions.
120 changes: 75 additions & 45 deletions src/easynetwork/lowlevel/api_async/backend/_asyncio/stream/socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
import warnings
from collections.abc import Callable, Iterable, Mapping
from types import MappingProxyType, TracebackType
from typing import TYPE_CHECKING, Any, final
from typing import TYPE_CHECKING, Any, final, overload

from ......exceptions import UnsupportedOperation
from ..... import _utils, socket as socket_tools
Expand Down Expand Up @@ -143,6 +143,7 @@ class StreamReaderBufferedProtocol(asyncio.BufferedProtocol):
"__buffer",
"__buffer_view",
"__buffer_nbytes_written",
"__external_buffer_view",
"__transport",
"__closed",
"__write_flow",
Expand Down Expand Up @@ -170,10 +171,11 @@ def __init__(
self.__loop: asyncio.AbstractEventLoop = loop
self.__buffer: bytearray | None = bytearray(self.max_size)
self.__buffer_view: memoryview = memoryview(self.__buffer)
self.__external_buffer_view: WriteableBuffer | None = None
self.__buffer_nbytes_written: int = 0
self.__transport: asyncio.Transport | None = None
self.__closed: asyncio.Future[None] = loop.create_future()
self.__read_waiter: asyncio.Future[None] | None = None
self.__read_waiter: asyncio.Future[int | None] | None = None
self.__write_flow: WriteFlowControl
self.__read_paused: bool = False
self.__connection_lost: bool = False
Expand Down Expand Up @@ -219,6 +221,8 @@ def connection_lost(self, exc: Exception | None) -> None:
self.__transport = None

def get_buffer(self, sizehint: int) -> WriteableBuffer:
if (external_buffer_view := self.__external_buffer_view) is not None:
return external_buffer_view
# Ignore sizehint, the buffer is already at its maximum size.
# Returns unused buffer part
if self.__buffer is None:
Expand All @@ -229,12 +233,21 @@ def buffer_updated(self, nbytes: int) -> None:
assert not self.__connection_lost, "buffer_updated() after connection_lost()" # nosec assert_used
assert not self.__eof_reached, "buffer_updated() after eof_received()" # nosec assert_used

if self.__external_buffer_view is not None:
# Early remove to prevent using this buffer between this point and the wakeup of the task.
self.__external_buffer_view = None
self._read_waiter_fut(lambda waiter: waiter.set_result(nbytes))
# Call to _maybe_pause_transport() is unnecessary: Did not write in internal buffer.
return

self.__buffer_nbytes_written += nbytes
assert 0 <= self.__buffer_nbytes_written <= self.__buffer_view.nbytes # nosec assert_used
self._wakeup_read_waiter(None)
self._maybe_pause_transport()

def eof_received(self) -> bool:
# Early remove to prevent using this buffer between this point and the wakeup of the task.
self.__external_buffer_view = None
self.__eof_reached = True
self._wakeup_read_waiter(None)
if self.__over_ssl:
Expand All @@ -245,14 +258,13 @@ def eof_received(self) -> bool:
return True

async def receive_data(self, bufsize: int, /) -> bytes:
if self.__connection_lost_exception is not None:
raise self.__connection_lost_exception.with_traceback(self.__connection_lost_exception_tb)
self._check_for_connection_lost()
if bufsize == 0:
return b""
if bufsize < 0:
raise ValueError("'bufsize' must be a positive or null integer")

blocked: bool = await self._wait_for_data("receive_data")
await self._wait_for_data("receive_data", None)

nbytes_written = self.__buffer_nbytes_written
if nbytes_written:
Expand All @@ -266,66 +278,80 @@ async def receive_data(self, bufsize: int, /) -> bytes:
else:
data = b""
self._maybe_resume_transport()
if not blocked:
await TaskUtils.cancel_shielded_coro_yield()
return data

async def receive_data_into(self, buffer: WriteableBuffer, /) -> int:
if self.__connection_lost_exception is not None:
raise self.__connection_lost_exception.with_traceback(self.__connection_lost_exception_tb)
with memoryview(buffer) as buffer, buffer.cast("B") as buffer:
self._check_for_connection_lost()

with memoryview(buffer) as buffer:
if not buffer:
return 0

blocked: bool = await self._wait_for_data("receive_data_into")

nbytes_written = self.__buffer_nbytes_written
if nbytes_written:
protocol_buffer_written = self.__buffer_view[:nbytes_written]
bufsize_offset = nbytes_written - buffer.nbytes
if bufsize_offset > 0:
nbytes_written = buffer.nbytes
buffer[:] = protocol_buffer_written[:nbytes_written]
protocol_buffer_written[:bufsize_offset] = protocol_buffer_written[-bufsize_offset:]
self.__buffer_nbytes_written = bufsize_offset
else:
buffer[:nbytes_written] = protocol_buffer_written
self.__buffer_nbytes_written = 0
with buffer.cast("B") if buffer.itemsize != 1 else buffer as buffer:
nbytes_written = await self._wait_for_data("receive_data_into", buffer)
if nbytes_written is not None:
# Call to _maybe_resume_transport() is unnecessary: Did not write in internal buffer.
return nbytes_written

nbytes_written = self.__buffer_nbytes_written
if nbytes_written:
protocol_buffer_written = self.__buffer_view[:nbytes_written]
bufsize_offset = nbytes_written - buffer.nbytes
if bufsize_offset > 0:
nbytes_written = buffer.nbytes
buffer[:] = protocol_buffer_written[:nbytes_written]
protocol_buffer_written[:bufsize_offset] = protocol_buffer_written[-bufsize_offset:]
self.__buffer_nbytes_written = bufsize_offset
else:
buffer[:nbytes_written] = protocol_buffer_written
self.__buffer_nbytes_written = 0

self._maybe_resume_transport()
if not blocked:
await TaskUtils.cancel_shielded_coro_yield()
return nbytes_written

async def _wait_for_data(self, requester: str) -> bool:
if self.__read_waiter is not None:
raise RuntimeError(f"{requester}() called while another coroutine is already waiting for incoming data")

if self.__buffer_nbytes_written or self.__eof_reached:
return False
@overload
async def _wait_for_data(self, requester: str, external_buffer: None) -> None: ...

assert not self.__read_paused, "transport reading is paused" # nosec assert_used
@overload
async def _wait_for_data(self, requester: str, external_buffer: WriteableBuffer) -> int | None: ...

if self.__transport is None:
# happening if transport.pause_reading() raises NotImplementedError
raise _utils.error_from_errno(_errno.ECONNABORTED)
async def _wait_for_data(self, requester: str, external_buffer: WriteableBuffer | None) -> int | None:
if self.__read_waiter is not None:
raise RuntimeError(f"{requester}() called while another coroutine is already waiting for incoming data")

self.__read_waiter = self.__loop.create_future()
try:
await self.__read_waiter
nbytes_written_in_external_buffer: int | None
if self.__buffer_nbytes_written or self.__eof_reached:
self.__read_waiter.set_result(None)
await TaskUtils.coro_yield()
nbytes_written_in_external_buffer = None
else:
assert not self.__read_paused, "transport reading is paused" # nosec assert_used

if self.__transport is None:
# happening if transport.pause_reading() raises NotImplementedError
raise _utils.error_from_errno(_errno.ECONNABORTED)

self.__external_buffer_view = external_buffer
try:
nbytes_written_in_external_buffer = await self.__read_waiter
finally:
self.__external_buffer_view = None
finally:
self.__read_waiter = None
if self.__connection_lost_exception is not None:
raise self.__connection_lost_exception.with_traceback(self.__connection_lost_exception_tb)
return True

def _wakeup_read_waiter(self, exc: Exception | None) -> None:
if nbytes_written_in_external_buffer is None:
self._check_for_connection_lost()
return nbytes_written_in_external_buffer

def _read_waiter_fut(self, set_result_cb: Callable[[asyncio.Future[int | None]], None]) -> None:
if (waiter := self.__read_waiter) is not None:
if not waiter.done():
if exc is None:
waiter.set_result(None)
else:
waiter.set_exception(exc)
set_result_cb(waiter)

def _wakeup_read_waiter(self, exc: Exception | None) -> None:
self._read_waiter_fut(lambda waiter: waiter.set_result(None) if exc is None else waiter.set_exception(exc))

def _get_read_buffer_size(self) -> int:
return self.__buffer_nbytes_written
Expand Down Expand Up @@ -365,6 +391,10 @@ def _maybe_resume_transport(self) -> None:
transport.resume_reading()
self.__read_paused = False

def _check_for_connection_lost(self) -> None:
if self.__connection_lost_exception is not None:
raise self.__connection_lost_exception.with_traceback(self.__connection_lost_exception_tb)

def pause_writing(self) -> None:
self.__write_flow.pause_writing()

Expand Down
72 changes: 59 additions & 13 deletions tests/unit_test/test_async/test_asyncio_backend/test_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -1118,7 +1118,8 @@ async def test____receive_data____default(
self.write_in_protocol_buffer(protocol, b"abcdef")

# Act
data = await data_receiver(protocol, 1024)
async with asyncio.timeout(5):
data = await data_receiver(protocol, 1024)

# Assert
assert data == b"abcdef"
Expand All @@ -1135,31 +1136,74 @@ async def test____receive_data____partial_read(
self.write_in_protocol_buffer(protocol, b"abcdef")

# Act
first = await data_receiver(protocol, 3)
second = await data_receiver(protocol, 3)
async with asyncio.timeout(5):
first = await data_receiver(protocol, 3)
second = await data_receiver(protocol, 3)

# Assert
assert first == b"abc"
assert second == b"def"
mock_asyncio_transport.resume_reading.assert_not_called()

@pytest.mark.asyncio
async def test____receive_data____buffer_updated_several_times(
@pytest.mark.parametrize("blocking", [False, True], ids=lambda p: f"blocking=={p}")
@pytest.mark.parametrize("data_receiver", ["data"], indirect=True)
async def test____receive_data____owned_data____buffer_updated_several_times(
self,
blocking: bool,
protocol: StreamReaderBufferedProtocol,
mock_asyncio_transport: MagicMock,
data_receiver: _ProtocolDataReceiver,
) -> None:
# Arrange
event_loop = asyncio.get_running_loop()
event_loop.call_soon(self.write_in_protocol_buffer, protocol, b"abc")
event_loop.call_soon(self.write_in_protocol_buffer, protocol, b"def")
if blocking:
event_loop.call_soon(self.write_in_protocol_buffer, protocol, b"abc")
event_loop.call_soon(self.write_in_protocol_buffer, protocol, b"def")
else:
self.write_in_protocol_buffer(protocol, b"abc")
self.write_in_protocol_buffer(protocol, b"def")

# Act
data = await data_receiver(protocol, 1024)
async with asyncio.timeout(5):
data = await data_receiver(protocol, 1024)

# Assert
assert data == b"abcdef"
assert protocol._get_read_buffer_size() == 0
mock_asyncio_transport.resume_reading.assert_not_called()

@pytest.mark.asyncio
@pytest.mark.parametrize("blocking", [False, True], ids=lambda p: f"blocking=={p}")
@pytest.mark.parametrize("data_receiver", ["buffer"], indirect=True)
async def test____receive_data____into_buffer____buffer_updated_several_times(
self,
blocking: bool,
protocol: StreamReaderBufferedProtocol,
mock_asyncio_transport: MagicMock,
data_receiver: _ProtocolDataReceiver,
) -> None:
# Arrange
event_loop = asyncio.get_running_loop()
if blocking:
event_loop.call_soon(self.write_in_protocol_buffer, protocol, b"abc")
event_loop.call_soon(self.write_in_protocol_buffer, protocol, b"def")
else:
self.write_in_protocol_buffer(protocol, b"abc")
self.write_in_protocol_buffer(protocol, b"def")

# Act
async with asyncio.timeout(5):
data = await data_receiver(protocol, 1024)

# Assert
if blocking:
assert data == b"abc"
assert protocol._get_read_buffer_size() == 3 # should be b"def"
assert (await data_receiver(protocol, 1024)) == b"def"
else:
assert data == b"abcdef"
assert protocol._get_read_buffer_size() == 0
mock_asyncio_transport.resume_reading.assert_not_called()

@pytest.mark.asyncio
Expand All @@ -1172,7 +1216,8 @@ async def test____receive_data____null_bufsize(
# Arrange

# Act
data = await data_receiver(protocol, 0)
async with asyncio.timeout(5):
data = await data_receiver(protocol, 0)

# Assert
assert data == b""
Expand Down Expand Up @@ -1208,7 +1253,8 @@ def protocol_eof_handler() -> None:
protocol_eof_handler()

# Act
data = await data_receiver(protocol, 1024)
async with asyncio.timeout(5):
data = await data_receiver(protocol, 1024)

# Assert
assert data == b""
Expand Down Expand Up @@ -1243,14 +1289,14 @@ async def test____receive_data____connection_reset(
data_receiver: _ProtocolDataReceiver,
) -> None:
# Arrange
event_loop = asyncio.get_running_loop()
event_loop.call_soon(self.write_in_protocol_buffer, protocol, b"abc")
event_loop.call_soon(protocol.connection_lost, None)
self.write_in_protocol_buffer(protocol, b"abc")
protocol.connection_lost(None)

# Act & Assert
for _ in range(3):
with pytest.raises(ConnectionResetError):
_ = await data_receiver(protocol, 1024)
async with asyncio.timeout(5):
_ = await data_receiver(protocol, 1024)

@pytest.mark.asyncio
async def test____receive_data____invalid_bufsize(
Expand Down

0 comments on commit 2f685e5

Please sign in to comment.