Skip to content

Commit

Permalink
Update to AntaTestStatus
Browse files Browse the repository at this point in the history
  • Loading branch information
carl-baillargeon committed Aug 29, 2024
1 parent 3824f14 commit b848a82
Show file tree
Hide file tree
Showing 8 changed files with 64 additions and 73 deletions.
4 changes: 2 additions & 2 deletions anta/cli/nrfu/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from anta.cli.nrfu import commands
from anta.cli.utils import AliasedGroup, catalog_options, inventory_options
from anta.result_manager import ResultManager
from anta.result_manager.models import TestStatus
from anta.result_manager.models import AntaTestStatus

if TYPE_CHECKING:
from anta.catalog import AntaCatalog
Expand Down Expand Up @@ -49,7 +49,7 @@ def parse_args(self, ctx: click.Context, args: list[str]) -> list[str]:
return super().parse_args(ctx, args)


HIDE_STATUS: list[str] = list(TestStatus)
HIDE_STATUS: list[str] = list(AntaTestStatus)
HIDE_STATUS.remove("unset")


Expand Down
8 changes: 4 additions & 4 deletions anta/reporter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import pathlib

from anta.result_manager import ResultManager
from anta.result_manager.models import TestResult, TestStatus
from anta.result_manager.models import AntaTestStatus, TestResult

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -79,20 +79,20 @@ def _build_headers(self, headers: list[str], table: Table) -> Table:
table.add_column(header, justify="left")
return table

def _color_result(self, status: TestStatus) -> str:
def _color_result(self, status: AntaTestStatus) -> str:
"""Return a colored string based on the status value.
Parameters
----------
status (TestStatus): status value to color.
status (AntaTestStatus): status value to color.
Returns
-------
str: the colored string
"""
color = RICH_COLOR_THEME.get(status.value, "")
return f"[{color}]{status.value}" if color != "" else str(status.value)
return f"[{color}]{status.value}" if color != "" else status.value

def report_all(self, manager: ResultManager, title: str = "All tests results") -> Table:
"""Create a table report with all tests for one or all devices.
Expand Down
2 changes: 1 addition & 1 deletion anta/reporter/md_reporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ def generate_rows(self) -> Generator[str, None, None]:
categories = ", ".join(result.categories)
yield (
f"| {result.name or '-'} | {categories or '-'} | {result.test or '-'} "
f"| {result.description or '-'} | {self.safe_markdown(result.custom_field) or '-'} | {result.result or '-'} | {messages or '-'} |\n"
f"| {result.description or '-'} | {self.safe_markdown(result.custom_field) or '-'} | {result.result.value or '-'} | {messages or '-'} |\n"
)

def generate_section(self) -> None:
Expand Down
37 changes: 16 additions & 21 deletions anta/result_manager/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,9 @@
from collections import defaultdict
from functools import cached_property
from itertools import chain
from typing import get_args

from pydantic import TypeAdapter

from anta.constants import ACRONYM_CATEGORIES
from anta.result_manager.models import TestResult, TestStatus
from anta.result_manager.models import AntaTestStatus, TestResult

from .models import CategoryStats, DeviceStats, TestStats

Expand Down Expand Up @@ -94,7 +91,7 @@ def __init__(self) -> None:
error_status is set to True.
"""
self._result_entries: list[TestResult] = []
self.status: TestStatus = TestStatus.unset
self.status: AntaTestStatus = AntaTestStatus.unset
self.error_status = False

self.device_stats: defaultdict[str, DeviceStats] = defaultdict(DeviceStats)
Expand All @@ -115,7 +112,7 @@ def results(self, value: list[TestResult]) -> None:
"""Set the list of TestResult."""
# When setting the results, we need to reset the state of the current instance
self._result_entries = []
self.status = TestStatus.unset
self.status = AntaTestStatus.unset
self.error_status = False

# Also reset the stats attributes
Expand All @@ -137,19 +134,17 @@ def sorted_category_stats(self) -> dict[str, CategoryStats]:
return dict(sorted(self.category_stats.items()))

@cached_property
def results_by_status(self) -> dict[TestStatus, list[TestResult]]:
def results_by_status(self) -> dict[AntaTestStatus, list[TestResult]]:
"""A cached property that returns the results grouped by status."""
return {status: [result for result in self._result_entries if result.result == status] for status in get_args(TestStatus)}
return {status: [result for result in self._result_entries if result.result == status] for status in AntaTestStatus}

def _update_status(self, test_status: TestStatus) -> None:
def _update_status(self, test_status: AntaTestStatus) -> None:
"""Update the status of the ResultManager instance based on the test status.
Parameters
----------
test_status: TestStatus to update the ResultManager status.
test_status: AntaTestStatus to update the ResultManager status.
"""
result_validator: TypeAdapter[TestStatus] = TypeAdapter(TestStatus)
result_validator.validate_python(test_status)
if test_status == "error":
self.error_status = True
return
Expand All @@ -168,7 +163,7 @@ def _update_stats(self, result: TestResult) -> None:
result.categories = [
" ".join(word.upper() if word.lower() in ACRONYM_CATEGORIES else word.title() for word in category.split()) for category in result.categories
]
count_attr = f"tests_{result.result}_count"
count_attr = f"tests_{result.result.value}_count"

# Update device stats
device_stats: DeviceStats = self.device_stats[result.name]
Expand All @@ -185,7 +180,7 @@ def _update_stats(self, result: TestResult) -> None:
setattr(category_stats, count_attr, getattr(category_stats, count_attr) + 1)

# Update test stats
count_attr = f"devices_{result.result}_count"
count_attr = f"devices_{result.result.value}_count"
test_stats: TestStats = self.test_stats[result.test]
setattr(test_stats, count_attr, getattr(test_stats, count_attr) + 1)
if result.result in ("failure", "error"):
Expand All @@ -208,14 +203,14 @@ def add(self, result: TestResult) -> None:
# Every time a new result is added, we need to clear the cached property
self.__dict__.pop("results_by_status", None)

def get_results(self, status: set[TestStatus] | None = None, sort_by: list[str] | None = None) -> list[TestResult]:
def get_results(self, status: set[AntaTestStatus] | None = None, sort_by: list[str] | None = None) -> list[TestResult]:
"""Get the results, optionally filtered by status and sorted by TestResult fields.
If no status is provided, all results are returned.
Parameters
----------
status: Optional set of TestStatus literals to filter the results.
status: Optional set of AntaTestStatus enum members to filter the results.
sort_by: Optional list of TestResult fields to sort the results.
Returns
Expand All @@ -234,14 +229,14 @@ def get_results(self, status: set[TestStatus] | None = None, sort_by: list[str]

return results

def get_total_results(self, status: set[TestStatus] | None = None) -> int:
def get_total_results(self, status: set[AntaTestStatus] | None = None) -> int:
"""Get the total number of results, optionally filtered by status.
If no status is provided, the total number of results is returned.
Parameters
----------
status: Optional set of TestStatus literals to filter the results.
status: Optional set of AntaTestStatus enum members to filter the results.
Returns
-------
Expand All @@ -258,18 +253,18 @@ def get_status(self, *, ignore_error: bool = False) -> str:
"""Return the current status including error_status if ignore_error is False."""
return "error" if self.error_status and not ignore_error else self.status

def filter(self, hide: set[TestStatus]) -> ResultManager:
def filter(self, hide: set[AntaTestStatus]) -> ResultManager:
"""Get a filtered ResultManager based on test status.
Parameters
----------
hide: set of TestStatus literals to select tests to hide based on their status.
hide: Set of AntaTestStatus enum members to select tests to hide based on their status.
Returns
-------
A filtered `ResultManager`.
"""
possible_statuses = set(get_args(TestStatus))
possible_statuses = set(AntaTestStatus)
manager = ResultManager()
manager.results = self.get_results(possible_statuses - hide)
return manager
Expand Down
34 changes: 15 additions & 19 deletions anta/result_manager/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,8 @@
from pydantic import BaseModel


class TestStatus(str, Enum):
"""TestStatus enum."""

# This is to prevent pytest to collecting this
# TODO: find a way to ignore this only in test and not in library code.
__test__ = False
class AntaTestStatus(str, Enum):
"""Test status Enum for the TestResult."""

unset = "unset"
success = "success"
Expand All @@ -30,21 +26,21 @@ class TestResult(BaseModel):
Attributes
----------
name: Device name where the test has run.
test: Test name runs on the device.
categories: List of categories the TestResult belongs to, by default the AntaTest categories.
description: TestResult description, by default the AntaTest description.
result: Result of the test. Can be one of "unset", "success", "failure", "error" or "skipped".
messages: Message to report after the test if any.
custom_field: Custom field to store a string for flexibility in integrating with ANTA
name: Name of the device where the test was run.
test: Name of the test run on the device.
categories: List of categories the TestResult belongs to. Defaults to the AntaTest categories.
description: Description of the TestResult. Defaults to the AntaTest description.
result: Result of the test. Must be one of the Status Enum values: unset, success, failure, error, skipped.
messages: Messages to report after the test, if any.
custom_field: Custom field to store a string for flexibility in integrating with ANTA.
"""

name: str
test: str
categories: list[str]
description: str
result: TestStatus = TestStatus.unset
result: AntaTestStatus = AntaTestStatus.unset
messages: list[str] = []
custom_field: str | None = None

Expand All @@ -56,7 +52,7 @@ def is_success(self, message: str | None = None) -> None:
message: Optional message related to the test
"""
self._set_status(TestStatus.success, message)
self._set_status(AntaTestStatus.success, message)

def is_failure(self, message: str | None = None) -> None:
"""Set status to failure.
Expand All @@ -66,7 +62,7 @@ def is_failure(self, message: str | None = None) -> None:
message: Optional message related to the test
"""
self._set_status(TestStatus.failure, message)
self._set_status(AntaTestStatus.failure, message)

def is_skipped(self, message: str | None = None) -> None:
"""Set status to skipped.
Expand All @@ -76,7 +72,7 @@ def is_skipped(self, message: str | None = None) -> None:
message: Optional message related to the test
"""
self._set_status(TestStatus.skipped, message)
self._set_status(AntaTestStatus.skipped, message)

def is_error(self, message: str | None = None) -> None:
"""Set status to error.
Expand All @@ -86,9 +82,9 @@ def is_error(self, message: str | None = None) -> None:
message: Optional message related to the test
"""
self._set_status(TestStatus.error, message)
self._set_status(AntaTestStatus.error, message)

def _set_status(self, status: TestStatus, message: str | None = None) -> None:
def _set_status(self, status: AntaTestStatus, message: str | None = None) -> None:
"""Set status and insert optional message.
Parameters
Expand Down
18 changes: 9 additions & 9 deletions tests/units/reporter/test__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from anta import RICH_COLOR_PALETTE
from anta.reporter import ReportJinja, ReportTable
from anta.result_manager.models import TestStatus
from anta.result_manager.models import AntaTestStatus

if TYPE_CHECKING:
from anta.result_manager import ResultManager
Expand Down Expand Up @@ -73,14 +73,14 @@ def test__build_headers(self, headers: list[str]) -> None:
@pytest.mark.parametrize(
("status", "expected_status"),
[
pytest.param(TestStatus.unset, "[grey74]unset", id="unset status"),
pytest.param(TestStatus.skipped, "[bold orange4]skipped", id="skipped status"),
pytest.param(TestStatus.failure, "[bold red]failure", id="failure status"),
pytest.param(TestStatus.error, "[indian_red]error", id="error status"),
pytest.param(TestStatus.success, "[green4]success", id="success status"),
pytest.param(AntaTestStatus.unset, "[grey74]unset", id="unset status"),
pytest.param(AntaTestStatus.skipped, "[bold orange4]skipped", id="skipped status"),
pytest.param(AntaTestStatus.failure, "[bold red]failure", id="failure status"),
pytest.param(AntaTestStatus.error, "[indian_red]error", id="error status"),
pytest.param(AntaTestStatus.success, "[green4]success", id="success status"),
],
)
def test__color_result(self, status: TestStatus, expected_status: str) -> None:
def test__color_result(self, status: AntaTestStatus, expected_status: str) -> None:
"""Test _build_headers."""
# pylint: disable=protected-access
report = ReportTable()
Expand Down Expand Up @@ -139,7 +139,7 @@ def test_report_summary_tests(
new_results = [result.model_copy() for result in manager.results]
for result in new_results:
result.name = "test_device"
result.result = TestStatus.failure
result.result = AntaTestStatus.failure

report = ReportTable()
kwargs = {"tests": [test] if test is not None else None, "title": title}
Expand Down Expand Up @@ -174,7 +174,7 @@ def test_report_summary_devices(
new_results = [result.model_copy() for result in manager.results]
for result in new_results:
result.name = dev or "test_device"
result.result = TestStatus.failure
result.result = AntaTestStatus.failure
manager.results = new_results

report = ReportTable()
Expand Down
30 changes: 15 additions & 15 deletions tests/units/result_manager/test__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import pytest

from anta.result_manager import ResultManager, models
from anta.result_manager.models import TestStatus
from anta.result_manager.models import AntaTestStatus

if TYPE_CHECKING:
from anta.result_manager.models import TestResult
Expand Down Expand Up @@ -56,7 +56,7 @@ def test_json(self, list_result_factory: Callable[[int], list[TestResult]]) -> N

success_list = list_result_factory(3)
for test in success_list:
test.result = TestStatus.success
test.result = AntaTestStatus.success
result_manager.results = success_list

json_res = result_manager.json
Expand Down Expand Up @@ -149,8 +149,8 @@ def test_sorted_category_stats(self, list_result_factory: Callable[[int], list[T
def test_add(
self,
test_result_factory: Callable[[], TestResult],
starting_status: TestStatus,
test_status: TestStatus,
starting_status: AntaTestStatus,
test_status: AntaTestStatus,
expected_status: str,
expected_raise: AbstractContextManager[Exception],
) -> None:
Expand Down Expand Up @@ -266,7 +266,7 @@ def test_get_total_results(self, result_manager: ResultManager) -> None:
)
def test_get_status(
self,
status: TestStatus,
status: AntaTestStatus,
error_status: bool,
ignore_error: bool,
expected_status: str,
Expand All @@ -284,28 +284,28 @@ def test_filter(self, test_result_factory: Callable[[], TestResult], list_result

success_list = list_result_factory(3)
for test in success_list:
test.result = TestStatus.success
test.result = AntaTestStatus.success
result_manager.results = success_list

test = test_result_factory()
test.result = TestStatus.failure
test.result = AntaTestStatus.failure
result_manager.add(test)

test = test_result_factory()
test.result = TestStatus.error
test.result = AntaTestStatus.error
result_manager.add(test)

test = test_result_factory()
test.result = TestStatus.skipped
test.result = AntaTestStatus.skipped
result_manager.add(test)

assert len(result_manager) == 6
assert len(result_manager.filter({TestStatus.failure})) == 5
assert len(result_manager.filter({TestStatus.error})) == 5
assert len(result_manager.filter({TestStatus.skipped})) == 5
assert len(result_manager.filter({TestStatus.failure, TestStatus.error})) == 4
assert len(result_manager.filter({TestStatus.failure, TestStatus.error, TestStatus.skipped})) == 3
assert len(result_manager.filter({TestStatus.success, TestStatus.failure, TestStatus.error, TestStatus.skipped})) == 0
assert len(result_manager.filter({AntaTestStatus.failure})) == 5
assert len(result_manager.filter({AntaTestStatus.error})) == 5
assert len(result_manager.filter({AntaTestStatus.skipped})) == 5
assert len(result_manager.filter({AntaTestStatus.failure, AntaTestStatus.error})) == 4
assert len(result_manager.filter({AntaTestStatus.failure, AntaTestStatus.error, AntaTestStatus.skipped})) == 3
assert len(result_manager.filter({AntaTestStatus.success, AntaTestStatus.failure, AntaTestStatus.error, AntaTestStatus.skipped})) == 0

def test_get_by_tests(self, test_result_factory: Callable[[], TestResult], result_manager_factory: Callable[[int], ResultManager]) -> None:
"""Test ResultManager.get_by_tests."""
Expand Down
Loading

0 comments on commit b848a82

Please sign in to comment.