diff --git a/src/easynetwork/lowlevel/_utils.py b/src/easynetwork/lowlevel/_utils.py index ec58e492..de330408 100644 --- a/src/easynetwork/lowlevel/_utils.py +++ b/src/easynetwork/lowlevel/_utils.py @@ -57,7 +57,6 @@ del _ssl from . import constants -from .socket import AddressFamily if TYPE_CHECKING: from ssl import SSLError as _SSLError, SSLSocket as _SSLSocket @@ -113,10 +112,8 @@ def error_from_errno(errno: int) -> OSError: def check_socket_family(family: int) -> None: - supported_families = AddressFamily.__members__ - - if family not in supported_families.values(): - raise ValueError(f"Only these families are supported: {', '.join(supported_families)}") + if family not in {_socket.AF_INET, _socket.AF_INET6}: + raise ValueError("Only these families are supported: AF_INET, AF_INET6") def check_real_socket_state(socket: ISocket) -> None: diff --git a/src/easynetwork/lowlevel/socket.py b/src/easynetwork/lowlevel/socket.py index 3464ec4e..04f5ce5b 100644 --- a/src/easynetwork/lowlevel/socket.py +++ b/src/easynetwork/lowlevel/socket.py @@ -17,7 +17,6 @@ from __future__ import annotations __all__ = [ - "AddressFamily", "INETSocketAttribute", "IPv4SocketAddress", "IPv6SocketAddress", @@ -43,7 +42,6 @@ import threading from abc import abstractmethod from collections.abc import Callable -from enum import IntEnum, unique from struct import Struct from typing import ( TYPE_CHECKING, @@ -54,7 +52,6 @@ Protocol, TypeAlias, TypeVar, - assert_never, final, overload, runtime_checkable, @@ -122,22 +119,6 @@ class TLSAttribute(typed_attr.TypedAttributeSet): """the TLS protocol version (e.g. TLSv1.2)""" -@unique -class AddressFamily(IntEnum): - """ - Enumeration of supported socket address families. - """ - - AF_INET = _socket.AF_INET - AF_INET6 = _socket.AF_INET6 - - def __repr__(self) -> str: - return f"{type(self).__name__}.{self.name}" - - def __str__(self) -> str: # pragma: no cover - return repr(self) - - class IPv4SocketAddress(NamedTuple): host: str port: int @@ -175,13 +156,13 @@ def for_connection(self) -> tuple[str, int]: @overload -def new_socket_address(addr: tuple[str, int], family: Literal[AddressFamily.AF_INET]) -> IPv4SocketAddress: +def new_socket_address(addr: tuple[str, int], family: Literal[_socket.AddressFamily.AF_INET]) -> IPv4SocketAddress: ... @overload def new_socket_address( - addr: tuple[str, int] | tuple[str, int, int, int], family: Literal[AddressFamily.AF_INET6] + addr: tuple[str, int] | tuple[str, int, int, int], family: Literal[_socket.AddressFamily.AF_INET6] ) -> IPv6SocketAddress: ... @@ -206,7 +187,7 @@ def new_socket_address(addr: tuple[Any, ...], family: int) -> SocketAddress: >>> new_socket_address(("127.0.0.1", 12345), socket.AF_APPLETALK) Traceback (most recent call last): ... - ValueError: is not a valid AddressFamily + ValueError: Unsupported address family Parameters: addr: The address in the form ``(host, port)`` or ``(host, port, flow, scope_id)``. @@ -219,14 +200,13 @@ def new_socket_address(addr: tuple[Any, ...], family: int) -> SocketAddress: Returns: a :data:`SocketAddress` named tuple. """ - family = AddressFamily(family) match family: - case AddressFamily.AF_INET: + case _socket.AddressFamily.AF_INET: return IPv4SocketAddress(*addr) - case AddressFamily.AF_INET6: + case _socket.AddressFamily.AF_INET6: return IPv6SocketAddress(*addr) - case _: # pragma: no cover - assert_never(family) + case _: + raise ValueError(f"Unsupported address family {family!r}") @runtime_checkable diff --git a/tests/unit_test/base.py b/tests/unit_test/base.py index a13bd93f..5ac0ee22 100644 --- a/tests/unit_test/base.py +++ b/tests/unit_test/base.py @@ -3,8 +3,6 @@ from socket import AF_INET, AF_INET6, socket as Socket from typing import TYPE_CHECKING -from easynetwork.lowlevel.socket import AddressFamily - import pytest from ._utils import get_all_socket_families @@ -14,7 +12,7 @@ from pytest_mock import MockerFixture -SUPPORTED_FAMILIES: tuple[str, ...] = tuple(sorted(AddressFamily.__members__)) +SUPPORTED_FAMILIES: tuple[str, ...] = tuple(sorted(("AF_INET", "AF_INET6"))) UNSUPPORTED_FAMILIES: tuple[str, ...] = tuple(sorted(get_all_socket_families().difference(SUPPORTED_FAMILIES))) diff --git a/tests/unit_test/test_tools/test_socket.py b/tests/unit_test/test_tools/test_socket.py index a338a0a6..a6f1e3ec 100644 --- a/tests/unit_test/test_tools/test_socket.py +++ b/tests/unit_test/test_tools/test_socket.py @@ -2,11 +2,10 @@ import socket from collections.abc import Callable -from socket import IPPROTO_TCP, SO_KEEPALIVE, SO_LINGER, SOL_SOCKET, TCP_NODELAY +from socket import AF_INET, AF_INET6, IPPROTO_TCP, SO_KEEPALIVE, SO_LINGER, SOL_SOCKET, TCP_NODELAY from typing import TYPE_CHECKING, Any from easynetwork.lowlevel.socket import ( - AddressFamily, IPv4SocketAddress, IPv6SocketAddress, SocketAddress, @@ -22,6 +21,7 @@ import pytest from .._utils import partial_eq +from ..base import UNSUPPORTED_FAMILIES if TYPE_CHECKING: from unittest.mock import MagicMock @@ -29,16 +29,6 @@ from pytest_mock import MockerFixture -@pytest.mark.parametrize("name", list(AddressFamily.__members__)) -def test____AddressFamily____constants(name: str) -> None: - # Arrange - enum = AddressFamily[name] - constant = getattr(socket, name) - - # Act & Assert - assert enum.value == constant - - class TestSocketAddress: @pytest.mark.parametrize("address", [IPv4SocketAddress("127.0.0.1", 3000), IPv6SocketAddress("127.0.0.1", 3000)]) def test____for_connection____return_host_port_tuple(self, address: SocketAddress) -> None: @@ -54,15 +44,15 @@ def test____for_connection____return_host_port_tuple(self, address: SocketAddres @pytest.mark.parametrize( ["address", "family", "expected_type"], [ - pytest.param(("127.0.0.1", 3000), AddressFamily.AF_INET, IPv4SocketAddress), - pytest.param(("127.0.0.1", 3000), AddressFamily.AF_INET6, IPv6SocketAddress), - pytest.param(("127.0.0.1", 3000, 0, 0), AddressFamily.AF_INET6, IPv6SocketAddress), + pytest.param(("127.0.0.1", 3000), AF_INET, IPv4SocketAddress), + pytest.param(("127.0.0.1", 3000), AF_INET6, IPv6SocketAddress), + pytest.param(("127.0.0.1", 3000, 0, 0), AF_INET6, IPv6SocketAddress), ], ) def test____new_socket_address____factory( self, address: tuple[Any, ...], - family: AddressFamily, + family: int, expected_type: type[SocketAddress], ) -> None: # Arrange @@ -73,6 +63,20 @@ def test____new_socket_address____factory( # Assert assert isinstance(socket_address, expected_type) + @pytest.mark.parametrize("socket_family_name", list(UNSUPPORTED_FAMILIES)) + def test____new_socket_address____unsupported_family( + self, + socket_family_name: str, + ) -> None: + # Arrange + import socket + + family: int = getattr(socket, socket_family_name) + + # Act & Assert + with pytest.raises(ValueError, match=r"^Unsupported address family .+$"): + _ = new_socket_address(("127.0.0.1", 12345), family) + class TestSocketProxy: @pytest.fixture(