Skip to content

Commit

Permalink
refactor(anta): Change TestStatus to be an Enum for coding clarity (#758
Browse files Browse the repository at this point in the history
)

Co-authored-by: Carl Baillargeon <[email protected]
  • Loading branch information
gmuloc authored Aug 30, 2024
1 parent 30f731c commit 7ff8043
Show file tree
Hide file tree
Showing 11 changed files with 101 additions and 95 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ repos:
- '<!--| ~| -->'

- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.6.2
rev: v0.6.3
hooks:
- id: ruff
name: Run Ruff linter
Expand Down
6 changes: 3 additions & 3 deletions anta/cli/nrfu/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@

from __future__ import annotations

from typing import TYPE_CHECKING, get_args
from typing import TYPE_CHECKING

import click

from anta.cli.nrfu import commands
from anta.cli.utils import AliasedGroup, catalog_options, inventory_options
from anta.custom_types import TestStatus
from anta.result_manager import ResultManager
from anta.result_manager.models import AntaTestStatus

if TYPE_CHECKING:
from anta.catalog import AntaCatalog
Expand Down Expand Up @@ -49,7 +49,7 @@ def parse_args(self, ctx: click.Context, args: list[str]) -> list[str]:
return super().parse_args(ctx, args)


HIDE_STATUS: list[str] = list(get_args(TestStatus))
HIDE_STATUS: list[str] = list(AntaTestStatus)
HIDE_STATUS.remove("unset")


Expand Down
3 changes: 0 additions & 3 deletions anta/custom_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,9 +112,6 @@ def validate_regex(value: str) -> str:
return value


# ANTA framework
TestStatus = Literal["unset", "success", "failure", "error", "skipped"]

# AntaTest.Input types
AAAAuthMethod = Annotated[str, AfterValidator(aaa_group_prefix)]
Vlan = Annotated[int, Field(ge=0, le=4094)]
Expand Down
13 changes: 6 additions & 7 deletions anta/reporter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,8 @@
if TYPE_CHECKING:
import pathlib

from anta.custom_types import TestStatus
from anta.result_manager import ResultManager
from anta.result_manager.models import TestResult
from anta.result_manager.models import AntaTestStatus, TestResult

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -80,19 +79,19 @@ def _build_headers(self, headers: list[str], table: Table) -> Table:
table.add_column(header, justify="left")
return table

def _color_result(self, status: TestStatus) -> str:
"""Return a colored string based on the status value.
def _color_result(self, status: AntaTestStatus) -> str:
"""Return a colored string based on an AntaTestStatus.
Parameters
----------
status (TestStatus): status value to color.
status: AntaTestStatus enum to color.
Returns
-------
str: the colored string
The colored string.
"""
color = RICH_COLOR_THEME.get(status, "")
color = RICH_COLOR_THEME.get(str(status), "")
return f"[{color}]{status}" if color != "" else str(status)

def report_all(self, manager: ResultManager, title: str = "All tests results") -> Table:
Expand Down
9 changes: 5 additions & 4 deletions anta/reporter/md_reporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from anta.constants import MD_REPORT_TOC
from anta.logger import anta_log_exception
from anta.result_manager.models import AntaTestStatus

if TYPE_CHECKING:
from collections.abc import Generator
Expand Down Expand Up @@ -203,10 +204,10 @@ def generate_rows(self) -> Generator[str, None, None]:
"""Generate the rows of the summary totals table."""
yield (
f"| {self.results.get_total_results()} "
f"| {self.results.get_total_results({'success'})} "
f"| {self.results.get_total_results({'skipped'})} "
f"| {self.results.get_total_results({'failure'})} "
f"| {self.results.get_total_results({'error'})} |\n"
f"| {self.results.get_total_results({AntaTestStatus.SUCCESS})} "
f"| {self.results.get_total_results({AntaTestStatus.SKIPPED})} "
f"| {self.results.get_total_results({AntaTestStatus.FAILURE})} "
f"| {self.results.get_total_results({AntaTestStatus.ERROR})} |\n"
)

def generate_section(self) -> None:
Expand Down
36 changes: 15 additions & 21 deletions anta/result_manager/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,9 @@
from collections import defaultdict
from functools import cached_property
from itertools import chain
from typing import get_args

from pydantic import TypeAdapter

from anta.constants import ACRONYM_CATEGORIES
from anta.custom_types import TestStatus
from anta.result_manager.models import TestResult
from anta.result_manager.models import AntaTestStatus, TestResult

from .models import CategoryStats, DeviceStats, TestStats

Expand Down Expand Up @@ -95,7 +91,7 @@ def __init__(self) -> None:
error_status is set to True.
"""
self._result_entries: list[TestResult] = []
self.status: TestStatus = "unset"
self.status: AntaTestStatus = AntaTestStatus.UNSET
self.error_status = False

self.device_stats: defaultdict[str, DeviceStats] = defaultdict(DeviceStats)
Expand All @@ -116,7 +112,7 @@ def results(self, value: list[TestResult]) -> None:
"""Set the list of TestResult."""
# When setting the results, we need to reset the state of the current instance
self._result_entries = []
self.status = "unset"
self.status = AntaTestStatus.UNSET
self.error_status = False

# Also reset the stats attributes
Expand All @@ -138,26 +134,24 @@ def sorted_category_stats(self) -> dict[str, CategoryStats]:
return dict(sorted(self.category_stats.items()))

@cached_property
def results_by_status(self) -> dict[TestStatus, list[TestResult]]:
def results_by_status(self) -> dict[AntaTestStatus, list[TestResult]]:
"""A cached property that returns the results grouped by status."""
return {status: [result for result in self._result_entries if result.result == status] for status in get_args(TestStatus)}
return {status: [result for result in self._result_entries if result.result == status] for status in AntaTestStatus}

def _update_status(self, test_status: TestStatus) -> None:
def _update_status(self, test_status: AntaTestStatus) -> None:
"""Update the status of the ResultManager instance based on the test status.
Parameters
----------
test_status: TestStatus to update the ResultManager status.
test_status: AntaTestStatus to update the ResultManager status.
"""
result_validator: TypeAdapter[TestStatus] = TypeAdapter(TestStatus)
result_validator.validate_python(test_status)
if test_status == "error":
self.error_status = True
return
if self.status == "unset" or self.status == "skipped" and test_status in {"success", "failure"}:
self.status = test_status
elif self.status == "success" and test_status == "failure":
self.status = "failure"
self.status = AntaTestStatus.FAILURE

def _update_stats(self, result: TestResult) -> None:
"""Update the statistics based on the test result.
Expand Down Expand Up @@ -209,14 +203,14 @@ def add(self, result: TestResult) -> None:
# Every time a new result is added, we need to clear the cached property
self.__dict__.pop("results_by_status", None)

def get_results(self, status: set[TestStatus] | None = None, sort_by: list[str] | None = None) -> list[TestResult]:
def get_results(self, status: set[AntaTestStatus] | None = None, sort_by: list[str] | None = None) -> list[TestResult]:
"""Get the results, optionally filtered by status and sorted by TestResult fields.
If no status is provided, all results are returned.
Parameters
----------
status: Optional set of TestStatus literals to filter the results.
status: Optional set of AntaTestStatus enum members to filter the results.
sort_by: Optional list of TestResult fields to sort the results.
Returns
Expand All @@ -235,14 +229,14 @@ def get_results(self, status: set[TestStatus] | None = None, sort_by: list[str]

return results

def get_total_results(self, status: set[TestStatus] | None = None) -> int:
def get_total_results(self, status: set[AntaTestStatus] | None = None) -> int:
"""Get the total number of results, optionally filtered by status.
If no status is provided, the total number of results is returned.
Parameters
----------
status: Optional set of TestStatus literals to filter the results.
status: Optional set of AntaTestStatus enum members to filter the results.
Returns
-------
Expand All @@ -259,18 +253,18 @@ def get_status(self, *, ignore_error: bool = False) -> str:
"""Return the current status including error_status if ignore_error is False."""
return "error" if self.error_status and not ignore_error else self.status

def filter(self, hide: set[TestStatus]) -> ResultManager:
def filter(self, hide: set[AntaTestStatus]) -> ResultManager:
"""Get a filtered ResultManager based on test status.
Parameters
----------
hide: set of TestStatus literals to select tests to hide based on their status.
hide: Set of AntaTestStatus enum members to select tests to hide based on their status.
Returns
-------
A filtered `ResultManager`.
"""
possible_statuses = set(get_args(TestStatus))
possible_statuses = set(AntaTestStatus)
manager = ResultManager()
manager.results = self.get_results(possible_statuses - hide)
return manager
Expand Down
44 changes: 30 additions & 14 deletions anta/result_manager/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,32 +6,48 @@
from __future__ import annotations

from dataclasses import dataclass, field
from enum import Enum

from pydantic import BaseModel

from anta.custom_types import TestStatus

class AntaTestStatus(str, Enum):
"""Test status Enum for the TestResult.
NOTE: This could be updated to StrEnum when Python 3.11 is the minimum supported version in ANTA.
"""

UNSET = "unset"
SUCCESS = "success"
FAILURE = "failure"
ERROR = "error"
SKIPPED = "skipped"

def __str__(self) -> str:
"""Override the __str__ method to return the value of the Enum, mimicking the behavior of StrEnum."""
return self.value


class TestResult(BaseModel):
"""Describe the result of a test from a single device.
Attributes
----------
name: Device name where the test has run.
test: Test name runs on the device.
categories: List of categories the TestResult belongs to, by default the AntaTest categories.
description: TestResult description, by default the AntaTest description.
result: Result of the test. Can be one of "unset", "success", "failure", "error" or "skipped".
messages: Message to report after the test if any.
custom_field: Custom field to store a string for flexibility in integrating with ANTA
name: Name of the device where the test was run.
test: Name of the test run on the device.
categories: List of categories the TestResult belongs to. Defaults to the AntaTest categories.
description: Description of the TestResult. Defaults to the AntaTest description.
result: Result of the test. Must be one of the AntaTestStatus Enum values: unset, success, failure, error or skipped.
messages: Messages to report after the test, if any.
custom_field: Custom field to store a string for flexibility in integrating with ANTA.
"""

name: str
test: str
categories: list[str]
description: str
result: TestStatus = "unset"
result: AntaTestStatus = AntaTestStatus.UNSET
messages: list[str] = []
custom_field: str | None = None

Expand All @@ -43,7 +59,7 @@ def is_success(self, message: str | None = None) -> None:
message: Optional message related to the test
"""
self._set_status("success", message)
self._set_status(AntaTestStatus.SUCCESS, message)

def is_failure(self, message: str | None = None) -> None:
"""Set status to failure.
Expand All @@ -53,7 +69,7 @@ def is_failure(self, message: str | None = None) -> None:
message: Optional message related to the test
"""
self._set_status("failure", message)
self._set_status(AntaTestStatus.FAILURE, message)

def is_skipped(self, message: str | None = None) -> None:
"""Set status to skipped.
Expand All @@ -63,7 +79,7 @@ def is_skipped(self, message: str | None = None) -> None:
message: Optional message related to the test
"""
self._set_status("skipped", message)
self._set_status(AntaTestStatus.SKIPPED, message)

def is_error(self, message: str | None = None) -> None:
"""Set status to error.
Expand All @@ -73,9 +89,9 @@ def is_error(self, message: str | None = None) -> None:
message: Optional message related to the test
"""
self._set_status("error", message)
self._set_status(AntaTestStatus.ERROR, message)

def _set_status(self, status: TestStatus, message: str | None = None) -> None:
def _set_status(self, status: AntaTestStatus, message: str | None = None) -> None:
"""Set status and insert optional message.
Parameters
Expand Down
2 changes: 1 addition & 1 deletion asynceapi/aio_portcheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
# -----------------------------------------------------------------------------


async def port_check_url(url: URL, timeout: int = 5) -> bool: # noqa: ASYNC109
async def port_check_url(url: URL, timeout: int = 5) -> bool:
"""
Open the port designated by the URL given the timeout in seconds.
Expand Down
19 changes: 9 additions & 10 deletions tests/units/reporter/test__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@

from anta import RICH_COLOR_PALETTE
from anta.reporter import ReportJinja, ReportTable
from anta.result_manager.models import AntaTestStatus

if TYPE_CHECKING:
from anta.custom_types import TestStatus
from anta.result_manager import ResultManager


Expand Down Expand Up @@ -73,15 +73,14 @@ def test__build_headers(self, headers: list[str]) -> None:
@pytest.mark.parametrize(
("status", "expected_status"),
[
pytest.param("unknown", "unknown", id="unknown status"),
pytest.param("unset", "[grey74]unset", id="unset status"),
pytest.param("skipped", "[bold orange4]skipped", id="skipped status"),
pytest.param("failure", "[bold red]failure", id="failure status"),
pytest.param("error", "[indian_red]error", id="error status"),
pytest.param("success", "[green4]success", id="success status"),
pytest.param(AntaTestStatus.UNSET, "[grey74]unset", id="unset status"),
pytest.param(AntaTestStatus.SKIPPED, "[bold orange4]skipped", id="skipped status"),
pytest.param(AntaTestStatus.FAILURE, "[bold red]failure", id="failure status"),
pytest.param(AntaTestStatus.ERROR, "[indian_red]error", id="error status"),
pytest.param(AntaTestStatus.SUCCESS, "[green4]success", id="success status"),
],
)
def test__color_result(self, status: TestStatus, expected_status: str) -> None:
def test__color_result(self, status: AntaTestStatus, expected_status: str) -> None:
"""Test _build_headers."""
# pylint: disable=protected-access
report = ReportTable()
Expand Down Expand Up @@ -140,7 +139,7 @@ def test_report_summary_tests(
new_results = [result.model_copy() for result in manager.results]
for result in new_results:
result.name = "test_device"
result.result = "failure"
result.result = AntaTestStatus.FAILURE

report = ReportTable()
kwargs = {"tests": [test] if test is not None else None, "title": title}
Expand Down Expand Up @@ -175,7 +174,7 @@ def test_report_summary_devices(
new_results = [result.model_copy() for result in manager.results]
for result in new_results:
result.name = dev or "test_device"
result.result = "failure"
result.result = AntaTestStatus.FAILURE
manager.results = new_results

report = ReportTable()
Expand Down
Loading

0 comments on commit 7ff8043

Please sign in to comment.