From 238c24f9be6bdfc762535ed8f4c0b311611017c9 Mon Sep 17 00:00:00 2001 From: Sebastian Majewski Date: Thu, 11 Jul 2024 20:18:50 -0500 Subject: [PATCH] Converted TCP to protocol header class --- pytcp/protocols/tcp/fpa.py | 40 ++++---- pytcp/protocols/tcp/fpp.py | 40 +++----- pytcp/protocols/tcp/ps.py | 187 ++++++++++++++++++++++--------------- 3 files changed, 150 insertions(+), 117 deletions(-) diff --git a/pytcp/protocols/tcp/fpa.py b/pytcp/protocols/tcp/fpa.py index cd61c894..e3ef73c5 100755 --- a/pytcp/protocols/tcp/fpa.py +++ b/pytcp/protocols/tcp/fpa.py @@ -45,6 +45,7 @@ from pytcp.protocols.tcp.ps import ( TCP_HEADER_LEN, Tcp, + TcpHeader, TcpOption, TcpOptionEol, TcpOptionMss, @@ -76,6 +77,7 @@ def __init__( tcp__flag_syn: bool = False, tcp__flag_fin: bool = False, tcp__win: int = 0, + tcp_cksum: int = 0, tcp__urg: int = 0, tcp__options: list[TcpOption] | None = None, tcp__data: bytes | None = None, @@ -94,31 +96,37 @@ def __init__( self._tracker: Tracker = Tracker(prefix="TX", echo_tracker=echo_tracker) - self._sport = tcp__sport - self._dport = tcp__dport - self._seq = tcp__seq - self._ack = tcp__ack - self._flag_ns = tcp__flag_ns - self._flag_cwr = tcp__flag_cwr - self._flag_ece = tcp__flag_ece - self._flag_urg = tcp__flag_urg - self._flag_ack = tcp__flag_ack - self._flag_psh = tcp__flag_psh - self._flag_rst = tcp__flag_rst - self._flag_syn = tcp__flag_syn - self._flag_fin = tcp__flag_fin - self._win = tcp__win - self._urg = tcp__urg + self._data = b"" if tcp__data is None else tcp__data + self._options: list[TcpOption] = ( [] if tcp__options is None else tcp__options ) - self._data = b"" if tcp__data is None else tcp__data self._olen = sum(len(option) for option in self._options) self._hlen = TCP_HEADER_LEN + self._olen self._dlen = len(self._data) self._plen = self._hlen + self._dlen + self._header = TcpHeader( + sport=tcp__sport, + dport=tcp__dport, + seq=tcp__seq, + ack=tcp__ack, + hlen=self._hlen, + flag_ns=tcp__flag_ns, + flag_cwr=tcp__flag_cwr, + flag_ece=tcp__flag_ece, + flag_urg=tcp__flag_urg, + flag_ack=tcp__flag_ack, + flag_psh=tcp__flag_psh, + flag_rst=tcp__flag_rst, + flag_syn=tcp__flag_syn, + flag_fin=tcp__flag_fin, + win=tcp__win, + cksum=tcp_cksum, + urg=tcp__urg, + ) + assert self._hlen % 4 == 0, ( f"TCP header len {self._hlen!r} is not multiplication of 4 bytes, " f"check options: {self._options!r}" diff --git a/pytcp/protocols/tcp/fpp.py b/pytcp/protocols/tcp/fpp.py index a93f24ae..ae83525a 100755 --- a/pytcp/protocols/tcp/fpp.py +++ b/pytcp/protocols/tcp/fpp.py @@ -49,6 +49,7 @@ TCP_HEADER_LEN, TCP_OPTION_LEN__NOP, Tcp, + TcpHeader, TcpOption, TcpOptionEol, TcpOptionMss, @@ -69,7 +70,7 @@ class TcpIntegrityError(PacketIntegrityError): Exception raised when TCP packet integrity check fails. """ - def __init__(self, message: str): + def __init__(self, /, message: str): super().__init__("[TCP] " + message) @@ -78,7 +79,7 @@ class TcpSanityError(PacketSanityError): Exception raised when TCP packet sanity check fails. """ - def __init__(self, message: str): + def __init__(self, /, message: str): super().__init__("[TCP] " + message) @@ -167,25 +168,10 @@ def _parse(self) -> None: Parse TCP packet. """ - self._sport = struct.unpack("!H", self._frame[0:2])[0] - self._dport = struct.unpack("!H", self._frame[2:4])[0] - self._seq = struct.unpack("!L", self._frame[4:8])[0] - self._ack = struct.unpack("!L", self._frame[8:12])[0] - self._hlen = (self._frame[12] & 0b11110000) >> 2 - self._flag_ns = bool(self._frame[12] & 0b00000001) - self._flag_cwr = bool(self._frame[13] & 0b10000000) - self._flag_ece = bool(self._frame[13] & 0b01000000) - self._flag_urg = bool(self._frame[13] & 0b00100000) - self._flag_ack = bool(self._frame[13] & 0b00010000) - self._flag_psh = bool(self._frame[13] & 0b00001000) - self._flag_rst = bool(self._frame[13] & 0b00000100) - self._flag_syn = bool(self._frame[13] & 0b00000010) - self._flag_fin = bool(self._frame[13] & 0b00000001) - self._win = struct.unpack("!H", self._frame[14:16])[0] - self._cksum = struct.unpack("!H", self._frame[16:18])[0] - self._urg = struct.unpack("!H", self._frame[18:20])[0] + self._header = TcpHeader.from_frame(self._frame) self._plen = self._ip__dlen + self._hlen = self._header.hlen self._olen = self._hlen - TCP_HEADER_LEN self._dlen = self._plen - self._hlen @@ -236,42 +222,42 @@ def _validate_sanity(self) -> None: Check sanity of incoming packet after it has been parsed. """ - if self._sport == 0: + if self._header.sport == 0: raise TcpSanityError( "The 'sport' must be greater than 0.", ) - if self._dport == 0: + if self._header.dport == 0: raise TcpSanityError( "The 'dport' must be greater than 0.", ) - if self._flag_syn and self._flag_fin: + if self._header.flag_syn and self._header.flag_fin: raise TcpSanityError( "The 'flag_syn' and 'flag_fin' must not be set simultaneously.", ) - if self._flag_syn and self._flag_rst: + if self._header.flag_syn and self._header.flag_rst: raise TcpSanityError( "The 'flag_syn' and 'flag_rst' must not set simultaneously.", ) - if self._flag_fin and self._flag_rst: + if self._header.flag_fin and self._header.flag_rst: raise TcpSanityError( "The 'flag_fin' and 'flag_rst' must not be set simultaneously.", ) - if self._flag_fin and not self._flag_ack: + if self._header.flag_fin and not self._header.flag_ack: raise TcpSanityError( "The 'flag_ack' must be set when 'flag_fin' is set.", ) - if self._ack and not self._flag_ack: + if self._header.ack and not self._header.flag_ack: raise TcpSanityError( "The 'flag_ack' must be set when 'ack' is not 0.", ) - if self._urg and not self._flag_urg: + if self._header.urg and not self._header.flag_urg: raise TcpSanityError( "The 'flag_urg' must be set when 'urg' is not 0.", ) diff --git a/pytcp/protocols/tcp/ps.py b/pytcp/protocols/tcp/ps.py index 47e183b2..be179f4d 100755 --- a/pytcp/protocols/tcp/ps.py +++ b/pytcp/protocols/tcp/ps.py @@ -35,6 +35,7 @@ from __future__ import annotations import struct +from dataclasses import dataclass from typing import override from pytcp.lib.enum import ProtoEnum @@ -59,12 +60,67 @@ # +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +@dataclass +class TcpHeader: + """ + Data class representing the TCP header. + """ + + sport: int + dport: int + seq: int + ack: int + hlen: int + flag_ns: bool + flag_cwr: bool + flag_ece: bool + flag_urg: bool + flag_ack: bool + flag_psh: bool + flag_rst: bool + flag_syn: bool + flag_fin: bool + win: int + cksum: int + urg: int + + @staticmethod + def from_frame(frame: bytes) -> TcpHeader: + """ + Populate the header from the raw frame. + """ + + return TcpHeader( + sport=struct.unpack("!H", frame[0:2])[0], + dport=struct.unpack("!H", frame[2:4])[0], + seq=struct.unpack("!L", frame[4:8])[0], + ack=struct.unpack("!L", frame[8:12])[0], + hlen=(frame[12] & 0b11110000) >> 2, + flag_ns=bool(frame[12] & 0b00000001), + flag_cwr=bool(frame[13] & 0b10000000), + flag_ece=bool(frame[13] & 0b01000000), + flag_urg=bool(frame[13] & 0b00100000), + flag_ack=bool(frame[13] & 0b00010000), + flag_psh=bool(frame[13] & 0b00001000), + flag_rst=bool(frame[13] & 0b00000100), + flag_syn=bool(frame[13] & 0b00000010), + flag_fin=bool(frame[13] & 0b00000001), + win=struct.unpack("!H", frame[14:16])[0], + cksum=struct.unpack("!H", frame[16:18])[0], + urg=struct.unpack("!H", frame[18:20])[0], + ) + + TCP_HEADER_LEN = 20 TCP_DEFAULT_MSS = 536 class TcpOptionType(ProtoEnum): + """ + TCP option types. + """ + EOL = 0 NOP = 1 MSS = 2 @@ -93,25 +149,10 @@ class Tcp(Proto): _ip6_next = Ip6Next.TCP _ip4_proto = Ip4Proto.TCP - _sport: int - _dport: int - _seq: int - _ack: int - _res: int - _flag_ns: bool - _flag_cwr: bool - _flag_ece: bool - _flag_urg: bool - _flag_ack: bool - _flag_psh: bool - _flag_rst: bool - _flag_syn: bool - _flag_fin: bool - _win: int - _cksum: int - _urg: int + _header: TcpHeader _options: list[TcpOption] _data: bytes + _plen: int _hlen: int _olen: int @@ -124,13 +165,13 @@ def __str__(self) -> str: """ log = ( - f"TCP {self._sport} > {self._dport}, " - f"{'N' if self._flag_ns else ''}{'C' if self._flag_cwr else ''}" - f"{'E' if self._flag_ece else ''}{'U' if self._flag_urg else ''}" - f"{'A' if self._flag_ack else ''}{'P' if self._flag_psh else ''}" - f"{'R' if self._flag_rst else ''}{'S' if self._flag_syn else ''}" - f"{'F' if self._flag_fin else ''}, seq {self._seq}, " - f"ack {self._ack}, win {self._win}, dlen {len(self._data)}" + f"TCP {self._header.sport} > {self._header.dport}, " + f"{'N' if self._header.flag_ns else ''}{'C' if self._header.flag_cwr else ''}" + f"{'E' if self._header.flag_ece else ''}{'U' if self._header.flag_urg else ''}" + f"{'A' if self._header.flag_ack else ''}{'P' if self._header.flag_psh else ''}" + f"{'R' if self._header.flag_rst else ''}{'S' if self._header.flag_syn else ''}" + f"{'F' if self._header.flag_fin else ''}, seq {self._header.seq}, " + f"ack {self._header.ack}, win {self._header.win}, dlen {len(self._data)}" ) for option in self._options: @@ -146,26 +187,24 @@ def __repr__(self) -> str: return ( "Tcp(" - f"sport={self._sport!r}, " - f"dport={self._dport!r}, " - f"seq={self._seq!r}, " - f"ack={self._ack!r}, " - f"flag_ns={self._flag_ns!r}, " - f"flag_cwr={self._flag_cwr!r}, " - f"flag_ece={self._flag_ece!r}, " - f"flag_urg={self._flag_urg!r}, " - f"flag_ack={self._flag_ack!r}, " - f"flag_psh={self._flag_psh!r}, " - f"flag_rst={self._flag_rst!r}, " - f"flag_syn={self._flag_syn!r}, " - f"flag_fin={self._flag_fin!r}, " - f"window={self._win!r}, " - f"urg={self._urg!r}, " + f"sport={self._header.sport!r}, " + f"dport={self._header.dport!r}, " + f"seq={self._header.seq!r}, " + f"ack={self._header.ack!r}, " + f"hlen={self._header.hlen!r}, " + f"flag_ns={self._header.flag_ns!r}, " + f"flag_cwr={self._header.flag_cwr!r}, " + f"flag_ece={self._header.flag_ece!r}, " + f"flag_urg={self._header.flag_urg!r}, " + f"flag_ack={self._header.flag_ack!r}, " + f"flag_psh={self._header.flag_psh!r}, " + f"flag_rst={self._header.flag_rst!r}, " + f"flag_syn={self._header.flag_syn!r}, " + f"flag_fin={self._header.flag_fin!r}, " + f"window={self._header.win!r}, " + f"cksum={self._header.cksum!r}," + f"urg={self._header.urg!r}, " f"options={self._options!r})" - f"plen={self._hlen!r}, " - f"hlen={self._hlen!r}, " - f"olen={self._hlen!r}, " - f"dlen={self._hlen!r}, " ) @override @@ -178,22 +217,22 @@ def __bytes__(self) -> bytes: return struct.pack( f"! HH L L BBH HH {len(raw_options)}s {len(self._data)}s", - self._sport, - self._dport, - self._seq, - self._ack, - self._hlen << 2 | self._flag_ns, - self._flag_cwr << 7 - | self._flag_ece << 6 - | self._flag_urg << 5 - | self._flag_ack << 4 - | self._flag_psh << 3 - | self._flag_rst << 2 - | self._flag_syn << 1 - | self._flag_fin, - self._win, + self._header.sport, + self._header.dport, + self._header.seq, + self._header.ack, + self._hlen << 2 | self._header.flag_ns, + self._header.flag_cwr << 7 + | self._header.flag_ece << 6 + | self._header.flag_urg << 5 + | self._header.flag_ack << 4 + | self._header.flag_psh << 3 + | self._header.flag_rst << 2 + | self._header.flag_syn << 1 + | self._header.flag_fin, + self._header.win, 0, - self._urg, + self._header.urg, raw_options, self._data, ) @@ -220,7 +259,7 @@ def sport(self) -> int: Get the source port. """ - return self._sport + return self._header.sport @property def dport(self) -> int: @@ -228,7 +267,7 @@ def dport(self) -> int: Get the destination port. """ - return self._dport + return self._header.dport @property def seq(self) -> int: @@ -236,7 +275,7 @@ def seq(self) -> int: Get the sequence number. """ - return self._seq + return self._header.seq @property def ack(self) -> int: @@ -244,7 +283,7 @@ def ack(self) -> int: Get the acknowledgment number. """ - return self._ack + return self._header.ack @property def flag_ns(self) -> bool: @@ -252,7 +291,7 @@ def flag_ns(self) -> bool: Get the NS flag. """ - return self._flag_ns + return self._header.flag_ns @property def flag_cwr(self) -> bool: @@ -260,7 +299,7 @@ def flag_cwr(self) -> bool: Get the CWR flag. """ - return self._flag_cwr + return self._header.flag_cwr @property def flag_ece(self) -> bool: @@ -268,7 +307,7 @@ def flag_ece(self) -> bool: Get the ECE flag. """ - return self._flag_ece + return self._header.flag_ece @property def flag_urg(self) -> bool: @@ -276,7 +315,7 @@ def flag_urg(self) -> bool: Get the URG flag. """ - return self._flag_urg + return self._header.flag_urg @property def flag_ack(self) -> bool: @@ -284,7 +323,7 @@ def flag_ack(self) -> bool: Get the ACK flag. """ - return self._flag_ack + return self._header.flag_ack @property def flag_psh(self) -> bool: @@ -292,7 +331,7 @@ def flag_psh(self) -> bool: Get the PSH flag. """ - return self._flag_psh + return self._header.flag_psh @property def flag_rst(self) -> bool: @@ -300,7 +339,7 @@ def flag_rst(self) -> bool: Get the RST flag. """ - return self._flag_rst + return self._header.flag_rst @property def flag_syn(self) -> bool: @@ -308,7 +347,7 @@ def flag_syn(self) -> bool: Get the SYN flag. """ - return self._flag_syn + return self._header.flag_syn @property def flag_fin(self) -> bool: @@ -316,7 +355,7 @@ def flag_fin(self) -> bool: Get the FIN flag. """ - return self._flag_fin + return self._header.flag_fin @property def win(self) -> int: @@ -324,7 +363,7 @@ def win(self) -> int: Get the window size. """ - return self._win + return self._header.win @property def cksum(self) -> int: @@ -332,7 +371,7 @@ def cksum(self) -> int: Get the checksum. """ - return self._cksum + return self._header.cksum @property def urg(self) -> int: @@ -340,7 +379,7 @@ def urg(self) -> int: Get the urgent pointer. """ - return self._urg + return self._header.urg @property def options(self) -> list[TcpOption]: