Skip to content

Commit

Permalink
[FIX] Implement zero copy writes for TCP socket (sync and async) tran…
Browse files Browse the repository at this point in the history
…sports
  • Loading branch information
francis-clairicia committed Nov 18, 2023
1 parent 8f50e10 commit 1d58bb5
Show file tree
Hide file tree
Showing 21 changed files with 691 additions and 68 deletions.
7 changes: 2 additions & 5 deletions docs/source/howto/serializers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -204,11 +204,8 @@ Most of the time, you will have a single :keyword:`yield`. The goal is: each :ke

.. note::

The endpoint implementation can (and most likely will) decide to concatenate all the pieces and do one big send.
This is the optimized way to send a large byte buffer.

However, it may be more attractive to do something else with the returned bytes.
:meth:`~.AbstractIncrementalPacketSerializer.incremental_serialize` is here to give endpoints this freedom.
The endpoint implementation can decide to concatenate all the pieces and do one big send. However, it may be more attractive to do something else
with the returned bytes. :meth:`~.AbstractIncrementalPacketSerializer.incremental_serialize` is here to give endpoints this freedom.


The Purpose Of ``incremental_deserialize()``
Expand Down
33 changes: 32 additions & 1 deletion src/easynetwork/lowlevel/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

__all__ = [
"ElapsedTime",
"adjust_leftover_buffer",
"check_real_socket_state",
"check_socket_family",
"check_socket_no_ssl",
Expand All @@ -30,6 +31,7 @@
"remove_traceback_frames_in_place",
"replace_kwargs",
"set_reuseport",
"supports_socket_sendmsg",
"validate_timeout_delay",
]

Expand All @@ -41,8 +43,10 @@
import socket as _socket
import threading
import time
from abc import abstractmethod
from collections import deque
from collections.abc import Callable, Iterable, Iterator
from typing import TYPE_CHECKING, Any, Concatenate, ParamSpec, Self, TypeGuard, TypeVar
from typing import TYPE_CHECKING, Any, Concatenate, Final, ParamSpec, Protocol, Self, TypeGuard, TypeVar

try:
import ssl as _ssl
Expand All @@ -58,6 +62,8 @@
if TYPE_CHECKING:
from ssl import SSLError as _SSLError, SSLSocket as _SSLSocket

from _typeshed import ReadableBuffer

from .socket import ISocket, SupportsSocketOptions

_P = ParamSpec("_P")
Expand Down Expand Up @@ -130,6 +136,20 @@ def check_real_socket_state(socket: ISocket) -> None:
raise error_from_errno(errno)


_HAS_SENDMSG: Final[bool] = hasattr(_socket.socket, "sendmsg")


class _SupportsSocketSendMSG(Protocol):
@abstractmethod
def sendmsg(self, buffers: Iterable[ReadableBuffer], /) -> int:
...


def supports_socket_sendmsg(sock: _socket.socket) -> TypeGuard[_SupportsSocketSendMSG]:
assert isinstance(sock, _socket.SocketType) # nosec assert_used
return _HAS_SENDMSG


def is_ssl_socket(socket: _socket.socket) -> TypeGuard[_SSLSocket]:
if ssl is None:
return False
Expand Down Expand Up @@ -170,6 +190,17 @@ def iter_bytes(b: bytes | bytearray | memoryview) -> Iterator[bytes]:
return map(int.to_bytes, b)


def adjust_leftover_buffer(buffers: deque[memoryview], nbytes: int) -> None:
while nbytes > 0:
b = buffers.popleft()
b_len = len(b)
if b_len <= nbytes:
nbytes -= b_len
else:
buffers.appendleft(b[nbytes:])
break


def is_socket_connected(sock: ISocket) -> bool:
try:
sock.getpeername()
Expand Down
14 changes: 5 additions & 9 deletions src/easynetwork/lowlevel/api_async/transports/abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,19 +118,15 @@ async def send_all_from_iterable(self, iterable_of_data: Iterable[bytes | bytear
"""
An efficient way to send a bunch of data via the transport.
Currently, the default implementation concatenates the arguments and
calls :meth:`send_all` on the result.
Like :meth:`send_all`, this method continues to send data from bytes until either all data has been sent or an error
occurs. :data:`None` is returned on success. On error, an exception is raised, and there is no way to determine how much
data, if any, was successfully sent.
Parameters:
iterable_of_data: An :term:`iterable` yielding the bytes to send.
"""
iterable_of_data = list(iterable_of_data)
if len(iterable_of_data) == 1:
data = iterable_of_data[0]
else:
data = b"".join(iterable_of_data)
del iterable_of_data
return await self.send_all(data)
for data in iterable_of_data:
await self.send_all(data)


class AsyncStreamTransport(AsyncStreamWriteTransport, AsyncStreamReadTransport):
Expand Down
18 changes: 8 additions & 10 deletions src/easynetwork/lowlevel/api_sync/transports/abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def send_all(self, data: bytes | bytearray | memoryview, timeout: float) -> None
"""

total_sent: int = 0
with memoryview(data) as data:
with memoryview(data) as data, data.cast("B") as data:
nb_bytes_to_send = len(data)
if nb_bytes_to_send == 0:
sent = self.send(data, timeout)
Expand All @@ -148,8 +148,9 @@ def send_all_from_iterable(self, iterable_of_data: Iterable[bytes | bytearray |
"""
An efficient way to send a bunch of data via the transport.
Currently, the default implementation concatenates the arguments and
calls :meth:`send_all` on the result.
Like :meth:`send_all`, this method continues to send data from bytes until either all data has been sent or an error
occurs. :data:`None` is returned on success. On error, an exception is raised, and there is no way to determine how much
data, if any, was successfully sent.
Parameters:
iterable_of_data: An :term:`iterable` yielding the bytes to send.
Expand All @@ -159,13 +160,10 @@ def send_all_from_iterable(self, iterable_of_data: Iterable[bytes | bytearray |
ValueError: Negative `timeout`.
TimeoutError: Operation timed out.
"""
iterable_of_data = list(iterable_of_data)
if len(iterable_of_data) == 1:
data = iterable_of_data[0]
else:
data = b"".join(iterable_of_data)
del iterable_of_data
return self.send_all(data, timeout)
for data in iterable_of_data:
with _utils.ElapsedTime() as elapsed:
self.send_all(data, timeout)
timeout = elapsed.recompute_timeout(timeout)


class StreamTransport(StreamWriteTransport, StreamReadTransport):
Expand Down
29 changes: 27 additions & 2 deletions src/easynetwork/lowlevel/api_sync/transports/socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,11 @@
"SocketStreamTransport",
]

import itertools
import selectors
import socket
from collections import ChainMap
from collections.abc import Callable, Mapping
from collections import ChainMap, deque
from collections.abc import Callable, Iterable, Mapping
from typing import TYPE_CHECKING, Any, ParamSpec, TypeVar

try:
Expand Down Expand Up @@ -96,6 +97,30 @@ def send_noblock(self, data: bytes | bytearray | memoryview) -> int:
except (BlockingIOError, InterruptedError):
raise base_selector.WouldBlockOnWrite(self.__socket.fileno()) from None

@_utils.inherit_doc(base_selector.SelectorStreamTransport)
def send_all_from_iterable(self, iterable_of_data: Iterable[bytes | bytearray | memoryview], timeout: float) -> None:
_sock = self.__socket
if constants.SC_IOV_MAX <= 0 or not _utils.supports_socket_sendmsg(_sock):
return super().send_all_from_iterable(iterable_of_data, timeout)

buffers: deque[memoryview] = deque(memoryview(data).cast("B") for data in iterable_of_data)
del iterable_of_data

sock_sendmsg = _sock.sendmsg
del _sock

def try_sendmsg() -> int:
try:
return sock_sendmsg(itertools.islice(buffers, constants.SC_IOV_MAX))
except (BlockingIOError, InterruptedError):
raise base_selector.WouldBlockOnWrite(self.__socket.fileno()) from None

while buffers:
with _utils.ElapsedTime() as elapsed:
sent: int = self._retry(try_sendmsg, timeout)
_utils.adjust_leftover_buffer(buffers, sent)
timeout = elapsed.recompute_timeout(timeout)

@_utils.inherit_doc(base_selector.SelectorStreamTransport)
def send_eof(self) -> None:
if self.__socket.fileno() < 0:
Expand Down
35 changes: 34 additions & 1 deletion src/easynetwork/lowlevel/asyncio/_asyncio_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,12 @@

from __future__ import annotations

__all__ = ["create_connection", "open_listener_sockets_from_getaddrinfo_result"]
__all__ = [
"create_connection",
"open_listener_sockets_from_getaddrinfo_result",
"wait_until_readable",
"wait_until_writable",
]

import asyncio
import contextlib
Expand Down Expand Up @@ -216,3 +221,31 @@ def open_listener_sockets_from_getaddrinfo_result(
socket_exit_stack.pop_all()

return sockets


def wait_until_readable(sock: _socket.socket, loop: asyncio.AbstractEventLoop) -> asyncio.Future[None]:
def on_fut_done(f: asyncio.Future[None]) -> None:
loop.remove_reader(sock)

def wakeup(f: asyncio.Future[None]) -> None:
if not f.done():
f.set_result(None)

f = loop.create_future()
loop.add_reader(sock, wakeup, f)
f.add_done_callback(on_fut_done)
return f


def wait_until_writable(sock: _socket.socket, loop: asyncio.AbstractEventLoop) -> asyncio.Future[None]:
def on_fut_done(f: asyncio.Future[None]) -> None:
loop.remove_writer(sock)

def wakeup(f: asyncio.Future[None]) -> None:
if not f.done():
f.set_result(None)

f = loop.create_future()
loop.add_writer(sock, wakeup, f)
f.add_done_callback(on_fut_done)
return f
33 changes: 29 additions & 4 deletions src/easynetwork/lowlevel/asyncio/socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,16 @@
import asyncio.trsock
import contextlib
import errno as _errno
import itertools
import socket as _socket
from collections.abc import Iterator
from typing import TYPE_CHECKING, Literal, Self, TypeAlias
from collections import deque
from collections.abc import Iterable, Iterator
from typing import TYPE_CHECKING, Literal, Self, TypeAlias, cast
from weakref import WeakSet

from .. import _utils
from ...exceptions import UnsupportedOperation
from .. import _utils, constants
from . import _asyncio_utils
from .tasks import CancelScope, TaskUtils

if TYPE_CHECKING:
Expand Down Expand Up @@ -120,6 +124,26 @@ async def sendall(self, data: ReadableBuffer, /) -> None:
socket = self.__check_not_closed()
await self.__loop.sock_sendall(socket, data)

async def sendmsg(self, buffers: Iterable[ReadableBuffer], /) -> None:
with self.__conflict_detection("send", abort_errno=_errno.ECONNABORTED):
socket = self.__check_not_closed()
if constants.SC_IOV_MAX <= 0 or not _utils.supports_socket_sendmsg(_sock := socket):
raise UnsupportedOperation("sendmsg() is not supported")

loop = self.__loop
buffers = cast("deque[memoryview]", deque(memoryview(data).cast("B") for data in buffers))

sock_sendmsg = _sock.sendmsg
del _sock

while buffers:
try:
sent: int = sock_sendmsg(itertools.islice(buffers, constants.SC_IOV_MAX))
except (BlockingIOError, InterruptedError):
await _asyncio_utils.wait_until_writable(socket, loop)
else:
_utils.adjust_leftover_buffer(buffers, sent)

async def sendto(self, data: ReadableBuffer, address: _socket._Address, /) -> None:
with self.__conflict_detection("send", abort_errno=_errno.ECONNABORTED):
socket = self.__check_not_closed()
Expand Down Expand Up @@ -152,7 +176,8 @@ def __conflict_detection(self, task_id: _SocketTaskId, *, abort_errno: int = _er
if task_id in self.__waiters:
raise _utils.error_from_errno(_errno.EBUSY)

_ = TaskUtils.current_asyncio_task(self.__loop)
# Checks if we are within the bound loop
TaskUtils.current_asyncio_task(self.__loop) # type: ignore[unused-awaitable]

with CancelScope() as scope, contextlib.ExitStack() as stack:
self.__scopes.add(scope)
Expand Down
6 changes: 6 additions & 0 deletions src/easynetwork/lowlevel/asyncio/stream/socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,12 @@ async def send_all(self, data: bytes | bytearray | memoryview) -> None:
async def send_eof(self) -> None:
await self.__socket.shutdown(_socket.SHUT_WR)

async def send_all_from_iterable(self, iterable_of_data: Iterable[bytes | bytearray | memoryview]) -> None:
try:
await self.__socket.sendmsg(iterable_of_data)
except UnsupportedOperation:
await super().send_all_from_iterable(iterable_of_data)

@property
def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]:
socket = self.__socket.socket
Expand Down
18 changes: 18 additions & 0 deletions src/easynetwork/lowlevel/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
"DEFAULT_STREAM_BUFSIZE",
"MAX_DATAGRAM_BUFSIZE",
"NOT_CONNECTED_SOCKET_ERRNOS",
"SC_IOV_MAX",
"SSL_HANDSHAKE_TIMEOUT",
"SSL_SHUTDOWN_TIMEOUT",
"_DEFAULT_LIMIT",
Expand Down Expand Up @@ -80,3 +81,20 @@

# Buffer size limit when waiting for a byte sequence
_DEFAULT_LIMIT: Final[int] = 64 * 1024 # 64 KiB


def __get_sysconf(name: str, /) -> int:
import os

try:
# os.sysconf() can return a negative value if 'name' is not defined
return os.sysconf(name) # type: ignore[attr-defined,unused-ignore]
except (AttributeError, OSError):
return -1


# Maximum number of buffer that can accept sendmsg(2)
# Can be a negative value
SC_IOV_MAX: Final[int] = __get_sysconf("SC_IOV_MAX")

del __get_sysconf
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ def use_ssl(request: Any) -> bool:
case "USE_SSL":
return True
case _:
raise SystemError
pytest.fail(f"Invalid parameter: {request.param}")

@pytest.fixture
@staticmethod
Expand Down
9 changes: 9 additions & 0 deletions tests/tools.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import asyncio
import sys
import time
from collections.abc import Generator
Expand Down Expand Up @@ -56,3 +57,11 @@ def __exit__(self, exc_type: type[Exception] | None, exc_value: Exception | None
return
assert self.start_time >= 0
assert end_time - self.start_time == pytest.approx(self.expected_time, rel=self.approx)


def is_proactor_event_loop(event_loop: asyncio.AbstractEventLoop) -> bool:
try:
ProactorEventLoop: type[asyncio.AbstractEventLoop] = getattr(asyncio, "ProactorEventLoop")
except AttributeError:
return False
return isinstance(event_loop, ProactorEventLoop)
Loading

0 comments on commit 1d58bb5

Please sign in to comment.