Skip to content

Commit

Permalink
Code cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
ccie18643 committed Sep 10, 2024
1 parent 4735f3d commit 6f0cb69
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 14 deletions.
17 changes: 14 additions & 3 deletions pytcp/lib/socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,7 @@ def _pick_local_port(self) -> int:

def _is_address_in_use(
self,
*,
local_ip_address: IpAddress,
local_port: int,
) -> bool:
Expand All @@ -258,12 +259,12 @@ def _is_address_in_use(

return False

def _set_ip_addresses(
def _get_ip_addresses(
self,
*,
remote_address: tuple[str, int],
local_ip_address: IpAddress,
local_port: int,
remote_port: int,
) -> tuple[Ip6Address | Ip4Address, Ip6Address | Ip4Address]:
"""
Validate the remote address and pick appropriate local IP
Expand Down Expand Up @@ -296,7 +297,7 @@ def _set_ip_addresses(
if local_ip_address.is_unspecified:
local_ip_address = pick_local_ip_address(remote_ip_address)
if local_ip_address.is_unspecified and not (
local_port == 68 and remote_port == 67
local_port == 68 and remote_address[1] == 67
):
raise gaierror(
"[Errno -2] Name or service not known - "
Expand All @@ -316,6 +317,8 @@ def bind(
The 'bind()' socket API method placeholder.
"""

raise NotImplementedError

@abstractmethod
def connect(
self,
Expand All @@ -325,6 +328,8 @@ def connect(
The 'connect()' socket API method placeholder.
"""

raise NotImplementedError

@abstractmethod
def send(
self,
Expand All @@ -334,6 +339,8 @@ def send(
The 'send()' socket API method placeholder.
"""

raise NotImplementedError

@abstractmethod
def recv(
self,
Expand All @@ -344,12 +351,16 @@ def recv(
The 'recv()' socket API method placeholder.
"""

raise NotImplementedError

@abstractmethod
def close(self) -> None:
"""
The 'close()' socket API placeholder.
"""

raise NotImplementedError

if TYPE_CHECKING:

def listen(self) -> None:
Expand Down
26 changes: 22 additions & 4 deletions pytcp/protocols/tcp/tcp__socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
from __future__ import annotations

import threading
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, override

from pytcp.lib import stack
from pytcp.lib.logger import log
Expand Down Expand Up @@ -125,24 +125,29 @@ def state(self) -> FsmState:
"""
Return FSM state of associated TCP session.
"""

if self.tcp_session is not None:
return self.tcp_session.state

return FsmState.CLOSED

@property
def tcp_session(self) -> TcpSession | None:
"""
Getter for the '_tcp_session' attribute.
"""

return self._tcp_session

@property
def parent_socket(self) -> Socket | None:
"""
Getter for the '_parent_socket' attribute.
"""

return self._parent_socket

@override
def bind(self, address: tuple[str, int]) -> None:
"""
Bind the socket to local address.
Expand Down Expand Up @@ -199,7 +204,10 @@ def bind(self, address: tuple[str, int]) -> None:

# Confirm or pick local port number
if (local_port := address[1]) > 0:
if self._is_address_in_use(local_ip_address, local_port):
if self._is_address_in_use(
local_ip_address=local_ip_address,
local_port=local_port,
):
raise OSError(
"[Errno 98] Address already in use - "
"[Local address already in use]"
Expand All @@ -215,6 +223,7 @@ def bind(self, address: tuple[str, int]) -> None:

__debug__ and log("socket", f"<g>[{self}]</> - Bound socket")

@override
def connect(self, address: tuple[str, int]) -> None:
"""
Connect local socket to remote socket.
Expand All @@ -236,8 +245,10 @@ def connect(self, address: tuple[str, int]) -> None:
local_port = self._pick_local_port()

# Set local and remote ip addresses aproprietely
local_ip_address, remote_ip_address = self._set_ip_addresses(
address, self._local_ip_address, local_port, remote_port
local_ip_address, remote_ip_address = self._get_ip_addresses(
remote_address=address,
local_ip_address=self._local_ip_address,
local_port=local_port,
)

# Re-register socket with new socket id
Expand Down Expand Up @@ -319,6 +330,7 @@ def accept(self) -> tuple[Socket, tuple[str, int]]:

return socket, (str(socket.remote_ip_address), socket.remote_port)

@override
def send(self, data: bytes) -> int:
"""
Send the data to connected remote host.
Expand All @@ -344,6 +356,7 @@ def send(self, data: bytes) -> int:
)
return bytes_sent

@override
def recv(
self, bufsize: int | None = None, timeout: float | None = None
) -> bytes:
Expand All @@ -369,17 +382,22 @@ def recv(

return data_rx

@override
def close(self) -> None:
"""
Close socket and the TCP session(s) it owns.
"""

assert self._tcp_session is not None

self._tcp_session.close()

__debug__ and log("socket", f"<g>[{self}]</> - Closed socket")

def process_tcp_packet(self, packet_rx_md: TcpMetadata) -> None:
"""
Process incoming packet's metadata.
"""

if self._tcp_session:
self._tcp_session.tcp_fsm(packet_rx_md)
32 changes: 25 additions & 7 deletions pytcp/protocols/udp/udp__socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
from __future__ import annotations

import threading
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, override

from pytcp.lib import stack
from pytcp.lib.logger import log
Expand Down Expand Up @@ -98,6 +98,7 @@ def __init__(self, family: AddressFamily) -> None:

__debug__ and log("socket", f"<g>[{self}]</> - Created socket")

@override
def bind(self, address: tuple[str, int]) -> None:
"""
Bind the socket to local address.
Expand Down Expand Up @@ -154,7 +155,10 @@ def bind(self, address: tuple[str, int]) -> None:

# Confirm or pick local port number
if (local_port := address[1]) > 0:
if self._is_address_in_use(local_ip_address, local_port):
if self._is_address_in_use(
local_ip_address=local_ip_address,
local_port=local_port,
):
raise OSError(
"[Errno 98] Address already in use - "
"[Local address already in use]"
Expand All @@ -170,6 +174,7 @@ def bind(self, address: tuple[str, int]) -> None:

__debug__ and log("socket", f"<g>[{self}]</> - Bound")

@override
def connect(self, address: tuple[str, int]) -> None:
"""
Connect local socket to remote socket.
Expand All @@ -191,8 +196,10 @@ def connect(self, address: tuple[str, int]) -> None:
local_port = self._pick_local_port()

# Set local and remote ip addresses aproprietely
local_ip_address, remote_ip_address = self._set_ip_addresses(
address, self._local_ip_address, local_port, remote_port
local_ip_address, remote_ip_address = self._get_ip_addresses(
remote_address=address,
local_ip_address=self._local_ip_address,
local_port=local_port,
)

# Re-register socket with new socket id
Expand All @@ -205,6 +212,7 @@ def connect(self, address: tuple[str, int]) -> None:

__debug__ and log("socket", f"<g>[{self}]</> - Connected socket")

@override
def send(self, data: bytes) -> int:
"""
Send the data to connected remote host.
Expand Down Expand Up @@ -267,8 +275,10 @@ def sendto(self, data: bytes, address: tuple[str, int]) -> int:
stack.sockets[str(self)] = self

# Set local and remote ip addresses aproprietely
local_ip_address, remote_ip_address = self._set_ip_addresses(
address, self._local_ip_address, self._local_port, remote_port
local_ip_address, remote_ip_address = self._get_ip_addresses(
remote_address=address,
local_ip_address=self._local_ip_address,
local_port=self._local_port,
)

tx_status = stack.packet_handler.send_udp_packet(
Expand All @@ -292,10 +302,13 @@ def sendto(self, data: bytes, address: tuple[str, int]) -> int:

return sent_data_len

@override
def recv(
self, bufsize: int | None = None, timeout: float | None = None
) -> bytes:
"""Read data from socket"""
"""
Read data from socket.
"""

# TODO - Implement support for buffsize

Expand Down Expand Up @@ -338,22 +351,27 @@ def recvfrom(
)
raise ReceiveTimeout

@override
def close(self) -> None:
"""
Close socket.
"""

stack.sockets.pop(str(self), None)

__debug__ and log("socket", f"<g>[{self}]</> - Closed socket")

def process_udp_packet(self, packet_rx_md: UdpMetadata) -> None:
"""
Process incoming packet's metadata.
"""

self._packet_rx_md.append(packet_rx_md)
self._packet_rx_md_ready.release()

def notify_unreachable(self) -> None:
"""
Set the unreachable notification.
"""

self._unreachable = True

0 comments on commit 6f0cb69

Please sign in to comment.