diff --git a/.gitignore b/.gitignore
index 1b057405..92afa22f 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,2 +1,2 @@
-
__pycache__/
+venv/
diff --git a/.vscode/settings.json b/.vscode/settings.json
index 49d4639f..176ca276 100644
--- a/.vscode/settings.json
+++ b/.vscode/settings.json
@@ -2,5 +2,8 @@
"python.analysis.extraPaths": [
"./pytcp"
],
- "python.formatting.provider": "black"
+ "python.formatting.provider": "none",
+ "[python]": {
+ "editor.defaultFormatter": "ms-python.black-formatter"
+ }
}
\ No newline at end of file
diff --git a/pytcp/config.py b/pytcp/config.py
index 985ffb0f..2cd4ddeb 100755
--- a/pytcp/config.py
+++ b/pytcp/config.py
@@ -73,22 +73,6 @@
}
LOG_DEBUG = False
-# Packet integrity sanity check, if enabled it protects the protocol parsers
-# from being exposed to malformed or malicious packets that could cause them
-# to crash during packet parsing. It progessively check appropriate length
-# fields and ensure they are set within sane boundaries. It also checks
-# packet's actual header/options/data lengths against above values and default
-# minimum/maximum lengths for given protocol. Also packet options (if any) are
-# checked in similar fashion to ensure they will not exploit or crash parser.
-PACKET_INTEGRITY_CHECK = True
-
-# Packet sanity check, if enabled it validates packet's fields to detect invalid
-# values or invalid combinations of values. For example in TCP/UDP it drops
-# packets with port set to 0, in TCP it drop packet with SYN and FIN flags set
-# simultaneously, for ICMPv6 it provides very detailed check of messages
-# integrity.
-PACKET_SANITY_CHECK = True
-
# Drop IPv4 packets containing options - this seems to be widely adopted
# security feature. Stack parses but doesn't support IPv4 options as they are
# mostly useless anyway.
@@ -144,3 +128,6 @@
# Native support for UDP Echo (used for packet flow unit testing only and should
# always be disabled).
UDP_ECHO_NATIVE_DISABLE = True
+
+# LRU cache size, used by packet parsers to cache parsed field values.
+LRU_CACHE_SIZE = 16
diff --git a/pytcp/lib/errors.py b/pytcp/lib/errors.py
new file mode 100755
index 00000000..358f2ac5
--- /dev/null
+++ b/pytcp/lib/errors.py
@@ -0,0 +1,79 @@
+#!/usr/bin/env python3
+
+############################################################################
+# #
+# PyTCP - Python TCP/IP stack #
+# Copyright (C) 2020-present Sebastian Majewski #
+# #
+# This program is free software: you can redistribute it and/or modify #
+# it under the terms of the GNU General Public License as published by #
+# the Free Software Foundation, either version 3 of the License, or #
+# (at your option) any later version. #
+# #
+# This program is distributed in the hope that it will be useful, #
+# but WITHOUT ANY WARRANTY; without even the implied warranty of #
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the #
+# GNU General Public License for more details. #
+# #
+# You should have received a copy of the GNU General Public License #
+# along with this program. If not, see . #
+# #
+# Author's email: ccie18643@gmail.com #
+# Github repository: https://github.com/ccie18643/PyTCP #
+# #
+############################################################################
+
+
+"""
+Module contains methods supporting errors.
+
+pytcp/lib/errors.py
+
+ver 2.7
+"""
+
+
+from __future__ import annotations
+
+
+class PyTcpError(Exception):
+ """
+ Base class for all PyTCP exceptions.
+ """
+
+ ...
+
+
+class UnsupportedCaseError(PyTcpError):
+ """
+ Exception raised when the not supposed to be reached
+ 'match' case is being reached for whatever reason.
+ """
+
+ ...
+
+
+class PacketValidationError(PyTcpError):
+ """
+ Exception raised when packet validation fails.
+ """
+
+ ...
+
+
+class PacketIntegrityError(PacketValidationError):
+ """
+ Exception raised when integrity check fails.
+ """
+
+ def __init__(self, message: str):
+ super().__init__("[INTEGRITY]" + message)
+
+
+class PacketSanityError(PacketValidationError):
+ """
+ Exception raised when sanity check fails.
+ """
+
+ def __init__(self, message: str):
+ super().__init__("[SANITY]" + message)
diff --git a/pytcp/lib/protocol_enum.py b/pytcp/lib/protocol_enum.py
new file mode 100755
index 00000000..59507074
--- /dev/null
+++ b/pytcp/lib/protocol_enum.py
@@ -0,0 +1,59 @@
+#!/usr/bin/env python3
+
+############################################################################
+# #
+# PyTCP - Python TCP/IP stack #
+# Copyright (C) 2020-present Sebastian Majewski #
+# #
+# This program is free software: you can redistribute it and/or modify #
+# it under the terms of the GNU General Public License as published by #
+# the Free Software Foundation, either version 3 of the License, or #
+# (at your option) any later version. #
+# #
+# This program is distributed in the hope that it will be useful, #
+# but WITHOUT ANY WARRANTY; without even the implied warranty of #
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the #
+# GNU General Public License for more details. #
+# #
+# You should have received a copy of the GNU General Public License #
+# along with this program. If not, see . #
+# #
+# Author's email: ccie18643@gmail.com #
+# Github repository: https://github.com/ccie18643/PyTCP #
+# #
+############################################################################
+
+
+"""
+Module contains the ProtocolEnum class.
+
+pytcp/lib/protocol_enum.py
+
+ver 2.7
+"""
+
+
+from __future__ import annotations
+
+from enum import Enum
+from typing import Self
+
+
+class ProtocolEnum(Enum):
+ def __int__(self) -> int:
+ return int(self.value)
+
+ def __str__(self) -> str:
+ return str(self.value)
+
+ @staticmethod
+ def _extract(frame: bytes) -> int:
+ raise NotImplementedError
+
+ @classmethod
+ def from_frame(cls, /, frame: bytes) -> Self:
+ return cls(cls._extract(frame))
+
+ @classmethod
+ def sanity_check(cls, /, frame: bytes) -> bool:
+ return cls._extract(frame) in cls
diff --git a/pytcp/protocols/arp/fpa.py b/pytcp/protocols/arp/fpa.py
index 5cc0cef5..2423fdd5 100755
--- a/pytcp/protocols/arp/fpa.py
+++ b/pytcp/protocols/arp/fpa.py
@@ -41,17 +41,22 @@
from pytcp.lib.ip4_address import Ip4Address
from pytcp.lib.mac_address import MacAddress
from pytcp.lib.tracker import Tracker
-from pytcp.protocols.arp.ps import ARP_HEADER_LEN, ARP_OP_REPLY, ARP_OP_REQUEST
-from pytcp.protocols.ether.ps import ETHER_TYPE_ARP
-
-
-class ArpAssembler:
+from pytcp.protocols.arp.ps import (
+ ARP_HEADER_LEN,
+ Arp,
+ ArpHardwareLength,
+ ArpHardwareType,
+ ArpOperation,
+ ArpProtocolLength,
+ ArpProtocolType,
+)
+
+
+class ArpAssembler(Arp):
"""
ARP packet assembler support class.
"""
- ether_type = ETHER_TYPE_ARP
-
def __init__(
self,
*,
@@ -59,69 +64,54 @@ def __init__(
spa: Ip4Address = Ip4Address(0),
tha: MacAddress = MacAddress(0),
tpa: Ip4Address = Ip4Address(0),
- oper: int = ARP_OP_REQUEST,
+ oper: ArpOperation = ArpOperation.REQUEST,
echo_tracker: Tracker | None = None,
) -> None:
"""
Class constructor.
"""
- assert oper in (ARP_OP_REQUEST, ARP_OP_REPLY), f"{oper=}"
-
self._tracker = Tracker(prefix="TX", echo_tracker=echo_tracker)
- self._hrtype: int = 1
- self._prtype: int = 0x0800
- self._hrlen: int = 6
- self._prlen: int = 4
- self._oper: int = oper
- self._sha: MacAddress = sha
- self._spa: Ip4Address = spa
- self._tha: MacAddress = tha
- self._tpa: Ip4Address = tpa
+ self._hrtype = ArpHardwareType.ETHERNET
+ self._prtype = ArpProtocolType.IP4
+ self._hrlen = ArpHardwareLength.ETHERNET
+ self._prlen = ArpProtocolLength.IP4
+ self._oper = oper
+ self._sha = sha
+ self._spa = spa
+ self._tha = tha
+ self._tpa = tpa
def __len__(self) -> int:
"""
Length of the packet.
"""
- return ARP_HEADER_LEN
- def __str__(self) -> str:
- """
- Packet log string.
- """
- if self._oper == ARP_OP_REQUEST:
- return (
- f"ARP request {self._spa} / {self._sha}"
- f" > {self._tpa} / {self._tha}"
- )
- if self._oper == ARP_OP_REPLY:
- return (
- f"ARP reply {self._spa} / {self._sha}"
- f" > {self._tpa} / {self._tha}"
- )
- return f"ARP request unknown operation {self._oper}"
+ return ARP_HEADER_LEN
@property
def tracker(self) -> Tracker:
"""
Getter for the '_tracker' property.
"""
+
return self._tracker
- def assemble(self, frame: memoryview) -> None:
+ def assemble(self, /, frame: memoryview) -> None:
"""
Assemble packet into the raw form.
"""
+
struct.pack_into(
"!HH BBH 6s 4s 6s 4s",
frame,
0,
- self._hrtype,
- self._prtype,
- self._hrlen,
- self._prlen,
- self._oper,
+ int(self._hrtype),
+ int(self._prtype),
+ int(self._hrlen),
+ int(self._prlen),
+ int(self._oper),
bytes(self._sha),
bytes(self._spa),
bytes(self._tha),
diff --git a/pytcp/protocols/arp/fpp.py b/pytcp/protocols/arp/fpp.py
index a572c225..46709b73 100755
--- a/pytcp/protocols/arp/fpp.py
+++ b/pytcp/protocols/arp/fpp.py
@@ -38,180 +38,137 @@
from __future__ import annotations
-import struct
from typing import TYPE_CHECKING
-from pytcp import config
+from pytcp.lib.errors import PacketIntegrityError, PacketSanityError
from pytcp.lib.ip4_address import Ip4Address
from pytcp.lib.mac_address import MacAddress
-from pytcp.protocols.arp.ps import ARP_HEADER_LEN, ARP_OP_REPLY, ARP_OP_REQUEST
+from pytcp.protocols.arp.ps import (
+ ARP_HEADER_LEN,
+ Arp,
+ ArpHardwareLength,
+ ArpHardwareType,
+ ArpOperation,
+ ArpProtocolLength,
+ ArpProtocolType,
+)
if TYPE_CHECKING:
from pytcp.lib.packet import PacketRx
-class ArpParser:
+class ArpIntegrityError(PacketIntegrityError):
"""
- ARP packet parser class.
+ Exception raised when ARP packet integrity check fails.
"""
- def __init__(self, packet_rx: PacketRx) -> None:
- """
- Class constructor.
- """
-
- packet_rx.arp = self
-
- self._frame = packet_rx.frame
+ def __init__(self, message: str):
+ super().__init__("[ARP] " + message)
- packet_rx.parse_failed = (
- self._packet_integrity_check() or self._packet_sanity_check()
- )
- def __len__(self) -> int:
- """
- Number of bytes remaining in the frame.
- """
- return len(self._frame)
+class ArpSanityError(PacketSanityError):
+ """
+ Exception raised when ARP packet sanity check fails.
+ """
- def __str__(self) -> str:
- """
- Packet log string.
- """
- if self.oper == ARP_OP_REQUEST:
- return (
- f"ARP request {self.spa} / {self.sha}"
- f" > {self.tpa} / {self.tha}"
- )
- if self.oper == ARP_OP_REPLY:
- return (
- f"ARP reply {self.spa} / {self.sha}"
- f" > {self.tpa} / {self.tha}"
- )
- return f"ARP request unknown operation {self.oper}"
+ def __init__(self, message: str):
+ super().__init__("[ARP] " + message)
- @property
- def hrtype(self) -> int:
- """
- Read the 'Hardware address type' field.
- """
- if "_cache__hrtype" not in self.__dict__:
- self._cache__hrtype: int = struct.unpack("!H", self._frame[0:2])[0]
- return self._cache__hrtype
- @property
- def prtype(self) -> int:
- """
- Read the 'Protocol address type' field.
- """
- if "_cache__prtype" not in self.__dict__:
- self._cache__prtype: int = struct.unpack("!H", self._frame[2:4])[0]
- return self._cache__prtype
+class ArpParser(Arp):
+ """
+ ARP packet parser class.
+ """
- @property
- def hrlen(self) -> int:
+ def __init__(self, /, packet_rx: PacketRx) -> None:
"""
- Read the 'Hardware address length' field.
+ Class constructor.
"""
- return self._frame[4]
- @property
- def prlen(self) -> int:
- """
- Read the 'Protocol address length' field.
- """
- return self._frame[5]
+ packet_rx.arp = self
- @property
- def oper(self) -> int:
- """
- Read the 'Operation' field.
- """
- if "_cache__oper" not in self.__dict__:
- self._cache__oper: int = struct.unpack("!H", self._frame[6:8])[0]
- return self._cache__oper
+ self._frame = packet_rx.frame
- @property
- def sha(self) -> MacAddress:
- """
- Read the 'Sender hardware address' field.
- """
- if "_cache__sha" not in self.__dict__:
- self._cache__sha = MacAddress(self._frame[8:14])
- return self._cache__sha
+ self._packet_integrity_check()
+ self._packet_sanity_check()
- @property
- def spa(self) -> Ip4Address:
- """
- Read the 'Sender protocol address' field.
- """
- if "_cache__spa" not in self.__dict__:
- self._cache__spa = Ip4Address(self._frame[14:18])
- return self._cache__spa
+ self._hrtype = ArpHardwareType.from_frame(self._frame)
+ self._prtype = ArpProtocolType.from_frame(self._frame)
+ self._hrlen = ArpHardwareLength.from_frame(self._frame)
+ self._prlen = ArpProtocolLength.from_frame(self._frame)
+ self._oper = ArpOperation.from_frame(self._frame)
+ self._sha = MacAddress(self._frame[8:14])
+ self._spa = Ip4Address(self._frame[14:18])
+ self._tha = MacAddress(self._frame[18:24])
+ self._tpa = Ip4Address(self._frame[24:28])
- @property
- def tha(self) -> MacAddress:
+ def __len__(self) -> int:
"""
- Read the 'Target hardware address' field.
+ Number of bytes remaining in the frame.
"""
- if "_cache__tha" not in self.__dict__:
- self._cache__tha = MacAddress(self._frame[18:24])
- return self._cache__tha
- @property
- def tpa(self) -> Ip4Address:
- """
- Read the 'Target protocol address' field.
- """
- if "_cache__tpa" not in self.__dict__:
- self._cache__tpa = Ip4Address(self._frame[24:28])
- return self._cache__tpa
+ return len(self._frame)
@property
def packet_copy(self) -> bytes:
"""
Read the whole packet.
"""
+
if "_cache__packet_copy" not in self.__dict__:
self._cache__packet_copy = bytes(self._frame[:ARP_HEADER_LEN])
+
return self._cache__packet_copy
- def _packet_integrity_check(self) -> str:
+ def _packet_integrity_check(self) -> None:
"""
Packet integrity check to be run on raw packet prior to parsing
- to make sure parsing is safe
+ to make sure parsing is safe.
"""
- if not config.PACKET_INTEGRITY_CHECK:
- return ""
-
if len(self) < ARP_HEADER_LEN:
- return "ARP integrity - wrong packet length (I)"
-
- return ""
+ raise ArpIntegrityError(
+ "The minimum packet length must be "
+ f"'{ARP_HEADER_LEN}' bytes, got {len(self)} bytes."
+ )
- def _packet_sanity_check(self) -> str:
+ def _packet_sanity_check(self) -> None:
"""
- Packet sanity check to be run on parsed packet to make sure packet's
- fields contain sane values
+ Packet sanity check to be run on parsed packet to make sure
+ packet's fields contain sane values.
"""
- if not config.PACKET_SANITY_CHECK:
- return ""
-
- if self.hrtype != 1:
- return "ARP sanity - 'arp_hrtype' must be 1"
-
- if self.prtype != 0x0800:
- return "ARP sanity - 'arp_prtype' must be 0x0800"
+ if not ArpHardwareType.sanity_check(self._frame):
+ raise ArpSanityError(
+ "The 'arp_hrtype' field value must be one of "
+ f"{[hrtype.value for hrtype in ArpHardwareType]}, "
+ f"got '{self.hrtype}'."
+ )
- if self.hrlen != 6:
- return "ARP sanity - 'arp_hrlen' must be 6"
+ if not ArpProtocolType.sanity_check(self._frame):
+ raise ArpSanityError(
+ "The 'arp_prtype' field value must be one of "
+ f"{[prtype.value for prtype in ArpHardwareType]}, "
+ f"got '{self.prtype}'."
+ )
- if self.prlen != 4:
- return "ARP sanity - 'arp_prlen' must be 4"
+ if not ArpHardwareLength.sanity_check(self._frame):
+ raise ArpSanityError(
+ "The 'arp_hrlen' field value must be one of "
+ f"{[hrlen.value for hrlen in ArpHardwareLength]}, "
+ f"got '{self.hrlen}'."
+ )
- if self.oper not in {1, 2}:
- return "ARP sanity - 'oper' must be [1-2]"
+ if not ArpProtocolLength.sanity_check(self._frame):
+ raise ArpSanityError(
+ "The 'arp_prlen' field value must be one of "
+ f"{[prlen.value for prlen in ArpProtocolLength]}, "
+ f"got '{self.prlen}'."
+ )
- return ""
+ if not ArpOperation.sanity_check(self._frame):
+ raise ArpSanityError(
+ "The 'oper' field value must be one of "
+ f"{[oper.value for oper in ArpOperation]}, "
+ f"got '{self.oper}'."
+ )
diff --git a/pytcp/protocols/arp/phrx.py b/pytcp/protocols/arp/phrx.py
index 9ed0e8fc..e986f6a7 100755
--- a/pytcp/protocols/arp/phrx.py
+++ b/pytcp/protocols/arp/phrx.py
@@ -44,9 +44,10 @@
from pytcp import config
from pytcp.lib import stack
+from pytcp.lib.errors import PacketValidationError
from pytcp.lib.logger import log
from pytcp.protocols.arp.fpp import ArpParser
-from pytcp.protocols.arp.ps import ARP_OP_REPLY, ARP_OP_REQUEST
+from pytcp.protocols.arp.ps import ArpOperation
if TYPE_CHECKING:
from pytcp.lib.packet import PacketRx
@@ -60,110 +61,124 @@ def _phrx_arp(self: PacketHandler, packet_rx: PacketRx) -> None:
self.packet_stats_rx.arp__pre_parse += 1
- ArpParser(packet_rx)
-
- if packet_rx.parse_failed:
+ try:
+ ArpParser(packet_rx)
+ except PacketValidationError as error:
self.packet_stats_rx.arp__failed_parse__drop += 1
__debug__ and log(
"arp",
- f"{packet_rx.tracker} - {packet_rx.parse_failed}>",
+ f"{packet_rx.tracker} - {error}>",
)
return
__debug__ and log("arp", f"{packet_rx.tracker} - {packet_rx.arp}")
- if packet_rx.arp.oper == ARP_OP_REQUEST:
- self.packet_stats_rx.arp__op_request += 1
- # Check if request contains our IP address in SPA field,
- # this indicates IP address conflict
- if packet_rx.arp.spa in self.ip4_unicast:
- self.packet_stats_rx.arp__op_request__ip_conflict += 1
- __debug__ and log(
- "arp",
- f"{packet_rx.tracker} - IP ({packet_rx.arp.spa}) "
- f"conflict detected with host at {packet_rx.arp.sha}>",
- )
- return
+ match packet_rx.arp.oper:
+ case ArpOperation.REQUEST:
+ _phrx_arp__request(self, packet_rx)
+ case ArpOperation.REPLY:
+ _phrx_arp__reply(self, packet_rx)
- # Check if the request is for one of our IP addresses,
- # if so the craft ARP reply packet and send it out
- if packet_rx.arp.tpa in self.ip4_unicast:
- self.packet_stats_rx.arp__op_request__tpa_stack__respond += 1
- self._phtx_arp(
- ether_src=self.mac_unicast,
- ether_dst=packet_rx.arp.sha,
- arp_oper=ARP_OP_REPLY,
- arp_sha=self.mac_unicast,
- arp_spa=packet_rx.arp.tpa,
- arp_tha=packet_rx.arp.sha,
- arp_tpa=packet_rx.arp.spa,
- echo_tracker=packet_rx.tracker,
- )
- # Update ARP cache with the mapping learned from the received
- # ARP request that was destined to this stack
- if config.ARP_CACHE_UPDATE_FROM_DIRECT_REQUEST:
- self.packet_stats_rx.arp__op_request__update_arp_cache += 1
- __debug__ and log(
- "arp",
- f"{packet_rx.tracker} - Adding/refreshing "
- "ARP cache entry from direct request "
- f"- {packet_rx.arp.spa} -> {packet_rx.arp.sha}>",
- )
- stack.arp_cache.add_entry(packet_rx.arp.spa, packet_rx.arp.sha)
- return
+def _phrx_arp__request(self: PacketHandler, packet_rx: PacketRx) -> None:
+ """
+ Handle inbound ARP request packets.
+ """
- else:
- # Drop packet if TPA does not match one of our IP addresses
- self.packet_stats_rx.arp__op_request__tpa_unknown__drop += 1
- return
+ self.packet_stats_rx.arp__op_request += 1
+ # Check if request contains our IP address in SPA field,
+ # this indicates IP address conflict
+ if packet_rx.arp.spa in self.ip4_unicast:
+ self.packet_stats_rx.arp__op_request__ip_conflict += 1
+ __debug__ and log(
+ "arp",
+ f"{packet_rx.tracker} - IP ({packet_rx.arp.spa}) "
+ f"conflict detected with host at {packet_rx.arp.sha}>",
+ )
+ return
- # Handle ARP reply
- elif packet_rx.arp.oper == ARP_OP_REPLY:
- self.packet_stats_rx.arp__op_reply += 1
- # Check for ARP reply that is response to our ARP probe, this indicates
- # the IP address we trying to claim is in use
- if packet_rx.ether.dst == self.mac_unicast:
- if (
- packet_rx.arp.spa
- in [_.address for _ in self.ip4_host_candidate]
- and packet_rx.arp.tha == self.mac_unicast
- and packet_rx.arp.tpa.is_unspecified
- ):
- self.packet_stats_rx.arp__op_reply__ip_conflict += 1
- __debug__ and log(
- "arp",
- f"{packet_rx.tracker} - ARP Probe detected "
- f"conflict for IP {packet_rx.arp.spa} with host at "
- f"{packet_rx.arp.sha}>",
- )
- stack.arp_probe_unicast_conflict.add(packet_rx.arp.spa)
- return
-
- # Update ARP cache with mapping received as direct ARP reply
- if packet_rx.ether.dst == self.mac_unicast:
- self.packet_stats_rx.arp__op_reply__update_arp_cache += 1
+ # Check if the request is for one of our IP addresses,
+ # if so the craft ARP reply packet and send it out
+ if packet_rx.arp.tpa in self.ip4_unicast:
+ self.packet_stats_rx.arp__op_request__tpa_stack__respond += 1
+ self._phtx_arp(
+ ether_src=self.mac_unicast,
+ ether_dst=packet_rx.arp.sha,
+ arp_oper=ArpOperation.REPLY,
+ arp_sha=self.mac_unicast,
+ arp_spa=packet_rx.arp.tpa,
+ arp_tha=packet_rx.arp.sha,
+ arp_tpa=packet_rx.arp.spa,
+ echo_tracker=packet_rx.tracker,
+ )
+
+ # Update ARP cache with the mapping learned from the received
+ # ARP request that was destined to this stack
+ if config.ARP_CACHE_UPDATE_FROM_DIRECT_REQUEST:
+ self.packet_stats_rx.arp__op_request__update_arp_cache += 1
__debug__ and log(
"arp",
- f"{packet_rx.tracker} - Adding/refreshing ARP cache entry "
- f"from direct reply - {packet_rx.arp.spa} "
- f"-> {packet_rx.arp.sha}",
+ f"{packet_rx.tracker} - Adding/refreshing "
+ "ARP cache entry from direct request "
+ f"- {packet_rx.arp.spa} -> {packet_rx.arp.sha}>",
)
stack.arp_cache.add_entry(packet_rx.arp.spa, packet_rx.arp.sha)
- return
+ return
- # Update ARP cache with mapping received as gratuitous ARP reply
+ else:
+ # Drop packet if TPA does not match one of our IP addresses
+ self.packet_stats_rx.arp__op_request__tpa_unknown__drop += 1
+ return
+
+
+def _phrx_arp__reply(self: PacketHandler, packet_rx: PacketRx) -> None:
+ """
+ Handle inbound ARP reply packets.
+ """
+
+ self.packet_stats_rx.arp__op_reply += 1
+ # Check for ARP reply that is response to our ARP probe, this indicates
+ # the IP address we trying to claim is in use
+ if packet_rx.ether.dst == self.mac_unicast:
if (
- packet_rx.ether.dst.is_broadcast
- and packet_rx.arp.spa == packet_rx.arp.tpa
- and config.ARP_CACHE_UPDATE_FROM_GRATUITIOUS_REPLY
+ packet_rx.arp.spa in [_.address for _ in self.ip4_host_candidate]
+ and packet_rx.arp.tha == self.mac_unicast
+ and packet_rx.arp.tpa.is_unspecified
):
- self.packet_stats_rx.arp__op_reply__update_arp_cache_gratuitous += 1
+ self.packet_stats_rx.arp__op_reply__ip_conflict += 1
__debug__ and log(
"arp",
- f"{packet_rx.tracker} - Adding/refreshing ARP cache entry "
- f"from gratuitous reply - {packet_rx.arp.spa} "
- f"-> {packet_rx.arp.sha}",
+ f"{packet_rx.tracker} - ARP Probe detected "
+ f"conflict for IP {packet_rx.arp.spa} with host at "
+ f"{packet_rx.arp.sha}>",
)
- stack.arp_cache.add_entry(packet_rx.arp.spa, packet_rx.arp.sha)
+ stack.arp_probe_unicast_conflict.add(packet_rx.arp.spa)
return
+
+ # Update ARP cache with mapping received as direct ARP reply
+ if packet_rx.ether.dst == self.mac_unicast:
+ self.packet_stats_rx.arp__op_reply__update_arp_cache += 1
+ __debug__ and log(
+ "arp",
+ f"{packet_rx.tracker} - Adding/refreshing ARP cache entry "
+ f"from direct reply - {packet_rx.arp.spa} "
+ f"-> {packet_rx.arp.sha}",
+ )
+ stack.arp_cache.add_entry(packet_rx.arp.spa, packet_rx.arp.sha)
+ return
+
+ # Update ARP cache with mapping received as gratuitous ARP reply
+ if (
+ packet_rx.ether.dst.is_broadcast
+ and packet_rx.arp.spa == packet_rx.arp.tpa
+ and config.ARP_CACHE_UPDATE_FROM_GRATUITIOUS_REPLY
+ ):
+ self.packet_stats_rx.arp__op_reply__update_arp_cache_gratuitous += 1
+ __debug__ and log(
+ "arp",
+ f"{packet_rx.tracker} - Adding/refreshing ARP cache entry "
+ f"from gratuitous reply - {packet_rx.arp.spa} "
+ f"-> {packet_rx.arp.sha}",
+ )
+ stack.arp_cache.add_entry(packet_rx.arp.spa, packet_rx.arp.sha)
+ return
diff --git a/pytcp/protocols/arp/phtx.py b/pytcp/protocols/arp/phtx.py
index a2dc435d..a472345f 100755
--- a/pytcp/protocols/arp/phtx.py
+++ b/pytcp/protocols/arp/phtx.py
@@ -45,7 +45,7 @@
from pytcp.lib.tracker import Tracker
from pytcp.lib.tx_status import TxStatus
from pytcp.protocols.arp.fpa import ArpAssembler
-from pytcp.protocols.arp.ps import ARP_OP_REPLY, ARP_OP_REQUEST
+from pytcp.protocols.arp.ps import ArpOperation
if TYPE_CHECKING:
from pytcp.lib.ip4_address import Ip4Address
@@ -57,7 +57,7 @@ def _phtx_arp(
*,
ether_src: MacAddress,
ether_dst: MacAddress,
- arp_oper: int,
+ arp_oper: ArpOperation,
arp_sha: MacAddress,
arp_spa: Ip4Address,
arp_tha: MacAddress,
@@ -76,11 +76,11 @@ def _phtx_arp(
self.packet_stats_tx.arp__no_proto_support__drop += 1
return TxStatus.DROPED__ARP__NO_PROTOCOL_SUPPORT
- if arp_oper == ARP_OP_REQUEST:
- self.packet_stats_tx.arp__op_request__send += 1
-
- if arp_oper == ARP_OP_REPLY:
- self.packet_stats_tx.arp__op_reply__send += 1
+ match arp_oper:
+ case ArpOperation.REQUEST:
+ self.packet_stats_tx.arp__op_request__send += 1
+ case ArpOperation.REPLY:
+ self.packet_stats_tx.arp__op_reply__send += 1
arp_packet_tx = ArpAssembler(
oper=arp_oper,
diff --git a/pytcp/protocols/arp/ps.py b/pytcp/protocols/arp/ps.py
index c268cd3e..aca19ed2 100755
--- a/pytcp/protocols/arp/ps.py
+++ b/pytcp/protocols/arp/ps.py
@@ -33,7 +33,13 @@
"""
-from __future__ import annotations
+import struct
+from abc import ABC, abstractmethod
+
+from pytcp.lib.ip4_address import Ip4Address
+from pytcp.lib.mac_address import MacAddress
+from pytcp.lib.protocol_enum import ProtocolEnum
+from pytcp.protocols.ether.ps import EtherType
# ARP packet header - IPv4 stack version only
@@ -53,8 +59,168 @@
# | Target IP address |
# +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
-
ARP_HEADER_LEN = 28
-ARP_OP_REQUEST = 1
-ARP_OP_REPLY = 2
+
+class ArpHardwareType(ProtocolEnum):
+ ETHERNET = 1
+
+ @staticmethod
+ def _extract(frame: bytes) -> int:
+ return int(struct.unpack("!H", frame[0:2])[0])
+
+
+class ArpProtocolType(ProtocolEnum):
+ IP4 = 0x0800
+
+ @staticmethod
+ def _extract(frame: bytes) -> int:
+ return int(struct.unpack("!H", frame[2:4])[0])
+
+
+class ArpHardwareLength(ProtocolEnum):
+ ETHERNET = 6
+
+ @staticmethod
+ def _extract(frame: bytes) -> int:
+ return int(frame[4])
+
+
+class ArpProtocolLength(ProtocolEnum):
+ IP4 = 4
+
+ @staticmethod
+ def _extract(frame: bytes) -> int:
+ return int(frame[5])
+
+
+class ArpOperation(ProtocolEnum):
+ REQUEST = 1
+ REPLY = 2
+
+ @staticmethod
+ def _extract(frame: bytes) -> int:
+ return int(struct.unpack("!H", frame[6:8])[0])
+
+
+class Arp(ABC):
+ """
+ Base class for ARP packet parser and assembler classes.
+ """
+
+ _ether_type = EtherType.ARP
+
+ _hrtype: ArpHardwareType
+ _prtype: ArpProtocolType
+ _hrlen: ArpHardwareLength
+ _prlen: ArpProtocolLength
+ _oper: ArpOperation
+ _sha: MacAddress
+ _spa: Ip4Address
+ _tha: MacAddress
+ _tpa: Ip4Address
+
+ def __str__(self) -> str:
+ """
+ Packet log string.
+ """
+
+ match self._oper:
+ case ArpOperation.REQUEST:
+ return (
+ f"ARP request {self._spa} / {self._sha}"
+ f" > {self.tpa} / {self.tha}"
+ )
+ case ArpOperation.REPLY:
+ return (
+ f"ARP reply {self._spa} / {self._sha}"
+ f" > {self._tpa} / {self._tha}"
+ )
+
+ @abstractmethod
+ def __len__(self) -> int:
+ """
+ Length of the packet.
+ """
+
+ raise NotImplementedError
+
+ @property
+ def ether_type(self) -> EtherType:
+ """
+ Getter for the '_ether_type' property.
+ """
+
+ return self._ether_type
+
+ @property
+ def hrtype(self) -> ArpHardwareType:
+ """
+ Getter for the '_hrtype' property.
+ """
+
+ return self._hrtype
+
+ @property
+ def prtype(self) -> ArpProtocolType:
+ """
+ Getter for the '_prtype' property.
+ """
+
+ return self._prtype
+
+ @property
+ def hrlen(self) -> ArpHardwareLength:
+ """
+ Getter for the '_hrlen' property.
+ """
+
+ return self._hrlen
+
+ @property
+ def prlen(self) -> ArpProtocolLength:
+ """
+ Getter for the '_prlen' property.
+ """
+
+ return self._prlen
+
+ @property
+ def oper(self) -> ArpOperation:
+ """
+ Getter for the '_oper' property.
+ """
+
+ return self._oper
+
+ @property
+ def sha(self) -> MacAddress:
+ """
+ Getter for the '_sha' property.
+ """
+
+ return self._sha
+
+ @property
+ def spa(self) -> Ip4Address:
+ """
+ Getter for the '_spa' property.
+ """
+
+ return self._spa
+
+ @property
+ def tha(self) -> MacAddress:
+ """
+ Getter for the '_tha' property.
+ """
+
+ return self._tha
+
+ @property
+ def tpa(self) -> Ip4Address:
+ """
+ Getter for the '_tpa' property.
+ """
+
+ return self._tpa
diff --git a/pytcp/protocols/ether/fpa.py b/pytcp/protocols/ether/fpa.py
index 5490f08d..21f35251 100755
--- a/pytcp/protocols/ether/fpa.py
+++ b/pytcp/protocols/ether/fpa.py
@@ -32,21 +32,13 @@
ver 2.7
"""
-
from __future__ import annotations
import struct
from typing import TYPE_CHECKING
from pytcp.lib.mac_address import MacAddress
-from pytcp.protocols.ether.ps import (
- ETHER_HEADER_LEN,
- ETHER_TYPE_ARP,
- ETHER_TYPE_IP4,
- ETHER_TYPE_IP6,
- ETHER_TYPE_RAW,
- ETHER_TYPE_TABLE,
-)
+from pytcp.protocols.ether.ps import ETHER_HEADER_LEN, Ethernet
from pytcp.protocols.raw.fpa import RawAssembler
if TYPE_CHECKING:
@@ -56,7 +48,7 @@
from pytcp.protocols.ip6.fpa import Ip6Assembler
-class EtherAssembler:
+class EtherAssembler(Ethernet):
"""
Ethernet packet assembler support class.
"""
@@ -76,13 +68,6 @@ def __init__(
Class constructor.
"""
- assert carried_packet.ether_type in {
- ETHER_TYPE_ARP,
- ETHER_TYPE_IP4,
- ETHER_TYPE_IP6,
- ETHER_TYPE_RAW,
- }, f"{carried_packet.ether_type=}"
-
self._carried_packet: (
ArpAssembler
| Ip4Assembler
@@ -90,25 +75,17 @@ def __init__(
| Ip6Assembler
| RawAssembler
) = carried_packet
- self._tracker: Tracker = self._carried_packet.tracker
- self._dst: MacAddress = dst
- self._src: MacAddress = src
- self._type: int = self._carried_packet.ether_type
+ self._tracker = self._carried_packet.tracker
+ self._dst = dst
+ self._src = src
+ self._type = self._carried_packet.ether_type
def __len__(self) -> int:
"""
Length of the packet.
"""
- return ETHER_HEADER_LEN + len(self._carried_packet)
- def __str__(self) -> str:
- """
- Packet log string.
- """
- return (
- f"ETHER {self._src} > {self._dst}, 0x{self._type:0>4x} "
- f"({ETHER_TYPE_TABLE.get(self._type, '???')}), plen {len(self)}"
- )
+ return ETHER_HEADER_LEN + len(self._carried_packet)
@property
def tracker(self) -> Tracker:
@@ -117,44 +94,18 @@ def tracker(self) -> Tracker:
"""
return self._tracker
- @property
- def dst(self) -> MacAddress:
- """
- Getter for the '_dst' attribute.
- """
- return self._dst
-
- @dst.setter
- def dst(self, mac_address: MacAddress) -> None:
- """
- Setter for the '_dst' attribute.
- """
- self._dst = mac_address
-
- @property
- def src(self) -> MacAddress:
- """
- Getter for the '_src' attribute.
- """
- return self._src
-
- @src.setter
- def src(self, mac_address: MacAddress) -> None:
- """
- Setter for the '_src' attribute.
- """
- self._src = mac_address
-
def assemble(self, frame: memoryview) -> None:
"""
Assemble packet into the raw form.
"""
+
struct.pack_into(
"! 6s 6s H",
frame,
0,
bytes(self._dst),
bytes(self._src),
- self._type,
+ int(self._type),
)
+
self._carried_packet.assemble(frame[ETHER_HEADER_LEN:])
diff --git a/pytcp/protocols/ether/fpp.py b/pytcp/protocols/ether/fpp.py
index 0d68a32c..1950a044 100755
--- a/pytcp/protocols/ether/fpp.py
+++ b/pytcp/protocols/ether/fpp.py
@@ -34,25 +34,37 @@
ver 2.7
"""
-
from __future__ import annotations
-import struct
from typing import TYPE_CHECKING
-from pytcp import config
+from pytcp.lib.errors import PacketIntegrityError, PacketSanityError
from pytcp.lib.mac_address import MacAddress
-from pytcp.protocols.ether.ps import (
- ETHER_HEADER_LEN,
- ETHER_TYPE_MIN,
- ETHER_TYPE_TABLE,
-)
+from pytcp.protocols.ether.ps import ETHER_HEADER_LEN, Ethernet, EtherType
if TYPE_CHECKING:
from pytcp.lib.packet import PacketRx
-class EtherParser:
+class EtherIntegrityError(PacketIntegrityError):
+ """
+ Exception raised when Ethernet packet integrity check fails.
+ """
+
+ def __init__(self, message: str):
+ super().__init__("[ETHER] " + message)
+
+
+class EtherSanityError(PacketSanityError):
+ """
+ Exception raised when Ethernet packet sanity check fails.
+ """
+
+ def __init__(self, message: str):
+ super().__init__("[ETHER] " + message)
+
+
+class EtherParser(Ethernet):
"""
Ethernet packet parser class.
"""
@@ -66,109 +78,74 @@ def __init__(self, packet_rx: PacketRx) -> None:
self._frame = packet_rx.frame
- packet_rx.parse_failed = (
- self._packet_integrity_check() or self._packet_sanity_check()
- )
+ self._packet_integrity_check()
+ self._packet_sanity_check()
- if not packet_rx.parse_failed:
- packet_rx.frame = packet_rx.frame[ETHER_HEADER_LEN:]
+ packet_rx.frame = packet_rx.frame[ETHER_HEADER_LEN:]
+
+ self._dst = MacAddress(self._frame[0:6])
+ self._src = MacAddress(self._frame[6:12])
+ self._type = EtherType.from_frame(self._frame)
def __len__(self) -> int:
"""
- Number of bytes remaining in the frame.
+ Get number of bytes remaining in the frame.
"""
return len(self._frame)
- def __str__(self) -> str:
- """
- Packet log string.
- """
- return (
- f"ETHER {self.src} > {self.dst}, 0x{self.type:0>4x} "
- f"({ETHER_TYPE_TABLE.get(self.type, '???')})"
- )
-
- @property
- def dst(self) -> MacAddress:
- """
- Read the 'Destination MAC address' field.
- """
- if "_cache__dst" not in self.__dict__:
- self._cache__dst = MacAddress(self._frame[0:6])
- return self._cache__dst
-
- @property
- def src(self) -> MacAddress:
- """
- Read the 'Source MAC address' field.
- """
- if "_cache__src" not in self.__dict__:
- self._cache__src = MacAddress(self._frame[6:12])
- return self._cache__src
-
- @property
- def type(self) -> int:
- """
- Read the 'EtherType' field.
- """
- if "_cache__type" not in self.__dict__:
- self._cache__type: int = struct.unpack("!H", self._frame[12:14])[0]
- return self._cache__type
-
@property
def header_copy(self) -> bytes:
"""
- Return copy of packet header.
+ Get copy of packet header.
"""
- if "_cache__header_copy" not in self.__dict__:
- self._cache__header_copy = bytes(self._frame[:ETHER_HEADER_LEN])
- return self._cache__header_copy
+
+ return bytes(self._frame[:ETHER_HEADER_LEN])
@property
def data_copy(self) -> bytes:
"""
- Return copy of packet data.
+ Get copy of packet data.
"""
- if "_cache__data_copy" not in self.__dict__:
- self._cache__data_copy = bytes(self._frame[ETHER_HEADER_LEN:])
- return self._cache__data_copy
+
+ return bytes(self._frame[ETHER_HEADER_LEN:])
@property
def packet_copy(self) -> bytes:
"""
- Return copy of whole packet.
+ Get copy of whole packet.
"""
- if "_cache__packet_copy" not in self.__dict__:
- self._cache__packet_copy = bytes(self._frame[:])
- return self._cache__packet_copy
+
+ return bytes(self._frame[:])
@property
def plen(self) -> int:
"""
- Calculate packet length.
+ Get packet length.
"""
- if "_cache__plen" not in self.__dict__:
- self._cache__plen = len(self)
- return self._cache__plen
- def _packet_integrity_check(self) -> str:
+ return len(self)
+
+ def _packet_integrity_check(self) -> None:
"""
Packet integrity check to be run on raw packet prior to parsing
to make sure parsing is safe.
"""
- if not config.PACKET_INTEGRITY_CHECK:
- return ""
+
if len(self) < ETHER_HEADER_LEN:
- return "ETHER integrity - wrong packet length (I)"
- return ""
+ raise EtherIntegrityError(
+ "The minimum packet length must be "
+ f"'{ETHER_HEADER_LEN}' bytes, got {len(self)} bytes."
+ )
- def _packet_sanity_check(self) -> str:
+ def _packet_sanity_check(self) -> None:
"""
Packet sanity check to be run on parsed packet to make sure packet's
fields contain sane values.
"""
- if not config.PACKET_SANITY_CHECK:
- return ""
- if self.type < ETHER_TYPE_MIN:
- return "ETHER sanity - 'ether_type' must be greater than 0x0600"
- return ""
+
+ if not EtherType.sanity_check(self._frame):
+ raise EtherSanityError(
+ "The 'type' field value must be one of "
+ f"{[type.value for type in EtherType]}, "
+ f"got '{self.type}'."
+ )
diff --git a/pytcp/protocols/ether/phrx.py b/pytcp/protocols/ether/phrx.py
index 6500dc1e..269a7e3f 100755
--- a/pytcp/protocols/ether/phrx.py
+++ b/pytcp/protocols/ether/phrx.py
@@ -34,19 +34,15 @@
ver 2.7
"""
-
from __future__ import annotations
from typing import TYPE_CHECKING
from pytcp import config
+from pytcp.lib.errors import PacketValidationError
from pytcp.lib.logger import log
from pytcp.protocols.ether.fpp import EtherParser
-from pytcp.protocols.ether.ps import (
- ETHER_TYPE_ARP,
- ETHER_TYPE_IP4,
- ETHER_TYPE_IP6,
-)
+from pytcp.protocols.ether.ps import EtherType
if TYPE_CHECKING:
from pytcp.lib.packet import PacketRx
@@ -60,13 +56,13 @@ def _phrx_ether(self: PacketHandler, packet_rx: PacketRx) -> None:
self.packet_stats_rx.ether__pre_parse += 1
- EtherParser(packet_rx)
-
- if packet_rx.parse_failed:
+ try:
+ EtherParser(packet_rx)
+ except PacketValidationError as error:
self.packet_stats_rx.ether__failed_parse__drop += 1
__debug__ and log(
"ether",
- f"{packet_rx.tracker} - {packet_rx.parse_failed}>",
+ f"{packet_rx.tracker} - {error}>",
)
return
@@ -95,14 +91,14 @@ def _phrx_ether(self: PacketHandler, packet_rx: PacketRx) -> None:
if packet_rx.ether.dst == self.mac_broadcast:
self.packet_stats_rx.ether__dst_broadcast += 1
- if packet_rx.ether.type == ETHER_TYPE_ARP and config.IP4_SUPPORT:
+ if packet_rx.ether.type == EtherType.ARP and config.IP4_SUPPORT:
self._phrx_arp(packet_rx)
return
- if packet_rx.ether.type == ETHER_TYPE_IP4 and config.IP4_SUPPORT:
+ if packet_rx.ether.type == EtherType.IP4 and config.IP4_SUPPORT:
self._phrx_ip4(packet_rx)
return
- if packet_rx.ether.type == ETHER_TYPE_IP6 and config.IP6_SUPPORT:
+ if packet_rx.ether.type == EtherType.IP6 and config.IP6_SUPPORT:
self._phrx_ip6(packet_rx)
return
diff --git a/pytcp/protocols/ether/ps.py b/pytcp/protocols/ether/ps.py
index 307f0d4b..27dc283d 100755
--- a/pytcp/protocols/ether/ps.py
+++ b/pytcp/protocols/ether/ps.py
@@ -32,8 +32,11 @@
ver 2.7
"""
+import struct
+from abc import ABC, abstractmethod
-from __future__ import annotations
+from pytcp.lib.mac_address import MacAddress
+from pytcp.lib.protocol_enum import ProtocolEnum
# Ethernet packet header
@@ -50,15 +53,96 @@
ETHER_HEADER_LEN = 14
-ETHER_TYPE_MIN = 0x0600
-ETHER_TYPE_ARP = 0x0806
-ETHER_TYPE_IP4 = 0x0800
-ETHER_TYPE_IP6 = 0x86DD
-ETHER_TYPE_RAW = 0xFFFF
-
-ETHER_TYPE_TABLE = {
- ETHER_TYPE_ARP: "ARP",
- ETHER_TYPE_IP4: "IPv4",
- ETHER_TYPE_IP6: "IPv6",
- ETHER_TYPE_RAW: "raw_data",
-}
+
+class EtherType(ProtocolEnum):
+ ARP = 0x0806
+ IP4 = 0x0800
+ IP6 = 0x86DD
+ RAW = 0xFFFF
+
+ @staticmethod
+ def _extract(frame: bytes) -> int:
+ return int(struct.unpack("!H", frame[12:14])[0])
+
+ def __str__(self) -> str:
+ """
+ Get string representation of this enum.
+ """
+
+ match self:
+ case EtherType.ARP:
+ return "ARP"
+ case EtherType.IP4:
+ return "IPv4"
+ case EtherType.IP6:
+ return "IPv6"
+ case EtherType.RAW:
+ return "raw_data"
+
+
+class Ethernet(ABC):
+ """
+ Base class for ARP packet parser and assembler classes.
+ """
+
+ _dst: MacAddress
+ _src: MacAddress
+ _type: EtherType
+
+ @abstractmethod
+ def __len__(self) -> int:
+ """
+ Length of the packet.
+ """
+
+ raise NotImplementedError
+
+ def __str__(self) -> str:
+ """
+ Packet log string.
+ """
+
+ return (
+ f"ETHER {self._src} > {self._dst}, 0x{int(self._type):0>4x} "
+ f"({self._type}), plen {len(self)}"
+ )
+
+ @property
+ def dst(self) -> MacAddress:
+ """
+ Getter for '_dst' property.
+ """
+
+ return self._dst
+
+ @dst.setter
+ def dst(self, mac_address: MacAddress) -> None:
+ """
+ Setter for the '_dst' attribute.
+ """
+
+ self._dst = mac_address
+
+ @property
+ def src(self) -> MacAddress:
+ """
+ Getter for '_src' property.
+ """
+
+ return self._src
+
+ @src.setter
+ def src(self, mac_address: MacAddress) -> None:
+ """
+ Setter for the '_src' attribute.
+ """
+
+ self._src = mac_address
+
+ @property
+ def type(self) -> EtherType:
+ """
+ Getter for '_type' property.
+ """
+
+ return self._type
diff --git a/pytcp/protocols/icmp4/fpa.py b/pytcp/protocols/icmp4/fpa.py
index ade2cee7..b8dc3707 100755
--- a/pytcp/protocols/icmp4/fpa.py
+++ b/pytcp/protocols/icmp4/fpa.py
@@ -38,21 +38,24 @@
import struct
+from pytcp.lib.errors import UnsupportedCaseError
from pytcp.lib.ip_helper import inet_cksum
from pytcp.lib.tracker import Tracker
from pytcp.protocols.icmp4.ps import (
- ICMP4_ECHO_REPLY,
ICMP4_ECHO_REPLY_LEN,
- ICMP4_ECHO_REQUEST,
ICMP4_ECHO_REQUEST_LEN,
- ICMP4_UNREACHABLE,
- ICMP4_UNREACHABLE__PORT,
ICMP4_UNREACHABLE_LEN,
+ Icmp4,
+ Icmp4Code,
+ Icmp4EchoReplyCode,
+ Icmp4EchoRequestCode,
+ Icmp4Type,
+ Icmp4UnreachableCode,
)
from pytcp.protocols.ip4.ps import IP4_PROTO_ICMP4
-class Icmp4Assembler:
+class Icmp4Assembler(Icmp4):
"""
ICMPv4 packet assembler support class.
"""
@@ -62,63 +65,57 @@ class Icmp4Assembler:
def __init__(
self,
*,
- type: int = 0,
- code: int = 0,
+ type: Icmp4Type,
+ code: Icmp4Code,
ec_id: int | None = None,
ec_seq: int | None = None,
ec_data: bytes | None = None,
un_data: bytes | None = None,
echo_tracker: Tracker | None = None,
) -> None:
- """Class constructor"""
+ """
+ Class constructor.
+ """
self._tracker: Tracker = Tracker(prefix="TX", echo_tracker=echo_tracker)
assert ec_id is None or 0 <= ec_id <= 0xFFFF
assert ec_seq is None or 0 <= ec_seq <= 0xFFFF
- self._type: int = type
- self._code: int = code
-
- if self._type == ICMP4_ECHO_REPLY and self._code == 0:
- self._ec_id = 0 if ec_id is None else ec_id
- self._ec_seq = 0 if ec_seq is None else ec_seq
- self._ec_data = b"" if ec_data is None else ec_data
- return
+ self._type = type
+ self._code = code
- if (
- self._type == ICMP4_UNREACHABLE
- and self._code == ICMP4_UNREACHABLE__PORT
- ):
- self._un_data = b"" if un_data is None else un_data[:520]
- return
+ match (self._type, self._code):
+ case (Icmp4Type.ECHO_REPLY, Icmp4EchoReplyCode.DEFAULT):
+ self._ec_id = 0 if ec_id is None else ec_id
+ self._ec_seq = 0 if ec_seq is None else ec_seq
+ self._ec_data = b"" if ec_data is None else ec_data
- if self._type == ICMP4_ECHO_REQUEST and self._code == 0:
- self._ec_id = 0 if ec_id is None else ec_id
- self._ec_seq = 0 if ec_seq is None else ec_seq
- self._ec_data = b"" if ec_data is None else ec_data
- return
+ case (Icmp4Type.UNREACHABLE, Icmp4UnreachableCode.PORT):
+ self._un_data = b"" if un_data is None else un_data[:520]
- assert False, "Unknown ICMPv4 Type/Code"
+ case (Icmp4Type.ECHO_REQUEST, Icmp4EchoRequestCode.DEFAULT):
+ self._ec_id = 0 if ec_id is None else ec_id
+ self._ec_seq = 0 if ec_seq is None else ec_seq
+ self._ec_data = b"" if ec_data is None else ec_data
def __len__(self) -> int:
"""
Length of the packet.
"""
- if self._type == ICMP4_ECHO_REPLY:
- return ICMP4_ECHO_REPLY_LEN + len(self._ec_data)
+ match (self._type, self._code):
+ case (Icmp4Type.ECHO_REPLY, Icmp4EchoReplyCode.DEFAULT):
+ return ICMP4_ECHO_REPLY_LEN + len(self._ec_data)
- if (
- self._type == ICMP4_UNREACHABLE
- and self._code == ICMP4_UNREACHABLE__PORT
- ):
- return ICMP4_UNREACHABLE_LEN + len(self._un_data)
+ case (Icmp4Type.UNREACHABLE, Icmp4UnreachableCode.PORT):
+ return ICMP4_UNREACHABLE_LEN + len(self._un_data)
- if self._type == ICMP4_ECHO_REQUEST:
- return ICMP4_ECHO_REQUEST_LEN + len(self._ec_data)
+ case (Icmp4Type.ECHO_REQUEST, Icmp4EchoRequestCode.DEFAULT):
+ return ICMP4_ECHO_REQUEST_LEN + len(self._ec_data)
- assert False, "Unknown ICMPv4 Type/Code"
+ case _:
+ raise UnsupportedCaseError
def __str__(self) -> str:
"""
@@ -127,25 +124,24 @@ def __str__(self) -> str:
header = f"ICMPv4 {self._type}/{self._code}"
- if self._type == ICMP4_ECHO_REPLY and self._code == 0:
- return (
- f"{header} (echo_reply), id {self._ec_id}, "
- f"seq {self._ec_seq}, dlen {len(self._ec_data)}"
- )
+ match (self._type, self._code):
+ case (Icmp4Type.ECHO_REPLY, Icmp4EchoReplyCode.DEFAULT):
+ return (
+ f"{header} (echo_reply), id {self._ec_id}, "
+ f"seq {self._ec_seq}, dlen {len(self._ec_data)}"
+ )
- if (
- self._type == ICMP4_UNREACHABLE
- and self._code == ICMP4_UNREACHABLE__PORT
- ):
- return f"{header} (unreachable_port), dlen {len(self._un_data)}"
+ case (Icmp4Type.UNREACHABLE, Icmp4UnreachableCode.PORT):
+ return f"{header} (unreachable_port), dlen {len(self._un_data)}"
- if self._type == ICMP4_ECHO_REQUEST and self._code == 0:
- return (
- f"{header} (echo_request), id {self._ec_id}, "
- f"seq {self._ec_seq}, dlen {len(self._ec_data)}"
- )
+ case (Icmp4Type.ECHO_REQUEST, Icmp4EchoRequestCode.DEFAULT):
+ return (
+ f"{header} (echo_request), id {self._ec_id}, "
+ f"seq {self._ec_seq}, dlen {len(self._ec_data)}"
+ )
- assert False, "Unknown ICMPv4 Type/Code"
+ case _:
+ raise UnsupportedCaseError
@property
def tracker(self) -> Tracker:
@@ -159,51 +155,47 @@ def assemble(self, frame: memoryview, _: int = 0) -> None:
Assemble packet into the raw form.
"""
- if self._type == ICMP4_ECHO_REPLY and self._code == 0:
- struct.pack_into(
- f"! BBH HH {len(self._ec_data)}s",
- frame,
- 0,
- self._type,
- self._code,
- 0,
- self._ec_id,
- self._ec_seq,
- bytes(self._ec_data),
- )
- struct.pack_into("! H", frame, 2, inet_cksum(frame))
- return
-
- if (
- self._type == ICMP4_UNREACHABLE
- and self._code == ICMP4_UNREACHABLE__PORT
- ):
- struct.pack_into(
- f"! BBH L {len(self._un_data)}s",
- frame,
- 0,
- self._type,
- self._code,
- 0,
- 0,
- bytes(self._un_data),
- )
- struct.pack_into("! H", frame, 2, inet_cksum(frame))
- return
-
- if self._type == ICMP4_ECHO_REQUEST and self._code == 0:
- struct.pack_into(
- f"! BBH HH {len(self._ec_data)}s",
- frame,
- 0,
- self._type,
- self._code,
- 0,
- self._ec_id,
- self._ec_seq,
- bytes(self._ec_data),
- )
- struct.pack_into("! H", frame, 2, inet_cksum(frame))
- return
-
- assert False, "Unknown ICMPv4 Type/Code"
+ match (self._type, self._code):
+ case (Icmp4Type.ECHO_REPLY, Icmp4EchoReplyCode.DEFAULT):
+ struct.pack_into(
+ f"! BBH HH {len(self._ec_data)}s",
+ frame,
+ 0,
+ int(self._type),
+ int(self._code),
+ 0,
+ self._ec_id,
+ self._ec_seq,
+ bytes(self._ec_data),
+ )
+ struct.pack_into("! H", frame, 2, inet_cksum(frame))
+
+ case (Icmp4Type.UNREACHABLE, Icmp4UnreachableCode.PORT):
+ struct.pack_into(
+ f"! BBH L {len(self._un_data)}s",
+ frame,
+ 0,
+ int(self._type),
+ int(self._code),
+ 0,
+ 0,
+ bytes(self._un_data),
+ )
+ struct.pack_into("! H", frame, 2, inet_cksum(frame))
+
+ case (Icmp4Type.ECHO_REQUEST, Icmp4EchoRequestCode.DEFAULT):
+ struct.pack_into(
+ f"! BBH HH {len(self._ec_data)}s",
+ frame,
+ 0,
+ int(self._type),
+ int(self._code),
+ 0,
+ self._ec_id,
+ self._ec_seq,
+ bytes(self._ec_data),
+ )
+ struct.pack_into("! H", frame, 2, inet_cksum(frame))
+
+ case _:
+ raise UnsupportedCaseError
diff --git a/pytcp/protocols/icmp4/fpp.py b/pytcp/protocols/icmp4/fpp.py
index 6a904bd1..bc0a15a7 100755
--- a/pytcp/protocols/icmp4/fpp.py
+++ b/pytcp/protocols/icmp4/fpp.py
@@ -41,21 +41,43 @@
import struct
from typing import TYPE_CHECKING
-from pytcp import config
+from pytcp.lib.errors import PacketIntegrityError, PacketSanityError
from pytcp.lib.ip_helper import inet_cksum
from pytcp.protocols.icmp4.ps import (
- ICMP4_ECHO_REPLY,
- ICMP4_ECHO_REQUEST,
ICMP4_HEADER_LEN,
- ICMP4_UNREACHABLE,
- ICMP4_UNREACHABLE__PORT,
+ Icmp4,
+ Icmp4EchoReplyCode,
+ Icmp4EchoReplyMessage,
+ Icmp4EchoRequestCode,
+ Icmp4EchoRequestMessage,
+ Icmp4Type,
+ Icmp4UnreachableCode,
+ Icmp4UnreachablePortMessage,
)
if TYPE_CHECKING:
from pytcp.lib.packet import PacketRx
-class Icmp4Parser:
+class Icmp4IntegrityError(PacketIntegrityError):
+ """
+ Exception raised when ICMPv4 packet integrity check fails.
+ """
+
+ def __init__(self, message: str):
+ super().__init__("[ICMPv4] " + message)
+
+
+class Icmp4SanityError(PacketSanityError):
+ """
+ Exception raised when ICMPv4 packet sanity check fails.
+ """
+
+ def __init__(self, message: str):
+ super().__init__("[ICMPv4] " + message)
+
+
+class Icmp4Parser(Icmp4):
"""
ICMPv4 packet parser class.
"""
@@ -65,179 +87,144 @@ def __init__(self, packet_rx: PacketRx) -> None:
Class constructor.
"""
- assert packet_rx.ip4 is not None
-
packet_rx.icmp4 = self
self._frame = packet_rx.frame
self._plen = packet_rx.ip4.dlen
- packet_rx.parse_failed = (
- self._packet_integrity_check() or self._packet_sanity_check()
- )
+ self._packet_integrity_check()
+ self._packet_sanity_check()
+
+ self._type = Icmp4Type.from_frame(self._frame)
+ match self._type:
+ case Icmp4Type.ECHO_REPLY:
+ self._code = Icmp4EchoReplyCode.from_frame(self._frame)
+ self._message = Icmp4EchoReplyMessageParser(self._frame)
+ case Icmp4Type.UNREACHABLE:
+ self._code = Icmp4UnreachableCode.from_frame(self._frame)
+ match self._code:
+ case Icmp4UnreachableCode.PORT:
+ self._message = Icmp4UnreachablePortMessageParser(
+ self._frame
+ )
+ case Icmp4Type.ECHO_REQUEST:
+ self._code = Icmp4EchoRequestCode.from_frame(self._frame)
+ self._message = Icmp4EchoRequestMessageParser(self._frame)
+ self._cksum: int = struct.unpack("!H", self._frame[2:4])[0]
def __len__(self) -> int:
"""
Number of bytes remaining in the frame.
"""
- return len(self._frame)
-
- def __str__(self) -> str:
- """
- Packet log string.
- """
- header = f"ICMPv4 {self.type}/{self.code}"
-
- if self.type == ICMP4_ECHO_REPLY:
- return f"{header} (echo_reply), id {self.ec_id}, seq {self.ec_seq}, dlen {len(self.ec_data)}"
-
- if (
- self.type == ICMP4_UNREACHABLE
- and self.code == ICMP4_UNREACHABLE__PORT
- ):
- return f"{header} (unreachable_port), dlen {len(self.un_data)}"
-
- if self.type == ICMP4_ECHO_REQUEST:
- return f"{header} (echo_request), id {self.ec_id}, seq {self.ec_seq}, dlen {len(self.ec_data)}"
- return f"{header} (unknown)"
+ return len(self._frame)
@property
- def type(self) -> int:
+ def plen(self) -> int:
"""
- Read the 'Type' field.
+ Calculate packet length.
"""
- return self._frame[0]
+ return self._plen
@property
- def code(self) -> int:
+ def packet_copy(self) -> bytes:
"""
- Read the 'Code' field.
+ Read the whole packet.
"""
- return self._frame[1]
- @property
- def cksum(self) -> int:
- """
- Read the 'Checksum' field.
- """
- if "_cache__cksum" not in self.__dict__:
- self._cache__cksum: int = struct.unpack("!H", self._frame[2:4])[0]
- return self._cache__cksum
+ return bytes(self._frame[: self.plen])
- @property
- def ec_id(self) -> int:
+ def _packet_integrity_check(self) -> None:
"""
- Read the Echo 'Id' field.
+ Packet integrity check to be run on raw packet prior to parsing
+ to make sure parsing is safe.
"""
- if "_cache__ec_id" not in self.__dict__:
- assert self.type in {ICMP4_ECHO_REQUEST, ICMP4_ECHO_REPLY}
- self._cache__ec_id: int = struct.unpack("!H", self._frame[4:6])[0]
- return self._cache__ec_id
- @property
- def ec_seq(self) -> int:
- """
- Read the Echo 'Seq' field.
- """
- if "_cache__ec_seq" not in self.__dict__:
- assert self.type in {ICMP4_ECHO_REQUEST, ICMP4_ECHO_REPLY}
- self._cache__ec_seq: int = struct.unpack("!H", self._frame[6:8])[0]
- return self._cache__ec_seq
+ if inet_cksum(self._frame[: self._plen]):
+ raise Icmp4IntegrityError(
+ "Wrong packet checksum.",
+ )
- @property
- def ec_data(self) -> bytes:
- """
- Read data carried by the Echo message.
- """
- if "_cache__ec_data" not in self.__dict__:
- assert self.type in {ICMP4_ECHO_REQUEST, ICMP4_ECHO_REPLY}
- self._cache__ec_data = self._frame[8 : self.plen]
- return self._cache__ec_data
+ if not ICMP4_HEADER_LEN <= self._plen <= len(self):
+ raise Icmp4IntegrityError(
+ "Wrong packet length (I).",
+ )
- @property
- def un_data(self) -> bytes:
+ def _packet_sanity_check(self) -> None:
"""
- Read the data carried by Uneachable message.
+ Packet sanity check to be run on parsed packet to make sure packets's
+ fields contain sane values.
"""
- if "_cache__un_data" not in self.__dict__:
- assert self.type == ICMP4_UNREACHABLE
- self._cache__un_data = self._frame[8 : self.plen]
- return self._cache__un_data
- @property
- def plen(self) -> int:
- """
- Calculate packet length.
- """
- return self._plen
+ if not Icmp4Type.sanity_check(self._frame):
+ raise Icmp4SanityError(
+ "The 'type' field value must be one of "
+ f"{[type.value for type in Icmp4Type]}, "
+ f"got '{self.type}'."
+ )
+
+ match Icmp4Type.from_frame(self._frame):
+ case Icmp4Type.ECHO_REPLY:
+ if not Icmp4EchoReplyCode.sanity_check(self._frame):
+ raise Icmp4SanityError(
+ "The 'code' field value for Echo Reply message must be one of "
+ f"{[code.value for code in Icmp4EchoReplyCode]}, "
+ f"got '{self.code}'."
+ )
+ case Icmp4Type.UNREACHABLE:
+ if not Icmp4UnreachableCode.sanity_check(self._frame):
+ raise Icmp4SanityError(
+ "The 'code' field value Unreachable message must be one of "
+ f"{[code.value for code in Icmp4UnreachableCode]}, "
+ f"got '{self.code}'."
+ )
+ case Icmp4Type.ECHO_REQUEST:
+ if not Icmp4EchoRequestCode.sanity_check(self._frame):
+ raise Icmp4SanityError(
+ "The 'code' field value for Echo Request message must be one of "
+ f"{[code.value for code in Icmp4EchoRequestCode]}, "
+ f"got '{self.code}'."
+ )
+
+
+class Icmp4EchoReplyMessageParser(Icmp4EchoReplyMessage):
+ """
+ Message class for ICMPv4 Echo Reply packet.
+ """
- @property
- def packet_copy(self) -> bytes:
+ def __init__(self, frame: bytes) -> None:
"""
- Read the whole packet.
+ Class constructor.
"""
- if "_cache__packet_copy" not in self.__dict__:
- self._cache__packet_copy = bytes(self._frame[: self.plen])
- return self._cache__packet_copy
- def _packet_integrity_check(self) -> str:
- """
- Packet integrity check to be run on raw packet prior to parsing
- to make sure parsing is safe.
- """
+ self._id: int = struct.unpack("!H", frame[4:6])[0]
+ self._seq: int = struct.unpack("!H", frame[6:8])[0]
+ self._data: bytes = frame[8:]
- if not config.PACKET_INTEGRITY_CHECK:
- return ""
- if inet_cksum(self._frame[: self._plen]):
- return "ICMPv4 integrity - wrong packet checksum"
+class Icmp4UnreachablePortMessageParser(Icmp4UnreachablePortMessage):
+ """
+ Message class for ICMPv4 Unreachable Port packet.
+ """
- if not ICMP4_HEADER_LEN <= self._plen <= len(self):
- return "ICMPv4 integrity - wrong packet length (I)"
+ def __init__(self, frame: bytes) -> None:
+ """
+ Class constructor.
+ """
- if self._frame[0] in {ICMP4_ECHO_REQUEST, ICMP4_ECHO_REPLY}:
- if not 8 <= self._plen <= len(self):
- return "ICMPv6 integrity - wrong packet length (II)"
+ self._data: bytes = frame[8:]
- elif self._frame[0] == ICMP4_UNREACHABLE:
- if not 12 <= self._plen <= len(self):
- return "ICMPv6 integrity - wrong packet length (II)"
- return ""
+class Icmp4EchoRequestMessageParser(Icmp4EchoRequestMessage):
+ """
+ Message class for ICMPv4 Echo Request packet class.
+ """
- def _packet_sanity_check(self) -> str:
+ def __init__(self, frame: bytes) -> None:
"""
- Packet sanity check to be run on parsed packet to make sure packets's
- fields contain sane values.
+ Class constructor.
"""
- if not config.PACKET_SANITY_CHECK:
- return ""
-
- if self.type in {ICMP4_ECHO_REQUEST, ICMP4_ECHO_REPLY}:
- if not self.code == 0:
- return "ICMPv4 sanity - 'code' should be set to 0 (RFC 792)"
-
- if self.type == ICMP4_UNREACHABLE:
- if self.code not in {
- 0,
- 1,
- 2,
- 3,
- 4,
- 5,
- 6,
- 7,
- 8,
- 9,
- 10,
- 11,
- 12,
- 13,
- 14,
- 15,
- }:
- return "ICMPv4 sanity - 'code' must be set to [0-15] (RFC 792)"
-
- return ""
+ self._id: int = struct.unpack("!H", frame[4:6])[0]
+ self._seq: int = struct.unpack("!H", frame[6:8])[0]
+ self._data: bytes = frame[8:]
diff --git a/pytcp/protocols/icmp4/phrx.py b/pytcp/protocols/icmp4/phrx.py
index c05e6822..8f98d819 100755
--- a/pytcp/protocols/icmp4/phrx.py
+++ b/pytcp/protocols/icmp4/phrx.py
@@ -41,13 +41,19 @@
from typing import TYPE_CHECKING
from pytcp.lib import stack
+from pytcp.lib.errors import PacketValidationError
from pytcp.lib.ip4_address import Ip4Address
from pytcp.lib.logger import log
-from pytcp.protocols.icmp4.fpp import Icmp4Parser
+from pytcp.protocols.icmp4.fpp import (
+ Icmp4EchoReplyMessage,
+ Icmp4EchoRequestMessage,
+ Icmp4Parser,
+ Icmp4UnreachablePortMessage,
+)
from pytcp.protocols.icmp4.ps import (
- ICMP4_ECHO_REPLY,
- ICMP4_ECHO_REQUEST,
- ICMP4_UNREACHABLE,
+ Icmp4EchoReplyCode,
+ Icmp4Type,
+ Icmp4UnreachableCode,
)
from pytcp.protocols.ip4.ps import IP4_HEADER_LEN, IP4_PROTO_UDP
from pytcp.protocols.udp.metadata import UdpMetadata
@@ -59,96 +65,147 @@
def _phrx_icmp4(self: PacketHandler, packet_rx: PacketRx) -> None:
- """Handle inbound ICMPv4 packets"""
+ """
+ Handle inbound ICMPv4 packets.
+ """
self.packet_stats_rx.icmp4__pre_parse += 1
- Icmp4Parser(packet_rx)
-
- if packet_rx.parse_failed:
+ try:
+ Icmp4Parser(packet_rx)
+ except PacketValidationError as error:
__debug__ and log(
"icmp4",
- f"{packet_rx.tracker} - {packet_rx.parse_failed}>",
+ f"{packet_rx.tracker} - {error}>",
)
self.packet_stats_rx.icmp4__failed_parse__drop += 1
return
__debug__ and log("icmp4", f"{packet_rx.tracker} - {packet_rx.icmp4}")
- # ICMPv4 Echo Request packet
- if packet_rx.icmp4.type == ICMP4_ECHO_REQUEST:
- __debug__ and log(
- "icmp4",
- f"{packet_rx.tracker} - Received ICMPv4 Echo Request "
- f"packet from {packet_rx.ip4.src}, sending reply>",
- )
- self.packet_stats_rx.icmp4__echo_request__respond_echo_reply += 1
-
- self._phtx_icmp4(
- ip4_src=packet_rx.ip4.dst,
- ip4_dst=packet_rx.ip4.src,
- icmp4_type=ICMP4_ECHO_REPLY,
- icmp4_ec_id=packet_rx.icmp4.ec_id,
- icmp4_ec_seq=packet_rx.icmp4.ec_seq,
- icmp4_ec_data=packet_rx.icmp4.ec_data,
- echo_tracker=packet_rx.tracker,
+ match packet_rx.icmp4.type:
+ case Icmp4Type.ECHO_REPLY:
+ _phrx_icmp4__echo_reply(self, packet_rx)
+ case Icmp4Type.ECHO_REQUEST:
+ _phrx_icmp4__echo_request(self, packet_rx)
+ case Icmp4Type.UNREACHABLE:
+ match packet_rx.icmp4.code:
+ case Icmp4UnreachableCode.PORT:
+ _phrx_icmp4__unreachable__port(self, packet_rx)
+
+
+def _phrx_icmp4__echo_reply(self: PacketHandler, packet_rx: PacketRx) -> None:
+ """
+ Handle inbound ICMPv4 Echo Reply packets.
+ """
+
+ assert isinstance(packet_rx.icmp4.message, Icmp4EchoReplyMessage)
+
+ __debug__ and log(
+ "icmp4",
+ f"{packet_rx.tracker} - Received ICMPv4 Echo Request "
+ f"packet from {packet_rx.ip4.src}, sending reply>",
+ )
+ self.packet_stats_rx.icmp4__echo_request__respond_echo_reply += 1
+
+ self._phtx_icmp4(
+ ip4_src=packet_rx.ip4.dst,
+ ip4_dst=packet_rx.ip4.src,
+ icmp4_type=Icmp4Type.ECHO_REPLY,
+ icmp4_code=Icmp4EchoReplyCode.DEFAULT,
+ icmp4_ec_id=packet_rx.icmp4.message.id,
+ icmp4_ec_seq=packet_rx.icmp4.message.seq,
+ icmp4_ec_data=packet_rx.icmp4.message.data,
+ echo_tracker=packet_rx.tracker,
+ )
+
+
+def _phrx_icmp4__echo_request(self: PacketHandler, packet_rx: PacketRx) -> None:
+ """
+ Handle inbound ICMPv4 Echo Reply packets.
+ """
+
+ assert isinstance(packet_rx.icmp4.message, Icmp4EchoRequestMessage)
+
+ __debug__ and log(
+ "icmp4",
+ f"{packet_rx.tracker} - Received ICMPv4 Echo Request "
+ f"packet from {packet_rx.ip4.src}, sending reply>",
+ )
+ self.packet_stats_rx.icmp4__echo_request__respond_echo_reply += 1
+
+ self._phtx_icmp4(
+ ip4_src=packet_rx.ip4.dst,
+ ip4_dst=packet_rx.ip4.src,
+ icmp4_type=Icmp4Type.ECHO_REPLY,
+ icmp4_code=Icmp4EchoReplyCode.DEFAULT,
+ icmp4_ec_id=packet_rx.icmp4.message.id,
+ icmp4_ec_seq=packet_rx.icmp4.message.seq,
+ icmp4_ec_data=packet_rx.icmp4.message.data,
+ echo_tracker=packet_rx.tracker,
+ )
+
+
+def _phrx_icmp4__unreachable__port(
+ self: PacketHandler, packet_rx: PacketRx
+) -> None:
+ """
+ Handle inbound ICMPv4 Unreachable Port packets.
+ """
+
+ assert isinstance(packet_rx.icmp4.message, Icmp4UnreachablePortMessage)
+
+ __debug__ and log(
+ "icmp4",
+ f"{packet_rx.tracker} - Received ICMPv4 Unreachable packet "
+ f"from {packet_rx.ip4.src}, will try to match UDP socket",
+ )
+ self.packet_stats_rx.icmp4__unreachable += 1
+
+ # Quick and dirty way to validate received data and pull useful
+ # information from it.
+ frame = packet_rx.icmp4.message.data
+ if (
+ len(frame) >= IP4_HEADER_LEN
+ and frame[0] >> 4 == 4
+ and len(frame) >= ((frame[0] & 0b00001111) << 2)
+ and frame[9] == IP4_PROTO_UDP
+ and len(frame) >= ((frame[0] & 0b00001111) << 2) + UDP_HEADER_LEN
+ ):
+ # Create UdpMetadata object and try to find matching UDP socket.
+ udp_offset = (frame[0] & 0b00001111) << 2
+ packet = UdpMetadata(
+ local_ip_address=Ip4Address(frame[12:16]),
+ remote_ip_address=Ip4Address(frame[16:20]),
+ local_port=struct.unpack(
+ "!H", frame[udp_offset + 0 : udp_offset + 2]
+ )[0],
+ remote_port=struct.unpack(
+ "!H", frame[udp_offset + 2 : udp_offset + 4]
+ )[0],
)
- return
- # ICMPv4 Unreachable packet
- if packet_rx.icmp4.type == ICMP4_UNREACHABLE:
- __debug__ and log(
- "icmp4",
- f"{packet_rx.tracker} - Received ICMPv4 Unreachable packet "
- f"from {packet_rx.ip4.src}, will try to match UDP socket",
- )
- self.packet_stats_rx.icmp4__unreachable += 1
-
- # Quick and dirty way to validate received data and pull useful
- # information from it
- frame = packet_rx.icmp4.un_data
- if (
- len(frame) >= IP4_HEADER_LEN
- and frame[0] >> 4 == 4
- and len(frame) >= ((frame[0] & 0b00001111) << 2)
- and frame[9] == IP4_PROTO_UDP
- and len(frame) >= ((frame[0] & 0b00001111) << 2) + UDP_HEADER_LEN
- ):
- # Create UdpMetadata object and try to find matching UDP socket
- udp_offset = (frame[0] & 0b00001111) << 2
- packet = UdpMetadata(
- local_ip_address=Ip4Address(frame[12:16]),
- remote_ip_address=Ip4Address(frame[16:20]),
- local_port=struct.unpack(
- "!H", frame[udp_offset + 0 : udp_offset + 2]
- )[0],
- remote_port=struct.unpack(
- "!H", frame[udp_offset + 2 : udp_offset + 4]
- )[0],
- )
-
- for socket_pattern in packet.socket_patterns:
- socket = stack.sockets.get(socket_pattern, None)
- if socket:
- __debug__ and log(
- "icmp4",
- f"{packet_rx.tracker} - Found matching "
- f"listening socket {socket}, for Unreachable "
- f"packet from {packet_rx.ip4.src}>",
- )
- socket.notify_unreachable()
- return
-
- __debug__ and log(
- "icmp4",
- f"{packet_rx.tracker} - Unreachable data doesn't match "
- "any UDP socket",
- )
- return
+ for socket_pattern in packet.socket_patterns:
+ socket = stack.sockets.get(socket_pattern, None)
+ if socket:
+ __debug__ and log(
+ "icmp4",
+ f"{packet_rx.tracker} - Found matching "
+ f"listening socket {socket}, for Unreachable "
+ f"packet from {packet_rx.ip4.src}>",
+ )
+ socket.notify_unreachable()
+ return
__debug__ and log(
"icmp4",
- f"{packet_rx.tracker} - Unreachable data doesn't pass basic "
- "IPv4/UDP integrity check",
+ f"{packet_rx.tracker} - Unreachable data doesn't match "
+ "any UDP socket",
)
return
+
+ __debug__ and log(
+ "icmp4",
+ f"{packet_rx.tracker} - Unreachable data doesn't pass basic "
+ "IPv4/UDP integrity check",
+ )
diff --git a/pytcp/protocols/icmp4/phtx.py b/pytcp/protocols/icmp4/phtx.py
index 4c05d75e..ce4151d2 100755
--- a/pytcp/protocols/icmp4/phtx.py
+++ b/pytcp/protocols/icmp4/phtx.py
@@ -40,15 +40,17 @@
from typing import TYPE_CHECKING
+from pytcp.lib.errors import UnsupportedCaseError
from pytcp.lib.logger import log
from pytcp.lib.tracker import Tracker
from pytcp.lib.tx_status import TxStatus
from pytcp.protocols.icmp4.fpa import Icmp4Assembler
from pytcp.protocols.icmp4.ps import (
- ICMP4_ECHO_REPLY,
- ICMP4_ECHO_REQUEST,
- ICMP4_UNREACHABLE,
- ICMP4_UNREACHABLE__PORT,
+ Icmp4Code,
+ Icmp4EchoReplyCode,
+ Icmp4EchoRequestCode,
+ Icmp4Type,
+ Icmp4UnreachableCode,
)
if TYPE_CHECKING:
@@ -61,8 +63,8 @@ def _phtx_icmp4(
*,
ip4_src: Ip4Address,
ip4_dst: Ip4Address,
- icmp4_type: int,
- icmp4_code: int = 0,
+ icmp4_type: Icmp4Type,
+ icmp4_code: Icmp4Code,
icmp4_ec_id: int | None = None,
icmp4_ec_seq: int | None = None,
icmp4_ec_data: bytes | None = None,
@@ -87,28 +89,24 @@ def _phtx_icmp4(
__debug__ and log("icmp4", f"{icmp4_packet_tx.tracker} - {icmp4_packet_tx}")
- if icmp4_type == ICMP4_ECHO_REPLY and icmp4_code == 0:
- self.packet_stats_tx.icmp4__echo_reply__send += 1
- return self._phtx_ip4(
- ip4_src=ip4_src, ip4_dst=ip4_dst, carried_packet=icmp4_packet_tx
- )
-
- if icmp4_type == ICMP4_ECHO_REQUEST and icmp4_code == 0:
- self.packet_stats_tx.icmp4__echo_request__send += 1
- return self._phtx_ip4(
- ip4_src=ip4_src, ip4_dst=ip4_dst, carried_packet=icmp4_packet_tx
- )
-
- if (
- icmp4_type == ICMP4_UNREACHABLE
- and icmp4_code == ICMP4_UNREACHABLE__PORT
- ):
- self.packet_stats_tx.icmp4__unreachable_port__send += 1
- return self._phtx_ip4(
- ip4_src=ip4_src, ip4_dst=ip4_dst, carried_packet=icmp4_packet_tx
- )
-
- # This code will never be executed in debug mode due to assertions present
- # in Packet Assembler
- self.packet_stats_tx.icmp4__unknown__drop += 1
- return TxStatus.DROPED__ICMP4__UNKNOWN
+ match (icmp4_type, icmp4_code):
+ case (Icmp4Type.ECHO_REPLY, Icmp4EchoReplyCode.DEFAULT):
+ self.packet_stats_tx.icmp4__echo_reply__send += 1
+ return self._phtx_ip4(
+ ip4_src=ip4_src, ip4_dst=ip4_dst, carried_packet=icmp4_packet_tx
+ )
+
+ case (Icmp4Type.UNREACHABLE, Icmp4UnreachableCode.PORT):
+ self.packet_stats_tx.icmp4__unreachable_port__send += 1
+ return self._phtx_ip4(
+ ip4_src=ip4_src, ip4_dst=ip4_dst, carried_packet=icmp4_packet_tx
+ )
+
+ case (Icmp4Type.ECHO_REQUEST, Icmp4EchoRequestCode.DEFAULT):
+ self.packet_stats_tx.icmp4__echo_request__send += 1
+ return self._phtx_ip4(
+ ip4_src=ip4_src, ip4_dst=ip4_dst, carried_packet=icmp4_packet_tx
+ )
+
+ case _:
+ raise UnsupportedCaseError
diff --git a/pytcp/protocols/icmp4/ps.py b/pytcp/protocols/icmp4/ps.py
index afb721c7..39e20fa4 100755
--- a/pytcp/protocols/icmp4/ps.py
+++ b/pytcp/protocols/icmp4/ps.py
@@ -35,6 +35,10 @@
from __future__ import annotations
+from abc import ABC, abstractmethod
+
+from pytcp.lib.protocol_enum import ProtocolEnum
+
# Echo reply message (0/0)
# +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
@@ -79,16 +83,223 @@
# +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
ICMP4_HEADER_LEN = 4
-
-ICMP4_ECHO_REPLY = 0
ICMP4_ECHO_REPLY_LEN = 8
-ICMP4_UNREACHABLE = 3
ICMP4_UNREACHABLE_LEN = 8
-ICMP4_UNREACHABLE__NET = 0
-ICMP4_UNREACHABLE__HOST = 1
-ICMP4_UNREACHABLE__PROTOCOL = 2
-ICMP4_UNREACHABLE__PORT = 3
-ICMP4_UNREACHABLE__FRAGMENTATION = 4
-ICMP4_UNREACHABLE__SOURCE_ROUTE_FAILED = 5
-ICMP4_ECHO_REQUEST = 8
ICMP4_ECHO_REQUEST_LEN = 8
+
+
+class Icmp4Type(ProtocolEnum):
+ ECHO_REPLY = 0
+ UNREACHABLE = 3
+ ECHO_REQUEST = 8
+
+ @staticmethod
+ def _extract(frame: bytes) -> int:
+ return int(frame[0])
+
+
+class Icmp4Code(ProtocolEnum):
+ @staticmethod
+ def _extract(frame: bytes) -> int:
+ return int(frame[1])
+
+
+class Icmp4EchoReplyCode(Icmp4Code):
+ DEFAULT = 0
+
+
+class Icmp4UnreachableCode(Icmp4Code):
+ PORT = 3
+
+
+class Icmp4EchoRequestCode(Icmp4Code):
+ DEFAULT = 0
+
+
+class Icmp4(ABC):
+ """
+ Base class for ICMPv4 packet parser and assembler classes.
+ """
+
+ _type: Icmp4Type
+ _code: Icmp4Code
+ _cksum: int
+ _message: Icmp4EchoReplyMessage | Icmp4UnreachablePortMessage | Icmp4EchoRequestMessage
+
+ @abstractmethod
+ def __len__(self) -> int:
+ """
+ Length of the packet.
+ """
+
+ raise NotImplementedError
+
+ def __str__(self) -> str:
+ """
+ Packet log string.
+ """
+
+ return f"ICMPv4 {self._type}/{self._code} {self._message}"
+
+ @property
+ def type(self) -> Icmp4Type:
+ """
+ Getter for the '_type' property.
+ """
+
+ return self._type
+
+ @property
+ def code(self) -> Icmp4Code:
+ """
+ Getter for the '_code' property.
+ """
+
+ return self._code
+
+ @property
+ def cksum(self) -> int:
+ """
+ Getter for the '_cksum' property.
+ """
+
+ return self._cksum
+
+ @property
+ def message(
+ self,
+ ) -> (
+ Icmp4EchoReplyMessage
+ | Icmp4UnreachablePortMessage
+ | Icmp4EchoRequestMessage
+ ):
+ """
+ Getter for the '_message' property.
+ """
+
+ return self._message
+
+
+class Icmp4Message(ABC):
+ """
+ Base class for ICMPv4 message classes.
+ """
+
+ @abstractmethod
+ def __str__(self) -> str:
+ """
+ Packet log string.
+ """
+
+ raise NotImplementedError
+
+
+class Icmp4EchoReplyMessage(Icmp4Message):
+ """
+ Message class for ICMPv4 Echo Reply packet.
+ """
+
+ _id: int
+ _seq: int
+ _data: bytes
+
+ def __str__(self) -> str:
+ """
+ Packet log string.
+ """
+
+ return (
+ f"(echo reply), id {self._id}, seq {self._seq}, "
+ f"dlen {len(self._data)}"
+ )
+
+ @property
+ def id(self) -> int:
+ """
+ Getter for the '_id' property.
+ """
+
+ return self._id
+
+ @property
+ def seq(self) -> int:
+ """
+ Getter for the '_seq' property.
+ """
+
+ return self._seq
+
+ @property
+ def data(self) -> bytes:
+ """
+ Getter for the '_data' property.
+ """
+
+ return self._data
+
+
+class Icmp4UnreachablePortMessage(Icmp4Message):
+ """
+ Message class for ICMPv4 Unreachable Port packet.
+ """
+
+ _data: bytes
+
+ def __str__(self) -> str:
+ """
+ Packet log string.
+ """
+
+ return f"(unreachable - port), dlen {len(self._data)}"
+
+ @property
+ def data(self) -> bytes:
+ """
+ Getter for the '_data' property.
+ """
+
+ return self._data
+
+
+class Icmp4EchoRequestMessage(Icmp4Message):
+ """
+ Message class for ICMPv4 Echo Request packet class.
+ """
+
+ _id: int
+ _seq: int
+ _data: bytes
+
+ def __str__(self) -> str:
+ """
+ Packet log string.
+ """
+
+ return (
+ f"(echo request), id {self._id}, seq {self._seq}, "
+ f"dlen {len(self._data)}"
+ )
+
+ @property
+ def id(self) -> int:
+ """
+ Getter for the '_id' property.
+ """
+
+ return self._id
+
+ @property
+ def seq(self) -> int:
+ """
+ Getter for the '_seq' property.
+ """
+
+ return self._seq
+
+ @property
+ def data(self) -> bytes:
+ """
+ Getter for the '_data' property.
+ """
+
+ return self._data
diff --git a/pytcp/protocols/icmp6/fpp.py b/pytcp/protocols/icmp6/fpp.py
index a23c7beb..38812abd 100755
--- a/pytcp/protocols/icmp6/fpp.py
+++ b/pytcp/protocols/icmp6/fpp.py
@@ -43,7 +43,7 @@
import struct
from typing import TYPE_CHECKING
-from pytcp import config
+from pytcp.lib.errors import PacketIntegrityError, PacketSanityError
from pytcp.lib.ip6_address import Ip6Address, Ip6Mask, Ip6Network
from pytcp.lib.ip_helper import inet_cksum
from pytcp.lib.mac_address import MacAddress
@@ -71,6 +71,24 @@
from pytcp.lib.packet import PacketRx
+class Icmp6IntegrityError(PacketIntegrityError):
+ """
+ Exception raised when ICMPv6 packet integrity check fails.
+ """
+
+ def __init__(self, message: str):
+ super().__init__("[ICMPv6] " + message)
+
+
+class Icmp6SanityError(PacketSanityError):
+ """
+ Exception raised when ICMPv6 packet sanity check fails.
+ """
+
+ def __init__(self, message: str):
+ super().__init__("[ICMPv6] " + message)
+
+
class Icmp6Parser:
"""
ICMPv6 packet parser class.
@@ -88,10 +106,13 @@ def __init__(self, packet_rx: PacketRx) -> None:
self._frame = packet_rx.frame
self._plen = packet_rx.ip6.dlen
- packet_rx.parse_failed = self._packet_integrity_check(
- packet_rx.ip6.pshdr_sum
- ) or self._packet_sanity_check(
- packet_rx.ip6.src, packet_rx.ip6.dst, packet_rx.ip6.hop
+ self._packet_integrity_check(
+ pshdr_sum=packet_rx.ip6.pshdr_sum,
+ )
+ self._packet_sanity_check(
+ ip6_src=packet_rx.ip6.src,
+ ip6_dst=packet_rx.ip6.dst,
+ ip6_hop=packet_rx.ip6.hop,
)
def __len__(self) -> int:
@@ -498,217 +519,271 @@ def packet_copy(self) -> bytes:
self._cache__packet_copy = self._frame[: self.plen]
return self._cache__packet_copy
- def _nd_option_integrity_check(self, optr: int) -> str:
+ def _nd_option_integrity_check(self, optr: int) -> None:
"""
Check integrity of ICMPv6 ND options.
"""
+
while optr < len(self._frame):
if optr + 1 > len(self._frame):
- return "ICMPv6 sanity check fail - wrong option length (I)"
+ raise Icmp6IntegrityError(
+ "Wrong option length (I)",
+ )
if self._frame[optr + 1] == 0:
- return "ICMPv6 sanity check fail - wrong option length (II)"
+ raise Icmp6IntegrityError(
+ "Wrong option length (II)",
+ )
optr += self._frame[optr + 1] << 3
if optr > len(self._frame):
- return "ICMPv6 sanity check fail - wrong option length (III)"
- return ""
+ raise Icmp6IntegrityError(
+ "Wrong option length (III)",
+ )
- def _packet_integrity_check(self, pshdr_sum: int) -> str:
+ def _packet_integrity_check(self, *, pshdr_sum: int) -> None:
"""
Packet integrity check to be run on raw packet prior to parsing
to make sure parsing is safe.
"""
- if not config.PACKET_INTEGRITY_CHECK:
- return ""
-
if inet_cksum(self._frame[: self._plen], pshdr_sum):
- return "ICMPv6 integrity - wrong packet checksum"
+ raise Icmp6IntegrityError(
+ "Wrong packet checksum.",
+ )
if not ICMP6_HEADER_LEN <= self._plen <= len(self):
- return "ICMPv6 integrity - wrong packet length (I)"
+ raise Icmp6IntegrityError(
+ "Wrong packet length (I)",
+ )
if self._frame[0] == ICMP6_UNREACHABLE:
if not 12 <= self._plen <= len(self):
- return "ICMPv6 integrity - wrong packet length (II)"
+ raise Icmp6IntegrityError(
+ "Wrong packet length (II)",
+ )
elif self._frame[0] in {ICMP6_ECHO_REQUEST, ICMP6_ECHO_REPLY}:
if not 8 <= self._plen <= len(self):
- return "ICMPv6 integrity - wrong packet length (II)"
+ raise Icmp6IntegrityError(
+ "Wrong packet length (II)",
+ )
elif self._frame[0] == ICMP6_MLD2_QUERY:
if not 28 <= self._plen <= len(self):
- return "ICMPv6 integrity - wrong packet length (II)"
+ raise Icmp6IntegrityError(
+ "Wrong packet length (II)",
+ )
if (
self._plen
!= 28 + struct.unpack("!H", self._frame[26:28])[0] * 16
):
- return "ICMPv6 integrity - wrong packet length (III)"
+ raise Icmp6IntegrityError(
+ "Wrong packet length (III)",
+ )
elif self._frame[0] == ICMP6_ND_ROUTER_SOLICITATION:
if not 8 <= self._plen <= len(self):
- return "ICMPv6 integrity - wrong packet length (II)"
- if fail := self._nd_option_integrity_check(8):
- return fail
+ raise Icmp6IntegrityError(
+ "Wrong packet length (II)",
+ )
+ self._nd_option_integrity_check(8)
elif self._frame[0] == ICMP6_ND_ROUTER_ADVERTISEMENT:
if not 16 <= self._plen <= len(self):
- return "ICMPv6 integrity - wrong packet length (II)"
- if fail := self._nd_option_integrity_check(16):
- return fail
+ raise Icmp6IntegrityError(
+ "Wrong packet length (II)",
+ )
+ self._nd_option_integrity_check(16)
elif self._frame[0] == ICMP6_ND_NEIGHBOR_SOLICITATION:
if not 24 <= self._plen <= len(self):
- return "ICMPv6 integrity - wrong packet length (II)"
- if fail := self._nd_option_integrity_check(24):
- return fail
+ raise Icmp6IntegrityError(
+ "Wrong packet length (II)",
+ )
+ self._nd_option_integrity_check(24)
elif self._frame[0] == ICMP6_ND_NEIGHBOR_ADVERTISEMENT:
if not 24 <= self._plen <= len(self):
- return "ICMPv6 integrity - wrong packet length (II)"
- if fail := self._nd_option_integrity_check(24):
- return fail
+ raise Icmp6IntegrityError(
+ "Wrong packet length (II)",
+ )
+ self._nd_option_integrity_check(24)
elif self._frame[0] == ICMP6_MLD2_REPORT:
if not 8 <= self._plen <= len(self):
- return "ICMPv6 integrity - wrong packet length (II)"
+ raise Icmp6IntegrityError(
+ "Wrong packet length (II)",
+ )
optr = 8
for _ in range(struct.unpack("!H", self._frame[6:8])[0]):
if optr + 20 > self._plen:
- return "ICMPv6 integrity - wrong packet length (III)"
+ raise Icmp6IntegrityError(
+ "Wrong packet length (III)",
+ )
optr += (
20
+ self._frame[optr + 1]
+ struct.unpack_from("! H", self._frame, optr + 2)[0] * 16
)
if optr != self._plen:
- return "ICMPv6 integrity - wrong packet length (IV)"
-
- return ""
+ raise Icmp6IntegrityError(
+ "Wrong packet length (IV)",
+ )
def _packet_sanity_check(
- self, ip6_src: Ip6Address, ip6_dst: Ip6Address, ip6_hop: int
- ) -> str:
+ self, *, ip6_src: Ip6Address, ip6_dst: Ip6Address, ip6_hop: int
+ ) -> None:
"""
Packet sanity check to be run on parsed packet to make sure packet's
fields contain sane values.
"""
- if not config.PACKET_SANITY_CHECK:
- return ""
-
if self.type == ICMP6_UNREACHABLE:
if self.code not in {0, 1, 2, 3, 4, 5, 6}:
- return "ICMPv6 sanity - 'code' must be [0-6] (RFC 4861)"
+ raise Icmp6SanityError(
+ "The 'code' must be [0-6] (RFC 4861)",
+ )
elif self.type == ICMP6_PACKET_TOO_BIG:
if not self.code == 0:
- return "ICMPv6 sanity - 'code' should be 0 (RFC 4861)"
+ raise Icmp6SanityError(
+ "The 'code' should be 0 (RFC 4861)",
+ )
elif self.type == ICMP6_TIME_EXCEEDED:
if self.code not in {0, 1}:
- return "ICMPv6 sanity - 'code' must be [0-1] (RFC 4861)"
+ raise Icmp6SanityError(
+ "The 'code' must be [0-1] (RFC 4861)",
+ )
elif self.type == ICMP6_PARAMETER_PROBLEM:
if self.code not in {0, 1, 2}:
- return "ICMPv6 sanity - 'code' must be [0-2] (RFC 4861)"
+ raise Icmp6SanityError(
+ "The 'code' must be [0-2] (RFC 4861)",
+ )
elif self.type in {ICMP6_ECHO_REQUEST, ICMP6_ECHO_REPLY}:
if not self.code == 0:
- return "ICMPv6 sanity - 'code' should be 0 (RFC 4861)"
+ raise Icmp6SanityError(
+ "The 'code' should be 0 (RFC 4861)",
+ )
elif self.type == ICMP6_MLD2_QUERY:
if not self.code == 0:
- return "ICMPv6 sanity - 'code' must be 0 (RFC 3810)"
+ raise Icmp6SanityError(
+ "The 'code' must be 0 (RFC 3810)",
+ )
if not ip6_hop == 1:
- return "ICMPv6 sanity - 'hop' must be 255 (RFC 3810)"
+ raise Icmp6SanityError(
+ "The 'hop' must be 255 (RFC 3810)",
+ )
elif self.type == ICMP6_ND_ROUTER_SOLICITATION:
if not self.code == 0:
- return "ICMPv6 sanity - 'code' must be 0 (RFC 4861)"
+ raise Icmp6SanityError(
+ "The 'code' must be 0 (RFC 4861)",
+ )
if not ip6_hop == 255:
- return "ICMPv6 sanity - 'hop' must be 255 (RFC 4861)"
+ raise Icmp6SanityError(
+ "The 'hop' must be 255 (RFC 4861)",
+ )
if not (ip6_src.is_unicast or ip6_src.is_unspecified):
- return (
- "ICMPv6 sanity - 'src' must be unicast or unspecified "
- "(RFC 4861)"
+ raise Icmp6SanityError(
+ "The 'src' must be unicast or unspecified (RFC 4861)",
)
if not ip6_dst == Ip6Address("ff02::2"):
- return "ICMPv6 sanity - 'dst' must be all-routers (RFC 4861)"
+ raise Icmp6SanityError(
+ "The 'dst' must be all-routers (RFC 4861)",
+ )
if ip6_src.is_unspecified and self.nd_opt_slla:
- return (
- "ICMPv6 sanity - 'nd_opt_slla' must not be included if "
- "'src' is unspecified (RFC 4861)"
+ raise Icmp6SanityError(
+ "The 'nd_opt_slla' must not be included if "
+ "'src' is unspecified (RFC 4861)",
)
elif self.type == ICMP6_ND_ROUTER_ADVERTISEMENT:
if not self.code == 0:
- return "ICMPv6 sanity - 'code' must be 0 (RFC 4861)"
+ raise Icmp6SanityError(
+ "The 'code' must be 0 (RFC 4861)",
+ )
if not ip6_hop == 255:
- return "ICMPv6 sanity - 'hop' must be 255 (RFC 4861)"
+ raise Icmp6SanityError(
+ "The 'hop' must be 255 (RFC 4861)",
+ )
if not ip6_src.is_link_local:
- return "ICMPv6 sanity - 'src' must be link local (RFC 4861)"
+ raise Icmp6SanityError(
+ "The 'src' must be link local (RFC 4861)",
+ )
if not (ip6_dst.is_unicast or ip6_dst == Ip6Address("ff02::1")):
- return (
- "ICMPv6 sanity - 'dst' must be unicast or all-nodes "
- "(RFC 4861)"
+ raise Icmp6SanityError(
+ "The 'dst' must be unicast or all-nodes (RFC 4861)",
)
elif self.type == ICMP6_ND_NEIGHBOR_SOLICITATION:
if not self.code == 0:
- return "ICMPv6 sanity - 'code' must be 0 (RFC 4861)"
+ raise Icmp6SanityError(
+ "The 'code' must be 0 (RFC 4861)",
+ )
if not ip6_hop == 255:
- return "ICMPv6 sanity - 'hop' must be 255 (RFC 4861)"
+ raise Icmp6SanityError(
+ "The 'hop' must be 255 (RFC 4861)",
+ )
if not (ip6_src.is_unicast or ip6_src.is_unspecified):
- return (
- "ICMPv6 sanity - 'src' must be unicast or unspecified "
- "(RFC 4861)"
+ raise Icmp6SanityError(
+ "The 'src' must be unicast or unspecified (RFC 4861)",
)
if ip6_dst not in {
self.ns_target_address,
self.ns_target_address.solicited_node_multicast,
}:
- return (
- "ICMPv6 sanity - 'dst' must be 'ns_target_address' or it's "
- "solicited-node multicast (RFC 4861)"
+ raise Icmp6SanityError(
+ "The 'dst' must be 'ns_target_address' or it's "
+ "solicited-node multicast (RFC 4861)",
)
if not self.ns_target_address.is_unicast:
- return (
- "ICMPv6 sanity - 'ns_target_address' must be unicast "
- "(RFC 4861)"
+ raise Icmp6SanityError(
+ "The 'ns_target_address' must be unicast (RFC 4861)",
)
if ip6_src.is_unspecified and self.nd_opt_slla is not None:
- return (
- "ICMPv6 sanity - 'nd_opt_slla' must not be included if "
- "'src' is unspecified"
+ raise Icmp6SanityError(
+ "The 'nd_opt_slla' must not be included if "
+ "'src' is unspecified",
)
elif self.type == ICMP6_ND_NEIGHBOR_ADVERTISEMENT:
if not self.code == 0:
- return "ICMPv6 sanity - 'code' must be 0 (RFC 4861)"
+ raise Icmp6SanityError(
+ "The 'code' must be 0 (RFC 4861)",
+ )
if not ip6_hop == 255:
- return "ICMPv6 sanity - 'hop' must be 255 (RFC 4861)"
+ raise Icmp6SanityError(
+ "The 'hop' must be 255 (RFC 4861)",
+ )
if not ip6_src.is_unicast:
- return "ICMPv6 sanity - 'src' must be unicast (RFC 4861)"
+ raise Icmp6SanityError(
+ "The 'src' must be unicast (RFC 4861)",
+ )
if self.na_flag_s is True and not (
ip6_dst.is_unicast or ip6_dst == Ip6Address("ff02::1")
):
- return (
- "ICMPv6 sanity - if 'na_flag_s' is set then 'dst' must be "
- "unicast or all-nodes (RFC 4861)"
+ raise Icmp6SanityError(
+ "If 'na_flag_s' is set then 'dst' must be "
+ "unicast or all-nodes (RFC 4861)",
)
if self.na_flag_s is False and not ip6_dst == Ip6Address("ff02::1"):
- return (
- "ICMPv6 sanity - if 'na_flag_s' is not set then 'dst' must "
- "be all-nodes (RFC 4861)"
+ raise Icmp6SanityError(
+ "If 'na_flag_s' is not set then 'dst' must "
+ "be all-nodes (RFC 4861)",
)
elif self.type == ICMP6_MLD2_REPORT:
if not self.code == 0:
- return "ICMPv6 sanity - 'code' must be 0 (RFC 3810)"
+ raise Icmp6SanityError(
+ "The 'code' must be 0 (RFC 3810)",
+ )
if not ip6_hop == 1:
- return "ICMPv6 sanity - 'hop' must be 1 (RFC 3810)"
-
- return ""
+ raise Icmp6SanityError(
+ "The 'hop' must be 1 (RFC 3810)",
+ )
#
diff --git a/pytcp/protocols/icmp6/phrx.py b/pytcp/protocols/icmp6/phrx.py
index a26d58ea..2da671d2 100755
--- a/pytcp/protocols/icmp6/phrx.py
+++ b/pytcp/protocols/icmp6/phrx.py
@@ -45,6 +45,7 @@
from typing import TYPE_CHECKING
from pytcp.lib import stack
+from pytcp.lib.errors import PacketValidationError
from pytcp.lib.ip6_address import Ip6Address
from pytcp.lib.logger import log
from pytcp.protocols.icmp6.fpa import Icmp6NdOptTLLA
@@ -74,12 +75,12 @@ def _phrx_icmp6(self: PacketHandler, packet_rx: PacketRx) -> None:
self.packet_stats_rx.icmp6__pre_parse += 1
- Icmp6Parser(packet_rx)
-
- if packet_rx.parse_failed:
+ try:
+ Icmp6Parser(packet_rx)
+ except PacketValidationError as error:
__debug__ and log(
"icmp6",
- f"{packet_rx.tracker} - {packet_rx.parse_failed}>",
+ f"{packet_rx.tracker} - {error}>",
)
self.packet_stats_rx.icmp6__failed_parse__drop += 1
return
diff --git a/pytcp/protocols/ip4/fpa.py b/pytcp/protocols/ip4/fpa.py
index 93f168ad..6037a39c 100755
--- a/pytcp/protocols/ip4/fpa.py
+++ b/pytcp/protocols/ip4/fpa.py
@@ -44,7 +44,7 @@
from pytcp.lib.ip4_address import Ip4Address
from pytcp.lib.ip_helper import inet_cksum
from pytcp.lib.tracker import Tracker
-from pytcp.protocols.ether.ps import ETHER_TYPE_IP4
+from pytcp.protocols.ether.ps import EtherType
from pytcp.protocols.ip4.ps import (
IP4_HEADER_LEN,
IP4_OPT_EOL,
@@ -70,7 +70,7 @@ class Ip4Assembler:
IPv4 packet assembler support class.
"""
- ether_type = ETHER_TYPE_IP4
+ ether_type = EtherType.IP4
def __init__(
self,
@@ -158,6 +158,7 @@ def tracker(self) -> Tracker:
"""
Getter for the '_tracker' attribute.
"""
+
return self._tracker
@property
@@ -165,6 +166,7 @@ def dst(self) -> Ip4Address:
"""
Getter for the '_dst' attribute.
"""
+
return self._dst
@property
@@ -172,6 +174,7 @@ def src(self) -> Ip4Address:
"""
Getter for the '_src' attribute.
"""
+
return self._src
@property
@@ -179,6 +182,7 @@ def hlen(self) -> int:
"""
Getter for the '_hlen' attribute.
"""
+
return self._hlen
@property
@@ -186,6 +190,7 @@ def proto(self) -> int:
"""
Getter for the '_proto' attribute.
"""
+
return self._proto
@property
@@ -193,6 +198,7 @@ def dlen(self) -> int:
"""
Calculate data length.
"""
+
return self._plen - self._hlen
@property
@@ -201,6 +207,7 @@ def pshdr_sum(self) -> int:
Create IPv4 pseudo header used by TCP and UDP to compute
their checksums.
"""
+
pseudo_header = struct.pack(
"! 4s 4s BBH",
bytes(self._src),
@@ -209,6 +216,7 @@ def pshdr_sum(self) -> int:
self._proto,
self._plen - self._hlen,
)
+
return sum(struct.unpack("! 3L", pseudo_header))
@property
@@ -216,12 +224,14 @@ def _raw_options(self) -> bytes:
"""
Packet options in raw format.
"""
+
return b"".join(bytes(option) for option in self._options)
def assemble(self, frame: memoryview) -> None:
"""
Assemble packet into the raw form.
"""
+
struct.pack_into(
f"! BBH HH BBH 4s 4s {len(self._raw_options)}s",
frame,
@@ -247,7 +257,7 @@ class Ip4FragAssembler:
IPv4 packet fragment assembler support class.
"""
- ether_type = ETHER_TYPE_IP4
+ ether_type = EtherType.IP4
def __init__(
self,
diff --git a/pytcp/protocols/ip4/fpp.py b/pytcp/protocols/ip4/fpp.py
index 534c11a3..df4535ba 100755
--- a/pytcp/protocols/ip4/fpp.py
+++ b/pytcp/protocols/ip4/fpp.py
@@ -44,6 +44,7 @@
from typing import TYPE_CHECKING
from pytcp import config
+from pytcp.lib.errors import PacketIntegrityError, PacketSanityError
from pytcp.lib.ip4_address import Ip4Address
from pytcp.lib.ip_helper import inet_cksum
from pytcp.protocols.ip4.ps import (
@@ -59,6 +60,24 @@
from pytcp.lib.packet import PacketRx
+class Ip4IntegrityError(PacketIntegrityError):
+ """
+ Exception raised when IPv4 packet integrity check fails.
+ """
+
+ def __init__(self, message: str):
+ super().__init__("[IPv4] " + message)
+
+
+class Ip4SanityError(PacketSanityError):
+ """
+ Exception raised when IPv4 packet sanity check fails.
+ """
+
+ def __init__(self, message: str):
+ super().__init__("[IPv4] " + message)
+
+
class Ip4Parser:
"""
IPv4 packet parser class.
@@ -74,12 +93,10 @@ def __init__(self, packet_rx: PacketRx) -> None:
self._frame = packet_rx.frame
- packet_rx.parse_failed = (
- self._packet_integrity_check() or self._packet_sanity_check()
- )
+ self._packet_integrity_check()
+ self._packet_sanity_check()
- if not packet_rx.parse_failed:
- packet_rx.frame = packet_rx.frame[self.hlen :]
+ packet_rx.frame = packet_rx.frame[self.hlen :]
def __len__(self) -> int:
"""
@@ -324,25 +341,28 @@ def pshdr_sum(self) -> int:
)
return self._cache__pshdr_sum
- def _packet_integrity_check(self) -> str:
+ def _packet_integrity_check(self) -> None:
"""
Packet integrity check to be run on raw packet prior to parsing
to make sure parsing is safe.
"""
- if not config.PACKET_INTEGRITY_CHECK:
- return ""
-
if len(self) < IP4_HEADER_LEN:
- return "IPv4 integrity - wrong packet length (I)"
+ raise Ip4IntegrityError(
+ "The wrong packet length (I)",
+ )
if not IP4_HEADER_LEN <= self.hlen <= self.plen <= len(self):
- return "IPv4 integrity - wrong packet length (II)"
+ raise Ip4IntegrityError(
+ "The wrong packet length (II)",
+ )
# Cannot compute checksum earlier because it depends
# on sanity of hlen field
if inet_cksum(self._frame[: self.hlen]):
- return "IPv4 integriy - wrong packet checksum"
+ raise Ip4IntegrityError(
+ "The wrong packet checksum",
+ )
optr = IP4_HEADER_LEN
while optr < self.hlen:
@@ -351,55 +371,69 @@ def _packet_integrity_check(self) -> str:
if self._frame[optr] == IP4_OPT_NOP:
optr += 1
if optr > self.hlen:
- return "IPv4 integrity - wrong option length (I)"
+ raise Ip4IntegrityError(
+ "The integrity - wrong option length (I)",
+ )
continue
if optr + 1 > self.hlen:
- return "IPv4 integrity - wrong option length (II)"
+ raise Ip4IntegrityError(
+ "The wrong option length (II)",
+ )
if self._frame[optr + 1] == 0:
- return "IPv4 integrity - wrong option length (III)"
+ raise Ip4IntegrityError(
+ "The wrong option length (III)",
+ )
optr += self._frame[optr + 1]
if optr > self.hlen:
- return "IPv4 integrity - wrong option length (IV)"
-
- return ""
+ raise Ip4IntegrityError(
+ "The wrong option length (IV)",
+ )
- def _packet_sanity_check(self) -> str:
+ def _packet_sanity_check(self) -> None:
"""
Packet sanity check to be run on parsed packet to make sure packet's
fields contain sane values.
"""
- if not config.PACKET_SANITY_CHECK:
- return ""
-
if self.ver != 4:
- return "IP sanityi - 'ver' must be 4"
+ raise Ip4SanityError(
+ "The 'ver' must be 4",
+ )
if self.ver == 0:
- return "IP sanity - 'ttl' must be greater than 0"
+ raise Ip4SanityError(
+ "The 'ttl' must be greater than 0",
+ )
if self.src.is_multicast:
- return "IP sanity - 'src' must not be multicast"
+ raise Ip4SanityError(
+ "The 'src' must not be multicast",
+ )
if self.src.is_reserved:
- return "IP sanity - 'src' must not be reserved"
+ raise Ip4SanityError(
+ "The 'src' must not be reserved",
+ )
if self.src.is_limited_broadcast:
- return "IP sanity - 'src' must not be limited broadcast"
+ raise Ip4SanityError(
+ "The 'src' must not be limited broadcast",
+ )
if self.flag_df and self.flag_mf:
- return (
- "IP sanity - 'flag_df' and 'flag_mf' must not be set "
- "simultaneously"
+ raise Ip4SanityError(
+ "The 'flag_df' and 'flag_mf' must not be set simultaneously",
)
if self.offset and self.flag_df:
- return "IP sanity - 'offset' must be 0 when 'df_flag' is set"
+ raise Ip4SanityError(
+ "The 'offset' must be 0 when 'df_flag' is set",
+ )
if self.options and config.IP4_OPTION_PACKET_DROP:
- return "IP sanity - packet must not contain options"
-
- return ""
+ raise Ip4SanityError(
+ "The packet must not contain options",
+ )
#
diff --git a/pytcp/protocols/ip4/phrx.py b/pytcp/protocols/ip4/phrx.py
index b660ebdc..f25accf9 100755
--- a/pytcp/protocols/ip4/phrx.py
+++ b/pytcp/protocols/ip4/phrx.py
@@ -42,6 +42,7 @@
from typing import TYPE_CHECKING
from pytcp import config
+from pytcp.lib.errors import PacketValidationError
from pytcp.lib.ip_helper import inet_cksum
from pytcp.lib.logger import log
from pytcp.lib.packet import PacketRx
@@ -135,13 +136,13 @@ def _phrx_ip4(self: PacketHandler, packet_rx: PacketRx) -> None:
self.packet_stats_rx.ip4__pre_parse += 1
- Ip4Parser(packet_rx)
-
- if packet_rx.parse_failed:
+ try:
+ Ip4Parser(packet_rx)
+ except PacketValidationError as error:
self.packet_stats_rx.ip4__failed_parse__drop += 1
__debug__ and log(
"ip4",
- f"{packet_rx.tracker} - {packet_rx.parse_failed}>",
+ f"{packet_rx.tracker} - {error}>",
)
return
diff --git a/pytcp/protocols/ip6/fpa.py b/pytcp/protocols/ip6/fpa.py
index bcad1b77..de30d994 100755
--- a/pytcp/protocols/ip6/fpa.py
+++ b/pytcp/protocols/ip6/fpa.py
@@ -41,7 +41,7 @@
from pytcp import config
from pytcp.lib.ip6_address import Ip6Address
-from pytcp.protocols.ether.ps import ETHER_TYPE_IP6
+from pytcp.protocols.ether.ps import EtherType
from pytcp.protocols.ip6.ps import (
IP6_HEADER_LEN,
IP6_NEXT_EXT_FRAG,
@@ -66,7 +66,7 @@ class Ip6Assembler:
IPv6 packet assembler support class.
"""
- ether_type = ETHER_TYPE_IP6
+ ether_type = EtherType.IP6
def __init__(
self,
diff --git a/pytcp/protocols/ip6/fpp.py b/pytcp/protocols/ip6/fpp.py
index 178bee0d..26037b79 100755
--- a/pytcp/protocols/ip6/fpp.py
+++ b/pytcp/protocols/ip6/fpp.py
@@ -40,7 +40,7 @@
import struct
from typing import TYPE_CHECKING
-from pytcp import config
+from pytcp.lib.errors import PacketIntegrityError, PacketSanityError
from pytcp.lib.ip6_address import Ip6Address
from pytcp.protocols.ip6.ps import IP6_HEADER_LEN, IP6_NEXT_TABLE
@@ -48,6 +48,24 @@
from pytcp.lib.packet import PacketRx
+class Ip6IntegrityError(PacketIntegrityError):
+ """
+ Exception raised when IPv6 packet integrity check fails.
+ """
+
+ def __init__(self, message: str):
+ super().__init__("[IPv6] " + message)
+
+
+class Ip6SanityError(PacketSanityError):
+ """
+ Exception raised when IPv6 packet sanity check fails.
+ """
+
+ def __init__(self, message: str):
+ super().__init__("[IPv6] " + message)
+
+
class Ip6Parser:
"""
IPv6 packet parser class.
@@ -63,12 +81,10 @@ def __init__(self, packet_rx: PacketRx) -> None:
self._frame = packet_rx.frame
- packet_rx.parse_failed = (
- self._packet_integrity_check() or self._packet_sanity_check()
- )
+ self._packet_integrity_check()
+ self._packet_sanity_check()
- if not packet_rx.parse_failed:
- packet_rx.frame = packet_rx.frame[IP6_HEADER_LEN:]
+ packet_rx.frame = packet_rx.frame[IP6_HEADER_LEN:]
def __len__(self) -> int:
"""
@@ -235,41 +251,42 @@ def pshdr_sum(self) -> int:
return self._cache__pshdr_sum
- def _packet_integrity_check(self) -> str:
+ def _packet_integrity_check(self) -> None:
"""
Packet integrity check to be run on raw packet prior to parsing
to make sure parsing is safe.
"""
- if not config.PACKET_INTEGRITY_CHECK:
- return ""
if len(self) < IP6_HEADER_LEN:
- return "IPv6 integrity - wrong packet length (I)"
+ raise Ip6IntegrityError(
+ "The wrong packet length (I)",
+ )
if (
struct.unpack("!H", self._frame[4:6])[0]
!= len(self) - IP6_HEADER_LEN
):
- return "IPv6 integrity - wrong packet length (II)"
-
- return ""
+ raise Ip6IntegrityError(
+ "The wrong packet length (II)",
+ )
- def _packet_sanity_check(self) -> str:
+ def _packet_sanity_check(self) -> None:
"""
Packet sanity check to be run on parsed packet to make sure packet's
fields contain sane values.
"""
- if not config.PACKET_SANITY_CHECK:
- return ""
-
if self.ver != 6:
- return "IPv6 sanity - 'ver' must be 6"
+ raise Ip6SanityError(
+ "The 'ver' must be 6",
+ )
if self.hop == 0:
- return "IPv6 sanity - 'hop' must not be 0"
+ raise Ip6SanityError(
+ "The 'hop' must not be 0",
+ )
if self.src.is_multicast:
- return "IPv6 sanity - 'src' must not be multicast"
-
- return ""
+ raise Ip6SanityError(
+ "The 'src' must not be multicast",
+ )
diff --git a/pytcp/protocols/ip6/phrx.py b/pytcp/protocols/ip6/phrx.py
index 06add1fb..161fad57 100755
--- a/pytcp/protocols/ip6/phrx.py
+++ b/pytcp/protocols/ip6/phrx.py
@@ -39,6 +39,7 @@
from typing import TYPE_CHECKING
+from pytcp.lib.errors import PacketValidationError
from pytcp.lib.logger import log
from pytcp.lib.packet import PacketRx
from pytcp.protocols.ip6.fpp import Ip6Parser
@@ -60,13 +61,11 @@ def _phrx_ip6(self: PacketHandler, packet_rx: PacketRx) -> None:
self.packet_stats_rx.ip6__pre_parse += 1
- Ip6Parser(packet_rx)
-
- if packet_rx.parse_failed:
+ try:
+ Ip6Parser(packet_rx)
+ except PacketValidationError as error:
self.packet_stats_rx.ip6__failed_parse__drop += 1
- __debug__ and log(
- "ip6", f"{packet_rx.tracker} - {packet_rx.parse_failed}>"
- )
+ __debug__ and log("ip6", f"{packet_rx.tracker} - {error}>")
return
__debug__ and log("ip6", f"{packet_rx.tracker} - {packet_rx.ip6}")
diff --git a/pytcp/protocols/ip6_ext_frag/fpp.py b/pytcp/protocols/ip6_ext_frag/fpp.py
index 910a38fa..74ea9a11 100755
--- a/pytcp/protocols/ip6_ext_frag/fpp.py
+++ b/pytcp/protocols/ip6_ext_frag/fpp.py
@@ -41,7 +41,7 @@
import struct
from typing import TYPE_CHECKING
-from pytcp import config
+from pytcp.lib.errors import PacketIntegrityError, PacketSanityError
from pytcp.protocols.ip6_ext_frag.ps import (
IP6_EXT_FRAG_HEADER_LEN,
IP6_EXT_FRAG_NEXT_HEADER_TABLE,
@@ -51,6 +51,24 @@
from pytcp.lib.packet import PacketRx
+class Ip6ExtFragIntegrityError(PacketIntegrityError):
+ """
+ Exception raised when IPv6 Ext Frag packet integrity check fails.
+ """
+
+ def __init__(self, message: str):
+ super().__init__("[IPv6 Ext Frag] " + message)
+
+
+class Ip6ExtFragSanityError(PacketSanityError):
+ """
+ Exception raised when IPv6 Ext Frag packet sanity check fails.
+ """
+
+ def __init__(self, message: str):
+ super().__init__("[IPv6 Ext Frag] " + message)
+
+
class Ip6ExtFragParser:
"""
IPv6 fragmentation extension header parser class.
@@ -68,12 +86,10 @@ def __init__(self, packet_rx: PacketRx) -> None:
self._frame = packet_rx.frame
self._plen = packet_rx.ip6.dlen
- packet_rx.parse_failed = (
- self._packet_integrity_check() or self._packet_sanity_check()
- )
+ self._packet_integrity_check()
+ self._packet_sanity_check()
- if not packet_rx.parse_failed:
- packet_rx.frame = packet_rx.frame[IP6_EXT_FRAG_HEADER_LEN:]
+ packet_rx.frame = packet_rx.frame[IP6_EXT_FRAG_HEADER_LEN:]
def __len__(self) -> int:
"""
@@ -177,27 +193,21 @@ def packet_copy(self) -> bytes:
self._cache__packet_copy = bytes(self._frame[: self.plen])
return self._cache__packet_copy
- def _packet_integrity_check(self) -> str:
+ def _packet_integrity_check(self) -> None:
"""
Packet integrity check to be run on raw packet prior to parsing
to make sure parsing is safe.
"""
- if not config.PACKET_INTEGRITY_CHECK:
- return ""
-
if len(self) < IP6_EXT_FRAG_HEADER_LEN:
- return "IPv4 integrity - wrong packet length (I)"
-
- return ""
+ raise Ip6ExtFragIntegrityError(
+ "The wrong packet length (I)",
+ )
- def _packet_sanity_check(self) -> str:
+ def _packet_sanity_check(self) -> None:
"""
Packet sanity check to be run on parsed packet to make sure packet's
fields contain sane values.
"""
- if not config.PACKET_SANITY_CHECK:
- return ""
-
- return ""
+ pass
diff --git a/pytcp/protocols/raw/fpa.py b/pytcp/protocols/raw/fpa.py
index 7fa5bb98..71f02a15 100755
--- a/pytcp/protocols/raw/fpa.py
+++ b/pytcp/protocols/raw/fpa.py
@@ -38,7 +38,7 @@
import struct
from pytcp.lib.tracker import Tracker
-from pytcp.protocols.ether.ps import ETHER_TYPE_RAW
+from pytcp.protocols.ether.ps import EtherType
from pytcp.protocols.ip4.ps import IP4_PROTO_RAW
from pytcp.protocols.ip6.ps import IP6_NEXT_RAW
@@ -50,7 +50,7 @@ class RawAssembler:
ip4_proto = IP4_PROTO_RAW
ip6_next = IP6_NEXT_RAW
- ether_type = ETHER_TYPE_RAW
+ ether_type = EtherType.RAW
def __init__(
self, *, data: bytes | None = None, echo_tracker: Tracker | None = None
diff --git a/pytcp/protocols/tcp/fpp.py b/pytcp/protocols/tcp/fpp.py
index f5aa84ab..5e0a3c0c 100755
--- a/pytcp/protocols/tcp/fpp.py
+++ b/pytcp/protocols/tcp/fpp.py
@@ -42,7 +42,7 @@
import struct
from typing import TYPE_CHECKING
-from pytcp import config
+from pytcp.lib.errors import PacketIntegrityError, PacketSanityError
from pytcp.lib.ip_helper import inet_cksum
from pytcp.protocols.tcp.ps import (
TCP_HEADER_LEN,
@@ -60,6 +60,24 @@
from pytcp.lib.packet import PacketRx
+class TcpIntegrityError(PacketIntegrityError):
+ """
+ Exception raised when TCP packet integrity check fails.
+ """
+
+ def __init__(self, message: str):
+ super().__init__("[TCP] " + message)
+
+
+class TcpSanityError(PacketSanityError):
+ """
+ Exception raised when TCP packet sanity check fails.
+ """
+
+ def __init__(self, message: str):
+ super().__init__("[TCP] " + message)
+
+
class TcpParser:
"""
TCP packet parser class.
@@ -77,10 +95,10 @@ def __init__(self, packet_rx: PacketRx) -> None:
self._frame = packet_rx.frame
self._plen = packet_rx.ip.dlen
- packet_rx.parse_failed = (
- self._packet_integrity_check(packet_rx.ip.pshdr_sum)
- or self._packet_sanity_check()
+ self._packet_integrity_check(
+ pshdr_sum=packet_rx.ip.pshdr_sum,
)
+ self._packet_sanity_check()
if packet_rx.parse_failed:
packet_rx.frame = packet_rx.frame[self.hlen :]
@@ -432,24 +450,27 @@ def timestamp(self) -> tuple[int, int] | None:
self._cache__timestamp = None
return self._cache__timestamp
- def _packet_integrity_check(self, pshdr_sum: int) -> str:
+ def _packet_integrity_check(self, *, pshdr_sum: int) -> None:
"""
Packet integrity check to be run on raw frame prior to parsing
to make sure parsing is safe.
"""
- if not config.PACKET_INTEGRITY_CHECK:
- return ""
-
if inet_cksum(self._frame[: self._plen], pshdr_sum):
- return "TCP integrity - wrong packet checksum"
+ raise TcpIntegrityError(
+ "The wrong packet checksum",
+ )
if not TCP_HEADER_LEN <= self._plen <= len(self):
- return "TCP integrity - wrong packet length (I)"
+ raise TcpIntegrityError(
+ "The wrong packet length (I)",
+ )
hlen = (self._frame[12] & 0b11110000) >> 2
if not TCP_HEADER_LEN <= hlen <= self._plen <= len(self):
- return "TCP integrity - wrong packet length (II)"
+ raise TcpIntegrityError(
+ "The wrong packet length (II)",
+ )
optr = TCP_HEADER_LEN
while optr < hlen:
@@ -458,61 +479,69 @@ def _packet_integrity_check(self, pshdr_sum: int) -> str:
if self._frame[optr] == TCP_OPT_NOP:
optr += 1
if optr > hlen:
- return "TCP integrity - wrong option length (I)"
+ raise TcpIntegrityError(
+ "The wrong option length (I)",
+ )
continue
if optr + 1 > hlen:
- return "TCP integrity - wrong option length (II)"
+ raise TcpIntegrityError(
+ "The wrong option length (II)",
+ )
if self._frame[optr + 1] == 0:
- return "TCP integrity - wrong option length (III)"
+ raise TcpIntegrityError(
+ "The wrong option length (III)",
+ )
optr += self._frame[optr + 1]
if optr > hlen:
- return "TCP integrity - wrong option length (IV)"
-
- return ""
+ raise TcpIntegrityError(
+ "The wrong option length (IV)",
+ )
- def _packet_sanity_check(self) -> str:
+ def _packet_sanity_check(self) -> None:
"""
Packet sanity check to be run on parsed packet to make sure packets's
fields contain sane values.
"""
- if not config.PACKET_SANITY_CHECK:
- return ""
-
if self.sport == 0:
- return "TCP sanity - 'sport' must be greater than 0"
+ raise TcpSanityError(
+ "The 'sport' must be greater than 0",
+ )
if self.dport == 0:
- return "TCP sanity - 'dport' must be greater than 0"
+ raise TcpSanityError(
+ "The 'dport' must be greater than 0",
+ )
if self.flag_syn and self.flag_fin:
- return (
- "TCP sanity - 'flag_syn' and 'flag_fin' must not be set "
- "simultaneously"
+ raise TcpSanityError(
+ "The 'flag_syn' and 'flag_fin' must not be set simultaneously",
)
if self.flag_syn and self.flag_rst:
- return (
- "TCP sanity - 'flag_syn' and 'flag_rst' must not set "
- "simultaneously"
+ raise TcpSanityError(
+ "The 'flag_syn' and 'flag_rst' must not set simultaneously",
)
if self.flag_fin and self.flag_rst:
- return (
- "TCP sanity - 'flag_fin' and 'flag_rst' must not be set "
- "simultaneously"
+ raise TcpSanityError(
+ "The 'flag_fin' and 'flag_rst' must not be set simultaneously",
)
if self.flag_fin and not self.flag_ack:
- return "TCP sanity - 'flag_ack' must be set when 'flag_fin' is set"
+ raise TcpSanityError(
+ "The 'flag_ack' must be set when 'flag_fin' is set",
+ )
if self.ack and not self.flag_ack:
- return "TCP sanity - 'flag_ack' must be set when 'ack' is not 0"
+ raise TcpSanityError(
+ "The 'flag_ack' must be set when 'ack' is not 0",
+ )
if self.urg and not self.flag_urg:
- return "TCP sanity - 'flag_urg' must be set when 'urg' is not 0"
-
- return ""
+ raise TcpSanityError(
+ "The 'flag_urg' must be set when 'urg' is not 0",
+ )
#
diff --git a/pytcp/protocols/tcp/phrx.py b/pytcp/protocols/tcp/phrx.py
index 3770a2c1..522a2a75 100755
--- a/pytcp/protocols/tcp/phrx.py
+++ b/pytcp/protocols/tcp/phrx.py
@@ -40,6 +40,7 @@
from typing import TYPE_CHECKING
from pytcp.lib import stack
+from pytcp.lib.errors import PacketValidationError
from pytcp.lib.logger import log
from pytcp.lib.packet import PacketRx
from pytcp.protocols.tcp.fpp import TcpParser
@@ -56,13 +57,13 @@ def _phrx_tcp(self: PacketHandler, packet_rx: PacketRx) -> None:
self.packet_stats_rx.tcp__pre_parse += 1
- TcpParser(packet_rx)
-
- if packet_rx.parse_failed:
+ try:
+ TcpParser(packet_rx)
+ except PacketValidationError as error:
self.packet_stats_rx.tcp__failed_parse__drop += 1
__debug__ and log(
"tcp",
- f"{packet_rx.tracker} - {packet_rx.parse_failed}>",
+ f"{packet_rx.tracker} - {error}>",
)
return
diff --git a/pytcp/protocols/udp/fpp.py b/pytcp/protocols/udp/fpp.py
index a3fd51e0..865f7bc4 100755
--- a/pytcp/protocols/udp/fpp.py
+++ b/pytcp/protocols/udp/fpp.py
@@ -40,7 +40,7 @@
import struct
from typing import TYPE_CHECKING
-from pytcp import config
+from pytcp.lib.errors import PacketIntegrityError, PacketSanityError
from pytcp.lib.ip_helper import inet_cksum
from pytcp.protocols.udp.ps import UDP_HEADER_LEN
@@ -48,6 +48,24 @@
from pytcp.lib.packet import PacketRx
+class UdpIntegrityError(PacketIntegrityError):
+ """
+ Exception raised when UDP packet integrity check fails.
+ """
+
+ def __init__(self, message: str):
+ super().__init__("[UDP] " + message)
+
+
+class UdpSanityError(PacketSanityError):
+ """
+ Exception raised when UDP packet sanity check fails.
+ """
+
+ def __init__(self, message: str):
+ super().__init__("[UDP] " + message)
+
+
class UdpParser:
"""
UDP packet parser class.
@@ -65,10 +83,8 @@ def __init__(self, packet_rx: PacketRx) -> None:
self._frame = packet_rx.frame
self._plen = packet_rx.ip.dlen
- packet_rx.parse_failed = (
- self._packet_integrity_check(packet_rx.ip.pshdr_sum)
- or self._packet_sanity_check()
- )
+ self._packet_integrity_check(pshdr_sum=packet_rx.ip.pshdr_sum)
+ self._packet_sanity_check()
if not packet_rx.parse_failed:
packet_rx.frame = packet_rx.frame[UDP_HEADER_LEN:]
@@ -166,40 +182,40 @@ def packet_copy(self) -> bytes:
self._cache__packet_copy = bytes(self._frame[: self.plen])
return self._cache__packet_copy
- def _packet_integrity_check(self, pshdr_sum: int) -> str:
+ def _packet_integrity_check(self, pshdr_sum: int) -> None:
"""
Packet integrity check to be run on raw frame prior to parsing
to make sure parsing is safe.
"""
- if not config.PACKET_INTEGRITY_CHECK:
- return ""
-
if inet_cksum(self._frame[: self._plen], pshdr_sum):
- return "UDP integrity - wrong packet checksum"
+ raise UdpIntegrityError(
+ "The wrong packet checksum",
+ )
if not UDP_HEADER_LEN <= self._plen <= len(self):
- return "UDP integrity - wrong packet length (I)"
+ raise UdpIntegrityError(
+ "The wrong packet length (I)",
+ )
plen = struct.unpack("!H", self._frame[4:6])[0]
if not UDP_HEADER_LEN <= plen == self._plen <= len(self):
- return "UDP integrity - wrong packet length (II)"
-
- return ""
+ raise UdpIntegrityError(
+ "The wrong packet length (II)",
+ )
- def _packet_sanity_check(self) -> str:
+ def _packet_sanity_check(self) -> None:
"""
Packet sanity check to be run on parsed packet to make sure packet's
fields contain sane values.
"""
- if not config.PACKET_SANITY_CHECK:
- return ""
-
if self.sport == 0:
- return "UDP sanity - 'udp_sport' must be greater than 0"
+ raise UdpSanityError(
+ "The 'udp_sport' must be greater than 0",
+ )
if self.dport == 0:
- return "UDP sanity - 'udp_dport' must be greater then 0"
-
- return ""
+ raise UdpSanityError(
+ "The 'udp_dport' must be greater then 0",
+ )
diff --git a/pytcp/protocols/udp/phrx.py b/pytcp/protocols/udp/phrx.py
index 091a70e7..eb1aa021 100755
--- a/pytcp/protocols/udp/phrx.py
+++ b/pytcp/protocols/udp/phrx.py
@@ -41,9 +41,10 @@
from pytcp import config
from pytcp.lib import stack
+from pytcp.lib.errors import PacketValidationError
from pytcp.lib.logger import log
from pytcp.lib.packet import PacketRx
-from pytcp.protocols.icmp4.ps import ICMP4_UNREACHABLE, ICMP4_UNREACHABLE__PORT
+from pytcp.protocols.icmp4.ps import Icmp4Type, Icmp4UnreachableCode
from pytcp.protocols.icmp6.ps import ICMP6_UNREACHABLE, ICMP6_UNREACHABLE__PORT
from pytcp.protocols.udp.fpp import UdpParser
from pytcp.protocols.udp.metadata import UdpMetadata
@@ -59,13 +60,13 @@ def _phrx_udp(self: PacketHandler, packet_rx: PacketRx) -> None:
self.packet_stats_rx.udp__pre_parse += 1
- UdpParser(packet_rx)
-
- if packet_rx.parse_failed:
+ try:
+ UdpParser(packet_rx)
+ except PacketValidationError as error:
self.packet_stats_rx.udp__failed_parse__drop += 1
__debug__ and log(
"udp",
- f"{packet_rx.tracker} - {packet_rx.parse_failed}>",
+ f"{packet_rx.tracker} - {error}>",
)
return
@@ -158,8 +159,8 @@ def _phrx_udp(self: PacketHandler, packet_rx: PacketRx) -> None:
self._phtx_icmp4(
ip4_src=packet_rx.ip4.dst,
ip4_dst=packet_rx.ip4.src,
- icmp4_type=ICMP4_UNREACHABLE,
- icmp4_code=ICMP4_UNREACHABLE__PORT,
+ icmp4_type=Icmp4Type.UNREACHABLE,
+ icmp4_code=Icmp4UnreachableCode.PORT,
icmp4_un_data=packet_rx.ip.packet_copy,
echo_tracker=packet_rx.tracker,
)
diff --git a/pytcp/subsystems/arp_cache.py b/pytcp/subsystems/arp_cache.py
index 8973c46e..6d33954b 100755
--- a/pytcp/subsystems/arp_cache.py
+++ b/pytcp/subsystems/arp_cache.py
@@ -46,7 +46,7 @@
from pytcp.lib.ip4_address import Ip4Address
from pytcp.lib.logger import log
from pytcp.lib.mac_address import MacAddress
-from pytcp.protocols.arp.ps import ARP_OP_REQUEST
+from pytcp.protocols.arp.ps import ArpOperation
class ArpCache:
@@ -181,7 +181,7 @@ def _send_arp_request(self, arp_tpa: Ip4Address) -> None:
stack.packet_handler._phtx_arp(
ether_src=stack.packet_handler.mac_unicast,
ether_dst=MacAddress(0xFFFFFFFFFFFF),
- arp_oper=ARP_OP_REQUEST,
+ arp_oper=ArpOperation.REQUEST,
arp_sha=stack.packet_handler.mac_unicast,
arp_spa=stack.packet_handler.ip4_unicast[0]
if stack.packet_handler.ip4_unicast
diff --git a/pytcp/subsystems/packet_handler.py b/pytcp/subsystems/packet_handler.py
index 742ffc55..a69e51a3 100755
--- a/pytcp/subsystems/packet_handler.py
+++ b/pytcp/subsystems/packet_handler.py
@@ -56,12 +56,13 @@
from pytcp.lib.packet_stats import PacketStatsRx, PacketStatsTx
from pytcp.protocols.arp.phrx import _phrx_arp
from pytcp.protocols.arp.phtx import _phtx_arp
-from pytcp.protocols.arp.ps import ARP_OP_REPLY, ARP_OP_REQUEST
+from pytcp.protocols.arp.ps import ArpOperation
from pytcp.protocols.dhcp4.client import Dhcp4Client
from pytcp.protocols.ether.phrx import _phrx_ether
from pytcp.protocols.ether.phtx import _phtx_ether
from pytcp.protocols.icmp4.phrx import _phrx_icmp4
from pytcp.protocols.icmp4.phtx import _phtx_icmp4
+from pytcp.protocols.icmp4.ps import Icmp4Code, Icmp4Type
from pytcp.protocols.icmp6.fpa import (
Icmp6MulticastAddressRecord,
Icmp6NdOptPI,
@@ -473,7 +474,7 @@ def _send_arp_probe(self, ip4_unicast: Ip4Address) -> None:
self._phtx_arp(
ether_src=self.mac_unicast,
ether_dst=MacAddress(0xFFFFFFFFFFFF),
- arp_oper=ARP_OP_REQUEST,
+ arp_oper=ArpOperation.REQUEST,
arp_sha=self.mac_unicast,
arp_spa=Ip4Address(0),
arp_tha=MacAddress(0),
@@ -488,7 +489,7 @@ def _send_arp_announcement(self, ip4_unicast: Ip4Address) -> None:
self._phtx_arp(
ether_src=self.mac_unicast,
ether_dst=MacAddress(0xFFFFFFFFFFFF),
- arp_oper=ARP_OP_REQUEST,
+ arp_oper=ArpOperation.REQUEST,
arp_sha=self.mac_unicast,
arp_spa=ip4_unicast,
arp_tha=MacAddress(0),
@@ -505,7 +506,7 @@ def _send_gratitous_arp(self, ip4_unicast: Ip4Address) -> None:
self._phtx_arp(
ether_src=self.mac_unicast,
ether_dst=MacAddress(0xFFFFFFFFFFFF),
- arp_oper=ARP_OP_REPLY,
+ arp_oper=ArpOperation.REPLY,
arp_sha=self.mac_unicast,
arp_spa=ip4_unicast,
arp_tha=MacAddress(0),
@@ -684,8 +685,8 @@ def send_icmp4_packet(
self,
local_ip_address: Ip4Address,
remote_ip_address: Ip4Address,
- type: int,
- code: int = 0,
+ type: Icmp4Type,
+ code: Icmp4Code,
ec_id: int | None = None,
ec_seq: int | None = None,
ec_data: bytes | None = None,
diff --git a/tests/integration/packet_flows_rx.py b/tests/integration/packet_flows_rx.py
index 41d0e6e6..cb28956b 100755
--- a/tests/integration/packet_flows_rx.py
+++ b/tests/integration/packet_flows_rx.py
@@ -69,8 +69,6 @@
"LOG_CHANEL": set(),
"IP6_SUPPORT": True,
"IP4_SUPPORT": True,
- "PACKET_INTEGRITY_CHECK": True,
- "PACKET_SANITY_CHECK": True,
"TAP_MTU": 1500,
"UDP_ECHO_NATIVE_DISABLE": False,
}
diff --git a/tests/integration/packet_flows_rx_tx.py b/tests/integration/packet_flows_rx_tx.py
index 9c8d5cf6..4afe5588 100755
--- a/tests/integration/packet_flows_rx_tx.py
+++ b/tests/integration/packet_flows_rx_tx.py
@@ -74,8 +74,6 @@
"LOG_CHANEL": set(),
"IP6_SUPPORT": True,
"IP4_SUPPORT": True,
- "PACKET_INTEGRITY_CHECK": True,
- "PACKET_SANITY_CHECK": True,
"TAP_MTU": 1500,
"UDP_ECHO_NATIVE_DISABLE": False,
}
diff --git a/tests/unit/mock_network.py b/tests/unit/mock_network.py
index 10df3d55..92fc43b3 100755
--- a/tests/unit/mock_network.py
+++ b/tests/unit/mock_network.py
@@ -142,8 +142,6 @@ def __init__(self) -> None:
"LOG_CHANEL": set(),
"IP6_SUPPORT": True,
"IP4_SUPPORT": True,
- "PACKET_INTEGRITY_CHECK": True,
- "PACKET_SANITY_CHECK": True,
"TAP_MTU": 1500,
"UDP_ECHO_NATIVE_DISABLE": False,
"IP4_DEFAULT_TTL": 64,
diff --git a/tests/unit/protocols__arp__fpa.py b/tests/unit/protocols__arp__fpa.py
index 32b9f053..bfa71c0d 100755
--- a/tests/unit/protocols__arp__fpa.py
+++ b/tests/unit/protocols__arp__fpa.py
@@ -36,8 +36,8 @@
from pytcp.lib.ip4_address import Ip4Address
from pytcp.lib.mac_address import MacAddress
from pytcp.protocols.arp.fpa import ArpAssembler
-from pytcp.protocols.arp.ps import ARP_HEADER_LEN, ARP_OP_REPLY, ARP_OP_REQUEST
-from pytcp.protocols.ether.ps import ETHER_TYPE_ARP
+from pytcp.protocols.arp.ps import ARP_HEADER_LEN, ArpOperation
+from pytcp.protocols.ether.ps import EtherType
class TestArpAssembler(TestCase):
@@ -50,7 +50,8 @@ def test_arp_fpa__ethertype(self) -> None:
Make sure the 'ArpAssembler' class has the proper
'ethertype' value assigned.
"""
- self.assertEqual(ArpAssembler.ether_type, ETHER_TYPE_ARP)
+
+ self.assertEqual(ArpAssembler._ether_type, EtherType.ARP)
def test_arp_fpa____init__(self) -> None:
"""
@@ -61,13 +62,13 @@ def test_arp_fpa____init__(self) -> None:
spa=Ip4Address("1.2.3.4"),
tha=MacAddress("66:77:88:99:AA:BB"),
tpa=Ip4Address("5.6.7.8"),
- oper=ARP_OP_REPLY,
+ oper=ArpOperation.REPLY,
)
self.assertEqual(packet._sha, MacAddress("00:11:22:33:44:55"))
self.assertEqual(packet._spa, Ip4Address("1.2.3.4"))
self.assertEqual(packet._tha, MacAddress("66:77:88:99:AA:BB"))
self.assertEqual(packet._tpa, Ip4Address("5.6.7.8"))
- self.assertEqual(packet._oper, ARP_OP_REPLY)
+ self.assertEqual(packet._oper, ArpOperation.REPLY)
def test_arp_fpa____init____defaults(self) -> None:
"""
@@ -78,26 +79,19 @@ def test_arp_fpa____init____defaults(self) -> None:
self.assertEqual(packet._spa, Ip4Address("0.0.0.0"))
self.assertEqual(packet._tha, MacAddress("00:00:00:00:00:00"))
self.assertEqual(packet._tpa, Ip4Address("0.0.0.0"))
- self.assertEqual(packet._oper, ARP_OP_REQUEST)
+ self.assertEqual(packet._oper, ArpOperation.REQUEST)
def test_arp_fpa____init____assert_oper_request(self) -> None:
"""
Test assertion for the request operation.
"""
- ArpAssembler(oper=ARP_OP_REQUEST)
+ ArpAssembler(oper=ArpOperation.REQUEST)
def test_arp_fpa____init____assert_oper_reply(self) -> None:
"""
Test assertion for the request operation.
"""
- ArpAssembler(oper=ARP_OP_REPLY)
-
- def test_arp_fpa____init____assert_oper_unknown(self) -> None:
- """
- Test assertion for the unknown operation.
- """
- with self.assertRaises(AssertionError):
- ArpAssembler(oper=-1)
+ ArpAssembler(oper=ArpOperation.REPLY)
def test_arp_fpa____len__(self) -> None:
"""
@@ -115,7 +109,7 @@ def test_arp_fpa____str____request(self) -> None:
spa=Ip4Address("1.2.3.4"),
tha=MacAddress("66:77:88:99:AA:BB"),
tpa=Ip4Address("5.6.7.8"),
- oper=ARP_OP_REQUEST,
+ oper=ArpOperation.REQUEST,
)
self.assertEqual(
str(packet),
@@ -132,7 +126,7 @@ def test_arp_fpa____str____reply(self) -> None:
spa=Ip4Address("1.2.3.4"),
tha=MacAddress("66:77:88:99:AA:BB"),
tpa=Ip4Address("5.6.7.8"),
- oper=ARP_OP_REPLY,
+ oper=ArpOperation.REPLY,
)
self.assertEqual(
str(packet),
@@ -158,7 +152,7 @@ def test_ether_fpa__assemble(self) -> None:
spa=Ip4Address("1.2.3.4"),
tha=MacAddress("66:77:88:99:AA:BB"),
tpa=Ip4Address("5.6.7.8"),
- oper=ARP_OP_REPLY,
+ oper=ArpOperation.REPLY,
)
frame = memoryview(bytearray(len(packet)))
packet.assemble(frame)
diff --git a/tests/unit/protocols__arp__phtx.py b/tests/unit/protocols__arp__phtx.py
index d2c7d487..b6c88071 100755
--- a/tests/unit/protocols__arp__phtx.py
+++ b/tests/unit/protocols__arp__phtx.py
@@ -36,7 +36,7 @@
from pytcp.lib.packet_stats import PacketStatsTx
from pytcp.lib.tx_status import TxStatus
-from pytcp.protocols.arp.ps import ARP_OP_REPLY, ARP_OP_REQUEST
+from pytcp.protocols.arp.ps import ArpOperation
from pytcp.subsystems.packet_handler import PacketHandler
from tests.unit.mock_network import (
MockNetworkSettings,
@@ -72,7 +72,7 @@ def test_arp_phtx__arp_request(self) -> None:
tx_status = self.packet_handler._phtx_arp(
ether_src=self.mns.stack_mac_address,
ether_dst=self.mns.mac_broadcast,
- arp_oper=ARP_OP_REQUEST,
+ arp_oper=ArpOperation.REQUEST,
arp_sha=self.mns.stack_mac_address,
arp_spa=self.mns.stack_ip4_host.address,
arp_tha=self.mns.mac_unspecified,
@@ -100,7 +100,7 @@ def test_arp_phtx__arp_reply(self) -> None:
tx_status = self.packet_handler._phtx_arp(
ether_src=self.mns.stack_mac_address,
ether_dst=self.mns.host_a_mac_address,
- arp_oper=ARP_OP_REPLY,
+ arp_oper=ArpOperation.REPLY,
arp_sha=self.mns.stack_mac_address,
arp_spa=self.mns.stack_ip4_host.address,
arp_tha=self.mns.host_a_mac_address,
diff --git a/tests/unit/protocols__ether__fpa.py b/tests/unit/protocols__ether__fpa.py
index 3208e2f2..f65b6353 100755
--- a/tests/unit/protocols__ether__fpa.py
+++ b/tests/unit/protocols__ether__fpa.py
@@ -37,7 +37,7 @@
from pytcp.lib.tracker import Tracker
from pytcp.protocols.arp.fpa import ArpAssembler
from pytcp.protocols.ether.fpa import EtherAssembler
-from pytcp.protocols.ether.ps import ETHER_HEADER_LEN, ETHER_TYPE_RAW
+from pytcp.protocols.ether.ps import ETHER_HEADER_LEN, EtherType
from pytcp.protocols.ip4.fpa import Ip4Assembler
from pytcp.protocols.ip6.fpa import Ip6Assembler
from pytcp.protocols.raw.fpa import RawAssembler
@@ -61,7 +61,7 @@ def test_ether_fpa____init__(self) -> None:
self.assertEqual(packet._tracker, packet._carried_packet._tracker)
self.assertEqual(packet._src, MacAddress("00:11:22:33:44:55"))
self.assertEqual(packet._dst, MacAddress("66:77:88:99:AA:BB"))
- self.assertEqual(packet._type, ETHER_TYPE_RAW)
+ self.assertEqual(packet._type, EtherType.RAW)
def test_ether_fpa____init____defaults(self) -> None:
"""
@@ -72,7 +72,7 @@ def test_ether_fpa____init____defaults(self) -> None:
self.assertEqual(packet._tracker, packet._carried_packet._tracker)
self.assertEqual(packet._src, MacAddress("00:00:00:00:00:00"))
self.assertEqual(packet._dst, MacAddress("00:00:00:00:00:00"))
- self.assertEqual(packet._type, ETHER_TYPE_RAW)
+ self.assertEqual(packet._type, EtherType.RAW)
def test_ether_fpa____init____assert_ethertype_arp(self) -> None:
"""
@@ -98,16 +98,6 @@ def test_ether_fpa____init____assert_ethertype_raw(self) -> None:
"""
EtherAssembler(carried_packet=RawAssembler())
- def test_ether_fpa____init____assert_ethertype_unknown(self) -> None:
- """
- Test assertion for carried packet 'ether_type' attribute.
- """
- with self.assertRaises(AssertionError):
- carried_packet_mock = StrictMock()
- carried_packet_mock.ether_type = -1
- carried_packet_mock.tracker = StrictMock(Tracker)
- EtherAssembler(carried_packet=carried_packet_mock) # type: ignore[arg-type]
-
def test_ether_fpa____len__(self) -> None:
"""
Test the '__len__' dunder.
@@ -139,7 +129,7 @@ def test_ether_fpa__tracker_getter(self) -> None:
"""
carried_packet_mock = StrictMock()
- carried_packet_mock.ether_type = ETHER_TYPE_RAW
+ carried_packet_mock.ether_type = EtherType.RAW
carried_packet_mock.tracker = StrictMock(Tracker)
packet = EtherAssembler(carried_packet=carried_packet_mock) # type: ignore[arg-type]
diff --git a/tests/unit/protocols__icmp4__fpa.py b/tests/unit/protocols__icmp4__fpa.py
index f4c309d7..c8e4a0e7 100755
--- a/tests/unit/protocols__icmp4__fpa.py
+++ b/tests/unit/protocols__icmp4__fpa.py
@@ -36,13 +36,13 @@
from pytcp.lib.tracker import Tracker
from pytcp.protocols.icmp4.fpa import Icmp4Assembler
from pytcp.protocols.icmp4.ps import (
- ICMP4_ECHO_REPLY,
ICMP4_ECHO_REPLY_LEN,
- ICMP4_ECHO_REQUEST,
ICMP4_ECHO_REQUEST_LEN,
- ICMP4_UNREACHABLE,
- ICMP4_UNREACHABLE__PORT,
ICMP4_UNREACHABLE_LEN,
+ Icmp4EchoReplyCode,
+ Icmp4EchoRequestCode,
+ Icmp4Type,
+ Icmp4UnreachableCode,
)
from pytcp.protocols.ip4.ps import IP4_PROTO_ICMP4
@@ -63,8 +63,8 @@ def test_icmp4_fpa____init____echo_request(self) -> None:
Test the packet constructor for the 'Echo Request' message.
"""
packet = Icmp4Assembler(
- type=ICMP4_ECHO_REQUEST,
- code=0,
+ type=Icmp4Type.ECHO_REQUEST,
+ code=Icmp4EchoRequestCode.DEFAULT,
ec_id=12345,
ec_seq=54321,
ec_data=b"0123456789ABCDEF",
@@ -79,28 +79,6 @@ def test_icmp4_fpa____init____echo_request(self) -> None:
)
)
- def test_icmp4_fpa____init____echo_request__assert_code__under(
- self,
- ) -> None:
- """
- Test packet constructor for the 'Echo Request' message.
- """
- with self.assertRaises(AssertionError):
- Icmp4Assembler(
- type=ICMP4_ECHO_REQUEST,
- code=-1,
- )
-
- def test_icmp4_fpa____init____echo_request__assert_code__over(self) -> None:
- """
- Test packet constructor for the 'Echo Request' message.
- """
- with self.assertRaises(AssertionError):
- Icmp4Assembler(
- type=ICMP4_ECHO_REQUEST,
- code=1,
- )
-
def test_icmp4_fpa____init____echo_request__assert_ec_id__under(
self,
) -> None:
@@ -109,8 +87,8 @@ def test_icmp4_fpa____init____echo_request__assert_ec_id__under(
"""
with self.assertRaises(AssertionError):
Icmp4Assembler(
- type=ICMP4_ECHO_REQUEST,
- code=0,
+ type=Icmp4Type.ECHO_REQUEST,
+ code=Icmp4EchoRequestCode.DEFAULT,
ec_id=-1,
)
@@ -122,8 +100,8 @@ def test_icmp4_fpa____init____echo_request__assert_ec_id__over(
"""
with self.assertRaises(AssertionError):
Icmp4Assembler(
- type=ICMP4_ECHO_REQUEST,
- code=0,
+ type=Icmp4Type.ECHO_REQUEST,
+ code=Icmp4EchoRequestCode.DEFAULT,
ec_id=0x10000,
)
@@ -135,8 +113,8 @@ def test_icmp4_fpa____init____echo_request__assert_ec_seq__under(
"""
with self.assertRaises(AssertionError):
Icmp4Assembler(
- type=ICMP4_ECHO_REQUEST,
- code=0,
+ type=Icmp4Type.ECHO_REQUEST,
+ code=Icmp4EchoRequestCode.DEFAULT,
ec_seq=-1,
)
@@ -148,8 +126,8 @@ def test_icmp4_fpa____init____echo_request__assert_ec_seq__over(
"""
with self.assertRaises(AssertionError):
Icmp4Assembler(
- type=ICMP4_ECHO_REQUEST,
- code=0,
+ type=Icmp4Type.ECHO_REQUEST,
+ code=Icmp4EchoRequestCode.DEFAULT,
ec_seq=0x10000,
)
@@ -158,8 +136,8 @@ def test_icmp4_fpa____init____unreachable_port(self) -> None:
Test packet constructor for the 'Unreachable Port' message.
"""
packet = Icmp4Assembler(
- type=ICMP4_UNREACHABLE,
- code=ICMP4_UNREACHABLE__PORT,
+ type=Icmp4Type.UNREACHABLE,
+ code=Icmp4UnreachableCode.PORT,
un_data=b"0123456789ABCDEF" * 50,
echo_tracker=Tracker(prefix="TX"),
)
@@ -170,37 +148,13 @@ def test_icmp4_fpa____init____unreachable_port(self) -> None:
)
)
- def test_icmp4_fpa____init____unreachable_port__assert_code__under(
- self,
- ) -> None:
- """
- Test packet constructor for the 'Unreachable Port' message.
- """
- with self.assertRaises(AssertionError):
- Icmp4Assembler(
- type=ICMP4_ECHO_REQUEST,
- code=ICMP4_UNREACHABLE__PORT - 1,
- )
-
- def test_icmp4_fpa____init____unreachable_port__assert_code__over(
- self,
- ) -> None:
- """
- Test packet constructor for the 'Unreachable Port' message.
- """
- with self.assertRaises(AssertionError):
- Icmp4Assembler(
- type=ICMP4_ECHO_REQUEST,
- code=ICMP4_UNREACHABLE__PORT + 1,
- )
-
def test_icmp4_fpa____init____echo_reply(self) -> None:
"""
Test packet constructor for the 'Echo Reply' message.
"""
packet = Icmp4Assembler(
- type=ICMP4_ECHO_REPLY,
- code=0,
+ type=Icmp4Type.ECHO_REPLY,
+ code=Icmp4EchoReplyCode.DEFAULT,
ec_id=12345,
ec_seq=54321,
ec_data=b"0123456789ABCDEF",
@@ -215,34 +169,14 @@ def test_icmp4_fpa____init____echo_reply(self) -> None:
)
)
- def test_icmp4_fpa____init____echo_reply__assert_code__under(self) -> None:
- """
- Test packet constructor for the 'Echo Reply' message.
- """
- with self.assertRaises(AssertionError):
- Icmp4Assembler(
- type=ICMP4_ECHO_REPLY,
- code=-1,
- )
-
- def test_icmp4_fpa____init____echo_reply__assert_code__over(self) -> None:
- """
- Test packet constructor for the 'Echo Reply' message.
- """
- with self.assertRaises(AssertionError):
- Icmp4Assembler(
- type=ICMP4_ECHO_REPLY,
- code=1,
- )
-
def test_icmp4_fpa____init____echo_reply__assert_ec_id__under(self) -> None:
"""
Test assertion for the 'ec_id' argument.
"""
with self.assertRaises(AssertionError):
Icmp4Assembler(
- type=ICMP4_ECHO_REPLY,
- code=0,
+ type=Icmp4Type.ECHO_REPLY,
+ code=Icmp4EchoReplyCode.DEFAULT,
ec_id=-1,
)
@@ -252,8 +186,8 @@ def test_icmp4_fpa____init____echo_reply__assert_ec_id__over(self) -> None:
"""
with self.assertRaises(AssertionError):
Icmp4Assembler(
- type=ICMP4_ECHO_REPLY,
- code=0,
+ type=Icmp4Type.ECHO_REPLY,
+ code=Icmp4EchoReplyCode.DEFAULT,
ec_id=0x10000,
)
@@ -265,8 +199,8 @@ def test_icmp4_fpa____init____echo_reply__assert_ec_seq__under(
"""
with self.assertRaises(AssertionError):
Icmp4Assembler(
- type=ICMP4_ECHO_REPLY,
- code=0,
+ type=Icmp4Type.ECHO_REPLY,
+ code=Icmp4EchoReplyCode.DEFAULT,
ec_seq=-1,
)
@@ -276,27 +210,18 @@ def test_icmp4_fpa____init____echo_reply__assert_ec_seq__over(self) -> None:
"""
with self.assertRaises(AssertionError):
Icmp4Assembler(
- type=ICMP4_ECHO_REPLY,
- code=0,
+ type=Icmp4Type.ECHO_REPLY,
+ code=Icmp4EchoReplyCode.DEFAULT,
ec_seq=0x10000,
)
- def test_icmp4_fpa____init____unknown(self) -> None:
- """
- Test packet constructor for the message with unknown type.
- """
- with self.assertRaises(AssertionError):
- Icmp4Assembler(
- type=255,
- )
-
def test_icmp4_fpa____len____echo_reply(self) -> None:
"""
Test the '__len__()' dunder for the 'Echo Reply' message.
"""
packet = Icmp4Assembler(
- type=ICMP4_ECHO_REPLY,
- code=0,
+ type=Icmp4Type.ECHO_REPLY,
+ code=Icmp4EchoReplyCode.DEFAULT,
ec_data=b"0123456789ABCDEF",
)
self.assertEqual(len(packet), ICMP4_ECHO_REPLY_LEN + 16)
@@ -306,8 +231,8 @@ def test_icmp4_fpa____len____unreachable_port(self) -> None:
Test the '__len__()' dunder for the 'Unreachable Port' message.
"""
packet = Icmp4Assembler(
- type=ICMP4_UNREACHABLE,
- code=ICMP4_UNREACHABLE__PORT,
+ type=Icmp4Type.UNREACHABLE,
+ code=Icmp4UnreachableCode.PORT,
un_data=b"0123456789ABCDEF",
)
self.assertEqual(len(packet), ICMP4_UNREACHABLE_LEN + 16)
@@ -317,8 +242,8 @@ def test_icmp4_fpa____len____echo_request(self) -> None:
Test the '__len__() dudner for the 'Echo Request' message.
"""
packet = Icmp4Assembler(
- type=ICMP4_ECHO_REQUEST,
- code=0,
+ type=Icmp4Type.ECHO_REQUEST,
+ code=Icmp4EchoRequestCode.DEFAULT,
ec_data=b"0123456789ABCDEF",
)
self.assertEqual(len(packet), ICMP4_ECHO_REQUEST_LEN + 16)
@@ -328,8 +253,8 @@ def test_icmp4_fpa____str____echo_reply(self) -> None:
Test the '__str__()' dunder for the 'Echo Reply' message.
"""
packet = Icmp4Assembler(
- type=ICMP4_ECHO_REPLY,
- code=0,
+ type=Icmp4Type.ECHO_REPLY,
+ code=Icmp4EchoReplyCode.DEFAULT,
ec_id=12345,
ec_seq=54321,
ec_data=b"0123456789ABCDEF",
@@ -343,8 +268,8 @@ def test_icmp4_fpa____str____unreachable_port(self) -> None:
Test the '__str__() dunder for the 'Unreachable Port' message.
"""
packet = Icmp4Assembler(
- type=ICMP4_UNREACHABLE,
- code=ICMP4_UNREACHABLE__PORT,
+ type=Icmp4Type.UNREACHABLE,
+ code=Icmp4UnreachableCode.PORT,
un_data=b"0123456789ABCDEF",
)
self.assertEqual(str(packet), "ICMPv4 3/3 (unreachable_port), dlen 16")
@@ -354,8 +279,8 @@ def test_icmp4_fpa____str____echo_request(self) -> None:
Test the '__str__()' dunder for the 'Echo Request' message.
"""
packet = Icmp4Assembler(
- type=ICMP4_ECHO_REQUEST,
- code=0,
+ type=Icmp4Type.ECHO_REQUEST,
+ code=Icmp4EchoRequestCode.DEFAULT,
ec_id=12345,
ec_seq=54321,
ec_data=b"0123456789ABCDEF",
@@ -369,7 +294,10 @@ def test_icmp4_fpa__tracker_getter(self) -> None:
"""
Test the '_tracker' attribute getter.
"""
- packet = Icmp4Assembler()
+ packet = Icmp4Assembler(
+ type=Icmp4Type.ECHO_REQUEST,
+ code=Icmp4EchoRequestCode.DEFAULT,
+ )
self.assertTrue(
repr(packet.tracker).startswith("Tracker(serial='TX")
)
@@ -379,8 +307,8 @@ def test_icmp4_fpa__assemble__echo_reply(self) -> None:
Test the 'assemble()' method for the 'Echo Reply' message.
"""
packet = Icmp4Assembler(
- type=ICMP4_ECHO_REPLY,
- code=0,
+ type=Icmp4Type.ECHO_REPLY,
+ code=Icmp4EchoReplyCode.DEFAULT,
ec_id=12345,
ec_seq=54321,
ec_data=b"0123456789ABCDEF",
@@ -394,8 +322,8 @@ def test_icmp4_fpa__asssemble__unreachable_port(self) -> None:
Test the 'assemble()' method for the 'Unreachable Port' message.
"""
packet = Icmp4Assembler(
- type=ICMP4_UNREACHABLE,
- code=ICMP4_UNREACHABLE__PORT,
+ type=Icmp4Type.UNREACHABLE,
+ code=Icmp4UnreachableCode.PORT,
un_data=b"0123456789ABCDEF",
)
frame = memoryview(bytearray(len(packet)))
@@ -409,8 +337,8 @@ def test_icmp4_fpa__assemble__echo_request(self) -> None:
Test the 'assemble()' method for the 'Echo Request' message.
"""
packet = Icmp4Assembler(
- type=ICMP4_ECHO_REQUEST,
- code=0,
+ type=Icmp4Type.ECHO_REQUEST,
+ code=Icmp4EchoRequestCode.DEFAULT,
ec_id=12345,
ec_seq=54321,
ec_data=b"0123456789ABCDEF",
diff --git a/tests/unit/protocols__icmp4__phtx.py b/tests/unit/protocols__icmp4__phtx.py
index 06792888..c3344616 100755
--- a/tests/unit/protocols__icmp4__phtx.py
+++ b/tests/unit/protocols__icmp4__phtx.py
@@ -37,10 +37,10 @@
from pytcp.lib.packet_stats import PacketStatsTx
from pytcp.lib.tx_status import TxStatus
from pytcp.protocols.icmp4.ps import (
- ICMP4_ECHO_REPLY,
- ICMP4_ECHO_REQUEST,
- ICMP4_UNREACHABLE,
- ICMP4_UNREACHABLE__PORT,
+ Icmp4EchoReplyCode,
+ Icmp4EchoRequestCode,
+ Icmp4Type,
+ Icmp4UnreachableCode,
)
from pytcp.subsystems.packet_handler import PacketHandler
from tests.unit.mock_network import (
@@ -78,7 +78,8 @@ def test_icmp4_phtx__ip4_icmp4_echo_request(self) -> None:
tx_status = self.packet_handler._phtx_icmp4(
ip4_src=self.mns.stack_ip4_host.address,
ip4_dst=self.mns.host_a_ip4_address,
- icmp4_type=ICMP4_ECHO_REQUEST,
+ icmp4_type=Icmp4Type.ECHO_REQUEST,
+ icmp4_code=Icmp4EchoRequestCode.DEFAULT,
icmp4_ec_id=12345,
icmp4_ec_seq=54320,
icmp4_ec_data=b"0123456789ABCDEF" * 20,
@@ -108,7 +109,8 @@ def test_icmp4_phtx__ip4_icmp4_echo_reply(self) -> None:
tx_status = self.packet_handler._phtx_icmp4(
ip4_src=self.mns.stack_ip4_host.address,
ip4_dst=self.mns.host_a_ip4_address,
- icmp4_type=ICMP4_ECHO_REPLY,
+ icmp4_type=Icmp4Type.ECHO_REPLY,
+ icmp4_code=Icmp4EchoReplyCode.DEFAULT,
icmp4_ec_id=12345,
icmp4_ec_seq=54320,
icmp4_ec_data=b"0123456789ABCDEF" * 20,
@@ -138,8 +140,8 @@ def test_icmp4_phtx__ip4_icmp4_unreachable_port(self) -> None:
tx_status = self.packet_handler._phtx_icmp4(
ip4_src=self.mns.stack_ip4_host.address,
ip4_dst=self.mns.host_a_ip4_address,
- icmp4_type=ICMP4_UNREACHABLE,
- icmp4_code=ICMP4_UNREACHABLE__PORT,
+ icmp4_type=Icmp4Type.UNREACHABLE,
+ icmp4_code=Icmp4UnreachableCode.PORT,
icmp4_un_data=b"0123456789ABCDEF" * 100,
)
self.assertEqual(tx_status, TxStatus.PASSED__ETHER__TO_TX_RING)
diff --git a/tests/unit/protocols__ip4__fpa.py b/tests/unit/protocols__ip4__fpa.py
index 137eb220..226771b6 100755
--- a/tests/unit/protocols__ip4__fpa.py
+++ b/tests/unit/protocols__ip4__fpa.py
@@ -36,8 +36,9 @@
from pytcp.config import IP4_DEFAULT_TTL
from pytcp.lib.ip4_address import Ip4Address
from pytcp.lib.tracker import Tracker
-from pytcp.protocols.ether.ps import ETHER_TYPE_IP4
+from pytcp.protocols.ether.ps import EtherType
from pytcp.protocols.icmp4.fpa import Icmp4Assembler
+from pytcp.protocols.icmp4.ps import Icmp4EchoReplyCode, Icmp4Type
from pytcp.protocols.ip4.fpa import (
Ip4Assembler,
Ip4FragAssembler,
@@ -68,7 +69,7 @@ def test_ip4_fpa__ethertype(self) -> None:
Make sure the 'Ip4Assembler' class has the proper
'ethertype' value assigned.
"""
- self.assertEqual(Ip4Assembler.ether_type, ETHER_TYPE_IP4)
+ self.assertEqual(Ip4Assembler.ether_type, EtherType.IP4)
def test_ip4_fpa____init__(self) -> None:
"""
@@ -203,7 +204,12 @@ def test_ip4_fpa____init____assert_proto_icmp4(self) -> None:
"""
Test assertion for the carried packet 'ip4_proto' attribute.
"""
- Ip4Assembler(carried_packet=Icmp4Assembler())
+ Ip4Assembler(
+ carried_packet=Icmp4Assembler(
+ type=Icmp4Type.ECHO_REPLY,
+ code=Icmp4EchoReplyCode.DEFAULT,
+ )
+ )
def test_ip4_fpa____init____assert_proto_raw(self) -> None:
"""
@@ -417,7 +423,7 @@ def test_ip4_frag_fpa__ethertype(self) -> None:
"""
Test the 'ethertype' property of the 'Ip4FragAssembler' class.
"""
- self.assertEqual(Ip4Assembler.ether_type, ETHER_TYPE_IP4)
+ self.assertEqual(Ip4Assembler.ether_type, EtherType.IP4)
def test_ip4_frag_fpa____init__(self) -> None:
"""
diff --git a/tests/unit/protocols__ip6__fpa.py b/tests/unit/protocols__ip6__fpa.py
index 2b7f482a..41170a09 100755
--- a/tests/unit/protocols__ip6__fpa.py
+++ b/tests/unit/protocols__ip6__fpa.py
@@ -36,7 +36,7 @@
from pytcp.config import IP6_DEFAULT_HOP
from pytcp.lib.ip6_address import Ip6Address
from pytcp.lib.tracker import Tracker
-from pytcp.protocols.ether.ps import ETHER_TYPE_IP6
+from pytcp.protocols.ether.ps import EtherType
from pytcp.protocols.icmp6.fpa import Icmp6Assembler
from pytcp.protocols.ip6.fpa import Ip6Assembler
from pytcp.protocols.ip6.ps import IP6_HEADER_LEN, IP6_NEXT_RAW
@@ -55,7 +55,7 @@ def test_ip6_fpa__ethertype(self) -> None:
Make sure the 'Ip6Assembler' class has the proper
'ethertype' value assigned.
"""
- self.assertEqual(Ip6Assembler.ether_type, ETHER_TYPE_IP6)
+ self.assertEqual(Ip6Assembler.ether_type, EtherType.IP6)
def test_ip6_fpa____init__(self) -> None:
"""