Skip to content

Commit

Permalink
AsyncIOBackend: Implemented happy eyeballs delay internally (#200)
Browse files Browse the repository at this point in the history
  • Loading branch information
francis-clairicia authored Dec 17, 2023
1 parent 065d35b commit ae4d4ed
Show file tree
Hide file tree
Showing 7 changed files with 445 additions and 222 deletions.
4 changes: 4 additions & 0 deletions src/easynetwork/lowlevel/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,10 @@
# The default timeout mimics lingering_time
SSL_SHUTDOWN_TIMEOUT: Final[float] = 30.0

# "Connection Attempt Delay" for concurrent connections
# Recommended value by the RFC 6555
HAPPY_EYEBALLS_DELAY = 0.25

# Buffer size limit when waiting for a byte sequence
_DEFAULT_LIMIT: Final[int] = 64 * 1024 # 64 KiB

Expand Down
89 changes: 87 additions & 2 deletions src/easynetwork/lowlevel/std_asyncio/_asyncio_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,11 @@
import asyncio
import contextlib
import itertools
import math
import socket as _socket
from collections import OrderedDict
from collections.abc import Iterable, Sequence
from typing import Any
from typing import Any, cast

from .. import _utils

Expand Down Expand Up @@ -144,11 +146,93 @@ async def _create_connection_impl(
errors.clear()


# Taken from asyncio library
def _interleave_addrinfos(
addrinfos: Sequence[tuple[int, int, int, str, tuple[Any, ...]]]
) -> list[tuple[int, int, int, str, tuple[Any, ...]]]:
"""Interleave list of addrinfo tuples by family."""
# Group addresses by family
addrinfos_by_family: OrderedDict[int, list[tuple[Any, ...]]] = OrderedDict()
for addr in addrinfos:
family = addr[0]
if family not in addrinfos_by_family:
addrinfos_by_family[family] = []
addrinfos_by_family[family].append(addr)
addrinfos_lists = list(addrinfos_by_family.values())
return [addr for addr in itertools.chain.from_iterable(itertools.zip_longest(*addrinfos_lists)) if addr is not None]


# Taken from anyio project
def _prioritize_ipv6_over_ipv4(
addrinfos: Sequence[tuple[int, int, int, str, tuple[Any, ...]]]
) -> list[tuple[int, int, int, str, tuple[Any, ...]]]:
# Organize the list so that the first address is an IPv6 address (if available)
# and the second one is an IPv4 addresses. The rest can be in whatever order.
v6_found = v4_found = False
reordered: list[tuple[int, int, int, str, tuple[Any, ...]]] = []
for addr in addrinfos:
family = addr[0]
if family == _socket.AF_INET6 and not v6_found:
v6_found = True
reordered.insert(0, addr)
elif family == _socket.AF_INET and not v4_found and v6_found:
v4_found = True
reordered.insert(1, addr)
else:
reordered.append(addr)
return reordered


async def _staggered_race_connection_impl(
*,
loop: asyncio.AbstractEventLoop,
remote_addrinfo: Sequence[tuple[int, int, int, str, tuple[Any, ...]]],
local_addrinfo: Sequence[tuple[int, int, int, str, tuple[Any, ...]]] | None,
happy_eyeballs_delay: float,
) -> _socket.socket:
from .tasks import CancelScope

remote_addrinfo = _interleave_addrinfos(_prioritize_ipv6_over_ipv4(remote_addrinfo))
winner: _socket.socket | None = cast(_socket.socket | None, None)
errors: list[OSError | BaseExceptionGroup[OSError]] = []

async def try_connect(addr: tuple[int, int, int, str, tuple[Any, ...]]) -> None:
nonlocal winner
try:
socket = await _create_connection_impl(loop=loop, remote_addrinfo=[addr], local_addrinfo=local_addrinfo)
except* OSError as excgrp:
errors.extend(excgrp.exceptions)
else:
if winner is None:
winner = socket
connection_scope.cancel()
else:
socket.close()

try:
with CancelScope() as connection_scope:
async with asyncio.TaskGroup() as task_group:
for addr in remote_addrinfo:
await asyncio.wait({task_group.create_task(try_connect(addr))}, timeout=happy_eyeballs_delay)

if winner is None:
raise BaseExceptionGroup("create_connection() failed", errors)
return winner
except BaseException:
if winner is not None:
winner.close()
raise
finally:
errors.clear()


async def create_connection(
host: str,
port: int,
loop: asyncio.AbstractEventLoop,
local_address: tuple[str, int] | None = None,
*,
happy_eyeballs_delay: float = math.inf,
) -> _socket.socket:
remote_addrinfo: Sequence[tuple[int, int, int, str, tuple[Any, ...]]] = await ensure_resolved(
host,
Expand All @@ -168,10 +252,11 @@ async def create_connection(
loop=loop,
)

return await _create_connection_impl(
return await _staggered_race_connection_impl(
loop=loop,
remote_addrinfo=remote_addrinfo,
local_addrinfo=local_addrinfo,
happy_eyeballs_delay=happy_eyeballs_delay,
)


Expand Down
84 changes: 32 additions & 52 deletions src/easynetwork/lowlevel/std_asyncio/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
from ...exceptions import UnsupportedOperation
from ..api_async.backend import _sniffio_helpers
from ..api_async.backend.abc import AsyncBackend as AbstractAsyncBackend
from ..constants import HAPPY_EYEBALLS_DELAY as _DEFAULT_HAPPY_EYEBALLS_DELAY
from ._asyncio_utils import (
create_connection,
create_datagram_connection,
Expand Down Expand Up @@ -122,30 +123,19 @@ async def create_tcp_connection(
local_address: tuple[str, int] | None = None,
happy_eyeballs_delay: float | None = None,
) -> AsyncioTransportStreamSocketAdapter | RawStreamSocketAdapter:
if happy_eyeballs_delay is not None:
self._check_asyncio_transport("'happy_eyeballs_delay' option")

if not self.__use_asyncio_transport:
loop = asyncio.get_running_loop()
socket = await create_connection(host, port, loop, local_address=local_address)
return RawStreamSocketAdapter(socket, loop)
if happy_eyeballs_delay is None:
happy_eyeballs_delay = _DEFAULT_HAPPY_EYEBALLS_DELAY

happy_eyeballs_delay = self._default_happy_eyeballs_delay(happy_eyeballs_delay)
loop = asyncio.get_running_loop()
socket = await create_connection(
host,
port,
loop,
local_address=local_address,
happy_eyeballs_delay=happy_eyeballs_delay,
)

if happy_eyeballs_delay is None:
reader, writer = await asyncio.open_connection(
host,
port,
local_addr=local_address,
)
else:
reader, writer = await asyncio.open_connection(
host,
port,
local_addr=local_address,
happy_eyeballs_delay=happy_eyeballs_delay,
)
return AsyncioTransportStreamSocketAdapter(reader, writer)
return await self.wrap_stream_socket(socket)

async def create_ssl_over_tcp_connection(
self,
Expand All @@ -162,38 +152,28 @@ async def create_ssl_over_tcp_connection(
self._check_ssl_support()
self.__verify_ssl_context(ssl_context)

happy_eyeballs_delay = self._default_happy_eyeballs_delay(happy_eyeballs_delay)

if happy_eyeballs_delay is None:
reader, writer = await asyncio.open_connection(
host,
port,
ssl=ssl_context,
server_hostname=server_hostname,
ssl_handshake_timeout=float(ssl_handshake_timeout),
ssl_shutdown_timeout=float(ssl_shutdown_timeout),
local_addr=local_address,
)
else:
reader, writer = await asyncio.open_connection(
host,
port,
ssl=ssl_context,
server_hostname=server_hostname,
ssl_handshake_timeout=float(ssl_handshake_timeout),
ssl_shutdown_timeout=float(ssl_shutdown_timeout),
local_addr=local_address,
happy_eyeballs_delay=happy_eyeballs_delay,
)
return AsyncioTransportStreamSocketAdapter(reader, writer)
happy_eyeballs_delay = _DEFAULT_HAPPY_EYEBALLS_DELAY

@staticmethod
def _default_happy_eyeballs_delay(happy_eyeballs_delay: float | None) -> float | None:
if happy_eyeballs_delay is None:
running_loop = asyncio.get_running_loop()
if isinstance(running_loop, asyncio.base_events.BaseEventLoop): # Base class of standard implementation
happy_eyeballs_delay = 0.25 # Recommended value by the RFC 6555
return happy_eyeballs_delay
if server_hostname is None:
server_hostname = host

loop = asyncio.get_running_loop()
socket = await create_connection(
host,
port,
loop,
local_address=local_address,
happy_eyeballs_delay=happy_eyeballs_delay,
)

return await self.wrap_ssl_over_stream_socket_client_side(
socket,
ssl_context=ssl_context,
server_hostname=server_hostname,
ssl_handshake_timeout=ssl_handshake_timeout,
ssl_shutdown_timeout=ssl_shutdown_timeout,
)

async def wrap_stream_socket(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,7 @@ async def on_disconnection(self, client: AsyncStreamClient[str]) -> None:

async def handle(self, client: AsyncStreamClient[str]) -> AsyncGenerator[None, str]:
if self.request_count[client] >= self.refuse_after:
await asyncio.sleep(0.2)
return
request = yield
self.request_count[client] += 1
Expand Down Expand Up @@ -274,6 +275,7 @@ async def on_connection(self, client: AsyncStreamClient[str]) -> None:
await client.send_packet("milk")

async def handle(self, client: AsyncStreamClient[str]) -> AsyncGenerator[None, str]:
await asyncio.sleep(0.2)
raise RandomError("An error occurred")
request = yield # type: ignore[unreachable]
await client.send_packet(request)
Expand Down
2 changes: 1 addition & 1 deletion tests/unit_test/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from types import TracebackType
from typing import Any

_DEFAULT_FAMILIES: Sequence[int] = (AF_INET, AF_INET6)
_DEFAULT_FAMILIES: Sequence[int] = (AF_INET6, AF_INET)


class _LockMixin:
Expand Down
Loading

0 comments on commit ae4d4ed

Please sign in to comment.