Skip to content

Commit

Permalink
Removed easynetwork.lowlevel.socket.AddressFamily (#179)
Browse files Browse the repository at this point in the history
  • Loading branch information
francis-clairicia authored Dec 3, 2023
1 parent dc98071 commit 7072238
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 51 deletions.
7 changes: 2 additions & 5 deletions src/easynetwork/lowlevel/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@
del _ssl

from . import constants
from .socket import AddressFamily

if TYPE_CHECKING:
from ssl import SSLError as _SSLError, SSLSocket as _SSLSocket
Expand Down Expand Up @@ -113,10 +112,8 @@ def error_from_errno(errno: int) -> OSError:


def check_socket_family(family: int) -> None:
supported_families = AddressFamily.__members__

if family not in supported_families.values():
raise ValueError(f"Only these families are supported: {', '.join(supported_families)}")
if family not in {_socket.AF_INET, _socket.AF_INET6}:
raise ValueError("Only these families are supported: AF_INET, AF_INET6")


def check_real_socket_state(socket: ISocket) -> None:
Expand Down
34 changes: 7 additions & 27 deletions src/easynetwork/lowlevel/socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from __future__ import annotations

__all__ = [
"AddressFamily",
"INETSocketAttribute",
"IPv4SocketAddress",
"IPv6SocketAddress",
Expand All @@ -43,7 +42,6 @@
import threading
from abc import abstractmethod
from collections.abc import Callable
from enum import IntEnum, unique
from struct import Struct
from typing import (
TYPE_CHECKING,
Expand All @@ -54,7 +52,6 @@
Protocol,
TypeAlias,
TypeVar,
assert_never,
final,
overload,
runtime_checkable,
Expand Down Expand Up @@ -122,22 +119,6 @@ class TLSAttribute(typed_attr.TypedAttributeSet):
"""the TLS protocol version (e.g. TLSv1.2)"""


@unique
class AddressFamily(IntEnum):
"""
Enumeration of supported socket address families.
"""

AF_INET = _socket.AF_INET
AF_INET6 = _socket.AF_INET6

def __repr__(self) -> str:
return f"{type(self).__name__}.{self.name}"

def __str__(self) -> str: # pragma: no cover
return repr(self)


class IPv4SocketAddress(NamedTuple):
host: str
port: int
Expand Down Expand Up @@ -175,13 +156,13 @@ def for_connection(self) -> tuple[str, int]:


@overload
def new_socket_address(addr: tuple[str, int], family: Literal[AddressFamily.AF_INET]) -> IPv4SocketAddress:
def new_socket_address(addr: tuple[str, int], family: Literal[_socket.AddressFamily.AF_INET]) -> IPv4SocketAddress:
...


@overload
def new_socket_address(
addr: tuple[str, int] | tuple[str, int, int, int], family: Literal[AddressFamily.AF_INET6]
addr: tuple[str, int] | tuple[str, int, int, int], family: Literal[_socket.AddressFamily.AF_INET6]
) -> IPv6SocketAddress:
...

Expand All @@ -206,7 +187,7 @@ def new_socket_address(addr: tuple[Any, ...], family: int) -> SocketAddress:
>>> new_socket_address(("127.0.0.1", 12345), socket.AF_APPLETALK)
Traceback (most recent call last):
...
ValueError: <AddressFamily.AF_APPLETALK: 5> is not a valid AddressFamily
ValueError: Unsupported address family <AddressFamily.AF_APPLETALK: 5>
Parameters:
addr: The address in the form ``(host, port)`` or ``(host, port, flow, scope_id)``.
Expand All @@ -219,14 +200,13 @@ def new_socket_address(addr: tuple[Any, ...], family: int) -> SocketAddress:
Returns:
a :data:`SocketAddress` named tuple.
"""
family = AddressFamily(family)
match family:
case AddressFamily.AF_INET:
case _socket.AddressFamily.AF_INET:
return IPv4SocketAddress(*addr)
case AddressFamily.AF_INET6:
case _socket.AddressFamily.AF_INET6:
return IPv6SocketAddress(*addr)
case _: # pragma: no cover
assert_never(family)
case _:
raise ValueError(f"Unsupported address family {family!r}")


@runtime_checkable
Expand Down
4 changes: 1 addition & 3 deletions tests/unit_test/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
from socket import AF_INET, AF_INET6, socket as Socket
from typing import TYPE_CHECKING

from easynetwork.lowlevel.socket import AddressFamily

import pytest

from ._utils import get_all_socket_families
Expand All @@ -14,7 +12,7 @@

from pytest_mock import MockerFixture

SUPPORTED_FAMILIES: tuple[str, ...] = tuple(sorted(AddressFamily.__members__))
SUPPORTED_FAMILIES: tuple[str, ...] = tuple(sorted(("AF_INET", "AF_INET6")))
UNSUPPORTED_FAMILIES: tuple[str, ...] = tuple(sorted(get_all_socket_families().difference(SUPPORTED_FAMILIES)))


Expand Down
36 changes: 20 additions & 16 deletions tests/unit_test/test_tools/test_socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,10 @@

import socket
from collections.abc import Callable
from socket import IPPROTO_TCP, SO_KEEPALIVE, SO_LINGER, SOL_SOCKET, TCP_NODELAY
from socket import AF_INET, AF_INET6, IPPROTO_TCP, SO_KEEPALIVE, SO_LINGER, SOL_SOCKET, TCP_NODELAY
from typing import TYPE_CHECKING, Any

from easynetwork.lowlevel.socket import (
AddressFamily,
IPv4SocketAddress,
IPv6SocketAddress,
SocketAddress,
Expand All @@ -22,23 +21,14 @@
import pytest

from .._utils import partial_eq
from ..base import UNSUPPORTED_FAMILIES

if TYPE_CHECKING:
from unittest.mock import MagicMock

from pytest_mock import MockerFixture


@pytest.mark.parametrize("name", list(AddressFamily.__members__))
def test____AddressFamily____constants(name: str) -> None:
# Arrange
enum = AddressFamily[name]
constant = getattr(socket, name)

# Act & Assert
assert enum.value == constant


class TestSocketAddress:
@pytest.mark.parametrize("address", [IPv4SocketAddress("127.0.0.1", 3000), IPv6SocketAddress("127.0.0.1", 3000)])
def test____for_connection____return_host_port_tuple(self, address: SocketAddress) -> None:
Expand All @@ -54,15 +44,15 @@ def test____for_connection____return_host_port_tuple(self, address: SocketAddres
@pytest.mark.parametrize(
["address", "family", "expected_type"],
[
pytest.param(("127.0.0.1", 3000), AddressFamily.AF_INET, IPv4SocketAddress),
pytest.param(("127.0.0.1", 3000), AddressFamily.AF_INET6, IPv6SocketAddress),
pytest.param(("127.0.0.1", 3000, 0, 0), AddressFamily.AF_INET6, IPv6SocketAddress),
pytest.param(("127.0.0.1", 3000), AF_INET, IPv4SocketAddress),
pytest.param(("127.0.0.1", 3000), AF_INET6, IPv6SocketAddress),
pytest.param(("127.0.0.1", 3000, 0, 0), AF_INET6, IPv6SocketAddress),
],
)
def test____new_socket_address____factory(
self,
address: tuple[Any, ...],
family: AddressFamily,
family: int,
expected_type: type[SocketAddress],
) -> None:
# Arrange
Expand All @@ -73,6 +63,20 @@ def test____new_socket_address____factory(
# Assert
assert isinstance(socket_address, expected_type)

@pytest.mark.parametrize("socket_family_name", list(UNSUPPORTED_FAMILIES))
def test____new_socket_address____unsupported_family(
self,
socket_family_name: str,
) -> None:
# Arrange
import socket

family: int = getattr(socket, socket_family_name)

# Act & Assert
with pytest.raises(ValueError, match=r"^Unsupported address family .+$"):
_ = new_socket_address(("127.0.0.1", 12345), family)


class TestSocketProxy:
@pytest.fixture(
Expand Down

0 comments on commit 7072238

Please sign in to comment.