Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(anta.cli): Add Progress Bar for NRFU #251

Merged
merged 8 commits into from
Jul 7, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,19 @@ repos:
args:
- --config-file=pyproject.toml
additional_dependencies:
- pydantic~=2.0
- "aio-eapi==0.3.0"
- "click==8.1.3"
- "click-help-colors==0.9.1"
- "cvprac>=1.2.0"
- "netaddr>=0.8.0"
- "pydantic~=2.0"
- "PyYAML>=6.0"
- "requests"
- "rich>=12.5.1"
- "scp"
- "asyncssh==2.13.1"
- "Jinja2>=3.1.2"
- types-PyYAML
- types-paramiko
- types-requests
files: ^(anta|scripts|tests)/
files: ^(anta|tests)/
8 changes: 4 additions & 4 deletions anta/cli/exec/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,10 @@ async def collect(dev: AntaDevice, command: str, outformat: Literal["json", "tex
logger.error(f"Could not collect commands on device {dev.name}: {exc_to_str(c.failed)}")
return
if c.ofmt == "json":
outfile = outdir / (command + ".json")
outfile = outdir / f"{command}.json"
content = json.dumps(c.json_output, indent=2)
elif c.ofmt == "text":
outfile = outdir / (command + ".log")
outfile = outdir / f"{command}.log"
content = c.text_output
with outfile.open(mode="w", encoding="UTF-8") as f:
f.write(content)
Expand All @@ -88,7 +88,7 @@ async def collect(dev: AntaDevice, command: str, outformat: Literal["json", "tex
if __DEBUG__:
logger.exception(message, exc_info=r)
else:
logger.error(message + f": {exc_to_str(r)}")
logger.error(f"{message}: {exc_to_str(r)}")


async def collect_scheduled_show_tech(inv: AntaInventory, root_dir: Path, configure: bool, tags: Optional[List[str]] = None, latest: Optional[int] = None) -> None:
Expand Down Expand Up @@ -150,7 +150,7 @@ async def collect(device: AntaDevice) -> None:
if __DEBUG__:
logger.exception(message)
else:
logger.error(message + f": {exc_to_str(e)}")
logger.error(f"{message}: {exc_to_str(e)}")

logger.info("Connecting to devices...")
await inv.connect_inventory()
Expand Down
14 changes: 9 additions & 5 deletions anta/cli/nrfu/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from anta.result_manager import ResultManager
from anta.runner import main

from .utils import print_jinja, print_json, print_settings, print_table, print_text
from .utils import anta_progress_bar, print_jinja, print_json, print_settings, print_table, print_text

logger = logging.getLogger(__name__)

Expand All @@ -30,7 +30,8 @@ def table(ctx: click.Context, tags: Optional[List[str]], device: Optional[str],
"""ANTA command to check network states with table result"""
print_settings(ctx)
results = ResultManager()
asyncio.run(main(results, ctx.obj["inventory"], ctx.obj["catalog"], tags=tags))
with anta_progress_bar() as progress:
asyncio.run(main(results, ctx.obj["inventory"], ctx.obj["catalog"], tags=tags, progress=progress))
print_table(results=results, device=device, test=test)

# TODO make a util method to avoid repeating the same three line
Expand All @@ -54,7 +55,8 @@ def json(ctx: click.Context, tags: Optional[List[str]], output: Optional[pathlib
"""ANTA command to check network state with JSON result"""
print_settings(ctx)
results = ResultManager()
asyncio.run(main(results, ctx.obj["inventory"], ctx.obj["catalog"], tags=tags))
with anta_progress_bar() as progress:
asyncio.run(main(results, ctx.obj["inventory"], ctx.obj["catalog"], tags=tags, progress=progress))
print_json(results=results, output=output)

ignore_status = ctx.obj["ignore_status"]
Expand All @@ -71,7 +73,8 @@ def text(ctx: click.Context, tags: Optional[List[str]], search: Optional[str], s
"""ANTA command to check network states with text result"""
print_settings(ctx)
results = ResultManager()
asyncio.run(main(results, ctx.obj["inventory"], ctx.obj["catalog"], tags=tags))
with anta_progress_bar() as progress:
asyncio.run(main(results, ctx.obj["inventory"], ctx.obj["catalog"], tags=tags, progress=progress))
print_text(results=results, search=search, skip_error=skip_error)

ignore_status = ctx.obj["ignore_status"]
Expand Down Expand Up @@ -102,7 +105,8 @@ def tpl_report(ctx: click.Context, tags: Optional[List[str]], template: pathlib.
"""ANTA command to check network state with templated report"""
print_settings(ctx, template, output)
results = ResultManager()
asyncio.run(main(results, ctx.obj["inventory"], ctx.obj["catalog"], tags=tags))
with anta_progress_bar() as progress:
asyncio.run(main(results, ctx.obj["inventory"], ctx.obj["catalog"], tags=tags, progress=progress))
print_jinja(results=results, template=template, output=output)

ignore_status = ctx.obj["ignore_status"]
Expand Down
42 changes: 42 additions & 0 deletions anta/cli/nrfu/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import rich
from rich.panel import Panel
from rich.pretty import pprint
from rich.progress import BarColumn, MofNCompleteColumn, Progress, SpinnerColumn, TextColumn, TimeElapsedColumn, TimeRemainingColumn

from anta.cli.console import console
from anta.reporter import ReportJinja, ReportTable
Expand Down Expand Up @@ -80,3 +81,44 @@ def print_jinja(results: ResultManager, template: pathlib.Path, output: Optional
if output is not None:
with open(output, "w", encoding="utf-8") as file:
file.write(report)


# Adding our own ANTA spinner - overriding rich SPINNERS for our own
# so ignore warning for redefinition
rich.spinner.SPINNERS = { # type: ignore[attr-defined] # noqa: F811
"anta": {
"interval": 150,
"frames": [
"( 🐜)",
"( 🐜 )",
"( 🐜 )",
"( 🐜 )",
"( 🐜 )",
"(🐜 )",
"(🐌 )",
"( 🐌 )",
"( 🐌 )",
"( 🐌 )",
"( 🐌 )",
"( 🐌)",
],
}
}


def anta_progress_bar() -> Progress:
"""
Return a customized Progress for progress bar
"""
return Progress(
SpinnerColumn("anta"),
TextColumn("•"),
TextColumn("{task.description}[progress.percentage]{task.percentage:>3.0f}%"),
BarColumn(bar_width=None),
MofNCompleteColumn(),
TextColumn("•"),
TimeElapsedColumn(),
TextColumn("•"),
TimeRemainingColumn(),
expand=True,
)
5 changes: 5 additions & 0 deletions anta/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,12 @@ async def wrapper(*args: Any, **kwargs: Dict[str, Any]) -> TestResult:
anta_test = args[0]

if anta_test.result.result != "unset":
anta_test.update_progres()
mtache marked this conversation as resolved.
Show resolved Hide resolved
return anta_test.result

if anta_test.device.hw_model in platforms:
anta_test.result.is_skipped(f"{anta_test.__class__.__name__} test is not supported on {anta_test.device.hw_model}.")
anta_test.update_progres()
mtache marked this conversation as resolved.
Show resolved Hide resolved
return anta_test.result

return await function(*args, **kwargs)
Expand Down Expand Up @@ -71,6 +73,7 @@ async def wrapper(*args: Any, **kwargs: Dict[str, Any]) -> TestResult:
anta_test = args[0]

if anta_test.result.result != "unset":
anta_test.update_progres()
mtache marked this conversation as resolved.
Show resolved Hide resolved
return anta_test.result

if family == "ipv4":
Expand All @@ -92,10 +95,12 @@ async def wrapper(*args: Any, **kwargs: Dict[str, Any]) -> TestResult:
return anta_test.result
if "vrfs" not in command.json_output:
anta_test.result.is_skipped(f"no BGP configuration for {family} on this device")
anta_test.update_progres()
mtache marked this conversation as resolved.
Show resolved Hide resolved
return anta_test.result
if len(bgp_vrfs := command.json_output["vrfs"]) == 0 or len(bgp_vrfs["default"]["peers"]) == 0:
# No VRF
anta_test.result.is_skipped(f"no {family} peer on this device")
anta_test.update_progres()
mtache marked this conversation as resolved.
Show resolved Hide resolved
return anta_test.result

return await function(*args, **kwargs)
Expand Down
2 changes: 1 addition & 1 deletion anta/inventory/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,4 +170,4 @@ async def connect_inventory(self) -> None:
if __DEBUG__:
logger.exception(message, exc_info=r)
else:
logger.error(message + f": {exc_to_str(r)}")
logger.error(f"{message}: {exc_to_str(r)}")
29 changes: 21 additions & 8 deletions anta/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,10 @@
from abc import ABC, abstractmethod
from copy import deepcopy
from functools import wraps
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Coroutine, Dict, List, Literal, Optional, TypeVar, Union
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Coroutine, Dict, List, Literal, Optional, TypeVar, Union, cast

from pydantic import BaseModel, ConfigDict, conint
from rich.progress import Progress, TaskID

from anta import __DEBUG__
from anta.result_manager.models import TestResult
Expand Down Expand Up @@ -162,13 +163,17 @@ def __init__(
# TODO document very well the order of eos_data
eos_data: list[dict[Any, Any] | str] | None = None,
labels: list[str] | None = None,
progress: Optional[Progress] = None,
):
"""Class constructor"""
# Accept 6 input arguments
# pylint: disable=R0913
self.logger: logging.Logger = logging.getLogger(f"{self.__module__}.{self.__class__.__name__}")
self.device: AntaDevice = device
self.result: TestResult = TestResult(name=device.name, test=self.name, test_category=self.categories, test_description=self.description)
self.labels: List[str] = labels or []
self.instance_commands: List[AntaCommand] = []
self.progress = progress

# TODO - check optimization for deepcopy
# Generating instance_commands from list of commands and template
Expand Down Expand Up @@ -204,11 +209,7 @@ def all_data_collected(self) -> bool:

def get_failed_commands(self) -> List[AntaCommand]:
"""returns a list of all the commands that have a populated failed field"""
errors = []
for command in self.instance_commands:
if command.failed is not None:
errors.append(command)
return errors
return [command for command in self.instance_commands if command.failed is not None]

def __init_subclass__(cls) -> None:
"""
Expand All @@ -233,7 +234,7 @@ async def collect(self) -> None:
if __DEBUG__:
self.logger.exception(message)
else:
self.logger.error(message + f": {exc_to_str(e)}")
self.logger.error(f"{message}: {exc_to_str(e)}")
self.result.is_error(exc_to_str(e))

@staticmethod
Expand Down Expand Up @@ -286,12 +287,24 @@ async def wrapper(
if __DEBUG__:
self.logger.exception(message)
else:
self.logger.error(message + f": {exc_to_str(e)}")
self.logger.error(f"{message}: {exc_to_str(e)}")
self.result.is_error(exc_to_str(e))

self.update_progress()
return self.result

return wrapper

def update_progress(self) -> None:
"""
Update progress bar if it exists
"""
if self.progress:
# TODO this is hacky because we only have one task..
# Should be id 0 - casting for mypy
nrfu_task: TaskID = cast(TaskID, 0)
self.progress.update(nrfu_task, advance=1)

@abstractmethod
def test(self) -> Coroutine[Any, Any, TestResult]:
"""
Expand Down
14 changes: 11 additions & 3 deletions anta/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
import logging
from typing import Any, Callable, Dict, List, Optional, Tuple

from rich.progress import Progress

from anta import __DEBUG__
from anta.inventory import AntaInventory
from anta.result_manager import ResultManager
Expand All @@ -25,6 +27,7 @@ async def main(
tests: List[Tuple[Callable[..., TestResult], Dict[Any, Any]]],
tags: Optional[List[str]] = None,
established_only: bool = True,
progress: Optional[Progress] = None,
) -> None:
"""
Main coroutine to run ANTA.
Expand All @@ -45,6 +48,8 @@ async def main(
Returns:
any: List of results.
"""
# Accept 6 arguments here
# pylint: disable=R0913

await inventory.connect_inventory()

Expand All @@ -57,14 +62,17 @@ async def main(
template_params = test[1].get(TEST_TPL_PARAMS)
try:
# Instantiate AntaTest object
test_instance = test[0](device=device, template_params=template_params)
test_instance = test[0](device=device, template_params=template_params, progress=progress)
coros.append(test_instance.test(eos_data=None, **test_params))
except Exception as e: # pylint: disable=broad-exception-caught
message = "Error when creating ANTA tests"
if __DEBUG__:
logger.exception(message)
else:
logger.error(message + f": {exc_to_str(e)}")
logger.error(f"{message}: {exc_to_str(e)}")

if progress is not None:
progress.add_task("Running NRFU Tests...", total=len(coros))

logger.info("Running ANTA tests...")
res = await asyncio.gather(*coros, return_exceptions=True)
Expand All @@ -74,6 +82,6 @@ async def main(
if __DEBUG__:
logger.exception(message, exc_info=r)
else:
logger.error(message + f": {exc_to_str(r)}")
logger.error(f"{message}: {exc_to_str(r)}")
res.remove(r)
manager.add_test_results(res)
1 change: 0 additions & 1 deletion pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
disable=
invalid-name,
logging-fstring-interpolation,
logging-not-lazy,
fixme

[BASIC]
Expand Down