diff --git a/pytcp/lib/packet_stats.py b/pytcp/lib/packet_stats.py index e81847e2..cee8cc56 100755 --- a/pytcp/lib/packet_stats.py +++ b/pytcp/lib/packet_stats.py @@ -134,6 +134,8 @@ class PacketStatsRx: tcp__no_socket_match__rst__drop: int = 0 tcp__no_socket_match__respond_rst: int = 0 + raw__socket_match: int = 0 + @dataclass class PacketStatsTx: diff --git a/pytcp/socket/raw__metadata.py b/pytcp/socket/raw__metadata.py index 91b19ac1..ad6711c5 100755 --- a/pytcp/socket/raw__metadata.py +++ b/pytcp/socket/raw__metadata.py @@ -36,10 +36,11 @@ from __future__ import annotations from dataclasses import dataclass -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING from pytcp.protocols.enums import IpProto from pytcp.socket.socket import AddressFamily, SocketType +from pytcp.socket.socket_id import SocketId if TYPE_CHECKING: from net_addr import IpAddress @@ -62,13 +63,13 @@ class RawMetadata: tracker: Tracker | None = None @property - def socket_ids(self) -> list[tuple[Any, ...]]: + def socket_ids(self) -> list[SocketId]: """ Get list of the listening socket IDs that match the metadata. """ return [ - ( + SocketId( AddressFamily.from_ver(self.ip__ver), SocketType.RAW, self.ip__local_address.unspecified, diff --git a/pytcp/stack/packet_handler/packet_handler__ip4__rx.py b/pytcp/stack/packet_handler/packet_handler__ip4__rx.py index a7c3cab3..5c758292 100755 --- a/pytcp/stack/packet_handler/packet_handler__ip4__rx.py +++ b/pytcp/stack/packet_handler/packet_handler__ip4__rx.py @@ -38,7 +38,7 @@ import struct from abc import ABC from time import time -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, cast from pytcp import stack from pytcp.lib.inet_cksum import inet_cksum @@ -48,6 +48,8 @@ from pytcp.protocols.errors import PacketValidationError from pytcp.protocols.ip4.ip4__header import IP4__HEADER__LEN from pytcp.protocols.ip4.ip4__parser import Ip4Parser +from pytcp.socket.raw__metadata import RawMetadata +from pytcp.socket.raw__socket import RawSocket class PacketHandlerIp4Rx(ABC): @@ -134,6 +136,29 @@ def _phrx_ip4(self, packet_rx: PacketRx, /) -> None: packet_rx = defragmented_packet_rx self.packet_stats_rx.ip4__defrag += 1 + # Create RawMetadata object and try to find matching RAW socket + packet_rx_md = RawMetadata( + ip__ver=packet_rx.ip.ver, + ip__local_address=packet_rx.ip.dst, + ip__remote_address=packet_rx.ip.src, + ip__proto=packet_rx.ip4.proto, + raw__data=bytes( + packet_rx.ip4.payload_bytes + ), # memoryview: conversion for end-user interface + tracker=packet_rx.tracker, + ) + + for socket_id in packet_rx_md.socket_ids: + if socket := cast(RawSocket, stack.sockets.get(socket_id, None)): + self.packet_stats_rx.raw__socket_match += 1 + __debug__ and log( + "ip4", + f"{packet_rx_md.tracker} - Found matching listening " + f"socket [{socket}]", + ) + socket.process_raw_packet(packet_rx_md) + return + match packet_rx.ip4.proto: case IpProto.ICMP4: self._phrx_icmp4(packet_rx) diff --git a/pytcp/stack/packet_handler/packet_handler__ip6__rx.py b/pytcp/stack/packet_handler/packet_handler__ip6__rx.py index bc235793..e8e69d02 100755 --- a/pytcp/stack/packet_handler/packet_handler__ip6__rx.py +++ b/pytcp/stack/packet_handler/packet_handler__ip6__rx.py @@ -36,13 +36,16 @@ from __future__ import annotations from abc import ABC -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, cast +from pytcp import stack from pytcp.lib.logger import log from pytcp.lib.packet import PacketRx from pytcp.protocols.enums import IpProto from pytcp.protocols.errors import PacketValidationError from pytcp.protocols.ip6.ip6__parser import Ip6Parser +from pytcp.socket.raw__metadata import RawMetadata +from pytcp.socket.raw__socket import RawSocket class PacketHandlerIp6Rx(ABC): @@ -105,6 +108,29 @@ def _phrx_ip6(self, packet_rx: PacketRx, /) -> None: if packet_rx.ip6.dst in self.ip6_multicast: self.packet_stats_rx.ip6__dst_multicast += 1 + # Create RawMetadata object and try to find matching RAW socket + packet_rx_md = RawMetadata( + ip__ver=packet_rx.ip.ver, + ip__local_address=packet_rx.ip.dst, + ip__remote_address=packet_rx.ip.src, + ip__proto=packet_rx.ip6.next, + raw__data=bytes( + packet_rx.ip6.payload_bytes + ), # memoryview: conversion for end-user interface + tracker=packet_rx.tracker, + ) + + for socket_id in packet_rx_md.socket_ids: + if socket := cast(RawSocket, stack.sockets.get(socket_id, None)): + self.packet_stats_rx.raw__socket_match += 1 + __debug__ and log( + "ip6", + f"{packet_rx_md.tracker} - Found matching listening " + f"socket [{socket}]", + ) + socket.process_raw_packet(packet_rx_md) + return + match packet_rx.ip6.next: case IpProto.IP6_FRAG: self._phrx_ip6_frag(packet_rx)