diff --git a/anta/custom_types.py b/anta/custom_types.py index c1e1f6428..b4e1113e4 100644 --- a/anta/custom_types.py +++ b/anta/custom_types.py @@ -4,9 +4,9 @@ """Module that provides predefined types for AntaTest.Input instances.""" import re -from typing import Annotated, Literal +from typing import Annotated, Literal, Self, get_args -from pydantic import Field +from pydantic import BaseModel, Field, model_validator from pydantic.functional_validators import AfterValidator, BeforeValidator # Regular Expression definition @@ -204,3 +204,36 @@ def validate_regex(value: str) -> str: ] BgpUpdateError = Literal["inUpdErrWithdraw", "inUpdErrIgnore", "inUpdErrDisableAfiSafi", "disabledAfiSafi", "lastUpdErrTime"] BfdProtocol = Literal["bgp", "isis", "lag", "ospf", "ospfv3", "pim", "route-input", "static-bfd", "static-route", "vrrp", "vxlan"] + + +class APISSLCertificate(BaseModel): + """Model for an API SSL certificate.""" + + certificate_name: str + """The name of the certificate to be verified.""" + expiry_threshold: int + """The expiry threshold of the certificate in days.""" + common_name: str + """The common subject name of the certificate.""" + encryption_algorithm: EncryptionAlgorithm + """The encryption algorithm of the certificate.""" + key_size: RsaKeySize | EcdsaKeySize + """The encryption algorithm key size of the certificate.""" + + @model_validator(mode="after") + def validate_inputs(self) -> Self: + """Validate the key size provided to the APISSLCertificates class. + + If encryption_algorithm is RSA then key_size should be in {2048, 3072, 4096}. + + If encryption_algorithm is ECDSA then key_size should be in {256, 384, 521}. + """ + if self.encryption_algorithm == "RSA" and self.key_size not in get_args(RsaKeySize): + msg = f"`{self.certificate_name}` key size {self.key_size} is invalid for RSA encryption. Allowed sizes are {get_args(RsaKeySize)}." + raise ValueError(msg) + + if self.encryption_algorithm == "ECDSA" and self.key_size not in get_args(EcdsaKeySize): + msg = f"`{self.certificate_name}` key size {self.key_size} is invalid for ECDSA encryption. Allowed sizes are {get_args(EcdsaKeySize)}." + raise ValueError(msg) + + return self diff --git a/anta/tests/field_notices.py b/anta/tests/field_notices.py index 71a11749f..6f98a2c9a 100644 --- a/anta/tests/field_notices.py +++ b/anta/tests/field_notices.py @@ -196,4 +196,4 @@ def test(self) -> None: self.result.is_success("FN72 is mitigated") return # We should never hit this point - self.result.is_error("Error in running test - FixedSystemvrm1 not found") + self.result.is_failure("Error in running test - Component FixedSystemvrm1 not found in 'show version'") diff --git a/anta/tests/interfaces.py b/anta/tests/interfaces.py index 9ff1cf357..32b85d493 100644 --- a/anta/tests/interfaces.py +++ b/anta/tests/interfaces.py @@ -71,7 +71,7 @@ def test(self) -> None: if ((duplex := (interface := interfaces["interfaces"][intf]).get("duplex", None)) is not None and duplex != duplex_full) or ( (members := interface.get("memberInterfaces", None)) is not None and any(stats["duplex"] != duplex_full for stats in members.values()) ): - self.result.is_error(f"Interface {intf} or one of its member interfaces is not Full-Duplex. VerifyInterfaceUtilization has not been implemented.") + self.result.is_failure(f"Interface {intf} or one of its member interfaces is not Full-Duplex. VerifyInterfaceUtilization has not been implemented.") return if (bandwidth := interfaces["interfaces"][intf]["bandwidth"]) == 0: @@ -705,7 +705,7 @@ def test(self) -> None: input_interface_detail = interface break else: - self.result.is_error(f"Could not find `{intf}` in the input interfaces. {GITHUB_SUGGESTION}") + self.result.is_failure(f"Could not find `{intf}` in the input interfaces. {GITHUB_SUGGESTION}") continue input_primary_ip = str(input_interface_detail.primary_ip) diff --git a/anta/tests/mlag.py b/anta/tests/mlag.py index 1d17ab642..c894b98b6 100644 --- a/anta/tests/mlag.py +++ b/anta/tests/mlag.py @@ -123,10 +123,7 @@ class VerifyMlagConfigSanity(AntaTest): def test(self) -> None: """Main test function for VerifyMlagConfigSanity.""" command_output = self.instance_commands[0].json_output - if (mlag_status := get_value(command_output, "mlagActive")) is None: - self.result.is_error(message="Incorrect JSON response - 'mlagActive' state was not found") - return - if mlag_status is False: + if command_output["mlagActive"] is False: self.result.is_skipped("MLAG is disabled") return keys_to_verify = ["globalConfiguration", "interfaceConfiguration"] diff --git a/anta/tests/routing/bgp.py b/anta/tests/routing/bgp.py index 97f919876..8efc1ac3e 100644 --- a/anta/tests/routing/bgp.py +++ b/anta/tests/routing/bgp.py @@ -8,7 +8,7 @@ from __future__ import annotations from ipaddress import IPv4Address, IPv4Network, IPv6Address -from typing import Any, ClassVar +from typing import Any, ClassVar, Self from pydantic import BaseModel, Field, PositiveInt, model_validator from pydantic.v1.utils import deep_update @@ -235,7 +235,7 @@ class BgpAfi(BaseModel): """Number of expected BGP peer(s).""" @model_validator(mode="after") - def validate_inputs(self: BaseModel) -> BaseModel: + def validate_inputs(self) -> Self: """Validate the inputs provided to the BgpAfi class. If afi is either ipv4 or ipv6, safi must be provided. @@ -375,7 +375,7 @@ class BgpAfi(BaseModel): """ @model_validator(mode="after") - def validate_inputs(self: BaseModel) -> BaseModel: + def validate_inputs(self) -> Self: """Validate the inputs provided to the BgpAfi class. If afi is either ipv4 or ipv6, safi must be provided. @@ -522,7 +522,7 @@ class BgpAfi(BaseModel): """List of BGP IPv4 or IPv6 peer.""" @model_validator(mode="after") - def validate_inputs(self: BaseModel) -> BaseModel: + def validate_inputs(self) -> Self: """Validate the inputs provided to the BgpAfi class. If afi is either ipv4 or ipv6, safi must be provided and vrf must NOT be all. @@ -1485,7 +1485,7 @@ class BgpPeer(BaseModel): """Outbound route map applied, defaults to None.""" @model_validator(mode="after") - def validate_inputs(self: BaseModel) -> BaseModel: + def validate_inputs(self) -> Self: """Validate the inputs provided to the BgpPeer class. At least one of 'inbound' or 'outbound' route-map must be provided. diff --git a/anta/tests/routing/generic.py b/anta/tests/routing/generic.py index cd9cf0d24..7e64acda7 100644 --- a/anta/tests/routing/generic.py +++ b/anta/tests/routing/generic.py @@ -9,7 +9,7 @@ from functools import cache from ipaddress import IPv4Address, IPv4Interface -from typing import ClassVar, Literal +from typing import ClassVar, Literal, Self from pydantic import model_validator @@ -89,8 +89,8 @@ class Input(AntaTest.Input): maximum: int """Expected maximum routing table size.""" - @model_validator(mode="after") # type: ignore[misc] - def check_min_max(self) -> AntaTest.Input: + @model_validator(mode="after") + def check_min_max(self) -> Self: """Validate that maximum is greater than minimum.""" if self.minimum > self.maximum: msg = f"Minimum {self.minimum} is greater than maximum {self.maximum}" diff --git a/anta/tests/security.py b/anta/tests/security.py index ae5b9bebd..159e0155e 100644 --- a/anta/tests/security.py +++ b/anta/tests/security.py @@ -11,9 +11,9 @@ from ipaddress import IPv4Address from typing import ClassVar -from pydantic import BaseModel, Field, model_validator +from pydantic import BaseModel, Field -from anta.custom_types import EcdsaKeySize, EncryptionAlgorithm, PositiveInteger, RsaKeySize +from anta.custom_types import APISSLCertificate, PositiveInteger from anta.models import AntaCommand, AntaTemplate, AntaTest from anta.tools import get_failed_logs, get_item, get_value @@ -47,7 +47,7 @@ def test(self) -> None: try: line = next(line for line in command_output.split("\n") if line.startswith("SSHD status")) except StopIteration: - self.result.is_error("Could not find SSH status in returned output.") + self.result.is_failure("Could not find SSH status in returned output.") return status = line.split("is ")[1] @@ -401,38 +401,6 @@ class Input(AntaTest.Input): certificates: list[APISSLCertificate] """List of API SSL certificates.""" - class APISSLCertificate(BaseModel): - """Model for an API SSL certificate.""" - - certificate_name: str - """The name of the certificate to be verified.""" - expiry_threshold: int - """The expiry threshold of the certificate in days.""" - common_name: str - """The common subject name of the certificate.""" - encryption_algorithm: EncryptionAlgorithm - """The encryption algorithm of the certificate.""" - key_size: RsaKeySize | EcdsaKeySize - """The encryption algorithm key size of the certificate.""" - - @model_validator(mode="after") - def validate_inputs(self: BaseModel) -> BaseModel: - """Validate the key size provided to the APISSLCertificates class. - - If encryption_algorithm is RSA then key_size should be in {2048, 3072, 4096}. - - If encryption_algorithm is ECDSA then key_size should be in {256, 384, 521}. - """ - if self.encryption_algorithm == "RSA" and self.key_size not in RsaKeySize.__args__: - msg = f"`{self.certificate_name}` key size {self.key_size} is invalid for RSA encryption. Allowed sizes are {RsaKeySize.__args__}." - raise ValueError(msg) - - if self.encryption_algorithm == "ECDSA" and self.key_size not in EcdsaKeySize.__args__: - msg = f"`{self.certificate_name}` key size {self.key_size} is invalid for ECDSA encryption. Allowed sizes are {EcdsaKeySize.__args__}." - raise ValueError(msg) - - return self - @AntaTest.anta_test def test(self) -> None: """Main test function for VerifyAPISSLCertificate.""" diff --git a/anta/tests/system.py b/anta/tests/system.py index 486e5e1ed..d620d533b 100644 --- a/anta/tests/system.py +++ b/anta/tests/system.py @@ -89,9 +89,6 @@ class VerifyReloadCause(AntaTest): def test(self) -> None: """Main test function for VerifyReloadCause.""" command_output = self.instance_commands[0].json_output - if "resetCauses" not in command_output: - self.result.is_error(message="No reload causes available") - return if len(command_output["resetCauses"]) == 0: # No reload causes self.result.is_success() diff --git a/tests/benchmark/test_anta.py b/tests/benchmark/test_anta.py index 6d0b585e3..82d08cf6e 100644 --- a/tests/benchmark/test_anta.py +++ b/tests/benchmark/test_anta.py @@ -106,3 +106,5 @@ def bench() -> ResultManager: "---------------------------------------" ) logger.info(bench_info) + assert manager.get_total_results({AntaTestStatus.ERROR}) == 0 + assert manager.get_total_results({AntaTestStatus.UNSET}) == 0 diff --git a/tests/benchmark/utils.py b/tests/benchmark/utils.py index c59400130..1017cfe0a 100644 --- a/tests/benchmark/utils.py +++ b/tests/benchmark/utils.py @@ -13,7 +13,6 @@ from typing import TYPE_CHECKING, Any import httpx -from pydantic import ValidationError from anta.catalog import AntaCatalog, AntaTestDefinition from anta.models import AntaCommand, AntaTest @@ -91,15 +90,10 @@ def import_test_modules() -> Generator[ModuleType, None, None]: for test_data in module.DATA: test = test_data["test"] result_overwrite = AntaTest.Input.ResultOverwrite(custom_field=test_data["name"]) - # Some unit tests purposely have invalid inputs, we just skip them - try: - if test_data["inputs"] is None: - inputs = test.Input(result_overwrite=result_overwrite) - else: - inputs = test.Input(**test_data["inputs"], result_overwrite=result_overwrite) - except ValidationError: - continue - + if test_data["inputs"] is None: + inputs = test.Input(result_overwrite=result_overwrite) + else: + inputs = test.Input(**test_data["inputs"], result_overwrite=result_overwrite) test_definition = AntaTestDefinition( test=test, inputs=inputs, diff --git a/tests/units/anta_tests/routing/test_generic.py b/tests/units/anta_tests/routing/test_generic.py index 0ac43f3c5..0d87bcc34 100644 --- a/tests/units/anta_tests/routing/test_generic.py +++ b/tests/units/anta_tests/routing/test_generic.py @@ -7,6 +7,9 @@ from typing import Any +import pytest +from pydantic import ValidationError + from anta.tests.routing.generic import VerifyRoutingProtocolModel, VerifyRoutingTableEntry, VerifyRoutingTableSize from tests.units.anta_tests import test @@ -66,16 +69,6 @@ "inputs": {"minimum": 42, "maximum": 666}, "expected": {"result": "failure", "messages": ["routing-table has 1000 routes and not between min (42) and maximum (666)"]}, }, - { - "name": "error-max-smaller-than-min", - "test": VerifyRoutingTableSize, - "eos_data": [{}], - "inputs": {"minimum": 666, "maximum": 42}, - "expected": { - "result": "error", - "messages": ["Minimum 666 is greater than maximum 42"], - }, - }, { "name": "success", "test": VerifyRoutingTableEntry, @@ -310,11 +303,14 @@ "inputs": {"vrf": "default", "routes": ["10.1.0.1", "10.1.0.2"], "collect": "all"}, "expected": {"result": "failure", "messages": ["The following route(s) are missing from the routing table of VRF default: ['10.1.0.2']"]}, }, - { - "name": "collect-input-error", - "test": VerifyRoutingTableEntry, - "eos_data": {}, - "inputs": {"vrf": "default", "routes": ["10.1.0.1", "10.1.0.2"], "collect": "not-valid"}, - "expected": {"result": "error", "messages": ["Inputs are not valid"]}, - }, ] + + +class TestVerifyRoutingTableSize: # pylint: disable=too-few-public-methods + """Test VerifyRoutingTableSize.""" + + def test_inputs(self) -> None: + """Test VerifyRoutingTableSize inputs.""" + VerifyRoutingTableSize.Input(minimum=1, maximum=2) + with pytest.raises(ValidationError): + VerifyRoutingTableSize.Input(minimum=2, maximum=1) diff --git a/tests/units/anta_tests/test_configuration.py b/tests/units/anta_tests/test_configuration.py index dbe22d365..d8f86beaa 100644 --- a/tests/units/anta_tests/test_configuration.py +++ b/tests/units/anta_tests/test_configuration.py @@ -60,14 +60,4 @@ "inputs": {"regex_patterns": ["bla", "bleh"]}, "expected": {"result": "failure", "messages": ["Following patterns were not found: 'bla','bleh'"]}, }, - { - "name": "failure-invalid-regex", - "test": VerifyRunningConfigLines, - "eos_data": ["enable password something\nsome other line"], - "inputs": {"regex_patterns": ["["]}, - "expected": { - "result": "error", - "messages": ["1 validation error for Input\nregex_patterns.0\n Value error, Invalid regex: unterminated character set at position 0"], - }, - }, ] diff --git a/tests/units/anta_tests/test_field_notices.py b/tests/units/anta_tests/test_field_notices.py index a30604b8b..8e7c9d8b3 100644 --- a/tests/units/anta_tests/test_field_notices.py +++ b/tests/units/anta_tests/test_field_notices.py @@ -358,8 +358,8 @@ ], "inputs": None, "expected": { - "result": "error", - "messages": ["Error in running test - FixedSystemvrm1 not found"], + "result": "failure", + "messages": ["Error in running test - Component FixedSystemvrm1 not found in 'show version'"], }, }, ] diff --git a/tests/units/anta_tests/test_interfaces.py b/tests/units/anta_tests/test_interfaces.py index 73ef6c6aa..ea8106e84 100644 --- a/tests/units/anta_tests/test_interfaces.py +++ b/tests/units/anta_tests/test_interfaces.py @@ -652,7 +652,7 @@ ], "inputs": {"threshold": 70.0}, "expected": { - "result": "error", + "result": "failure", "messages": ["Interface Ethernet1/1 or one of its member interfaces is not Full-Duplex. VerifyInterfaceUtilization has not been implemented."], }, }, @@ -797,7 +797,7 @@ ], "inputs": {"threshold": 70.0}, "expected": { - "result": "error", + "result": "failure", "messages": ["Interface Port-Channel31 or one of its member interfaces is not Full-Duplex. VerifyInterfaceUtilization has not been implemented."], }, }, diff --git a/tests/units/anta_tests/test_mlag.py b/tests/units/anta_tests/test_mlag.py index 1ef547259..193d69c2d 100644 --- a/tests/units/anta_tests/test_mlag.py +++ b/tests/units/anta_tests/test_mlag.py @@ -110,17 +110,6 @@ "inputs": None, "expected": {"result": "skipped", "messages": ["MLAG is disabled"]}, }, - { - "name": "error", - "test": VerifyMlagConfigSanity, - "eos_data": [ - { - "dummy": False, - }, - ], - "inputs": None, - "expected": {"result": "error", "messages": ["Incorrect JSON response - 'mlagActive' state was not found"]}, - }, { "name": "failure-global", "test": VerifyMlagConfigSanity, diff --git a/tests/units/anta_tests/test_security.py b/tests/units/anta_tests/test_security.py index 792b06595..474539da9 100644 --- a/tests/units/anta_tests/test_security.py +++ b/tests/units/anta_tests/test_security.py @@ -39,7 +39,7 @@ "test": VerifySSHStatus, "eos_data": ["SSH per host connection limit is 20\nFIPS status: disabled\n\n"], "inputs": None, - "expected": {"result": "error", "messages": ["Could not find SSH status in returned output."]}, + "expected": {"result": "failure", "messages": ["Could not find SSH status in returned output."]}, }, { "name": "failure-ssh-disabled", @@ -581,40 +581,6 @@ ], }, }, - { - "name": "error-wrong-input-rsa", - "test": VerifyAPISSLCertificate, - "eos_data": [], - "inputs": { - "certificates": [ - { - "certificate_name": "ARISTA_ROOT_CA.crt", - "expiry_threshold": 30, - "common_name": "Arista Networks Internal IT Root Cert Authority", - "encryption_algorithm": "RSA", - "key_size": 256, - }, - ] - }, - "expected": {"result": "error", "messages": ["Allowed sizes are (2048, 3072, 4096)."]}, - }, - { - "name": "error-wrong-input-ecdsa", - "test": VerifyAPISSLCertificate, - "eos_data": [], - "inputs": { - "certificates": [ - { - "certificate_name": "ARISTA_SIGNING_CA.crt", - "expiry_threshold": 30, - "common_name": "AristaIT-ICA ECDSA Issuing Cert Authority", - "encryption_algorithm": "ECDSA", - "key_size": 2048, - }, - ] - }, - "expected": {"result": "error", "messages": ["Allowed sizes are (256, 384, 512)."]}, - }, { "name": "success", "test": VerifyBannerLogin, diff --git a/tests/units/anta_tests/test_system.py b/tests/units/anta_tests/test_system.py index 22b9787b2..1eda8a1d5 100644 --- a/tests/units/anta_tests/test_system.py +++ b/tests/units/anta_tests/test_system.py @@ -76,13 +76,6 @@ "inputs": None, "expected": {"result": "failure", "messages": ["Reload cause is: 'Reload after crash.'"]}, }, - { - "name": "error", - "test": VerifyReloadCause, - "eos_data": [{}], - "inputs": None, - "expected": {"result": "error", "messages": ["No reload causes available"]}, - }, { "name": "success-without-minidump", "test": VerifyCoredump, diff --git a/tests/units/test_custom_types.py b/tests/units/test_custom_types.py index e3dc09d25..b6fcf1c98 100644 --- a/tests/units/test_custom_types.py +++ b/tests/units/test_custom_types.py @@ -11,8 +11,10 @@ from __future__ import annotations import re +from typing import Any import pytest +from pydantic import ValidationError from anta.custom_types import ( REGEX_BGP_IPV4_MPLS_VPN, @@ -26,10 +28,12 @@ REGEXP_TYPE_EOS_INTERFACE, REGEXP_TYPE_HOSTNAME, REGEXP_TYPE_VXLAN_SRC_INTERFACE, + APISSLCertificate, aaa_group_prefix, bgp_multiprotocol_capabilities_abbreviations, interface_autocomplete, interface_case_sensitivity, + validate_regex, ) # ------------------------------------------------------------------------------ @@ -281,3 +285,73 @@ def test_interface_case_sensitivity_uppercase() -> None: assert interface_case_sensitivity("ETHERNET") == "ETHERNET" assert interface_case_sensitivity("VLAN") == "VLAN" assert interface_case_sensitivity("LOOPBACK") == "LOOPBACK" + + +@pytest.mark.parametrize( + "str_input", + [ + REGEX_BGP_IPV4_MPLS_VPN, + REGEX_BGP_IPV4_UNICAST, + REGEX_TYPE_PORTCHANNEL, + REGEXP_BGP_IPV4_MPLS_LABELS, + REGEXP_BGP_L2VPN_AFI, + REGEXP_INTERFACE_ID, + REGEXP_PATH_MARKERS, + REGEXP_TYPE_EOS_INTERFACE, + REGEXP_TYPE_HOSTNAME, + REGEXP_TYPE_VXLAN_SRC_INTERFACE, + ], +) +def test_validate_regex_valid(str_input: str) -> None: + """Test validate_regex with valid regex.""" + assert validate_regex(str_input) == str_input + + +@pytest.mark.parametrize( + ("str_input", "error"), + [ + pytest.param("[", "Invalid regex: unterminated character set at position 0", id="unterminated character"), + pytest.param("\\", r"Invalid regex: bad escape \(end of pattern\) at position 0", id="bad escape"), + ], +) +def test_validate_regex_invalid(str_input: str, error: str) -> None: + """Test validate_regex with invalid regex.""" + with pytest.raises(ValueError, match=error): + validate_regex(str_input) + + +class TestAPISSLCertificate: # pylint: disable=too-few-public-methods + """Test anta.custom_types.APISSLCertificate.""" + + @pytest.mark.parametrize( + ("model_params", "error"), + [ + pytest.param( + { + "certificate_name": "ARISTA_ROOT_CA.crt", + "expiry_threshold": 30, + "common_name": "Arista Networks Internal IT Root Cert Authority", + "encryption_algorithm": "RSA", + "key_size": 256, + }, + "Value error, `ARISTA_ROOT_CA.crt` key size 256 is invalid for RSA encryption. Allowed sizes are (2048, 3072, 4096).", + id="RSA_wrong_size", + ), + pytest.param( + { + "certificate_name": "ARISTA_SIGNING_CA.crt", + "expiry_threshold": 30, + "common_name": "AristaIT-ICA ECDSA Issuing Cert Authority", + "encryption_algorithm": "ECDSA", + "key_size": 2048, + }, + "Value error, `ARISTA_SIGNING_CA.crt` key size 2048 is invalid for ECDSA encryption. Allowed sizes are (256, 384, 512).", + id="ECDSA_wrong_size", + ), + ], + ) + def test_invalid(self, model_params: dict[str, Any], error: str) -> None: + """Test invalid inputs for anta.custom_types.APISSLCertificate.""" + with pytest.raises(ValidationError) as exec_info: + APISSLCertificate.model_validate(model_params) + assert error == exec_info.value.errors()[0]["msg"]