From eb60a3b513e2c1e88634ff1ed2bb721e1ffe64cb Mon Sep 17 00:00:00 2001 From: Sebastian Majewski Date: Fri, 10 Nov 2023 22:38:47 -0600 Subject: [PATCH] Refactored ARP related code. --- .gitignore | 2 +- .vscode/settings.json | 5 +- pytcp/config.py | 19 +- pytcp/lib/errors.py | 79 ++++++ pytcp/lib/protocol_enum.py | 59 ++++ pytcp/protocols/arp/fpa.py | 72 +++-- pytcp/protocols/arp/fpp.py | 209 ++++++-------- pytcp/protocols/arp/phrx.py | 185 +++++++------ pytcp/protocols/arp/phtx.py | 14 +- pytcp/protocols/arp/ps.py | 174 +++++++++++- pytcp/protocols/ether/fpa.py | 69 +---- pytcp/protocols/ether/fpp.py | 133 ++++----- pytcp/protocols/ether/phrx.py | 22 +- pytcp/protocols/ether/ps.py | 110 +++++++- pytcp/protocols/icmp4/fpa.py | 223 +++++++-------- pytcp/protocols/icmp4/fpp.py | 265 +++++++++--------- pytcp/protocols/icmp4/phrx.py | 220 +++++++++------ pytcp/protocols/icmp4/phtx.py | 75 +++--- pytcp/protocols/icmp4/ps.py | 301 ++++++++++++++++++++- pytcp/protocols/icmp6/fpp.py | 257 +++++++++++------- pytcp/protocols/icmp6/phrx.py | 9 +- pytcp/protocols/ip4/fpa.py | 18 +- pytcp/protocols/ip4/fpp.py | 102 ++++--- pytcp/protocols/ip4/phrx.py | 9 +- pytcp/protocols/ip6/fpa.py | 4 +- pytcp/protocols/ip6/fpp.py | 61 +++-- pytcp/protocols/ip6/phrx.py | 11 +- pytcp/protocols/ip6_ext_frag/fpp.py | 46 ++-- pytcp/protocols/raw/fpa.py | 4 +- pytcp/protocols/tcp/fpp.py | 103 ++++--- pytcp/protocols/tcp/phrx.py | 9 +- pytcp/protocols/udp/fpp.py | 60 +++-- pytcp/protocols/udp/phrx.py | 20 +- pytcp/subsystems/arp_cache.py | 4 +- pytcp/subsystems/packet_handler.py | 31 ++- tests/integration/packet_flows_rx.py | 2 - tests/integration/packet_flows_rx_tx.py | 2 - tests/unit/mock_network.py | 2 - tests/unit/protocols__arp__fpa.py | 30 +-- tests/unit/protocols__arp__phtx.py | 6 +- tests/unit/protocols__ether__fpa.py | 18 +- tests/unit/protocols__icmp4__fpa.py | 345 +++++++++++------------- tests/unit/protocols__icmp4__phtx.py | 43 +-- tests/unit/protocols__ip4__fpa.py | 20 +- tests/unit/protocols__ip6__fpa.py | 4 +- 45 files changed, 2089 insertions(+), 1367 deletions(-) create mode 100755 pytcp/lib/errors.py create mode 100755 pytcp/lib/protocol_enum.py diff --git a/.gitignore b/.gitignore index 1b057405..92afa22f 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,2 @@ - __pycache__/ +venv/ diff --git a/.vscode/settings.json b/.vscode/settings.json index 49d4639f..176ca276 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -2,5 +2,8 @@ "python.analysis.extraPaths": [ "./pytcp" ], - "python.formatting.provider": "black" + "python.formatting.provider": "none", + "[python]": { + "editor.defaultFormatter": "ms-python.black-formatter" + } } \ No newline at end of file diff --git a/pytcp/config.py b/pytcp/config.py index 985ffb0f..2cd4ddeb 100755 --- a/pytcp/config.py +++ b/pytcp/config.py @@ -73,22 +73,6 @@ } LOG_DEBUG = False -# Packet integrity sanity check, if enabled it protects the protocol parsers -# from being exposed to malformed or malicious packets that could cause them -# to crash during packet parsing. It progessively check appropriate length -# fields and ensure they are set within sane boundaries. It also checks -# packet's actual header/options/data lengths against above values and default -# minimum/maximum lengths for given protocol. Also packet options (if any) are -# checked in similar fashion to ensure they will not exploit or crash parser. -PACKET_INTEGRITY_CHECK = True - -# Packet sanity check, if enabled it validates packet's fields to detect invalid -# values or invalid combinations of values. For example in TCP/UDP it drops -# packets with port set to 0, in TCP it drop packet with SYN and FIN flags set -# simultaneously, for ICMPv6 it provides very detailed check of messages -# integrity. -PACKET_SANITY_CHECK = True - # Drop IPv4 packets containing options - this seems to be widely adopted # security feature. Stack parses but doesn't support IPv4 options as they are # mostly useless anyway. @@ -144,3 +128,6 @@ # Native support for UDP Echo (used for packet flow unit testing only and should # always be disabled). UDP_ECHO_NATIVE_DISABLE = True + +# LRU cache size, used by packet parsers to cache parsed field values. +LRU_CACHE_SIZE = 16 diff --git a/pytcp/lib/errors.py b/pytcp/lib/errors.py new file mode 100755 index 00000000..358f2ac5 --- /dev/null +++ b/pytcp/lib/errors.py @@ -0,0 +1,79 @@ +#!/usr/bin/env python3 + +############################################################################ +# # +# PyTCP - Python TCP/IP stack # +# Copyright (C) 2020-present Sebastian Majewski # +# # +# This program is free software: you can redistribute it and/or modify # +# it under the terms of the GNU General Public License as published by # +# the Free Software Foundation, either version 3 of the License, or # +# (at your option) any later version. # +# # +# This program is distributed in the hope that it will be useful, # +# but WITHOUT ANY WARRANTY; without even the implied warranty of # +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # +# GNU General Public License for more details. # +# # +# You should have received a copy of the GNU General Public License # +# along with this program. If not, see . # +# # +# Author's email: ccie18643@gmail.com # +# Github repository: https://github.com/ccie18643/PyTCP # +# # +############################################################################ + + +""" +Module contains methods supporting errors. + +pytcp/lib/errors.py + +ver 2.7 +""" + + +from __future__ import annotations + + +class PyTcpError(Exception): + """ + Base class for all PyTCP exceptions. + """ + + ... + + +class UnsupportedCaseError(PyTcpError): + """ + Exception raised when the not supposed to be reached + 'match' case is being reached for whatever reason. + """ + + ... + + +class PacketValidationError(PyTcpError): + """ + Exception raised when packet validation fails. + """ + + ... + + +class PacketIntegrityError(PacketValidationError): + """ + Exception raised when integrity check fails. + """ + + def __init__(self, message: str): + super().__init__("[INTEGRITY]" + message) + + +class PacketSanityError(PacketValidationError): + """ + Exception raised when sanity check fails. + """ + + def __init__(self, message: str): + super().__init__("[SANITY]" + message) diff --git a/pytcp/lib/protocol_enum.py b/pytcp/lib/protocol_enum.py new file mode 100755 index 00000000..59507074 --- /dev/null +++ b/pytcp/lib/protocol_enum.py @@ -0,0 +1,59 @@ +#!/usr/bin/env python3 + +############################################################################ +# # +# PyTCP - Python TCP/IP stack # +# Copyright (C) 2020-present Sebastian Majewski # +# # +# This program is free software: you can redistribute it and/or modify # +# it under the terms of the GNU General Public License as published by # +# the Free Software Foundation, either version 3 of the License, or # +# (at your option) any later version. # +# # +# This program is distributed in the hope that it will be useful, # +# but WITHOUT ANY WARRANTY; without even the implied warranty of # +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # +# GNU General Public License for more details. # +# # +# You should have received a copy of the GNU General Public License # +# along with this program. If not, see . # +# # +# Author's email: ccie18643@gmail.com # +# Github repository: https://github.com/ccie18643/PyTCP # +# # +############################################################################ + + +""" +Module contains the ProtocolEnum class. + +pytcp/lib/protocol_enum.py + +ver 2.7 +""" + + +from __future__ import annotations + +from enum import Enum +from typing import Self + + +class ProtocolEnum(Enum): + def __int__(self) -> int: + return int(self.value) + + def __str__(self) -> str: + return str(self.value) + + @staticmethod + def _extract(frame: bytes) -> int: + raise NotImplementedError + + @classmethod + def from_frame(cls, /, frame: bytes) -> Self: + return cls(cls._extract(frame)) + + @classmethod + def sanity_check(cls, /, frame: bytes) -> bool: + return cls._extract(frame) in cls diff --git a/pytcp/protocols/arp/fpa.py b/pytcp/protocols/arp/fpa.py index 5cc0cef5..2423fdd5 100755 --- a/pytcp/protocols/arp/fpa.py +++ b/pytcp/protocols/arp/fpa.py @@ -41,17 +41,22 @@ from pytcp.lib.ip4_address import Ip4Address from pytcp.lib.mac_address import MacAddress from pytcp.lib.tracker import Tracker -from pytcp.protocols.arp.ps import ARP_HEADER_LEN, ARP_OP_REPLY, ARP_OP_REQUEST -from pytcp.protocols.ether.ps import ETHER_TYPE_ARP - - -class ArpAssembler: +from pytcp.protocols.arp.ps import ( + ARP_HEADER_LEN, + Arp, + ArpHardwareLength, + ArpHardwareType, + ArpOperation, + ArpProtocolLength, + ArpProtocolType, +) + + +class ArpAssembler(Arp): """ ARP packet assembler support class. """ - ether_type = ETHER_TYPE_ARP - def __init__( self, *, @@ -59,69 +64,54 @@ def __init__( spa: Ip4Address = Ip4Address(0), tha: MacAddress = MacAddress(0), tpa: Ip4Address = Ip4Address(0), - oper: int = ARP_OP_REQUEST, + oper: ArpOperation = ArpOperation.REQUEST, echo_tracker: Tracker | None = None, ) -> None: """ Class constructor. """ - assert oper in (ARP_OP_REQUEST, ARP_OP_REPLY), f"{oper=}" - self._tracker = Tracker(prefix="TX", echo_tracker=echo_tracker) - self._hrtype: int = 1 - self._prtype: int = 0x0800 - self._hrlen: int = 6 - self._prlen: int = 4 - self._oper: int = oper - self._sha: MacAddress = sha - self._spa: Ip4Address = spa - self._tha: MacAddress = tha - self._tpa: Ip4Address = tpa + self._hrtype = ArpHardwareType.ETHERNET + self._prtype = ArpProtocolType.IP4 + self._hrlen = ArpHardwareLength.ETHERNET + self._prlen = ArpProtocolLength.IP4 + self._oper = oper + self._sha = sha + self._spa = spa + self._tha = tha + self._tpa = tpa def __len__(self) -> int: """ Length of the packet. """ - return ARP_HEADER_LEN - def __str__(self) -> str: - """ - Packet log string. - """ - if self._oper == ARP_OP_REQUEST: - return ( - f"ARP request {self._spa} / {self._sha}" - f" > {self._tpa} / {self._tha}" - ) - if self._oper == ARP_OP_REPLY: - return ( - f"ARP reply {self._spa} / {self._sha}" - f" > {self._tpa} / {self._tha}" - ) - return f"ARP request unknown operation {self._oper}" + return ARP_HEADER_LEN @property def tracker(self) -> Tracker: """ Getter for the '_tracker' property. """ + return self._tracker - def assemble(self, frame: memoryview) -> None: + def assemble(self, /, frame: memoryview) -> None: """ Assemble packet into the raw form. """ + struct.pack_into( "!HH BBH 6s 4s 6s 4s", frame, 0, - self._hrtype, - self._prtype, - self._hrlen, - self._prlen, - self._oper, + int(self._hrtype), + int(self._prtype), + int(self._hrlen), + int(self._prlen), + int(self._oper), bytes(self._sha), bytes(self._spa), bytes(self._tha), diff --git a/pytcp/protocols/arp/fpp.py b/pytcp/protocols/arp/fpp.py index a572c225..46709b73 100755 --- a/pytcp/protocols/arp/fpp.py +++ b/pytcp/protocols/arp/fpp.py @@ -38,180 +38,137 @@ from __future__ import annotations -import struct from typing import TYPE_CHECKING -from pytcp import config +from pytcp.lib.errors import PacketIntegrityError, PacketSanityError from pytcp.lib.ip4_address import Ip4Address from pytcp.lib.mac_address import MacAddress -from pytcp.protocols.arp.ps import ARP_HEADER_LEN, ARP_OP_REPLY, ARP_OP_REQUEST +from pytcp.protocols.arp.ps import ( + ARP_HEADER_LEN, + Arp, + ArpHardwareLength, + ArpHardwareType, + ArpOperation, + ArpProtocolLength, + ArpProtocolType, +) if TYPE_CHECKING: from pytcp.lib.packet import PacketRx -class ArpParser: +class ArpIntegrityError(PacketIntegrityError): """ - ARP packet parser class. + Exception raised when ARP packet integrity check fails. """ - def __init__(self, packet_rx: PacketRx) -> None: - """ - Class constructor. - """ - - packet_rx.arp = self - - self._frame = packet_rx.frame + def __init__(self, message: str): + super().__init__("[ARP] " + message) - packet_rx.parse_failed = ( - self._packet_integrity_check() or self._packet_sanity_check() - ) - def __len__(self) -> int: - """ - Number of bytes remaining in the frame. - """ - return len(self._frame) +class ArpSanityError(PacketSanityError): + """ + Exception raised when ARP packet sanity check fails. + """ - def __str__(self) -> str: - """ - Packet log string. - """ - if self.oper == ARP_OP_REQUEST: - return ( - f"ARP request {self.spa} / {self.sha}" - f" > {self.tpa} / {self.tha}" - ) - if self.oper == ARP_OP_REPLY: - return ( - f"ARP reply {self.spa} / {self.sha}" - f" > {self.tpa} / {self.tha}" - ) - return f"ARP request unknown operation {self.oper}" + def __init__(self, message: str): + super().__init__("[ARP] " + message) - @property - def hrtype(self) -> int: - """ - Read the 'Hardware address type' field. - """ - if "_cache__hrtype" not in self.__dict__: - self._cache__hrtype: int = struct.unpack("!H", self._frame[0:2])[0] - return self._cache__hrtype - @property - def prtype(self) -> int: - """ - Read the 'Protocol address type' field. - """ - if "_cache__prtype" not in self.__dict__: - self._cache__prtype: int = struct.unpack("!H", self._frame[2:4])[0] - return self._cache__prtype +class ArpParser(Arp): + """ + ARP packet parser class. + """ - @property - def hrlen(self) -> int: + def __init__(self, /, packet_rx: PacketRx) -> None: """ - Read the 'Hardware address length' field. + Class constructor. """ - return self._frame[4] - @property - def prlen(self) -> int: - """ - Read the 'Protocol address length' field. - """ - return self._frame[5] + packet_rx.arp = self - @property - def oper(self) -> int: - """ - Read the 'Operation' field. - """ - if "_cache__oper" not in self.__dict__: - self._cache__oper: int = struct.unpack("!H", self._frame[6:8])[0] - return self._cache__oper + self._frame = packet_rx.frame - @property - def sha(self) -> MacAddress: - """ - Read the 'Sender hardware address' field. - """ - if "_cache__sha" not in self.__dict__: - self._cache__sha = MacAddress(self._frame[8:14]) - return self._cache__sha + self._packet_integrity_check() + self._packet_sanity_check() - @property - def spa(self) -> Ip4Address: - """ - Read the 'Sender protocol address' field. - """ - if "_cache__spa" not in self.__dict__: - self._cache__spa = Ip4Address(self._frame[14:18]) - return self._cache__spa + self._hrtype = ArpHardwareType.from_frame(self._frame) + self._prtype = ArpProtocolType.from_frame(self._frame) + self._hrlen = ArpHardwareLength.from_frame(self._frame) + self._prlen = ArpProtocolLength.from_frame(self._frame) + self._oper = ArpOperation.from_frame(self._frame) + self._sha = MacAddress(self._frame[8:14]) + self._spa = Ip4Address(self._frame[14:18]) + self._tha = MacAddress(self._frame[18:24]) + self._tpa = Ip4Address(self._frame[24:28]) - @property - def tha(self) -> MacAddress: + def __len__(self) -> int: """ - Read the 'Target hardware address' field. + Number of bytes remaining in the frame. """ - if "_cache__tha" not in self.__dict__: - self._cache__tha = MacAddress(self._frame[18:24]) - return self._cache__tha - @property - def tpa(self) -> Ip4Address: - """ - Read the 'Target protocol address' field. - """ - if "_cache__tpa" not in self.__dict__: - self._cache__tpa = Ip4Address(self._frame[24:28]) - return self._cache__tpa + return len(self._frame) @property def packet_copy(self) -> bytes: """ Read the whole packet. """ + if "_cache__packet_copy" not in self.__dict__: self._cache__packet_copy = bytes(self._frame[:ARP_HEADER_LEN]) + return self._cache__packet_copy - def _packet_integrity_check(self) -> str: + def _packet_integrity_check(self) -> None: """ Packet integrity check to be run on raw packet prior to parsing - to make sure parsing is safe + to make sure parsing is safe. """ - if not config.PACKET_INTEGRITY_CHECK: - return "" - if len(self) < ARP_HEADER_LEN: - return "ARP integrity - wrong packet length (I)" - - return "" + raise ArpIntegrityError( + "The minimum packet length must be " + f"'{ARP_HEADER_LEN}' bytes, got {len(self)} bytes." + ) - def _packet_sanity_check(self) -> str: + def _packet_sanity_check(self) -> None: """ - Packet sanity check to be run on parsed packet to make sure packet's - fields contain sane values + Packet sanity check to be run on parsed packet to make sure + packet's fields contain sane values. """ - if not config.PACKET_SANITY_CHECK: - return "" - - if self.hrtype != 1: - return "ARP sanity - 'arp_hrtype' must be 1" - - if self.prtype != 0x0800: - return "ARP sanity - 'arp_prtype' must be 0x0800" + if not ArpHardwareType.sanity_check(self._frame): + raise ArpSanityError( + "The 'arp_hrtype' field value must be one of " + f"{[hrtype.value for hrtype in ArpHardwareType]}, " + f"got '{self.hrtype}'." + ) - if self.hrlen != 6: - return "ARP sanity - 'arp_hrlen' must be 6" + if not ArpProtocolType.sanity_check(self._frame): + raise ArpSanityError( + "The 'arp_prtype' field value must be one of " + f"{[prtype.value for prtype in ArpHardwareType]}, " + f"got '{self.prtype}'." + ) - if self.prlen != 4: - return "ARP sanity - 'arp_prlen' must be 4" + if not ArpHardwareLength.sanity_check(self._frame): + raise ArpSanityError( + "The 'arp_hrlen' field value must be one of " + f"{[hrlen.value for hrlen in ArpHardwareLength]}, " + f"got '{self.hrlen}'." + ) - if self.oper not in {1, 2}: - return "ARP sanity - 'oper' must be [1-2]" + if not ArpProtocolLength.sanity_check(self._frame): + raise ArpSanityError( + "The 'arp_prlen' field value must be one of " + f"{[prlen.value for prlen in ArpProtocolLength]}, " + f"got '{self.prlen}'." + ) - return "" + if not ArpOperation.sanity_check(self._frame): + raise ArpSanityError( + "The 'oper' field value must be one of " + f"{[oper.value for oper in ArpOperation]}, " + f"got '{self.oper}'." + ) diff --git a/pytcp/protocols/arp/phrx.py b/pytcp/protocols/arp/phrx.py index 9ed0e8fc..e986f6a7 100755 --- a/pytcp/protocols/arp/phrx.py +++ b/pytcp/protocols/arp/phrx.py @@ -44,9 +44,10 @@ from pytcp import config from pytcp.lib import stack +from pytcp.lib.errors import PacketValidationError from pytcp.lib.logger import log from pytcp.protocols.arp.fpp import ArpParser -from pytcp.protocols.arp.ps import ARP_OP_REPLY, ARP_OP_REQUEST +from pytcp.protocols.arp.ps import ArpOperation if TYPE_CHECKING: from pytcp.lib.packet import PacketRx @@ -60,110 +61,124 @@ def _phrx_arp(self: PacketHandler, packet_rx: PacketRx) -> None: self.packet_stats_rx.arp__pre_parse += 1 - ArpParser(packet_rx) - - if packet_rx.parse_failed: + try: + ArpParser(packet_rx) + except PacketValidationError as error: self.packet_stats_rx.arp__failed_parse__drop += 1 __debug__ and log( "arp", - f"{packet_rx.tracker} - {packet_rx.parse_failed}", + f"{packet_rx.tracker} - {error}", ) return __debug__ and log("arp", f"{packet_rx.tracker} - {packet_rx.arp}") - if packet_rx.arp.oper == ARP_OP_REQUEST: - self.packet_stats_rx.arp__op_request += 1 - # Check if request contains our IP address in SPA field, - # this indicates IP address conflict - if packet_rx.arp.spa in self.ip4_unicast: - self.packet_stats_rx.arp__op_request__ip_conflict += 1 - __debug__ and log( - "arp", - f"{packet_rx.tracker} - IP ({packet_rx.arp.spa}) " - f"conflict detected with host at {packet_rx.arp.sha}", - ) - return + match packet_rx.arp.oper: + case ArpOperation.REQUEST: + _phrx_arp__request(self, packet_rx) + case ArpOperation.REPLY: + _phrx_arp__reply(self, packet_rx) - # Check if the request is for one of our IP addresses, - # if so the craft ARP reply packet and send it out - if packet_rx.arp.tpa in self.ip4_unicast: - self.packet_stats_rx.arp__op_request__tpa_stack__respond += 1 - self._phtx_arp( - ether_src=self.mac_unicast, - ether_dst=packet_rx.arp.sha, - arp_oper=ARP_OP_REPLY, - arp_sha=self.mac_unicast, - arp_spa=packet_rx.arp.tpa, - arp_tha=packet_rx.arp.sha, - arp_tpa=packet_rx.arp.spa, - echo_tracker=packet_rx.tracker, - ) - # Update ARP cache with the mapping learned from the received - # ARP request that was destined to this stack - if config.ARP_CACHE_UPDATE_FROM_DIRECT_REQUEST: - self.packet_stats_rx.arp__op_request__update_arp_cache += 1 - __debug__ and log( - "arp", - f"{packet_rx.tracker} - Adding/refreshing " - "ARP cache entry from direct request " - f"- {packet_rx.arp.spa} -> {packet_rx.arp.sha}", - ) - stack.arp_cache.add_entry(packet_rx.arp.spa, packet_rx.arp.sha) - return +def _phrx_arp__request(self: PacketHandler, packet_rx: PacketRx) -> None: + """ + Handle inbound ARP request packets. + """ - else: - # Drop packet if TPA does not match one of our IP addresses - self.packet_stats_rx.arp__op_request__tpa_unknown__drop += 1 - return + self.packet_stats_rx.arp__op_request += 1 + # Check if request contains our IP address in SPA field, + # this indicates IP address conflict + if packet_rx.arp.spa in self.ip4_unicast: + self.packet_stats_rx.arp__op_request__ip_conflict += 1 + __debug__ and log( + "arp", + f"{packet_rx.tracker} - IP ({packet_rx.arp.spa}) " + f"conflict detected with host at {packet_rx.arp.sha}", + ) + return - # Handle ARP reply - elif packet_rx.arp.oper == ARP_OP_REPLY: - self.packet_stats_rx.arp__op_reply += 1 - # Check for ARP reply that is response to our ARP probe, this indicates - # the IP address we trying to claim is in use - if packet_rx.ether.dst == self.mac_unicast: - if ( - packet_rx.arp.spa - in [_.address for _ in self.ip4_host_candidate] - and packet_rx.arp.tha == self.mac_unicast - and packet_rx.arp.tpa.is_unspecified - ): - self.packet_stats_rx.arp__op_reply__ip_conflict += 1 - __debug__ and log( - "arp", - f"{packet_rx.tracker} - ARP Probe detected " - f"conflict for IP {packet_rx.arp.spa} with host at " - f"{packet_rx.arp.sha}", - ) - stack.arp_probe_unicast_conflict.add(packet_rx.arp.spa) - return - - # Update ARP cache with mapping received as direct ARP reply - if packet_rx.ether.dst == self.mac_unicast: - self.packet_stats_rx.arp__op_reply__update_arp_cache += 1 + # Check if the request is for one of our IP addresses, + # if so the craft ARP reply packet and send it out + if packet_rx.arp.tpa in self.ip4_unicast: + self.packet_stats_rx.arp__op_request__tpa_stack__respond += 1 + self._phtx_arp( + ether_src=self.mac_unicast, + ether_dst=packet_rx.arp.sha, + arp_oper=ArpOperation.REPLY, + arp_sha=self.mac_unicast, + arp_spa=packet_rx.arp.tpa, + arp_tha=packet_rx.arp.sha, + arp_tpa=packet_rx.arp.spa, + echo_tracker=packet_rx.tracker, + ) + + # Update ARP cache with the mapping learned from the received + # ARP request that was destined to this stack + if config.ARP_CACHE_UPDATE_FROM_DIRECT_REQUEST: + self.packet_stats_rx.arp__op_request__update_arp_cache += 1 __debug__ and log( "arp", - f"{packet_rx.tracker} - Adding/refreshing ARP cache entry " - f"from direct reply - {packet_rx.arp.spa} " - f"-> {packet_rx.arp.sha}", + f"{packet_rx.tracker} - Adding/refreshing " + "ARP cache entry from direct request " + f"- {packet_rx.arp.spa} -> {packet_rx.arp.sha}", ) stack.arp_cache.add_entry(packet_rx.arp.spa, packet_rx.arp.sha) - return + return - # Update ARP cache with mapping received as gratuitous ARP reply + else: + # Drop packet if TPA does not match one of our IP addresses + self.packet_stats_rx.arp__op_request__tpa_unknown__drop += 1 + return + + +def _phrx_arp__reply(self: PacketHandler, packet_rx: PacketRx) -> None: + """ + Handle inbound ARP reply packets. + """ + + self.packet_stats_rx.arp__op_reply += 1 + # Check for ARP reply that is response to our ARP probe, this indicates + # the IP address we trying to claim is in use + if packet_rx.ether.dst == self.mac_unicast: if ( - packet_rx.ether.dst.is_broadcast - and packet_rx.arp.spa == packet_rx.arp.tpa - and config.ARP_CACHE_UPDATE_FROM_GRATUITIOUS_REPLY + packet_rx.arp.spa in [_.address for _ in self.ip4_host_candidate] + and packet_rx.arp.tha == self.mac_unicast + and packet_rx.arp.tpa.is_unspecified ): - self.packet_stats_rx.arp__op_reply__update_arp_cache_gratuitous += 1 + self.packet_stats_rx.arp__op_reply__ip_conflict += 1 __debug__ and log( "arp", - f"{packet_rx.tracker} - Adding/refreshing ARP cache entry " - f"from gratuitous reply - {packet_rx.arp.spa} " - f"-> {packet_rx.arp.sha}", + f"{packet_rx.tracker} - ARP Probe detected " + f"conflict for IP {packet_rx.arp.spa} with host at " + f"{packet_rx.arp.sha}", ) - stack.arp_cache.add_entry(packet_rx.arp.spa, packet_rx.arp.sha) + stack.arp_probe_unicast_conflict.add(packet_rx.arp.spa) return + + # Update ARP cache with mapping received as direct ARP reply + if packet_rx.ether.dst == self.mac_unicast: + self.packet_stats_rx.arp__op_reply__update_arp_cache += 1 + __debug__ and log( + "arp", + f"{packet_rx.tracker} - Adding/refreshing ARP cache entry " + f"from direct reply - {packet_rx.arp.spa} " + f"-> {packet_rx.arp.sha}", + ) + stack.arp_cache.add_entry(packet_rx.arp.spa, packet_rx.arp.sha) + return + + # Update ARP cache with mapping received as gratuitous ARP reply + if ( + packet_rx.ether.dst.is_broadcast + and packet_rx.arp.spa == packet_rx.arp.tpa + and config.ARP_CACHE_UPDATE_FROM_GRATUITIOUS_REPLY + ): + self.packet_stats_rx.arp__op_reply__update_arp_cache_gratuitous += 1 + __debug__ and log( + "arp", + f"{packet_rx.tracker} - Adding/refreshing ARP cache entry " + f"from gratuitous reply - {packet_rx.arp.spa} " + f"-> {packet_rx.arp.sha}", + ) + stack.arp_cache.add_entry(packet_rx.arp.spa, packet_rx.arp.sha) + return diff --git a/pytcp/protocols/arp/phtx.py b/pytcp/protocols/arp/phtx.py index a2dc435d..a472345f 100755 --- a/pytcp/protocols/arp/phtx.py +++ b/pytcp/protocols/arp/phtx.py @@ -45,7 +45,7 @@ from pytcp.lib.tracker import Tracker from pytcp.lib.tx_status import TxStatus from pytcp.protocols.arp.fpa import ArpAssembler -from pytcp.protocols.arp.ps import ARP_OP_REPLY, ARP_OP_REQUEST +from pytcp.protocols.arp.ps import ArpOperation if TYPE_CHECKING: from pytcp.lib.ip4_address import Ip4Address @@ -57,7 +57,7 @@ def _phtx_arp( *, ether_src: MacAddress, ether_dst: MacAddress, - arp_oper: int, + arp_oper: ArpOperation, arp_sha: MacAddress, arp_spa: Ip4Address, arp_tha: MacAddress, @@ -76,11 +76,11 @@ def _phtx_arp( self.packet_stats_tx.arp__no_proto_support__drop += 1 return TxStatus.DROPED__ARP__NO_PROTOCOL_SUPPORT - if arp_oper == ARP_OP_REQUEST: - self.packet_stats_tx.arp__op_request__send += 1 - - if arp_oper == ARP_OP_REPLY: - self.packet_stats_tx.arp__op_reply__send += 1 + match arp_oper: + case ArpOperation.REQUEST: + self.packet_stats_tx.arp__op_request__send += 1 + case ArpOperation.REPLY: + self.packet_stats_tx.arp__op_reply__send += 1 arp_packet_tx = ArpAssembler( oper=arp_oper, diff --git a/pytcp/protocols/arp/ps.py b/pytcp/protocols/arp/ps.py index c268cd3e..aca19ed2 100755 --- a/pytcp/protocols/arp/ps.py +++ b/pytcp/protocols/arp/ps.py @@ -33,7 +33,13 @@ """ -from __future__ import annotations +import struct +from abc import ABC, abstractmethod + +from pytcp.lib.ip4_address import Ip4Address +from pytcp.lib.mac_address import MacAddress +from pytcp.lib.protocol_enum import ProtocolEnum +from pytcp.protocols.ether.ps import EtherType # ARP packet header - IPv4 stack version only @@ -53,8 +59,168 @@ # | Target IP address | # +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - ARP_HEADER_LEN = 28 -ARP_OP_REQUEST = 1 -ARP_OP_REPLY = 2 + +class ArpHardwareType(ProtocolEnum): + ETHERNET = 1 + + @staticmethod + def _extract(frame: bytes) -> int: + return int(struct.unpack("!H", frame[0:2])[0]) + + +class ArpProtocolType(ProtocolEnum): + IP4 = 0x0800 + + @staticmethod + def _extract(frame: bytes) -> int: + return int(struct.unpack("!H", frame[2:4])[0]) + + +class ArpHardwareLength(ProtocolEnum): + ETHERNET = 6 + + @staticmethod + def _extract(frame: bytes) -> int: + return int(frame[4]) + + +class ArpProtocolLength(ProtocolEnum): + IP4 = 4 + + @staticmethod + def _extract(frame: bytes) -> int: + return int(frame[5]) + + +class ArpOperation(ProtocolEnum): + REQUEST = 1 + REPLY = 2 + + @staticmethod + def _extract(frame: bytes) -> int: + return int(struct.unpack("!H", frame[6:8])[0]) + + +class Arp(ABC): + """ + Base class for ARP packet parser and assembler classes. + """ + + _ether_type = EtherType.ARP + + _hrtype: ArpHardwareType + _prtype: ArpProtocolType + _hrlen: ArpHardwareLength + _prlen: ArpProtocolLength + _oper: ArpOperation + _sha: MacAddress + _spa: Ip4Address + _tha: MacAddress + _tpa: Ip4Address + + def __str__(self) -> str: + """ + Packet log string. + """ + + match self._oper: + case ArpOperation.REQUEST: + return ( + f"ARP request {self._spa} / {self._sha}" + f" > {self.tpa} / {self.tha}" + ) + case ArpOperation.REPLY: + return ( + f"ARP reply {self._spa} / {self._sha}" + f" > {self._tpa} / {self._tha}" + ) + + @abstractmethod + def __len__(self) -> int: + """ + Length of the packet. + """ + + raise NotImplementedError + + @property + def ether_type(self) -> EtherType: + """ + Getter for the '_ether_type' property. + """ + + return self._ether_type + + @property + def hrtype(self) -> ArpHardwareType: + """ + Getter for the '_hrtype' property. + """ + + return self._hrtype + + @property + def prtype(self) -> ArpProtocolType: + """ + Getter for the '_prtype' property. + """ + + return self._prtype + + @property + def hrlen(self) -> ArpHardwareLength: + """ + Getter for the '_hrlen' property. + """ + + return self._hrlen + + @property + def prlen(self) -> ArpProtocolLength: + """ + Getter for the '_prlen' property. + """ + + return self._prlen + + @property + def oper(self) -> ArpOperation: + """ + Getter for the '_oper' property. + """ + + return self._oper + + @property + def sha(self) -> MacAddress: + """ + Getter for the '_sha' property. + """ + + return self._sha + + @property + def spa(self) -> Ip4Address: + """ + Getter for the '_spa' property. + """ + + return self._spa + + @property + def tha(self) -> MacAddress: + """ + Getter for the '_tha' property. + """ + + return self._tha + + @property + def tpa(self) -> Ip4Address: + """ + Getter for the '_tpa' property. + """ + + return self._tpa diff --git a/pytcp/protocols/ether/fpa.py b/pytcp/protocols/ether/fpa.py index 5490f08d..21f35251 100755 --- a/pytcp/protocols/ether/fpa.py +++ b/pytcp/protocols/ether/fpa.py @@ -32,21 +32,13 @@ ver 2.7 """ - from __future__ import annotations import struct from typing import TYPE_CHECKING from pytcp.lib.mac_address import MacAddress -from pytcp.protocols.ether.ps import ( - ETHER_HEADER_LEN, - ETHER_TYPE_ARP, - ETHER_TYPE_IP4, - ETHER_TYPE_IP6, - ETHER_TYPE_RAW, - ETHER_TYPE_TABLE, -) +from pytcp.protocols.ether.ps import ETHER_HEADER_LEN, Ethernet from pytcp.protocols.raw.fpa import RawAssembler if TYPE_CHECKING: @@ -56,7 +48,7 @@ from pytcp.protocols.ip6.fpa import Ip6Assembler -class EtherAssembler: +class EtherAssembler(Ethernet): """ Ethernet packet assembler support class. """ @@ -76,13 +68,6 @@ def __init__( Class constructor. """ - assert carried_packet.ether_type in { - ETHER_TYPE_ARP, - ETHER_TYPE_IP4, - ETHER_TYPE_IP6, - ETHER_TYPE_RAW, - }, f"{carried_packet.ether_type=}" - self._carried_packet: ( ArpAssembler | Ip4Assembler @@ -90,25 +75,17 @@ def __init__( | Ip6Assembler | RawAssembler ) = carried_packet - self._tracker: Tracker = self._carried_packet.tracker - self._dst: MacAddress = dst - self._src: MacAddress = src - self._type: int = self._carried_packet.ether_type + self._tracker = self._carried_packet.tracker + self._dst = dst + self._src = src + self._type = self._carried_packet.ether_type def __len__(self) -> int: """ Length of the packet. """ - return ETHER_HEADER_LEN + len(self._carried_packet) - def __str__(self) -> str: - """ - Packet log string. - """ - return ( - f"ETHER {self._src} > {self._dst}, 0x{self._type:0>4x} " - f"({ETHER_TYPE_TABLE.get(self._type, '???')}), plen {len(self)}" - ) + return ETHER_HEADER_LEN + len(self._carried_packet) @property def tracker(self) -> Tracker: @@ -117,44 +94,18 @@ def tracker(self) -> Tracker: """ return self._tracker - @property - def dst(self) -> MacAddress: - """ - Getter for the '_dst' attribute. - """ - return self._dst - - @dst.setter - def dst(self, mac_address: MacAddress) -> None: - """ - Setter for the '_dst' attribute. - """ - self._dst = mac_address - - @property - def src(self) -> MacAddress: - """ - Getter for the '_src' attribute. - """ - return self._src - - @src.setter - def src(self, mac_address: MacAddress) -> None: - """ - Setter for the '_src' attribute. - """ - self._src = mac_address - def assemble(self, frame: memoryview) -> None: """ Assemble packet into the raw form. """ + struct.pack_into( "! 6s 6s H", frame, 0, bytes(self._dst), bytes(self._src), - self._type, + int(self._type), ) + self._carried_packet.assemble(frame[ETHER_HEADER_LEN:]) diff --git a/pytcp/protocols/ether/fpp.py b/pytcp/protocols/ether/fpp.py index 0d68a32c..1950a044 100755 --- a/pytcp/protocols/ether/fpp.py +++ b/pytcp/protocols/ether/fpp.py @@ -34,25 +34,37 @@ ver 2.7 """ - from __future__ import annotations -import struct from typing import TYPE_CHECKING -from pytcp import config +from pytcp.lib.errors import PacketIntegrityError, PacketSanityError from pytcp.lib.mac_address import MacAddress -from pytcp.protocols.ether.ps import ( - ETHER_HEADER_LEN, - ETHER_TYPE_MIN, - ETHER_TYPE_TABLE, -) +from pytcp.protocols.ether.ps import ETHER_HEADER_LEN, Ethernet, EtherType if TYPE_CHECKING: from pytcp.lib.packet import PacketRx -class EtherParser: +class EtherIntegrityError(PacketIntegrityError): + """ + Exception raised when Ethernet packet integrity check fails. + """ + + def __init__(self, message: str): + super().__init__("[ETHER] " + message) + + +class EtherSanityError(PacketSanityError): + """ + Exception raised when Ethernet packet sanity check fails. + """ + + def __init__(self, message: str): + super().__init__("[ETHER] " + message) + + +class EtherParser(Ethernet): """ Ethernet packet parser class. """ @@ -66,109 +78,74 @@ def __init__(self, packet_rx: PacketRx) -> None: self._frame = packet_rx.frame - packet_rx.parse_failed = ( - self._packet_integrity_check() or self._packet_sanity_check() - ) + self._packet_integrity_check() + self._packet_sanity_check() - if not packet_rx.parse_failed: - packet_rx.frame = packet_rx.frame[ETHER_HEADER_LEN:] + packet_rx.frame = packet_rx.frame[ETHER_HEADER_LEN:] + + self._dst = MacAddress(self._frame[0:6]) + self._src = MacAddress(self._frame[6:12]) + self._type = EtherType.from_frame(self._frame) def __len__(self) -> int: """ - Number of bytes remaining in the frame. + Get number of bytes remaining in the frame. """ return len(self._frame) - def __str__(self) -> str: - """ - Packet log string. - """ - return ( - f"ETHER {self.src} > {self.dst}, 0x{self.type:0>4x} " - f"({ETHER_TYPE_TABLE.get(self.type, '???')})" - ) - - @property - def dst(self) -> MacAddress: - """ - Read the 'Destination MAC address' field. - """ - if "_cache__dst" not in self.__dict__: - self._cache__dst = MacAddress(self._frame[0:6]) - return self._cache__dst - - @property - def src(self) -> MacAddress: - """ - Read the 'Source MAC address' field. - """ - if "_cache__src" not in self.__dict__: - self._cache__src = MacAddress(self._frame[6:12]) - return self._cache__src - - @property - def type(self) -> int: - """ - Read the 'EtherType' field. - """ - if "_cache__type" not in self.__dict__: - self._cache__type: int = struct.unpack("!H", self._frame[12:14])[0] - return self._cache__type - @property def header_copy(self) -> bytes: """ - Return copy of packet header. + Get copy of packet header. """ - if "_cache__header_copy" not in self.__dict__: - self._cache__header_copy = bytes(self._frame[:ETHER_HEADER_LEN]) - return self._cache__header_copy + + return bytes(self._frame[:ETHER_HEADER_LEN]) @property def data_copy(self) -> bytes: """ - Return copy of packet data. + Get copy of packet data. """ - if "_cache__data_copy" not in self.__dict__: - self._cache__data_copy = bytes(self._frame[ETHER_HEADER_LEN:]) - return self._cache__data_copy + + return bytes(self._frame[ETHER_HEADER_LEN:]) @property def packet_copy(self) -> bytes: """ - Return copy of whole packet. + Get copy of whole packet. """ - if "_cache__packet_copy" not in self.__dict__: - self._cache__packet_copy = bytes(self._frame[:]) - return self._cache__packet_copy + + return bytes(self._frame[:]) @property def plen(self) -> int: """ - Calculate packet length. + Get packet length. """ - if "_cache__plen" not in self.__dict__: - self._cache__plen = len(self) - return self._cache__plen - def _packet_integrity_check(self) -> str: + return len(self) + + def _packet_integrity_check(self) -> None: """ Packet integrity check to be run on raw packet prior to parsing to make sure parsing is safe. """ - if not config.PACKET_INTEGRITY_CHECK: - return "" + if len(self) < ETHER_HEADER_LEN: - return "ETHER integrity - wrong packet length (I)" - return "" + raise EtherIntegrityError( + "The minimum packet length must be " + f"'{ETHER_HEADER_LEN}' bytes, got {len(self)} bytes." + ) - def _packet_sanity_check(self) -> str: + def _packet_sanity_check(self) -> None: """ Packet sanity check to be run on parsed packet to make sure packet's fields contain sane values. """ - if not config.PACKET_SANITY_CHECK: - return "" - if self.type < ETHER_TYPE_MIN: - return "ETHER sanity - 'ether_type' must be greater than 0x0600" - return "" + + if not EtherType.sanity_check(self._frame): + raise EtherSanityError( + "The 'type' field value must be one of " + f"{[type.value for type in EtherType]}, " + f"got '{self.type}'." + ) diff --git a/pytcp/protocols/ether/phrx.py b/pytcp/protocols/ether/phrx.py index 6500dc1e..269a7e3f 100755 --- a/pytcp/protocols/ether/phrx.py +++ b/pytcp/protocols/ether/phrx.py @@ -34,19 +34,15 @@ ver 2.7 """ - from __future__ import annotations from typing import TYPE_CHECKING from pytcp import config +from pytcp.lib.errors import PacketValidationError from pytcp.lib.logger import log from pytcp.protocols.ether.fpp import EtherParser -from pytcp.protocols.ether.ps import ( - ETHER_TYPE_ARP, - ETHER_TYPE_IP4, - ETHER_TYPE_IP6, -) +from pytcp.protocols.ether.ps import EtherType if TYPE_CHECKING: from pytcp.lib.packet import PacketRx @@ -60,13 +56,13 @@ def _phrx_ether(self: PacketHandler, packet_rx: PacketRx) -> None: self.packet_stats_rx.ether__pre_parse += 1 - EtherParser(packet_rx) - - if packet_rx.parse_failed: + try: + EtherParser(packet_rx) + except PacketValidationError as error: self.packet_stats_rx.ether__failed_parse__drop += 1 __debug__ and log( "ether", - f"{packet_rx.tracker} - {packet_rx.parse_failed}", + f"{packet_rx.tracker} - {error}", ) return @@ -95,14 +91,14 @@ def _phrx_ether(self: PacketHandler, packet_rx: PacketRx) -> None: if packet_rx.ether.dst == self.mac_broadcast: self.packet_stats_rx.ether__dst_broadcast += 1 - if packet_rx.ether.type == ETHER_TYPE_ARP and config.IP4_SUPPORT: + if packet_rx.ether.type == EtherType.ARP and config.IP4_SUPPORT: self._phrx_arp(packet_rx) return - if packet_rx.ether.type == ETHER_TYPE_IP4 and config.IP4_SUPPORT: + if packet_rx.ether.type == EtherType.IP4 and config.IP4_SUPPORT: self._phrx_ip4(packet_rx) return - if packet_rx.ether.type == ETHER_TYPE_IP6 and config.IP6_SUPPORT: + if packet_rx.ether.type == EtherType.IP6 and config.IP6_SUPPORT: self._phrx_ip6(packet_rx) return diff --git a/pytcp/protocols/ether/ps.py b/pytcp/protocols/ether/ps.py index 307f0d4b..27dc283d 100755 --- a/pytcp/protocols/ether/ps.py +++ b/pytcp/protocols/ether/ps.py @@ -32,8 +32,11 @@ ver 2.7 """ +import struct +from abc import ABC, abstractmethod -from __future__ import annotations +from pytcp.lib.mac_address import MacAddress +from pytcp.lib.protocol_enum import ProtocolEnum # Ethernet packet header @@ -50,15 +53,96 @@ ETHER_HEADER_LEN = 14 -ETHER_TYPE_MIN = 0x0600 -ETHER_TYPE_ARP = 0x0806 -ETHER_TYPE_IP4 = 0x0800 -ETHER_TYPE_IP6 = 0x86DD -ETHER_TYPE_RAW = 0xFFFF - -ETHER_TYPE_TABLE = { - ETHER_TYPE_ARP: "ARP", - ETHER_TYPE_IP4: "IPv4", - ETHER_TYPE_IP6: "IPv6", - ETHER_TYPE_RAW: "raw_data", -} + +class EtherType(ProtocolEnum): + ARP = 0x0806 + IP4 = 0x0800 + IP6 = 0x86DD + RAW = 0xFFFF + + @staticmethod + def _extract(frame: bytes) -> int: + return int(struct.unpack("!H", frame[12:14])[0]) + + def __str__(self) -> str: + """ + Get string representation of this enum. + """ + + match self: + case EtherType.ARP: + return "ARP" + case EtherType.IP4: + return "IPv4" + case EtherType.IP6: + return "IPv6" + case EtherType.RAW: + return "raw_data" + + +class Ethernet(ABC): + """ + Base class for ARP packet parser and assembler classes. + """ + + _dst: MacAddress + _src: MacAddress + _type: EtherType + + @abstractmethod + def __len__(self) -> int: + """ + Length of the packet. + """ + + raise NotImplementedError + + def __str__(self) -> str: + """ + Packet log string. + """ + + return ( + f"ETHER {self._src} > {self._dst}, 0x{int(self._type):0>4x} " + f"({self._type}), plen {len(self)}" + ) + + @property + def dst(self) -> MacAddress: + """ + Getter for '_dst' property. + """ + + return self._dst + + @dst.setter + def dst(self, mac_address: MacAddress) -> None: + """ + Setter for the '_dst' attribute. + """ + + self._dst = mac_address + + @property + def src(self) -> MacAddress: + """ + Getter for '_src' property. + """ + + return self._src + + @src.setter + def src(self, mac_address: MacAddress) -> None: + """ + Setter for the '_src' attribute. + """ + + self._src = mac_address + + @property + def type(self) -> EtherType: + """ + Getter for '_type' property. + """ + + return self._type diff --git a/pytcp/protocols/icmp4/fpa.py b/pytcp/protocols/icmp4/fpa.py index ade2cee7..369fd28c 100755 --- a/pytcp/protocols/icmp4/fpa.py +++ b/pytcp/protocols/icmp4/fpa.py @@ -41,18 +41,18 @@ from pytcp.lib.ip_helper import inet_cksum from pytcp.lib.tracker import Tracker from pytcp.protocols.icmp4.ps import ( - ICMP4_ECHO_REPLY, - ICMP4_ECHO_REPLY_LEN, - ICMP4_ECHO_REQUEST, - ICMP4_ECHO_REQUEST_LEN, - ICMP4_UNREACHABLE, - ICMP4_UNREACHABLE__PORT, - ICMP4_UNREACHABLE_LEN, + ICMP4_HEADER_LEN, + Icmp4, + Icmp4Code, + Icmp4EchoReplyMessage, + Icmp4EchoRequestMessage, + Icmp4Type, + Icmp4UnreachablePortMessage, ) from pytcp.protocols.ip4.ps import IP4_PROTO_ICMP4 -class Icmp4Assembler: +class Icmp4Assembler(Icmp4): """ ICMPv4 packet assembler support class. """ @@ -62,90 +62,36 @@ class Icmp4Assembler: def __init__( self, *, - type: int = 0, - code: int = 0, - ec_id: int | None = None, - ec_seq: int | None = None, - ec_data: bytes | None = None, - un_data: bytes | None = None, + type: Icmp4Type, + code: Icmp4Code, + message: Icmp4EchoReplyMessage + | Icmp4UnreachablePortMessage + | Icmp4EchoRequestMessage, echo_tracker: Tracker | None = None, ) -> None: - """Class constructor""" + """ + Class constructor. + """ self._tracker: Tracker = Tracker(prefix="TX", echo_tracker=echo_tracker) - assert ec_id is None or 0 <= ec_id <= 0xFFFF - assert ec_seq is None or 0 <= ec_seq <= 0xFFFF - - self._type: int = type - self._code: int = code - - if self._type == ICMP4_ECHO_REPLY and self._code == 0: - self._ec_id = 0 if ec_id is None else ec_id - self._ec_seq = 0 if ec_seq is None else ec_seq - self._ec_data = b"" if ec_data is None else ec_data - return - - if ( - self._type == ICMP4_UNREACHABLE - and self._code == ICMP4_UNREACHABLE__PORT - ): - self._un_data = b"" if un_data is None else un_data[:520] - return - - if self._type == ICMP4_ECHO_REQUEST and self._code == 0: - self._ec_id = 0 if ec_id is None else ec_id - self._ec_seq = 0 if ec_seq is None else ec_seq - self._ec_data = b"" if ec_data is None else ec_data - return - - assert False, "Unknown ICMPv4 Type/Code" + self._type = type + self._code = code + self._message = message def __len__(self) -> int: """ Length of the packet. """ - if self._type == ICMP4_ECHO_REPLY: - return ICMP4_ECHO_REPLY_LEN + len(self._ec_data) - - if ( - self._type == ICMP4_UNREACHABLE - and self._code == ICMP4_UNREACHABLE__PORT - ): - return ICMP4_UNREACHABLE_LEN + len(self._un_data) - - if self._type == ICMP4_ECHO_REQUEST: - return ICMP4_ECHO_REQUEST_LEN + len(self._ec_data) - - assert False, "Unknown ICMPv4 Type/Code" + return ICMP4_HEADER_LEN + len(self._message) def __str__(self) -> str: """ Packet log string. """ - header = f"ICMPv4 {self._type}/{self._code}" - - if self._type == ICMP4_ECHO_REPLY and self._code == 0: - return ( - f"{header} (echo_reply), id {self._ec_id}, " - f"seq {self._ec_seq}, dlen {len(self._ec_data)}" - ) - - if ( - self._type == ICMP4_UNREACHABLE - and self._code == ICMP4_UNREACHABLE__PORT - ): - return f"{header} (unreachable_port), dlen {len(self._un_data)}" - - if self._type == ICMP4_ECHO_REQUEST and self._code == 0: - return ( - f"{header} (echo_request), id {self._ec_id}, " - f"seq {self._ec_seq}, dlen {len(self._ec_data)}" - ) - - assert False, "Unknown ICMPv4 Type/Code" + return f"ICMPv4 {self._type}/{self._code} {self._message}" @property def tracker(self) -> Tracker: @@ -159,51 +105,82 @@ def assemble(self, frame: memoryview, _: int = 0) -> None: Assemble packet into the raw form. """ - if self._type == ICMP4_ECHO_REPLY and self._code == 0: - struct.pack_into( - f"! BBH HH {len(self._ec_data)}s", - frame, - 0, - self._type, - self._code, - 0, - self._ec_id, - self._ec_seq, - bytes(self._ec_data), - ) - struct.pack_into("! H", frame, 2, inet_cksum(frame)) - return - - if ( - self._type == ICMP4_UNREACHABLE - and self._code == ICMP4_UNREACHABLE__PORT - ): - struct.pack_into( - f"! BBH L {len(self._un_data)}s", - frame, - 0, - self._type, - self._code, - 0, - 0, - bytes(self._un_data), - ) - struct.pack_into("! H", frame, 2, inet_cksum(frame)) - return - - if self._type == ICMP4_ECHO_REQUEST and self._code == 0: - struct.pack_into( - f"! BBH HH {len(self._ec_data)}s", - frame, - 0, - self._type, - self._code, - 0, - self._ec_id, - self._ec_seq, - bytes(self._ec_data), - ) - struct.pack_into("! H", frame, 2, inet_cksum(frame)) - return - - assert False, "Unknown ICMPv4 Type/Code" + struct.pack_into( + f"! BBH {len(self._message)}s", + frame, + 0, + int(self._type), + int(self._code), + 0, + bytes(self._message), + ) + struct.pack_into("! H", frame, 2, inet_cksum(frame)) + + +class Icmp4EchoReplyMessageAssembler(Icmp4EchoReplyMessage): + """ + Message assembler class for ICMPv4 Echo Reply packet. + """ + + def __init__( + self, + *, + id: int = 0, + seq: int = 0, + data: bytes = b"", + ) -> None: + """ + Class constructor. + """ + + assert 0 <= id <= 0xFFFF + assert 0 <= seq <= 0xFFFF + assert len(data) <= 65507 + + self._id = id + self._seq = seq + self._data = data + + +class Icmp4UnreachablePortMessageAssembler(Icmp4UnreachablePortMessage): + """ + Message assembler class for ICMPv4 Unreachable Port packet. + """ + + def __init__( + self, + *, + data: bytes = b"", + ) -> None: + """ + Class constructor. + """ + + assert len(data) <= 65507 + + self._data = data[:520] + + +class Icmp4EchoRequestMessageAssembler(Icmp4EchoRequestMessage): + """ + Message assembler class for ICMPv4 Echo Request packet. + """ + + def __init__( + self, + *, + id: int = 0, + seq: int = 0, + data: bytes = b"", + ) -> None: + """ + Class constructor. + """ + + assert 0 <= id <= 0xFFFF + assert 0 <= seq <= 0xFFFF + assert len(data) <= 65507 + + self._id = id + self._seq = seq + self._data = data diff --git a/pytcp/protocols/icmp4/fpp.py b/pytcp/protocols/icmp4/fpp.py index 6a904bd1..971fda3d 100755 --- a/pytcp/protocols/icmp4/fpp.py +++ b/pytcp/protocols/icmp4/fpp.py @@ -41,21 +41,43 @@ import struct from typing import TYPE_CHECKING -from pytcp import config +from pytcp.lib.errors import PacketIntegrityError, PacketSanityError from pytcp.lib.ip_helper import inet_cksum from pytcp.protocols.icmp4.ps import ( - ICMP4_ECHO_REPLY, - ICMP4_ECHO_REQUEST, ICMP4_HEADER_LEN, - ICMP4_UNREACHABLE, - ICMP4_UNREACHABLE__PORT, + Icmp4, + Icmp4EchoReplyCode, + Icmp4EchoReplyMessage, + Icmp4EchoRequestCode, + Icmp4EchoRequestMessage, + Icmp4Type, + Icmp4UnreachableCode, + Icmp4UnreachablePortMessage, ) if TYPE_CHECKING: from pytcp.lib.packet import PacketRx -class Icmp4Parser: +class Icmp4IntegrityError(PacketIntegrityError): + """ + Exception raised when ICMPv4 packet integrity check fails. + """ + + def __init__(self, message: str): + super().__init__("[ICMPv4] " + message) + + +class Icmp4SanityError(PacketSanityError): + """ + Exception raised when ICMPv4 packet sanity check fails. + """ + + def __init__(self, message: str): + super().__init__("[ICMPv4] " + message) + + +class Icmp4Parser(Icmp4): """ ICMPv4 packet parser class. """ @@ -65,179 +87,144 @@ def __init__(self, packet_rx: PacketRx) -> None: Class constructor. """ - assert packet_rx.ip4 is not None - packet_rx.icmp4 = self self._frame = packet_rx.frame self._plen = packet_rx.ip4.dlen - packet_rx.parse_failed = ( - self._packet_integrity_check() or self._packet_sanity_check() - ) + self._packet_integrity_check() + self._packet_sanity_check() + + self._type = Icmp4Type.from_frame(self._frame) + match self._type: + case Icmp4Type.ECHO_REPLY: + self._code = Icmp4EchoReplyCode.from_frame(self._frame) + self._message = Icmp4EchoReplyMessageParser(self._frame) + case Icmp4Type.UNREACHABLE: + self._code = Icmp4UnreachableCode.from_frame(self._frame) + match self._code: + case Icmp4UnreachableCode.PORT: + self._message = Icmp4UnreachablePortMessageParser( + self._frame + ) + case Icmp4Type.ECHO_REQUEST: + self._code = Icmp4EchoRequestCode.from_frame(self._frame) + self._message = Icmp4EchoRequestMessageParser(self._frame) + self._cksum: int = struct.unpack("!H", self._frame[2:4])[0] def __len__(self) -> int: """ Number of bytes remaining in the frame. """ - return len(self._frame) - - def __str__(self) -> str: - """ - Packet log string. - """ - header = f"ICMPv4 {self.type}/{self.code}" - - if self.type == ICMP4_ECHO_REPLY: - return f"{header} (echo_reply), id {self.ec_id}, seq {self.ec_seq}, dlen {len(self.ec_data)}" - - if ( - self.type == ICMP4_UNREACHABLE - and self.code == ICMP4_UNREACHABLE__PORT - ): - return f"{header} (unreachable_port), dlen {len(self.un_data)}" - - if self.type == ICMP4_ECHO_REQUEST: - return f"{header} (echo_request), id {self.ec_id}, seq {self.ec_seq}, dlen {len(self.ec_data)}" - return f"{header} (unknown)" + return len(self._frame) @property - def type(self) -> int: + def plen(self) -> int: """ - Read the 'Type' field. + Calculate packet length. """ - return self._frame[0] + return self._plen @property - def code(self) -> int: + def packet_copy(self) -> bytes: """ - Read the 'Code' field. + Read the whole packet. """ - return self._frame[1] - @property - def cksum(self) -> int: - """ - Read the 'Checksum' field. - """ - if "_cache__cksum" not in self.__dict__: - self._cache__cksum: int = struct.unpack("!H", self._frame[2:4])[0] - return self._cache__cksum + return bytes(self._frame[: self.plen]) - @property - def ec_id(self) -> int: + def _packet_integrity_check(self) -> None: """ - Read the Echo 'Id' field. + Packet integrity check to be run on raw packet prior to parsing + to make sure parsing is safe. """ - if "_cache__ec_id" not in self.__dict__: - assert self.type in {ICMP4_ECHO_REQUEST, ICMP4_ECHO_REPLY} - self._cache__ec_id: int = struct.unpack("!H", self._frame[4:6])[0] - return self._cache__ec_id - @property - def ec_seq(self) -> int: - """ - Read the Echo 'Seq' field. - """ - if "_cache__ec_seq" not in self.__dict__: - assert self.type in {ICMP4_ECHO_REQUEST, ICMP4_ECHO_REPLY} - self._cache__ec_seq: int = struct.unpack("!H", self._frame[6:8])[0] - return self._cache__ec_seq + if inet_cksum(self._frame[: self._plen]): + raise Icmp4IntegrityError( + "Wrong packet checksum.", + ) - @property - def ec_data(self) -> bytes: - """ - Read data carried by the Echo message. - """ - if "_cache__ec_data" not in self.__dict__: - assert self.type in {ICMP4_ECHO_REQUEST, ICMP4_ECHO_REPLY} - self._cache__ec_data = self._frame[8 : self.plen] - return self._cache__ec_data + if not ICMP4_HEADER_LEN <= self._plen <= len(self): + raise Icmp4IntegrityError( + "Wrong packet length (I).", + ) - @property - def un_data(self) -> bytes: + def _packet_sanity_check(self) -> None: """ - Read the data carried by Uneachable message. + Packet sanity check to be run on parsed packet to make sure packets's + fields contain sane values. """ - if "_cache__un_data" not in self.__dict__: - assert self.type == ICMP4_UNREACHABLE - self._cache__un_data = self._frame[8 : self.plen] - return self._cache__un_data - @property - def plen(self) -> int: - """ - Calculate packet length. - """ - return self._plen + if not Icmp4Type.sanity_check(self._frame): + raise Icmp4SanityError( + "The 'type' field value must be one of " + f"{[type.value for type in Icmp4Type]}, " + f"got '{self.type}'." + ) + + match Icmp4Type.from_frame(self._frame): + case Icmp4Type.ECHO_REPLY: + if not Icmp4EchoReplyCode.sanity_check(self._frame): + raise Icmp4SanityError( + "The 'code' field value for Echo Reply message must be one of " + f"{[code.value for code in Icmp4EchoReplyCode]}, " + f"got '{self.code}'." + ) + case Icmp4Type.UNREACHABLE: + if not Icmp4UnreachableCode.sanity_check(self._frame): + raise Icmp4SanityError( + "The 'code' field value Unreachable message must be one of " + f"{[code.value for code in Icmp4UnreachableCode]}, " + f"got '{self.code}'." + ) + case Icmp4Type.ECHO_REQUEST: + if not Icmp4EchoRequestCode.sanity_check(self._frame): + raise Icmp4SanityError( + "The 'code' field value for Echo Request message must be one of " + f"{[code.value for code in Icmp4EchoRequestCode]}, " + f"got '{self.code}'." + ) + + +class Icmp4EchoReplyMessageParser(Icmp4EchoReplyMessage): + """ + Message parser class for ICMPv4 Echo Reply packet. + """ - @property - def packet_copy(self) -> bytes: + def __init__(self, frame: bytes) -> None: """ - Read the whole packet. + Class constructor. """ - if "_cache__packet_copy" not in self.__dict__: - self._cache__packet_copy = bytes(self._frame[: self.plen]) - return self._cache__packet_copy - def _packet_integrity_check(self) -> str: - """ - Packet integrity check to be run on raw packet prior to parsing - to make sure parsing is safe. - """ + self._id: int = struct.unpack("!H", frame[4:6])[0] + self._seq: int = struct.unpack("!H", frame[6:8])[0] + self._data: bytes = frame[8:] - if not config.PACKET_INTEGRITY_CHECK: - return "" - if inet_cksum(self._frame[: self._plen]): - return "ICMPv4 integrity - wrong packet checksum" +class Icmp4UnreachablePortMessageParser(Icmp4UnreachablePortMessage): + """ + Message parser class for ICMPv4 Unreachable Port packet. + """ - if not ICMP4_HEADER_LEN <= self._plen <= len(self): - return "ICMPv4 integrity - wrong packet length (I)" + def __init__(self, frame: bytes) -> None: + """ + Class constructor. + """ - if self._frame[0] in {ICMP4_ECHO_REQUEST, ICMP4_ECHO_REPLY}: - if not 8 <= self._plen <= len(self): - return "ICMPv6 integrity - wrong packet length (II)" + self._data: bytes = frame[8:] - elif self._frame[0] == ICMP4_UNREACHABLE: - if not 12 <= self._plen <= len(self): - return "ICMPv6 integrity - wrong packet length (II)" - return "" +class Icmp4EchoRequestMessageParser(Icmp4EchoRequestMessage): + """ + Message parser class for ICMPv4 Echo Request packet. + """ - def _packet_sanity_check(self) -> str: + def __init__(self, frame: bytes) -> None: """ - Packet sanity check to be run on parsed packet to make sure packets's - fields contain sane values. + Class constructor. """ - if not config.PACKET_SANITY_CHECK: - return "" - - if self.type in {ICMP4_ECHO_REQUEST, ICMP4_ECHO_REPLY}: - if not self.code == 0: - return "ICMPv4 sanity - 'code' should be set to 0 (RFC 792)" - - if self.type == ICMP4_UNREACHABLE: - if self.code not in { - 0, - 1, - 2, - 3, - 4, - 5, - 6, - 7, - 8, - 9, - 10, - 11, - 12, - 13, - 14, - 15, - }: - return "ICMPv4 sanity - 'code' must be set to [0-15] (RFC 792)" - - return "" + self._id: int = struct.unpack("!H", frame[4:6])[0] + self._seq: int = struct.unpack("!H", frame[6:8])[0] + self._data: bytes = frame[8:] diff --git a/pytcp/protocols/icmp4/phrx.py b/pytcp/protocols/icmp4/phrx.py index c05e6822..31ca9fdb 100755 --- a/pytcp/protocols/icmp4/phrx.py +++ b/pytcp/protocols/icmp4/phrx.py @@ -41,13 +41,20 @@ from typing import TYPE_CHECKING from pytcp.lib import stack +from pytcp.lib.errors import PacketValidationError from pytcp.lib.ip4_address import Ip4Address from pytcp.lib.logger import log -from pytcp.protocols.icmp4.fpp import Icmp4Parser +from pytcp.protocols.icmp4.fpa import Icmp4EchoReplyMessageAssembler +from pytcp.protocols.icmp4.fpp import ( + Icmp4EchoReplyMessage, + Icmp4EchoRequestMessage, + Icmp4Parser, + Icmp4UnreachablePortMessage, +) from pytcp.protocols.icmp4.ps import ( - ICMP4_ECHO_REPLY, - ICMP4_ECHO_REQUEST, - ICMP4_UNREACHABLE, + Icmp4EchoReplyCode, + Icmp4Type, + Icmp4UnreachableCode, ) from pytcp.protocols.ip4.ps import IP4_HEADER_LEN, IP4_PROTO_UDP from pytcp.protocols.udp.metadata import UdpMetadata @@ -59,96 +66,151 @@ def _phrx_icmp4(self: PacketHandler, packet_rx: PacketRx) -> None: - """Handle inbound ICMPv4 packets""" + """ + Handle inbound ICMPv4 packets. + """ self.packet_stats_rx.icmp4__pre_parse += 1 - Icmp4Parser(packet_rx) - - if packet_rx.parse_failed: + try: + Icmp4Parser(packet_rx) + except PacketValidationError as error: __debug__ and log( "icmp4", - f"{packet_rx.tracker} - {packet_rx.parse_failed}", + f"{packet_rx.tracker} - {error}", ) self.packet_stats_rx.icmp4__failed_parse__drop += 1 return __debug__ and log("icmp4", f"{packet_rx.tracker} - {packet_rx.icmp4}") - # ICMPv4 Echo Request packet - if packet_rx.icmp4.type == ICMP4_ECHO_REQUEST: - __debug__ and log( - "icmp4", - f"{packet_rx.tracker} - Received ICMPv4 Echo Request " - f"packet from {packet_rx.ip4.src}, sending reply", - ) - self.packet_stats_rx.icmp4__echo_request__respond_echo_reply += 1 - - self._phtx_icmp4( - ip4_src=packet_rx.ip4.dst, - ip4_dst=packet_rx.ip4.src, - icmp4_type=ICMP4_ECHO_REPLY, - icmp4_ec_id=packet_rx.icmp4.ec_id, - icmp4_ec_seq=packet_rx.icmp4.ec_seq, - icmp4_ec_data=packet_rx.icmp4.ec_data, - echo_tracker=packet_rx.tracker, + match packet_rx.icmp4.type: + case Icmp4Type.ECHO_REPLY: + _phrx_icmp4__echo_reply(self, packet_rx) + case Icmp4Type.ECHO_REQUEST: + _phrx_icmp4__echo_request(self, packet_rx) + case Icmp4Type.UNREACHABLE: + match packet_rx.icmp4.code: + case Icmp4UnreachableCode.PORT: + _phrx_icmp4__unreachable__port(self, packet_rx) + + +def _phrx_icmp4__echo_reply(self: PacketHandler, packet_rx: PacketRx) -> None: + """ + Handle inbound ICMPv4 Echo Reply packets. + """ + + assert isinstance(packet_rx.icmp4.message, Icmp4EchoReplyMessage) + + __debug__ and log( + "icmp4", + f"{packet_rx.tracker} - Received ICMPv4 Echo Request " + f"packet from {packet_rx.ip4.src}, sending reply", + ) + self.packet_stats_rx.icmp4__echo_request__respond_echo_reply += 1 + + self._phtx_icmp4( + ip4_src=packet_rx.ip4.dst, + ip4_dst=packet_rx.ip4.src, + icmp4_type=Icmp4Type.ECHO_REPLY, + icmp4_code=Icmp4EchoReplyCode.DEFAULT, + icmp4_message=Icmp4EchoReplyMessageAssembler( + id=packet_rx.icmp4.message.id, + seq=packet_rx.icmp4.message.seq, + data=packet_rx.icmp4.message.data, + ), + echo_tracker=packet_rx.tracker, + ) + + +def _phrx_icmp4__echo_request(self: PacketHandler, packet_rx: PacketRx) -> None: + """ + Handle inbound ICMPv4 Echo Reply packets. + """ + + assert isinstance(packet_rx.icmp4.message, Icmp4EchoRequestMessage) + + __debug__ and log( + "icmp4", + f"{packet_rx.tracker} - Received ICMPv4 Echo Request " + f"packet from {packet_rx.ip4.src}, sending reply", + ) + self.packet_stats_rx.icmp4__echo_request__respond_echo_reply += 1 + + self._phtx_icmp4( + ip4_src=packet_rx.ip4.dst, + ip4_dst=packet_rx.ip4.src, + icmp4_type=Icmp4Type.ECHO_REPLY, + icmp4_code=Icmp4EchoReplyCode.DEFAULT, + icmp4_message=Icmp4EchoReplyMessageAssembler( + id=packet_rx.icmp4.message.id, + seq=packet_rx.icmp4.message.seq, + data=packet_rx.icmp4.message.data, + ), + echo_tracker=packet_rx.tracker, + ) + + +def _phrx_icmp4__unreachable__port( + self: PacketHandler, packet_rx: PacketRx +) -> None: + """ + Handle inbound ICMPv4 Unreachable Port packets. + """ + + assert isinstance(packet_rx.icmp4.message, Icmp4UnreachablePortMessage) + + __debug__ and log( + "icmp4", + f"{packet_rx.tracker} - Received ICMPv4 Unreachable packet " + f"from {packet_rx.ip4.src}, will try to match UDP socket", + ) + self.packet_stats_rx.icmp4__unreachable += 1 + + # Quick and dirty way to validate received data and pull useful + # information from it. + frame = packet_rx.icmp4.message.data + if ( + len(frame) >= IP4_HEADER_LEN + and frame[0] >> 4 == 4 + and len(frame) >= ((frame[0] & 0b00001111) << 2) + and frame[9] == IP4_PROTO_UDP + and len(frame) >= ((frame[0] & 0b00001111) << 2) + UDP_HEADER_LEN + ): + # Create UdpMetadata object and try to find matching UDP socket. + udp_offset = (frame[0] & 0b00001111) << 2 + packet = UdpMetadata( + local_ip_address=Ip4Address(frame[12:16]), + remote_ip_address=Ip4Address(frame[16:20]), + local_port=struct.unpack( + "!H", frame[udp_offset + 0 : udp_offset + 2] + )[0], + remote_port=struct.unpack( + "!H", frame[udp_offset + 2 : udp_offset + 4] + )[0], ) - return - # ICMPv4 Unreachable packet - if packet_rx.icmp4.type == ICMP4_UNREACHABLE: - __debug__ and log( - "icmp4", - f"{packet_rx.tracker} - Received ICMPv4 Unreachable packet " - f"from {packet_rx.ip4.src}, will try to match UDP socket", - ) - self.packet_stats_rx.icmp4__unreachable += 1 - - # Quick and dirty way to validate received data and pull useful - # information from it - frame = packet_rx.icmp4.un_data - if ( - len(frame) >= IP4_HEADER_LEN - and frame[0] >> 4 == 4 - and len(frame) >= ((frame[0] & 0b00001111) << 2) - and frame[9] == IP4_PROTO_UDP - and len(frame) >= ((frame[0] & 0b00001111) << 2) + UDP_HEADER_LEN - ): - # Create UdpMetadata object and try to find matching UDP socket - udp_offset = (frame[0] & 0b00001111) << 2 - packet = UdpMetadata( - local_ip_address=Ip4Address(frame[12:16]), - remote_ip_address=Ip4Address(frame[16:20]), - local_port=struct.unpack( - "!H", frame[udp_offset + 0 : udp_offset + 2] - )[0], - remote_port=struct.unpack( - "!H", frame[udp_offset + 2 : udp_offset + 4] - )[0], - ) - - for socket_pattern in packet.socket_patterns: - socket = stack.sockets.get(socket_pattern, None) - if socket: - __debug__ and log( - "icmp4", - f"{packet_rx.tracker} - Found matching " - f"listening socket {socket}, for Unreachable " - f"packet from {packet_rx.ip4.src}", - ) - socket.notify_unreachable() - return - - __debug__ and log( - "icmp4", - f"{packet_rx.tracker} - Unreachable data doesn't match " - "any UDP socket", - ) - return + for socket_pattern in packet.socket_patterns: + socket = stack.sockets.get(socket_pattern, None) + if socket: + __debug__ and log( + "icmp4", + f"{packet_rx.tracker} - Found matching " + f"listening socket {socket}, for Unreachable " + f"packet from {packet_rx.ip4.src}", + ) + socket.notify_unreachable() + return __debug__ and log( "icmp4", - f"{packet_rx.tracker} - Unreachable data doesn't pass basic " - "IPv4/UDP integrity check", + f"{packet_rx.tracker} - Unreachable data doesn't match " + "any UDP socket", ) return + + __debug__ and log( + "icmp4", + f"{packet_rx.tracker} - Unreachable data doesn't pass basic " + "IPv4/UDP integrity check", + ) diff --git a/pytcp/protocols/icmp4/phtx.py b/pytcp/protocols/icmp4/phtx.py index 4c05d75e..f3a4cbcf 100755 --- a/pytcp/protocols/icmp4/phtx.py +++ b/pytcp/protocols/icmp4/phtx.py @@ -40,15 +40,20 @@ from typing import TYPE_CHECKING +from pytcp.lib.errors import UnsupportedCaseError from pytcp.lib.logger import log from pytcp.lib.tracker import Tracker from pytcp.lib.tx_status import TxStatus from pytcp.protocols.icmp4.fpa import Icmp4Assembler from pytcp.protocols.icmp4.ps import ( - ICMP4_ECHO_REPLY, - ICMP4_ECHO_REQUEST, - ICMP4_UNREACHABLE, - ICMP4_UNREACHABLE__PORT, + Icmp4Code, + Icmp4EchoReplyCode, + Icmp4EchoReplyMessage, + Icmp4EchoRequestCode, + Icmp4EchoRequestMessage, + Icmp4Type, + Icmp4UnreachableCode, + Icmp4UnreachablePortMessage, ) if TYPE_CHECKING: @@ -61,12 +66,11 @@ def _phtx_icmp4( *, ip4_src: Ip4Address, ip4_dst: Ip4Address, - icmp4_type: int, - icmp4_code: int = 0, - icmp4_ec_id: int | None = None, - icmp4_ec_seq: int | None = None, - icmp4_ec_data: bytes | None = None, - icmp4_un_data: bytes | None = None, + icmp4_type: Icmp4Type, + icmp4_code: Icmp4Code, + icmp4_message: Icmp4EchoReplyMessage + | Icmp4UnreachablePortMessage + | Icmp4EchoRequestMessage, echo_tracker: Tracker | None = None, ) -> TxStatus: """ @@ -78,37 +82,30 @@ def _phtx_icmp4( icmp4_packet_tx = Icmp4Assembler( type=icmp4_type, code=icmp4_code, - ec_id=icmp4_ec_id, - ec_seq=icmp4_ec_seq, - ec_data=icmp4_ec_data, - un_data=icmp4_un_data, + message=icmp4_message, echo_tracker=echo_tracker, ) __debug__ and log("icmp4", f"{icmp4_packet_tx.tracker} - {icmp4_packet_tx}") - if icmp4_type == ICMP4_ECHO_REPLY and icmp4_code == 0: - self.packet_stats_tx.icmp4__echo_reply__send += 1 - return self._phtx_ip4( - ip4_src=ip4_src, ip4_dst=ip4_dst, carried_packet=icmp4_packet_tx - ) - - if icmp4_type == ICMP4_ECHO_REQUEST and icmp4_code == 0: - self.packet_stats_tx.icmp4__echo_request__send += 1 - return self._phtx_ip4( - ip4_src=ip4_src, ip4_dst=ip4_dst, carried_packet=icmp4_packet_tx - ) - - if ( - icmp4_type == ICMP4_UNREACHABLE - and icmp4_code == ICMP4_UNREACHABLE__PORT - ): - self.packet_stats_tx.icmp4__unreachable_port__send += 1 - return self._phtx_ip4( - ip4_src=ip4_src, ip4_dst=ip4_dst, carried_packet=icmp4_packet_tx - ) - - # This code will never be executed in debug mode due to assertions present - # in Packet Assembler - self.packet_stats_tx.icmp4__unknown__drop += 1 - return TxStatus.DROPED__ICMP4__UNKNOWN + match (icmp4_type, icmp4_code): + case (Icmp4Type.ECHO_REPLY, Icmp4EchoReplyCode.DEFAULT): + self.packet_stats_tx.icmp4__echo_reply__send += 1 + return self._phtx_ip4( + ip4_src=ip4_src, ip4_dst=ip4_dst, carried_packet=icmp4_packet_tx + ) + + case (Icmp4Type.UNREACHABLE, Icmp4UnreachableCode.PORT): + self.packet_stats_tx.icmp4__unreachable_port__send += 1 + return self._phtx_ip4( + ip4_src=ip4_src, ip4_dst=ip4_dst, carried_packet=icmp4_packet_tx + ) + + case (Icmp4Type.ECHO_REQUEST, Icmp4EchoRequestCode.DEFAULT): + self.packet_stats_tx.icmp4__echo_request__send += 1 + return self._phtx_ip4( + ip4_src=ip4_src, ip4_dst=ip4_dst, carried_packet=icmp4_packet_tx + ) + + case _: + raise UnsupportedCaseError diff --git a/pytcp/protocols/icmp4/ps.py b/pytcp/protocols/icmp4/ps.py index afb721c7..6b8a475d 100755 --- a/pytcp/protocols/icmp4/ps.py +++ b/pytcp/protocols/icmp4/ps.py @@ -35,6 +35,11 @@ from __future__ import annotations +import struct +from abc import ABC, ABCMeta, abstractmethod + +from pytcp.lib.protocol_enum import ProtocolEnum + # Echo reply message (0/0) # +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ @@ -79,16 +84,288 @@ # +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ ICMP4_HEADER_LEN = 4 +ICMP4_ECHO_REPLY_MESSAGE_LEN = 4 +ICMP4_UNRECHABLE_MESSAGE_LEN = 4 +ICMP4_ECHO_REQUEST_MESSAGE_LEN = 4 + + +class Icmp4Type(ProtocolEnum): + ECHO_REPLY = 0 + UNREACHABLE = 3 + ECHO_REQUEST = 8 + + @staticmethod + def _extract(frame: bytes) -> int: + return int(frame[0]) + + +class Icmp4Code(ProtocolEnum): + @staticmethod + def _extract(frame: bytes) -> int: + return int(frame[1]) + + +class Icmp4EchoReplyCode(Icmp4Code): + DEFAULT = 0 + + +class Icmp4UnreachableCode(Icmp4Code): + PORT = 3 + + +class Icmp4EchoRequestCode(Icmp4Code): + DEFAULT = 0 + + +class Icmp4(ABC): + """ + Base class for ICMPv4 packet parser and assembler classes. + """ + + _type: Icmp4Type + _code: Icmp4Code + _cksum: int + _message: Icmp4EchoReplyMessage | Icmp4UnreachablePortMessage | Icmp4EchoRequestMessage + + @abstractmethod + def __len__(self) -> int: + """ + Length of the packet. + """ + + raise NotImplementedError + + def __str__(self) -> str: + """ + Packet log string. + """ + + return f"ICMPv4 {self._type}/{self._code} {self._message}" + + @property + def type(self) -> Icmp4Type: + """ + Getter for the '_type' property. + """ + + return self._type + + @property + def code(self) -> Icmp4Code: + """ + Getter for the '_code' property. + """ + + return self._code + + @property + def cksum(self) -> int: + """ + Getter for the '_cksum' property. + """ + + return self._cksum + + @property + def message( + self, + ) -> ( + Icmp4EchoReplyMessage + | Icmp4UnreachablePortMessage + | Icmp4EchoRequestMessage + ): + """ + Getter for the '_message' property. + """ + + return self._message + + +class Icmp4Message(ABC): + """ + Message base class for ICMPv4 packet. + """ + + def __init__(self) -> None: + """ + Ensure that the class is not instantiated. + """ + + raise NotImplementedError( + "The 'Icmp4[*]Message' classes are not instantiable. " + "Please us appropriate 'Icmp4[*]Message[Assembler|Parser]' " + "class instead." + ) + + @abstractmethod + def __str__(self) -> str: + """ + Packet log string. + """ + + raise NotImplementedError + + @abstractmethod + def __len__(self) -> int: + """ + Length of the ICMPv4 message. + """ + + raise NotImplementedError + + +class Icmp4EchoReplyMessage(Icmp4Message): + """ + Message base class for ICMPv4 Echo Reply packet. + """ + + _id: int + _seq: int + _data: bytes + + def __str__(self) -> str: + """ + Packet log string. + """ + + return ( + f"(echo reply), id {self._id}, seq {self._seq}, " + f"dlen {len(self._data)}" + ) + + def __bytes__(self) -> bytes: + """ + Get the message in raw form. + """ + + return struct.pack( + f"! HH {len(self._data)}s", self._id, self._seq, bytes(self._data) + ) + + def __len__(self) -> int: + """ + Length of the ICMPv4 Echo Reply message. + """ + + return ICMP4_ECHO_REPLY_MESSAGE_LEN + len(self._data) + + @property + def id(self) -> int: + """ + Getter for the '_id' property. + """ + + return self._id + + @property + def seq(self) -> int: + """ + Getter for the '_seq' property. + """ + + return self._seq + + @property + def data(self) -> bytes: + """ + Getter for the '_data' property. + """ + + return self._data + + +class Icmp4UnreachablePortMessage(Icmp4Message): + """ + Message base class for ICMPv4 Unreachable Port packet. + """ + + _data: bytes + + def __str__(self) -> str: + """ + Packet log string. + """ + + return f"(port unreachable), dlen {len(self._data)}" + + def __bytes__(self) -> bytes: + """ + Get the message in raw form. + """ + + return struct.pack(f"! L {len(self._data)}s", 0, bytes(self._data)) + + def __len__(self) -> int: + """ + Length of the ICMPv4 Unreachable Port message. + """ + + return ICMP4_UNRECHABLE_MESSAGE_LEN + len(self._data) + + @property + def data(self) -> bytes: + """ + Getter for the '_data' property. + """ + + return self._data + + +class Icmp4EchoRequestMessage(Icmp4Message): + """ + Message base class for ICMPv4 Echo Request packet. + """ + + _id: int + _seq: int + _data: bytes + + def __str__(self) -> str: + """ + Packet log string. + """ + + return ( + f"(echo request), id {self._id}, seq {self._seq}, " + f"dlen {len(self._data)}" + ) + + def __bytes__(self) -> bytes: + """ + Get the message in raw form. + """ + + return struct.pack( + f"! HH {len(self._data)}s", self._id, self._seq, bytes(self._data) + ) + + def __len__(self) -> int: + """ + Length of the ICMPv4 Echo Reply message. + """ + + return ICMP4_ECHO_REPLY_MESSAGE_LEN + len(self._data) + + @property + def id(self) -> int: + """ + Getter for the '_id' property. + """ + + return self._id + + @property + def seq(self) -> int: + """ + Getter for the '_seq' property. + """ + + return self._seq + + @property + def data(self) -> bytes: + """ + Getter for the '_data' property. + """ -ICMP4_ECHO_REPLY = 0 -ICMP4_ECHO_REPLY_LEN = 8 -ICMP4_UNREACHABLE = 3 -ICMP4_UNREACHABLE_LEN = 8 -ICMP4_UNREACHABLE__NET = 0 -ICMP4_UNREACHABLE__HOST = 1 -ICMP4_UNREACHABLE__PROTOCOL = 2 -ICMP4_UNREACHABLE__PORT = 3 -ICMP4_UNREACHABLE__FRAGMENTATION = 4 -ICMP4_UNREACHABLE__SOURCE_ROUTE_FAILED = 5 -ICMP4_ECHO_REQUEST = 8 -ICMP4_ECHO_REQUEST_LEN = 8 + return self._data diff --git a/pytcp/protocols/icmp6/fpp.py b/pytcp/protocols/icmp6/fpp.py index a23c7beb..38812abd 100755 --- a/pytcp/protocols/icmp6/fpp.py +++ b/pytcp/protocols/icmp6/fpp.py @@ -43,7 +43,7 @@ import struct from typing import TYPE_CHECKING -from pytcp import config +from pytcp.lib.errors import PacketIntegrityError, PacketSanityError from pytcp.lib.ip6_address import Ip6Address, Ip6Mask, Ip6Network from pytcp.lib.ip_helper import inet_cksum from pytcp.lib.mac_address import MacAddress @@ -71,6 +71,24 @@ from pytcp.lib.packet import PacketRx +class Icmp6IntegrityError(PacketIntegrityError): + """ + Exception raised when ICMPv6 packet integrity check fails. + """ + + def __init__(self, message: str): + super().__init__("[ICMPv6] " + message) + + +class Icmp6SanityError(PacketSanityError): + """ + Exception raised when ICMPv6 packet sanity check fails. + """ + + def __init__(self, message: str): + super().__init__("[ICMPv6] " + message) + + class Icmp6Parser: """ ICMPv6 packet parser class. @@ -88,10 +106,13 @@ def __init__(self, packet_rx: PacketRx) -> None: self._frame = packet_rx.frame self._plen = packet_rx.ip6.dlen - packet_rx.parse_failed = self._packet_integrity_check( - packet_rx.ip6.pshdr_sum - ) or self._packet_sanity_check( - packet_rx.ip6.src, packet_rx.ip6.dst, packet_rx.ip6.hop + self._packet_integrity_check( + pshdr_sum=packet_rx.ip6.pshdr_sum, + ) + self._packet_sanity_check( + ip6_src=packet_rx.ip6.src, + ip6_dst=packet_rx.ip6.dst, + ip6_hop=packet_rx.ip6.hop, ) def __len__(self) -> int: @@ -498,217 +519,271 @@ def packet_copy(self) -> bytes: self._cache__packet_copy = self._frame[: self.plen] return self._cache__packet_copy - def _nd_option_integrity_check(self, optr: int) -> str: + def _nd_option_integrity_check(self, optr: int) -> None: """ Check integrity of ICMPv6 ND options. """ + while optr < len(self._frame): if optr + 1 > len(self._frame): - return "ICMPv6 sanity check fail - wrong option length (I)" + raise Icmp6IntegrityError( + "Wrong option length (I)", + ) if self._frame[optr + 1] == 0: - return "ICMPv6 sanity check fail - wrong option length (II)" + raise Icmp6IntegrityError( + "Wrong option length (II)", + ) optr += self._frame[optr + 1] << 3 if optr > len(self._frame): - return "ICMPv6 sanity check fail - wrong option length (III)" - return "" + raise Icmp6IntegrityError( + "Wrong option length (III)", + ) - def _packet_integrity_check(self, pshdr_sum: int) -> str: + def _packet_integrity_check(self, *, pshdr_sum: int) -> None: """ Packet integrity check to be run on raw packet prior to parsing to make sure parsing is safe. """ - if not config.PACKET_INTEGRITY_CHECK: - return "" - if inet_cksum(self._frame[: self._plen], pshdr_sum): - return "ICMPv6 integrity - wrong packet checksum" + raise Icmp6IntegrityError( + "Wrong packet checksum.", + ) if not ICMP6_HEADER_LEN <= self._plen <= len(self): - return "ICMPv6 integrity - wrong packet length (I)" + raise Icmp6IntegrityError( + "Wrong packet length (I)", + ) if self._frame[0] == ICMP6_UNREACHABLE: if not 12 <= self._plen <= len(self): - return "ICMPv6 integrity - wrong packet length (II)" + raise Icmp6IntegrityError( + "Wrong packet length (II)", + ) elif self._frame[0] in {ICMP6_ECHO_REQUEST, ICMP6_ECHO_REPLY}: if not 8 <= self._plen <= len(self): - return "ICMPv6 integrity - wrong packet length (II)" + raise Icmp6IntegrityError( + "Wrong packet length (II)", + ) elif self._frame[0] == ICMP6_MLD2_QUERY: if not 28 <= self._plen <= len(self): - return "ICMPv6 integrity - wrong packet length (II)" + raise Icmp6IntegrityError( + "Wrong packet length (II)", + ) if ( self._plen != 28 + struct.unpack("!H", self._frame[26:28])[0] * 16 ): - return "ICMPv6 integrity - wrong packet length (III)" + raise Icmp6IntegrityError( + "Wrong packet length (III)", + ) elif self._frame[0] == ICMP6_ND_ROUTER_SOLICITATION: if not 8 <= self._plen <= len(self): - return "ICMPv6 integrity - wrong packet length (II)" - if fail := self._nd_option_integrity_check(8): - return fail + raise Icmp6IntegrityError( + "Wrong packet length (II)", + ) + self._nd_option_integrity_check(8) elif self._frame[0] == ICMP6_ND_ROUTER_ADVERTISEMENT: if not 16 <= self._plen <= len(self): - return "ICMPv6 integrity - wrong packet length (II)" - if fail := self._nd_option_integrity_check(16): - return fail + raise Icmp6IntegrityError( + "Wrong packet length (II)", + ) + self._nd_option_integrity_check(16) elif self._frame[0] == ICMP6_ND_NEIGHBOR_SOLICITATION: if not 24 <= self._plen <= len(self): - return "ICMPv6 integrity - wrong packet length (II)" - if fail := self._nd_option_integrity_check(24): - return fail + raise Icmp6IntegrityError( + "Wrong packet length (II)", + ) + self._nd_option_integrity_check(24) elif self._frame[0] == ICMP6_ND_NEIGHBOR_ADVERTISEMENT: if not 24 <= self._plen <= len(self): - return "ICMPv6 integrity - wrong packet length (II)" - if fail := self._nd_option_integrity_check(24): - return fail + raise Icmp6IntegrityError( + "Wrong packet length (II)", + ) + self._nd_option_integrity_check(24) elif self._frame[0] == ICMP6_MLD2_REPORT: if not 8 <= self._plen <= len(self): - return "ICMPv6 integrity - wrong packet length (II)" + raise Icmp6IntegrityError( + "Wrong packet length (II)", + ) optr = 8 for _ in range(struct.unpack("!H", self._frame[6:8])[0]): if optr + 20 > self._plen: - return "ICMPv6 integrity - wrong packet length (III)" + raise Icmp6IntegrityError( + "Wrong packet length (III)", + ) optr += ( 20 + self._frame[optr + 1] + struct.unpack_from("! H", self._frame, optr + 2)[0] * 16 ) if optr != self._plen: - return "ICMPv6 integrity - wrong packet length (IV)" - - return "" + raise Icmp6IntegrityError( + "Wrong packet length (IV)", + ) def _packet_sanity_check( - self, ip6_src: Ip6Address, ip6_dst: Ip6Address, ip6_hop: int - ) -> str: + self, *, ip6_src: Ip6Address, ip6_dst: Ip6Address, ip6_hop: int + ) -> None: """ Packet sanity check to be run on parsed packet to make sure packet's fields contain sane values. """ - if not config.PACKET_SANITY_CHECK: - return "" - if self.type == ICMP6_UNREACHABLE: if self.code not in {0, 1, 2, 3, 4, 5, 6}: - return "ICMPv6 sanity - 'code' must be [0-6] (RFC 4861)" + raise Icmp6SanityError( + "The 'code' must be [0-6] (RFC 4861)", + ) elif self.type == ICMP6_PACKET_TOO_BIG: if not self.code == 0: - return "ICMPv6 sanity - 'code' should be 0 (RFC 4861)" + raise Icmp6SanityError( + "The 'code' should be 0 (RFC 4861)", + ) elif self.type == ICMP6_TIME_EXCEEDED: if self.code not in {0, 1}: - return "ICMPv6 sanity - 'code' must be [0-1] (RFC 4861)" + raise Icmp6SanityError( + "The 'code' must be [0-1] (RFC 4861)", + ) elif self.type == ICMP6_PARAMETER_PROBLEM: if self.code not in {0, 1, 2}: - return "ICMPv6 sanity - 'code' must be [0-2] (RFC 4861)" + raise Icmp6SanityError( + "The 'code' must be [0-2] (RFC 4861)", + ) elif self.type in {ICMP6_ECHO_REQUEST, ICMP6_ECHO_REPLY}: if not self.code == 0: - return "ICMPv6 sanity - 'code' should be 0 (RFC 4861)" + raise Icmp6SanityError( + "The 'code' should be 0 (RFC 4861)", + ) elif self.type == ICMP6_MLD2_QUERY: if not self.code == 0: - return "ICMPv6 sanity - 'code' must be 0 (RFC 3810)" + raise Icmp6SanityError( + "The 'code' must be 0 (RFC 3810)", + ) if not ip6_hop == 1: - return "ICMPv6 sanity - 'hop' must be 255 (RFC 3810)" + raise Icmp6SanityError( + "The 'hop' must be 255 (RFC 3810)", + ) elif self.type == ICMP6_ND_ROUTER_SOLICITATION: if not self.code == 0: - return "ICMPv6 sanity - 'code' must be 0 (RFC 4861)" + raise Icmp6SanityError( + "The 'code' must be 0 (RFC 4861)", + ) if not ip6_hop == 255: - return "ICMPv6 sanity - 'hop' must be 255 (RFC 4861)" + raise Icmp6SanityError( + "The 'hop' must be 255 (RFC 4861)", + ) if not (ip6_src.is_unicast or ip6_src.is_unspecified): - return ( - "ICMPv6 sanity - 'src' must be unicast or unspecified " - "(RFC 4861)" + raise Icmp6SanityError( + "The 'src' must be unicast or unspecified (RFC 4861)", ) if not ip6_dst == Ip6Address("ff02::2"): - return "ICMPv6 sanity - 'dst' must be all-routers (RFC 4861)" + raise Icmp6SanityError( + "The 'dst' must be all-routers (RFC 4861)", + ) if ip6_src.is_unspecified and self.nd_opt_slla: - return ( - "ICMPv6 sanity - 'nd_opt_slla' must not be included if " - "'src' is unspecified (RFC 4861)" + raise Icmp6SanityError( + "The 'nd_opt_slla' must not be included if " + "'src' is unspecified (RFC 4861)", ) elif self.type == ICMP6_ND_ROUTER_ADVERTISEMENT: if not self.code == 0: - return "ICMPv6 sanity - 'code' must be 0 (RFC 4861)" + raise Icmp6SanityError( + "The 'code' must be 0 (RFC 4861)", + ) if not ip6_hop == 255: - return "ICMPv6 sanity - 'hop' must be 255 (RFC 4861)" + raise Icmp6SanityError( + "The 'hop' must be 255 (RFC 4861)", + ) if not ip6_src.is_link_local: - return "ICMPv6 sanity - 'src' must be link local (RFC 4861)" + raise Icmp6SanityError( + "The 'src' must be link local (RFC 4861)", + ) if not (ip6_dst.is_unicast or ip6_dst == Ip6Address("ff02::1")): - return ( - "ICMPv6 sanity - 'dst' must be unicast or all-nodes " - "(RFC 4861)" + raise Icmp6SanityError( + "The 'dst' must be unicast or all-nodes (RFC 4861)", ) elif self.type == ICMP6_ND_NEIGHBOR_SOLICITATION: if not self.code == 0: - return "ICMPv6 sanity - 'code' must be 0 (RFC 4861)" + raise Icmp6SanityError( + "The 'code' must be 0 (RFC 4861)", + ) if not ip6_hop == 255: - return "ICMPv6 sanity - 'hop' must be 255 (RFC 4861)" + raise Icmp6SanityError( + "The 'hop' must be 255 (RFC 4861)", + ) if not (ip6_src.is_unicast or ip6_src.is_unspecified): - return ( - "ICMPv6 sanity - 'src' must be unicast or unspecified " - "(RFC 4861)" + raise Icmp6SanityError( + "The 'src' must be unicast or unspecified (RFC 4861)", ) if ip6_dst not in { self.ns_target_address, self.ns_target_address.solicited_node_multicast, }: - return ( - "ICMPv6 sanity - 'dst' must be 'ns_target_address' or it's " - "solicited-node multicast (RFC 4861)" + raise Icmp6SanityError( + "The 'dst' must be 'ns_target_address' or it's " + "solicited-node multicast (RFC 4861)", ) if not self.ns_target_address.is_unicast: - return ( - "ICMPv6 sanity - 'ns_target_address' must be unicast " - "(RFC 4861)" + raise Icmp6SanityError( + "The 'ns_target_address' must be unicast (RFC 4861)", ) if ip6_src.is_unspecified and self.nd_opt_slla is not None: - return ( - "ICMPv6 sanity - 'nd_opt_slla' must not be included if " - "'src' is unspecified" + raise Icmp6SanityError( + "The 'nd_opt_slla' must not be included if " + "'src' is unspecified", ) elif self.type == ICMP6_ND_NEIGHBOR_ADVERTISEMENT: if not self.code == 0: - return "ICMPv6 sanity - 'code' must be 0 (RFC 4861)" + raise Icmp6SanityError( + "The 'code' must be 0 (RFC 4861)", + ) if not ip6_hop == 255: - return "ICMPv6 sanity - 'hop' must be 255 (RFC 4861)" + raise Icmp6SanityError( + "The 'hop' must be 255 (RFC 4861)", + ) if not ip6_src.is_unicast: - return "ICMPv6 sanity - 'src' must be unicast (RFC 4861)" + raise Icmp6SanityError( + "The 'src' must be unicast (RFC 4861)", + ) if self.na_flag_s is True and not ( ip6_dst.is_unicast or ip6_dst == Ip6Address("ff02::1") ): - return ( - "ICMPv6 sanity - if 'na_flag_s' is set then 'dst' must be " - "unicast or all-nodes (RFC 4861)" + raise Icmp6SanityError( + "If 'na_flag_s' is set then 'dst' must be " + "unicast or all-nodes (RFC 4861)", ) if self.na_flag_s is False and not ip6_dst == Ip6Address("ff02::1"): - return ( - "ICMPv6 sanity - if 'na_flag_s' is not set then 'dst' must " - "be all-nodes (RFC 4861)" + raise Icmp6SanityError( + "If 'na_flag_s' is not set then 'dst' must " + "be all-nodes (RFC 4861)", ) elif self.type == ICMP6_MLD2_REPORT: if not self.code == 0: - return "ICMPv6 sanity - 'code' must be 0 (RFC 3810)" + raise Icmp6SanityError( + "The 'code' must be 0 (RFC 3810)", + ) if not ip6_hop == 1: - return "ICMPv6 sanity - 'hop' must be 1 (RFC 3810)" - - return "" + raise Icmp6SanityError( + "The 'hop' must be 1 (RFC 3810)", + ) # diff --git a/pytcp/protocols/icmp6/phrx.py b/pytcp/protocols/icmp6/phrx.py index a26d58ea..2da671d2 100755 --- a/pytcp/protocols/icmp6/phrx.py +++ b/pytcp/protocols/icmp6/phrx.py @@ -45,6 +45,7 @@ from typing import TYPE_CHECKING from pytcp.lib import stack +from pytcp.lib.errors import PacketValidationError from pytcp.lib.ip6_address import Ip6Address from pytcp.lib.logger import log from pytcp.protocols.icmp6.fpa import Icmp6NdOptTLLA @@ -74,12 +75,12 @@ def _phrx_icmp6(self: PacketHandler, packet_rx: PacketRx) -> None: self.packet_stats_rx.icmp6__pre_parse += 1 - Icmp6Parser(packet_rx) - - if packet_rx.parse_failed: + try: + Icmp6Parser(packet_rx) + except PacketValidationError as error: __debug__ and log( "icmp6", - f"{packet_rx.tracker} - {packet_rx.parse_failed}", + f"{packet_rx.tracker} - {error}", ) self.packet_stats_rx.icmp6__failed_parse__drop += 1 return diff --git a/pytcp/protocols/ip4/fpa.py b/pytcp/protocols/ip4/fpa.py index 93f168ad..1ed29986 100755 --- a/pytcp/protocols/ip4/fpa.py +++ b/pytcp/protocols/ip4/fpa.py @@ -44,7 +44,7 @@ from pytcp.lib.ip4_address import Ip4Address from pytcp.lib.ip_helper import inet_cksum from pytcp.lib.tracker import Tracker -from pytcp.protocols.ether.ps import ETHER_TYPE_IP4 +from pytcp.protocols.ether.ps import EtherType from pytcp.protocols.ip4.ps import ( IP4_HEADER_LEN, IP4_OPT_EOL, @@ -70,7 +70,7 @@ class Ip4Assembler: IPv4 packet assembler support class. """ - ether_type = ETHER_TYPE_IP4 + ether_type = EtherType.IP4 def __init__( self, @@ -131,7 +131,7 @@ def __len__(self) -> int: return ( IP4_HEADER_LEN - + sum(len(_) for _ in self._options) + + sum(len(option) for option in self._options) + len(self._carried_packet) ) @@ -158,6 +158,7 @@ def tracker(self) -> Tracker: """ Getter for the '_tracker' attribute. """ + return self._tracker @property @@ -165,6 +166,7 @@ def dst(self) -> Ip4Address: """ Getter for the '_dst' attribute. """ + return self._dst @property @@ -172,6 +174,7 @@ def src(self) -> Ip4Address: """ Getter for the '_src' attribute. """ + return self._src @property @@ -179,6 +182,7 @@ def hlen(self) -> int: """ Getter for the '_hlen' attribute. """ + return self._hlen @property @@ -186,6 +190,7 @@ def proto(self) -> int: """ Getter for the '_proto' attribute. """ + return self._proto @property @@ -193,6 +198,7 @@ def dlen(self) -> int: """ Calculate data length. """ + return self._plen - self._hlen @property @@ -201,6 +207,7 @@ def pshdr_sum(self) -> int: Create IPv4 pseudo header used by TCP and UDP to compute their checksums. """ + pseudo_header = struct.pack( "! 4s 4s BBH", bytes(self._src), @@ -209,6 +216,7 @@ def pshdr_sum(self) -> int: self._proto, self._plen - self._hlen, ) + return sum(struct.unpack("! 3L", pseudo_header)) @property @@ -216,12 +224,14 @@ def _raw_options(self) -> bytes: """ Packet options in raw format. """ + return b"".join(bytes(option) for option in self._options) def assemble(self, frame: memoryview) -> None: """ Assemble packet into the raw form. """ + struct.pack_into( f"! BBH HH BBH 4s 4s {len(self._raw_options)}s", frame, @@ -247,7 +257,7 @@ class Ip4FragAssembler: IPv4 packet fragment assembler support class. """ - ether_type = ETHER_TYPE_IP4 + ether_type = EtherType.IP4 def __init__( self, diff --git a/pytcp/protocols/ip4/fpp.py b/pytcp/protocols/ip4/fpp.py index 534c11a3..df4535ba 100755 --- a/pytcp/protocols/ip4/fpp.py +++ b/pytcp/protocols/ip4/fpp.py @@ -44,6 +44,7 @@ from typing import TYPE_CHECKING from pytcp import config +from pytcp.lib.errors import PacketIntegrityError, PacketSanityError from pytcp.lib.ip4_address import Ip4Address from pytcp.lib.ip_helper import inet_cksum from pytcp.protocols.ip4.ps import ( @@ -59,6 +60,24 @@ from pytcp.lib.packet import PacketRx +class Ip4IntegrityError(PacketIntegrityError): + """ + Exception raised when IPv4 packet integrity check fails. + """ + + def __init__(self, message: str): + super().__init__("[IPv4] " + message) + + +class Ip4SanityError(PacketSanityError): + """ + Exception raised when IPv4 packet sanity check fails. + """ + + def __init__(self, message: str): + super().__init__("[IPv4] " + message) + + class Ip4Parser: """ IPv4 packet parser class. @@ -74,12 +93,10 @@ def __init__(self, packet_rx: PacketRx) -> None: self._frame = packet_rx.frame - packet_rx.parse_failed = ( - self._packet_integrity_check() or self._packet_sanity_check() - ) + self._packet_integrity_check() + self._packet_sanity_check() - if not packet_rx.parse_failed: - packet_rx.frame = packet_rx.frame[self.hlen :] + packet_rx.frame = packet_rx.frame[self.hlen :] def __len__(self) -> int: """ @@ -324,25 +341,28 @@ def pshdr_sum(self) -> int: ) return self._cache__pshdr_sum - def _packet_integrity_check(self) -> str: + def _packet_integrity_check(self) -> None: """ Packet integrity check to be run on raw packet prior to parsing to make sure parsing is safe. """ - if not config.PACKET_INTEGRITY_CHECK: - return "" - if len(self) < IP4_HEADER_LEN: - return "IPv4 integrity - wrong packet length (I)" + raise Ip4IntegrityError( + "The wrong packet length (I)", + ) if not IP4_HEADER_LEN <= self.hlen <= self.plen <= len(self): - return "IPv4 integrity - wrong packet length (II)" + raise Ip4IntegrityError( + "The wrong packet length (II)", + ) # Cannot compute checksum earlier because it depends # on sanity of hlen field if inet_cksum(self._frame[: self.hlen]): - return "IPv4 integriy - wrong packet checksum" + raise Ip4IntegrityError( + "The wrong packet checksum", + ) optr = IP4_HEADER_LEN while optr < self.hlen: @@ -351,55 +371,69 @@ def _packet_integrity_check(self) -> str: if self._frame[optr] == IP4_OPT_NOP: optr += 1 if optr > self.hlen: - return "IPv4 integrity - wrong option length (I)" + raise Ip4IntegrityError( + "The integrity - wrong option length (I)", + ) continue if optr + 1 > self.hlen: - return "IPv4 integrity - wrong option length (II)" + raise Ip4IntegrityError( + "The wrong option length (II)", + ) if self._frame[optr + 1] == 0: - return "IPv4 integrity - wrong option length (III)" + raise Ip4IntegrityError( + "The wrong option length (III)", + ) optr += self._frame[optr + 1] if optr > self.hlen: - return "IPv4 integrity - wrong option length (IV)" - - return "" + raise Ip4IntegrityError( + "The wrong option length (IV)", + ) - def _packet_sanity_check(self) -> str: + def _packet_sanity_check(self) -> None: """ Packet sanity check to be run on parsed packet to make sure packet's fields contain sane values. """ - if not config.PACKET_SANITY_CHECK: - return "" - if self.ver != 4: - return "IP sanityi - 'ver' must be 4" + raise Ip4SanityError( + "The 'ver' must be 4", + ) if self.ver == 0: - return "IP sanity - 'ttl' must be greater than 0" + raise Ip4SanityError( + "The 'ttl' must be greater than 0", + ) if self.src.is_multicast: - return "IP sanity - 'src' must not be multicast" + raise Ip4SanityError( + "The 'src' must not be multicast", + ) if self.src.is_reserved: - return "IP sanity - 'src' must not be reserved" + raise Ip4SanityError( + "The 'src' must not be reserved", + ) if self.src.is_limited_broadcast: - return "IP sanity - 'src' must not be limited broadcast" + raise Ip4SanityError( + "The 'src' must not be limited broadcast", + ) if self.flag_df and self.flag_mf: - return ( - "IP sanity - 'flag_df' and 'flag_mf' must not be set " - "simultaneously" + raise Ip4SanityError( + "The 'flag_df' and 'flag_mf' must not be set simultaneously", ) if self.offset and self.flag_df: - return "IP sanity - 'offset' must be 0 when 'df_flag' is set" + raise Ip4SanityError( + "The 'offset' must be 0 when 'df_flag' is set", + ) if self.options and config.IP4_OPTION_PACKET_DROP: - return "IP sanity - packet must not contain options" - - return "" + raise Ip4SanityError( + "The packet must not contain options", + ) # diff --git a/pytcp/protocols/ip4/phrx.py b/pytcp/protocols/ip4/phrx.py index b660ebdc..f25accf9 100755 --- a/pytcp/protocols/ip4/phrx.py +++ b/pytcp/protocols/ip4/phrx.py @@ -42,6 +42,7 @@ from typing import TYPE_CHECKING from pytcp import config +from pytcp.lib.errors import PacketValidationError from pytcp.lib.ip_helper import inet_cksum from pytcp.lib.logger import log from pytcp.lib.packet import PacketRx @@ -135,13 +136,13 @@ def _phrx_ip4(self: PacketHandler, packet_rx: PacketRx) -> None: self.packet_stats_rx.ip4__pre_parse += 1 - Ip4Parser(packet_rx) - - if packet_rx.parse_failed: + try: + Ip4Parser(packet_rx) + except PacketValidationError as error: self.packet_stats_rx.ip4__failed_parse__drop += 1 __debug__ and log( "ip4", - f"{packet_rx.tracker} - {packet_rx.parse_failed}", + f"{packet_rx.tracker} - {error}", ) return diff --git a/pytcp/protocols/ip6/fpa.py b/pytcp/protocols/ip6/fpa.py index bcad1b77..de30d994 100755 --- a/pytcp/protocols/ip6/fpa.py +++ b/pytcp/protocols/ip6/fpa.py @@ -41,7 +41,7 @@ from pytcp import config from pytcp.lib.ip6_address import Ip6Address -from pytcp.protocols.ether.ps import ETHER_TYPE_IP6 +from pytcp.protocols.ether.ps import EtherType from pytcp.protocols.ip6.ps import ( IP6_HEADER_LEN, IP6_NEXT_EXT_FRAG, @@ -66,7 +66,7 @@ class Ip6Assembler: IPv6 packet assembler support class. """ - ether_type = ETHER_TYPE_IP6 + ether_type = EtherType.IP6 def __init__( self, diff --git a/pytcp/protocols/ip6/fpp.py b/pytcp/protocols/ip6/fpp.py index 178bee0d..26037b79 100755 --- a/pytcp/protocols/ip6/fpp.py +++ b/pytcp/protocols/ip6/fpp.py @@ -40,7 +40,7 @@ import struct from typing import TYPE_CHECKING -from pytcp import config +from pytcp.lib.errors import PacketIntegrityError, PacketSanityError from pytcp.lib.ip6_address import Ip6Address from pytcp.protocols.ip6.ps import IP6_HEADER_LEN, IP6_NEXT_TABLE @@ -48,6 +48,24 @@ from pytcp.lib.packet import PacketRx +class Ip6IntegrityError(PacketIntegrityError): + """ + Exception raised when IPv6 packet integrity check fails. + """ + + def __init__(self, message: str): + super().__init__("[IPv6] " + message) + + +class Ip6SanityError(PacketSanityError): + """ + Exception raised when IPv6 packet sanity check fails. + """ + + def __init__(self, message: str): + super().__init__("[IPv6] " + message) + + class Ip6Parser: """ IPv6 packet parser class. @@ -63,12 +81,10 @@ def __init__(self, packet_rx: PacketRx) -> None: self._frame = packet_rx.frame - packet_rx.parse_failed = ( - self._packet_integrity_check() or self._packet_sanity_check() - ) + self._packet_integrity_check() + self._packet_sanity_check() - if not packet_rx.parse_failed: - packet_rx.frame = packet_rx.frame[IP6_HEADER_LEN:] + packet_rx.frame = packet_rx.frame[IP6_HEADER_LEN:] def __len__(self) -> int: """ @@ -235,41 +251,42 @@ def pshdr_sum(self) -> int: return self._cache__pshdr_sum - def _packet_integrity_check(self) -> str: + def _packet_integrity_check(self) -> None: """ Packet integrity check to be run on raw packet prior to parsing to make sure parsing is safe. """ - if not config.PACKET_INTEGRITY_CHECK: - return "" if len(self) < IP6_HEADER_LEN: - return "IPv6 integrity - wrong packet length (I)" + raise Ip6IntegrityError( + "The wrong packet length (I)", + ) if ( struct.unpack("!H", self._frame[4:6])[0] != len(self) - IP6_HEADER_LEN ): - return "IPv6 integrity - wrong packet length (II)" - - return "" + raise Ip6IntegrityError( + "The wrong packet length (II)", + ) - def _packet_sanity_check(self) -> str: + def _packet_sanity_check(self) -> None: """ Packet sanity check to be run on parsed packet to make sure packet's fields contain sane values. """ - if not config.PACKET_SANITY_CHECK: - return "" - if self.ver != 6: - return "IPv6 sanity - 'ver' must be 6" + raise Ip6SanityError( + "The 'ver' must be 6", + ) if self.hop == 0: - return "IPv6 sanity - 'hop' must not be 0" + raise Ip6SanityError( + "The 'hop' must not be 0", + ) if self.src.is_multicast: - return "IPv6 sanity - 'src' must not be multicast" - - return "" + raise Ip6SanityError( + "The 'src' must not be multicast", + ) diff --git a/pytcp/protocols/ip6/phrx.py b/pytcp/protocols/ip6/phrx.py index 06add1fb..161fad57 100755 --- a/pytcp/protocols/ip6/phrx.py +++ b/pytcp/protocols/ip6/phrx.py @@ -39,6 +39,7 @@ from typing import TYPE_CHECKING +from pytcp.lib.errors import PacketValidationError from pytcp.lib.logger import log from pytcp.lib.packet import PacketRx from pytcp.protocols.ip6.fpp import Ip6Parser @@ -60,13 +61,11 @@ def _phrx_ip6(self: PacketHandler, packet_rx: PacketRx) -> None: self.packet_stats_rx.ip6__pre_parse += 1 - Ip6Parser(packet_rx) - - if packet_rx.parse_failed: + try: + Ip6Parser(packet_rx) + except PacketValidationError as error: self.packet_stats_rx.ip6__failed_parse__drop += 1 - __debug__ and log( - "ip6", f"{packet_rx.tracker} - {packet_rx.parse_failed}" - ) + __debug__ and log("ip6", f"{packet_rx.tracker} - {error}") return __debug__ and log("ip6", f"{packet_rx.tracker} - {packet_rx.ip6}") diff --git a/pytcp/protocols/ip6_ext_frag/fpp.py b/pytcp/protocols/ip6_ext_frag/fpp.py index 910a38fa..74ea9a11 100755 --- a/pytcp/protocols/ip6_ext_frag/fpp.py +++ b/pytcp/protocols/ip6_ext_frag/fpp.py @@ -41,7 +41,7 @@ import struct from typing import TYPE_CHECKING -from pytcp import config +from pytcp.lib.errors import PacketIntegrityError, PacketSanityError from pytcp.protocols.ip6_ext_frag.ps import ( IP6_EXT_FRAG_HEADER_LEN, IP6_EXT_FRAG_NEXT_HEADER_TABLE, @@ -51,6 +51,24 @@ from pytcp.lib.packet import PacketRx +class Ip6ExtFragIntegrityError(PacketIntegrityError): + """ + Exception raised when IPv6 Ext Frag packet integrity check fails. + """ + + def __init__(self, message: str): + super().__init__("[IPv6 Ext Frag] " + message) + + +class Ip6ExtFragSanityError(PacketSanityError): + """ + Exception raised when IPv6 Ext Frag packet sanity check fails. + """ + + def __init__(self, message: str): + super().__init__("[IPv6 Ext Frag] " + message) + + class Ip6ExtFragParser: """ IPv6 fragmentation extension header parser class. @@ -68,12 +86,10 @@ def __init__(self, packet_rx: PacketRx) -> None: self._frame = packet_rx.frame self._plen = packet_rx.ip6.dlen - packet_rx.parse_failed = ( - self._packet_integrity_check() or self._packet_sanity_check() - ) + self._packet_integrity_check() + self._packet_sanity_check() - if not packet_rx.parse_failed: - packet_rx.frame = packet_rx.frame[IP6_EXT_FRAG_HEADER_LEN:] + packet_rx.frame = packet_rx.frame[IP6_EXT_FRAG_HEADER_LEN:] def __len__(self) -> int: """ @@ -177,27 +193,21 @@ def packet_copy(self) -> bytes: self._cache__packet_copy = bytes(self._frame[: self.plen]) return self._cache__packet_copy - def _packet_integrity_check(self) -> str: + def _packet_integrity_check(self) -> None: """ Packet integrity check to be run on raw packet prior to parsing to make sure parsing is safe. """ - if not config.PACKET_INTEGRITY_CHECK: - return "" - if len(self) < IP6_EXT_FRAG_HEADER_LEN: - return "IPv4 integrity - wrong packet length (I)" - - return "" + raise Ip6ExtFragIntegrityError( + "The wrong packet length (I)", + ) - def _packet_sanity_check(self) -> str: + def _packet_sanity_check(self) -> None: """ Packet sanity check to be run on parsed packet to make sure packet's fields contain sane values. """ - if not config.PACKET_SANITY_CHECK: - return "" - - return "" + pass diff --git a/pytcp/protocols/raw/fpa.py b/pytcp/protocols/raw/fpa.py index 7fa5bb98..71f02a15 100755 --- a/pytcp/protocols/raw/fpa.py +++ b/pytcp/protocols/raw/fpa.py @@ -38,7 +38,7 @@ import struct from pytcp.lib.tracker import Tracker -from pytcp.protocols.ether.ps import ETHER_TYPE_RAW +from pytcp.protocols.ether.ps import EtherType from pytcp.protocols.ip4.ps import IP4_PROTO_RAW from pytcp.protocols.ip6.ps import IP6_NEXT_RAW @@ -50,7 +50,7 @@ class RawAssembler: ip4_proto = IP4_PROTO_RAW ip6_next = IP6_NEXT_RAW - ether_type = ETHER_TYPE_RAW + ether_type = EtherType.RAW def __init__( self, *, data: bytes | None = None, echo_tracker: Tracker | None = None diff --git a/pytcp/protocols/tcp/fpp.py b/pytcp/protocols/tcp/fpp.py index f5aa84ab..5e0a3c0c 100755 --- a/pytcp/protocols/tcp/fpp.py +++ b/pytcp/protocols/tcp/fpp.py @@ -42,7 +42,7 @@ import struct from typing import TYPE_CHECKING -from pytcp import config +from pytcp.lib.errors import PacketIntegrityError, PacketSanityError from pytcp.lib.ip_helper import inet_cksum from pytcp.protocols.tcp.ps import ( TCP_HEADER_LEN, @@ -60,6 +60,24 @@ from pytcp.lib.packet import PacketRx +class TcpIntegrityError(PacketIntegrityError): + """ + Exception raised when TCP packet integrity check fails. + """ + + def __init__(self, message: str): + super().__init__("[TCP] " + message) + + +class TcpSanityError(PacketSanityError): + """ + Exception raised when TCP packet sanity check fails. + """ + + def __init__(self, message: str): + super().__init__("[TCP] " + message) + + class TcpParser: """ TCP packet parser class. @@ -77,10 +95,10 @@ def __init__(self, packet_rx: PacketRx) -> None: self._frame = packet_rx.frame self._plen = packet_rx.ip.dlen - packet_rx.parse_failed = ( - self._packet_integrity_check(packet_rx.ip.pshdr_sum) - or self._packet_sanity_check() + self._packet_integrity_check( + pshdr_sum=packet_rx.ip.pshdr_sum, ) + self._packet_sanity_check() if packet_rx.parse_failed: packet_rx.frame = packet_rx.frame[self.hlen :] @@ -432,24 +450,27 @@ def timestamp(self) -> tuple[int, int] | None: self._cache__timestamp = None return self._cache__timestamp - def _packet_integrity_check(self, pshdr_sum: int) -> str: + def _packet_integrity_check(self, *, pshdr_sum: int) -> None: """ Packet integrity check to be run on raw frame prior to parsing to make sure parsing is safe. """ - if not config.PACKET_INTEGRITY_CHECK: - return "" - if inet_cksum(self._frame[: self._plen], pshdr_sum): - return "TCP integrity - wrong packet checksum" + raise TcpIntegrityError( + "The wrong packet checksum", + ) if not TCP_HEADER_LEN <= self._plen <= len(self): - return "TCP integrity - wrong packet length (I)" + raise TcpIntegrityError( + "The wrong packet length (I)", + ) hlen = (self._frame[12] & 0b11110000) >> 2 if not TCP_HEADER_LEN <= hlen <= self._plen <= len(self): - return "TCP integrity - wrong packet length (II)" + raise TcpIntegrityError( + "The wrong packet length (II)", + ) optr = TCP_HEADER_LEN while optr < hlen: @@ -458,61 +479,69 @@ def _packet_integrity_check(self, pshdr_sum: int) -> str: if self._frame[optr] == TCP_OPT_NOP: optr += 1 if optr > hlen: - return "TCP integrity - wrong option length (I)" + raise TcpIntegrityError( + "The wrong option length (I)", + ) continue if optr + 1 > hlen: - return "TCP integrity - wrong option length (II)" + raise TcpIntegrityError( + "The wrong option length (II)", + ) if self._frame[optr + 1] == 0: - return "TCP integrity - wrong option length (III)" + raise TcpIntegrityError( + "The wrong option length (III)", + ) optr += self._frame[optr + 1] if optr > hlen: - return "TCP integrity - wrong option length (IV)" - - return "" + raise TcpIntegrityError( + "The wrong option length (IV)", + ) - def _packet_sanity_check(self) -> str: + def _packet_sanity_check(self) -> None: """ Packet sanity check to be run on parsed packet to make sure packets's fields contain sane values. """ - if not config.PACKET_SANITY_CHECK: - return "" - if self.sport == 0: - return "TCP sanity - 'sport' must be greater than 0" + raise TcpSanityError( + "The 'sport' must be greater than 0", + ) if self.dport == 0: - return "TCP sanity - 'dport' must be greater than 0" + raise TcpSanityError( + "The 'dport' must be greater than 0", + ) if self.flag_syn and self.flag_fin: - return ( - "TCP sanity - 'flag_syn' and 'flag_fin' must not be set " - "simultaneously" + raise TcpSanityError( + "The 'flag_syn' and 'flag_fin' must not be set simultaneously", ) if self.flag_syn and self.flag_rst: - return ( - "TCP sanity - 'flag_syn' and 'flag_rst' must not set " - "simultaneously" + raise TcpSanityError( + "The 'flag_syn' and 'flag_rst' must not set simultaneously", ) if self.flag_fin and self.flag_rst: - return ( - "TCP sanity - 'flag_fin' and 'flag_rst' must not be set " - "simultaneously" + raise TcpSanityError( + "The 'flag_fin' and 'flag_rst' must not be set simultaneously", ) if self.flag_fin and not self.flag_ack: - return "TCP sanity - 'flag_ack' must be set when 'flag_fin' is set" + raise TcpSanityError( + "The 'flag_ack' must be set when 'flag_fin' is set", + ) if self.ack and not self.flag_ack: - return "TCP sanity - 'flag_ack' must be set when 'ack' is not 0" + raise TcpSanityError( + "The 'flag_ack' must be set when 'ack' is not 0", + ) if self.urg and not self.flag_urg: - return "TCP sanity - 'flag_urg' must be set when 'urg' is not 0" - - return "" + raise TcpSanityError( + "The 'flag_urg' must be set when 'urg' is not 0", + ) # diff --git a/pytcp/protocols/tcp/phrx.py b/pytcp/protocols/tcp/phrx.py index 3770a2c1..522a2a75 100755 --- a/pytcp/protocols/tcp/phrx.py +++ b/pytcp/protocols/tcp/phrx.py @@ -40,6 +40,7 @@ from typing import TYPE_CHECKING from pytcp.lib import stack +from pytcp.lib.errors import PacketValidationError from pytcp.lib.logger import log from pytcp.lib.packet import PacketRx from pytcp.protocols.tcp.fpp import TcpParser @@ -56,13 +57,13 @@ def _phrx_tcp(self: PacketHandler, packet_rx: PacketRx) -> None: self.packet_stats_rx.tcp__pre_parse += 1 - TcpParser(packet_rx) - - if packet_rx.parse_failed: + try: + TcpParser(packet_rx) + except PacketValidationError as error: self.packet_stats_rx.tcp__failed_parse__drop += 1 __debug__ and log( "tcp", - f"{packet_rx.tracker} - {packet_rx.parse_failed}", + f"{packet_rx.tracker} - {error}", ) return diff --git a/pytcp/protocols/udp/fpp.py b/pytcp/protocols/udp/fpp.py index a3fd51e0..865f7bc4 100755 --- a/pytcp/protocols/udp/fpp.py +++ b/pytcp/protocols/udp/fpp.py @@ -40,7 +40,7 @@ import struct from typing import TYPE_CHECKING -from pytcp import config +from pytcp.lib.errors import PacketIntegrityError, PacketSanityError from pytcp.lib.ip_helper import inet_cksum from pytcp.protocols.udp.ps import UDP_HEADER_LEN @@ -48,6 +48,24 @@ from pytcp.lib.packet import PacketRx +class UdpIntegrityError(PacketIntegrityError): + """ + Exception raised when UDP packet integrity check fails. + """ + + def __init__(self, message: str): + super().__init__("[UDP] " + message) + + +class UdpSanityError(PacketSanityError): + """ + Exception raised when UDP packet sanity check fails. + """ + + def __init__(self, message: str): + super().__init__("[UDP] " + message) + + class UdpParser: """ UDP packet parser class. @@ -65,10 +83,8 @@ def __init__(self, packet_rx: PacketRx) -> None: self._frame = packet_rx.frame self._plen = packet_rx.ip.dlen - packet_rx.parse_failed = ( - self._packet_integrity_check(packet_rx.ip.pshdr_sum) - or self._packet_sanity_check() - ) + self._packet_integrity_check(pshdr_sum=packet_rx.ip.pshdr_sum) + self._packet_sanity_check() if not packet_rx.parse_failed: packet_rx.frame = packet_rx.frame[UDP_HEADER_LEN:] @@ -166,40 +182,40 @@ def packet_copy(self) -> bytes: self._cache__packet_copy = bytes(self._frame[: self.plen]) return self._cache__packet_copy - def _packet_integrity_check(self, pshdr_sum: int) -> str: + def _packet_integrity_check(self, pshdr_sum: int) -> None: """ Packet integrity check to be run on raw frame prior to parsing to make sure parsing is safe. """ - if not config.PACKET_INTEGRITY_CHECK: - return "" - if inet_cksum(self._frame[: self._plen], pshdr_sum): - return "UDP integrity - wrong packet checksum" + raise UdpIntegrityError( + "The wrong packet checksum", + ) if not UDP_HEADER_LEN <= self._plen <= len(self): - return "UDP integrity - wrong packet length (I)" + raise UdpIntegrityError( + "The wrong packet length (I)", + ) plen = struct.unpack("!H", self._frame[4:6])[0] if not UDP_HEADER_LEN <= plen == self._plen <= len(self): - return "UDP integrity - wrong packet length (II)" - - return "" + raise UdpIntegrityError( + "The wrong packet length (II)", + ) - def _packet_sanity_check(self) -> str: + def _packet_sanity_check(self) -> None: """ Packet sanity check to be run on parsed packet to make sure packet's fields contain sane values. """ - if not config.PACKET_SANITY_CHECK: - return "" - if self.sport == 0: - return "UDP sanity - 'udp_sport' must be greater than 0" + raise UdpSanityError( + "The 'udp_sport' must be greater than 0", + ) if self.dport == 0: - return "UDP sanity - 'udp_dport' must be greater then 0" - - return "" + raise UdpSanityError( + "The 'udp_dport' must be greater then 0", + ) diff --git a/pytcp/protocols/udp/phrx.py b/pytcp/protocols/udp/phrx.py index 091a70e7..eab5fcbd 100755 --- a/pytcp/protocols/udp/phrx.py +++ b/pytcp/protocols/udp/phrx.py @@ -41,9 +41,11 @@ from pytcp import config from pytcp.lib import stack +from pytcp.lib.errors import PacketValidationError from pytcp.lib.logger import log from pytcp.lib.packet import PacketRx -from pytcp.protocols.icmp4.ps import ICMP4_UNREACHABLE, ICMP4_UNREACHABLE__PORT +from pytcp.protocols.icmp4.fpa import Icmp4UnreachablePortMessageAssembler +from pytcp.protocols.icmp4.ps import Icmp4Type, Icmp4UnreachableCode from pytcp.protocols.icmp6.ps import ICMP6_UNREACHABLE, ICMP6_UNREACHABLE__PORT from pytcp.protocols.udp.fpp import UdpParser from pytcp.protocols.udp.metadata import UdpMetadata @@ -59,13 +61,13 @@ def _phrx_udp(self: PacketHandler, packet_rx: PacketRx) -> None: self.packet_stats_rx.udp__pre_parse += 1 - UdpParser(packet_rx) - - if packet_rx.parse_failed: + try: + UdpParser(packet_rx) + except PacketValidationError as error: self.packet_stats_rx.udp__failed_parse__drop += 1 __debug__ and log( "udp", - f"{packet_rx.tracker} - {packet_rx.parse_failed}", + f"{packet_rx.tracker} - {error}", ) return @@ -158,9 +160,11 @@ def _phrx_udp(self: PacketHandler, packet_rx: PacketRx) -> None: self._phtx_icmp4( ip4_src=packet_rx.ip4.dst, ip4_dst=packet_rx.ip4.src, - icmp4_type=ICMP4_UNREACHABLE, - icmp4_code=ICMP4_UNREACHABLE__PORT, - icmp4_un_data=packet_rx.ip.packet_copy, + icmp4_type=Icmp4Type.UNREACHABLE, + icmp4_code=Icmp4UnreachableCode.PORT, + icmp4_message=Icmp4UnreachablePortMessageAssembler( + data=packet_rx.ip.packet_copy + ), echo_tracker=packet_rx.tracker, ) diff --git a/pytcp/subsystems/arp_cache.py b/pytcp/subsystems/arp_cache.py index 8973c46e..6d33954b 100755 --- a/pytcp/subsystems/arp_cache.py +++ b/pytcp/subsystems/arp_cache.py @@ -46,7 +46,7 @@ from pytcp.lib.ip4_address import Ip4Address from pytcp.lib.logger import log from pytcp.lib.mac_address import MacAddress -from pytcp.protocols.arp.ps import ARP_OP_REQUEST +from pytcp.protocols.arp.ps import ArpOperation class ArpCache: @@ -181,7 +181,7 @@ def _send_arp_request(self, arp_tpa: Ip4Address) -> None: stack.packet_handler._phtx_arp( ether_src=stack.packet_handler.mac_unicast, ether_dst=MacAddress(0xFFFFFFFFFFFF), - arp_oper=ARP_OP_REQUEST, + arp_oper=ArpOperation.REQUEST, arp_sha=stack.packet_handler.mac_unicast, arp_spa=stack.packet_handler.ip4_unicast[0] if stack.packet_handler.ip4_unicast diff --git a/pytcp/subsystems/packet_handler.py b/pytcp/subsystems/packet_handler.py index 742ffc55..f9572a3f 100755 --- a/pytcp/subsystems/packet_handler.py +++ b/pytcp/subsystems/packet_handler.py @@ -56,12 +56,19 @@ from pytcp.lib.packet_stats import PacketStatsRx, PacketStatsTx from pytcp.protocols.arp.phrx import _phrx_arp from pytcp.protocols.arp.phtx import _phtx_arp -from pytcp.protocols.arp.ps import ARP_OP_REPLY, ARP_OP_REQUEST +from pytcp.protocols.arp.ps import ArpOperation from pytcp.protocols.dhcp4.client import Dhcp4Client from pytcp.protocols.ether.phrx import _phrx_ether from pytcp.protocols.ether.phtx import _phtx_ether from pytcp.protocols.icmp4.phrx import _phrx_icmp4 from pytcp.protocols.icmp4.phtx import _phtx_icmp4 +from pytcp.protocols.icmp4.ps import ( + Icmp4Code, + Icmp4EchoReplyMessage, + Icmp4EchoRequestMessage, + Icmp4Type, + Icmp4UnreachablePortMessage, +) from pytcp.protocols.icmp6.fpa import ( Icmp6MulticastAddressRecord, Icmp6NdOptPI, @@ -473,7 +480,7 @@ def _send_arp_probe(self, ip4_unicast: Ip4Address) -> None: self._phtx_arp( ether_src=self.mac_unicast, ether_dst=MacAddress(0xFFFFFFFFFFFF), - arp_oper=ARP_OP_REQUEST, + arp_oper=ArpOperation.REQUEST, arp_sha=self.mac_unicast, arp_spa=Ip4Address(0), arp_tha=MacAddress(0), @@ -488,7 +495,7 @@ def _send_arp_announcement(self, ip4_unicast: Ip4Address) -> None: self._phtx_arp( ether_src=self.mac_unicast, ether_dst=MacAddress(0xFFFFFFFFFFFF), - arp_oper=ARP_OP_REQUEST, + arp_oper=ArpOperation.REQUEST, arp_sha=self.mac_unicast, arp_spa=ip4_unicast, arp_tha=MacAddress(0), @@ -505,7 +512,7 @@ def _send_gratitous_arp(self, ip4_unicast: Ip4Address) -> None: self._phtx_arp( ether_src=self.mac_unicast, ether_dst=MacAddress(0xFFFFFFFFFFFF), - arp_oper=ARP_OP_REPLY, + arp_oper=ArpOperation.REPLY, arp_sha=self.mac_unicast, arp_spa=ip4_unicast, arp_tha=MacAddress(0), @@ -684,12 +691,11 @@ def send_icmp4_packet( self, local_ip_address: Ip4Address, remote_ip_address: Ip4Address, - type: int, - code: int = 0, - ec_id: int | None = None, - ec_seq: int | None = None, - ec_data: bytes | None = None, - un_data: bytes | None = None, + type: Icmp4Type, + code: Icmp4Code, + message: Icmp4EchoReplyMessage + | Icmp4UnreachablePortMessage + | Icmp4EchoRequestMessage, ) -> TxStatus: """ Interface method for ICMPv4 Socket -> FPA communication. @@ -699,10 +705,7 @@ def send_icmp4_packet( ip4_dst=remote_ip_address, icmp4_type=type, icmp4_code=code, - icmp4_ec_id=ec_id, - icmp4_ec_seq=ec_seq, - icmp4_ec_data=ec_data, - icmp4_un_data=un_data, + icmp4_message=message, ) def send_icmp6_packet( diff --git a/tests/integration/packet_flows_rx.py b/tests/integration/packet_flows_rx.py index 41d0e6e6..cb28956b 100755 --- a/tests/integration/packet_flows_rx.py +++ b/tests/integration/packet_flows_rx.py @@ -69,8 +69,6 @@ "LOG_CHANEL": set(), "IP6_SUPPORT": True, "IP4_SUPPORT": True, - "PACKET_INTEGRITY_CHECK": True, - "PACKET_SANITY_CHECK": True, "TAP_MTU": 1500, "UDP_ECHO_NATIVE_DISABLE": False, } diff --git a/tests/integration/packet_flows_rx_tx.py b/tests/integration/packet_flows_rx_tx.py index 9c8d5cf6..4afe5588 100755 --- a/tests/integration/packet_flows_rx_tx.py +++ b/tests/integration/packet_flows_rx_tx.py @@ -74,8 +74,6 @@ "LOG_CHANEL": set(), "IP6_SUPPORT": True, "IP4_SUPPORT": True, - "PACKET_INTEGRITY_CHECK": True, - "PACKET_SANITY_CHECK": True, "TAP_MTU": 1500, "UDP_ECHO_NATIVE_DISABLE": False, } diff --git a/tests/unit/mock_network.py b/tests/unit/mock_network.py index 10df3d55..92fc43b3 100755 --- a/tests/unit/mock_network.py +++ b/tests/unit/mock_network.py @@ -142,8 +142,6 @@ def __init__(self) -> None: "LOG_CHANEL": set(), "IP6_SUPPORT": True, "IP4_SUPPORT": True, - "PACKET_INTEGRITY_CHECK": True, - "PACKET_SANITY_CHECK": True, "TAP_MTU": 1500, "UDP_ECHO_NATIVE_DISABLE": False, "IP4_DEFAULT_TTL": 64, diff --git a/tests/unit/protocols__arp__fpa.py b/tests/unit/protocols__arp__fpa.py index 32b9f053..bfa71c0d 100755 --- a/tests/unit/protocols__arp__fpa.py +++ b/tests/unit/protocols__arp__fpa.py @@ -36,8 +36,8 @@ from pytcp.lib.ip4_address import Ip4Address from pytcp.lib.mac_address import MacAddress from pytcp.protocols.arp.fpa import ArpAssembler -from pytcp.protocols.arp.ps import ARP_HEADER_LEN, ARP_OP_REPLY, ARP_OP_REQUEST -from pytcp.protocols.ether.ps import ETHER_TYPE_ARP +from pytcp.protocols.arp.ps import ARP_HEADER_LEN, ArpOperation +from pytcp.protocols.ether.ps import EtherType class TestArpAssembler(TestCase): @@ -50,7 +50,8 @@ def test_arp_fpa__ethertype(self) -> None: Make sure the 'ArpAssembler' class has the proper 'ethertype' value assigned. """ - self.assertEqual(ArpAssembler.ether_type, ETHER_TYPE_ARP) + + self.assertEqual(ArpAssembler._ether_type, EtherType.ARP) def test_arp_fpa____init__(self) -> None: """ @@ -61,13 +62,13 @@ def test_arp_fpa____init__(self) -> None: spa=Ip4Address("1.2.3.4"), tha=MacAddress("66:77:88:99:AA:BB"), tpa=Ip4Address("5.6.7.8"), - oper=ARP_OP_REPLY, + oper=ArpOperation.REPLY, ) self.assertEqual(packet._sha, MacAddress("00:11:22:33:44:55")) self.assertEqual(packet._spa, Ip4Address("1.2.3.4")) self.assertEqual(packet._tha, MacAddress("66:77:88:99:AA:BB")) self.assertEqual(packet._tpa, Ip4Address("5.6.7.8")) - self.assertEqual(packet._oper, ARP_OP_REPLY) + self.assertEqual(packet._oper, ArpOperation.REPLY) def test_arp_fpa____init____defaults(self) -> None: """ @@ -78,26 +79,19 @@ def test_arp_fpa____init____defaults(self) -> None: self.assertEqual(packet._spa, Ip4Address("0.0.0.0")) self.assertEqual(packet._tha, MacAddress("00:00:00:00:00:00")) self.assertEqual(packet._tpa, Ip4Address("0.0.0.0")) - self.assertEqual(packet._oper, ARP_OP_REQUEST) + self.assertEqual(packet._oper, ArpOperation.REQUEST) def test_arp_fpa____init____assert_oper_request(self) -> None: """ Test assertion for the request operation. """ - ArpAssembler(oper=ARP_OP_REQUEST) + ArpAssembler(oper=ArpOperation.REQUEST) def test_arp_fpa____init____assert_oper_reply(self) -> None: """ Test assertion for the request operation. """ - ArpAssembler(oper=ARP_OP_REPLY) - - def test_arp_fpa____init____assert_oper_unknown(self) -> None: - """ - Test assertion for the unknown operation. - """ - with self.assertRaises(AssertionError): - ArpAssembler(oper=-1) + ArpAssembler(oper=ArpOperation.REPLY) def test_arp_fpa____len__(self) -> None: """ @@ -115,7 +109,7 @@ def test_arp_fpa____str____request(self) -> None: spa=Ip4Address("1.2.3.4"), tha=MacAddress("66:77:88:99:AA:BB"), tpa=Ip4Address("5.6.7.8"), - oper=ARP_OP_REQUEST, + oper=ArpOperation.REQUEST, ) self.assertEqual( str(packet), @@ -132,7 +126,7 @@ def test_arp_fpa____str____reply(self) -> None: spa=Ip4Address("1.2.3.4"), tha=MacAddress("66:77:88:99:AA:BB"), tpa=Ip4Address("5.6.7.8"), - oper=ARP_OP_REPLY, + oper=ArpOperation.REPLY, ) self.assertEqual( str(packet), @@ -158,7 +152,7 @@ def test_ether_fpa__assemble(self) -> None: spa=Ip4Address("1.2.3.4"), tha=MacAddress("66:77:88:99:AA:BB"), tpa=Ip4Address("5.6.7.8"), - oper=ARP_OP_REPLY, + oper=ArpOperation.REPLY, ) frame = memoryview(bytearray(len(packet))) packet.assemble(frame) diff --git a/tests/unit/protocols__arp__phtx.py b/tests/unit/protocols__arp__phtx.py index d2c7d487..b6c88071 100755 --- a/tests/unit/protocols__arp__phtx.py +++ b/tests/unit/protocols__arp__phtx.py @@ -36,7 +36,7 @@ from pytcp.lib.packet_stats import PacketStatsTx from pytcp.lib.tx_status import TxStatus -from pytcp.protocols.arp.ps import ARP_OP_REPLY, ARP_OP_REQUEST +from pytcp.protocols.arp.ps import ArpOperation from pytcp.subsystems.packet_handler import PacketHandler from tests.unit.mock_network import ( MockNetworkSettings, @@ -72,7 +72,7 @@ def test_arp_phtx__arp_request(self) -> None: tx_status = self.packet_handler._phtx_arp( ether_src=self.mns.stack_mac_address, ether_dst=self.mns.mac_broadcast, - arp_oper=ARP_OP_REQUEST, + arp_oper=ArpOperation.REQUEST, arp_sha=self.mns.stack_mac_address, arp_spa=self.mns.stack_ip4_host.address, arp_tha=self.mns.mac_unspecified, @@ -100,7 +100,7 @@ def test_arp_phtx__arp_reply(self) -> None: tx_status = self.packet_handler._phtx_arp( ether_src=self.mns.stack_mac_address, ether_dst=self.mns.host_a_mac_address, - arp_oper=ARP_OP_REPLY, + arp_oper=ArpOperation.REPLY, arp_sha=self.mns.stack_mac_address, arp_spa=self.mns.stack_ip4_host.address, arp_tha=self.mns.host_a_mac_address, diff --git a/tests/unit/protocols__ether__fpa.py b/tests/unit/protocols__ether__fpa.py index 3208e2f2..f65b6353 100755 --- a/tests/unit/protocols__ether__fpa.py +++ b/tests/unit/protocols__ether__fpa.py @@ -37,7 +37,7 @@ from pytcp.lib.tracker import Tracker from pytcp.protocols.arp.fpa import ArpAssembler from pytcp.protocols.ether.fpa import EtherAssembler -from pytcp.protocols.ether.ps import ETHER_HEADER_LEN, ETHER_TYPE_RAW +from pytcp.protocols.ether.ps import ETHER_HEADER_LEN, EtherType from pytcp.protocols.ip4.fpa import Ip4Assembler from pytcp.protocols.ip6.fpa import Ip6Assembler from pytcp.protocols.raw.fpa import RawAssembler @@ -61,7 +61,7 @@ def test_ether_fpa____init__(self) -> None: self.assertEqual(packet._tracker, packet._carried_packet._tracker) self.assertEqual(packet._src, MacAddress("00:11:22:33:44:55")) self.assertEqual(packet._dst, MacAddress("66:77:88:99:AA:BB")) - self.assertEqual(packet._type, ETHER_TYPE_RAW) + self.assertEqual(packet._type, EtherType.RAW) def test_ether_fpa____init____defaults(self) -> None: """ @@ -72,7 +72,7 @@ def test_ether_fpa____init____defaults(self) -> None: self.assertEqual(packet._tracker, packet._carried_packet._tracker) self.assertEqual(packet._src, MacAddress("00:00:00:00:00:00")) self.assertEqual(packet._dst, MacAddress("00:00:00:00:00:00")) - self.assertEqual(packet._type, ETHER_TYPE_RAW) + self.assertEqual(packet._type, EtherType.RAW) def test_ether_fpa____init____assert_ethertype_arp(self) -> None: """ @@ -98,16 +98,6 @@ def test_ether_fpa____init____assert_ethertype_raw(self) -> None: """ EtherAssembler(carried_packet=RawAssembler()) - def test_ether_fpa____init____assert_ethertype_unknown(self) -> None: - """ - Test assertion for carried packet 'ether_type' attribute. - """ - with self.assertRaises(AssertionError): - carried_packet_mock = StrictMock() - carried_packet_mock.ether_type = -1 - carried_packet_mock.tracker = StrictMock(Tracker) - EtherAssembler(carried_packet=carried_packet_mock) # type: ignore[arg-type] - def test_ether_fpa____len__(self) -> None: """ Test the '__len__' dunder. @@ -139,7 +129,7 @@ def test_ether_fpa__tracker_getter(self) -> None: """ carried_packet_mock = StrictMock() - carried_packet_mock.ether_type = ETHER_TYPE_RAW + carried_packet_mock.ether_type = EtherType.RAW carried_packet_mock.tracker = StrictMock(Tracker) packet = EtherAssembler(carried_packet=carried_packet_mock) # type: ignore[arg-type] diff --git a/tests/unit/protocols__icmp4__fpa.py b/tests/unit/protocols__icmp4__fpa.py index f4c309d7..fe48f878 100755 --- a/tests/unit/protocols__icmp4__fpa.py +++ b/tests/unit/protocols__icmp4__fpa.py @@ -34,15 +34,24 @@ from testslide import TestCase from pytcp.lib.tracker import Tracker -from pytcp.protocols.icmp4.fpa import Icmp4Assembler +from pytcp.protocols.icmp4.fpa import ( + Icmp4Assembler, + Icmp4EchoReplyMessageAssembler, + Icmp4EchoRequestMessageAssembler, + Icmp4UnreachablePortMessageAssembler, +) from pytcp.protocols.icmp4.ps import ( - ICMP4_ECHO_REPLY, - ICMP4_ECHO_REPLY_LEN, - ICMP4_ECHO_REQUEST, - ICMP4_ECHO_REQUEST_LEN, - ICMP4_UNREACHABLE, - ICMP4_UNREACHABLE__PORT, - ICMP4_UNREACHABLE_LEN, + ICMP4_ECHO_REPLY_MESSAGE_LEN, + ICMP4_ECHO_REQUEST_MESSAGE_LEN, + ICMP4_HEADER_LEN, + ICMP4_UNRECHABLE_MESSAGE_LEN, + Icmp4EchoReplyCode, + Icmp4EchoReplyMessage, + Icmp4EchoRequestCode, + Icmp4EchoRequestMessage, + Icmp4Type, + Icmp4UnreachableCode, + Icmp4UnreachablePortMessage, ) from pytcp.protocols.ip4.ps import IP4_PROTO_ICMP4 @@ -63,94 +72,53 @@ def test_icmp4_fpa____init____echo_request(self) -> None: Test the packet constructor for the 'Echo Request' message. """ packet = Icmp4Assembler( - type=ICMP4_ECHO_REQUEST, - code=0, - ec_id=12345, - ec_seq=54321, - ec_data=b"0123456789ABCDEF", + type=Icmp4Type.ECHO_REQUEST, + code=Icmp4EchoRequestCode.DEFAULT, + message=Icmp4EchoRequestMessageAssembler( + id=12345, + seq=54321, + data=b"0123456789ABCDEF", + ), echo_tracker=Tracker(prefix="TX"), ) - self.assertEqual(packet._ec_id, 12345) - self.assertEqual(packet._ec_seq, 54321) - self.assertEqual(packet._ec_data, b"0123456789ABCDEF") + assert isinstance(packet.message, Icmp4EchoRequestMessage) + self.assertEqual(packet.message.id, 12345) + self.assertEqual(packet.message.seq, 54321) + self.assertEqual(packet.message.data, b"0123456789ABCDEF") self.assertTrue( repr(packet.tracker._echo_tracker).startswith( "Tracker(serial='TX" ) ) - def test_icmp4_fpa____init____echo_request__assert_code__under( - self, - ) -> None: - """ - Test packet constructor for the 'Echo Request' message. - """ - with self.assertRaises(AssertionError): - Icmp4Assembler( - type=ICMP4_ECHO_REQUEST, - code=-1, - ) - - def test_icmp4_fpa____init____echo_request__assert_code__over(self) -> None: - """ - Test packet constructor for the 'Echo Request' message. - """ - with self.assertRaises(AssertionError): - Icmp4Assembler( - type=ICMP4_ECHO_REQUEST, - code=1, - ) - - def test_icmp4_fpa____init____echo_request__assert_ec_id__under( - self, - ) -> None: - """ - Test assertion for the 'ec_id' argument. - """ - with self.assertRaises(AssertionError): - Icmp4Assembler( - type=ICMP4_ECHO_REQUEST, - code=0, - ec_id=-1, - ) - - def test_icmp4_fpa____init____echo_request__assert_ec_id__over( - self, - ) -> None: - """ - Test assertion for the 'ec_id' argument. - """ - with self.assertRaises(AssertionError): - Icmp4Assembler( - type=ICMP4_ECHO_REQUEST, - code=0, - ec_id=0x10000, - ) - def test_icmp4_fpa____init____echo_request__assert_ec_seq__under( self, ) -> None: """ - Test assertion for the 'ec_id' argument. + Test assertion for the 'id' argument. """ with self.assertRaises(AssertionError): Icmp4Assembler( - type=ICMP4_ECHO_REQUEST, - code=0, - ec_seq=-1, + type=Icmp4Type.ECHO_REQUEST, + code=Icmp4EchoRequestCode.DEFAULT, + message=Icmp4EchoRequestMessageAssembler( + seq=-1, + ), ) def test_icmp4_fpa____init____echo_request__assert_ec_seq__over( self, ) -> None: """ - Test assertion for the 'ec_seq' argument. + Test assertion for the 'seq' argument. """ with self.assertRaises(AssertionError): Icmp4Assembler( - type=ICMP4_ECHO_REQUEST, - code=0, - ec_seq=0x10000, + type=Icmp4Type.ECHO_REQUEST, + code=Icmp4EchoRequestCode.DEFAULT, + message=Icmp4EchoRequestMessageAssembler( + seq=0x10000, + ), ) def test_icmp4_fpa____init____unreachable_port(self) -> None: @@ -158,116 +126,84 @@ def test_icmp4_fpa____init____unreachable_port(self) -> None: Test packet constructor for the 'Unreachable Port' message. """ packet = Icmp4Assembler( - type=ICMP4_UNREACHABLE, - code=ICMP4_UNREACHABLE__PORT, - un_data=b"0123456789ABCDEF" * 50, + type=Icmp4Type.UNREACHABLE, + code=Icmp4UnreachableCode.PORT, + message=Icmp4UnreachablePortMessageAssembler( + data=b"0123456789ABCDEF" * 50, + ), echo_tracker=Tracker(prefix="TX"), ) - self.assertEqual(packet._un_data, (b"0123456789ABCDEF" * 50)[:520]) + assert isinstance(packet.message, Icmp4UnreachablePortMessage) + self.assertEqual(packet.message.data, (b"0123456789ABCDEF" * 50)[:520]) self.assertTrue( repr(packet.tracker._echo_tracker).startswith( "Tracker(serial='TX" ) ) - def test_icmp4_fpa____init____unreachable_port__assert_code__under( - self, - ) -> None: - """ - Test packet constructor for the 'Unreachable Port' message. - """ - with self.assertRaises(AssertionError): - Icmp4Assembler( - type=ICMP4_ECHO_REQUEST, - code=ICMP4_UNREACHABLE__PORT - 1, - ) - - def test_icmp4_fpa____init____unreachable_port__assert_code__over( - self, - ) -> None: - """ - Test packet constructor for the 'Unreachable Port' message. - """ - with self.assertRaises(AssertionError): - Icmp4Assembler( - type=ICMP4_ECHO_REQUEST, - code=ICMP4_UNREACHABLE__PORT + 1, - ) - def test_icmp4_fpa____init____echo_reply(self) -> None: """ Test packet constructor for the 'Echo Reply' message. """ packet = Icmp4Assembler( - type=ICMP4_ECHO_REPLY, - code=0, - ec_id=12345, - ec_seq=54321, - ec_data=b"0123456789ABCDEF", + type=Icmp4Type.ECHO_REPLY, + code=Icmp4EchoReplyCode.DEFAULT, + message=Icmp4EchoReplyMessageAssembler( + id=12345, + seq=54321, + data=b"0123456789ABCDEF", + ), echo_tracker=Tracker(prefix="TX"), ) - self.assertEqual(packet._ec_id, 12345) - self.assertEqual(packet._ec_seq, 54321) - self.assertEqual(packet._ec_data, b"0123456789ABCDEF") + assert isinstance(packet.message, Icmp4EchoReplyMessage) + self.assertEqual(packet.message.id, 12345) + self.assertEqual(packet.message.seq, 54321) + self.assertEqual(packet.message.data, b"0123456789ABCDEF") self.assertTrue( repr(packet.tracker._echo_tracker).startswith( "Tracker(serial='TX" ) ) - def test_icmp4_fpa____init____echo_reply__assert_code__under(self) -> None: - """ - Test packet constructor for the 'Echo Reply' message. - """ - with self.assertRaises(AssertionError): - Icmp4Assembler( - type=ICMP4_ECHO_REPLY, - code=-1, - ) - - def test_icmp4_fpa____init____echo_reply__assert_code__over(self) -> None: - """ - Test packet constructor for the 'Echo Reply' message. - """ - with self.assertRaises(AssertionError): - Icmp4Assembler( - type=ICMP4_ECHO_REPLY, - code=1, - ) - def test_icmp4_fpa____init____echo_reply__assert_ec_id__under(self) -> None: """ - Test assertion for the 'ec_id' argument. + Test assertion for the 'id' argument. """ with self.assertRaises(AssertionError): Icmp4Assembler( - type=ICMP4_ECHO_REPLY, - code=0, - ec_id=-1, + type=Icmp4Type.ECHO_REPLY, + code=Icmp4EchoReplyCode.DEFAULT, + message=Icmp4EchoReplyMessageAssembler( + id=-1, + ), ) def test_icmp4_fpa____init____echo_reply__assert_ec_id__over(self) -> None: """ - Test assertion for the 'ec_id' argument. + Test assertion for the 'id' argument. """ with self.assertRaises(AssertionError): Icmp4Assembler( - type=ICMP4_ECHO_REPLY, - code=0, - ec_id=0x10000, + type=Icmp4Type.ECHO_REPLY, + code=Icmp4EchoReplyCode.DEFAULT, + message=Icmp4EchoReplyMessageAssembler( + id=0x10000, + ), ) def test_icmp4_fpa____init____echo_reply__assert_ec_seq__under( self, ) -> None: """ - Test assertion for the 'ec_id' argument. + Test assertion for the 'id' argument. """ with self.assertRaises(AssertionError): Icmp4Assembler( - type=ICMP4_ECHO_REPLY, - code=0, - ec_seq=-1, + type=Icmp4Type.ECHO_REPLY, + code=Icmp4EchoReplyCode.DEFAULT, + message=Icmp4EchoReplyMessageAssembler( + seq=-1, + ), ) def test_icmp4_fpa____init____echo_reply__assert_ec_seq__over(self) -> None: @@ -276,18 +212,11 @@ def test_icmp4_fpa____init____echo_reply__assert_ec_seq__over(self) -> None: """ with self.assertRaises(AssertionError): Icmp4Assembler( - type=ICMP4_ECHO_REPLY, - code=0, - ec_seq=0x10000, - ) - - def test_icmp4_fpa____init____unknown(self) -> None: - """ - Test packet constructor for the message with unknown type. - """ - with self.assertRaises(AssertionError): - Icmp4Assembler( - type=255, + type=Icmp4Type.ECHO_REPLY, + code=Icmp4EchoReplyCode.DEFAULT, + message=Icmp4EchoReplyMessageAssembler( + seq=0x10000, + ), ) def test_icmp4_fpa____len____echo_reply(self) -> None: @@ -295,47 +224,61 @@ def test_icmp4_fpa____len____echo_reply(self) -> None: Test the '__len__()' dunder for the 'Echo Reply' message. """ packet = Icmp4Assembler( - type=ICMP4_ECHO_REPLY, - code=0, - ec_data=b"0123456789ABCDEF", + type=Icmp4Type.ECHO_REPLY, + code=Icmp4EchoReplyCode.DEFAULT, + message=Icmp4EchoReplyMessageAssembler( + data=b"0123456789ABCDEF", + ), + ) + self.assertEqual( + len(packet), ICMP4_HEADER_LEN + ICMP4_ECHO_REPLY_MESSAGE_LEN + 16 ) - self.assertEqual(len(packet), ICMP4_ECHO_REPLY_LEN + 16) def test_icmp4_fpa____len____unreachable_port(self) -> None: """ Test the '__len__()' dunder for the 'Unreachable Port' message. """ packet = Icmp4Assembler( - type=ICMP4_UNREACHABLE, - code=ICMP4_UNREACHABLE__PORT, - un_data=b"0123456789ABCDEF", + type=Icmp4Type.UNREACHABLE, + code=Icmp4UnreachableCode.PORT, + message=Icmp4UnreachablePortMessageAssembler( + data=b"0123456789ABCDEF", + ), + ) + self.assertEqual( + len(packet), ICMP4_HEADER_LEN + ICMP4_UNRECHABLE_MESSAGE_LEN + 16 ) - self.assertEqual(len(packet), ICMP4_UNREACHABLE_LEN + 16) def test_icmp4_fpa____len____echo_request(self) -> None: """ Test the '__len__() dudner for the 'Echo Request' message. """ packet = Icmp4Assembler( - type=ICMP4_ECHO_REQUEST, - code=0, - ec_data=b"0123456789ABCDEF", + type=Icmp4Type.ECHO_REQUEST, + code=Icmp4EchoRequestCode.DEFAULT, + message=Icmp4EchoRequestMessageAssembler( + data=b"0123456789ABCDEF", + ), + ) + self.assertEqual( + len(packet), ICMP4_HEADER_LEN + ICMP4_ECHO_REQUEST_MESSAGE_LEN + 16 ) - self.assertEqual(len(packet), ICMP4_ECHO_REQUEST_LEN + 16) def test_icmp4_fpa____str____echo_reply(self) -> None: """ Test the '__str__()' dunder for the 'Echo Reply' message. """ packet = Icmp4Assembler( - type=ICMP4_ECHO_REPLY, - code=0, - ec_id=12345, - ec_seq=54321, - ec_data=b"0123456789ABCDEF", + type=Icmp4Type.ECHO_REPLY, + code=Icmp4EchoReplyCode.DEFAULT, + message=Icmp4EchoReplyMessageAssembler( + id=12345, + seq=54321, + data=b"0123456789ABCDEF", + ), ) self.assertEqual( - str(packet), "ICMPv4 0/0 (echo_reply), id 12345, seq 54321, dlen 16" + str(packet), "ICMPv4 0/0 (echo reply), id 12345, seq 54321, dlen 16" ) def test_icmp4_fpa____str____unreachable_port(self) -> None: @@ -343,33 +286,41 @@ def test_icmp4_fpa____str____unreachable_port(self) -> None: Test the '__str__() dunder for the 'Unreachable Port' message. """ packet = Icmp4Assembler( - type=ICMP4_UNREACHABLE, - code=ICMP4_UNREACHABLE__PORT, - un_data=b"0123456789ABCDEF", + type=Icmp4Type.UNREACHABLE, + code=Icmp4UnreachableCode.PORT, + message=Icmp4UnreachablePortMessageAssembler( + data=b"0123456789ABCDEF", + ), ) - self.assertEqual(str(packet), "ICMPv4 3/3 (unreachable_port), dlen 16") + self.assertEqual(str(packet), "ICMPv4 3/3 (port unreachable), dlen 16") def test_icmp4_fpa____str____echo_request(self) -> None: """ Test the '__str__()' dunder for the 'Echo Request' message. """ packet = Icmp4Assembler( - type=ICMP4_ECHO_REQUEST, - code=0, - ec_id=12345, - ec_seq=54321, - ec_data=b"0123456789ABCDEF", + type=Icmp4Type.ECHO_REQUEST, + code=Icmp4EchoRequestCode.DEFAULT, + message=Icmp4EchoRequestMessageAssembler( + id=12345, + seq=54321, + data=b"0123456789ABCDEF", + ), ) self.assertEqual( str(packet), - "ICMPv4 8/0 (echo_request), id 12345, seq 54321, dlen 16", + "ICMPv4 8/0 (echo request), id 12345, seq 54321, dlen 16", ) def test_icmp4_fpa__tracker_getter(self) -> None: """ Test the '_tracker' attribute getter. """ - packet = Icmp4Assembler() + packet = Icmp4Assembler( + type=Icmp4Type.ECHO_REQUEST, + code=Icmp4EchoRequestCode.DEFAULT, + message=Icmp4EchoRequestMessageAssembler(), + ) self.assertTrue( repr(packet.tracker).startswith("Tracker(serial='TX") ) @@ -379,11 +330,13 @@ def test_icmp4_fpa__assemble__echo_reply(self) -> None: Test the 'assemble()' method for the 'Echo Reply' message. """ packet = Icmp4Assembler( - type=ICMP4_ECHO_REPLY, - code=0, - ec_id=12345, - ec_seq=54321, - ec_data=b"0123456789ABCDEF", + type=Icmp4Type.ECHO_REPLY, + code=Icmp4EchoReplyCode.DEFAULT, + message=Icmp4EchoReplyMessageAssembler( + id=12345, + seq=54321, + data=b"0123456789ABCDEF", + ), ) frame = memoryview(bytearray(len(packet))) packet.assemble(frame) @@ -394,9 +347,11 @@ def test_icmp4_fpa__asssemble__unreachable_port(self) -> None: Test the 'assemble()' method for the 'Unreachable Port' message. """ packet = Icmp4Assembler( - type=ICMP4_UNREACHABLE, - code=ICMP4_UNREACHABLE__PORT, - un_data=b"0123456789ABCDEF", + type=Icmp4Type.UNREACHABLE, + code=Icmp4UnreachableCode.PORT, + message=Icmp4UnreachablePortMessageAssembler( + data=b"0123456789ABCDEF", + ), ) frame = memoryview(bytearray(len(packet))) packet.assemble(frame) @@ -409,11 +364,13 @@ def test_icmp4_fpa__assemble__echo_request(self) -> None: Test the 'assemble()' method for the 'Echo Request' message. """ packet = Icmp4Assembler( - type=ICMP4_ECHO_REQUEST, - code=0, - ec_id=12345, - ec_seq=54321, - ec_data=b"0123456789ABCDEF", + type=Icmp4Type.ECHO_REQUEST, + code=Icmp4EchoRequestCode.DEFAULT, + message=Icmp4EchoRequestMessageAssembler( + id=12345, + seq=54321, + data=b"0123456789ABCDEF", + ), ) frame = memoryview(bytearray(len(packet))) packet.assemble(frame) diff --git a/tests/unit/protocols__icmp4__phtx.py b/tests/unit/protocols__icmp4__phtx.py index 06792888..c1b5d31d 100755 --- a/tests/unit/protocols__icmp4__phtx.py +++ b/tests/unit/protocols__icmp4__phtx.py @@ -36,11 +36,16 @@ from pytcp.lib.packet_stats import PacketStatsTx from pytcp.lib.tx_status import TxStatus +from pytcp.protocols.icmp4.fpa import ( + Icmp4EchoReplyMessageAssembler, + Icmp4EchoRequestMessageAssembler, + Icmp4UnreachablePortMessageAssembler, +) from pytcp.protocols.icmp4.ps import ( - ICMP4_ECHO_REPLY, - ICMP4_ECHO_REQUEST, - ICMP4_UNREACHABLE, - ICMP4_UNREACHABLE__PORT, + Icmp4EchoReplyCode, + Icmp4EchoRequestCode, + Icmp4Type, + Icmp4UnreachableCode, ) from pytcp.subsystems.packet_handler import PacketHandler from tests.unit.mock_network import ( @@ -78,10 +83,13 @@ def test_icmp4_phtx__ip4_icmp4_echo_request(self) -> None: tx_status = self.packet_handler._phtx_icmp4( ip4_src=self.mns.stack_ip4_host.address, ip4_dst=self.mns.host_a_ip4_address, - icmp4_type=ICMP4_ECHO_REQUEST, - icmp4_ec_id=12345, - icmp4_ec_seq=54320, - icmp4_ec_data=b"0123456789ABCDEF" * 20, + icmp4_type=Icmp4Type.ECHO_REQUEST, + icmp4_code=Icmp4EchoRequestCode.DEFAULT, + icmp4_message=Icmp4EchoRequestMessageAssembler( + id=12345, + seq=54320, + data=b"0123456789ABCDEF" * 20, + ), ) self.assertEqual(tx_status, TxStatus.PASSED__ETHER__TO_TX_RING) self.assertEqual( @@ -108,10 +116,13 @@ def test_icmp4_phtx__ip4_icmp4_echo_reply(self) -> None: tx_status = self.packet_handler._phtx_icmp4( ip4_src=self.mns.stack_ip4_host.address, ip4_dst=self.mns.host_a_ip4_address, - icmp4_type=ICMP4_ECHO_REPLY, - icmp4_ec_id=12345, - icmp4_ec_seq=54320, - icmp4_ec_data=b"0123456789ABCDEF" * 20, + icmp4_type=Icmp4Type.ECHO_REPLY, + icmp4_code=Icmp4EchoReplyCode.DEFAULT, + icmp4_message=Icmp4EchoReplyMessageAssembler( + id=12345, + seq=54320, + data=b"0123456789ABCDEF" * 20, + ), ) self.assertEqual(tx_status, TxStatus.PASSED__ETHER__TO_TX_RING) self.assertEqual( @@ -138,9 +149,11 @@ def test_icmp4_phtx__ip4_icmp4_unreachable_port(self) -> None: tx_status = self.packet_handler._phtx_icmp4( ip4_src=self.mns.stack_ip4_host.address, ip4_dst=self.mns.host_a_ip4_address, - icmp4_type=ICMP4_UNREACHABLE, - icmp4_code=ICMP4_UNREACHABLE__PORT, - icmp4_un_data=b"0123456789ABCDEF" * 100, + icmp4_type=Icmp4Type.UNREACHABLE, + icmp4_code=Icmp4UnreachableCode.PORT, + icmp4_message=Icmp4UnreachablePortMessageAssembler( + data=b"0123456789ABCDEF" * 100, + ), ) self.assertEqual(tx_status, TxStatus.PASSED__ETHER__TO_TX_RING) self.assertEqual( diff --git a/tests/unit/protocols__ip4__fpa.py b/tests/unit/protocols__ip4__fpa.py index 137eb220..7bb00289 100755 --- a/tests/unit/protocols__ip4__fpa.py +++ b/tests/unit/protocols__ip4__fpa.py @@ -36,8 +36,12 @@ from pytcp.config import IP4_DEFAULT_TTL from pytcp.lib.ip4_address import Ip4Address from pytcp.lib.tracker import Tracker -from pytcp.protocols.ether.ps import ETHER_TYPE_IP4 -from pytcp.protocols.icmp4.fpa import Icmp4Assembler +from pytcp.protocols.ether.ps import EtherType +from pytcp.protocols.icmp4.fpa import ( + Icmp4Assembler, + Icmp4EchoReplyMessageAssembler, +) +from pytcp.protocols.icmp4.ps import Icmp4EchoReplyCode, Icmp4Type from pytcp.protocols.ip4.fpa import ( Ip4Assembler, Ip4FragAssembler, @@ -68,7 +72,7 @@ def test_ip4_fpa__ethertype(self) -> None: Make sure the 'Ip4Assembler' class has the proper 'ethertype' value assigned. """ - self.assertEqual(Ip4Assembler.ether_type, ETHER_TYPE_IP4) + self.assertEqual(Ip4Assembler.ether_type, EtherType.IP4) def test_ip4_fpa____init__(self) -> None: """ @@ -203,7 +207,13 @@ def test_ip4_fpa____init____assert_proto_icmp4(self) -> None: """ Test assertion for the carried packet 'ip4_proto' attribute. """ - Ip4Assembler(carried_packet=Icmp4Assembler()) + Ip4Assembler( + carried_packet=Icmp4Assembler( + type=Icmp4Type.ECHO_REPLY, + code=Icmp4EchoReplyCode.DEFAULT, + message=Icmp4EchoReplyMessageAssembler(), + ) + ) def test_ip4_fpa____init____assert_proto_raw(self) -> None: """ @@ -417,7 +427,7 @@ def test_ip4_frag_fpa__ethertype(self) -> None: """ Test the 'ethertype' property of the 'Ip4FragAssembler' class. """ - self.assertEqual(Ip4Assembler.ether_type, ETHER_TYPE_IP4) + self.assertEqual(Ip4Assembler.ether_type, EtherType.IP4) def test_ip4_frag_fpa____init__(self) -> None: """ diff --git a/tests/unit/protocols__ip6__fpa.py b/tests/unit/protocols__ip6__fpa.py index 2b7f482a..41170a09 100755 --- a/tests/unit/protocols__ip6__fpa.py +++ b/tests/unit/protocols__ip6__fpa.py @@ -36,7 +36,7 @@ from pytcp.config import IP6_DEFAULT_HOP from pytcp.lib.ip6_address import Ip6Address from pytcp.lib.tracker import Tracker -from pytcp.protocols.ether.ps import ETHER_TYPE_IP6 +from pytcp.protocols.ether.ps import EtherType from pytcp.protocols.icmp6.fpa import Icmp6Assembler from pytcp.protocols.ip6.fpa import Ip6Assembler from pytcp.protocols.ip6.ps import IP6_HEADER_LEN, IP6_NEXT_RAW @@ -55,7 +55,7 @@ def test_ip6_fpa__ethertype(self) -> None: Make sure the 'Ip6Assembler' class has the proper 'ethertype' value assigned. """ - self.assertEqual(Ip6Assembler.ether_type, ETHER_TYPE_IP6) + self.assertEqual(Ip6Assembler.ether_type, EtherType.IP6) def test_ip6_fpa____init__(self) -> None: """