Skip to content

Commit

Permalink
Socket code cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
ccie18643 committed Sep 16, 2024
1 parent b37577b commit 1ff534f
Show file tree
Hide file tree
Showing 11 changed files with 128 additions and 94 deletions.
1 change: 1 addition & 0 deletions net_addr/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
"""


from .address import Address
from .click_types import (
ClickTypeIp4Address,
ClickTypeIp4Host,
Expand Down
6 changes: 3 additions & 3 deletions pytcp/socket/raw__metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any

from pytcp.socket.socket import AddressFamily, IpProto, SocketType
from pytcp.socket.socket import AddressFamily, SocketType

if TYPE_CHECKING:
from net_addr import IpAddress
Expand Down Expand Up @@ -69,16 +69,16 @@ def socket_ids(self) -> list[tuple[Any, ...]]:
(
AddressFamily.from_ver(self.ip__ver),
SocketType.RAW,
IpProto.ICMP4,
self.ip__local_address.unspecified,
0,
self.ip__remote_address.unspecified,
0,
),
(
AddressFamily.from_ver(self.ip__ver),
SocketType.RAW,
IpProto.ICMP6,
self.ip__local_address.unspecified,
0,
self.ip__remote_address.unspecified,
0,
),
Expand Down
48 changes: 4 additions & 44 deletions pytcp/socket/raw__socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,9 @@
Ip6Address,
Ip6AddressFormatError,
)
from pytcp.lib import stack
from pytcp.lib.ip_helper import pick_local_ip_address
from pytcp.lib.logger import log
from pytcp.lib.tx_status import TxStatus
from pytcp.socket.raw__metadata import RawMetadata
from pytcp.socket.socket import (
AddressFamily,
IpProto,
Expand All @@ -60,7 +58,7 @@
if TYPE_CHECKING:
from net_addr import IpAddress
from pytcp.socket.udp__metadata import UdpMetadata
from pytcp.socket.raw__metadata import RawMetadata
class RawSocket(Socket):
Expand Down Expand Up @@ -90,48 +88,10 @@ def __init__(
self._local_ip_address = Ip4Address()
self._remote_ip_address = Ip4Address()
__debug__ and log("socket", f"<g>[{self}]</> - Created socket")
@override
def __str__(self) -> str:
"""
Get the UDP log string.
"""
return (
f"{self._address_family}/{self._socket_type}/{self._ip_proto}/"
f"{self._local_ip_address}/{self._remote_ip_address}"
)
self._local_port = int(ip_proto)
self._remote_port = 0
@property
def id(self) -> tuple[Any, ...]:
"""
Get the socket ID.
"""
return (
self._address_family,
self._socket_type,
self._ip_proto,
self._local_ip_address,
self._remote_ip_address,
)
@property
def local_ip_address(self) -> IpAddress:
"""
Get the '_local_ip_address' attribute.
"""
return self._local_ip_address
@property
def remote_ip_address(self) -> IpAddress:
"""
Get the '_remote_ip_address' attribute.
"""
return self._remote_ip_address
__debug__ and log("socket", f"<g>[{self}]</> - Created socket")
def _get_ip_addresses(
self,
Expand Down
34 changes: 27 additions & 7 deletions pytcp/socket/socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@

from pytcp.lib.name_enum import NameEnum
from pytcp.protocols.enums import IpProto
from pytcp.socket.socket_id import SocketId

if TYPE_CHECKING:
from net_addr import IpAddress
Expand Down Expand Up @@ -126,25 +127,44 @@ def __repr__(self) -> str:
return self.__str__()

@property
def id(
self,
) -> tuple[
AddressFamily, SocketType, IpProto, IpAddress, int, IpAddress, int
]:
def socket_id(self) -> SocketId:
"""
Get the socket ID.
"""

return (
return SocketId(
self._address_family,
self._socket_type,
self._ip_proto,
self._local_ip_address,
self._local_port,
self._remote_ip_address,
self._remote_port,
)

@property
def address_family(self) -> AddressFamily:
"""
Get the '_family' attribute.
"""

return self._address_family

@property
def socket_type(self) -> SocketType:
"""
Get the '_type' attribute.
"""

return self._socket_type

@property
def ip_proto(self) -> IpProto:
"""
Get the '_proto' attribute.
"""

return self._ip_proto

@property
def local_ip_address(self) -> IpAddress:
"""
Expand Down
58 changes: 58 additions & 0 deletions pytcp/socket/socket_id.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
#!/usr/bin/env python3

################################################################################
## ##
## PyTCP - Python TCP/IP stack ##
## Copyright (C) 2020-present Sebastian Majewski ##
## ##
## This program is free software: you can redistribute it and/or modify ##
## it under the terms of the GNU General Public License as published by ##
## the Free Software Foundation, either version 3 of the License, or ##
## (at your option) any later version. ##
## ##
## This program is distributed in the hope that it will be useful, ##
## but WITHOUT ANY WARRANTY; without even the implied warranty of ##
## MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the ##
## GNU General Public License for more details. ##
## ##
## You should have received a copy of the GNU General Public License ##
## along with this program. If not, see <https://www.gnu.org/licenses/>. ##
## ##
## Author's email: [email protected] ##
## Github repository: https://github.com/ccie18643/PyTCP ##
## ##
################################################################################


"""
Module contains class representing the Socket identificator.
pytcp/lib/socket.py
ver 3.0.2
"""


from __future__ import annotations

from dataclasses import dataclass
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from net_addr import Address

from .socket import AddressFamily, SocketType


@dataclass(frozen=True)
class SocketId:
"""
Store the Socket identificator data.
"""

address_family: AddressFamily
socket_type: SocketType
local_address: Address
local_port: int
remote_address: Address
remote_port: int
19 changes: 9 additions & 10 deletions pytcp/socket/tcp__metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,11 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING

from .socket import AddressFamily, IpProto, SocketType
from pytcp.socket.socket_id import SocketId

from .socket import AddressFamily, SocketType

if TYPE_CHECKING:
from net_addr import IpAddress
Expand Down Expand Up @@ -71,41 +73,38 @@ class TcpMetadata:
tracker: Tracker | None

@property
def socket_id(self) -> tuple[Any, ...]:
def socket_id(self) -> SocketId:
"""
Get the exact match socket ID.
"""

return (
return SocketId(
AddressFamily.from_ver(self.ip__ver),
SocketType.STREAM,
IpProto.TCP,
self.ip__local_address,
self.tcp__local_port,
self.ip__remote_address,
self.tcp__remote_port,
)

@property
def listening_socket_ids(self) -> list[tuple[Any, ...]]:
def listening_socket_ids(self) -> list[SocketId]:
"""
Get list of the listening socket IDs that match the metadata.
"""

return [
(
SocketId(
AddressFamily.from_ver(self.ip__ver),
SocketType.STREAM,
IpProto.TCP,
self.ip__local_address,
self.tcp__local_port,
self.ip__remote_address.unspecified,
0,
),
(
SocketId(
AddressFamily.from_ver(self.ip__ver),
SocketType.STREAM,
IpProto.TCP,
self.ip__local_address.unspecified,
self.tcp__local_port,
self.ip__remote_address.unspecified,
Expand Down
2 changes: 1 addition & 1 deletion pytcp/socket/tcp__session.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,7 +472,7 @@ def _change_state(self, state: FsmState) -> None:

# Unregister session
if self._state in {FsmState.CLOSED}:
stack.sockets.pop(self._socket.id)
stack.sockets.pop(self._socket.socket_id)
__debug__ and log(
"tcp-ss", f"[{self}] - Unregister associated socket"
)
Expand Down
12 changes: 6 additions & 6 deletions pytcp/socket/tcp__socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def __init__(
self._local_port = tcp_session.local_port
self._remote_port = tcp_session.remote_port
self._parent_socket = tcp_session.socket
stack.sockets[self.id] = self
stack.sockets[self.socket_id] = self

# Fresh socket initialization
else:
Expand Down Expand Up @@ -259,10 +259,10 @@ def bind(self, address: tuple[str, int]) -> None:
local_port = pick_local_port()

# Assigning local port makes socket "bound"
stack.sockets.pop(self.id, None)
stack.sockets.pop(self.socket_id, None)
self._local_ip_address = local_ip_address
self._local_port = local_port
stack.sockets[self.id] = self
stack.sockets[self.socket_id] = self

__debug__ and log("socket", f"<g>[{self}]</> - Bound socket")

Expand Down Expand Up @@ -292,12 +292,12 @@ def connect(self, address: tuple[str, int]) -> None:
)

# Re-register socket with new socket id
stack.sockets.pop(self.id, None)
stack.sockets.pop(self.socket_id, None)
self._local_ip_address = local_ip_address
self._local_port = local_port
self._remote_ip_address = remote_ip_address
self._remote_port = remote_port
stack.sockets[self.id] = self
stack.sockets[self.socket_id] = self

self._tcp_session = TcpSession(
local_ip_address=self._local_ip_address,
Expand Down Expand Up @@ -346,7 +346,7 @@ def listen(self) -> None:
"connections",
)

stack.sockets[self.id] = self
stack.sockets[self.socket_id] = self
self._tcp_session.listen()

def accept(self) -> tuple[Socket, tuple[str, int]]:
Expand Down
Loading

0 comments on commit 1ff534f

Please sign in to comment.