Skip to content

Commit

Permalink
Refactor: Change TestStatus to be an Enum for coding clarity
Browse files Browse the repository at this point in the history
  • Loading branch information
gmuloc committed Jul 18, 2024
1 parent bb0b2ba commit 8ca3c2e
Show file tree
Hide file tree
Showing 8 changed files with 46 additions and 34 deletions.
6 changes: 3 additions & 3 deletions anta/cli/nrfu/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@

from __future__ import annotations

from typing import TYPE_CHECKING, get_args
from typing import TYPE_CHECKING

import click

from anta.cli.nrfu import commands
from anta.cli.utils import AliasedGroup, catalog_options, inventory_options
from anta.custom_types import TestStatus
from anta.result_manager import ResultManager
from anta.result_manager.models import TestStatus

if TYPE_CHECKING:
from anta.catalog import AntaCatalog
Expand Down Expand Up @@ -49,7 +49,7 @@ def parse_args(self, ctx: click.Context, args: list[str]) -> list[str]:
return super().parse_args(ctx, args)


HIDE_STATUS: list[str] = list(get_args(TestStatus))
HIDE_STATUS: list[str] = list(TestStatus)
HIDE_STATUS.remove("unset")


Expand Down
3 changes: 0 additions & 3 deletions anta/custom_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,9 +112,6 @@ def validate_regex(value: str) -> str:
return value


# ANTA framework
TestStatus = Literal["unset", "success", "failure", "error", "skipped"]

# AntaTest.Input types
AAAAuthMethod = Annotated[str, AfterValidator(aaa_group_prefix)]
Vlan = Annotated[int, Field(ge=0, le=4094)]
Expand Down
3 changes: 1 addition & 2 deletions anta/reporter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,8 @@
if TYPE_CHECKING:
import pathlib

from anta.custom_types import TestStatus
from anta.result_manager import ResultManager
from anta.result_manager.models import TestResult
from anta.result_manager.models import TestResult, TestStatus

logger = logging.getLogger(__name__)

Expand Down
8 changes: 4 additions & 4 deletions anta/result_manager/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from pydantic import TypeAdapter

from anta.custom_types import TestStatus
from anta.result_manager.models import TestStatus

if TYPE_CHECKING:
from anta.result_manager.models import TestResult
Expand Down Expand Up @@ -91,7 +91,7 @@ def __init__(self) -> None:
error_status is set to True.
"""
self._result_entries: list[TestResult] = []
self.status: TestStatus = "unset"
self.status: TestStatus = TestStatus.unset
self.error_status = False

def __len__(self) -> int:
Expand All @@ -106,7 +106,7 @@ def results(self) -> list[TestResult]:
@results.setter
def results(self, value: list[TestResult]) -> None:
self._result_entries = []
self.status = "unset"
self.status = TestStatus.unset
self.error_status = False
for e in value:
self.add(e)
Expand All @@ -133,7 +133,7 @@ def _update_status(test_status: TestStatus) -> None:
if self.status == "unset" or self.status == "skipped" and test_status in {"success", "failure"}:
self.status = test_status
elif self.status == "success" and test_status == "failure":
self.status = "failure"
self.status = TestStatus.failure

self._result_entries.append(result)
_update_status(result.result)
Expand Down
26 changes: 20 additions & 6 deletions anta/result_manager/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,23 @@

from __future__ import annotations

from enum import Enum

from pydantic import BaseModel

from anta.custom_types import TestStatus

class TestStatus(str, Enum):
"""TestStatus enum."""

# This is to prevent pytest to collecting this
# TODO: find a way to ignore this only in test and not in library code.
__test__ = False

unset = "unset"
success = "success"
failure = "failure"
error = "error"
skipped = "skipped"


class TestResult(BaseModel):
Expand All @@ -29,7 +43,7 @@ class TestResult(BaseModel):
test: str
categories: list[str]
description: str
result: TestStatus = "unset"
result: TestStatus = TestStatus.unset
messages: list[str] = []
custom_field: str | None = None

Expand All @@ -41,7 +55,7 @@ def is_success(self, message: str | None = None) -> None:
message: Optional message related to the test
"""
self._set_status("success", message)
self._set_status(TestStatus.success, message)

def is_failure(self, message: str | None = None) -> None:
"""Set status to failure.
Expand All @@ -51,7 +65,7 @@ def is_failure(self, message: str | None = None) -> None:
message: Optional message related to the test
"""
self._set_status("failure", message)
self._set_status(TestStatus.failure, message)

def is_skipped(self, message: str | None = None) -> None:
"""Set status to skipped.
Expand All @@ -61,7 +75,7 @@ def is_skipped(self, message: str | None = None) -> None:
message: Optional message related to the test
"""
self._set_status("skipped", message)
self._set_status(TestStatus.skipped, message)

def is_error(self, message: str | None = None) -> None:
"""Set status to error.
Expand All @@ -71,7 +85,7 @@ def is_error(self, message: str | None = None) -> None:
message: Optional message related to the test
"""
self._set_status("error", message)
self._set_status(TestStatus.error, message)

def _set_status(self, status: TestStatus, message: str | None = None) -> None:
"""Set status and insert optional message.
Expand Down
6 changes: 3 additions & 3 deletions tests/units/reporter/test__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@

from anta import RICH_COLOR_PALETTE
from anta.reporter import ReportJinja, ReportTable
from anta.result_manager.models import TestStatus

if TYPE_CHECKING:
from anta.custom_types import TestStatus
from anta.result_manager import ResultManager


Expand Down Expand Up @@ -140,7 +140,7 @@ def test_report_summary_tests(
new_results = [result.model_copy() for result in manager.results]
for result in new_results:
result.name = "test_device"
result.result = "failure"
result.result = TestStatus.failure

report = ReportTable()
kwargs = {"tests": [test] if test is not None else None, "title": title}
Expand Down Expand Up @@ -175,7 +175,7 @@ def test_report_summary_devices(
new_results = [result.model_copy() for result in manager.results]
for result in new_results:
result.name = dev or "test_device"
result.result = "failure"
result.result = TestStatus.failure
manager.results = new_results

report = ReportTable()
Expand Down
24 changes: 12 additions & 12 deletions tests/units/result_manager/test__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@
import pytest

from anta.result_manager import ResultManager, models
from anta.result_manager.models import TestStatus

if TYPE_CHECKING:
from anta.custom_types import TestStatus
from anta.result_manager.models import TestResult


Expand Down Expand Up @@ -55,7 +55,7 @@ def test_json(self, list_result_factory: Callable[[int], list[TestResult]]) -> N

success_list = list_result_factory(3)
for test in success_list:
test.result = "success"
test.result = TestStatus.success
result_manager.results = success_list

json_res = result_manager.json
Expand Down Expand Up @@ -177,28 +177,28 @@ def test_filter(self, test_result_factory: Callable[[], TestResult], list_result

success_list = list_result_factory(3)
for test in success_list:
test.result = "success"
test.result = TestStatus.success
result_manager.results = success_list

test = test_result_factory()
test.result = "failure"
test.result = TestStatus.failure
result_manager.add(test)

test = test_result_factory()
test.result = "error"
test.result = TestStatus.error
result_manager.add(test)

test = test_result_factory()
test.result = "skipped"
test.result = TestStatus.skipped
result_manager.add(test)

assert len(result_manager) == 6
assert len(result_manager.filter({"failure"})) == 5
assert len(result_manager.filter({"error"})) == 5
assert len(result_manager.filter({"skipped"})) == 5
assert len(result_manager.filter({"failure", "error"})) == 4
assert len(result_manager.filter({"failure", "error", "skipped"})) == 3
assert len(result_manager.filter({"success", "failure", "error", "skipped"})) == 0
assert len(result_manager.filter({TestStatus.failure})) == 5
assert len(result_manager.filter({TestStatus.error})) == 5
assert len(result_manager.filter({TestStatus.skipped})) == 5
assert len(result_manager.filter({TestStatus.failure, TestStatus.error})) == 4
assert len(result_manager.filter({TestStatus.failure, TestStatus.error, TestStatus.skipped})) == 3
assert len(result_manager.filter({TestStatus.success, TestStatus.failure, TestStatus.error, TestStatus.skipped})) == 0

def test_get_by_tests(self, test_result_factory: Callable[[], TestResult], result_manager_factory: Callable[[int], ResultManager]) -> None:
"""Test ResultManager.get_by_tests."""
Expand Down
4 changes: 3 additions & 1 deletion tests/units/result_manager/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@

import pytest

from anta.result_manager.models import TestStatus

# Import as Result to avoid pytest collection
from tests.data.json_data import TEST_RESULT_SET_STATUS
from tests.lib.fixture import DEVICE_NAME
Expand Down Expand Up @@ -45,7 +47,7 @@ def test__is_status_foo(self, test_result_factory: Callable[[int], Result], data
assert data["message"] in testresult.messages
# no helper for unset, testing _set_status
if data["target"] == "unset":
testresult._set_status("unset", data["message"]) # pylint: disable=W0212
testresult._set_status(TestStatus.unset, data["message"]) # pylint: disable=W0212
assert testresult.result == data["target"]
assert data["message"] in testresult.messages

Expand Down

0 comments on commit 8ca3c2e

Please sign in to comment.