From 1ff534fe8cf84c50152a7e1f0a7a062fa2909e99 Mon Sep 17 00:00:00 2001 From: Sebastian Majewski Date: Sun, 15 Sep 2024 23:22:42 -0500 Subject: [PATCH] Socket code cleanup --- net_addr/__init__.py | 1 + pytcp/socket/raw__metadata.py | 6 ++-- pytcp/socket/raw__socket.py | 48 +++-------------------------- pytcp/socket/socket.py | 34 +++++++++++++++----- pytcp/socket/socket_id.py | 58 +++++++++++++++++++++++++++++++++++ pytcp/socket/tcp__metadata.py | 19 ++++++------ pytcp/socket/tcp__session.py | 2 +- pytcp/socket/tcp__socket.py | 12 ++++---- pytcp/socket/udp__metadata.py | 25 ++++++--------- pytcp/socket/udp__socket.py | 14 ++++----- pytcp/stack/__init__.py | 3 +- 11 files changed, 128 insertions(+), 94 deletions(-) create mode 100755 pytcp/socket/socket_id.py diff --git a/net_addr/__init__.py b/net_addr/__init__.py index 3c7887b5..e6eef7b9 100644 --- a/net_addr/__init__.py +++ b/net_addr/__init__.py @@ -33,6 +33,7 @@ """ +from .address import Address from .click_types import ( ClickTypeIp4Address, ClickTypeIp4Host, diff --git a/pytcp/socket/raw__metadata.py b/pytcp/socket/raw__metadata.py index 12432f12..ed460a35 100755 --- a/pytcp/socket/raw__metadata.py +++ b/pytcp/socket/raw__metadata.py @@ -38,7 +38,7 @@ from dataclasses import dataclass from typing import TYPE_CHECKING, Any -from pytcp.socket.socket import AddressFamily, IpProto, SocketType +from pytcp.socket.socket import AddressFamily, SocketType if TYPE_CHECKING: from net_addr import IpAddress @@ -69,16 +69,16 @@ def socket_ids(self) -> list[tuple[Any, ...]]: ( AddressFamily.from_ver(self.ip__ver), SocketType.RAW, - IpProto.ICMP4, self.ip__local_address.unspecified, + 0, self.ip__remote_address.unspecified, 0, ), ( AddressFamily.from_ver(self.ip__ver), SocketType.RAW, - IpProto.ICMP6, self.ip__local_address.unspecified, + 0, self.ip__remote_address.unspecified, 0, ), diff --git a/pytcp/socket/raw__socket.py b/pytcp/socket/raw__socket.py index 815644c9..b0c9bb83 100755 --- a/pytcp/socket/raw__socket.py +++ b/pytcp/socket/raw__socket.py @@ -44,11 +44,9 @@ Ip6Address, Ip6AddressFormatError, ) -from pytcp.lib import stack from pytcp.lib.ip_helper import pick_local_ip_address from pytcp.lib.logger import log from pytcp.lib.tx_status import TxStatus -from pytcp.socket.raw__metadata import RawMetadata from pytcp.socket.socket import ( AddressFamily, IpProto, @@ -60,7 +58,7 @@ if TYPE_CHECKING: from net_addr import IpAddress - from pytcp.socket.udp__metadata import UdpMetadata + from pytcp.socket.raw__metadata import RawMetadata class RawSocket(Socket): @@ -90,48 +88,10 @@ def __init__( self._local_ip_address = Ip4Address() self._remote_ip_address = Ip4Address() - __debug__ and log("socket", f"[{self}] - Created socket") - - @override - def __str__(self) -> str: - """ - Get the UDP log string. - """ - - return ( - f"{self._address_family}/{self._socket_type}/{self._ip_proto}/" - f"{self._local_ip_address}/{self._remote_ip_address}" - ) + self._local_port = int(ip_proto) + self._remote_port = 0 - @property - def id(self) -> tuple[Any, ...]: - """ - Get the socket ID. - """ - - return ( - self._address_family, - self._socket_type, - self._ip_proto, - self._local_ip_address, - self._remote_ip_address, - ) - - @property - def local_ip_address(self) -> IpAddress: - """ - Get the '_local_ip_address' attribute. - """ - - return self._local_ip_address - - @property - def remote_ip_address(self) -> IpAddress: - """ - Get the '_remote_ip_address' attribute. - """ - - return self._remote_ip_address + __debug__ and log("socket", f"[{self}] - Created socket") def _get_ip_addresses( self, diff --git a/pytcp/socket/socket.py b/pytcp/socket/socket.py index 68db9e9e..d580a5b9 100755 --- a/pytcp/socket/socket.py +++ b/pytcp/socket/socket.py @@ -40,6 +40,7 @@ from pytcp.lib.name_enum import NameEnum from pytcp.protocols.enums import IpProto +from pytcp.socket.socket_id import SocketId if TYPE_CHECKING: from net_addr import IpAddress @@ -126,25 +127,44 @@ def __repr__(self) -> str: return self.__str__() @property - def id( - self, - ) -> tuple[ - AddressFamily, SocketType, IpProto, IpAddress, int, IpAddress, int - ]: + def socket_id(self) -> SocketId: """ Get the socket ID. """ - return ( + return SocketId( self._address_family, self._socket_type, - self._ip_proto, self._local_ip_address, self._local_port, self._remote_ip_address, self._remote_port, ) + @property + def address_family(self) -> AddressFamily: + """ + Get the '_family' attribute. + """ + + return self._address_family + + @property + def socket_type(self) -> SocketType: + """ + Get the '_type' attribute. + """ + + return self._socket_type + + @property + def ip_proto(self) -> IpProto: + """ + Get the '_proto' attribute. + """ + + return self._ip_proto + @property def local_ip_address(self) -> IpAddress: """ diff --git a/pytcp/socket/socket_id.py b/pytcp/socket/socket_id.py new file mode 100755 index 00000000..190aae84 --- /dev/null +++ b/pytcp/socket/socket_id.py @@ -0,0 +1,58 @@ +#!/usr/bin/env python3 + +################################################################################ +## ## +## PyTCP - Python TCP/IP stack ## +## Copyright (C) 2020-present Sebastian Majewski ## +## ## +## This program is free software: you can redistribute it and/or modify ## +## it under the terms of the GNU General Public License as published by ## +## the Free Software Foundation, either version 3 of the License, or ## +## (at your option) any later version. ## +## ## +## This program is distributed in the hope that it will be useful, ## +## but WITHOUT ANY WARRANTY; without even the implied warranty of ## +## MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the ## +## GNU General Public License for more details. ## +## ## +## You should have received a copy of the GNU General Public License ## +## along with this program. If not, see . ## +## ## +## Author's email: ccie18643@gmail.com ## +## Github repository: https://github.com/ccie18643/PyTCP ## +## ## +################################################################################ + + +""" +Module contains class representing the Socket identificator. + +pytcp/lib/socket.py + +ver 3.0.2 +""" + + +from __future__ import annotations + +from dataclasses import dataclass +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from net_addr import Address + + from .socket import AddressFamily, SocketType + + +@dataclass(frozen=True) +class SocketId: + """ + Store the Socket identificator data. + """ + + address_family: AddressFamily + socket_type: SocketType + local_address: Address + local_port: int + remote_address: Address + remote_port: int diff --git a/pytcp/socket/tcp__metadata.py b/pytcp/socket/tcp__metadata.py index b012b8fc..66155927 100755 --- a/pytcp/socket/tcp__metadata.py +++ b/pytcp/socket/tcp__metadata.py @@ -36,9 +36,11 @@ from __future__ import annotations from dataclasses import dataclass -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING -from .socket import AddressFamily, IpProto, SocketType +from pytcp.socket.socket_id import SocketId + +from .socket import AddressFamily, SocketType if TYPE_CHECKING: from net_addr import IpAddress @@ -71,15 +73,14 @@ class TcpMetadata: tracker: Tracker | None @property - def socket_id(self) -> tuple[Any, ...]: + def socket_id(self) -> SocketId: """ Get the exact match socket ID. """ - return ( + return SocketId( AddressFamily.from_ver(self.ip__ver), SocketType.STREAM, - IpProto.TCP, self.ip__local_address, self.tcp__local_port, self.ip__remote_address, @@ -87,25 +88,23 @@ def socket_id(self) -> tuple[Any, ...]: ) @property - def listening_socket_ids(self) -> list[tuple[Any, ...]]: + def listening_socket_ids(self) -> list[SocketId]: """ Get list of the listening socket IDs that match the metadata. """ return [ - ( + SocketId( AddressFamily.from_ver(self.ip__ver), SocketType.STREAM, - IpProto.TCP, self.ip__local_address, self.tcp__local_port, self.ip__remote_address.unspecified, 0, ), - ( + SocketId( AddressFamily.from_ver(self.ip__ver), SocketType.STREAM, - IpProto.TCP, self.ip__local_address.unspecified, self.tcp__local_port, self.ip__remote_address.unspecified, diff --git a/pytcp/socket/tcp__session.py b/pytcp/socket/tcp__session.py index 41c56d6b..9cab00f9 100755 --- a/pytcp/socket/tcp__session.py +++ b/pytcp/socket/tcp__session.py @@ -472,7 +472,7 @@ def _change_state(self, state: FsmState) -> None: # Unregister session if self._state in {FsmState.CLOSED}: - stack.sockets.pop(self._socket.id) + stack.sockets.pop(self._socket.socket_id) __debug__ and log( "tcp-ss", f"[{self}] - Unregister associated socket" ) diff --git a/pytcp/socket/tcp__socket.py b/pytcp/socket/tcp__socket.py index 9fd260fc..9fb813e9 100755 --- a/pytcp/socket/tcp__socket.py +++ b/pytcp/socket/tcp__socket.py @@ -99,7 +99,7 @@ def __init__( self._local_port = tcp_session.local_port self._remote_port = tcp_session.remote_port self._parent_socket = tcp_session.socket - stack.sockets[self.id] = self + stack.sockets[self.socket_id] = self # Fresh socket initialization else: @@ -259,10 +259,10 @@ def bind(self, address: tuple[str, int]) -> None: local_port = pick_local_port() # Assigning local port makes socket "bound" - stack.sockets.pop(self.id, None) + stack.sockets.pop(self.socket_id, None) self._local_ip_address = local_ip_address self._local_port = local_port - stack.sockets[self.id] = self + stack.sockets[self.socket_id] = self __debug__ and log("socket", f"[{self}] - Bound socket") @@ -292,12 +292,12 @@ def connect(self, address: tuple[str, int]) -> None: ) # Re-register socket with new socket id - stack.sockets.pop(self.id, None) + stack.sockets.pop(self.socket_id, None) self._local_ip_address = local_ip_address self._local_port = local_port self._remote_ip_address = remote_ip_address self._remote_port = remote_port - stack.sockets[self.id] = self + stack.sockets[self.socket_id] = self self._tcp_session = TcpSession( local_ip_address=self._local_ip_address, @@ -346,7 +346,7 @@ def listen(self) -> None: "connections", ) - stack.sockets[self.id] = self + stack.sockets[self.socket_id] = self self._tcp_session.listen() def accept(self) -> tuple[Socket, tuple[str, int]]: diff --git a/pytcp/socket/udp__metadata.py b/pytcp/socket/udp__metadata.py index b1785602..44068d9c 100755 --- a/pytcp/socket/udp__metadata.py +++ b/pytcp/socket/udp__metadata.py @@ -36,12 +36,13 @@ from __future__ import annotations from dataclasses import dataclass -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING from net_addr.ip6_address import Ip6Address from net_addr import Ip4Address -from pytcp.socket.socket import AddressFamily, IpProto, SocketType +from pytcp.socket.socket import AddressFamily, SocketType +from pytcp.socket.socket_id import SocketId if TYPE_CHECKING: from net_addr import IpAddress @@ -65,7 +66,7 @@ class UdpMetadata: tracker: Tracker | None = None @property - def socket_ids(self) -> list[tuple[Any, ...]]: + def socket_ids(self) -> list[SocketId]: """ Get list of the listening socket IDs that match the metadata. """ @@ -73,10 +74,9 @@ def socket_ids(self) -> list[tuple[Any, ...]]: match self.ip__ver, self.udp__local_port, self.udp__remote_port: case 4, 68, 67: return [ - ( + SocketId( AddressFamily.INET4, SocketType.DGRAM, - IpProto.UDP, Ip4Address(), 68, Ip4Address("255.255.255.255"), @@ -85,19 +85,17 @@ def socket_ids(self) -> list[tuple[Any, ...]]: ] case 6, 546, 547: return [ - ( + SocketId( AddressFamily.INET6, SocketType.DGRAM, - IpProto.UDP, Ip6Address(), 546, Ip6Address("ff02::1:2"), 547, ), # ID for the DHCPv6 client operation. - ( + SocketId( AddressFamily.INET6, SocketType.DGRAM, - IpProto.UDP, Ip6Address(), 546, Ip6Address("ff02::1:3"), @@ -106,28 +104,25 @@ def socket_ids(self) -> list[tuple[Any, ...]]: ] case _: return [ - ( + SocketId( AddressFamily.from_ver(self.ip__ver), SocketType.DGRAM, - IpProto.UDP, self.ip__local_address, self.udp__local_port, self.ip__remote_address, self.udp__remote_port, ), - ( + SocketId( AddressFamily.from_ver(self.ip__ver), SocketType.DGRAM, - IpProto.UDP, self.ip__local_address, self.udp__local_port, self.ip__remote_address.unspecified, 0, ), - ( + SocketId( AddressFamily.from_ver(self.ip__ver), SocketType.DGRAM, - IpProto.UDP, self.ip__local_address.unspecified, self.udp__local_port, self.ip__remote_address.unspecified, diff --git a/pytcp/socket/udp__socket.py b/pytcp/socket/udp__socket.py index b8852ff8..c9d74f8c 100755 --- a/pytcp/socket/udp__socket.py +++ b/pytcp/socket/udp__socket.py @@ -219,10 +219,10 @@ def bind(self, address: tuple[str, int]) -> None: local_port = pick_local_port() # Assigning local port makes socket "bound" - stack.sockets.pop(self.id, None) + stack.sockets.pop(self.socket_id, None) self._local_ip_address = local_ip_address self._local_port = local_port - stack.sockets[self.id] = self + stack.sockets[self.socket_id] = self __debug__ and log("socket", f"[{self}] - Bound") @@ -252,12 +252,12 @@ def connect(self, address: tuple[str, int]) -> None: ) # Re-register socket with new socket id - stack.sockets.pop(self.id, None) + stack.sockets.pop(self.socket_id, None) self._local_ip_address = local_ip_address self._local_port = local_port self._remote_ip_address = remote_ip_address self._remote_port = remote_port - stack.sockets[self.id] = self + stack.sockets[self.socket_id] = self __debug__ and log("socket", f"[{self}] - Connected socket") @@ -318,9 +318,9 @@ def sendto(self, data: bytes, address: tuple[str, int]) -> int: # Assigning local port makes socket "bound" if not "bound" already if self._local_port not in range(1, 65536): - stack.sockets.pop(self.id, None) + stack.sockets.pop(self.socket_id, None) self._local_port = pick_local_port() - stack.sockets[self.id] = self + stack.sockets[self.socket_id] = self # Set local and remote ip addresses aproprietely local_ip_address, remote_ip_address = self._get_ip_addresses( @@ -404,7 +404,7 @@ def close(self) -> None: Close socket. """ - stack.sockets.pop(self.id, None) + stack.sockets.pop(self.socket_id, None) __debug__ and log("socket", f"[{self}] - Closed socket") diff --git a/pytcp/stack/__init__.py b/pytcp/stack/__init__.py index 3908cb8d..84575e6a 100755 --- a/pytcp/stack/__init__.py +++ b/pytcp/stack/__init__.py @@ -46,6 +46,7 @@ from net_addr.mac_address import MacAddress from pytcp.lib.logger import log +from pytcp.socket.socket_id import SocketId from .arp_cache import ArpCache from .nd_cache import NdCache @@ -143,7 +144,7 @@ # Stack shared data. stack_initialized: bool = False interface_mtu: int -sockets: dict[tuple[Any, ...], Socket] = {} +sockets: dict[SocketId, Socket] = {} arp_probe_unicast_conflict: set[Ip4Address] = set()