diff --git a/pytcp/lib/socket.py b/pytcp/lib/socket.py index d7a72d6d..bb7f413e 100755 --- a/pytcp/lib/socket.py +++ b/pytcp/lib/socket.py @@ -234,6 +234,7 @@ def _pick_local_port(self) -> int: def _is_address_in_use( self, + *, local_ip_address: IpAddress, local_port: int, ) -> bool: @@ -258,12 +259,12 @@ def _is_address_in_use( return False - def _set_ip_addresses( + def _get_ip_addresses( self, + *, remote_address: tuple[str, int], local_ip_address: IpAddress, local_port: int, - remote_port: int, ) -> tuple[Ip6Address | Ip4Address, Ip6Address | Ip4Address]: """ Validate the remote address and pick appropriate local IP @@ -296,7 +297,7 @@ def _set_ip_addresses( if local_ip_address.is_unspecified: local_ip_address = pick_local_ip_address(remote_ip_address) if local_ip_address.is_unspecified and not ( - local_port == 68 and remote_port == 67 + local_port == 68 and remote_address[1] == 67 ): raise gaierror( "[Errno -2] Name or service not known - " @@ -316,6 +317,8 @@ def bind( The 'bind()' socket API method placeholder. """ + raise NotImplementedError + @abstractmethod def connect( self, @@ -325,6 +328,8 @@ def connect( The 'connect()' socket API method placeholder. """ + raise NotImplementedError + @abstractmethod def send( self, @@ -334,6 +339,8 @@ def send( The 'send()' socket API method placeholder. """ + raise NotImplementedError + @abstractmethod def recv( self, @@ -344,12 +351,16 @@ def recv( The 'recv()' socket API method placeholder. """ + raise NotImplementedError + @abstractmethod def close(self) -> None: """ The 'close()' socket API placeholder. """ + raise NotImplementedError + if TYPE_CHECKING: def listen(self) -> None: diff --git a/pytcp/protocols/tcp/tcp__socket.py b/pytcp/protocols/tcp/tcp__socket.py index 5b704b22..8dd90d00 100755 --- a/pytcp/protocols/tcp/tcp__socket.py +++ b/pytcp/protocols/tcp/tcp__socket.py @@ -38,7 +38,7 @@ from __future__ import annotations import threading -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, override from pytcp.lib import stack from pytcp.lib.logger import log @@ -125,8 +125,10 @@ def state(self) -> FsmState: """ Return FSM state of associated TCP session. """ + if self.tcp_session is not None: return self.tcp_session.state + return FsmState.CLOSED @property @@ -134,6 +136,7 @@ def tcp_session(self) -> TcpSession | None: """ Getter for the '_tcp_session' attribute. """ + return self._tcp_session @property @@ -141,8 +144,10 @@ def parent_socket(self) -> Socket | None: """ Getter for the '_parent_socket' attribute. """ + return self._parent_socket + @override def bind(self, address: tuple[str, int]) -> None: """ Bind the socket to local address. @@ -199,7 +204,10 @@ def bind(self, address: tuple[str, int]) -> None: # Confirm or pick local port number if (local_port := address[1]) > 0: - if self._is_address_in_use(local_ip_address, local_port): + if self._is_address_in_use( + local_ip_address=local_ip_address, + local_port=local_port, + ): raise OSError( "[Errno 98] Address already in use - " "[Local address already in use]" @@ -215,6 +223,7 @@ def bind(self, address: tuple[str, int]) -> None: __debug__ and log("socket", f"[{self}] - Bound socket") + @override def connect(self, address: tuple[str, int]) -> None: """ Connect local socket to remote socket. @@ -236,8 +245,10 @@ def connect(self, address: tuple[str, int]) -> None: local_port = self._pick_local_port() # Set local and remote ip addresses aproprietely - local_ip_address, remote_ip_address = self._set_ip_addresses( - address, self._local_ip_address, local_port, remote_port + local_ip_address, remote_ip_address = self._get_ip_addresses( + remote_address=address, + local_ip_address=self._local_ip_address, + local_port=local_port, ) # Re-register socket with new socket id @@ -319,6 +330,7 @@ def accept(self) -> tuple[Socket, tuple[str, int]]: return socket, (str(socket.remote_ip_address), socket.remote_port) + @override def send(self, data: bytes) -> int: """ Send the data to connected remote host. @@ -344,6 +356,7 @@ def send(self, data: bytes) -> int: ) return bytes_sent + @override def recv( self, bufsize: int | None = None, timeout: float | None = None ) -> bytes: @@ -369,17 +382,22 @@ def recv( return data_rx + @override def close(self) -> None: """ Close socket and the TCP session(s) it owns. """ + assert self._tcp_session is not None + self._tcp_session.close() + __debug__ and log("socket", f"[{self}] - Closed socket") def process_tcp_packet(self, packet_rx_md: TcpMetadata) -> None: """ Process incoming packet's metadata. """ + if self._tcp_session: self._tcp_session.tcp_fsm(packet_rx_md) diff --git a/pytcp/protocols/udp/udp__socket.py b/pytcp/protocols/udp/udp__socket.py index 3b97509c..aa574c35 100755 --- a/pytcp/protocols/udp/udp__socket.py +++ b/pytcp/protocols/udp/udp__socket.py @@ -39,7 +39,7 @@ from __future__ import annotations import threading -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, override from pytcp.lib import stack from pytcp.lib.logger import log @@ -98,6 +98,7 @@ def __init__(self, family: AddressFamily) -> None: __debug__ and log("socket", f"[{self}] - Created socket") + @override def bind(self, address: tuple[str, int]) -> None: """ Bind the socket to local address. @@ -154,7 +155,10 @@ def bind(self, address: tuple[str, int]) -> None: # Confirm or pick local port number if (local_port := address[1]) > 0: - if self._is_address_in_use(local_ip_address, local_port): + if self._is_address_in_use( + local_ip_address=local_ip_address, + local_port=local_port, + ): raise OSError( "[Errno 98] Address already in use - " "[Local address already in use]" @@ -170,6 +174,7 @@ def bind(self, address: tuple[str, int]) -> None: __debug__ and log("socket", f"[{self}] - Bound") + @override def connect(self, address: tuple[str, int]) -> None: """ Connect local socket to remote socket. @@ -191,8 +196,10 @@ def connect(self, address: tuple[str, int]) -> None: local_port = self._pick_local_port() # Set local and remote ip addresses aproprietely - local_ip_address, remote_ip_address = self._set_ip_addresses( - address, self._local_ip_address, local_port, remote_port + local_ip_address, remote_ip_address = self._get_ip_addresses( + remote_address=address, + local_ip_address=self._local_ip_address, + local_port=local_port, ) # Re-register socket with new socket id @@ -205,6 +212,7 @@ def connect(self, address: tuple[str, int]) -> None: __debug__ and log("socket", f"[{self}] - Connected socket") + @override def send(self, data: bytes) -> int: """ Send the data to connected remote host. @@ -267,8 +275,10 @@ def sendto(self, data: bytes, address: tuple[str, int]) -> int: stack.sockets[str(self)] = self # Set local and remote ip addresses aproprietely - local_ip_address, remote_ip_address = self._set_ip_addresses( - address, self._local_ip_address, self._local_port, remote_port + local_ip_address, remote_ip_address = self._get_ip_addresses( + remote_address=address, + local_ip_address=self._local_ip_address, + local_port=self._local_port, ) tx_status = stack.packet_handler.send_udp_packet( @@ -292,10 +302,13 @@ def sendto(self, data: bytes, address: tuple[str, int]) -> int: return sent_data_len + @override def recv( self, bufsize: int | None = None, timeout: float | None = None ) -> bytes: - """Read data from socket""" + """ + Read data from socket. + """ # TODO - Implement support for buffsize @@ -338,17 +351,21 @@ def recvfrom( ) raise ReceiveTimeout + @override def close(self) -> None: """ Close socket. """ + stack.sockets.pop(str(self), None) + __debug__ and log("socket", f"[{self}] - Closed socket") def process_udp_packet(self, packet_rx_md: UdpMetadata) -> None: """ Process incoming packet's metadata. """ + self._packet_rx_md.append(packet_rx_md) self._packet_rx_md_ready.release() @@ -356,4 +373,5 @@ def notify_unreachable(self) -> None: """ Set the unreachable notification. """ + self._unreachable = True