Skip to content

Commit

Permalink
Fixed socket binding on local addresses (#146)
Browse files Browse the repository at this point in the history
  • Loading branch information
francis-clairicia authored Nov 4, 2023
1 parent 50aa224 commit bc8fa58
Show file tree
Hide file tree
Showing 5 changed files with 99 additions and 50 deletions.
35 changes: 33 additions & 2 deletions src/easynetwork/lowlevel/asyncio/_asyncio_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import asyncio
import contextlib
import itertools
import socket as _socket
from collections.abc import Iterable, Sequence
from typing import Any
Expand Down Expand Up @@ -50,6 +51,32 @@ async def ensure_resolved(
return info


async def resolve_local_addresses(
hosts: Sequence[str | None],
port: int,
socktype: int,
loop: asyncio.AbstractEventLoop,
) -> Sequence[tuple[int, int, int, str, tuple[Any, ...]]]:
infos: set[tuple[int, int, int, str, tuple[Any, ...]]] = set(
itertools.chain.from_iterable(
await asyncio.gather(
*[
ensure_resolved(
host,
port,
_socket.AF_UNSPEC,
socktype,
loop,
flags=_socket.AI_PASSIVE | _socket.AI_ADDRCONFIG,
)
for host in hosts
]
)
)
)
return sorted(infos)


async def create_connection(
host: str,
port: int,
Expand Down Expand Up @@ -163,8 +190,12 @@ def open_listener_sockets_from_getaddrinfo_result(
# Disable IPv4/IPv6 dual stack support (enabled by
# default on Linux) which makes a single socket
# listen on both address families.
if _socket.has_ipv6 and af == _socket.AF_INET6 and hasattr(_socket, "IPPROTO_IPV6"):
sock.setsockopt(_socket.IPPROTO_IPV6, _socket.IPV6_V6ONLY, True)
if af == _socket.AF_INET6:
if hasattr(_socket, "IPPROTO_IPV6"):
sock.setsockopt(_socket.IPPROTO_IPV6, _socket.IPV6_V6ONLY, True)
if "%" in sa[0]:
addr, scope_id = sa[0].split("%", 1)
sa = (addr, sa[1], 0, int(scope_id))
try:
sock.bind(sa)
except OSError as exc:
Expand Down
32 changes: 15 additions & 17 deletions src/easynetwork/lowlevel/asyncio/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
import asyncio.base_events
import contextvars
import functools
import itertools
import math
import os
import socket as _socket
Expand All @@ -42,7 +41,7 @@

from ..api_async.backend.abc import AsyncBackend as AbstractAsyncBackend
from ..api_async.backend.sniffio import current_async_library_cvar as _sniffio_current_async_library_cvar
from ._asyncio_utils import create_connection, ensure_resolved, open_listener_sockets_from_getaddrinfo_result
from ._asyncio_utils import create_connection, open_listener_sockets_from_getaddrinfo_result, resolve_local_addresses
from .datagram.endpoint import create_datagram_endpoint
from .datagram.listener import AsyncioTransportDatagramListenerSocketAdapter, RawDatagramListenerSocketAdapter
from .datagram.socket import AsyncioTransportDatagramSocketAdapter, RawDatagramSocketAdapter
Expand Down Expand Up @@ -285,15 +284,13 @@ async def _create_tcp_socket_listeners(
else:
hosts = host

infos: set[tuple[int, int, int, str, tuple[Any, ...]]] = set(
itertools.chain.from_iterable(
await asyncio.gather(
*[
ensure_resolved(host, port, _socket.AF_UNSPEC, _socket.SOCK_STREAM, loop, flags=_socket.AI_PASSIVE)
for host in hosts
]
)
)
del host

infos: Sequence[tuple[int, int, int, str, tuple[Any, ...]]] = await resolve_local_addresses(
hosts,
port,
_socket.SOCK_STREAM,
loop,
)

sockets: list[_socket.socket] = open_listener_sockets_from_getaddrinfo_result(
Expand Down Expand Up @@ -351,12 +348,13 @@ async def create_udp_listeners(
else:
hosts = host

infos: set[tuple[int, int, int, str, tuple[Any, ...]]] = set(
itertools.chain.from_iterable(
await asyncio.gather(
*[ensure_resolved(host, port, _socket.AF_UNSPEC, _socket.SOCK_DGRAM, loop) for host in hosts]
)
)
del host

infos: Sequence[tuple[int, int, int, str, tuple[Any, ...]]] = await resolve_local_addresses(
hosts,
port,
_socket.SOCK_DGRAM,
loop,
)

sockets: list[_socket.socket] = open_listener_sockets_from_getaddrinfo_result(
Expand Down
16 changes: 5 additions & 11 deletions tests/functional_test/test_concurrency/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from easynetwork.api_sync.server.abc import AbstractNetworkServer
from easynetwork.api_sync.server.tcp import StandaloneTCPNetworkServer
from easynetwork.api_sync.server.udp import StandaloneUDPNetworkServer
from easynetwork.lowlevel.socket import IPv4SocketAddress
from easynetwork.protocol import DatagramProtocol, StreamProtocol
from easynetwork.serializers.line import StringLineSerializer

Expand Down Expand Up @@ -53,17 +54,10 @@ def _run_server(server: AbstractNetworkServer) -> None:


def _retrieve_server_address(server: AbstractNetworkServer) -> tuple[str, int]:
match server:
case StandaloneTCPNetworkServer():
addresses = server.get_addresses()
assert addresses
return "localhost", addresses[0].port
case StandaloneUDPNetworkServer():
addresses = server.get_addresses()
assert addresses
return addresses[0].for_connection()
case _:
pytest.fail("Cannot retrieve server port")
address = server.get_addresses()[0]
if isinstance(address, IPv4SocketAddress):
return "127.0.0.1", address.port
return "::1", address.port


@pytest.fixture
Expand Down
30 changes: 15 additions & 15 deletions tests/unit_test/test_async/test_asyncio_backend/test_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import contextlib
import contextvars
from collections.abc import Callable, Coroutine, Sequence
from socket import AF_INET, AF_INET6, AF_UNSPEC, AI_PASSIVE, IPPROTO_TCP, IPPROTO_UDP, SOCK_DGRAM, SOCK_STREAM
from socket import AF_INET, AF_INET6, AF_UNSPEC, AI_ADDRCONFIG, AI_PASSIVE, IPPROTO_TCP, IPPROTO_UDP, SOCK_DGRAM, SOCK_STREAM
from typing import TYPE_CHECKING, Any, cast

from easynetwork.lowlevel.asyncio import AsyncIOBackend
Expand Down Expand Up @@ -829,10 +829,10 @@ async def test____create_tcp_listeners____open_listener_sockets(
family=AF_UNSPEC,
type=SOCK_STREAM,
proto=0,
flags=AI_PASSIVE,
flags=AI_PASSIVE | AI_ADDRCONFIG,
)
mock_open_listeners.assert_called_once_with(
set(addrinfo_list),
sorted(set(addrinfo_list)),
backlog=123456789,
reuse_address=mocker.ANY, # Determined according to OS
reuse_port=mocker.sentinel.reuse_port,
Expand Down Expand Up @@ -933,10 +933,10 @@ async def test____create_tcp_listeners____bind_to_any_interfaces(
family=AF_UNSPEC,
type=SOCK_STREAM,
proto=0,
flags=AI_PASSIVE,
flags=AI_PASSIVE | AI_ADDRCONFIG,
)
mock_open_listeners.assert_called_once_with(
set(addrinfo_list),
sorted(set(addrinfo_list)),
backlog=123456789,
reuse_address=mocker.ANY, # Determined according to OS
reuse_port=mocker.sentinel.reuse_port,
Expand Down Expand Up @@ -1039,12 +1039,12 @@ async def test____create_tcp_listeners____bind_to_several_hosts(
family=AF_UNSPEC,
type=SOCK_STREAM,
proto=0,
flags=AI_PASSIVE,
flags=AI_PASSIVE | AI_ADDRCONFIG,
)
for host in remote_hosts
]
mock_open_listeners.assert_called_once_with(
set(addrinfo_list),
sorted(set(addrinfo_list)),
backlog=123456789,
reuse_address=mocker.ANY, # Determined according to OS
reuse_port=mocker.sentinel.reuse_port,
Expand Down Expand Up @@ -1119,7 +1119,7 @@ async def test____create_tcp_listeners____error_getaddrinfo_returns_empty_list(
family=AF_UNSPEC,
type=SOCK_STREAM,
proto=0,
flags=AI_PASSIVE,
flags=AI_PASSIVE | AI_ADDRCONFIG,
)
mock_open_listeners.assert_not_called()
mock_ListenerSocketAdapter.assert_not_called()
Expand Down Expand Up @@ -1367,10 +1367,10 @@ async def test____create_udp_listeners____open_listener_sockets(
family=AF_UNSPEC,
type=SOCK_DGRAM,
proto=0,
flags=0,
flags=AI_PASSIVE | AI_ADDRCONFIG,
)
mock_open_listeners.assert_called_once_with(
set(addrinfo_list),
sorted(set(addrinfo_list)),
backlog=None,
reuse_address=False,
reuse_port=mocker.sentinel.reuse_port,
Expand Down Expand Up @@ -1456,10 +1456,10 @@ async def test____create_udp_listeners____bind_to_local_interfaces(
family=AF_UNSPEC,
type=SOCK_DGRAM,
proto=0,
flags=0,
flags=AI_PASSIVE | AI_ADDRCONFIG,
)
mock_open_listeners.assert_called_once_with(
set(addrinfo_list),
sorted(set(addrinfo_list)),
backlog=None,
reuse_address=False,
reuse_port=mocker.sentinel.reuse_port,
Expand Down Expand Up @@ -1549,12 +1549,12 @@ async def test____create_udp_listeners____bind_to_several_hosts(
family=AF_UNSPEC,
type=SOCK_DGRAM,
proto=0,
flags=0,
flags=AI_PASSIVE | AI_ADDRCONFIG,
)
for host in remote_hosts
]
mock_open_listeners.assert_called_once_with(
set(addrinfo_list),
sorted(set(addrinfo_list)),
backlog=None,
reuse_address=False,
reuse_port=mocker.sentinel.reuse_port,
Expand Down Expand Up @@ -1624,7 +1624,7 @@ async def test____create_udp_listeners____error_getaddrinfo_returns_empty_list(
family=AF_UNSPEC,
type=SOCK_DGRAM,
proto=0,
flags=0,
flags=AI_PASSIVE | AI_ADDRCONFIG,
)
mock_open_listeners.assert_not_called()
mock_RawDatagramListenerSocketAdapter.assert_not_called()
Expand Down
36 changes: 31 additions & 5 deletions tests/unit_test/test_async/test_asyncio_backend/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,25 +511,29 @@ def addrinfo_list() -> Sequence[tuple[int, int, int, str, tuple[Any, ...]]]:
@pytest.mark.parametrize("reuse_address", [False, True], ids=lambda boolean: f"reuse_address=={boolean}")
@pytest.mark.parametrize("SO_REUSEADDR_available", [False, True], ids=lambda boolean: f"SO_REUSEADDR_available=={boolean}")
@pytest.mark.parametrize("SO_REUSEADDR_raise_error", [False, True], ids=lambda boolean: f"SO_REUSEADDR_raise_error=={boolean}")
@pytest.mark.parametrize("IPPROTO_IPV6_available", [False, True], ids=lambda boolean: f"IPPROTO_IPV6_available=={boolean}")
@pytest.mark.parametrize("reuse_port", [False, True], ids=lambda boolean: f"reuse_port=={boolean}")
@pytest.mark.parametrize("backlog", [123456, None], ids=lambda value: f"backlog=={value}")
def test____open_listener_sockets_from_getaddrinfo_result____create_listener_sockets(
reuse_address: bool,
backlog: int | None,
SO_REUSEADDR_available: bool,
SO_REUSEADDR_raise_error: bool,
IPPROTO_IPV6_available: bool,
reuse_port: bool,
mock_socket_cls: MagicMock,
mock_socket_ipv4: MagicMock,
mock_socket_ipv6: MagicMock,
mocker: MockerFixture,
addrinfo_list: Sequence[tuple[int, int, int, str, tuple[Any, ...]]],
monkeypatch: pytest.MonkeyPatch,
SO_REUSEPORT: int,
mocker: MockerFixture,
) -> None:
# Arrange
if not SO_REUSEADDR_available:
monkeypatch.delattr("socket.SO_REUSEADDR", raising=True)
if not IPPROTO_IPV6_available:
monkeypatch.delattr("socket.IPPROTO_IPV6", raising=False)
if SO_REUSEADDR_raise_error:

def setsockopt(level: int, opt: int, value: int, /) -> None:
Expand All @@ -554,12 +558,16 @@ def setsockopt(level: int, opt: int, value: int, /) -> None:
assert len(sockets) == len(addrinfo_list)
assert mock_socket_cls.call_args_list == [mocker.call(f, t, p) for f, t, p, _, _ in addrinfo_list]
for socket, (sock_family, _, _, _, sock_addr) in zip(sockets, addrinfo_list, strict=True):
expected_setsockopt_calls: list[Any] = []
if reuse_address and SO_REUSEADDR_available:
socket.setsockopt.assert_any_call(SOL_SOCKET, SO_REUSEADDR, True)
expected_setsockopt_calls.append(mocker.call(SOL_SOCKET, SO_REUSEADDR, True))
if reuse_port:
socket.setsockopt.assert_any_call(SOL_SOCKET, SO_REUSEPORT, True)
if sock_family == AF_INET6:
socket.setsockopt.assert_any_call(IPPROTO_IPV6, IPV6_V6ONLY, True)
expected_setsockopt_calls.append(mocker.call(SOL_SOCKET, SO_REUSEPORT, True))
if sock_family == AF_INET6 and IPPROTO_IPV6_available:
expected_setsockopt_calls.append(mocker.call(IPPROTO_IPV6, IPV6_V6ONLY, True))

assert socket.setsockopt.mock_calls == expected_setsockopt_calls

socket.bind.assert_called_once_with(sock_addr)
if backlog is None:
socket.listen.assert_not_called()
Expand Down Expand Up @@ -609,3 +617,21 @@ def test____open_listener_sockets_from_getaddrinfo_result____bind_failed(

s1.close.assert_called_once_with()
s2.close.assert_called_once_with()


def test____open_listener_sockets_from_getaddrinfo_result____ipv6_scope_id_not_properly_extracted_from_address(
mock_socket_cls: MagicMock,
mock_socket_ipv6: MagicMock,
) -> None:
# Arrange
addrinfo_list: Sequence[tuple[int, int, int, str, tuple[Any, ...]]] = [
(AF_INET6, SOCK_STREAM, IPPROTO_TCP, "", ("4e76:f928:6bbc:53ce:c01e:00d5:cdd5:6bbb%6", 65432, 0, 0)),
]
mock_socket_cls.side_effect = [mock_socket_ipv6]

# Act
sockets = open_listener_sockets_from_getaddrinfo_result(addrinfo_list, backlog=10, reuse_address=True, reuse_port=False)

# Assert
assert sockets == [mock_socket_ipv6]
mock_socket_ipv6.bind.assert_called_once_with(("4e76:f928:6bbc:53ce:c01e:00d5:cdd5:6bbb", 65432, 0, 6))

0 comments on commit bc8fa58

Please sign in to comment.