Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(anta): Change TestStatus to be an Enum for coding clarity #758

Merged
merged 10 commits into from
Aug 30, 2024
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ repos:
- '<!--| ~| -->'

- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.6.2
rev: v0.6.3
hooks:
- id: ruff
name: Run Ruff linter
Expand Down
6 changes: 3 additions & 3 deletions anta/cli/nrfu/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@

from __future__ import annotations

from typing import TYPE_CHECKING, get_args
from typing import TYPE_CHECKING

import click

from anta.cli.nrfu import commands
from anta.cli.utils import AliasedGroup, catalog_options, inventory_options
from anta.custom_types import TestStatus
from anta.result_manager import ResultManager
from anta.result_manager.models import AntaTestStatus

if TYPE_CHECKING:
from anta.catalog import AntaCatalog
Expand Down Expand Up @@ -49,7 +49,7 @@ def parse_args(self, ctx: click.Context, args: list[str]) -> list[str]:
return super().parse_args(ctx, args)


HIDE_STATUS: list[str] = list(get_args(TestStatus))
HIDE_STATUS: list[str] = list(AntaTestStatus)
HIDE_STATUS.remove("unset")


Expand Down
3 changes: 0 additions & 3 deletions anta/custom_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,9 +112,6 @@ def validate_regex(value: str) -> str:
return value


# ANTA framework
TestStatus = Literal["unset", "success", "failure", "error", "skipped"]

# AntaTest.Input types
AAAAuthMethod = Annotated[str, AfterValidator(aaa_group_prefix)]
Vlan = Annotated[int, Field(ge=0, le=4094)]
Expand Down
13 changes: 6 additions & 7 deletions anta/reporter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,8 @@
if TYPE_CHECKING:
import pathlib

from anta.custom_types import TestStatus
from anta.result_manager import ResultManager
from anta.result_manager.models import TestResult
from anta.result_manager.models import AntaTestStatus, TestResult

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -80,19 +79,19 @@ def _build_headers(self, headers: list[str], table: Table) -> Table:
table.add_column(header, justify="left")
return table

def _color_result(self, status: TestStatus) -> str:
"""Return a colored string based on the status value.
def _color_result(self, status: AntaTestStatus) -> str:
"""Return a colored string based on an AntaTestStatus.

Parameters
----------
status (TestStatus): status value to color.
status: AntaTestStatus enum to color.

Returns
-------
str: the colored string
The colored string.

"""
color = RICH_COLOR_THEME.get(status, "")
color = RICH_COLOR_THEME.get(str(status), "")
return f"[{color}]{status}" if color != "" else str(status)

def report_all(self, manager: ResultManager, title: str = "All tests results") -> Table:
Expand Down
9 changes: 5 additions & 4 deletions anta/reporter/md_reporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from anta.constants import MD_REPORT_TOC
from anta.logger import anta_log_exception
from anta.result_manager.models import AntaTestStatus

if TYPE_CHECKING:
from collections.abc import Generator
Expand Down Expand Up @@ -203,10 +204,10 @@ def generate_rows(self) -> Generator[str, None, None]:
"""Generate the rows of the summary totals table."""
yield (
f"| {self.results.get_total_results()} "
f"| {self.results.get_total_results({'success'})} "
f"| {self.results.get_total_results({'skipped'})} "
f"| {self.results.get_total_results({'failure'})} "
f"| {self.results.get_total_results({'error'})} |\n"
f"| {self.results.get_total_results({AntaTestStatus.SUCCESS})} "
f"| {self.results.get_total_results({AntaTestStatus.SKIPPED})} "
f"| {self.results.get_total_results({AntaTestStatus.FAILURE})} "
f"| {self.results.get_total_results({AntaTestStatus.ERROR})} |\n"
)

def generate_section(self) -> None:
Expand Down
36 changes: 15 additions & 21 deletions anta/result_manager/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,9 @@
from collections import defaultdict
from functools import cached_property
from itertools import chain
from typing import get_args

from pydantic import TypeAdapter

from anta.constants import ACRONYM_CATEGORIES
from anta.custom_types import TestStatus
from anta.result_manager.models import TestResult
from anta.result_manager.models import AntaTestStatus, TestResult

from .models import CategoryStats, DeviceStats, TestStats

Expand Down Expand Up @@ -95,7 +91,7 @@ def __init__(self) -> None:
error_status is set to True.
"""
self._result_entries: list[TestResult] = []
self.status: TestStatus = "unset"
self.status: AntaTestStatus = AntaTestStatus.UNSET
self.error_status = False

self.device_stats: defaultdict[str, DeviceStats] = defaultdict(DeviceStats)
Expand All @@ -116,7 +112,7 @@ def results(self, value: list[TestResult]) -> None:
"""Set the list of TestResult."""
# When setting the results, we need to reset the state of the current instance
self._result_entries = []
self.status = "unset"
self.status = AntaTestStatus.UNSET
self.error_status = False

# Also reset the stats attributes
Expand All @@ -138,26 +134,24 @@ def sorted_category_stats(self) -> dict[str, CategoryStats]:
return dict(sorted(self.category_stats.items()))

@cached_property
def results_by_status(self) -> dict[TestStatus, list[TestResult]]:
def results_by_status(self) -> dict[AntaTestStatus, list[TestResult]]:
"""A cached property that returns the results grouped by status."""
return {status: [result for result in self._result_entries if result.result == status] for status in get_args(TestStatus)}
return {status: [result for result in self._result_entries if result.result == status] for status in AntaTestStatus}

def _update_status(self, test_status: TestStatus) -> None:
def _update_status(self, test_status: AntaTestStatus) -> None:
"""Update the status of the ResultManager instance based on the test status.

Parameters
----------
test_status: TestStatus to update the ResultManager status.
test_status: AntaTestStatus to update the ResultManager status.
"""
result_validator: TypeAdapter[TestStatus] = TypeAdapter(TestStatus)
result_validator.validate_python(test_status)
if test_status == "error":
self.error_status = True
return
if self.status == "unset" or self.status == "skipped" and test_status in {"success", "failure"}:
self.status = test_status
elif self.status == "success" and test_status == "failure":
self.status = "failure"
self.status = AntaTestStatus.FAILURE

def _update_stats(self, result: TestResult) -> None:
"""Update the statistics based on the test result.
Expand Down Expand Up @@ -209,14 +203,14 @@ def add(self, result: TestResult) -> None:
# Every time a new result is added, we need to clear the cached property
self.__dict__.pop("results_by_status", None)

def get_results(self, status: set[TestStatus] | None = None, sort_by: list[str] | None = None) -> list[TestResult]:
def get_results(self, status: set[AntaTestStatus] | None = None, sort_by: list[str] | None = None) -> list[TestResult]:
"""Get the results, optionally filtered by status and sorted by TestResult fields.

If no status is provided, all results are returned.

Parameters
----------
status: Optional set of TestStatus literals to filter the results.
status: Optional set of AntaTestStatus enum members to filter the results.
sort_by: Optional list of TestResult fields to sort the results.

Returns
Expand All @@ -235,14 +229,14 @@ def get_results(self, status: set[TestStatus] | None = None, sort_by: list[str]

return results

def get_total_results(self, status: set[TestStatus] | None = None) -> int:
def get_total_results(self, status: set[AntaTestStatus] | None = None) -> int:
"""Get the total number of results, optionally filtered by status.

If no status is provided, the total number of results is returned.

Parameters
----------
status: Optional set of TestStatus literals to filter the results.
status: Optional set of AntaTestStatus enum members to filter the results.

Returns
-------
Expand All @@ -259,18 +253,18 @@ def get_status(self, *, ignore_error: bool = False) -> str:
"""Return the current status including error_status if ignore_error is False."""
return "error" if self.error_status and not ignore_error else self.status

def filter(self, hide: set[TestStatus]) -> ResultManager:
def filter(self, hide: set[AntaTestStatus]) -> ResultManager:
"""Get a filtered ResultManager based on test status.

Parameters
----------
hide: set of TestStatus literals to select tests to hide based on their status.
hide: Set of AntaTestStatus enum members to select tests to hide based on their status.

Returns
-------
A filtered `ResultManager`.
"""
possible_statuses = set(get_args(TestStatus))
possible_statuses = set(AntaTestStatus)
manager = ResultManager()
manager.results = self.get_results(possible_statuses - hide)
return manager
Expand Down
44 changes: 30 additions & 14 deletions anta/result_manager/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,32 +6,48 @@
from __future__ import annotations

from dataclasses import dataclass, field
from enum import Enum

from pydantic import BaseModel

from anta.custom_types import TestStatus

class AntaTestStatus(str, Enum):
"""Test status Enum for the TestResult.

NOTE: This could be updated to StrEnum when Python 3.11 is the minimum supported version in ANTA.
"""

UNSET = "unset"
SUCCESS = "success"
FAILURE = "failure"
ERROR = "error"
SKIPPED = "skipped"

def __str__(self) -> str:
"""Override the __str__ method to return the value of the Enum, mimicking the behavior of StrEnum."""
return self.value


class TestResult(BaseModel):
"""Describe the result of a test from a single device.

Attributes
----------
name: Device name where the test has run.
test: Test name runs on the device.
categories: List of categories the TestResult belongs to, by default the AntaTest categories.
description: TestResult description, by default the AntaTest description.
result: Result of the test. Can be one of "unset", "success", "failure", "error" or "skipped".
messages: Message to report after the test if any.
custom_field: Custom field to store a string for flexibility in integrating with ANTA
name: Name of the device where the test was run.
test: Name of the test run on the device.
categories: List of categories the TestResult belongs to. Defaults to the AntaTest categories.
description: Description of the TestResult. Defaults to the AntaTest description.
result: Result of the test. Must be one of the AntaTestStatus Enum values: unset, success, failure, error, skipped.
gmuloc marked this conversation as resolved.
Show resolved Hide resolved
messages: Messages to report after the test, if any.
custom_field: Custom field to store a string for flexibility in integrating with ANTA.

"""

name: str
test: str
categories: list[str]
description: str
result: TestStatus = "unset"
result: AntaTestStatus = AntaTestStatus.UNSET
messages: list[str] = []
custom_field: str | None = None

Expand All @@ -43,7 +59,7 @@ def is_success(self, message: str | None = None) -> None:
message: Optional message related to the test

"""
self._set_status("success", message)
self._set_status(AntaTestStatus.SUCCESS, message)

def is_failure(self, message: str | None = None) -> None:
"""Set status to failure.
Expand All @@ -53,7 +69,7 @@ def is_failure(self, message: str | None = None) -> None:
message: Optional message related to the test

"""
self._set_status("failure", message)
self._set_status(AntaTestStatus.FAILURE, message)

def is_skipped(self, message: str | None = None) -> None:
"""Set status to skipped.
Expand All @@ -63,7 +79,7 @@ def is_skipped(self, message: str | None = None) -> None:
message: Optional message related to the test

"""
self._set_status("skipped", message)
self._set_status(AntaTestStatus.SKIPPED, message)

def is_error(self, message: str | None = None) -> None:
"""Set status to error.
Expand All @@ -73,9 +89,9 @@ def is_error(self, message: str | None = None) -> None:
message: Optional message related to the test

"""
self._set_status("error", message)
self._set_status(AntaTestStatus.ERROR, message)

def _set_status(self, status: TestStatus, message: str | None = None) -> None:
def _set_status(self, status: AntaTestStatus, message: str | None = None) -> None:
"""Set status and insert optional message.

Parameters
Expand Down
2 changes: 1 addition & 1 deletion asynceapi/aio_portcheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
# -----------------------------------------------------------------------------


async def port_check_url(url: URL, timeout: int = 5) -> bool: # noqa: ASYNC109
async def port_check_url(url: URL, timeout: int = 5) -> bool:
"""
Open the port designated by the URL given the timeout in seconds.

Expand Down
19 changes: 9 additions & 10 deletions tests/units/reporter/test__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@

from anta import RICH_COLOR_PALETTE
from anta.reporter import ReportJinja, ReportTable
from anta.result_manager.models import AntaTestStatus

if TYPE_CHECKING:
from anta.custom_types import TestStatus
from anta.result_manager import ResultManager


Expand Down Expand Up @@ -73,15 +73,14 @@ def test__build_headers(self, headers: list[str]) -> None:
@pytest.mark.parametrize(
("status", "expected_status"),
[
pytest.param("unknown", "unknown", id="unknown status"),
pytest.param("unset", "[grey74]unset", id="unset status"),
pytest.param("skipped", "[bold orange4]skipped", id="skipped status"),
pytest.param("failure", "[bold red]failure", id="failure status"),
pytest.param("error", "[indian_red]error", id="error status"),
pytest.param("success", "[green4]success", id="success status"),
pytest.param(AntaTestStatus.UNSET, "[grey74]unset", id="unset status"),
pytest.param(AntaTestStatus.SKIPPED, "[bold orange4]skipped", id="skipped status"),
pytest.param(AntaTestStatus.FAILURE, "[bold red]failure", id="failure status"),
pytest.param(AntaTestStatus.ERROR, "[indian_red]error", id="error status"),
pytest.param(AntaTestStatus.SUCCESS, "[green4]success", id="success status"),
],
)
def test__color_result(self, status: TestStatus, expected_status: str) -> None:
def test__color_result(self, status: AntaTestStatus, expected_status: str) -> None:
"""Test _build_headers."""
# pylint: disable=protected-access
report = ReportTable()
Expand Down Expand Up @@ -140,7 +139,7 @@ def test_report_summary_tests(
new_results = [result.model_copy() for result in manager.results]
for result in new_results:
result.name = "test_device"
result.result = "failure"
result.result = AntaTestStatus.FAILURE

report = ReportTable()
kwargs = {"tests": [test] if test is not None else None, "title": title}
Expand Down Expand Up @@ -175,7 +174,7 @@ def test_report_summary_devices(
new_results = [result.model_copy() for result in manager.results]
for result in new_results:
result.name = dev or "test_device"
result.result = "failure"
result.result = AntaTestStatus.FAILURE
manager.results = new_results

report = ReportTable()
Expand Down
Loading
Loading