Skip to content

Commit

Permalink
feat(anta.cli): Add Progress Bar for NRFU (#251)
Browse files Browse the repository at this point in the history
* Feat(anta.cli): Add Progress Bar for NRFU

* WIP

* CI: They be linting

* Fix: ProgresS with two 's'

* make progress an AntaTest class attribute

* lint

* black

* fix logical expression

---------

Co-authored-by: Matthieu Tâche <[email protected]>
  • Loading branch information
gmuloc and mtache authored Jul 7, 2023
1 parent f887ebe commit 6928d26
Show file tree
Hide file tree
Showing 9 changed files with 102 additions and 23 deletions.
15 changes: 13 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,19 @@ repos:
args:
- --config-file=pyproject.toml
additional_dependencies:
- pydantic~=2.0
- "aio-eapi==0.3.0"
- "click==8.1.3"
- "click-help-colors==0.9.1"
- "cvprac>=1.2.0"
- "netaddr>=0.8.0"
- "pydantic~=2.0"
- "PyYAML>=6.0"
- "requests"
- "rich>=12.5.1"
- "scp"
- "asyncssh==2.13.1"
- "Jinja2>=3.1.2"
- types-PyYAML
- types-paramiko
- types-requests
files: ^(anta|scripts|tests)/
files: ^(anta|tests)/
8 changes: 4 additions & 4 deletions anta/cli/exec/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,10 @@ async def collect(dev: AntaDevice, command: str, outformat: Literal["json", "tex
logger.error(f"Could not collect commands on device {dev.name}: {exc_to_str(c.failed)}")
return
if c.ofmt == "json":
outfile = outdir / (command + ".json")
outfile = outdir / f"{command}.json"
content = json.dumps(c.json_output, indent=2)
elif c.ofmt == "text":
outfile = outdir / (command + ".log")
outfile = outdir / f"{command}.log"
content = c.text_output
with outfile.open(mode="w", encoding="UTF-8") as f:
f.write(content)
Expand All @@ -88,7 +88,7 @@ async def collect(dev: AntaDevice, command: str, outformat: Literal["json", "tex
if __DEBUG__:
logger.exception(message, exc_info=r)
else:
logger.error(message + f": {exc_to_str(r)}")
logger.error(f"{message}: {exc_to_str(r)}")


async def collect_scheduled_show_tech(inv: AntaInventory, root_dir: Path, configure: bool, tags: Optional[List[str]] = None, latest: Optional[int] = None) -> None:
Expand Down Expand Up @@ -150,7 +150,7 @@ async def collect(device: AntaDevice) -> None:
if __DEBUG__:
logger.exception(message)
else:
logger.error(message + f": {exc_to_str(e)}")
logger.error(f"{message}: {exc_to_str(e)}")

logger.info("Connecting to devices...")
await inv.connect_inventory()
Expand Down
15 changes: 10 additions & 5 deletions anta/cli/nrfu/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,11 @@
import click

from anta.cli.utils import parse_tags, return_code
from anta.models import AntaTest
from anta.result_manager import ResultManager
from anta.runner import main

from .utils import print_jinja, print_json, print_settings, print_table, print_text
from .utils import anta_progress_bar, print_jinja, print_json, print_settings, print_table, print_text

logger = logging.getLogger(__name__)

Expand All @@ -30,7 +31,8 @@ def table(ctx: click.Context, tags: Optional[List[str]], device: Optional[str],
"""ANTA command to check network states with table result"""
print_settings(ctx)
results = ResultManager()
asyncio.run(main(results, ctx.obj["inventory"], ctx.obj["catalog"], tags=tags))
with anta_progress_bar() as AntaTest.progress:
asyncio.run(main(results, ctx.obj["inventory"], ctx.obj["catalog"], tags=tags))
print_table(results=results, device=device, test=test)

# TODO make a util method to avoid repeating the same three line
Expand All @@ -54,7 +56,8 @@ def json(ctx: click.Context, tags: Optional[List[str]], output: Optional[pathlib
"""ANTA command to check network state with JSON result"""
print_settings(ctx)
results = ResultManager()
asyncio.run(main(results, ctx.obj["inventory"], ctx.obj["catalog"], tags=tags))
with anta_progress_bar() as AntaTest.progress:
asyncio.run(main(results, ctx.obj["inventory"], ctx.obj["catalog"], tags=tags))
print_json(results=results, output=output)

ignore_status = ctx.obj["ignore_status"]
Expand All @@ -71,7 +74,8 @@ def text(ctx: click.Context, tags: Optional[List[str]], search: Optional[str], s
"""ANTA command to check network states with text result"""
print_settings(ctx)
results = ResultManager()
asyncio.run(main(results, ctx.obj["inventory"], ctx.obj["catalog"], tags=tags))
with anta_progress_bar() as AntaTest.progress:
asyncio.run(main(results, ctx.obj["inventory"], ctx.obj["catalog"], tags=tags))
print_text(results=results, search=search, skip_error=skip_error)

ignore_status = ctx.obj["ignore_status"]
Expand Down Expand Up @@ -102,7 +106,8 @@ def tpl_report(ctx: click.Context, tags: Optional[List[str]], template: pathlib.
"""ANTA command to check network state with templated report"""
print_settings(ctx, template, output)
results = ResultManager()
asyncio.run(main(results, ctx.obj["inventory"], ctx.obj["catalog"], tags=tags))
with anta_progress_bar() as AntaTest.progress:
asyncio.run(main(results, ctx.obj["inventory"], ctx.obj["catalog"], tags=tags))
print_jinja(results=results, template=template, output=output)

ignore_status = ctx.obj["ignore_status"]
Expand Down
42 changes: 42 additions & 0 deletions anta/cli/nrfu/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import rich
from rich.panel import Panel
from rich.pretty import pprint
from rich.progress import BarColumn, MofNCompleteColumn, Progress, SpinnerColumn, TextColumn, TimeElapsedColumn, TimeRemainingColumn

from anta.cli.console import console
from anta.reporter import ReportJinja, ReportTable
Expand Down Expand Up @@ -80,3 +81,44 @@ def print_jinja(results: ResultManager, template: pathlib.Path, output: Optional
if output is not None:
with open(output, "w", encoding="utf-8") as file:
file.write(report)


# Adding our own ANTA spinner - overriding rich SPINNERS for our own
# so ignore warning for redefinition
rich.spinner.SPINNERS = { # type: ignore[attr-defined] # noqa: F811
"anta": {
"interval": 150,
"frames": [
"( 🐜)",
"( 🐜 )",
"( 🐜 )",
"( 🐜 )",
"( 🐜 )",
"(🐜 )",
"(🐌 )",
"( 🐌 )",
"( 🐌 )",
"( 🐌 )",
"( 🐌 )",
"( 🐌)",
],
}
}


def anta_progress_bar() -> Progress:
"""
Return a customized Progress for progress bar
"""
return Progress(
SpinnerColumn("anta"),
TextColumn("•"),
TextColumn("{task.description}[progress.percentage]{task.percentage:>3.0f}%"),
BarColumn(bar_width=None),
MofNCompleteColumn(),
TextColumn("•"),
TimeElapsedColumn(),
TextColumn("•"),
TimeRemainingColumn(),
expand=True,
)
7 changes: 6 additions & 1 deletion anta/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from functools import wraps
from typing import Any, Callable, Dict, List, TypeVar, cast

from anta.models import AntaCommand
from anta.models import AntaCommand, AntaTest
from anta.result_manager.models import TestResult
from anta.tools.misc import exc_to_str

Expand Down Expand Up @@ -35,10 +35,12 @@ async def wrapper(*args: Any, **kwargs: Dict[str, Any]) -> TestResult:
anta_test = args[0]

if anta_test.result.result != "unset":
AntaTest.update_progress()
return anta_test.result

if anta_test.device.hw_model in platforms:
anta_test.result.is_skipped(f"{anta_test.__class__.__name__} test is not supported on {anta_test.device.hw_model}.")
AntaTest.update_progress()
return anta_test.result

return await function(*args, **kwargs)
Expand Down Expand Up @@ -71,6 +73,7 @@ async def wrapper(*args: Any, **kwargs: Dict[str, Any]) -> TestResult:
anta_test = args[0]

if anta_test.result.result != "unset":
AntaTest.update_progress()
return anta_test.result

if family == "ipv4":
Expand All @@ -92,10 +95,12 @@ async def wrapper(*args: Any, **kwargs: Dict[str, Any]) -> TestResult:
return anta_test.result
if "vrfs" not in command.json_output:
anta_test.result.is_skipped(f"no BGP configuration for {family} on this device")
AntaTest.update_progress()
return anta_test.result
if len(bgp_vrfs := command.json_output["vrfs"]) == 0 or len(bgp_vrfs["default"]["peers"]) == 0:
# No VRF
anta_test.result.is_skipped(f"no {family} peer on this device")
AntaTest.update_progress()
return anta_test.result

return await function(*args, **kwargs)
Expand Down
2 changes: 1 addition & 1 deletion anta/inventory/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,4 +170,4 @@ async def connect_inventory(self) -> None:
if __DEBUG__:
logger.exception(message, exc_info=r)
else:
logger.error(message + f": {exc_to_str(r)}")
logger.error(f"{message}: {exc_to_str(r)}")
25 changes: 18 additions & 7 deletions anta/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Coroutine, Dict, List, Literal, Optional, TypeVar, Union

from pydantic import BaseModel, ConfigDict, conint
from rich.progress import Progress, TaskID

from anta import __DEBUG__
from anta.result_manager.models import TestResult
Expand Down Expand Up @@ -151,6 +152,8 @@ class AntaTest(ABC):
commands: ClassVar[list[AntaCommand]]
# TODO - today we support only one template per Test
template: ClassVar[AntaTemplate]
progress: Optional[Progress] = None
nrfu_task: Optional[TaskID] = None

# Optional class attributes
test_filters: ClassVar[list[AntaTestFilter]]
Expand All @@ -164,6 +167,8 @@ def __init__(
labels: list[str] | None = None,
):
"""Class constructor"""
# Accept 6 input arguments
# pylint: disable=R0913
self.logger: logging.Logger = logging.getLogger(f"{self.__module__}.{self.__class__.__name__}")
self.device: AntaDevice = device
self.result: TestResult = TestResult(name=device.name, test=self.name, test_category=self.categories, test_description=self.description)
Expand Down Expand Up @@ -204,11 +209,7 @@ def all_data_collected(self) -> bool:

def get_failed_commands(self) -> List[AntaCommand]:
"""returns a list of all the commands that have a populated failed field"""
errors = []
for command in self.instance_commands:
if command.failed is not None:
errors.append(command)
return errors
return [command for command in self.instance_commands if command.failed is not None]

def __init_subclass__(cls) -> None:
"""
Expand All @@ -233,7 +234,7 @@ async def collect(self) -> None:
if __DEBUG__:
self.logger.exception(message)
else:
self.logger.error(message + f": {exc_to_str(e)}")
self.logger.error(f"{message}: {exc_to_str(e)}")
self.result.is_error(exc_to_str(e))

@staticmethod
Expand Down Expand Up @@ -286,12 +287,22 @@ async def wrapper(
if __DEBUG__:
self.logger.exception(message)
else:
self.logger.error(message + f": {exc_to_str(e)}")
self.logger.error(f"{message}: {exc_to_str(e)}")
self.result.is_error(exc_to_str(e))

AntaTest.update_progress()
return self.result

return wrapper

@classmethod
def update_progress(cls) -> None:
"""
Update progress bar for all AntaTest objects if it exists
"""
if cls.progress and (cls.nrfu_task is not None):
cls.progress.update(cls.nrfu_task, advance=1)

@abstractmethod
def test(self) -> Coroutine[Any, Any, TestResult]:
"""
Expand Down
10 changes: 8 additions & 2 deletions anta/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from anta import __DEBUG__
from anta.inventory import AntaInventory
from anta.models import AntaTest
from anta.result_manager import ResultManager
from anta.result_manager.models import TestResult
from anta.tools.misc import exc_to_str
Expand Down Expand Up @@ -45,6 +46,8 @@ async def main(
Returns:
any: List of results.
"""
# Accept 6 arguments here
# pylint: disable=R0913

await inventory.connect_inventory()

Expand All @@ -64,7 +67,10 @@ async def main(
if __DEBUG__:
logger.exception(message)
else:
logger.error(message + f": {exc_to_str(e)}")
logger.error(f"{message}: {exc_to_str(e)}")

if AntaTest.progress is not None:
AntaTest.nrfu_task = AntaTest.progress.add_task("Running NRFU Tests...", total=len(coros))

logger.info("Running ANTA tests...")
res = await asyncio.gather(*coros, return_exceptions=True)
Expand All @@ -74,6 +80,6 @@ async def main(
if __DEBUG__:
logger.exception(message, exc_info=r)
else:
logger.error(message + f": {exc_to_str(r)}")
logger.error(f"{message}: {exc_to_str(r)}")
res.remove(r)
manager.add_test_results(res)
1 change: 0 additions & 1 deletion pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
disable=
invalid-name,
logging-fstring-interpolation,
logging-not-lazy,
fixme

[BASIC]
Expand Down

0 comments on commit 6928d26

Please sign in to comment.