Skip to content

Commit

Permalink
Merge branch 'main' into issue_786
Browse files Browse the repository at this point in the history
  • Loading branch information
vitthalmagadum authored Aug 30, 2024
2 parents 7b0ecf6 + aa1fde8 commit c2f80c6
Show file tree
Hide file tree
Showing 21 changed files with 350 additions and 122 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ repos:
- '<!--| ~| -->'

- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.6.2
rev: v0.6.3
hooks:
- id: ruff
name: Run Ruff linter
Expand Down
7 changes: 5 additions & 2 deletions anta/cli/debug/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,13 +72,16 @@ def run_template(
revision: int,
) -> None:
# pylint: disable=too-many-arguments
# Using \b for click
# ruff: noqa: D301
"""Run arbitrary templated command to an ANTA device.
Takes a list of arguments (keys followed by a value) to build a dictionary used as template parameters.
Example:
\b
Example
-------
anta debug run-template -d leaf1a -t 'show vlan {vlan_id}' vlan_id 1
anta debug run-template -d leaf1a -t 'show vlan {vlan_id}' vlan_id 1
"""
template_params = dict(zip(params[::2], params[1::2]))
Expand Down
5 changes: 2 additions & 3 deletions anta/cli/debug/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

import click

from anta.cli.utils import ExitCode, inventory_options
from anta.cli.utils import ExitCode, core_options

if TYPE_CHECKING:
from anta.inventory import AntaInventory
Expand All @@ -22,7 +22,7 @@
def debug_options(f: Callable[..., Any]) -> Callable[..., Any]:
"""Click common options required to execute a command on a specific device."""

@inventory_options
@core_options
@click.option(
"--ofmt",
type=click.Choice(["json", "text"]),
Expand All @@ -44,7 +44,6 @@ def wrapper(
ctx: click.Context,
*args: tuple[Any],
inventory: AntaInventory,
tags: set[str] | None,
device: str,
**kwargs: Any,
) -> Any:
Expand Down
6 changes: 3 additions & 3 deletions anta/cli/nrfu/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@

from __future__ import annotations

from typing import TYPE_CHECKING, get_args
from typing import TYPE_CHECKING

import click

from anta.cli.nrfu import commands
from anta.cli.utils import AliasedGroup, catalog_options, inventory_options
from anta.custom_types import TestStatus
from anta.result_manager import ResultManager
from anta.result_manager.models import AntaTestStatus

if TYPE_CHECKING:
from anta.catalog import AntaCatalog
Expand Down Expand Up @@ -49,7 +49,7 @@ def parse_args(self, ctx: click.Context, args: list[str]) -> list[str]:
return super().parse_args(ctx, args)


HIDE_STATUS: list[str] = list(get_args(TestStatus))
HIDE_STATUS: list[str] = list(AntaTestStatus)
HIDE_STATUS.remove("unset")


Expand Down
46 changes: 33 additions & 13 deletions anta/cli/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def resolve_command(self, ctx: click.Context, args: Any) -> Any:
return cmd.name, cmd, args


def inventory_options(f: Callable[..., Any]) -> Callable[..., Any]:
def core_options(f: Callable[..., Any]) -> Callable[..., Any]:
"""Click common options when requiring an inventory to interact with devices."""

@click.option(
Expand Down Expand Up @@ -190,22 +190,12 @@ def inventory_options(f: Callable[..., Any]) -> Callable[..., Any]:
required=True,
type=click.Path(file_okay=True, dir_okay=False, exists=True, readable=True, path_type=Path),
)
@click.option(
"--tags",
help="List of tags using comma as separator: tag1,tag2,tag3.",
show_envvar=True,
envvar="ANTA_TAGS",
type=str,
required=False,
callback=parse_tags,
)
@click.pass_context
@functools.wraps(f)
def wrapper(
ctx: click.Context,
*args: tuple[Any],
inventory: Path,
tags: set[str] | None,
username: str,
password: str | None,
enable_password: str | None,
Expand All @@ -219,7 +209,7 @@ def wrapper(
# pylint: disable=too-many-arguments
# If help is invoke somewhere, do not parse inventory
if ctx.obj.get("_anta_help"):
return f(*args, inventory=None, tags=tags, **kwargs)
return f(*args, inventory=None, **kwargs)
if prompt:
# User asked for a password prompt
if password is None:
Expand Down Expand Up @@ -255,7 +245,37 @@ def wrapper(
)
except (TypeError, ValueError, YAMLError, OSError, InventoryIncorrectSchemaError, InventoryRootKeyError):
ctx.exit(ExitCode.USAGE_ERROR)
return f(*args, inventory=i, tags=tags, **kwargs)
return f(*args, inventory=i, **kwargs)

return wrapper


def inventory_options(f: Callable[..., Any]) -> Callable[..., Any]:
"""Click common options when requiring an inventory to interact with devices."""

@core_options
@click.option(
"--tags",
help="List of tags using comma as separator: tag1,tag2,tag3.",
show_envvar=True,
envvar="ANTA_TAGS",
type=str,
required=False,
callback=parse_tags,
)
@click.pass_context
@functools.wraps(f)
def wrapper(
ctx: click.Context,
*args: tuple[Any],
tags: set[str] | None,
**kwargs: dict[str, Any],
) -> Any:
# pylint: disable=too-many-arguments
# If help is invoke somewhere, do not parse inventory
if ctx.obj.get("_anta_help"):
return f(*args, tags=tags, **kwargs)
return f(*args, tags=tags, **kwargs)

return wrapper

Expand Down
3 changes: 0 additions & 3 deletions anta/custom_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,9 +112,6 @@ def validate_regex(value: str) -> str:
return value


# ANTA framework
TestStatus = Literal["unset", "success", "failure", "error", "skipped"]

# AntaTest.Input types
AAAAuthMethod = Annotated[str, AfterValidator(aaa_group_prefix)]
Vlan = Annotated[int, Field(ge=0, le=4094)]
Expand Down
13 changes: 6 additions & 7 deletions anta/reporter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,8 @@
if TYPE_CHECKING:
import pathlib

from anta.custom_types import TestStatus
from anta.result_manager import ResultManager
from anta.result_manager.models import TestResult
from anta.result_manager.models import AntaTestStatus, TestResult

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -80,19 +79,19 @@ def _build_headers(self, headers: list[str], table: Table) -> Table:
table.add_column(header, justify="left")
return table

def _color_result(self, status: TestStatus) -> str:
"""Return a colored string based on the status value.
def _color_result(self, status: AntaTestStatus) -> str:
"""Return a colored string based on an AntaTestStatus.
Parameters
----------
status (TestStatus): status value to color.
status: AntaTestStatus enum to color.
Returns
-------
str: the colored string
The colored string.
"""
color = RICH_COLOR_THEME.get(status, "")
color = RICH_COLOR_THEME.get(str(status), "")
return f"[{color}]{status}" if color != "" else str(status)

def report_all(self, manager: ResultManager, title: str = "All tests results") -> Table:
Expand Down
9 changes: 5 additions & 4 deletions anta/reporter/md_reporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from anta.constants import MD_REPORT_TOC
from anta.logger import anta_log_exception
from anta.result_manager.models import AntaTestStatus

if TYPE_CHECKING:
from collections.abc import Generator
Expand Down Expand Up @@ -203,10 +204,10 @@ def generate_rows(self) -> Generator[str, None, None]:
"""Generate the rows of the summary totals table."""
yield (
f"| {self.results.get_total_results()} "
f"| {self.results.get_total_results({'success'})} "
f"| {self.results.get_total_results({'skipped'})} "
f"| {self.results.get_total_results({'failure'})} "
f"| {self.results.get_total_results({'error'})} |\n"
f"| {self.results.get_total_results({AntaTestStatus.SUCCESS})} "
f"| {self.results.get_total_results({AntaTestStatus.SKIPPED})} "
f"| {self.results.get_total_results({AntaTestStatus.FAILURE})} "
f"| {self.results.get_total_results({AntaTestStatus.ERROR})} |\n"
)

def generate_section(self) -> None:
Expand Down
36 changes: 15 additions & 21 deletions anta/result_manager/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,9 @@
from collections import defaultdict
from functools import cached_property
from itertools import chain
from typing import get_args

from pydantic import TypeAdapter

from anta.constants import ACRONYM_CATEGORIES
from anta.custom_types import TestStatus
from anta.result_manager.models import TestResult
from anta.result_manager.models import AntaTestStatus, TestResult

from .models import CategoryStats, DeviceStats, TestStats

Expand Down Expand Up @@ -95,7 +91,7 @@ def __init__(self) -> None:
error_status is set to True.
"""
self._result_entries: list[TestResult] = []
self.status: TestStatus = "unset"
self.status: AntaTestStatus = AntaTestStatus.UNSET
self.error_status = False

self.device_stats: defaultdict[str, DeviceStats] = defaultdict(DeviceStats)
Expand All @@ -116,7 +112,7 @@ def results(self, value: list[TestResult]) -> None:
"""Set the list of TestResult."""
# When setting the results, we need to reset the state of the current instance
self._result_entries = []
self.status = "unset"
self.status = AntaTestStatus.UNSET
self.error_status = False

# Also reset the stats attributes
Expand All @@ -138,26 +134,24 @@ def sorted_category_stats(self) -> dict[str, CategoryStats]:
return dict(sorted(self.category_stats.items()))

@cached_property
def results_by_status(self) -> dict[TestStatus, list[TestResult]]:
def results_by_status(self) -> dict[AntaTestStatus, list[TestResult]]:
"""A cached property that returns the results grouped by status."""
return {status: [result for result in self._result_entries if result.result == status] for status in get_args(TestStatus)}
return {status: [result for result in self._result_entries if result.result == status] for status in AntaTestStatus}

def _update_status(self, test_status: TestStatus) -> None:
def _update_status(self, test_status: AntaTestStatus) -> None:
"""Update the status of the ResultManager instance based on the test status.
Parameters
----------
test_status: TestStatus to update the ResultManager status.
test_status: AntaTestStatus to update the ResultManager status.
"""
result_validator: TypeAdapter[TestStatus] = TypeAdapter(TestStatus)
result_validator.validate_python(test_status)
if test_status == "error":
self.error_status = True
return
if self.status == "unset" or self.status == "skipped" and test_status in {"success", "failure"}:
self.status = test_status
elif self.status == "success" and test_status == "failure":
self.status = "failure"
self.status = AntaTestStatus.FAILURE

def _update_stats(self, result: TestResult) -> None:
"""Update the statistics based on the test result.
Expand Down Expand Up @@ -209,14 +203,14 @@ def add(self, result: TestResult) -> None:
# Every time a new result is added, we need to clear the cached property
self.__dict__.pop("results_by_status", None)

def get_results(self, status: set[TestStatus] | None = None, sort_by: list[str] | None = None) -> list[TestResult]:
def get_results(self, status: set[AntaTestStatus] | None = None, sort_by: list[str] | None = None) -> list[TestResult]:
"""Get the results, optionally filtered by status and sorted by TestResult fields.
If no status is provided, all results are returned.
Parameters
----------
status: Optional set of TestStatus literals to filter the results.
status: Optional set of AntaTestStatus enum members to filter the results.
sort_by: Optional list of TestResult fields to sort the results.
Returns
Expand All @@ -235,14 +229,14 @@ def get_results(self, status: set[TestStatus] | None = None, sort_by: list[str]

return results

def get_total_results(self, status: set[TestStatus] | None = None) -> int:
def get_total_results(self, status: set[AntaTestStatus] | None = None) -> int:
"""Get the total number of results, optionally filtered by status.
If no status is provided, the total number of results is returned.
Parameters
----------
status: Optional set of TestStatus literals to filter the results.
status: Optional set of AntaTestStatus enum members to filter the results.
Returns
-------
Expand All @@ -259,18 +253,18 @@ def get_status(self, *, ignore_error: bool = False) -> str:
"""Return the current status including error_status if ignore_error is False."""
return "error" if self.error_status and not ignore_error else self.status

def filter(self, hide: set[TestStatus]) -> ResultManager:
def filter(self, hide: set[AntaTestStatus]) -> ResultManager:
"""Get a filtered ResultManager based on test status.
Parameters
----------
hide: set of TestStatus literals to select tests to hide based on their status.
hide: Set of AntaTestStatus enum members to select tests to hide based on their status.
Returns
-------
A filtered `ResultManager`.
"""
possible_statuses = set(get_args(TestStatus))
possible_statuses = set(AntaTestStatus)
manager = ResultManager()
manager.results = self.get_results(possible_statuses - hide)
return manager
Expand Down
Loading

0 comments on commit c2f80c6

Please sign in to comment.