Skip to content

Commit

Permalink
Merge branch 'main' into fix/issue-718
Browse files Browse the repository at this point in the history
  • Loading branch information
gmuloc authored Aug 30, 2024
2 parents c9b8085 + 7ff8043 commit 2fcc298
Show file tree
Hide file tree
Showing 28 changed files with 1,465 additions and 161 deletions.
6 changes: 3 additions & 3 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ repos:
- name: Check and insert license on Markdown files
id: insert-license
files: .*\.md$
# exclude:
exclude: ^tests/data/.*\.md$
args:
- --license-filepath
- .github/license-short.txt
Expand All @@ -43,7 +43,7 @@ repos:
- '<!--| ~| -->'

- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.6.1
rev: v0.6.3
hooks:
- id: ruff
name: Run Ruff linter
Expand Down Expand Up @@ -80,7 +80,7 @@ repos:
types: [text]

- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.11.1
rev: v1.11.2
hooks:
- id: mypy
name: Check typing with mypy
Expand Down
25 changes: 24 additions & 1 deletion anta/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,11 @@
import math
from collections import defaultdict
from inspect import isclass
from itertools import chain
from json import load as json_load
from pathlib import Path
from typing import TYPE_CHECKING, Any, Literal, Optional, Union
from warnings import warn

from pydantic import BaseModel, ConfigDict, RootModel, ValidationError, ValidationInfo, field_validator, model_serializer, model_validator
from pydantic.types import ImportString
Expand Down Expand Up @@ -386,6 +388,21 @@ def from_list(data: ListAntaTestTuples) -> AntaCatalog:
raise
return AntaCatalog(tests)

@classmethod
def merge_catalogs(cls, catalogs: list[AntaCatalog]) -> AntaCatalog:
"""Merge multiple AntaCatalog instances.
Parameters
----------
catalogs: A list of AntaCatalog instances to merge.
Returns
-------
A new AntaCatalog instance containing the tests of all the input catalogs.
"""
combined_tests = list(chain(*(catalog.tests for catalog in catalogs)))
return cls(tests=combined_tests)

def merge(self, catalog: AntaCatalog) -> AntaCatalog:
"""Merge two AntaCatalog instances.
Expand All @@ -397,7 +414,13 @@ def merge(self, catalog: AntaCatalog) -> AntaCatalog:
-------
A new AntaCatalog instance containing the tests of the two instances.
"""
return AntaCatalog(tests=self.tests + catalog.tests)
# TODO: Use a decorator to deprecate this method instead. See https://github.com/aristanetworks/anta/issues/754
warn(
message="AntaCatalog.merge() is deprecated and will be removed in ANTA v2.0. Use AntaCatalog.merge_catalogs() instead.",
category=DeprecationWarning,
stacklevel=2,
)
return self.merge_catalogs([self, catalog])

def dump(self) -> AntaCatalogFile:
"""Return an AntaCatalogFile instance from this AntaCatalog instance.
Expand Down
7 changes: 4 additions & 3 deletions anta/cli/nrfu/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@

from __future__ import annotations

from typing import TYPE_CHECKING, get_args
from typing import TYPE_CHECKING

import click

from anta.cli.nrfu import commands
from anta.cli.utils import AliasedGroup, catalog_options, inventory_options
from anta.custom_types import TestStatus
from anta.result_manager import ResultManager
from anta.result_manager.models import AntaTestStatus

if TYPE_CHECKING:
from anta.catalog import AntaCatalog
Expand Down Expand Up @@ -49,7 +49,7 @@ def parse_args(self, ctx: click.Context, args: list[str]) -> list[str]:
return super().parse_args(ctx, args)


HIDE_STATUS: list[str] = list(get_args(TestStatus))
HIDE_STATUS: list[str] = list(AntaTestStatus)
HIDE_STATUS.remove("unset")


Expand Down Expand Up @@ -147,3 +147,4 @@ def nrfu(
nrfu.add_command(commands.json)
nrfu.add_command(commands.text)
nrfu.add_command(commands.tpl_report)
nrfu.add_command(commands.md_report)
26 changes: 21 additions & 5 deletions anta/cli/nrfu/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from anta.cli.utils import exit_with_code

from .utils import print_jinja, print_json, print_table, print_text, run_tests, save_to_csv
from .utils import print_jinja, print_json, print_table, print_text, run_tests, save_markdown_report, save_to_csv

logger = logging.getLogger(__name__)

Expand All @@ -28,7 +28,7 @@
required=False,
)
def table(ctx: click.Context, group_by: Literal["device", "test"] | None) -> None:
"""ANTA command to check network states with table result."""
"""ANTA command to check network state with table results."""
run_tests(ctx)
print_table(ctx, group_by=group_by)
exit_with_code(ctx)
Expand All @@ -42,10 +42,10 @@ def table(ctx: click.Context, group_by: Literal["device", "test"] | None) -> Non
type=click.Path(file_okay=True, dir_okay=False, exists=False, writable=True, path_type=pathlib.Path),
show_envvar=True,
required=False,
help="Path to save report as a file",
help="Path to save report as a JSON file",
)
def json(ctx: click.Context, output: pathlib.Path | None) -> None:
"""ANTA command to check network state with JSON result."""
"""ANTA command to check network state with JSON results."""
run_tests(ctx)
print_json(ctx, output=output)
exit_with_code(ctx)
Expand All @@ -54,7 +54,7 @@ def json(ctx: click.Context, output: pathlib.Path | None) -> None:
@click.command()
@click.pass_context
def text(ctx: click.Context) -> None:
"""ANTA command to check network states with text result."""
"""ANTA command to check network state with text results."""
run_tests(ctx)
print_text(ctx)
exit_with_code(ctx)
Expand Down Expand Up @@ -105,3 +105,19 @@ def tpl_report(ctx: click.Context, template: pathlib.Path, output: pathlib.Path
run_tests(ctx)
print_jinja(results=ctx.obj["result_manager"], template=template, output=output)
exit_with_code(ctx)


@click.command()
@click.pass_context
@click.option(
"--md-output",
type=click.Path(file_okay=True, dir_okay=False, exists=False, writable=True, path_type=pathlib.Path),
show_envvar=True,
required=True,
help="Path to save the report as a Markdown file",
)
def md_report(ctx: click.Context, md_output: pathlib.Path) -> None:
"""ANTA command to check network state with Markdown report."""
run_tests(ctx)
save_markdown_report(ctx, md_output=md_output)
exit_with_code(ctx)
38 changes: 31 additions & 7 deletions anta/cli/nrfu/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from anta.models import AntaTest
from anta.reporter import ReportJinja, ReportTable
from anta.reporter.csv_reporter import ReportCsv
from anta.reporter.md_reporter import MDReportGenerator
from anta.runner import main

if TYPE_CHECKING:
Expand Down Expand Up @@ -94,14 +95,21 @@ def print_table(ctx: click.Context, group_by: Literal["device", "test"] | None =


def print_json(ctx: click.Context, output: pathlib.Path | None = None) -> None:
"""Print result in a json format."""
"""Print results as JSON. If output is provided, save to file instead."""
results = _get_result_manager(ctx)
console.print()
console.print(Panel("JSON results", style="cyan"))
rich.print_json(results.json)
if output is not None:
with output.open(mode="w", encoding="utf-8") as fout:
fout.write(results.json)

if output is None:
console.print()
console.print(Panel("JSON results", style="cyan"))
rich.print_json(results.json)
else:
try:
with output.open(mode="w", encoding="utf-8") as file:
file.write(results.json)
console.print(f"JSON results saved to {output} ✅", style="cyan")
except OSError:
console.print(f"Failed to save JSON results to {output} ❌", style="cyan")
ctx.exit(ExitCode.USAGE_ERROR)


def print_text(ctx: click.Context) -> None:
Expand Down Expand Up @@ -134,6 +142,22 @@ def save_to_csv(ctx: click.Context, csv_file: pathlib.Path) -> None:
ctx.exit(ExitCode.USAGE_ERROR)


def save_markdown_report(ctx: click.Context, md_output: pathlib.Path) -> None:
"""Save the markdown report to a file.
Parameters
----------
ctx: Click context containing the result manager.
md_output: Path to save the markdown report.
"""
try:
MDReportGenerator.generate(results=_get_result_manager(ctx), md_filename=md_output)
console.print(f"Markdown report saved to {md_output} ✅", style="cyan")
except OSError:
console.print(f"Failed to save Markdown report to {md_output} ❌", style="cyan")
ctx.exit(ExitCode.USAGE_ERROR)


# Adding our own ANTA spinner - overriding rich SPINNERS for our own
# so ignore warning for redefinition
rich.spinner.SPINNERS = { # type: ignore[attr-defined]
Expand Down
19 changes: 19 additions & 0 deletions anta/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Copyright (c) 2023-2024 Arista Networks, Inc.
# Use of this source code is governed by the Apache License 2.0
# that can be found in the LICENSE file.
"""Constants used in ANTA."""

from __future__ import annotations

ACRONYM_CATEGORIES: set[str] = {"aaa", "mlag", "snmp", "bgp", "ospf", "vxlan", "stp", "igmp", "ip", "lldp", "ntp", "bfd", "ptp", "lanz", "stun", "vlan"}
"""A set of network protocol or feature acronyms that should be represented in uppercase."""

MD_REPORT_TOC = """**Table of Contents:**
- [ANTA Report](#anta-report)
- [Test Results Summary](#test-results-summary)
- [Summary Totals](#summary-totals)
- [Summary Totals Device Under Test](#summary-totals-device-under-test)
- [Summary Totals Per Category](#summary-totals-per-category)
- [Test Results](#test-results)"""
"""Table of Contents for the Markdown report."""
3 changes: 0 additions & 3 deletions anta/custom_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,9 +112,6 @@ def validate_regex(value: str) -> str:
return value


# ANTA framework
TestStatus = Literal["unset", "success", "failure", "error", "skipped"]

# AntaTest.Input types
AAAAuthMethod = Annotated[str, AfterValidator(aaa_group_prefix)]
Vlan = Annotated[int, Field(ge=0, le=4094)]
Expand Down
49 changes: 18 additions & 31 deletions anta/reporter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,8 @@
if TYPE_CHECKING:
import pathlib

from anta.custom_types import TestStatus
from anta.result_manager import ResultManager
from anta.result_manager.models import TestResult
from anta.result_manager.models import AntaTestStatus, TestResult

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -80,19 +79,19 @@ def _build_headers(self, headers: list[str], table: Table) -> Table:
table.add_column(header, justify="left")
return table

def _color_result(self, status: TestStatus) -> str:
"""Return a colored string based on the status value.
def _color_result(self, status: AntaTestStatus) -> str:
"""Return a colored string based on an AntaTestStatus.
Parameters
----------
status (TestStatus): status value to color.
status: AntaTestStatus enum to color.
Returns
-------
str: the colored string
The colored string.
"""
color = RICH_COLOR_THEME.get(status, "")
color = RICH_COLOR_THEME.get(str(status), "")
return f"[{color}]{status}" if color != "" else str(status)

def report_all(self, manager: ResultManager, title: str = "All tests results") -> Table:
Expand Down Expand Up @@ -154,21 +153,15 @@ def report_summary_tests(
self.Headers.list_of_error_nodes,
]
table = self._build_headers(headers=headers, table=table)
for test in manager.get_tests():
for test, stats in sorted(manager.test_stats.items()):
if tests is None or test in tests:
results = manager.filter_by_tests({test}).results
nb_failure = len([result for result in results if result.result == "failure"])
nb_error = len([result for result in results if result.result == "error"])
list_failure = [result.name for result in results if result.result in ["failure", "error"]]
nb_success = len([result for result in results if result.result == "success"])
nb_skipped = len([result for result in results if result.result == "skipped"])
table.add_row(
test,
str(nb_success),
str(nb_skipped),
str(nb_failure),
str(nb_error),
str(list_failure),
str(stats.devices_success_count),
str(stats.devices_skipped_count),
str(stats.devices_failure_count),
str(stats.devices_error_count),
", ".join(stats.devices_failure),
)
return table

Expand Down Expand Up @@ -202,21 +195,15 @@ def report_summary_devices(
self.Headers.list_of_error_tests,
]
table = self._build_headers(headers=headers, table=table)
for device in manager.get_devices():
for device, stats in sorted(manager.device_stats.items()):
if devices is None or device in devices:
results = manager.filter_by_devices({device}).results
nb_failure = len([result for result in results if result.result == "failure"])
nb_error = len([result for result in results if result.result == "error"])
list_failure = [result.test for result in results if result.result in ["failure", "error"]]
nb_success = len([result for result in results if result.result == "success"])
nb_skipped = len([result for result in results if result.result == "skipped"])
table.add_row(
device,
str(nb_success),
str(nb_skipped),
str(nb_failure),
str(nb_error),
str(list_failure),
str(stats.tests_success_count),
str(stats.tests_skipped_count),
str(stats.tests_failure_count),
str(stats.tests_error_count),
", ".join(stats.tests_failure),
)
return table

Expand Down
Loading

0 comments on commit 2fcc298

Please sign in to comment.