diff --git a/anta/reporter/md_reporter.py b/anta/reporter/md_reporter.py index 5142a5708..4c8ac54a3 100644 --- a/anta/reporter/md_reporter.py +++ b/anta/reporter/md_reporter.py @@ -12,6 +12,7 @@ from anta.constants import MD_REPORT_TOC from anta.logger import anta_log_exception +from anta.result_manager.models import AntaTestStatus if TYPE_CHECKING: from collections.abc import Generator @@ -203,10 +204,10 @@ def generate_rows(self) -> Generator[str, None, None]: """Generate the rows of the summary totals table.""" yield ( f"| {self.results.get_total_results()} " - f"| {self.results.get_total_results({'success'})} " - f"| {self.results.get_total_results({'skipped'})} " - f"| {self.results.get_total_results({'failure'})} " - f"| {self.results.get_total_results({'error'})} |\n" + f"| {self.results.get_total_results({AntaTestStatus.success})} " + f"| {self.results.get_total_results({AntaTestStatus.skipped})} " + f"| {self.results.get_total_results({AntaTestStatus.failure})} " + f"| {self.results.get_total_results({AntaTestStatus.error})} |\n" ) def generate_section(self) -> None: diff --git a/anta/result_manager/__init__.py b/anta/result_manager/__init__.py index 4074e3a22..8282a70b5 100644 --- a/anta/result_manager/__init__.py +++ b/anta/result_manager/__init__.py @@ -151,7 +151,7 @@ def _update_status(self, test_status: AntaTestStatus) -> None: if self.status == "unset" or self.status == "skipped" and test_status in {"success", "failure"}: self.status = test_status elif self.status == "success" and test_status == "failure": - self.status = "failure" + self.status = AntaTestStatus.failure def _update_stats(self, result: TestResult) -> None: """Update the statistics based on the test result. diff --git a/tests/units/result_manager/test__init__.py b/tests/units/result_manager/test__init__.py index ae3c683fc..6274a1cfa 100644 --- a/tests/units/result_manager/test__init__.py +++ b/tests/units/result_manager/test__init__.py @@ -141,29 +141,27 @@ def test_sorted_category_stats(self, list_result_factory: Callable[[int], list[T nullcontext(), id="failure, add success", ), - pytest.param( - "unset", "unknown", None, pytest.raises(ValueError, match="Input should be 'unset', 'success', 'failure', 'error' or 'skipped'"), id="wrong status" - ), + pytest.param("unset", "unknown", None, pytest.raises(ValueError, match="'unknown' is not a valid AntaTestStatus"), id="wrong status"), ], ) def test_add( self, test_result_factory: Callable[[], TestResult], - starting_status: AntaTestStatus, - test_status: AntaTestStatus, + starting_status: str, + test_status: str, expected_status: str, expected_raise: AbstractContextManager[Exception], ) -> None: # pylint: disable=too-many-arguments """Test ResultManager_update_status.""" result_manager = ResultManager() - result_manager.status = starting_status + result_manager.status = AntaTestStatus(starting_status) assert result_manager.error_status is False assert len(result_manager) == 0 test = test_result_factory() - test.result = test_status with expected_raise: + test.result = AntaTestStatus(test_status) result_manager.add(test) if test_status == "error": assert result_manager.error_status is True @@ -199,12 +197,12 @@ def test_add_clear_cache(self, result_manager: ResultManager, test_result_factor def test_get_results(self, result_manager: ResultManager) -> None: """Test ResultManager.get_results.""" # Check for single status - success_results = result_manager.get_results(status={"success"}) + success_results = result_manager.get_results(status={AntaTestStatus.success}) assert len(success_results) == 7 assert all(r.result == "success" for r in success_results) # Check for multiple statuses - failure_results = result_manager.get_results(status={"failure", "error"}) + failure_results = result_manager.get_results(status={AntaTestStatus.failure, AntaTestStatus.error}) assert len(failure_results) == 21 assert all(r.result in {"failure", "error"} for r in failure_results) @@ -226,7 +224,7 @@ def test_get_results_sort_by(self, result_manager: ResultManager) -> None: assert all_results[-1].name == "DC1-SPINE1" # Check multiple statuses with sort_by categories - success_skipped_results = result_manager.get_results(status={"success", "skipped"}, sort_by=["categories"]) + success_skipped_results = result_manager.get_results(status={AntaTestStatus.success, AntaTestStatus.skipped}, sort_by=["categories"]) assert len(success_skipped_results) == 9 assert success_skipped_results[0].categories == ["Interfaces"] assert success_skipped_results[-1].categories == ["VXLAN"] @@ -246,15 +244,15 @@ def test_get_total_results(self, result_manager: ResultManager) -> None: assert result_manager.get_total_results() == 30 # Test single status - assert result_manager.get_total_results(status={"success"}) == 7 - assert result_manager.get_total_results(status={"failure"}) == 19 - assert result_manager.get_total_results(status={"error"}) == 2 - assert result_manager.get_total_results(status={"skipped"}) == 2 + assert result_manager.get_total_results(status={AntaTestStatus.success}) == 7 + assert result_manager.get_total_results(status={AntaTestStatus.failure}) == 19 + assert result_manager.get_total_results(status={AntaTestStatus.error}) == 2 + assert result_manager.get_total_results(status={AntaTestStatus.skipped}) == 2 # Test multiple statuses - assert result_manager.get_total_results(status={"success", "failure"}) == 26 - assert result_manager.get_total_results(status={"success", "failure", "error"}) == 28 - assert result_manager.get_total_results(status={"success", "failure", "error", "skipped"}) == 30 + assert result_manager.get_total_results(status={AntaTestStatus.success, AntaTestStatus.failure}) == 26 + assert result_manager.get_total_results(status={AntaTestStatus.success, AntaTestStatus.failure, AntaTestStatus.error}) == 28 + assert result_manager.get_total_results(status={AntaTestStatus.success, AntaTestStatus.failure, AntaTestStatus.error, AntaTestStatus.skipped}) == 30 @pytest.mark.parametrize( ("status", "error_status", "ignore_error", "expected_status"),