diff --git a/flow/record/fieldtypes/net/ip.py b/flow/record/fieldtypes/net/ip.py index 0eda412..73a9371 100644 --- a/flow/record/fieldtypes/net/ip.py +++ b/flow/record/fieldtypes/net/ip.py @@ -1,38 +1,54 @@ -from ipaddress import ip_address, ip_network +from __future__ import annotations + +from ipaddress import ( + IPv4Address, + IPv4Network, + IPv6Address, + IPv6Network, + ip_address, + ip_network, +) +from typing import Union from flow.record.base import FieldType from flow.record.fieldtypes import defang +_IPNetwork = Union[IPv4Network, IPv6Network] +_IPAddress = Union[IPv4Address, IPv6Address] + class ipaddress(FieldType): val = None _type = "net.ipaddress" - def __init__(self, addr): + def __init__(self, addr: str | int | bytes): self.val = ip_address(addr) - def __eq__(self, b): + def __eq__(self, b: str | int | bytes | _IPAddress) -> bool: try: return self.val == ip_address(b) except ValueError: return False - def __str__(self): + def __hash__(self) -> int: + return hash(self.val) + + def __str__(self) -> str: return str(self.val) - def __repr__(self): - return "{}({!r})".format(self._type, str(self)) + def __repr__(self) -> str: + return f"{self._type}({str(self)!r})" - def __format__(self, spec): + def __format__(self, spec: str) -> str: if spec == "defang": return defang(str(self)) return str.__format__(str(self), spec) - def _pack(self): + def _pack(self) -> int: return int(self.val) @staticmethod - def _unpack(data): + def _unpack(data: int) -> ipaddress: return ipaddress(data) @@ -40,17 +56,20 @@ class ipnetwork(FieldType): val = None _type = "net.ipnetwork" - def __init__(self, addr): + def __init__(self, addr: str | int | bytes): self.val = ip_network(addr) - def __eq__(self, b): + def __eq__(self, b: str | int | bytes | _IPNetwork) -> bool: try: return self.val == ip_network(b) except ValueError: return False + def __hash__(self) -> int: + return hash(self.val) + @staticmethod - def _is_subnet_of(a, b): + def _is_subnet_of(a: _IPNetwork, b: _IPNetwork) -> bool: try: # Always false if one is v4 and the other is v6. if a._version != b._version: @@ -59,23 +78,23 @@ def _is_subnet_of(a, b): except AttributeError: raise TypeError("Unable to test subnet containment " "between {} and {}".format(a, b)) - def __contains__(self, b): + def __contains__(self, b: str | int | bytes | _IPAddress) -> bool: try: return self._is_subnet_of(ip_network(b), self.val) except (ValueError, TypeError): return False - def __str__(self): + def __str__(self) -> str: return str(self.val) - def __repr__(self): - return "{}({!r})".format(self._type, str(self)) + def __repr__(self) -> str: + return f"{self._type}({str(self)!r})" - def _pack(self): + def _pack(self) -> str: return self.val.compressed @staticmethod - def _unpack(data): + def _unpack(data: str) -> ipnetwork: return ipnetwork(data) diff --git a/tests/test_fieldtype_ip.py b/tests/test_fieldtype_ip.py index a4c389f..b43984c 100644 --- a/tests/test_fieldtype_ip.py +++ b/tests/test_fieldtype_ip.py @@ -48,12 +48,19 @@ def test_record_ipaddress(): assert TestRecord("0.0.0.0").ip == "0.0.0.0" assert TestRecord("192.168.0.1").ip == "192.168.0.1" assert TestRecord("255.255.255.255").ip == "255.255.255.255" + assert hash(TestRecord("192.168.0.1").ip) == hash(net.ipaddress("192.168.0.1")) # ipv6 assert TestRecord("::1").ip == "::1" assert TestRecord("2001:4860:4860::8888").ip == "2001:4860:4860::8888" assert TestRecord("2001:4860:4860::4444").ip == "2001:4860:4860::4444" + # Test whether it functions in a set + data = {TestRecord(ip).ip for ip in ["192.168.0.1", "192.168.0.1", "::1", "::1"]} + assert len(data) == 2 + assert net.ipaddress("::1") in data + assert net.ipaddress("192.168.0.1") in data + # instantiate from different types assert TestRecord(1).ip == "0.0.0.1" assert TestRecord(0x7F0000FF).ip == "127.0.0.255" @@ -90,6 +97,7 @@ def test_record_ipnetwork(): assert "192.168.1.1" not in r.subnet assert isinstance(r.subnet, net.ipnetwork) assert repr(r.subnet) == "net.ipnetwork('192.168.0.0/24')" + assert hash(r.subnet) == hash(net.ipnetwork("192.168.0.0/24")) r = TestRecord("192.168.1.1/32") assert r.subnet == "192.168.1.1" @@ -111,6 +119,13 @@ def test_record_ipnetwork(): assert "64:ff9b::0.0.0.0" in r.subnet assert "64:ff9b::255.255.255.255" in r.subnet + # Test whether it functions in a set + data = {TestRecord(x).subnet for x in ["192.168.0.0/24", "192.168.0.0/24", "::1", "::1"]} + assert len(data) == 2 + assert net.ipnetwork("::1") in data + assert net.ipnetwork("192.168.0.0/24") in data + assert "::1" not in data + @pytest.mark.parametrize("PSelector", [Selector, CompiledSelector]) def test_selector_ipaddress(PSelector):